diff options
Diffstat (limited to 'rest_framework')
48 files changed, 2120 insertions, 589 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 557f5943..fd176603 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.0.0' +__version__ = '2.1.2' VERSION = __version__ # synonym diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7664c400..02e50604 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -1,8 +1,17 @@ """ -The :mod:`compat` module provides support for backwards compatibility with older versions of django/python. +The `compat` module provides support for backwards compatibility with older +versions of django/python, and compatbility wrappers around optional packages. """ +# flake8: noqa import django +# django-filter is optional +try: + import django_filters +except: + django_filters = None + + # cStringIO only if it's available, otherwise StringIO try: import cStringIO as StringIO @@ -346,33 +355,6 @@ except ImportError: yaml = None -import unittest -try: - import unittest.skip -except ImportError: # python < 2.7 - from unittest import TestCase - import functools - - def skip(reason): - # Pasted from py27/lib/unittest/case.py - """ - Unconditionally skip a test. - """ - def decorator(test_item): - if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - pass - test_item = skip_wrapper - - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason - return test_item - return decorator - - unittest.skip = skip - - # xml.etree.parse only throws ParseError for python >= 2.7 try: from xml.etree import ParseError as ETParseError diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 948973ae..a231f191 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,8 +10,18 @@ def api_view(http_method_names): def decorator(func): - class WrappedAPIView(APIView): - pass + WrappedAPIView = type( + 'WrappedAPIView', + (APIView,), + {'__doc__': func.__doc__} + ) + + # Note, the above allows us to set the docstring. + # It is the equivelent of: + # + # class WrappedAPIView(APIView): + # pass + # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 6ae0c95c..d635351c 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -47,14 +47,6 @@ class PermissionDenied(APIException): self.detail = detail or self.default_detail -class InvalidFormat(APIException): - status_code = status.HTTP_404_NOT_FOUND - default_detail = "Format suffix '.%s' not found." - - def __init__(self, format, detail=None): - self.detail = (detail or self.default_detail) % format - - class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = "Method '%s' not allowed." diff --git a/rest_framework/fields.py b/rest_framework/fields.py index bb9a523d..a4e29a30 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -5,13 +5,16 @@ import warnings from django.core import validators from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve +from django.core.urlresolvers import resolve, get_script_prefix from django.conf import settings +from django.forms import widgets +from django.forms.models import ModelChoiceIterator from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ from rest_framework.reverse import reverse from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import timezone +from urlparse import urlparse def is_simple_callable(obj): @@ -37,12 +40,12 @@ class Field(object): self.source = source - def initialize(self, parent): + def initialize(self, parent, field_name): """ Called to set up a field prior to field_to_native or field_from_native. parent - The parent serializer. - model_field - The model field this field corrosponds to, if one exists. + model_field - The model field this field corresponds to, if one exists. """ self.parent = parent self.root = parent.root or parent @@ -70,6 +73,8 @@ class Field(object): value = obj for component in self.source.split('.'): value = getattr(value, component) + if is_simple_callable(value): + value = value() else: value = getattr(obj, field_name) return self.to_native(value) @@ -85,6 +90,8 @@ class Field(object): return value elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): return [self.to_native(item) for item in value] + elif isinstance(value, dict): + return dict(map(self.to_native, (k, v)) for k, v in value.items()) return smart_unicode(value) def attributes(self): @@ -105,15 +112,20 @@ class WritableField(Field): 'required': _('This field is required.'), 'invalid': _('Invalid value.'), } + widget = widgets.TextInput + default = None + + def __init__(self, source=None, read_only=False, required=None, + validators=[], error_messages=None, widget=None, + default=None, blank=None): - def __init__(self, source=None, readonly=False, required=None, - validators=[], error_messages=None): super(WritableField, self).__init__(source=source) - self.readonly = readonly + + self.read_only = read_only if required is None: - self.required = not(readonly) + self.required = not(read_only) else: - assert not readonly, "Cannot set required=True and readonly=True" + assert not read_only, "Cannot set required=True and read_only=True" self.required = required messages = {} @@ -123,6 +135,14 @@ class WritableField(Field): self.error_messages = messages self.validators = self.default_validators + validators + self.default = default if default is not None else self.default + self.blank = blank + + # Widgets are ony used for HTML forms. + widget = widget or self.widget + if isinstance(widget, type): + widget = widget() + self.widget = widget def validate(self, value): if value in validators.EMPTY_VALUES and self.required: @@ -151,15 +171,18 @@ class WritableField(Field): Given a dictionary and a field name, updates the dictionary `into`, with the field and it's deserialized value. """ - if self.readonly: + if self.read_only: return try: native = data[field_name] except KeyError: - if self.required: - raise ValidationError(self.error_messages['required']) - return + if self.default is not None: + native = self.default + else: + if self.required: + raise ValidationError(self.error_messages['required']) + return value = self.from_native(native) if self.source == '*': @@ -179,7 +202,7 @@ class WritableField(Field): class ModelField(WritableField): """ - A generic field that can be used against an arbirtrary model field. + A generic field that can be used against an arbitrary model field. """ def __init__(self, *args, **kwargs): try: @@ -189,11 +212,11 @@ class ModelField(WritableField): super(ModelField, self).__init__(*args, **kwargs) def from_native(self, value): - try: - rel = self.model_field.rel - except: + rel = getattr(self.model_field, "rel", None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(value) + else: return self.model_field.to_python(value) - return rel.to._meta.get_field(rel.field_name).to_python(value) def field_to_native(self, obj, field_name): value = self.model_field._get_val_from_obj(obj) @@ -209,32 +232,119 @@ class ModelField(WritableField): ##### Relational fields ##### +# Not actually Writable, but subclasses may need to be. class RelatedField(WritableField): """ Base class for related model fields. + + If not overridden, this represents a to-one relationship, using the unicode + representation of the target. """ + widget = widgets.Select + cache_choices = False + empty_label = None + default_read_only = True # TODO: Remove this + def __init__(self, *args, **kwargs): self.queryset = kwargs.pop('queryset', None) super(RelatedField, self).__init__(*args, **kwargs) + self.read_only = kwargs.pop('read_only', self.default_read_only) + + def initialize(self, parent, field_name): + super(RelatedField, self).initialize(parent, field_name) + if self.queryset is None and not self.read_only: + try: + manager = getattr(self.parent.opts.model, self.source or field_name) + if hasattr(manager, 'related'): # Forward + self.queryset = manager.related.model._default_manager.all() + else: # Reverse + self.queryset = manager.field.rel.to._default_manager.all() + except: + raise + msg = ('Serializer related fields must include a `queryset`' + + ' argument or set `read_only=True') + raise Exception(msg) + + ### We need this stuff to make form choices work... + + # def __deepcopy__(self, memo): + # result = super(RelatedField, self).__deepcopy__(memo) + # result.queryset = result.queryset + # return result + + def prepare_value(self, obj): + return self.to_native(obj) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + def _get_queryset(self): + return self._queryset + + def _set_queryset(self, queryset): + self._queryset = queryset + self.widget.choices = self.choices + + queryset = property(_get_queryset, _set_queryset) + + def _get_choices(self): + # If self._choices is set, then somebody must have manually set + # the property self.choices. In this case, just return self._choices. + if hasattr(self, '_choices'): + return self._choices + + # Otherwise, execute the QuerySet in self.queryset to determine the + # choices dynamically. Return a fresh ModelChoiceIterator that has not been + # consumed. Note that we're instantiating a new ModelChoiceIterator *each* + # time _get_choices() is called (and, thus, each time self.choices is + # accessed) so that we can ensure the QuerySet has not been consumed. This + # construct might look complicated but it allows for lazy evaluation of + # the queryset. + return ModelChoiceIterator(self) + + def _set_choices(self, value): + # Setting choices also sets the choices on the widget. + # choices can be any iterable, but we call list() on it because + # it will be consumed more than once. + self._choices = self.widget.choices = list(value) + + choices = property(_get_choices, _set_choices) + + ### Regular serializier stuff... def field_to_native(self, obj, field_name): value = getattr(obj, self.source or field_name) return self.to_native(value) def field_from_native(self, data, field_name, into): + if self.read_only: + return + value = data.get(field_name) - into[(self.source or field_name) + '_id'] = self.from_native(value) + into[(self.source or field_name)] = self.from_native(value) class ManyRelatedMixin(object): """ Mixin to convert a related field to a many related field. """ + widget = widgets.SelectMultiple + def field_to_native(self, obj, field_name): value = getattr(obj, self.source or field_name) return [self.to_native(item) for item in value.all()] def field_from_native(self, data, field_name, into): + if self.read_only: + return + try: # Form data value = data.getlist(self.source or field_name) @@ -250,6 +360,9 @@ class ManyRelatedMixin(object): class ManyRelatedField(ManyRelatedMixin, RelatedField): """ Base class for related model managers. + + If not overridden, this represents a to-many relationship, using the unicode + representations of the target, and is read-only. """ pass @@ -258,12 +371,38 @@ class ManyRelatedField(ManyRelatedMixin, RelatedField): class PrimaryKeyRelatedField(RelatedField): """ - Serializes a related field or related object to a pk value. + Represents a to-one relationship as a pk value. """ + default_read_only = False + + # TODO: Remove these field hacks... + def prepare_value(self, obj): + return self.to_native(obj.pk) + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj.pk)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + # TODO: Possibly change this to just take `obj`, through prob less performant def to_native(self, pk): return pk + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) + raise ValidationError(msg) + def field_to_native(self, obj, field_name): try: # Prefer obj.serializable_value for performance reasons @@ -278,8 +417,23 @@ class PrimaryKeyRelatedField(RelatedField): class ManyPrimaryKeyRelatedField(ManyRelatedField): """ - Serializes a to-many related field or related manager to a pk value. + Represents a to-many relationship as a pk value. """ + default_read_only = False + + def prepare_value(self, obj): + return self.to_native(obj.pk) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj.pk)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + def to_native(self, pk): return pk @@ -294,27 +448,83 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): # Forward relationship return [self.to_native(item.pk) for item in queryset.all()] + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) + raise ValidationError(msg) + +### Slug relationships + + +class SlugRelatedField(RelatedField): + default_read_only = False + + def __init__(self, *args, **kwargs): + self.slug_field = kwargs.pop('slug_field', None) + assert self.slug_field, 'slug_field is required' + super(SlugRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + return getattr(obj, self.slug_field) + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(**{self.slug_field: data}) + except ObjectDoesNotExist: + raise ValidationError('Object with %s=%s does not exist.' % + (self.slug_field, unicode(data))) + + +class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): + pass + ### Hyperlinked relationships class HyperlinkedRelatedField(RelatedField): + """ + Represents a to-one relationship, using hyperlinking. + """ pk_url_kwarg = 'pk' - slug_url_kwarg = 'slug' slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + default_read_only = False def __init__(self, *args, **kwargs): try: self.view_name = kwargs.pop('view_name') except: raise ValueError("Hyperlinked field requires 'view_name' kwarg") + + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + + self.format = kwargs.pop('format', None) super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + def get_slug_field(self): + """ + Get the name of a slug field to be used to look up by slug. + """ + return self.slug_field + def to_native(self, obj): view_name = self.view_name request = self.context.get('request', None) + format = self.format or self.context.get('format', None) kwargs = {self.pk_url_kwarg: obj.pk} try: - return reverse(view_name, kwargs=kwargs, request=request) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass @@ -325,13 +535,13 @@ class HyperlinkedRelatedField(RelatedField): kwargs = {self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request) + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) except: pass kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request) + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) except: pass @@ -340,6 +550,16 @@ class HyperlinkedRelatedField(RelatedField): def from_native(self, value): # Convert URL -> model instance pk # TODO: Use values_list + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + if value.startswith('http:') or value.startswith('https:'): + # If needed convert absolute URLs to relative path + value = urlparse(value).path + prefix = get_script_prefix() + if value.startswith(prefix): + value = '/' + value[len(prefix):] + try: match = resolve(value) except: @@ -353,7 +573,7 @@ class HyperlinkedRelatedField(RelatedField): # Try explicit primary key. if pk is not None: - return pk + queryset = self.queryset.filter(pk=pk) # Next, try looking up by slug. elif slug is not None: slug_field = self.get_slug_field() @@ -366,48 +586,88 @@ class HyperlinkedRelatedField(RelatedField): obj = queryset.get() except ObjectDoesNotExist: raise ValidationError('Invalid hyperlink - object does not exist.') - return obj.pk + return obj class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): + """ + Represents a to-many relationship, using hyperlinking. + """ pass class HyperlinkedIdentityField(Field): """ - A field that represents the model's identity using a hyperlink. + Represents the instance, or a property on the instance, using hyperlinking. """ + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + def __init__(self, *args, **kwargs): - # TODO: Make this mandatory, and have the HyperlinkedModelSerializer - # set it on-the-fly + # TODO: Make view_name mandatory, and have the + # HyperlinkedModelSerializer set it on-the-fly self.view_name = kwargs.pop('view_name', None) + self.format = kwargs.pop('format', None) + + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) def field_to_native(self, obj, field_name): request = self.context.get('request', None) + format = self.format or self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name - view_kwargs = {'pk': obj.pk} - return reverse(view_name, kwargs=view_kwargs, request=request) + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) ##### Typed Fields ##### class BooleanField(WritableField): type_name = 'BooleanField' + widget = widgets.CheckboxInput default_error_messages = { 'invalid': _(u"'%s' value must be either True or False."), } + empty = False + + # Note: we set default to `False` in order to fill in missing value not + # supplied by html form. TODO: Fix so that only html form input gets + # this behavior. + default = False def from_native(self, value): - if value in (True, False): - # if value is 1 or 0 than it's equal to True or False, but we want - # to return a true bool for semantic reasons. - return bool(value) if value in ('t', 'True', '1'): return True if value in ('f', 'False', '0'): return False - raise ValidationError(self.error_messages['invalid'] % value) + return bool(value) class CharField(WritableField): @@ -421,12 +681,68 @@ class CharField(WritableField): if max_length is not None: self.validators.append(validators.MaxLengthValidator(max_length)) + def validate(self, value): + """ + Validates that the value is supplied (if required). + """ + # if empty string and allow blank + if self.blank and not value: + return + else: + super(CharField, self).validate(value) + def from_native(self, value): if isinstance(value, basestring) or value is None: return value return smart_unicode(value) +class ChoiceField(WritableField): + type_name = 'ChoiceField' + widget = widgets.Select + default_error_messages = { + 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), + } + + def __init__(self, choices=(), *args, **kwargs): + super(ChoiceField, self).__init__(*args, **kwargs) + self.choices = choices + + def _get_choices(self): + return self._choices + + def _set_choices(self, value): + # Setting choices also sets the choices on the widget. + # choices can be any iterable, but we call list() on it because + # it will be consumed more than once. + self._choices = self.widget.choices = list(value) + + choices = property(_get_choices, _set_choices) + + def validate(self, value): + """ + Validates that the input is in self.choices. + """ + super(ChoiceField, self).validate(value) + if value and not self.valid_value(value): + raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) + + def valid_value(self, value): + """ + Check to see if the provided value is a valid choice. + """ + for k, v in self.choices: + if isinstance(v, (list, tuple)): + # This is an optgroup, so look inside the group for options + for k2, v2 in v: + if value == smart_unicode(k2): + return True + else: + if value == smart_unicode(k): + return True + return False + + class EmailField(CharField): type_name = 'EmailField' @@ -436,7 +752,10 @@ class EmailField(CharField): default_validators = [validators.validate_email] def from_native(self, value): - return super(EmailField, self).from_native(value).strip() + ret = super(EmailField, self).from_native(value) + if ret is None: + return None + return ret.strip() def __deepcopy__(self, memo): result = copy.copy(self) @@ -458,8 +777,9 @@ class DateField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): if timezone and settings.USE_TZ and timezone.is_aware(value): # Convert aware datetimes to the default time zone @@ -497,8 +817,9 @@ class DateTimeField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): return value if isinstance(value, datetime.date): @@ -556,6 +877,7 @@ class IntegerField(WritableField): def from_native(self, value): if value in validators.EMPTY_VALUES: return None + try: value = int(str(value)) except (ValueError, TypeError): @@ -571,8 +893,9 @@ class FloatField(WritableField): } def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + try: return float(value) except (TypeError, ValueError): diff --git a/rest_framework/filters.py b/rest_framework/filters.py new file mode 100644 index 00000000..ccae4825 --- /dev/null +++ b/rest_framework/filters.py @@ -0,0 +1,59 @@ +from rest_framework.compat import django_filters + +FilterSet = django_filters and django_filters.FilterSet or None + + +class BaseFilterBackend(object): + """ + A base class from which all filter backend classes should inherit. + """ + + def filter_queryset(self, request, queryset, view): + """ + Return a filtered queryset. + """ + raise NotImplementedError(".filter_queryset() must be overridden.") + + +class DjangoFilterBackend(BaseFilterBackend): + """ + A filter backend that uses django-filter. + """ + default_filter_set = FilterSet + + def __init__(self): + assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' + + def get_filter_class(self, view): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + filter_class = getattr(view, 'filter_class', None) + filter_fields = getattr(view, 'filter_fields', None) + view_model = getattr(view, 'model', None) + + if filter_class: + filter_model = filter_class.Meta.model + + assert issubclass(filter_model, view_model), \ + 'FilterSet model %s does not match view model %s' % \ + (filter_model, view_model) + + return filter_class + + if filter_fields: + class AutoFilterSet(self.default_filter_set): + class Meta: + model = view_model + fields = filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, request, queryset, view): + filter_class = self.get_filter_class(view) + + if filter_class: + return filter_class(request.GET, queryset=queryset) + + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 59739d01..ebd06e45 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,5 +1,5 @@ """ -Generic views that provide commmonly needed behaviour. +Generic views that provide commonly needed behaviour. """ from rest_framework import views, mixins @@ -10,12 +10,12 @@ from django.views.generic.list import MultipleObjectMixin ### Base classes for the generic views ### -class BaseView(views.APIView): +class GenericAPIView(views.APIView): """ Base class for all other generic views. """ serializer_class = None - model_serializer_class = api_settings.MODEL_SERIALIZER + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS def get_serializer_context(self): """ @@ -43,21 +43,31 @@ class BaseView(views.APIView): return serializer_class - def get_serializer(self, data=None, files=None, instance=None): + def get_serializer(self, instance=None, data=None, files=None): # TODO: add support for files # TODO: add support for seperate serializer/deserializer serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(data, instance=instance, context=context) + return serializer_class(instance, data=data, context=context) -class MultipleObjectBaseView(MultipleObjectMixin, BaseView): +class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): """ Base class for generic views onto a queryset. """ - pagination_serializer_class = api_settings.PAGINATION_SERIALIZER + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY + filter_backend = api_settings.FILTER_BACKEND + + def filter_queryset(self, queryset): + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) + + def get_filtered_queryset(self): + return self.filter_queryset(self.get_queryset()) def get_pagination_serializer_class(self): """ @@ -75,7 +85,7 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView): return pagination_serializer_class(instance=page, context=context) -class SingleObjectBaseView(SingleObjectMixin, BaseView): +class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): """ Base class for generic views onto a model instance. """ @@ -86,7 +96,7 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): """ Override default to add support for object-level permissions. """ - obj = super(SingleObjectBaseView, self).get_object() + obj = super(SingleObjectAPIView, self).get_object() if not self.has_permission(self.request, obj): self.permission_denied(self.request) return obj @@ -95,8 +105,19 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with a base view. ### + +class CreateAPIView(mixins.CreateModelMixin, + GenericAPIView): + + """ + Concrete view for creating a model instance. + """ + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + class ListAPIView(mixins.ListModelMixin, - MultipleObjectBaseView): + MultipleObjectAPIView): """ Concrete view for listing a queryset. """ @@ -104,9 +125,38 @@ class ListAPIView(mixins.ListModelMixin, return self.list(request, *args, **kwargs) +class RetrieveAPIView(mixins.RetrieveModelMixin, + SingleObjectAPIView): + """ + Concrete view for retrieving a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class DestroyAPIView(mixins.DestroyModelMixin, + SingleObjectAPIView): + + """ + Concrete view for deleting a model instance. + """ + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class UpdateAPIView(mixins.UpdateModelMixin, + SingleObjectAPIView): + + """ + Concrete view for updating a model instance. + """ + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectBaseView): + MultipleObjectAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -117,18 +167,9 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) -class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectBaseView): - """ - Concrete view for retrieving a model instance. - """ - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectBaseView): + SingleObjectAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -142,7 +183,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectBaseView): + SingleObjectAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 29153e18..c3625a88 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -3,9 +3,6 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. - -Eg. Use mixins to build a Resource class, and have a Router class - perform the binding of http methods to actions for us. """ from django.http import Http404 from rest_framework import status @@ -20,20 +17,24 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): + self.pre_save(serializer.object) self.object = serializer.save() return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def pre_save(self, obj): + pass + class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectBaseView`. + Should be mixed in with `MultipleObjectAPIView`. """ empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - self.object_list = self.get_queryset() + self.object_list = self.get_filtered_queryset() # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. @@ -46,10 +47,11 @@ class ListModelMixin(object): # which may be `None` to disable pagination. page_size = self.get_paginate_by(self.object_list) if page_size: - paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size) + packed = self.paginate_queryset(self.object_list, page_size) + paginator, page, queryset, is_paginated = packed serializer = self.get_pagination_serializer(page) else: - serializer = self.get_serializer(instance=self.object_list) + serializer = self.get_serializer(self.object_list) return Response(serializer.data) @@ -61,7 +63,7 @@ class RetrieveModelMixin(object): """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() - serializer = self.get_serializer(instance=self.object) + serializer = self.get_serializer(self.object) return Response(serializer.data) @@ -73,26 +75,25 @@ class UpdateModelMixin(object): def update(self, request, *args, **kwargs): try: self.object = self.get_object() + success_status = status.HTTP_200_OK except Http404: self.object = None + success_status = status.HTTP_201_CREATED - serializer = self.get_serializer(data=request.DATA, instance=self.object) + serializer = self.get_serializer(self.object, data=request.DATA) if serializer.is_valid(): - if self.object is None: - # If PUT occurs to a non existant object, we need to set any - # attributes on the object that are implicit in the URL. - self.update_urlconf_attributes(serializer.object) + self.pre_save(serializer.object) self.object = serializer.save() - return Response(serializer.data) + return Response(serializer.data, status=success_status) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def update_urlconf_attributes(self, obj): + def pre_save(self, obj): """ - When update (re)creates an object, we need to set any attributes that - are tied to the URLconf. + Set any attributes on the object that are implicit in the request. """ + # pk and/or slug attributes are implicit in the URL. pk = self.kwargs.get(self.pk_url_kwarg, None) if pk: setattr(obj, 'pk', pk) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 8b22f669..dae38477 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,48 +1,38 @@ +from django.http import Http404 from rest_framework import exceptions from rest_framework.settings import api_settings from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches class BaseContentNegotiation(object): - def negotiate(self, request, renderers, format=None, force=False): - raise NotImplementedError('.negotiate() must be implemented') + def select_parser(self, request, parsers): + raise NotImplementedError('.select_parser() must be implemented') + def select_renderer(self, request, renderers, format_suffix=None): + raise NotImplementedError('.select_renderer() must be implemented') -class DefaultContentNegotiation(object): + +class DefaultContentNegotiation(BaseContentNegotiation): settings = api_settings - def select_parser(self, parsers, media_type): + def select_parser(self, request, parsers): """ Given a list of parsers and a media type, return the appropriate parser to handle the incoming request. """ for parser in parsers: - if media_type_matches(parser.media_type, media_type): + if media_type_matches(parser.media_type, request.content_type): return parser return None - def negotiate(self, request, renderers, format=None, force=False): + def select_renderer(self, request, renderers, format_suffix=None): """ Given a request and a list of renderers, return a two-tuple of: (renderer, media type). - - If force is set, then suppress exceptions, and forcibly return a - fallback renderer and media_type. - """ - try: - return self.unforced_negotiate(request, renderers, format) - except (exceptions.InvalidFormat, exceptions.NotAcceptable): - if force: - return (renderers[0], renderers[0].media_type) - raise - - def unforced_negotiate(self, request, renderers, format=None): - """ - As `.negotiate()`, but does not take the optional `force` agument, - or suppress exceptions. """ # Allow URL style format override. eg. "?format=json - format = format or request.GET.get(self.settings.URL_FORMAT_OVERRIDE) + format_query_param = self.settings.URL_FORMAT_OVERRIDE + format = format_suffix or request.GET.get(format_query_param) if format: renderers = self.filter_renderers(renderers, format) @@ -77,7 +67,7 @@ class DefaultContentNegotiation(object): renderers = [renderer for renderer in renderers if renderer.format == format] if not renderers: - raise exceptions.InvalidFormat(format) + raise Http404 return renderers def get_accept_list(self, request): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 131718fd..d241ade7 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from rest_framework.templatetags.rest_framework import replace_query_param # TODO: Support URLconf kwarg-style paging @@ -7,30 +8,30 @@ class NextPageField(serializers.Field): """ Field that returns a link to the next page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_next(): return None page = value.next_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PreviousPageField(serializers.Field): """ Field that returns a link to the previous page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_previous(): return None page = value.previous_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri('?page=%d' % page) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PaginationSerializerOptions(serializers.SerializerOptions): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 048b71e1..4841676c 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -1,14 +1,8 @@ """ -Django supports parsing the content of an HTTP request, but only for form POST requests. -That behavior is sufficient for dealing with standard HTML forms, but it doesn't map well -to general HTTP requests. +Parsers are used to parse the content of incoming HTTP requests. -We need a method to be able to: - -1.) Determine the parsed content on a request for methods other than POST (eg typically also PUT) - -2.) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded - and multipart/form-data. (eg also handle multipart/json) +They give us a generic way of being able to handle various media types +on the request, such as form content or json encoded data. """ from django.http import QueryDict @@ -21,7 +15,6 @@ from xml.etree import ElementTree as ET from xml.parsers.expat import ExpatError import datetime import decimal -from io import BytesIO class DataAndFiles(object): @@ -33,29 +26,18 @@ class DataAndFiles(object): class BaseParser(object): """ All parsers should extend `BaseParser`, specifying a `media_type` - attribute, and overriding the `.parse_stream()` method. + attribute, and overriding the `.parse()` method. """ media_type = None - def parse(self, string_or_stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ - The main entry point to parsers. This is a light wrapper around - `parse_stream`, that instead handles both string and stream objects. - """ - if isinstance(string_or_stream, basestring): - stream = BytesIO(string_or_stream) - else: - stream = string_or_stream - return self.parse_stream(stream, parser_context) - - def parse_stream(self, stream, parser_context=None): - """ - Given a stream to read from, return the deserialized output. - Should return parsed data, or a DataAndFiles object consisting of the + Given a stream to read from, return the parsed representation. + Should return parsed data, or a `DataAndFiles` object consisting of the parsed data and files. """ - raise NotImplementedError(".parse_stream() must be overridden.") + raise NotImplementedError(".parse() must be overridden.") class JSONParser(BaseParser): @@ -65,7 +47,7 @@ class JSONParser(BaseParser): media_type = 'application/json' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -85,7 +67,7 @@ class YAMLParser(BaseParser): media_type = 'application/yaml' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -105,7 +87,7 @@ class FormParser(BaseParser): media_type = 'application/x-www-form-urlencoded' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -123,7 +105,7 @@ class MultiPartParser(BaseParser): media_type = 'multipart/form-data' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a DataAndFiles object. @@ -131,8 +113,10 @@ class MultiPartParser(BaseParser): `.files` will be a `QueryDict` containing all the form files. """ parser_context = parser_context or {} - meta = parser_context['meta'] - upload_handlers = parser_context['upload_handlers'] + request = parser_context['request'] + meta = request.META + upload_handlers = request.upload_handlers + try: parser = DjangoMultiPartParser(meta, stream, upload_handlers) data, files = parser.parse() @@ -148,7 +132,7 @@ class XMLParser(BaseParser): media_type = 'application/xml' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): try: tree = ET.parse(stream) except (ExpatError, ETParseError, ValueError), exc: diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 6f848cee..655b78a3 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -18,6 +18,17 @@ class BasePermission(object): raise NotImplementedError(".has_permission() must be overridden.") +class AllowAny(BasePermission): + """ + Allow any access. + This isn't strictly required, since you could use an empty + permission_classes list, but it's useful because it makes the intention + more explicit. + """ + def has_permission(self, request, view, obj=None): + return True + + class IsAuthenticated(BasePermission): """ Allows access only to authenticated users. @@ -85,7 +96,7 @@ class DjangoModelPermissions(BasePermission): """ kwargs = { 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.module_name + 'model_name': model_cls._meta.module_name } return [perm % kwargs for perm in self.perms_map[method]] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 94d253c9..22fd6e74 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -1,14 +1,16 @@ """ -Renderers are used to serialize a View's output into specific media types. +Renderers are used to serialize a response into specific media types. -Django REST framework also provides HTML and PlainText renderers that help self-document the API, -by serializing the output along with documentation regarding the View, output status and headers, -and providing forms and links depending on the allowed methods, renderers and parsers on the View. +They give us a generic way of being able to handle various media types +on the response, such as JSON encoded data or HTML output. + +REST framework also provides an HTML renderer the renders the browseable API. """ +import copy import string from django import forms from django.http.multipartparser import parse_header -from django.template import RequestContext, loader +from django.template import RequestContext, loader, Template from django.utils import simplejson as json from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError @@ -23,8 +25,8 @@ from rest_framework import serializers, parsers class BaseRenderer(object): """ - All renderers must extend this class, set the :attr:`media_type` attribute, - and override the :meth:`render` method. + All renderers should extend this class, setting the `media_type` + and `format` attributes, and override the `.render()` method. """ media_type = None @@ -98,7 +100,7 @@ class JSONPRenderer(JSONRenderer): callback = self.get_callback(renderer_context) json = super(JSONPRenderer, self).render(data, accepted_media_type, renderer_context) - return "%s(%s);" % (callback, json) + return u"%s(%s);" % (callback, json) class XMLRenderer(BaseRenderer): @@ -137,18 +139,33 @@ class YAMLRenderer(BaseRenderer): return yaml.dump(data, stream=None, Dumper=self.encoder) -class HTMLRenderer(BaseRenderer): +class TemplateHTMLRenderer(BaseRenderer): """ - A Base class provided for convenience. + An HTML renderer for use with templates. + + The data supplied to the Response object should be a dictionary that will + be used as context for the template. + + The template name is determined by (in order of preference): + + 1. An explicit `.template_name` attribute set on the response. + 2. An explicit `.template_name` attribute set on this class. + 3. The return result of calling `view.get_template_names()`. - Render the object simply by using the given template. - To create a template renderer, subclass this class, and set - the :attr:`media_type` and :attr:`template` attributes. + For example: + data = {'users': User.objects.all()} + return Response(data, template_name='users.html') + + For pre-rendered HTML, see StaticHTMLRenderer. """ media_type = 'text/html' format = 'html' template_name = None + exception_template_names = [ + '%(status_code)s.html', + 'api_exception.html' + ] def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -165,15 +182,21 @@ class HTMLRenderer(BaseRenderer): request = renderer_context['request'] response = renderer_context['response'] - template_names = self.get_template_names(response, view) - template = self.resolve_template(template_names) - context = self.resolve_context(data, request) + if response.exception: + template = self.get_exception_template(response) + else: + template_names = self.get_template_names(response, view) + template = self.resolve_template(template_names) + + context = self.resolve_context(data, request, response) return template.render(context) def resolve_template(self, template_names): return loader.select_template(template_names) - def resolve_context(self, data, request): + def resolve_context(self, data, request, response): + if response.exception: + data['status_code'] = response.status_code return RequestContext(request, data) def get_template_names(self, response, view): @@ -185,6 +208,48 @@ class HTMLRenderer(BaseRenderer): return view.get_template_names() raise ConfigurationError('Returned a template response with no template_name') + def get_exception_template(self, response): + template_names = [name % {'status_code': response.status_code} + for name in self.exception_template_names] + + try: + # Try to find an appropriate error template + return self.resolve_template(template_names) + except: + # Fall back to using eg '404 Not Found' + return Template('%d %s' % (response.status_code, + response.status_text.title())) + + +# Note, subclass TemplateHTMLRenderer simply for the exception behavior +class StaticHTMLRenderer(TemplateHTMLRenderer): + """ + An HTML renderer class that simply returns pre-rendered HTML. + + The data supplied to the Response object should be a string representing + the pre-rendered HTML content. + + For example: + data = '<html><body>example</body></html>' + return Response(data) + + For template rendered HTML, see TemplateHTMLRenderer. + """ + media_type = 'text/html' + format = 'html' + + def render(self, data, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context['response'] + + if response and response.exception: + request = renderer_context['request'] + template = self.get_exception_template(response) + context = self.resolve_context(data, request, response) + return template.render(context) + + return data + class BrowsableAPIRenderer(BaseRenderer): """ @@ -222,11 +287,9 @@ class BrowsableAPIRenderer(BaseRenderer): return content - def get_form(self, view, method, request): + def show_form_for_method(self, view, method, request, obj): """ - Get a form, possibly bound to either the input or output data. - In the absence on of the Resource having an associated form then - provide a form that can be used to submit arbitrary content. + Returns True if a form should be shown for this method. """ if not method in view.allowed_methods: return # Not a valid method @@ -236,24 +299,13 @@ class BrowsableAPIRenderer(BaseRenderer): request = clone_request(request, method) try: - if not view.has_permission(request): + if not view.has_permission(request, obj): return # Don't have permission except: return # Don't have permission and exception explicitly raise + return True - if method == 'DELETE' or method == 'OPTIONS': - return True # Don't actually need to return a form - - if (not getattr(view, 'get_serializer', None) or - not parsers.FormParser in getattr(view, 'parser_classes')): - media_types = [parser.media_type for parser in view.parser_classes] - return self.get_generic_content_form(media_types) - - ##### - # TODO: This is a little bit of a hack. Actually we'd like to remove - # this and just render serializer fields to html directly. - - # We need to map our Fields to Django's Fields. + def serializer_to_form_fields(self, serializer): field_mapping = { serializers.FloatField: forms.FloatField, serializers.IntegerField: forms.IntegerField, @@ -261,34 +313,72 @@ class BrowsableAPIRenderer(BaseRenderer): serializers.DateField: forms.DateField, serializers.EmailField: forms.EmailField, serializers.CharField: forms.CharField, + serializers.ChoiceField: forms.ChoiceField, serializers.BooleanField: forms.BooleanField, - serializers.PrimaryKeyRelatedField: forms.ModelChoiceField, - serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField + serializers.PrimaryKeyRelatedField: forms.ChoiceField, + serializers.ManyPrimaryKeyRelatedField: forms.MultipleChoiceField, + serializers.SlugRelatedField: forms.ChoiceField, + serializers.ManySlugRelatedField: forms.MultipleChoiceField, + serializers.HyperlinkedRelatedField: forms.ChoiceField, + serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField } - # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python fields = {} - obj, data = None, None - if getattr(view, 'object', None): - obj = view.object - - serializer = view.get_serializer(instance=obj) for k, v in serializer.get_fields(True).items(): - if getattr(v, 'readonly', True): + if getattr(v, 'read_only', True): continue kwargs = {} - if getattr(v, 'queryset', None): - kwargs['queryset'] = getattr(v, 'queryset', None) + kwargs['required'] = v.required + + #if getattr(v, 'queryset', None): + # kwargs['queryset'] = v.queryset + + if getattr(v, 'choices', None) is not None: + kwargs['choices'] = v.choices + + if getattr(v, 'widget', None): + widget = copy.deepcopy(v.widget) + kwargs['widget'] = widget + + if getattr(v, 'default', None) is not None: + kwargs['initial'] = v.default + + kwargs['label'] = k try: fields[k] = field_mapping[v.__class__](**kwargs) except KeyError: - fields[k] = forms.CharField() + if getattr(v, 'choices', None) is not None: + fields[k] = forms.ChoiceField(**kwargs) + else: + fields[k] = forms.CharField(**kwargs) + return fields + + def get_form(self, view, method, request): + """ + Get a form, possibly bound to either the input or output data. + In the absence on of the Resource having an associated form then + provide a form that can be used to submit arbitrary content. + """ + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + + if method == 'DELETE' or method == 'OPTIONS': + return True # Don't actually need to return a form + + if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: + media_types = [parser.media_type for parser in view.parser_classes] + return self.get_generic_content_form(media_types) + + serializer = view.get_serializer(instance=obj) + fields = self.serializer_to_form_fields(serializer) + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) - if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted - data = serializer.data + data = (obj is not None) and serializer.data or None form_instance = OnTheFlyForm(data) return form_instance diff --git a/rest_framework/request.py b/rest_framework/request.py index 6f9cf09a..a1827ba4 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -21,8 +21,8 @@ def is_form_media_type(media_type): Return True if the media type is a valid form media type. """ base_media_type, params = parse_header(media_type) - return base_media_type == 'application/x-www-form-urlencoded' or \ - base_media_type == 'multipart/form-data' + return (base_media_type == 'application/x-www-form-urlencoded' or + base_media_type == 'multipart/form-data') class Empty(object): @@ -88,16 +88,11 @@ class Request(object): self._stream = Empty if self.parser_context is None: - self.parser_context = self._default_parser_context(request) + self.parser_context = {} + self.parser_context['request'] = self def _default_negotiator(self): - return api_settings.DEFAULT_CONTENT_NEGOTIATION() - - def _default_parser_context(self, request): - return { - 'upload_handlers': request.upload_handlers, - 'meta': request.META, - } + return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() @property def method(self): @@ -265,15 +260,19 @@ class Request(object): May raise an `UnsupportedMediaType`, or `ParseError` exception. """ - if self.stream is None or self.content_type is None: + stream = self.stream + media_type = self.content_type + + if stream is None or media_type is None: return (None, None) - parser = self.negotiator.select_parser(self.parsers, self.content_type) + parser = self.negotiator.select_parser(self, self.parsers) if not parser: - raise exceptions.UnsupportedMediaType(self.content_type) + raise exceptions.UnsupportedMediaType(media_type) + + parsed = parser.parse(stream, media_type, self.parser_context) - parsed = parser.parse(self.stream, self.parser_context) # Parser classes may return the raw data, or a # DataAndFiles object. Unpack the result as required. try: diff --git a/rest_framework/resources.py b/rest_framework/resources.py deleted file mode 100644 index dd8a5471..00000000 --- a/rest_framework/resources.py +++ /dev/null @@ -1,96 +0,0 @@ -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -from functools import update_wrapper -import inspect -from django.utils.decorators import classonlymethod -from rest_framework import views, generics - - -def wrapped(source, dest): - """ - Copy public, non-method attributes from source to dest, and return dest. - """ - for attr in [attr for attr in dir(source) - if not attr.startswith('_') and not inspect.ismethod(attr)]: - setattr(dest, attr, getattr(source, attr)) - return dest - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ResourceMixin(object): - """ - Clone Django's `View.as_view()` behaviour *except* using REST framework's - 'method -> action' binding for resources. - """ - - @classonlymethod - def as_view(cls, actions, **initkwargs): - """ - Main entry point for a request-response process. - """ - # sanitize keyword arguments - for key in initkwargs: - if key in cls.http_method_names: - raise TypeError("You tried to pass in the %s method name as a " - "keyword argument to %s(). Don't do that." - % (key, cls.__name__)) - if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r" % ( - cls.__name__, key)) - - def view(request, *args, **kwargs): - self = cls(**initkwargs) - - # Bind methods to actions - for method, action in actions.items(): - handler = getattr(self, action) - setattr(self, method, handler) - - # As you were, solider. - if hasattr(self, 'get') and not hasattr(self, 'head'): - self.head = self.get - return self.dispatch(request, *args, **kwargs) - - # take name and docstring from class - update_wrapper(view, cls, updated=()) - - # and possible attributes set by decorators - # like csrf_exempt from dispatch - update_wrapper(view, cls.dispatch, assigned=()) - return view - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class Resource(ResourceMixin, views.APIView): - pass - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ModelResource(ResourceMixin, views.APIView): - # TODO: Actually delegation won't work - root_class = generics.ListCreateAPIView - detail_class = generics.RetrieveUpdateDestroyAPIView - - def root_view(self): - return wrapped(self, self.root_class()) - - def detail_view(self): - return wrapped(self, self.detail_class()) - - def list(self, request, *args, **kwargs): - return self.root_view().list(request, args, kwargs) - - def create(self, request, *args, **kwargs): - return self.root_view().create(request, args, kwargs) - - def retrieve(self, request, *args, **kwargs): - return self.detail_view().retrieve(request, args, kwargs) - - def update(self, request, *args, **kwargs): - return self.detail_view().update(request, args, kwargs) - - def destroy(self, request, *args, **kwargs): - return self.detail_view().destroy(request, args, kwargs) diff --git a/rest_framework/response.py b/rest_framework/response.py index 7a459c8f..0de01204 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -9,7 +9,8 @@ class Response(SimpleTemplateResponse): """ def __init__(self, data=None, status=200, - template_name=None, headers=None): + template_name=None, headers=None, + exception=False): """ Alters the init arguments slightly. For example, drop 'template_name', and instead use 'data'. @@ -21,6 +22,7 @@ class Response(SimpleTemplateResponse): self.data = data self.headers = headers and headers[:] or [] self.template_name = template_name + self.exception = exception @property def rendered_content(self): @@ -45,3 +47,13 @@ class Response(SimpleTemplateResponse): # TODO: Deprecate and use a template tag instead # TODO: Status code text for RFC 6585 status codes return STATUS_CODE_TEXT.get(self.status_code, '') + + def __getstate__(self): + """ + Remove attributes from the response that shouldn't be cached + """ + state = super(Response, self).__getstate__() + for key in ('accepted_renderer', 'renderer_context', 'data'): + if key in state: + del state[key] + return state diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index ba663f98..c9db02f0 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse from django.utils.functional import lazy -def reverse(viewname, *args, **kwargs): +def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ Same as `django.core.urlresolvers.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ - request = kwargs.pop('request', None) - url = django_reverse(viewname, *args, **kwargs) + if format is not None: + kwargs = kwargs or {} + kwargs['format'] = format + url = django_reverse(viewname, args=args, kwargs=kwargs, **extra) if request: return request.build_absolute_uri(url) return url diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py index ea2e3d45..0ce379eb 100755 --- a/rest_framework/runtests/runcoverage.py +++ b/rest_framework/runtests/runcoverage.py @@ -32,10 +32,10 @@ def main(): 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', DeprecationWarning ) - failures = TestRunner(['rest_framework']) + failures = TestRunner(['tests']) else: test_runner = TestRunner() - failures = test_runner.run_tests(['rest_framework']) + failures = test_runner.run_tests(['tests']) cov.stop() # Discover the list of all modules that we should test coverage for diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index b2438c9b..1bd0a5fc 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -32,7 +32,7 @@ def main(): else: print usage() sys.exit(1) - failures = test_runner.run_tests(['rest_framework' + test_case]) + failures = test_runner.run_tests(['tests' + test_case]) sys.exit(failures) diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 67de82c8..dd5d9dc3 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -21,6 +21,12 @@ DATABASES = { } } +CACHES = { + 'default': { + 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + } +} + # Local time zone for this installation. Choices can be found here: # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name # although not all choices may be available on all operating systems. @@ -91,6 +97,7 @@ INSTALLED_APPS = ( # 'django.contrib.admindocs', 'rest_framework', 'rest_framework.authtoken', + 'rest_framework.tests' ) STATIC_URL = '/static/' @@ -100,13 +107,6 @@ import django if django.VERSION < (1, 3): INSTALLED_APPS += ('staticfiles',) -# OAuth support is optional, so we only test oauth if it's installed. -try: - import oauth_provider -except ImportError: - pass -else: - INSTALLED_APPS += ('oauth_provider',) # If we're running on the Jenkins server we want to archive the coverage reports as XML. import os diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 13f8cde2..95145d58 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -3,8 +3,18 @@ import datetime import types from decimal import Decimal from django.db import models +from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model + +# Note: We do the following so that users of the framework can use this style: +# +# example_field = serializers.CharField(...) +# +# This helps keep the seperation between model fields, form fields, and +# serializer fields more explicit. + + from rest_framework.fields import * @@ -22,10 +32,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata): pass -class RecursionOccured(BaseException): - pass - - def _is_protected_type(obj): """ True if the object is a native datatype that does not need to @@ -33,10 +39,10 @@ def _is_protected_type(obj): """ return isinstance(obj, ( types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) + int, long, + datetime.datetime, datetime.date, datetime.time, + float, Decimal, + basestring) ) @@ -73,7 +79,7 @@ class SerializerOptions(object): Meta class options for Serializer """ def __init__(self, meta): - self.nested = getattr(meta, 'nested', False) + self.depth = getattr(meta, 'depth', 0) self.fields = getattr(meta, 'fields', ()) self.exclude = getattr(meta, 'exclude', ()) @@ -85,14 +91,13 @@ class BaseSerializer(Field): _options_class = SerializerOptions _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations. - def __init__(self, data=None, instance=None, context=None, **kwargs): + def __init__(self, instance=None, data=None, context=None, **kwargs): super(BaseSerializer, self).__init__(**kwargs) - self.fields = copy.deepcopy(self.base_fields) self.opts = self._options_class(self.Meta) + self.fields = copy.deepcopy(self.base_fields) self.parent = None self.root = None - self.stack = [] self.context = context or {} self.init_data = data @@ -104,13 +109,13 @@ class BaseSerializer(Field): ##### # Methods to determine which fields to use when (de)serializing objects. - def default_fields(self, serialize, obj=None, data=None, nested=False): + def default_fields(self, nested=False): """ Return the complete set of default fields for the object, as a dict. """ return {} - def get_fields(self, serialize, obj=None, data=None, nested=False): + def get_fields(self, nested=False): """ Returns the complete set of fields for the object as a dict. @@ -123,10 +128,10 @@ class BaseSerializer(Field): for key, field in self.fields.items(): ret[key] = field # Set up the field - field.initialize(parent=self) + field.initialize(parent=self, field_name=key) # Add in the default fields - fields = self.default_fields(serialize, obj, data, nested) + fields = self.default_fields(nested) for key, val in fields.items(): if key not in ret: ret[key] = val @@ -148,17 +153,14 @@ class BaseSerializer(Field): ##### # Field methods - used when the serializer class is itself used as a field. - def initialize(self, parent): + def initialize(self, parent, field_name): """ Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth and recursion. + of state so that we can deal with handling maximum depth. """ - super(BaseSerializer, self).initialize(parent) - self.stack = parent.stack[:] - if parent.opts.nested and not isinstance(parent.opts.nested, bool): - self.opts.nested = parent.opts.nested - 1 - else: - self.opts.nested = parent.opts.nested + super(BaseSerializer, self).initialize(parent, field_name) + if parent.opts.depth: + self.opts.depth = parent.opts.depth - 1 ##### # Methods to convert or revert from objects <--> primative representations. @@ -174,21 +176,13 @@ class BaseSerializer(Field): Core of serialization. Convert an object into a dictionary of serialized field values. """ - if obj in self.stack and not self.source == '*': - raise RecursionOccured() - self.stack.append(obj) - ret = self._dict_class() ret.fields = {} - fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested) + fields = self.get_fields(nested=bool(self.opts.depth)) for field_name, field in fields.items(): key = self.get_field_key(field_name) - try: - value = field.field_to_native(obj, field_name) - except RecursionOccured: - field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name] - value = field.field_to_native(obj, field_name) + value = field.field_to_native(obj, field_name) ret[key] = value ret.fields[key] = field return ret @@ -198,7 +192,7 @@ class BaseSerializer(Field): Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested) + fields = self.get_fields(nested=bool(self.opts.depth)) reverted_data = {} for field_name, field in fields.items(): try: @@ -208,6 +202,35 @@ class BaseSerializer(Field): return reverted_data + def perform_validation(self, attrs): + """ + Run `validate_<fieldname>()` and `validate()` methods on the serializer + """ + # TODO: refactor this so we're not determining the fields again + fields = self.get_fields(nested=bool(self.opts.depth)) + + for field_name, field in fields.items(): + try: + validate_method = getattr(self, 'validate_%s' % field_name, None) + if validate_method: + source = field.source or field_name + attrs = validate_method(attrs, source) + except ValidationError as err: + self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) + + try: + attrs = self.validate(attrs) + except ValidationError as err: + self._errors['non_field_errors'] = err.messages + + return attrs + + def validate(self, attrs): + """ + Stub method, to be overridden in Serializer subclasses + """ + return attrs + def restore_object(self, attrs, instance=None): """ Deserialize a dictionary of attributes into an object instance. @@ -223,11 +246,8 @@ class BaseSerializer(Field): """ Serialize objects -> primatives. """ - if isinstance(obj, dict): - return dict([(key, self.to_native(val)) - for (key, val) in obj.items()]) - elif hasattr(obj, '__iter__'): - return [self.to_native(item) for item in obj] + if hasattr(obj, '__iter__'): + return [self.convert_object(item) for item in obj] return self.convert_object(obj) def from_native(self, data): @@ -241,17 +261,31 @@ class BaseSerializer(Field): self._errors = {} if data is not None: attrs = self.restore_fields(data) + attrs = self.perform_validation(attrs) else: - self._errors['non_field_errors'] = 'No input provided' + self._errors['non_field_errors'] = ['No input provided'] if not self._errors: return self.restore_object(attrs, instance=getattr(self, 'object', None)) + def field_to_native(self, obj, field_name): + """ + Override default so that we can apply ModelSerializer as a nested + field to relationships. + """ + obj = getattr(obj, self.source or field_name) + + # If the object has an "all" method, assume it's a relationship + if is_simple_callable(getattr(obj, 'all', None)): + return [self.to_native(item) for item in obj.all()] + + return self.to_native(obj) + @property def errors(self): """ Run deserialization and return error data, - setting self.object if no errors occured. + setting self.object if no errors occurred. """ if self._errors is None: obj = self.from_native(self.init_data) @@ -295,17 +329,7 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions - def field_to_native(self, obj, field_name): - """ - Override default so that we can apply ModelSerializer as a nested - field to relationships. - """ - obj = getattr(obj, self.source or field_name) - if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): - return [self.to_native(item) for item in obj.all()] - return self.to_native(obj) - - def default_fields(self, serialize, obj=None, data=None, nested=False): + def default_fields(self, nested=False): """ Return all the fields that should be serialized for the model. """ @@ -342,7 +366,7 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: - field.initialize(parent=self) + field.initialize(parent=self, field_name=model_field.name) ret[model_field.name] = field return ret @@ -374,6 +398,25 @@ class ModelSerializer(Serializer): """ Creates a default instance of a basic non-relational field. """ + kwargs = {} + + kwargs['blank'] = model_field.blank + + if model_field.null: + kwargs['required'] = False + + if model_field.has_default(): + kwargs['required'] = False + kwargs['default'] = model_field.get_default() + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + # TODO: TypedChoiceField? + if model_field.flatchoices: # This ModelField contains choices + kwargs['choices'] = model_field.flatchoices + return ChoiceField(**kwargs) + field_mapping = { models.FloatField: FloatField, models.IntegerField: IntegerField, @@ -389,14 +432,9 @@ class ModelSerializer(Serializer): models.BooleanField: BooleanField, } try: - ret = field_mapping[model_field.__class__]() + return field_mapping[model_field.__class__](**kwargs) except KeyError: - ret = ModelField(model_field=model_field) - - if model_field.default: - ret.required = False - - return ret + return ModelField(model_field=model_field, **kwargs) def restore_object(self, attrs, instance=None): """ @@ -409,6 +447,13 @@ class ModelSerializer(Serializer): setattr(instance, key, val) return instance + # Reverse relations + for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.m2m_data[field_name] = attrs.pop(field_name) + + # Forward relations for field in self.opts.model._meta.many_to_many: if field.name in attrs: self.m2m_data[field.name] = attrs.pop(field.name) @@ -420,7 +465,7 @@ class ModelSerializer(Serializer): """ self.object.save() - if self.m2m_data and save_m2m: + if getattr(self, 'm2m_data', None) and save_m2m: for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8bbb2f75..906a7cf6 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -3,11 +3,11 @@ Settings for REST framework are all namespaced in the REST_FRAMEWORK setting. For example your project's `settings.py` file might look like this: REST_FRAMEWORK = { - 'DEFAULT_RENDERERS': ( + 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.YAMLRenderer', ) - 'DEFAULT_PARSERS': ( + 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', 'rest_framework.parsers.YAMLParser', ) @@ -24,31 +24,38 @@ from django.utils import importlib USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { - 'DEFAULT_RENDERERS': ( + 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', ), - 'DEFAULT_PARSERS': ( + 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', 'rest_framework.parsers.FormParser', 'rest_framework.parsers.MultiPartParser' ), - 'DEFAULT_AUTHENTICATION': ( + 'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' ), - 'DEFAULT_PERMISSIONS': (), - 'DEFAULT_THROTTLES': (), - 'DEFAULT_CONTENT_NEGOTIATION': + 'DEFAULT_PERMISSION_CLASSES': ( + 'rest_framework.permissions.AllowAny', + ), + 'DEFAULT_THROTTLE_CLASSES': ( + ), + + 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + 'DEFAULT_MODEL_SERIALIZER_CLASS': + 'rest_framework.serializers.ModelSerializer', + 'DEFAULT_PAGINATION_SERIALIZER_CLASS': + 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, }, - - 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer', - 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer', 'PAGINATE_BY': None, + 'FILTER_BACKEND': None, 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -65,14 +72,15 @@ DEFAULTS = { # List of settings that may be in string import notation. IMPORT_STRINGS = ( - 'DEFAULT_RENDERERS', - 'DEFAULT_PARSERS', - 'DEFAULT_AUTHENTICATION', - 'DEFAULT_PERMISSIONS', - 'DEFAULT_THROTTLES', - 'DEFAULT_CONTENT_NEGOTIATION', - 'MODEL_SERIALIZER', - 'PAGINATION_SERIALIZER', + 'DEFAULT_RENDERER_CLASSES', + 'DEFAULT_PARSER_CLASSES', + 'DEFAULT_AUTHENTICATION_CLASSES', + 'DEFAULT_PERMISSION_CLASSES', + 'DEFAULT_THROTTLE_CLASSES', + 'DEFAULT_CONTENT_NEGOTIATION_CLASS', + 'DEFAULT_MODEL_SERIALIZER_CLASS', + 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) @@ -111,7 +119,7 @@ class APISettings(object): For example: from rest_framework.settings import api_settings - print api_settings.DEFAULT_RENDERERS + print api_settings.DEFAULT_RENDERER_CLASSES Any setting with string import paths will be automatically resolved and return the class, rather than the string literal. @@ -136,8 +144,15 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) + self.validate_setting(attr, val) + # Cache the result setattr(self, attr, val) return val + def validate_setting(self, attr, val): + if attr == 'FILTER_BACKEND' and val is not None: + # Make sure we can initilize the class + val() + api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index 739b9300..b2e41b99 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -32,6 +32,17 @@ h2, h3 { margin-right: 1em; } +ul.breadcrumb { + margin: 58px 0 0 0; +} + +form select, form input, form textarea { + width: 90%; +} + +form select[multiple] { + height: 150px; +} /* To allow tooltips to work on disabled elements */ .disabled-tooltip-shield { position: absolute; @@ -55,6 +66,7 @@ pre { .page-header { border-bottom: none; padding-bottom: 0px; + margin-bottom: 20px; } @@ -65,7 +77,7 @@ html{ background: none; } -body, .navbar .navbar-inner .container-fluid{ +body, .navbar .navbar-inner .container-fluid { max-width: 1150px; margin: 0 auto; } @@ -76,13 +88,14 @@ body{ } #content{ - margin: 40px 0 0 0; + margin: 0; } /* custom navigation styles */ .wrapper .navbar{ - width:100%; + width: 100%; position: absolute; - left:0; + left: 0; + top: 0; } .navbar .navbar-inner{ diff --git a/rest_framework/status.py b/rest_framework/status.py index f3a5e481..a1eb48da 100644 --- a/rest_framework/status.py +++ b/rest_framework/status.py @@ -49,4 +49,4 @@ HTTP_502_BAD_GATEWAY = 502 HTTP_503_SERVICE_UNAVAILABLE = 503 HTTP_504_GATEWAY_TIMEOUT = 504 HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 -HTTP_511_NETWORD_AUTHENTICATION_REQUIRED = 511 +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 5ac6ef67..fb0e19f0 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -109,7 +109,7 @@ <div class="content-main"> <div class="page-header"><h1>{{ name }}</h1></div> - <p class="resource-description">{{ description }}</p> + {{ description }} <div class="request-info"> <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> @@ -131,12 +131,12 @@ {% csrf_token %} {{ post_form.non_field_errors }} {% for field in post_form %} - <div class="control-group {% if field.errors %}error{% endif %}"> + <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> {{ field.label_tag|add_class:"control-label" }} <div class="controls"> - {{ field|add_class:"input-xlarge" }} + {{ field }} <span class="help-inline">{{ field.help_text }}</span> - {{ field.errors|add_class:"help-block" }} + <!--{{ field.errors|add_class:"help-block" }}--> </div> </div> {% endfor %} @@ -156,12 +156,12 @@ {% csrf_token %} {{ put_form.non_field_errors }} {% for field in put_form %} - <div class="control-group {% if field.errors %}error{% endif %}"> + <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> {{ field.label_tag|add_class:"control-label" }} <div class="controls"> - {{ field|add_class:"input-xlarge" }} + {{ field }} <span class='help-inline'>{{ field.help_text }}</span> - {{ field.errors|add_class:"help-block" }} + <!--{{ field.errors|add_class:"help-block" }}--> </div> </div> {% endfor %} diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index 65af512e..c1271399 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -3,42 +3,50 @@ <html> <head> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/> + <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> + <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> + <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> </head> - <body class="login"> + <body class="container"> - <div id="container"> - - <div id="header"> - <div id="branding"> - <h1 id="site-name">Django REST framework</h1> +<div class="container-fluid" style="margin-top: 30px"> + <div class="row-fluid"> + + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + <h3 style="margin: 0 0 20px;">Django REST framework</h3> </div> - </div> + </div><!-- /row fluid --> - <div id="content" class="colM"> - <div id="content-main"> - <form method="post" action="{% url 'rest_framework:login' %}" id="login-form"> + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> {% csrf_token %} - <div class="form-row"> - <label for="id_username">Username:</label> {{ form.username }} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> </div> - <div class="form-row"> - <label for="id_password">Password:</label> {{ form.password }} - <input type="hidden" name="next" value="{{ next }}" /> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> </div> - <div class="form-row"> - <label> </label><input type="submit" value="Log in"> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> </div> </form> - <script type="text/javascript"> - document.getElementById('id_username').focus() - </script> </div> - <br class="clear"> - </div> + </div><!-- /row fluid --> + </div><!--/span--> - <div id="footer"></div> + </div><!-- /.row-fluid --> + </div> </div> </body> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c9b6eb10..4e0181ee 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -11,6 +11,18 @@ import string register = template.Library() +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlsplit(url) + query_dict = QueryDict(query).copy() + query_dict[key] = val + query = query_dict.urlencode() + return urlunsplit((scheme, netloc, path, query, fragment)) + + # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') @@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '| trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') -# Helper function for 'add_query_param' -def replace_query_param(url, key, val): - """ - Given a URL and a key/val pair, set or replace an item in the query - parameters of the URL, and return the new URL. - """ - (scheme, netloc, path, query, fragment) = urlsplit(url) - query_dict = QueryDict(query).copy() - query_dict[key] = val - query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) - - # And the template tags themselves... @register.simple_tag diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py index adeaf6da..e69de29b 100644 --- a/rest_framework/tests/__init__.py +++ b/rest_framework/tests/__init__.py @@ -1,13 +0,0 @@ -""" -Force import of all modules in this package in order to get the standard test -runner to pick up the tests. Yowzers. -""" -import os - -modules = [filename.rsplit('.', 1)[0] - for filename in os.listdir(os.path.dirname(__file__)) - if filename.endswith('.py') and not filename.startswith('_')] -__test__ = dict() - -for module in modules: - exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py new file mode 100644 index 00000000..af2e6c2e --- /dev/null +++ b/rest_framework/tests/filterset.py @@ -0,0 +1,168 @@ +import datetime +from decimal import Decimal +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, status, filters +from rest_framework.compat import django_filters +from rest_framework.tests.models import FilterableItem, BasicModel + +factory = RequestFactory() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backend = filters.DjangoFilterBackend + + +class IntegrationTestFiltering(TestCase): + """ + Integration tests for filtered list views. + """ + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEquals(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEquals(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index 1d7e33bc..bc7378e1 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -25,7 +25,7 @@ class TestGenericRelations(TestCase): model = Bookmark exclude = ('id',) - serializer = BookmarkSerializer(instance=self.bookmark) + serializer = BookmarkSerializer(self.bookmark) expected = { 'tags': [u'django', u'python'], 'url': u'https://www.djangoproject.com/' diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f4263478..a8279ef2 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -2,7 +2,7 @@ from django.test import TestCase from django.test.client import RequestFactory from django.utils import simplejson as json from rest_framework import generics, serializers, status -from rest_framework.tests.models import BasicModel, Comment +from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel factory = RequestFactory() @@ -22,6 +22,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel +class SlugSerializer(serializers.ModelSerializer): + slug = serializers.Field() # read only + + class Meta: + model = SlugBasedModel + exclude = ('id',) + + +class SlugBasedInstanceView(InstanceView): + """ + A model with a slug-field. + """ + model = SlugBasedModel + serializer_class = SlugSerializer + + class TestRootView(TestCase): def setUp(self): """ @@ -129,6 +145,7 @@ class TestInstanceView(TestCase): for obj in self.objects.all() ] self.view = InstanceView.as_view() + self.slug_based_view = SlugBasedInstanceView.as_view() def test_get_instance_view(self): """ @@ -198,7 +215,7 @@ class TestInstanceView(TestCase): def test_put_cannot_set_id(self): """ - POST requests to create a new object should not be able to set the id. + PUT requests to create a new object should not be able to set the id. """ content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), @@ -219,11 +236,39 @@ class TestInstanceView(TestCase): request = factory.put('/1', json.dumps(content), content_type='application/json') response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) self.assertEquals(updated.text, 'foobar') + def test_put_as_create_on_id_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if it doesn't exist. + """ + content = {'text': 'foobar'} + # pk fields can not be created on demand, only the database can set th pk for a new object + request = factory.put('/5', json.dumps(content), + content_type='application/json') + response = self.view(request, pk=5).render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + new_obj = self.objects.get(pk=5) + self.assertEquals(new_obj.text, 'foobar') + + def test_put_as_create_on_slug_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. + """ + content = {'text': 'foobar'} + request = factory.put('/test_slug', json.dumps(content), + content_type='application/json') + response = self.slug_based_view(request, slug='test_slug').render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) + new_obj = SlugBasedModel.objects.get(slug='test_slug') + self.assertEquals(new_obj.text, 'foobar') + # Regression test for #285 diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index da2f83c3..4caed59e 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -1,14 +1,16 @@ +from django.core.exceptions import PermissionDenied from django.conf.urls.defaults import patterns, url +from django.http import Http404 from django.test import TestCase from django.template import TemplateDoesNotExist, Template import django.template.loader from rest_framework.decorators import api_view, renderer_classes -from rest_framework.renderers import HTMLRenderer +from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.response import Response @api_view(('GET',)) -@renderer_classes((HTMLRenderer,)) +@renderer_classes((TemplateHTMLRenderer,)) def example(request): """ A view that can returns an HTML representation. @@ -17,12 +19,26 @@ def example(request): return Response(data, template_name='example.html') +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def permission_denied(request): + raise PermissionDenied() + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def not_found(request): + raise Http404() + + urlpatterns = patterns('', url(r'^$', example), + url(r'^permission_denied$', permission_denied), + url(r'^not_found$', not_found), ) -class HTMLRendererTests(TestCase): +class TemplateHTMLRendererTests(TestCase): urls = 'rest_framework.tests.htmlrenderer' def setUp(self): @@ -48,3 +64,52 @@ class HTMLRendererTests(TestCase): response = self.client.get('/') self.assertContains(response, "example: foobar") self.assertEquals(response['Content-Type'], 'text/html') + + def test_not_found_html_view(self): + response = self.client.get('/not_found') + self.assertEquals(response.status_code, 404) + self.assertEquals(response.content, "404 Not Found") + self.assertEquals(response['Content-Type'], 'text/html') + + def test_permission_denied_html_view(self): + response = self.client.get('/permission_denied') + self.assertEquals(response.status_code, 403) + self.assertEquals(response.content, "403 Forbidden") + self.assertEquals(response['Content-Type'], 'text/html') + + +class TemplateHTMLRendererExceptionTests(TestCase): + urls = 'rest_framework.tests.htmlrenderer' + + def setUp(self): + """ + Monkeypatch get_template + """ + self.get_template = django.template.loader.get_template + + def get_template(template_name): + if template_name == '404.html': + return Template("404: {{ detail }}") + if template_name == '403.html': + return Template("403: {{ detail }}") + raise TemplateDoesNotExist(template_name) + + django.template.loader.get_template = get_template + + def tearDown(self): + """ + Revert monkeypatching + """ + django.template.loader.get_template = self.get_template + + def test_not_found_html_view_with_template(self): + response = self.client.get('/not_found') + self.assertEquals(response.status_code, 404) + self.assertEquals(response.content, "404: Not found") + self.assertEquals(response['Content-Type'], 'text/html') + + def test_permission_denied_html_view_with_template(self): + response = self.client.get('/permission_denied') + self.assertEquals(response.status_code, 403) + self.assertEquals(response.content, "403: Permission denied") + self.assertEquals(response['Content-Type'], 'text/html') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 5532a8ee..f71e2e28 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -2,11 +2,28 @@ from django.conf.urls.defaults import patterns, url from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, serializers -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel +from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo factory = RequestFactory() +class BlogPostCommentSerializer(serializers.ModelSerializer): + text = serializers.CharField() + blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') + + class Meta: + model = BlogPostComment + fields = ('text', 'blog_post_url') + + +class PhotoSerializer(serializers.Serializer): + description = serializers.CharField() + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + + def restore_object(self, attrs, instance=None): + return Photo(**attrs) + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -32,12 +49,34 @@ class ManyToManyDetail(generics.RetrieveAPIView): model_serializer_class = serializers.HyperlinkedModelSerializer +class BlogPostCommentListCreate(generics.ListCreateAPIView): + model = BlogPostComment + serializer_class = BlogPostCommentSerializer + + +class BlogPostDetail(generics.RetrieveAPIView): + model = BlogPost + + +class PhotoListCreate(generics.ListCreateAPIView): + model = Photo + model_serializer_class = PhotoSerializer + + +class AlbumDetail(generics.RetrieveAPIView): + model = Album + + urlpatterns = patterns('', url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), + url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), + url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), + url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), + url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list') ) @@ -124,3 +163,51 @@ class TestManyToManyHyperlinkedView(TestCase): response = self.detail_view(request, pk=1).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data[0]) + + +class TestCreateWithForeignKeys(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create a blog post + """ + self.post = BlogPost.objects.create(title="Test post") + self.create_view = BlogPostCommentListCreate.as_view() + + def test_create_comment(self): + + data = { + 'text': 'A test comment', + 'blog_post_url': 'http://testserver/posts/1/' + } + + request = factory.post('/comments/', data=data) + response = self.create_view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(self.post.blogpostcomment_set.count(), 1) + self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') + + +class TestCreateWithForeignKeysAndCustomSlug(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create an Album + """ + self.post = Album.objects.create(title='test-album') + self.list_create_view = PhotoListCreate.as_view() + + def test_create_photo(self): + + data = { + 'description': 'A test photo', + 'album_url': 'http://testserver/albums/test-album/' + } + + request = factory.post('/photos/', data=data) + response = self.list_create_view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(self.post.photo_set.count(), 1) + self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 6a758f0c..a2aba5be 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -40,7 +40,7 @@ class RESTFrameworkModel(models.Model): Base for test models that sets app_label, so they play nicely. """ class Meta: - app_label = 'rest_framework' + app_label = 'tests' abstract = True @@ -52,6 +52,11 @@ class BasicModel(RESTFrameworkModel): text = models.CharField(max_length=100) +class SlugBasedModel(RESTFrameworkModel): + text = models.CharField(max_length=100) + slug = models.SlugField(max_length=32) + + class DefaultValueModel(RESTFrameworkModel): text = models.CharField(default='foobar', max_length=100) @@ -63,6 +68,11 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) + +class ReadOnlyManyToManyModel(RESTFrameworkModel): + text = models.CharField(max_length=100, default='anchor') + rel = models.ManyToManyField(Anchor) + # Models to test generic relations @@ -85,9 +95,57 @@ class Bookmark(RESTFrameworkModel): tags = GenericRelation(TaggedItem) +# Model to test filtering. +class FilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): email = models.EmailField() content = models.CharField(max_length=200) created = models.DateTimeField(auto_now_add=True) + + +class ActionItem(RESTFrameworkModel): + title = models.CharField(max_length=200) + done = models.BooleanField(default=False) + + +# Models for reverse relations +class BlogPost(RESTFrameworkModel): + title = models.CharField(max_length=100) + + +class BlogPostComment(RESTFrameworkModel): + text = models.TextField() + blog_post = models.ForeignKey(BlogPost) + + +class Album(RESTFrameworkModel): + title = models.CharField(max_length=100, unique=True) + + +class Photo(RESTFrameworkModel): + description = models.TextField() + album = models.ForeignKey(Album) + + +class Person(RESTFrameworkModel): + name = models.CharField(max_length=10) + age = models.IntegerField(null=True, blank=True) + + @property + def info(self): + return { + 'name': self.name, + 'age': self.age, + } + + +# Model for issue #324 +class BlankFieldModel(RESTFrameworkModel): + title = models.CharField(max_length=100, blank=True) diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py index d8265b43..e06354ea 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/negotiation.py @@ -18,20 +18,20 @@ class TestAcceptedMediaType(TestCase): self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] self.negotiator = DefaultContentNegotiation() - def negotiate(self, request): - return self.negotiator.negotiate(request, self.renderers) + def select_renderer(self, request): + return self.negotiator.select_renderer(request, self.renderers) def test_client_without_accept_use_renderer(self): request = factory.get('/') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json') def test_client_underspecifies_accept_use_renderer(self): request = factory.get('/', HTTP_ACCEPT='*/*') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json') def test_client_overspecifies_accept_use_client(self): request = factory.get('/', HTTP_ACCEPT='application/json; indent=8') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json; indent=8') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index a939c9ef..713a7255 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,8 +1,12 @@ +import datetime +from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory -from rest_framework import generics, status, pagination -from rest_framework.tests.models import BasicModel +from django.utils import unittest +from rest_framework import generics, status, pagination, filters +from rest_framework.compat import django_filters +from rest_framework.tests.models import BasicModel, FilterableItem factory = RequestFactory() @@ -15,6 +19,21 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 +if django_filters: + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -22,7 +41,7 @@ class IntegrationTestPagination(TestCase): def setUp(self): """ - Create 26 BasicModel intances. + Create 26 BasicModel instances. """ for char in 'abcdefghijklmnopqrstuvwxyz': BasicModel(text=char * 3).save() @@ -62,6 +81,58 @@ class IntegrationTestPagination(TestCase): self.assertNotEquals(response.data['previous'], None) +class IntegrationTestPaginationAndFiltering(TestCase): + + def setUp(self): + """ + Create 50 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(26): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + self.view = FilterFieldsRootView.as_view() + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_paginated_filtered_root_view(self): + """ + GET requests to paginated filtered ListCreateAPIView should return + paginated results. The next and previous links should preserve the + filtered parameters. + """ + request = factory.get('/?decimal=15.20') + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[10:15]) + self.assertEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None) + + request = factory.get(response.data['previous']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + class UnitTestPagination(TestCase): """ Unit tests for pagination of primative objects. @@ -74,13 +145,13 @@ class UnitTestPagination(TestCase): self.last_page = paginator.page(3) def test_native_pagination(self): - serializer = pagination.PaginationSerializer(instance=self.first_page) + serializer = pagination.PaginationSerializer(self.first_page) self.assertEquals(serializer.data['count'], 26) self.assertEquals(serializer.data['next'], '?page=2') self.assertEquals(serializer.data['previous'], None) self.assertEquals(serializer.data['results'], self.objects[:10]) - serializer = pagination.PaginationSerializer(instance=self.last_page) + serializer = pagination.PaginationSerializer(self.last_page) self.assertEquals(serializer.data['count'], 26) self.assertEquals(serializer.data['next'], None) self.assertEquals(serializer.data['previous'], '?page=2') diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py new file mode 100644 index 00000000..44ae4040 --- /dev/null +++ b/rest_framework/tests/pk_relations.py @@ -0,0 +1,205 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +# ManyToMany + +class ManyToManyTarget(models.Model): + name = models.CharField(max_length=100) + + +class ManyToManySource(models.Model): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +class ManyToManyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManyPrimaryKeyRelatedField() + + class Meta: + model = ManyToManyTarget + + +class ManyToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManySource + + +# ForeignKey + +class ForeignKeyTarget(models.Model): + name = models.CharField(max_length=100) + + +class ForeignKeySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManyPrimaryKeyRelatedField(read_only=True) + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class PrimaryKeyManyToManyTests(TestCase): + def setUp(self): + for idx in range(1, 4): + target = ManyToManyTarget(name='target-%d' % idx) + target.save() + source = ManyToManySource(name='source-%d' % idx) + source.save() + for target in ManyToManyTarget.objects.all(): + source.targets.add(target) + + def test_many_to_many_retrieve(self): + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'targets': [1]}, + {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_retrieve(self): + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]} + ] + self.assertEquals(serializer.data, expected) + + def test_many_to_many_update(self): + data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} + instance = ManyToManySource.objects.get(pk=1) + serializer = ManyToManySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, + {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_update(self): + data = {'id': 1, 'name': u'target-1', 'sources': [1]} + instance = ManyToManyTarget.objects.get(pk=1) + serializer = ManyToManyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure target 1 is updated, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_create(self): + data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + serializer = ManyToManyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-4') + + # Ensure target 4 is added, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]}, + {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + ] + self.assertEquals(serializer.data, expected) + +class PrimaryKeyForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': u'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 2}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1} + ] + self.assertEquals(serializer.data, expected) + + # reverse foreign keys MUST be read_only + # In the general case they do not provide .remove() or .clear() + # and cannot be arbitrarily set. + + # def test_reverse_foreign_key_update(self): + # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # instance = ForeignKeyTarget.objects.get(pk=1) + # serializer = ForeignKeyTargetSerializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # self.assertEquals(serializer.data, data) + # serializer.save() + + # # Ensure target 1 is updated, and everything else is as expected + # queryset = ForeignKeyTarget.objects.all() + # serializer = ForeignKeyTargetSerializer(queryset) + # expected = [ + # {'id': 1, 'name': u'target-1', 'sources': [1]}, + # {'id': 2, 'name': u'target-2', 'sources': []}, + # ] + # self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index 48d8d9bd..9be4b114 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -1,6 +1,8 @@ +import pickle import re from django.conf.urls.defaults import patterns, url, include +from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory @@ -83,6 +85,7 @@ class HTMLView1(APIView): urlpatterns = patterns('', url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^cache$', MockGETView.as_view()), url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), url(r'^html$', HTMLView.as_view()), @@ -416,3 +419,89 @@ class XMLRendererTestCase(TestCase): self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) self.assertTrue(xml.endswith('</root>')) self.assertTrue(string in xml, '%r not in %r' % (string, xml)) + + +# Tests for caching issue, #346 +class CacheRenderTest(TestCase): + """ + Tests specific to caching responses + """ + + urls = 'rest_framework.tests.renderers' + + cache_key = 'just_a_cache_key' + + @classmethod + def _get_pickling_errors(cls, obj, seen=None): + """ Return any errors that would be raised if `obj' is pickled + Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 + """ + if seen == None: + seen = [] + try: + state = obj.__getstate__() + except AttributeError: + return + if state == None: + return + if isinstance(state,tuple): + if not isinstance(state[0],dict): + state=state[1] + else: + state=state[0].update(state[1]) + result = {} + for i in state: + try: + pickle.dumps(state[i],protocol=2) + except pickle.PicklingError: + if not state[i] in seen: + seen.append(state[i]) + result[i] = cls._get_pickling_errors(state[i],seen) + return result + + def http_resp(self, http_method, url): + """ + Simple wrapper for Client http requests + Removes the `client' and `request' attributes from as they are + added by django.test.client.Client and not part of caching + responses outside of tests. + """ + method = getattr(self.client, http_method) + resp = method(url) + del resp.client, resp.request + return resp + + def test_obj_pickling(self): + """ + Test that responses are properly pickled + """ + resp = self.http_resp('get', '/cache') + + # Make sure that no pickling errors occurred + self.assertEqual(self._get_pickling_errors(resp), {}) + + # Unfortunately LocMem backend doesn't raise PickleErrors but returns + # None instead. + cache.set(self.cache_key, resp) + self.assertTrue(cache.get(self.cache_key) is not None) + + def test_head_caching(self): + """ + Test caching of HEAD requests + """ + resp = self.http_resp('head', '/cache') + cache.set(self.cache_key, resp) + + cached_resp = cache.get(self.cache_key) + self.assertIsInstance(cached_resp, Response) + + def test_get_caching(self): + """ + Test caching of GET requests + """ + resp = self.http_resp('get', '/cache') + cache.set(self.cache_key, resp) + + cached_resp = cache.get(self.cache_key) + self.assertIsInstance(cached_resp, Response) + self.assertEqual(cached_resp.content, resp.content) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index f90bebf4..ff48f3fa 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -27,7 +27,7 @@ factory = RequestFactory() class PlainTextParser(BaseParser): media_type = 'text/plain' - def parse_stream(self, stream, parser_context=None): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 18b6af39..d7b75450 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') - def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 256987ad..8d1de429 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,7 +1,14 @@ import datetime from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * +from rest_framework.tests.models import (ActionItem, Anchor, BasicModel, + BlankFieldModel, BlogPost, CallableDefaultValueModel, DefaultValueModel, + ManyToManyModel, Person, ReadOnlyManyToManyModel) + + +class SubComment(object): + def __init__(self, sub_comment): + self.sub_comment = sub_comment class Comment(object): @@ -14,11 +21,16 @@ class Comment(object): return all([getattr(self, attr) == getattr(other, attr) for attr in ('email', 'content', 'created')]) + def get_sub_comment(self): + sub_comment = SubComment('And Merry Christmas!') + return sub_comment + class CommentSerializer(serializers.Serializer): email = serializers.EmailField() content = serializers.CharField(max_length=1000) created = serializers.DateTimeField() + sub_comment = serializers.Field(source='get_sub_comment.sub_comment') def restore_object(self, data, instance=None): if instance is None: @@ -28,6 +40,19 @@ class CommentSerializer(serializers.Serializer): return instance +class ActionItemSerializer(serializers.ModelSerializer): + class Meta: + model = ActionItem + + +class PersonSerializer(serializers.ModelSerializer): + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -38,36 +63,63 @@ class BasicTests(TestCase): self.data = { 'email': 'tom@example.com', 'content': 'Happy new year!', - 'created': datetime.datetime(2012, 1, 1) + 'created': datetime.datetime(2012, 1, 1), + 'sub_comment': 'This wont change' + } + self.expected = { + 'email': 'tom@example.com', + 'content': 'Happy new year!', + 'created': datetime.datetime(2012, 1, 1), + 'sub_comment': 'And Merry Christmas!' } + self.person_data = {'name': 'dwight', 'age': 35} + self.person = Person(**self.person_data) + self.person.save() def test_empty(self): serializer = CommentSerializer() expected = { 'email': '', 'content': '', - 'created': None + 'created': None, + 'sub_comment': '' } self.assertEquals(serializer.data, expected) def test_retrieve(self): - serializer = CommentSerializer(instance=self.comment) - expected = self.data - self.assertEquals(serializer.data, expected) + serializer = CommentSerializer(self.comment) + self.assertEquals(serializer.data, self.expected) def test_create(self): - serializer = CommentSerializer(self.data) + serializer = CommentSerializer(data=self.data) expected = self.comment self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) self.assertFalse(serializer.object is expected) + self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') def test_update(self): - serializer = CommentSerializer(self.data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) self.assertTrue(serializer.object is expected) + self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + + def test_model_fields_as_expected(self): + """ Make sure that the fields returned are the same as defined + in the Meta data + """ + serializer = PersonSerializer(self.person) + self.assertEquals(set(serializer.data.keys()), + set(['name', 'age', 'info'])) + + def test_field_with_dictionary(self): + """ Make sure that dictionaries from fields are left intact + """ + serializer = PersonSerializer(self.person) + expected = self.person_data + self.assertEquals(serializer.data['info'], expected) class ValidationTests(TestCase): @@ -82,14 +134,16 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } + self.actionitem = ActionItem('Some to do item', + ) def test_create(self): - serializer = CommentSerializer(self.data) + serializer = CommentSerializer(data=self.data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) def test_update(self): - serializer = CommentSerializer(self.data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=self.data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) @@ -98,10 +152,78 @@ class ValidationTests(TestCase): 'content': 'xxx', 'created': datetime.datetime(2012, 1, 1) } - serializer = CommentSerializer(data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) + def test_missing_bool_with_default(self): + """Make sure that a boolean value with a 'False' value is not + mistaken for not having a default.""" + data = { + 'title': 'Some action item', + #No 'done' value. + } + serializer = ActionItemSerializer(self.actionitem, data=data) + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.errors, {}) + + def test_field_validation(self): + + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = CommentSerializerWithFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = CommentSerializerWithFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + + def test_cross_field_validation(self): + + class CommentSerializerWithCrossFieldValidator(CommentSerializer): + + def validate(self, attrs): + if attrs["email"] not in attrs["content"]: + raise serializers.ValidationError("Email address not in content") + return attrs + + data = { + 'email': 'tom@example.com', + 'content': 'A comment from tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = CommentSerializerWithCrossFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'A comment from foo@bar.com' + + serializer = CommentSerializerWithCrossFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) + + def test_null_is_true_fields(self): + """ + Omitting a value for null-field should validate. + """ + serializer = PersonSerializer(data={'name': 'marko'}) + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.errors, {}) + class MetadataTests(TestCase): def test_empty(self): @@ -148,7 +270,7 @@ class ManyToManyTests(TestCase): Create an instance of a model with a ManyToMany relationship. """ data = {'rel': [self.anchor.id]} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) @@ -162,7 +284,7 @@ class ManyToManyTests(TestCase): new_anchor = Anchor() new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 1) @@ -175,7 +297,7 @@ class ManyToManyTests(TestCase): containing no items. """ data = {'rel': []} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) @@ -190,7 +312,7 @@ class ManyToManyTests(TestCase): new_anchor = Anchor() new_anchor.save() data = {'rel': []} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 1) @@ -204,7 +326,7 @@ class ManyToManyTests(TestCase): lists (eg form data). """ data = {'rel': ''} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) @@ -212,6 +334,61 @@ class ManyToManyTests(TestCase): self.assertEquals(list(instance.rel.all()), []) +class ReadOnlyManyToManyTests(TestCase): + def setUp(self): + class ReadOnlyManyToManySerializer(serializers.ModelSerializer): + rel = serializers.ManyRelatedField(read_only=True) + + class Meta: + model = ReadOnlyManyToManyModel + + self.serializer_class = ReadOnlyManyToManySerializer + + # An anchor instance to use for the relationship + self.anchor = Anchor() + self.anchor.save() + + # A model instance with a many to many relationship to the anchor + self.instance = ReadOnlyManyToManyModel() + self.instance.save() + self.instance.rel.add(self.anchor) + + # A serialized representation of the model instance + self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} + + def test_update(self): + """ + Attempt to update an instance of a model with a ManyToMany + relationship. Not updated due to read_only=True + """ + new_anchor = Anchor() + new_anchor.save() + data = {'rel': [self.anchor.id, new_anchor.id]} + serializer = self.serializer_class(self.instance, data=data) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEquals(instance.pk, 1) + # rel is still as original (1 entry) + self.assertEquals(list(instance.rel.all()), [self.anchor]) + + def test_update_without_relationship(self): + """ + Attempt to update an instance of a model where many to ManyToMany + relationship is not supplied. Not updated due to read_only=True + """ + new_anchor = Anchor() + new_anchor.save() + data = {} + serializer = self.serializer_class(self.instance, data=data) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEquals(instance.pk, 1) + # rel is still as original (1 entry) + self.assertEquals(list(instance.rel.all()), [self.anchor]) + + class DefaultValueTests(TestCase): def setUp(self): class DefaultValueSerializer(serializers.ModelSerializer): @@ -223,7 +400,7 @@ class DefaultValueTests(TestCase): def test_create_using_default(self): data = {} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -232,7 +409,7 @@ class DefaultValueTests(TestCase): def test_create_overriding_default(self): data = {'text': 'overridden'} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -251,7 +428,7 @@ class CallableDefaultValueTests(TestCase): def test_create_using_default(self): data = {} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -260,9 +437,87 @@ class CallableDefaultValueTests(TestCase): def test_create_overriding_default(self): data = {'text': 'overridden'} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) self.assertEquals(instance.pk, 1) self.assertEquals(instance.text, 'overridden') + + +class ManyRelatedTests(TestCase): + def setUp(self): + + class BlogPostCommentSerializer(serializers.Serializer): + text = serializers.CharField() + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + comments = BlogPostCommentSerializer(source='blogpostcomment_set') + + self.serializer_class = BlogPostSerializer + + def test_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + serializer = self.serializer_class(instance=post) + expected = { + 'title': 'Test blog post', + 'comments': [ + {'text': 'I hate this blog post'}, + {'text': 'I love this blog post'} + ] + } + + self.assertEqual(serializer.data, expected) + + +# Test for issue #324 +class BlankFieldTests(TestCase): + def setUp(self): + + class BlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BlankFieldModel + + class BlankFieldSerializer(serializers.Serializer): + title = serializers.CharField(blank=True) + + class NotBlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class NotBlankFieldSerializer(serializers.Serializer): + title = serializers.CharField() + + self.model_serializer_class = BlankFieldModelSerializer + self.serializer_class = BlankFieldSerializer + self.not_blank_model_serializer_class = NotBlankFieldModelSerializer + self.not_blank_serializer_class = NotBlankFieldSerializer + self.data = {'title': ''} + + def test_create_blank_field(self): + serializer = self.serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_model_blank_field(self): + serializer = self.model_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a non-model serializer + """ + serializer = self.not_blank_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), False) + + def test_create_model_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a model serializer + """ + serializer = self.not_blank_model_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), False) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py new file mode 100644 index 00000000..adeaf6da --- /dev/null +++ b/rest_framework/tests/tests.py @@ -0,0 +1,13 @@ +""" +Force import of all modules in this package in order to get the standard test +runner to pick up the tests. Yowzers. +""" +import os + +modules = [filename.rsplit('.', 1)[0] + for filename in os.listdir(os.path.dirname(__file__)) + if filename.endswith('.py') and not filename.startswith('_')] +__test__ = dict() + +for module in modules: + exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py index b390c42f..c032985e 100644 --- a/rest_framework/tests/validators.py +++ b/rest_framework/tests/validators.py @@ -285,7 +285,7 @@ # uiop = models.CharField(max_length=256, blank=True) # @property -# def readonly(self): +# def read_only(self): # return 'read only' # class MockResource(ModelResource): @@ -298,7 +298,7 @@ # def test_property_fields_are_allowed_on_model_forms(self): # """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} # self.assertEqual(self.validator.validate_request(content, None), content) # def test_property_fields_are_not_required_on_model_forms(self): @@ -310,19 +310,19 @@ # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_requires_fields_on_model_forms(self): # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'readonly': 'read only'} +# content = {'read_only': 'read only'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_does_not_require_blankable_fields_on_model_forms(self): # """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" -# content = {'qwerty': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'read_only': 'read only'} # self.validator.validate_request(content, None) # def test_model_form_validator_uses_model_forms(self): diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 6e7a0b72..8fe64248 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -16,7 +16,7 @@ class BaseThrottle(object): def wait(self): """ - Optionally, return a recommeded number of seconds to wait before + Optionally, return a recommended number of seconds to wait before the next request. """ return None @@ -60,7 +60,7 @@ class SimpleRateThrottle(BaseThrottle): Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None): - msg = ("You must set either `.scope` or `.rate` for '%s' thottle" % + msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise exceptions.ConfigurationError(msg) @@ -137,7 +137,7 @@ class AnonRateThrottle(SimpleRateThrottle): """ Limits the rate of API calls that may be made by a anonymous users. - The IP address of the request will be used as the unqiue cache key. + The IP address of the request will be used as the unique cache key. """ scope = 'anon' diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 386c78a2..316ccd19 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -2,26 +2,23 @@ from django.conf.urls.defaults import url from rest_framework.settings import api_settings -def format_suffix_patterns(urlpatterns, suffix_required=False, - suffix_kwarg=None, allowed=None): +def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): """ Supplement existing urlpatterns with corrosponding patterns that also include a '.format' suffix. Retains urlpattern ordering. + urlpatterns: + A list of URL patterns. + suffix_required: If `True`, only suffixed URLs will be generated, and non-suffixed URLs will not be used. Defaults to `False`. - suffix_kwarg: - The name of the kwarg that will be passed to the view. - Defaults to 'format'. - allowed: An optional tuple/list of allowed suffixes. eg ['json', 'api'] Defaults to `None`, which allows any suffix. - """ - suffix_kwarg = suffix_kwarg or api_settings.FORMAT_SUFFIX_KWARG + suffix_kwarg = api_settings.FORMAT_SUFFIX_KWARG if allowed: if len(allowed) == 1: allowed_pattern = allowed[0] diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index a59fff45..84fcb5db 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,7 +1,6 @@ from django.utils.encoding import smart_unicode from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO - import re import xml.etree.ElementTree as ET diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 672d32a3..80e39d46 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -6,7 +6,7 @@ def get_breadcrumbs(url): from rest_framework.views import APIView - def breadcrumbs_recursive(url, breadcrumbs_list, prefix): + def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" try: @@ -16,7 +16,11 @@ def get_breadcrumbs(url): else: # Check if this is a REST framework view, and if so add it to the breadcrumbs if isinstance(getattr(view, 'cls_instance', None), APIView): - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + # Don't list the same view twice in a row. + # Probably an optional trailing slash. + if not seen or seen[-1] != view: + breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + seen.append(view) if url == '': # All done @@ -24,11 +28,11 @@ def get_breadcrumbs(url): elif url.endswith('/'): # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix) + return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix) + return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] - return breadcrumbs_recursive(url, [], prefix) + return breadcrumbs_recursive(url, [], prefix, []) diff --git a/rest_framework/views.py b/rest_framework/views.py index 62fc92f9..1afbd697 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -54,12 +54,12 @@ def _camelcase_to_spaces(content): class APIView(View): settings = api_settings - renderer_classes = api_settings.DEFAULT_RENDERERS - parser_classes = api_settings.DEFAULT_PARSERS - authentication_classes = api_settings.DEFAULT_AUTHENTICATION - throttle_classes = api_settings.DEFAULT_THROTTLES - permission_classes = api_settings.DEFAULT_PERMISSIONS - content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION + renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES + parser_classes = api_settings.DEFAULT_PARSER_CLASSES + authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES + permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES + content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS @classmethod def as_view(cls, **initkwargs): @@ -158,12 +158,15 @@ class APIView(View): def get_parser_context(self, http_request): """ - Returns a dict that is passed through to Parser.parse_stream(), + Returns a dict that is passed through to Parser.parse(), as the `parser_context` keyword argument. """ + # Note: Additionally `request` will also be added to the context + # by the Request object. return { - 'upload_handlers': http_request.upload_handlers, - 'meta': http_request.META, + 'view': self, + 'args': getattr(self, 'args', ()), + 'kwargs': getattr(self, 'kwargs', {}) } def get_renderer_context(self): @@ -171,13 +174,13 @@ class APIView(View): Returns a dict that is passed through to Renderer.render(), as the `renderer_context` keyword argument. """ - # Note: Additionally 'response' will also be set on the context, + # Note: Additionally 'response' will also be added to the context, # by the Response object. return { 'view': self, - 'request': self.request, - 'args': self.args, - 'kwargs': self.kwargs + 'args': getattr(self, 'args', ()), + 'kwargs': getattr(self, 'kwargs', {}), + 'request': getattr(self, 'request', None) } # API policy instantiation methods @@ -215,7 +218,7 @@ class APIView(View): def get_throttles(self): """ - Instantiates and returns the list of thottles that this view uses. + Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes] @@ -235,7 +238,13 @@ class APIView(View): """ renderers = self.get_renderers() conneg = self.get_content_negotiator() - return conneg.negotiate(request, renderers, self.format_kwarg, force) + + try: + return conneg.select_renderer(request, renderers, self.format_kwarg) + except: + if force: + return (renderers[0], renderers[0].media_type) + raise def has_permission(self, request, obj=None): """ @@ -311,13 +320,17 @@ class APIView(View): self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait if isinstance(exc, exceptions.APIException): - return Response({'detail': exc.detail}, status=exc.status_code) + return Response({'detail': exc.detail}, + status=exc.status_code, + exception=True) elif isinstance(exc, Http404): return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND) + status=status.HTTP_404_NOT_FOUND, + exception=True) elif isinstance(exc, PermissionDenied): return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN) + status=status.HTTP_403_FORBIDDEN, + exception=True) raise # Note: session based authentication is explicitly CSRF validated, |
