aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/compat.py6
-rw-r--r--rest_framework/fields.py444
-rw-r--r--rest_framework/generics.py31
-rw-r--r--rest_framework/mixins.py19
-rw-r--r--rest_framework/relations.py460
-rwxr-xr-xrest_framework/runtests/runcoverage.py9
-rwxr-xr-xrest_framework/runtests/runtests.py8
-rw-r--r--rest_framework/runtests/urls.py2
-rw-r--r--rest_framework/serializers.py45
-rw-r--r--rest_framework/templates/rest_framework/base.html17
-rw-r--r--rest_framework/templates/rest_framework/login.html8
-rw-r--r--rest_framework/templatetags/rest_framework.py83
-rw-r--r--rest_framework/tests/decorators.py16
-rw-r--r--rest_framework/tests/files.py14
-rw-r--r--rest_framework/tests/generics.py16
-rw-r--r--rest_framework/tests/models.py28
-rw-r--r--rest_framework/tests/modelviews.py2
-rw-r--r--rest_framework/tests/relations_hyperlink.py115
-rw-r--r--rest_framework/tests/relations_nested.py55
-rw-r--r--rest_framework/tests/relations_pk.py121
-rw-r--r--rest_framework/tests/serializer.py4
-rw-r--r--rest_framework/tests/utils.py27
-rw-r--r--rest_framework/tests/views.py4
-rw-r--r--rest_framework/urlpatterns.py2
25 files changed, 968 insertions, 570 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 02bc6fc1..151ba832 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.12'
+__version__ = '2.1.14'
VERSION = __version__ # synonym
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 86952fb8..5508f6c0 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -96,6 +96,12 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
+# Taken from @markotibold's attempt at supporting PATCH.
+# https://github.com/markotibold/django-rest-framework/tree/patch
+http_method_names = set(View.http_method_names)
+http_method_names.add('patch')
+View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
from django.middleware.csrf import CsrfViewMiddleware
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5ae8ff48..c406a2f3 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,18 +7,14 @@ import warnings
from io import BytesIO
from django.core import validators
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve, get_script_prefix
+from django.core.exceptions import ValidationError
from django.conf import settings
from django import forms
from django.forms import widgets
-from django.forms.models import ModelChoiceIterator
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
-from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
-from urlparse import urlparse
def is_simple_callable(obj):
@@ -185,6 +181,7 @@ class WritableField(Field):
try:
if self._use_files:
+ files = files or {}
native = files[field_name]
else:
native = data[field_name]
@@ -254,443 +251,6 @@ class ModelField(WritableField):
"type": self.model_field.get_internal_type()
}
-##### Relational fields #####
-
-
-# Not actually Writable, but subclasses may need to be.
-class RelatedField(WritableField):
- """
- Base class for related model fields.
-
- If not overridden, this represents a to-one relationship, using the unicode
- representation of the target.
- """
- widget = widgets.Select
- cache_choices = False
- empty_label = None
- default_read_only = True # TODO: Remove this
-
- def __init__(self, *args, **kwargs):
- self.queryset = kwargs.pop('queryset', None)
- self.null = kwargs.pop('null', False)
- super(RelatedField, self).__init__(*args, **kwargs)
- self.read_only = kwargs.pop('read_only', self.default_read_only)
-
- def initialize(self, parent, field_name):
- super(RelatedField, self).initialize(parent, field_name)
- if self.queryset is None and not self.read_only:
- try:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
- except:
- raise
- msg = ('Serializer related fields must include a `queryset`' +
- ' argument or set `read_only=True')
- raise Exception(msg)
-
- ### We need this stuff to make form choices work...
-
- # def __deepcopy__(self, memo):
- # result = super(RelatedField, self).__deepcopy__(memo)
- # result.queryset = result.queryset
- # return result
-
- def prepare_value(self, obj):
- return self.to_native(obj)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def _get_queryset(self):
- return self._queryset
-
- def _set_queryset(self, queryset):
- self._queryset = queryset
- self.widget.choices = self.choices
-
- queryset = property(_get_queryset, _set_queryset)
-
- def _get_choices(self):
- # If self._choices is set, then somebody must have manually set
- # the property self.choices. In this case, just return self._choices.
- if hasattr(self, '_choices'):
- return self._choices
-
- # Otherwise, execute the QuerySet in self.queryset to determine the
- # choices dynamically. Return a fresh ModelChoiceIterator that has not been
- # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
- # time _get_choices() is called (and, thus, each time self.choices is
- # accessed) so that we can ensure the QuerySet has not been consumed. This
- # construct might look complicated but it allows for lazy evaluation of
- # the queryset.
- return ModelChoiceIterator(self)
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- ### Regular serializer stuff...
-
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return self.to_native(value)
-
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
-
- try:
- value = data[field_name]
- except KeyError:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- if value in (None, '') and not self.null:
- raise ValidationError('Value may not be null')
- elif value in (None, '') and self.null:
- into[(self.source or field_name)] = None
- else:
- into[(self.source or field_name)] = self.from_native(value)
-
-
-class ManyRelatedMixin(object):
- """
- Mixin to convert a related field to a many related field.
- """
- widget = widgets.SelectMultiple
-
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return [self.to_native(item) for item in value.all()]
-
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
-
- try:
- # Form data
- value = data.getlist(self.source or field_name)
- except:
- # Non-form data
- value = data.get(self.source or field_name)
- else:
- if value == ['']:
- value = []
-
- into[field_name] = [self.from_native(item) for item in value]
-
-
-class ManyRelatedField(ManyRelatedMixin, RelatedField):
- """
- Base class for related model managers.
-
- If not overridden, this represents a to-many relationship, using the unicode
- representations of the target, and is read-only.
- """
- pass
-
-
-### PrimaryKey relationships
-
-class PrimaryKeyRelatedField(RelatedField):
- """
- Represents a to-one relationship as a pk value.
- """
- default_read_only = False
- form_field_class = forms.ChoiceField
-
- # TODO: Remove these field hacks...
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- # TODO: Possibly change this to just take `obj`, through prob less performant
- def to_native(self, pk):
- return pk
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
- raise ValidationError(msg)
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- obj = getattr(obj, self.source or field_name)
- return self.to_native(obj.pk)
- # Forward relationship
- return self.to_native(pk)
-
-
-class ManyPrimaryKeyRelatedField(ManyRelatedField):
- """
- Represents a to-many relationship as a pk value.
- """
- default_read_only = False
- form_field_class = forms.MultipleChoiceField
-
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_unicode(obj)
- ident = smart_unicode(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
- return [self.to_native(item.pk) for item in queryset.all()]
- # Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
- raise ValidationError(msg)
-
-### Slug relationships
-
-
-class SlugRelatedField(RelatedField):
- default_read_only = False
- form_field_class = forms.ChoiceField
-
- def __init__(self, *args, **kwargs):
- self.slug_field = kwargs.pop('slug_field', None)
- assert self.slug_field, 'slug_field is required'
- super(SlugRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- return getattr(obj, self.slug_field)
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(**{self.slug_field: data})
- except ObjectDoesNotExist:
- raise ValidationError('Object with %s=%s does not exist.' %
- (self.slug_field, unicode(data)))
-
-
-class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
- form_field_class = forms.MultipleChoiceField
-
-
-### Hyperlinked relationships
-
-class HyperlinkedRelatedField(RelatedField):
- """
- Represents a to-one relationship, using hyperlinking.
- """
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- default_read_only = False
- form_field_class = forms.ChoiceField
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
-
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
- self.format = kwargs.pop('format', None)
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
-
- def get_slug_field(self):
- """
- Get the name of a slug field to be used to look up by slug.
- """
- return self.slug_field
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
- pk = getattr(obj, 'pk', None)
- if pk is None:
- return
- kwargs = {self.pk_url_kwarg: pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- slug = getattr(obj, self.slug_field, None)
-
- if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- if value.startswith('http:') or value.startswith('https:'):
- # If needed convert absolute URLs to relative path
- value = urlparse(value).path
- prefix = get_script_prefix()
- if value.startswith(prefix):
- value = '/' + value[len(prefix):]
-
- try:
- match = resolve(value)
- except:
- raise ValidationError('Invalid hyperlink - No URL match')
-
- if match.url_name != self.view_name:
- raise ValidationError('Invalid hyperlink - Incorrect URL match')
-
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- queryset = self.queryset.filter(pk=pk)
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's an error.
- else:
- raise ValidationError('Invalid hyperlink')
-
- try:
- obj = queryset.get()
- except ObjectDoesNotExist:
- raise ValidationError('Invalid hyperlink - object does not exist.')
- return obj
-
-
-class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
- """
- Represents a to-many relationship, using hyperlinking.
- """
- form_field_class = forms.MultipleChoiceField
-
-
-class HyperlinkedIdentityField(Field):
- """
- Represents the instance, or a property on the instance, using hyperlinking.
- """
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
- def __init__(self, *args, **kwargs):
- # TODO: Make view_name mandatory, and have the
- # HyperlinkedModelSerializer set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- self.format = kwargs.pop('format', None)
-
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
- view_name = self.view_name or self.parent.opts.view_name
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- slug = getattr(obj, self.slug_field, None)
-
- if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
- except:
- pass
-
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
##### Typed Fields #####
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index dd8dfcf8..19f2b704 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -47,14 +47,16 @@ class GenericAPIView(views.APIView):
return serializer_class
- def get_serializer(self, instance=None, data=None, files=None):
+ def get_serializer(self, instance=None, data=None,
+ files=None, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, files=files, context=context)
+ return serializer_class(instance, data=data, files=files,
+ partial=partial, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
@@ -171,6 +173,10 @@ class UpdateAPIView(mixins.UpdateModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
@@ -185,6 +191,23 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
+class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ SingleObjectAPIView):
+ """
+ Concrete view for retrieving, updating a model instance.
+ """
+ def get(self, request, *args, **kwargs):
+ return self.retrieve(request, *args, **kwargs)
+
+ def put(self, request, *args, **kwargs):
+ return self.update(request, *args, **kwargs)
+
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
+
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
SingleObjectAPIView):
@@ -211,5 +234,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2700606d..43581ae9 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -16,11 +16,14 @@ class CreateModelMixin(object):
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
+
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
headers = self.get_success_headers(serializer.data)
- return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+ return Response(serializer.data, status=status.HTTP_201_CREATED,
+ headers=headers)
+
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_success_headers(self, data):
@@ -82,20 +85,21 @@ class UpdateModelMixin(object):
Should be mixed in with `SingleObjectBaseView`.
"""
def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
try:
self.object = self.get_object()
- created = False
+ success_status_code = status.HTTP_200_OK
except Http404:
self.object = None
- created = True
+ success_status_code = status.HTTP_201_CREATED
- serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
+ serializer = self.get_serializer(self.object, data=request.DATA,
+ files=request.FILES, partial=partial)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
- return Response(serializer.data, status=status_code)
+ return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -115,7 +119,8 @@ class UpdateModelMixin(object):
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
- obj.full_clean()
+ if hasattr(obj, 'full_clean'):
+ obj.full_clean()
class DestroyModelMixin(object):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
new file mode 100644
index 00000000..686dcf04
--- /dev/null
+++ b/rest_framework/relations.py
@@ -0,0 +1,460 @@
+from django.core.exceptions import ObjectDoesNotExist, ValidationError
+from django.core.urlresolvers import resolve, get_script_prefix
+from django import forms
+from django.forms import widgets
+from django.forms.models import ModelChoiceIterator
+from django.utils.encoding import smart_unicode
+from rest_framework.fields import Field, WritableField
+from rest_framework.reverse import reverse
+from urlparse import urlparse
+
+##### Relational fields #####
+
+
+# Not actually Writable, but subclasses may need to be.
+class RelatedField(WritableField):
+ """
+ Base class for related model fields.
+
+ If not overridden, this represents a to-one relationship, using the unicode
+ representation of the target.
+ """
+ widget = widgets.Select
+ cache_choices = False
+ empty_label = None
+ default_read_only = True # TODO: Remove this
+
+ def __init__(self, *args, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ self.null = kwargs.pop('null', False)
+ super(RelatedField, self).__init__(*args, **kwargs)
+ self.read_only = kwargs.pop('read_only', self.default_read_only)
+
+ def initialize(self, parent, field_name):
+ super(RelatedField, self).initialize(parent, field_name)
+ if self.queryset is None and not self.read_only:
+ try:
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
+ except:
+ raise
+ msg = ('Serializer related fields must include a `queryset`' +
+ ' argument or set `read_only=True')
+ raise Exception(msg)
+
+ ### We need this stuff to make form choices work...
+
+ # def __deepcopy__(self, memo):
+ # result = super(RelatedField, self).__deepcopy__(memo)
+ # result.queryset = result.queryset
+ # return result
+
+ def prepare_value(self, obj):
+ return self.to_native(obj)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def _get_queryset(self):
+ return self._queryset
+
+ def _set_queryset(self, queryset):
+ self._queryset = queryset
+ self.widget.choices = self.choices
+
+ queryset = property(_get_queryset, _set_queryset)
+
+ def _get_choices(self):
+ # If self._choices is set, then somebody must have manually set
+ # the property self.choices. In this case, just return self._choices.
+ if hasattr(self, '_choices'):
+ return self._choices
+
+ # Otherwise, execute the QuerySet in self.queryset to determine the
+ # choices dynamically. Return a fresh ModelChoiceIterator that has not been
+ # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
+ # time _get_choices() is called (and, thus, each time self.choices is
+ # accessed) so that we can ensure the QuerySet has not been consumed. This
+ # construct might look complicated but it allows for lazy evaluation of
+ # the queryset.
+ return ModelChoiceIterator(self)
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ ### Regular serializer stuff...
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(obj, self.source or field_name)
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ if value in (None, '') and not self.null:
+ raise ValidationError('Value may not be null')
+ elif value in (None, '') and self.null:
+ into[(self.source or field_name)] = None
+ else:
+ into[(self.source or field_name)] = self.from_native(value)
+
+
+class ManyRelatedMixin(object):
+ """
+ Mixin to convert a related field to a many related field.
+ """
+ widget = widgets.SelectMultiple
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(obj, self.source or field_name)
+ return [self.to_native(item) for item in value.all()]
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ # Form data
+ value = data.getlist(self.source or field_name)
+ except:
+ # Non-form data
+ value = data.get(self.source or field_name)
+ else:
+ if value == ['']:
+ value = []
+
+ into[field_name] = [self.from_native(item) for item in value]
+
+
+class ManyRelatedField(ManyRelatedMixin, RelatedField):
+ """
+ Base class for related model managers.
+
+ If not overridden, this represents a to-many relationship, using the unicode
+ representations of the target, and is read-only.
+ """
+ pass
+
+
+### PrimaryKey relationships
+
+class PrimaryKeyRelatedField(RelatedField):
+ """
+ Represents a to-one relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ # TODO: Remove these field hacks...
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ # TODO: Possibly change this to just take `obj`, through prob less performant
+ def to_native(self, pk):
+ return pk
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ pk = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedObject (reverse relationship)
+ obj = getattr(obj, self.source or field_name)
+ return self.to_native(obj.pk)
+ # Forward relationship
+ return self.to_native(pk)
+
+
+class ManyPrimaryKeyRelatedField(ManyRelatedField):
+ """
+ Represents a to-many relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.MultipleChoiceField
+
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def to_native(self, pk):
+ return pk
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
+### Slug relationships
+
+
+class SlugRelatedField(RelatedField):
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ def __init__(self, *args, **kwargs):
+ self.slug_field = kwargs.pop('slug_field', None)
+ assert self.slug_field, 'slug_field is required'
+ super(SlugRelatedField, self).__init__(*args, **kwargs)
+
+ def to_native(self, obj):
+ return getattr(obj, self.slug_field)
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ raise ValidationError('Object with %s=%s does not exist.' %
+ (self.slug_field, unicode(data)))
+
+
+class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
+ form_field_class = forms.MultipleChoiceField
+
+
+### Hyperlinked relationships
+
+class HyperlinkedRelatedField(RelatedField):
+ """
+ Represents a to-one relationship, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ def __init__(self, *args, **kwargs):
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except:
+ raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ self.format = kwargs.pop('format', None)
+ super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+
+ def get_slug_field(self):
+ """
+ Get the name of a slug field to be used to look up by slug.
+ """
+ return self.slug_field
+
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+ pk = getattr(obj, 'pk', None)
+ if pk is None:
+ return
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ def from_native(self, value):
+ # Convert URL -> model instance pk
+ # TODO: Use values_list
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ if value.startswith('http:') or value.startswith('https:'):
+ # If needed convert absolute URLs to relative path
+ value = urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
+ try:
+ match = resolve(value)
+ except:
+ raise ValidationError('Invalid hyperlink - No URL match')
+
+ if match.url_name != self.view_name:
+ raise ValidationError('Invalid hyperlink - Incorrect URL match')
+
+ pk = match.kwargs.get(self.pk_url_kwarg, None)
+ slug = match.kwargs.get(self.slug_url_kwarg, None)
+
+ # Try explicit primary key.
+ if pk is not None:
+ queryset = self.queryset.filter(pk=pk)
+ # Next, try looking up by slug.
+ elif slug is not None:
+ slug_field = self.get_slug_field()
+ queryset = self.queryset.filter(**{slug_field: slug})
+ # If none of those are defined, it's an error.
+ else:
+ raise ValidationError('Invalid hyperlink')
+
+ try:
+ obj = queryset.get()
+ except ObjectDoesNotExist:
+ raise ValidationError('Invalid hyperlink - object does not exist.')
+ return obj
+
+
+class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
+ """
+ Represents a to-many relationship, using hyperlinking.
+ """
+ form_field_class = forms.MultipleChoiceField
+
+
+class HyperlinkedIdentityField(Field):
+ """
+ Represents the instance, or a property on the instance, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
+ def __init__(self, *args, **kwargs):
+ # TODO: Make view_name mandatory, and have the
+ # HyperlinkedModelSerializer set it on-the-fly
+ self.view_name = kwargs.pop('view_name', None)
+ # Optionally the format of the target hyperlink may be specified
+ self.format = kwargs.pop('format', None)
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
+
+ def field_to_native(self, obj, field_name):
+ request = self.context.get('request', None)
+ format = self.context.get('format', None)
+ view_name = self.view_name or self.parent.opts.view_name
+ kwargs = {self.pk_url_kwarg: obj.pk}
+
+ # By default use whatever format is given for the current context
+ # unless the target is a different type to the source.
+ #
+ # Eg. Consider a HyperlinkedIdentityField pointing from a json
+ # representation to an html property of that representation...
+ #
+ # '/snippets/1/' should link to '/snippets/1/highlight/'
+ # ...but...
+ # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
+ if format and self.format and self.format != format:
+ format = self.format
+
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
index 0ce379eb..bcab1d14 100755
--- a/rest_framework/runtests/runcoverage.py
+++ b/rest_framework/runtests/runcoverage.py
@@ -8,6 +8,9 @@ Useful tool to run the test suite for rest_framework and generate a coverage rep
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
+
+# fix sys path so we don't need to setup PYTHONPATH
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
from coverage import coverage
@@ -55,6 +58,12 @@ def main():
if 'compat.py' in files:
files.remove('compat.py')
+ # Same applies to template tags module.
+ # This module has to include branching on Django versions,
+ # so it's never possible for it to have full coverage.
+ if 'rest_framework.py' in files:
+ files.remove('rest_framework.py')
+
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
cov.report(cov_files)
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index 729ef26a..505994e2 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -5,11 +5,9 @@
# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
import os
import sys
-"""
-Need to fix sys path so following works without specifically messing with PYTHONPATH
-python ./rest_framework/runtests/runtests.py
-"""
-sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
+
+# fix sys path so we don't need to setup PYTHONPATH
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
from django.conf import settings
diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py
index 4b7da787..ed5baeae 100644
--- a/rest_framework/runtests/urls.py
+++ b/rest_framework/runtests/urls.py
@@ -1,7 +1,7 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
-from django.conf.urls.defaults import *
+from rest_framework.compat import patterns
urlpatterns = patterns('',
)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 8156bc18..bd54db4c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -14,7 +14,7 @@ from rest_framework.compat import get_concrete_model
# This helps keep the seperation between model fields, form fields, and
# serializer fields more explicit.
-
+from rest_framework.relations import *
from rest_framework.fields import *
@@ -307,6 +307,9 @@ class BaseSerializer(Field):
if is_simple_callable(getattr(obj, 'all', None)):
return [self.to_native(item) for item in obj.all()]
+ if obj is None:
+ return None
+
return self.to_native(obj)
@property
@@ -438,7 +441,7 @@ class ModelSerializer(Serializer):
kwargs['blank'] = model_field.blank
- if model_field.null:
+ if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.has_default():
@@ -493,23 +496,30 @@ class ModelSerializer(Serializer):
Restore the model instance.
"""
self.m2m_data = {}
+ self.related_data = {}
+
+ # Reverse fk relations
+ for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.related_data[field_name] = attrs.pop(field_name)
+
+ # Reverse m2m relations
+ for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.m2m_data[field_name] = attrs.pop(field_name)
+
+ # Forward m2m relations
+ for field in self.opts.model._meta.many_to_many:
+ if field.name in attrs:
+ self.m2m_data[field.name] = attrs.pop(field.name)
if instance is not None:
for key, val in attrs.items():
setattr(instance, key, val)
else:
- # Reverse relations
- for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
- field_name = obj.field.related_query_name()
- if field_name in attrs:
- self.m2m_data[field_name] = attrs.pop(field_name)
-
- # Forward relations
- for field in self.opts.model._meta.many_to_many:
- if field.name in attrs:
- self.m2m_data[field.name] = attrs.pop(field.name)
-
instance = self.opts.model(**attrs)
try:
@@ -520,17 +530,22 @@ class ModelSerializer(Serializer):
return instance
- def save(self, save_m2m=True):
+ def save(self):
"""
Save the deserialized object and return it.
"""
self.object.save()
- if getattr(self, 'm2m_data', None) and save_m2m:
+ if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}
+ if getattr(self, 'related_data', None):
+ for accessor_name, object_list in self.related_data.items():
+ setattr(self.object, accessor_name, object_list)
+ self.related_data = {}
+
return self.object
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index fb0e19f0..42e49cb9 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -1,6 +1,5 @@
{% load url from future %}
{% load rest_framework %}
-{% load static %}
<!DOCTYPE html>
<html>
<head>
@@ -14,10 +13,10 @@
<title>{% block title %}Django REST framework{% endblock %}</title>
{% block style %}
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/>
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/prettify.css'/>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
{% endblock %}
{% endblock %}
@@ -195,10 +194,10 @@
{% endblock %}
{% block script %}
- <script src="{% get_static_prefix %}rest_framework/js/jquery-1.8.1-min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/bootstrap.min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/prettify-min.js"></script>
- <script src="{% get_static_prefix %}rest_framework/js/default.js"></script>
+ <script src="{% static "rest_framework/js/jquery-1.8.1-min.js" %}"></script>
+ <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>
+ <script src="{% static "rest_framework/js/prettify-min.js" %}"></script>
+ <script src="{% static "rest_framework/js/default.js" %}"></script>
{% endblock %}
</body>
</html>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index c1271399..6e2bd8d4 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -1,11 +1,11 @@
{% load url from future %}
-{% load static %}
+{% load rest_framework %}
<html>
<head>
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/>
- <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/>
+ <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/>
</head>
<body class="container">
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 4e0181ee..09c658bc 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -11,6 +11,89 @@ import string
register = template.Library()
+# Note we don't use 'load staticfiles', because we need a 1.3 compatible
+# version, so instead we include the `static` template tag ourselves.
+
+# When 1.3 becomes unsupported by REST framework, we can instead start to
+# use the {% load staticfiles %} tag, remove the following code,
+# and add a dependancy that `django.contrib.staticfiles` must be installed.
+
+# Note: We can't put this into the `compat` module because the compat import
+# from rest_framework.compat import ...
+# conflicts with this rest_framework template tag module.
+
+try: # Django 1.5+
+ from django.contrib.staticfiles.templatetags import StaticFilesNode
+
+ @register.tag('static')
+ def do_static(parser, token):
+ return StaticFilesNode.handle_token(parser, token)
+
+except:
+ try: # Django 1.4
+ from django.contrib.staticfiles.storage import staticfiles_storage
+
+ @register.simple_tag
+ def static(path):
+ """
+ A template tag that returns the URL to a file
+ using staticfiles' storage backend
+ """
+ return staticfiles_storage.url(path)
+
+ except: # Django 1.3
+ from urlparse import urljoin
+ from django import template
+ from django.templatetags.static import PrefixNode
+
+ class StaticNode(template.Node):
+ def __init__(self, varname=None, path=None):
+ if path is None:
+ raise template.TemplateSyntaxError(
+ "Static template nodes must be given a path to return.")
+ self.path = path
+ self.varname = varname
+
+ def url(self, context):
+ path = self.path.resolve(context)
+ return self.handle_simple(path)
+
+ def render(self, context):
+ url = self.url(context)
+ if self.varname is None:
+ return url
+ context[self.varname] = url
+ return ''
+
+ @classmethod
+ def handle_simple(cls, path):
+ return urljoin(PrefixNode.handle_simple("STATIC_URL"), path)
+
+ @classmethod
+ def handle_token(cls, parser, token):
+ """
+ Class method to parse prefix node and return a Node.
+ """
+ bits = token.split_contents()
+
+ if len(bits) < 2:
+ raise template.TemplateSyntaxError(
+ "'%s' takes at least one argument (path to file)" % bits[0])
+
+ path = parser.compile_filter(bits[1])
+
+ if len(bits) >= 2 and bits[-2] == 'as':
+ varname = bits[3]
+ else:
+ varname = None
+
+ return cls(varname, path)
+
+ @register.tag('static')
+ def do_static_13(parser, token):
+ return StaticNode.handle_token(parser, token)
+
+
def replace_query_param(url, key, val):
"""
Given a URL and a key/val pair, set or replace an item in the query
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 8079c8cb..bc44a45b 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -17,6 +17,8 @@ from rest_framework.decorators import (
permission_classes,
)
+from rest_framework.tests.utils import RequestFactory
+
class DecoratorTestCase(TestCase):
@@ -63,6 +65,20 @@ class DecoratorTestCase(TestCase):
response = view(request)
self.assertEqual(response.status_code, 405)
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
def test_renderer_classes(self):
@api_view(['GET'])
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 5dd57b7c..446e23c0 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -25,7 +25,6 @@ class UploadedFileSerializer(serializers.Serializer):
class FileSerializerTests(TestCase):
-
def test_create(self):
now = datetime.datetime.now()
file = StringIO.StringIO('stuff')
@@ -37,3 +36,16 @@ class FileSerializerTests(TestCase):
self.assertEquals(serializer.object.created, uploaded_file.created)
self.assertEquals(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('file', serializer.errors)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 7c24d84e..843017eb 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,8 +1,8 @@
from django.db import models
from django.test import TestCase
-from django.test.client import RequestFactory
from django.utils import simplejson as json
from rest_framework import generics, serializers, status
+from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
@@ -181,6 +181,20 @@ class TestInstanceView(TestCase):
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.patch('/1', json.dumps(content),
+ content_type='application/json')
+
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index ae8e42b8..59c35074 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -177,3 +177,31 @@ class OptionalRelationModel(RESTFrameworkModel):
# Model for RegexField
class Book(RESTFrameworkModel):
isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
index 1f8468e8..f12e3b97 100644
--- a/rest_framework/tests/modelviews.py
+++ b/rest_framework/tests/modelviews.py
@@ -1,4 +1,4 @@
-# from django.conf.urls.defaults import patterns, url
+# from rest_framework.compat import patterns, url
# from django.forms import ModelForm
# from django.contrib.auth.models import Group, User
# from rest_framework.resources import ModelResource
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index 53ce0074..a7f8a035 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -2,7 +2,7 @@ from django.db import models
from django.test import TestCase
from rest_framework import serializers
from rest_framework.compat import patterns, url
-
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource
def dummy_view(request, pk):
pass
@@ -15,18 +15,6 @@ urlpatterns = patterns('',
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
)
-
-# ManyToMany
-
-class ManyToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
-
-
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
@@ -39,19 +27,8 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
model = ManyToManySource
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail', read_only=True)
+ sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
class Meta:
model = ForeignKeyTarget
@@ -114,8 +91,8 @@ class HyperlinkedManyToManyTests(TestCase):
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
serializer.save()
+ self.assertEquals(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
@@ -132,8 +109,8 @@ class HyperlinkedManyToManyTests(TestCase):
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
serializer.save()
+ self.assertEquals(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
@@ -234,6 +211,70 @@ class HyperlinkedForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_reverse_foreign_key_update(self):
+ data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
+ {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
def test_foreign_key_update_with_invalid_null(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
@@ -249,9 +290,21 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
+ if idx == 3:
+ target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
def test_foreign_key_create_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
@@ -266,7 +319,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
]
self.assertEquals(serializer.data, expected)
@@ -290,7 +343,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
]
self.assertEquals(serializer.data, expected)
@@ -309,7 +362,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
]
self.assertEquals(serializer.data, expected)
@@ -332,7 +385,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
+ {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
]
self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index 3482c252..5710c1ef 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -1,31 +1,33 @@
from django.db import models
from django.test import TestCase
from rest_framework import serializers
+from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = ForeignKeySource
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
+class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = ForeignKeySourceSerializer()
+ sources = FlatForeignKeySourceSerializer()
class Meta:
model = ForeignKeyTarget
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ depth = 1
+ model = NullableForeignKeySource
+
+
class ReverseForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -36,6 +38,16 @@ class ReverseForeignKeyTests(TestCase):
source = ForeignKeySource(name='source-%d' % idx, target=target)
source.save()
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}},
+ ]
+ self.assertEquals(serializer.data, expected)
+
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset)
@@ -49,3 +61,24 @@ class ReverseForeignKeyTests(TestCase):
]}
]
self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index e3360939..af6da2c0 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -1,17 +1,7 @@
from django.db import models
from django.test import TestCase
from rest_framework import serializers
-
-
-# ManyToMany
-
-class ManyToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource
class ManyToManyTargetSerializer(serializers.ModelSerializer):
@@ -26,19 +16,8 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):
model = ManyToManySource
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.ManyPrimaryKeyRelatedField(read_only=True)
+ sources = serializers.ManyPrimaryKeyRelatedField()
class Meta:
model = ForeignKeyTarget
@@ -49,14 +28,6 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):
model = ForeignKeySource
-# Nullable ForeignKey
-
-class NullableForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
-
-
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
@@ -99,8 +70,8 @@ class PKManyToManyTests(TestCase):
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
serializer.save()
+ self.assertEquals(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
@@ -117,8 +88,8 @@ class PKManyToManyTests(TestCase):
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
serializer.save()
+ self.assertEquals(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
@@ -216,6 +187,70 @@ class PKForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': u'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 4, 'name': u'source-4', 'target': 2},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [2]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 3, 'name': u'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEquals(serializer.data, expected)
+
def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': u'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
@@ -229,9 +264,21 @@ class PKNullableForeignKeyTests(TestCase):
target = ForeignKeyTarget(name='target-1')
target.save()
for idx in range(1, 4):
+ if idx == 3:
+ target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source.save()
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
def test_foreign_key_create_with_valid_null(self):
data = {'id': 4, 'name': u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
@@ -246,7 +293,7 @@ class PKNullableForeignKeyTests(TestCase):
expected = [
{'id': 1, 'name': u'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None}
]
self.assertEquals(serializer.data, expected)
@@ -270,7 +317,7 @@ class PKNullableForeignKeyTests(TestCase):
expected = [
{'id': 1, 'name': u'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': None},
{'id': 4, 'name': u'source-4', 'target': None}
]
self.assertEquals(serializer.data, expected)
@@ -289,7 +336,7 @@ class PKNullableForeignKeyTests(TestCase):
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 3, 'name': u'source-3', 'target': None}
]
self.assertEquals(serializer.data, expected)
@@ -312,7 +359,7 @@ class PKNullableForeignKeyTests(TestCase):
expected = [
{'id': 1, 'name': u'source-1', 'target': None},
{'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 3, 'name': u'source-3', 'target': None}
]
self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index aac1f68a..8767385e 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -765,6 +765,10 @@ class BlankFieldTests(TestCase):
serializer = self.not_blank_model_serializer_class(data=self.data)
self.assertEquals(serializer.is_valid(), False)
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={})
+ self.assertEquals(serializer.is_valid(), True)
+
#test for issue #460
class SerializerPickleTests(TestCase):
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
new file mode 100644
index 00000000..3906adb9
--- /dev/null
+++ b/rest_framework/tests/utils.py
@@ -0,0 +1,27 @@
+from django.test.client import RequestFactory, FakePayload
+from django.test.client import MULTIPART_CONTENT
+from urlparse import urlparse
+
+
+class RequestFactory(RequestFactory):
+
+ def __init__(self, **defaults):
+ super(RequestFactory, self).__init__(**defaults)
+
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ **extra):
+ "Construct a PATCH request."
+
+ patch_data = self._encode_data(data, content_type)
+
+ parsed = urlparse(path)
+ r = {
+ 'CONTENT_LENGTH': len(patch_data),
+ 'CONTENT_TYPE': content_type,
+ 'PATH_INFO': self._get_path(parsed),
+ 'QUERY_STRING': parsed[4],
+ 'REQUEST_METHOD': 'PATCH',
+ 'wsgi.input': FakePayload(patch_data),
+ }
+ r.update(extra)
+ return self.request(**r)
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 43365e07..7cd82656 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -18,7 +18,7 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.DATA})
-@api_view(['GET', 'POST', 'PUT'])
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
@@ -26,6 +26,8 @@ def basic_view(request):
return {'method': 'POST', 'data': request.DATA}
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
def sanitise_json_error(error_dict):
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 0ad926fa..143928c9 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,4 +1,4 @@
-from django.conf.urls.defaults import url
+from rest_framework.compat import url
from rest_framework.settings import api_settings