diff options
| author | Tom Christie | 2013-01-15 17:53:24 +0000 |
|---|---|---|
| committer | Tom Christie | 2013-01-15 17:53:24 +0000 |
| commit | 71e55cc4f6300959398f7aef4a8d91b6a6a2af57 (patch) | |
| tree | 68c2080034263d897741da33cbc5e09746006257 /rest_framework | |
| parent | 52847a215d4e8de88e81d9ae79ce8bee9a36a9a2 (diff) | |
| parent | e1076cfb49b6293aa837cf7bdb4c11988892c598 (diff) | |
| download | django-rest-framework-71e55cc4f6300959398f7aef4a8d91b6a6a2af57.tar.bz2 | |
Merge with latest master
Diffstat (limited to 'rest_framework')
71 files changed, 4268 insertions, 697 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 557f5943..bc267fad 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.0.0' +__version__ = '2.1.16' VERSION = __version__ # synonym diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index 9d750381..f4e052e4 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -5,13 +5,21 @@ from south.v2 import SchemaMigration from django.db import models +try: + from django.contrib.auth import get_user_model +except ImportError: # django < 1.5 + from django.contrib.auth.models import User +else: + User = get_user_model() + + class Migration(SchemaMigration): def forwards(self, orm): # Adding model 'Token' db.create_table('authtoken_token', ( ('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)), - ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['auth.User'])), + ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])), ('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)), )) db.send_create_signal('authtoken', ['Token']) @@ -36,7 +44,7 @@ class Migration(SchemaMigration): 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) }, - 'auth.user': { + "%s.%s" % (User._meta.app_label, User._meta.module_name): { 'Meta': {'object_name': 'User'}, 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}), @@ -56,7 +64,7 @@ class Migration(SchemaMigration): 'Meta': {'object_name': 'Token'}, 'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}), 'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}), - 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['auth.User']"}) + 'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)}) }, 'contenttypes.contenttype': { 'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 5b3071aa..4da2aa62 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,6 +1,7 @@ import uuid import hmac from hashlib import sha1 +from rest_framework.compat import User from django.db import models @@ -9,7 +10,7 @@ class Token(models.Model): The default authorization token model. """ key = models.CharField(max_length=40, primary_key=True) - user = models.OneToOneField('auth.User', related_name='auth_token') + user = models.OneToOneField(User, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) def save(self, *args, **kwargs): diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py new file mode 100644 index 00000000..60a3740e --- /dev/null +++ b/rest_framework/authtoken/serializers.py @@ -0,0 +1,24 @@ +from django.contrib.auth import authenticate +from rest_framework import serializers + + +class AuthTokenSerializer(serializers.Serializer): + username = serializers.CharField() + password = serializers.CharField() + + def validate(self, attrs): + username = attrs.get('username') + password = attrs.get('password') + + if username and password: + user = authenticate(username=username, password=password) + + if user: + if not user.is_active: + raise serializers.ValidationError('User account is disabled.') + attrs['user'] = user + return attrs + else: + raise serializers.ValidationError('Unable to login with provided credentials.') + else: + raise serializers.ValidationError('Must include "username" and "password"') diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index e69de29b..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -0,0 +1,26 @@ +from rest_framework.views import APIView +from rest_framework import status +from rest_framework import parsers +from rest_framework import renderers +from rest_framework.response import Response +from rest_framework.authtoken.models import Token +from rest_framework.authtoken.serializers import AuthTokenSerializer + + +class ObtainAuthToken(APIView): + throttle_classes = () + permission_classes = () + parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) + renderer_classes = (renderers.JSONRenderer,) + serializer_class = AuthTokenSerializer + model = Token + + def post(self, request): + serializer = self.serializer_class(data=request.DATA) + if serializer.is_valid(): + token, created = Token.objects.get_or_create(user=serializer.object['user']) + return Response({'token': token.key}) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7664c400..5508f6c0 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -1,8 +1,23 @@ """ -The :mod:`compat` module provides support for backwards compatibility with older versions of django/python. +The `compat` module provides support for backwards compatibility with older +versions of django/python, and compatibility wrappers around optional packages. """ +# flake8: noqa import django +# location of patterns, url, include changes in 1.4 onwards +try: + from django.conf.urls import patterns, url, include +except: + from django.conf.urls.defaults import patterns, url, include + +# django-filter is optional +try: + import django_filters +except: + django_filters = None + + # cStringIO only if it's available, otherwise StringIO try: import cStringIO as StringIO @@ -10,6 +25,16 @@ except ImportError: import StringIO +# Try to import PIL in either of the two ways it can end up installed. +try: + from PIL import Image +except ImportError: + try: + import Image + except ImportError: + Image = None + + def get_concrete_model(model_cls): try: return model_cls._meta.concrete_model @@ -18,6 +43,20 @@ def get_concrete_model(model_cls): return model_cls +# Django 1.5 add support for custom auth user model +if django.VERSION >= (1, 5): + from django.conf import settings + if hasattr(settings, 'AUTH_USER_MODEL'): + User = settings.AUTH_USER_MODEL + else: + from django.contrib.auth.models import User +else: + try: + from django.contrib.auth.models import User + except ImportError: + raise ImportError(u"User model is not to be found.") + + # First implementation of Django class-based views did not include head method # in base View class - https://code.djangoproject.com/ticket/15668 if django.VERSION >= (1, 4): @@ -57,6 +96,12 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view +# Taken from @markotibold's attempt at supporting PATCH. +# https://github.com/markotibold/django-rest-framework/tree/patch +http_method_names = set(View.http_method_names) +http_method_names.add('patch') +View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): from django.middleware.csrf import CsrfViewMiddleware @@ -331,7 +376,7 @@ try: """ extensions = ['headerid(level=2)'] - safe_mode = False, + safe_mode = False md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode) return md.convert(text) @@ -346,33 +391,6 @@ except ImportError: yaml = None -import unittest -try: - import unittest.skip -except ImportError: # python < 2.7 - from unittest import TestCase - import functools - - def skip(reason): - # Pasted from py27/lib/unittest/case.py - """ - Unconditionally skip a test. - """ - def decorator(test_item): - if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): - @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): - pass - test_item = skip_wrapper - - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason - return test_item - return decorator - - unittest.skip = skip - - # xml.etree.parse only throws ParseError for python >= 2.7 try: from xml.etree import ParseError as ETParseError diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 948973ae..1b710a03 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,8 +10,18 @@ def api_view(http_method_names): def decorator(func): - class WrappedAPIView(APIView): - pass + WrappedAPIView = type( + 'WrappedAPIView', + (APIView,), + {'__doc__': func.__doc__} + ) + + # Note, the above allows us to set the docstring. + # It is the equivalent of: + # + # class WrappedAPIView(APIView): + # pass + # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 572425b9..89479deb 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -31,14 +31,6 @@ class PermissionDenied(APIException): self.detail = detail or self.default_detail -class InvalidFormat(APIException): - status_code = status.HTTP_404_NOT_FOUND - default_detail = "Format suffix '.%s' not found." - - def __init__(self, format, detail=None): - self.detail = (detail or self.default_detail) % format - - class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = "Method '%s' not allowed." diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6ed37823..998911e1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,16 +1,18 @@ import copy import datetime import inspect +import re import warnings +from io import BytesIO + from django.core import validators -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve +from django.core.exceptions import ValidationError from django.conf import settings +from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ -from rest_framework.reverse import reverse from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import timezone @@ -26,9 +28,12 @@ def is_simple_callable(obj): class Field(object): + read_only = True creation_counter = 0 empty = '' type_name = None + _use_files = None + form_field_class = forms.CharField def __init__(self, source=None): self.parent = None @@ -38,18 +43,20 @@ class Field(object): self.source = source - def initialize(self, parent): + def initialize(self, parent, field_name): """ Called to set up a field prior to field_to_native or field_from_native. parent - The parent serializer. - model_field - The model field this field corrosponds to, if one exists. + model_field - The model field this field corresponds to, if one exists. """ self.parent = parent self.root = parent.root or parent self.context = self.root.context + if self.root.partial: + self.required = False - def field_from_native(self, data, field_name, into): + def field_from_native(self, data, files, field_name, into): """ Given a dictionary and a field name, updates the dictionary `into`, with the field and it's deserialized value. @@ -88,6 +95,8 @@ class Field(object): return value elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): return [self.to_native(item) for item in value] + elif isinstance(value, dict): + return dict(map(self.to_native, (k, v)) for k, v in value.items()) return smart_unicode(value) def attributes(self): @@ -111,17 +120,17 @@ class WritableField(Field): widget = widgets.TextInput default = None - def __init__(self, source=None, readonly=False, required=None, + def __init__(self, source=None, read_only=False, required=None, validators=[], error_messages=None, widget=None, - default=None): + default=None, blank=None): super(WritableField, self).__init__(source=source) - self.readonly = readonly + self.read_only = read_only if required is None: - self.required = not(readonly) + self.required = not(read_only) else: - assert not readonly, "Cannot set required=True and readonly=True" + assert not (read_only and required), "Cannot set required=True and read_only=True" self.required = required messages = {} @@ -131,7 +140,8 @@ class WritableField(Field): self.error_messages = messages self.validators = self.default_validators + validators - self.default = default or self.default + self.default = default if default is not None else self.default + self.blank = blank # Widgets are ony used for HTML forms. widget = widget or self.widget @@ -161,18 +171,23 @@ class WritableField(Field): if errors: raise ValidationError(errors) - def field_from_native(self, data, field_name, into): + def field_from_native(self, data, files, field_name, into): """ Given a dictionary and a field name, updates the dictionary `into`, with the field and it's deserialized value. """ - if self.readonly: + if self.read_only: return try: - native = data[field_name] + if self._use_files: + files = files or {} + native = files[field_name] + else: + native = data[field_name] except KeyError: - if self.default is not None: + if self.default is not None and not self.root.partial: + # Note: partial updates shouldn't set defaults native = self.default else: if self.required: @@ -197,21 +212,32 @@ class WritableField(Field): class ModelField(WritableField): """ - A generic field that can be used against an arbirtrary model field. + A generic field that can be used against an arbitrary model field. """ def __init__(self, *args, **kwargs): try: self.model_field = kwargs.pop('model_field') except: raise ValueError("ModelField requires 'model_field' kwarg") + + self.min_length = kwargs.pop('min_length', + getattr(self.model_field, 'min_length', None)) + self.max_length = kwargs.pop('max_length', + getattr(self.model_field, 'max_length', None)) + super(ModelField, self).__init__(*args, **kwargs) + if self.min_length is not None: + self.validators.append(validators.MinLengthValidator(self.min_length)) + if self.max_length is not None: + self.validators.append(validators.MaxLengthValidator(self.max_length)) + def from_native(self, value): - try: - rel = self.model_field.rel - except: + rel = getattr(self.model_field, "rel", None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(value) + else: return self.model_field.to_python(value) - return rel.to._meta.get_field(rel.field_name).to_python(value) def field_to_native(self, obj, field_name): value = self.model_field._get_val_from_obj(obj) @@ -224,200 +250,12 @@ class ModelField(WritableField): "type": self.model_field.get_internal_type() } -##### Relational fields ##### - - -class RelatedField(WritableField): - """ - Base class for related model fields. - """ - def __init__(self, *args, **kwargs): - self.queryset = kwargs.pop('queryset', None) - super(RelatedField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - value = getattr(obj, self.source or field_name) - return self.to_native(value) - - def field_from_native(self, data, field_name, into): - if self.readonly: - return - - value = data.get(field_name) - into[(self.source or field_name) + '_id'] = self.from_native(value) - - -class ManyRelatedMixin(object): - """ - Mixin to convert a related field to a many related field. - """ - def field_to_native(self, obj, field_name): - value = getattr(obj, self.source or field_name) - return [self.to_native(item) for item in value.all()] - - def field_from_native(self, data, field_name, into): - if self.readonly: - return - - try: - # Form data - value = data.getlist(self.source or field_name) - except: - # Non-form data - value = data.get(self.source or field_name) - else: - if value == ['']: - value = [] - into[field_name] = [self.from_native(item) for item in value] - - -class ManyRelatedField(ManyRelatedMixin, RelatedField): - """ - Base class for related model managers. - """ - pass - - -### PrimaryKey relationships - -class PrimaryKeyRelatedField(RelatedField): - """ - Serializes a related field or related object to a pk value. - """ - - def to_native(self, pk): - return pk - - def field_to_native(self, obj, field_name): - try: - # Prefer obj.serializable_value for performance reasons - pk = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedObject (reverse relationship) - obj = getattr(obj, self.source or field_name) - return self.to_native(obj.pk) - # Forward relationship - return self.to_native(pk) - - -class ManyPrimaryKeyRelatedField(ManyRelatedField): - """ - Serializes a to-many related field or related manager to a pk value. - """ - def to_native(self, pk): - return pk - - def field_to_native(self, obj, field_name): - try: - # Prefer obj.serializable_value for performance reasons - queryset = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedManager (reverse relationship) - queryset = getattr(obj, self.source or field_name) - return [self.to_native(item.pk) for item in queryset.all()] - # Forward relationship - return [self.to_native(item.pk) for item in queryset.all()] - - -### Hyperlinked relationships - -class HyperlinkedRelatedField(RelatedField): - pk_url_kwarg = 'pk' - slug_url_kwarg = 'slug' - slug_field = 'slug' - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except: - raise ValueError("Hyperlinked field requires 'view_name' kwarg") - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - kwargs = {self.pk_url_kwarg: obj.pk} - try: - return reverse(view_name, kwargs=kwargs, request=request) - except: - pass - - slug = getattr(obj, self.slug_field, None) - - if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) - - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request) - except: - pass - - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request) - except: - pass - - raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) - - def from_native(self, value): - # Convert URL -> model instance pk - # TODO: Use values_list - try: - match = resolve(value) - except: - raise ValidationError('Invalid hyperlink - No URL match') - - if match.url_name != self.view_name: - raise ValidationError('Invalid hyperlink - Incorrect URL match') - - pk = match.kwargs.get(self.pk_url_kwarg, None) - slug = match.kwargs.get(self.slug_url_kwarg, None) - - # Try explicit primary key. - if pk is not None: - return pk - # Next, try looking up by slug. - elif slug is not None: - slug_field = self.get_slug_field() - queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's an error. - else: - raise ValidationError('Invalid hyperlink') - - try: - obj = queryset.get() - except ObjectDoesNotExist: - raise ValidationError('Invalid hyperlink - object does not exist.') - return obj.pk - - -class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): - pass - - -class HyperlinkedIdentityField(Field): - """ - A field that represents the model's identity using a hyperlink. - """ - def __init__(self, *args, **kwargs): - # TODO: Make this mandatory, and have the HyperlinkedModelSerializer - # set it on-the-fly - self.view_name = kwargs.pop('view_name', None) - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - request = self.context.get('request', None) - view_name = self.view_name or self.parent.opts.view_name - view_kwargs = {'pk': obj.pk} - return reverse(view_name, kwargs=view_kwargs, request=request) - ##### Typed Fields ##### class BooleanField(WritableField): type_name = 'BooleanField' + form_field_class = forms.BooleanField widget = widgets.CheckboxInput default_error_messages = { 'invalid': _(u"'%s' value must be either True or False."), @@ -430,15 +268,16 @@ class BooleanField(WritableField): default = False def from_native(self, value): - if value in ('t', 'True', '1'): + if value in ('true', 't', 'True', '1'): return True - if value in ('f', 'False', '0'): + if value in ('false', 'f', 'False', '0'): return False return bool(value) class CharField(WritableField): type_name = 'CharField' + form_field_class = forms.CharField def __init__(self, max_length=None, min_length=None, *args, **kwargs): self.max_length, self.min_length = max_length, min_length @@ -448,14 +287,42 @@ class CharField(WritableField): if max_length is not None: self.validators.append(validators.MaxLengthValidator(max_length)) + def validate(self, value): + """ + Validates that the value is supplied (if required). + """ + # if empty string and allow blank + if self.blank and not value: + return + else: + super(CharField, self).validate(value) + def from_native(self, value): if isinstance(value, basestring) or value is None: return value return smart_unicode(value) +class URLField(CharField): + type_name = 'URLField' + + def __init__(self, **kwargs): + kwargs['max_length'] = kwargs.get('max_length', 200) + kwargs['validators'] = [validators.URLValidator()] + super(URLField, self).__init__(**kwargs) + + +class SlugField(CharField): + type_name = 'SlugField' + + def __init__(self, *args, **kwargs): + kwargs['max_length'] = kwargs.get('max_length', 50) + super(SlugField, self).__init__(*args, **kwargs) + + class ChoiceField(WritableField): type_name = 'ChoiceField' + form_field_class = forms.ChoiceField widget = widgets.Select default_error_messages = { 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), @@ -495,13 +362,14 @@ class ChoiceField(WritableField): if value == smart_unicode(k2): return True else: - if value == smart_unicode(k): + if value == smart_unicode(k) or value == k: return True return False class EmailField(CharField): type_name = 'EmailField' + form_field_class = forms.EmailField default_error_messages = { 'invalid': _('Enter a valid e-mail address.'), @@ -509,7 +377,10 @@ class EmailField(CharField): default_validators = [validators.validate_email] def from_native(self, value): - return super(EmailField, self).from_native(value).strip() + ret = super(EmailField, self).from_native(value) + if ret is None: + return None + return ret.strip() def __deepcopy__(self, memo): result = copy.copy(self) @@ -519,8 +390,39 @@ class EmailField(CharField): return result +class RegexField(CharField): + type_name = 'RegexField' + form_field_class = forms.RegexField + + def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): + super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) + self.regex = regex + + def _get_regex(self): + return self._regex + + def _set_regex(self, regex): + if isinstance(regex, basestring): + regex = re.compile(regex) + self._regex = regex + if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: + self.validators.remove(self._regex_validator) + self._regex_validator = validators.RegexValidator(regex=regex) + self.validators.append(self._regex_validator) + + regex = property(_get_regex, _set_regex) + + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + result.validators = self.validators[:] + return result + + class DateField(WritableField): type_name = 'DateField' + widget = widgets.DateInput + form_field_class = forms.DateField default_error_messages = { 'invalid': _(u"'%s' value has an invalid date format. It must be " @@ -531,8 +433,9 @@ class DateField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): if timezone and settings.USE_TZ and timezone.is_aware(value): # Convert aware datetimes to the default time zone @@ -557,6 +460,8 @@ class DateField(WritableField): class DateTimeField(WritableField): type_name = 'DateTimeField' + widget = widgets.DateTimeInput + form_field_class = forms.DateTimeField default_error_messages = { 'invalid': _(u"'%s' value has an invalid format. It must be in " @@ -570,8 +475,9 @@ class DateTimeField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): return value if isinstance(value, datetime.date): @@ -610,6 +516,7 @@ class DateTimeField(WritableField): class IntegerField(WritableField): type_name = 'IntegerField' + form_field_class = forms.IntegerField default_error_messages = { 'invalid': _('Enter a whole number.'), @@ -629,6 +536,7 @@ class IntegerField(WritableField): def from_native(self, value): if value in validators.EMPTY_VALUES: return None + try: value = int(str(value)) except (ValueError, TypeError): @@ -638,16 +546,123 @@ class IntegerField(WritableField): class FloatField(WritableField): type_name = 'FloatField' + form_field_class = forms.FloatField default_error_messages = { 'invalid': _("'%s' value must be a float."), } def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + try: return float(value) except (TypeError, ValueError): msg = self.error_messages['invalid'] % value raise ValidationError(msg) + + +class FileField(WritableField): + _use_files = True + type_name = 'FileField' + form_field_class = forms.FileField + widget = widgets.FileInput + + default_error_messages = { + 'invalid': _("No file was submitted. Check the encoding type on the form."), + 'missing': _("No file was submitted."), + 'empty': _("The submitted file is empty."), + 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), + 'contradiction': _('Please either submit a file or check the clear checkbox, not both.') + } + + def __init__(self, *args, **kwargs): + self.max_length = kwargs.pop('max_length', None) + self.allow_empty_file = kwargs.pop('allow_empty_file', False) + super(FileField, self).__init__(*args, **kwargs) + + def from_native(self, data): + if data in validators.EMPTY_VALUES: + return None + + # UploadedFile objects should have name and size attributes. + try: + file_name = data.name + file_size = data.size + except AttributeError: + raise ValidationError(self.error_messages['invalid']) + + if self.max_length is not None and len(file_name) > self.max_length: + error_values = {'max': self.max_length, 'length': len(file_name)} + raise ValidationError(self.error_messages['max_length'] % error_values) + if not file_name: + raise ValidationError(self.error_messages['invalid']) + if not self.allow_empty_file and not file_size: + raise ValidationError(self.error_messages['empty']) + + return data + + def to_native(self, value): + return value.name + + +class ImageField(FileField): + _use_files = True + form_field_class = forms.ImageField + + default_error_messages = { + 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."), + } + + def from_native(self, data): + """ + Checks that the file-upload field data contains a valid image (GIF, JPG, + PNG, possibly others -- whatever the Python Imaging Library supports). + """ + f = super(ImageField, self).from_native(data) + if f is None: + return None + + from compat import Image + assert Image is not None, 'PIL must be installed for ImageField support' + + # We need to get a file object for PIL. We might have a path or we might + # have to read the data into memory. + if hasattr(data, 'temporary_file_path'): + file = data.temporary_file_path() + else: + if hasattr(data, 'read'): + file = BytesIO(data.read()) + else: + file = BytesIO(data['content']) + + try: + # load() could spot a truncated JPEG, but it loads the entire + # image in memory, which is a DoS vector. See #3848 and #18520. + # verify() must be called immediately after the constructor. + Image.open(file).verify() + except ImportError: + # Under PyPy, it is possible to import PIL. However, the underlying + # _imaging C module isn't available, so an ImportError will be + # raised. Catch and re-raise. + raise + except Exception: # Python Imaging Library doesn't recognize it as an image + raise ValidationError(self.error_messages['invalid_image']) + if hasattr(f, 'seek') and callable(f.seek): + f.seek(0) + return f + + +class SerializerMethodField(Field): + """ + A field that gets its value by calling a method on the serializer it's attached to. + """ + + def __init__(self, method_name): + self.method_name = method_name + super(SerializerMethodField, self).__init__() + + def field_to_native(self, obj, field_name): + value = getattr(self.parent, self.method_name)(obj) + return self.to_native(value) diff --git a/rest_framework/filters.py b/rest_framework/filters.py new file mode 100644 index 00000000..bcc87660 --- /dev/null +++ b/rest_framework/filters.py @@ -0,0 +1,59 @@ +from rest_framework.compat import django_filters + +FilterSet = django_filters and django_filters.FilterSet or None + + +class BaseFilterBackend(object): + """ + A base class from which all filter backend classes should inherit. + """ + + def filter_queryset(self, request, queryset, view): + """ + Return a filtered queryset. + """ + raise NotImplementedError(".filter_queryset() must be overridden.") + + +class DjangoFilterBackend(BaseFilterBackend): + """ + A filter backend that uses django-filter. + """ + default_filter_set = FilterSet + + def __init__(self): + assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' + + def get_filter_class(self, view): + """ + Return the django-filters `FilterSet` used to filter the queryset. + """ + filter_class = getattr(view, 'filter_class', None) + filter_fields = getattr(view, 'filter_fields', None) + view_model = getattr(view, 'model', None) + + if filter_class: + filter_model = filter_class.Meta.model + + assert issubclass(filter_model, view_model), \ + 'FilterSet model %s does not match view model %s' % \ + (filter_model, view_model) + + return filter_class + + if filter_fields: + class AutoFilterSet(self.default_filter_set): + class Meta: + model = view_model + fields = filter_fields + return AutoFilterSet + + return None + + def filter_queryset(self, request, queryset, view): + filter_class = self.get_filter_class(view) + + if filter_class: + return filter_class(request.GET, queryset=queryset) + + return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 81014026..f575470e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,5 +1,5 @@ """ -Generic views that provide commmonly needed behaviour. +Generic views that provide commonly needed behaviour. """ from rest_framework import views, mixins @@ -14,6 +14,8 @@ class GenericAPIView(views.APIView): """ Base class for all other generic views. """ + + model = None serializer_class = None model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS @@ -30,8 +32,10 @@ class GenericAPIView(views.APIView): def get_serializer_class(self): """ Return the class to use for the serializer. - Use `self.serializer_class`, falling back to constructing a - model serializer class from `self.model_serializer_class` + + Defaults to using `self.serializer_class`, falls back to constructing a + model serializer class using `self.model_serializer_class`, with + `self.model` as the model. """ serializer_class = self.serializer_class @@ -43,12 +47,19 @@ class GenericAPIView(views.APIView): return serializer_class - def get_serializer(self, data=None, files=None, instance=None): - # TODO: add support for files - # TODO: add support for seperate serializer/deserializer + def get_serializer(self, instance=None, data=None, + files=None, partial=False): + """ + Return the serializer instance that should be used for validating and + deserializing input, and for serializing output. + """ serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(data, instance=instance, context=context) + return serializer_class(instance, data=data, files=files, + partial=partial, context=context) + + def pre_save(self, obj): + pass class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): @@ -56,37 +67,59 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): Base class for generic views onto a queryset. """ - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY + paginate_by_param = api_settings.PAGINATE_BY_PARAM + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + filter_backend = api_settings.FILTER_BACKEND - def get_pagination_serializer_class(self): + def filter_queryset(self, queryset): """ - Return the class to use for the pagination serializer. + Given a queryset, filter it with whichever filter backend is in use. + """ + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) + + def get_pagination_serializer(self, page=None): + """ + Return a serializer instance to use with paginated data. """ class SerializerClass(self.pagination_serializer_class): class Meta: object_serializer_class = self.get_serializer_class() - return SerializerClass - - def get_pagination_serializer(self, page=None): - pagination_serializer_class = self.get_pagination_serializer_class() + pagination_serializer_class = SerializerClass context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) + def get_paginate_by(self, queryset): + """ + Return the size of pages to use with pagination. + """ + if self.paginate_by_param: + query_params = self.request.QUERY_PARAMS + try: + return int(query_params[self.paginate_by_param]) + except (KeyError, ValueError): + pass + return self.paginate_by + class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): """ Base class for generic views onto a model instance. """ + pk_url_kwarg = 'pk' # Not provided in Django 1.3 slug_url_kwarg = 'slug' # Not provided in Django 1.3 + slug_field = 'slug' - def get_object(self): + def get_object(self, queryset=None): """ Override default to add support for object-level permissions. """ - obj = super(SingleObjectAPIView, self).get_object() + obj = super(SingleObjectAPIView, self).get_object(queryset) if not self.has_permission(self.request, obj): self.permission_denied(self.request) return obj @@ -125,7 +158,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, - SingleObjectAPIView): + SingleObjectAPIView): """ Concrete view for deleting a model instance. @@ -143,6 +176,10 @@ class UpdateAPIView(mixins.UpdateModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, @@ -157,6 +194,23 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + SingleObjectAPIView): + """ + Concrete view for retrieving, updating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + + class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, SingleObjectAPIView): @@ -183,5 +237,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 9bd566da..e0ae216e 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -3,9 +3,6 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. - -Eg. Use mixins to build a Resource class, and have a Router class - perform the binding of http methods to actions for us. """ from django.http import Http404 from rest_framework import status @@ -18,30 +15,42 @@ class CreateModelMixin(object): Should be mixed in with any `BaseView`. """ def create(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.DATA) + serializer = self.get_serializer(data=request.DATA, files=request.FILES) + if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, + headers=headers) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def get_success_headers(self, data): + try: + return {'Location': data['url']} + except (TypeError, KeyError): + return {} + class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectBaseView`. + Should be mixed in with `MultipleObjectAPIView`. """ empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - self.object_list = self.get_queryset() + queryset = self.get_queryset() + self.object_list = self.filter_queryset(queryset) # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. allow_empty = self.get_allow_empty() - if not allow_empty and len(self.object_list) == 0: - error_args = {'class_name': self.__class__.__name__} - raise Http404(self.empty_error % error_args) + if not allow_empty and not self.object_list: + class_name = self.__class__.__name__ + error_msg = self.empty_error % {'class_name': class_name} + raise Http404(error_msg) # Pagination size is set by the `.paginate_by` attribute, # which may be `None` to disable pagination. @@ -51,7 +60,7 @@ class ListModelMixin(object): paginator, page, queryset, is_paginated = packed serializer = self.get_pagination_serializer(page) else: - serializer = self.get_serializer(instance=self.object_list) + serializer = self.get_serializer(self.object_list) return Response(serializer.data) @@ -63,7 +72,7 @@ class RetrieveModelMixin(object): """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() - serializer = self.get_serializer(instance=self.object) + serializer = self.get_serializer(self.object) return Response(serializer.data) @@ -73,17 +82,21 @@ class UpdateModelMixin(object): Should be mixed in with `SingleObjectBaseView`. """ def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) try: self.object = self.get_object() + success_status_code = status.HTTP_200_OK except Http404: self.object = None + success_status_code = status.HTTP_201_CREATED - serializer = self.get_serializer(data=request.DATA, instance=self.object) + serializer = self.get_serializer(self.object, data=request.DATA, + files=request.FILES, partial=partial) if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save() - return Response(serializer.data) + return Response(serializer.data, status=success_status_code) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -101,6 +114,11 @@ class UpdateModelMixin(object): slug_field = self.get_slug_field() setattr(obj, slug_field, slug) + # Ensure we clean the attributes so that we don't eg return integer + # pk using a string representation, as provided by the url conf kwarg. + if hasattr(obj, 'full_clean'): + obj.full_clean() + class DestroyModelMixin(object): """ @@ -108,6 +126,6 @@ class DestroyModelMixin(object): Should be mixed in with `SingleObjectBaseView`. """ def destroy(self, request, *args, **kwargs): - self.object = self.get_object() - self.object.delete() + obj = self.get_object() + obj.delete() return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 444f8056..ee2800a6 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,6 +1,8 @@ +from django.http import Http404 from rest_framework import exceptions from rest_framework.settings import api_settings from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches +from rest_framework.utils.mediatypes import _MediaType class BaseContentNegotiation(object): @@ -47,7 +49,8 @@ class DefaultContentNegotiation(BaseContentNegotiation): for media_type in media_type_set: if media_type_matches(renderer.media_type, media_type): # Return the most specific media type as accepted. - if len(renderer.media_type) > len(media_type): + if (_MediaType(renderer.media_type).precedence > + _MediaType(media_type).precedence): # Eg client requests '*/*' # Accepted media type is 'application/json' return renderer, renderer.media_type @@ -66,7 +69,7 @@ class DefaultContentNegotiation(BaseContentNegotiation): renderers = [renderer for renderer in renderers if renderer.format == format] if not renderers: - raise exceptions.InvalidFormat(format) + raise Http404 return renderers def get_accept_list(self, request): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 131718fd..d241ade7 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,4 +1,5 @@ from rest_framework import serializers +from rest_framework.templatetags.rest_framework import replace_query_param # TODO: Support URLconf kwarg-style paging @@ -7,30 +8,30 @@ class NextPageField(serializers.Field): """ Field that returns a link to the next page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_next(): return None page = value.next_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri(relative_url) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PreviousPageField(serializers.Field): """ Field that returns a link to the previous page in paginated results. """ + page_field = 'page' + def to_native(self, value): if not value.has_previous(): return None page = value.previous_page_number() request = self.context.get('request') - relative_url = '?page=%d' % page - if request: - return request.build_absolute_uri('?page=%d' % page) - return relative_url + url = request and request.build_absolute_uri() or '' + return replace_query_param(url, self.page_field, page) class PaginationSerializerOptions(serializers.SerializerOptions): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 4841676c..149d6431 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -8,11 +8,11 @@ on the request, such as form content or json encoded data. from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError -from django.utils import simplejson as json from rest_framework.compat import yaml, ETParseError from rest_framework.exceptions import ParseError from xml.etree import ElementTree as ET from xml.parsers.expat import ExpatError +import json import datetime import decimal diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 6f848cee..655b78a3 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -18,6 +18,17 @@ class BasePermission(object): raise NotImplementedError(".has_permission() must be overridden.") +class AllowAny(BasePermission): + """ + Allow any access. + This isn't strictly required, since you could use an empty + permission_classes list, but it's useful because it makes the intention + more explicit. + """ + def has_permission(self, request, view, obj=None): + return True + + class IsAuthenticated(BasePermission): """ Allows access only to authenticated users. @@ -85,7 +96,7 @@ class DjangoModelPermissions(BasePermission): """ kwargs = { 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.module_name + 'model_name': model_cls._meta.module_name } return [perm % kwargs for perm in self.perms_map[method]] diff --git a/rest_framework/relations.py b/rest_framework/relations.py new file mode 100644 index 00000000..5e4552b7 --- /dev/null +++ b/rest_framework/relations.py @@ -0,0 +1,509 @@ +from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.urlresolvers import resolve, get_script_prefix +from django import forms +from django.forms import widgets +from django.forms.models import ModelChoiceIterator +from django.utils.encoding import smart_unicode +from django.utils.translation import ugettext_lazy as _ +from rest_framework.fields import Field, WritableField +from rest_framework.reverse import reverse +from urlparse import urlparse + +##### Relational fields ##### + + +# Not actually Writable, but subclasses may need to be. +class RelatedField(WritableField): + """ + Base class for related model fields. + + If not overridden, this represents a to-one relationship, using the unicode + representation of the target. + """ + widget = widgets.Select + cache_choices = False + empty_label = None + default_read_only = True # TODO: Remove this + + def __init__(self, *args, **kwargs): + self.queryset = kwargs.pop('queryset', None) + self.null = kwargs.pop('null', False) + super(RelatedField, self).__init__(*args, **kwargs) + self.read_only = kwargs.pop('read_only', self.default_read_only) + + def initialize(self, parent, field_name): + super(RelatedField, self).initialize(parent, field_name) + if self.queryset is None and not self.read_only: + try: + manager = getattr(self.parent.opts.model, self.source or field_name) + if hasattr(manager, 'related'): # Forward + self.queryset = manager.related.model._default_manager.all() + else: # Reverse + self.queryset = manager.field.rel.to._default_manager.all() + except: + raise + msg = ('Serializer related fields must include a `queryset`' + + ' argument or set `read_only=True') + raise Exception(msg) + + ### We need this stuff to make form choices work... + + # def __deepcopy__(self, memo): + # result = super(RelatedField, self).__deepcopy__(memo) + # result.queryset = result.queryset + # return result + + def prepare_value(self, obj): + return self.to_native(obj) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + def _get_queryset(self): + return self._queryset + + def _set_queryset(self, queryset): + self._queryset = queryset + self.widget.choices = self.choices + + queryset = property(_get_queryset, _set_queryset) + + def _get_choices(self): + # If self._choices is set, then somebody must have manually set + # the property self.choices. In this case, just return self._choices. + if hasattr(self, '_choices'): + return self._choices + + # Otherwise, execute the QuerySet in self.queryset to determine the + # choices dynamically. Return a fresh ModelChoiceIterator that has not been + # consumed. Note that we're instantiating a new ModelChoiceIterator *each* + # time _get_choices() is called (and, thus, each time self.choices is + # accessed) so that we can ensure the QuerySet has not been consumed. This + # construct might look complicated but it allows for lazy evaluation of + # the queryset. + return ModelChoiceIterator(self) + + def _set_choices(self, value): + # Setting choices also sets the choices on the widget. + # choices can be any iterable, but we call list() on it because + # it will be consumed more than once. + self._choices = self.widget.choices = list(value) + + choices = property(_get_choices, _set_choices) + + ### Regular serializer stuff... + + def field_to_native(self, obj, field_name): + try: + value = getattr(obj, self.source or field_name) + except ObjectDoesNotExist: + return None + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + try: + value = data[field_name] + except KeyError: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + if value in (None, '') and not self.null: + raise ValidationError('Value may not be null') + elif value in (None, '') and self.null: + into[(self.source or field_name)] = None + else: + into[(self.source or field_name)] = self.from_native(value) + + +class ManyRelatedMixin(object): + """ + Mixin to convert a related field to a many related field. + """ + widget = widgets.SelectMultiple + + def field_to_native(self, obj, field_name): + value = getattr(obj, self.source or field_name) + return [self.to_native(item) for item in value.all()] + + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + try: + # Form data + value = data.getlist(self.source or field_name) + except: + # Non-form data + value = data.get(self.source or field_name) + else: + if value == ['']: + value = [] + + into[field_name] = [self.from_native(item) for item in value] + + +class ManyRelatedField(ManyRelatedMixin, RelatedField): + """ + Base class for related model managers. + + If not overridden, this represents a to-many relationship, using the unicode + representations of the target, and is read-only. + """ + pass + + +### PrimaryKey relationships + +class PrimaryKeyRelatedField(RelatedField): + """ + Represents a to-one relationship as a pk value. + """ + default_read_only = False + form_field_class = forms.ChoiceField + + default_error_messages = { + 'does_not_exist': _("Invalid pk '%s' - object does not exist."), + 'invalid': _('Invalid value.'), + } + + # TODO: Remove these field hacks... + def prepare_value(self, obj): + return self.to_native(obj.pk) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj.pk)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + # TODO: Possibly change this to just take `obj`, through prob less performant + def to_native(self, pk): + return pk + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + msg = self.error_messages['does_not_exist'] % smart_unicode(data) + raise ValidationError(msg) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + def field_to_native(self, obj, field_name): + try: + # Prefer obj.serializable_value for performance reasons + pk = obj.serializable_value(self.source or field_name) + except AttributeError: + # RelatedObject (reverse relationship) + try: + obj = getattr(obj, self.source or field_name) + except ObjectDoesNotExist: + return None + return self.to_native(obj.pk) + # Forward relationship + return self.to_native(pk) + + +class ManyPrimaryKeyRelatedField(ManyRelatedField): + """ + Represents a to-many relationship as a pk value. + """ + default_read_only = False + form_field_class = forms.MultipleChoiceField + + default_error_messages = { + 'does_not_exist': _("Invalid pk '%s' - object does not exist."), + 'invalid': _('Invalid value.'), + } + + def prepare_value(self, obj): + return self.to_native(obj.pk) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_unicode(obj) + ident = smart_unicode(self.to_native(obj.pk)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + def to_native(self, pk): + return pk + + def field_to_native(self, obj, field_name): + try: + # Prefer obj.serializable_value for performance reasons + queryset = obj.serializable_value(self.source or field_name) + except AttributeError: + # RelatedManager (reverse relationship) + queryset = getattr(obj, self.source or field_name) + return [self.to_native(item.pk) for item in queryset.all()] + # Forward relationship + return [self.to_native(item.pk) for item in queryset.all()] + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + msg = self.error_messages['does_not_exist'] % smart_unicode(data) + raise ValidationError(msg) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + +### Slug relationships + + +class SlugRelatedField(RelatedField): + default_read_only = False + form_field_class = forms.ChoiceField + + default_error_messages = { + 'does_not_exist': _("Object with %s=%s does not exist."), + 'invalid': _('Invalid value.'), + } + + def __init__(self, *args, **kwargs): + self.slug_field = kwargs.pop('slug_field', None) + assert self.slug_field, 'slug_field is required' + super(SlugRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + return getattr(obj, self.slug_field) + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(**{self.slug_field: data}) + except ObjectDoesNotExist: + raise ValidationError(self.error_messages['does_not_exist'] % + (self.slug_field, unicode(data))) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + +class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): + form_field_class = forms.MultipleChoiceField + + +### Hyperlinked relationships + +class HyperlinkedRelatedField(RelatedField): + """ + Represents a to-one relationship, using hyperlinking. + """ + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + default_read_only = False + form_field_class = forms.ChoiceField + + default_error_messages = { + 'no_match': _('Invalid hyperlink - No URL match'), + 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), + 'configuration_error': _('Invalid hyperlink due to configuration error'), + 'does_not_exist': _("Invalid hyperlink - object does not exist."), + 'invalid': _('Invalid value.'), + } + + def __init__(self, *args, **kwargs): + try: + self.view_name = kwargs.pop('view_name') + except: + raise ValueError("Hyperlinked field requires 'view_name' kwarg") + + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + + self.format = kwargs.pop('format', None) + super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + + def get_slug_field(self): + """ + Get the name of a slug field to be used to look up by slug. + """ + return self.slug_field + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + format = self.format or self.context.get('format', None) + pk = getattr(obj, 'pk', None) + if pk is None: + return + kwargs = {self.pk_url_kwarg: pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + + def from_native(self, value): + # Convert URL -> model instance pk + # TODO: Use values_list + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + http_prefix = value.startswith('http:') or value.startswith('https:') + except AttributeError: + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + if http_prefix: + # If needed convert absolute URLs to relative path + value = urlparse(value).path + prefix = get_script_prefix() + if value.startswith(prefix): + value = '/' + value[len(prefix):] + + try: + match = resolve(value) + except: + raise ValidationError(self.error_messages['no_match']) + + if match.view_name != self.view_name: + raise ValidationError(self.error_messages['incorrect_match']) + + pk = match.kwargs.get(self.pk_url_kwarg, None) + slug = match.kwargs.get(self.slug_url_kwarg, None) + + # Try explicit primary key. + if pk is not None: + queryset = self.queryset.filter(pk=pk) + # Next, try looking up by slug. + elif slug is not None: + slug_field = self.get_slug_field() + queryset = self.queryset.filter(**{slug_field: slug}) + # If none of those are defined, it's probably a configuation error. + else: + raise ValidationError(self.error_messages['configuration_error']) + + try: + obj = queryset.get() + except ObjectDoesNotExist: + raise ValidationError(self.error_messages['does_not_exist']) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + return obj + + +class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): + """ + Represents a to-many relationship, using hyperlinking. + """ + form_field_class = forms.MultipleChoiceField + + +class HyperlinkedIdentityField(Field): + """ + Represents the instance, or a property on the instance, using hyperlinking. + """ + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + + def __init__(self, *args, **kwargs): + # TODO: Make view_name mandatory, and have the + # HyperlinkedModelSerializer set it on-the-fly + self.view_name = kwargs.pop('view_name', None) + # Optionally the format of the target hyperlink may be specified + self.format = kwargs.pop('format', None) + + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + + super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) + + def field_to_native(self, obj, field_name): + request = self.context.get('request', None) + format = self.context.get('format', None) + view_name = self.view_name or self.parent.opts.view_name + kwargs = {self.pk_url_kwarg: obj.pk} + + # By default use whatever format is given for the current context + # unless the target is a different type to the source. + # + # Eg. Consider a HyperlinkedIdentityField pointing from a json + # representation to an html property of that representation... + # + # '/snippets/1/' should link to '/snippets/1/highlight/' + # ...but... + # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' + if format and self.format and self.format != format: + format = self.format + + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except: + pass + + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index c64fb517..0a34abaa 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -4,14 +4,14 @@ Renderers are used to serialize a response into specific media types. They give us a generic way of being able to handle various media types on the response, such as JSON encoded data or HTML output. -REST framework also provides an HTML renderer the renders the browseable API. +REST framework also provides an HTML renderer the renders the browsable API. """ import copy import string +import json from django import forms from django.http.multipartparser import parse_header -from django.template import RequestContext, loader -from django.utils import simplejson as json +from django.template import RequestContext, loader, Template from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings @@ -19,8 +19,8 @@ from rest_framework.request import clone_request from rest_framework.utils import dict2xml from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import VERSION -from rest_framework import serializers, parsers +from rest_framework import VERSION, status +from rest_framework import parsers class BaseRenderer(object): @@ -100,7 +100,7 @@ class JSONPRenderer(JSONRenderer): callback = self.get_callback(renderer_context) json = super(JSONPRenderer, self).render(data, accepted_media_type, renderer_context) - return "%s(%s);" % (callback, json) + return u"%s(%s);" % (callback, json) class XMLRenderer(BaseRenderer): @@ -139,18 +139,33 @@ class YAMLRenderer(BaseRenderer): return yaml.dump(data, stream=None, Dumper=self.encoder) -class HTMLRenderer(BaseRenderer): +class TemplateHTMLRenderer(BaseRenderer): """ - A Base class provided for convenience. + An HTML renderer for use with templates. - Render the object simply by using the given template. - To create a template renderer, subclass this class, and set - the :attr:`media_type` and :attr:`template` attributes. + The data supplied to the Response object should be a dictionary that will + be used as context for the template. + + The template name is determined by (in order of preference): + + 1. An explicit `.template_name` attribute set on the response. + 2. An explicit `.template_name` attribute set on this class. + 3. The return result of calling `view.get_template_names()`. + + For example: + data = {'users': User.objects.all()} + return Response(data, template_name='users.html') + + For pre-rendered HTML, see StaticHTMLRenderer. """ media_type = 'text/html' format = 'html' template_name = None + exception_template_names = [ + '%(status_code)s.html', + 'api_exception.html' + ] def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -167,15 +182,21 @@ class HTMLRenderer(BaseRenderer): request = renderer_context['request'] response = renderer_context['response'] - template_names = self.get_template_names(response, view) - template = self.resolve_template(template_names) - context = self.resolve_context(data, request) + if response.exception: + template = self.get_exception_template(response) + else: + template_names = self.get_template_names(response, view) + template = self.resolve_template(template_names) + + context = self.resolve_context(data, request, response) return template.render(context) def resolve_template(self, template_names): return loader.select_template(template_names) - def resolve_context(self, data, request): + def resolve_context(self, data, request, response): + if response.exception: + data['status_code'] = response.status_code return RequestContext(request, data) def get_template_names(self, response, view): @@ -187,6 +208,48 @@ class HTMLRenderer(BaseRenderer): return view.get_template_names() raise ConfigurationError('Returned a template response with no template_name') + def get_exception_template(self, response): + template_names = [name % {'status_code': response.status_code} + for name in self.exception_template_names] + + try: + # Try to find an appropriate error template + return self.resolve_template(template_names) + except: + # Fall back to using eg '404 Not Found' + return Template('%d %s' % (response.status_code, + response.status_text.title())) + + +# Note, subclass TemplateHTMLRenderer simply for the exception behavior +class StaticHTMLRenderer(TemplateHTMLRenderer): + """ + An HTML renderer class that simply returns pre-rendered HTML. + + The data supplied to the Response object should be a string representing + the pre-rendered HTML content. + + For example: + data = '<html><body>example</body></html>' + return Response(data) + + For template rendered HTML, see TemplateHTMLRenderer. + """ + media_type = 'text/html' + format = 'html' + + def render(self, data, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context['response'] + + if response and response.exception: + request = renderer_context['request'] + template = self.get_exception_template(response) + context = self.resolve_context(data, request, response) + return template.render(context) + + return data + class BrowsableAPIRenderer(BaseRenderer): """ @@ -224,7 +287,7 @@ class BrowsableAPIRenderer(BaseRenderer): return content - def show_form_for_method(self, view, method, request): + def show_form_for_method(self, view, method, request, obj): """ Returns True if a form should be shown for this method. """ @@ -236,46 +299,32 @@ class BrowsableAPIRenderer(BaseRenderer): request = clone_request(request, method) try: - if not view.has_permission(request): + if not view.has_permission(request, obj): return # Don't have permission except: return # Don't have permission and exception explicitly raise return True def serializer_to_form_fields(self, serializer): - field_mapping = { - serializers.FloatField: forms.FloatField, - serializers.IntegerField: forms.IntegerField, - serializers.DateTimeField: forms.DateTimeField, - serializers.DateField: forms.DateField, - serializers.EmailField: forms.EmailField, - serializers.CharField: forms.CharField, - serializers.BooleanField: forms.BooleanField, - serializers.PrimaryKeyRelatedField: forms.ModelChoiceField, - serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField - } - fields = {} - for k, v in serializer.get_fields(True).items(): - if getattr(v, 'readonly', True): + for k, v in serializer.get_fields().items(): + if getattr(v, 'read_only', True): continue kwargs = {} kwargs['required'] = v.required - if getattr(v, 'queryset', None): - kwargs['queryset'] = v.queryset + #if getattr(v, 'queryset', None): + # kwargs['queryset'] = v.queryset + + if getattr(v, 'choices', None) is not None: + kwargs['choices'] = v.choices + + if getattr(v, 'regex', None) is not None: + kwargs['regex'] = v.regex if getattr(v, 'widget', None): widget = copy.deepcopy(v.widget) - # If choices have friendly readable names, - # then add in the identities too - if getattr(widget, 'choices', None): - choices = widget.choices - if any([ident != desc for (ident, desc) in choices]): - choices = [(ident, "%s (%s)" % (desc, ident)) - for (ident, desc) in choices] - widget.choices = choices kwargs['widget'] = widget if getattr(v, 'default', None) is not None: @@ -283,10 +332,7 @@ class BrowsableAPIRenderer(BaseRenderer): kwargs['label'] = k - try: - fields[k] = field_mapping[v.__class__](**kwargs) - except KeyError: - fields[k] = forms.CharField(**kwargs) + fields[k] = v.form_field_class(**kwargs) return fields def get_form(self, view, method, request): @@ -295,7 +341,8 @@ class BrowsableAPIRenderer(BaseRenderer): In the absence on of the Resource having an associated form then provide a form that can be used to submit arbitrary content. """ - if not self.show_form_for_method(view, method, request): + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): return if method == 'DELETE' or method == 'OPTIONS': @@ -305,17 +352,13 @@ class BrowsableAPIRenderer(BaseRenderer): media_types = [parser.media_type for parser in view.parser_classes] return self.get_generic_content_form(media_types) - # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - obj, data = None, None - if getattr(view, 'object', None): - obj = view.object - serializer = view.get_serializer(instance=obj) fields = self.serializer_to_form_fields(serializer) + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) - if obj: - data = serializer.data + data = (obj is not None) and serializer.data or None form_instance = OnTheFlyForm(data) return form_instance @@ -416,7 +459,7 @@ class BrowsableAPIRenderer(BaseRenderer): # Munge DELETE Response code to allow us to return content # (Do this *after* we've rendered the template so that we include # the normal deletion response code in the output) - if response.status_code == 204: - response.status_code = 200 + if response.status_code == status.HTTP_204_NO_CONTENT: + response.status_code = status.HTTP_200_OK return ret diff --git a/rest_framework/request.py b/rest_framework/request.py index 5870be82..b7133608 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -21,8 +21,8 @@ def is_form_media_type(media_type): Return True if the media type is a valid form media type. """ base_media_type, params = parse_header(media_type) - return base_media_type == 'application/x-www-form-urlencoded' or \ - base_media_type == 'multipart/form-data' + return (base_media_type == 'application/x-www-form-urlencoded' or + base_media_type == 'multipart/form-data') class Empty(object): @@ -169,6 +169,15 @@ class Request(object): self._user, self._auth = self._authenticate() return self._user + @user.setter + def user(self, value): + """ + Sets the user on the current request. This is necessary to maintain + compatilbility with django.contrib.auth where the user proprety is + set in the login and logout functions. + """ + self._user = value + @property def auth(self): """ @@ -179,6 +188,14 @@ class Request(object): self._user, self._auth = self._authenticate() return self._auth + @auth.setter + def auth(self, value): + """ + Sets any non-user authentication information associated with the + request, such as an authentication token. + """ + self._auth = value + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. diff --git a/rest_framework/response.py b/rest_framework/response.py index 7a459c8f..be78c43a 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -9,18 +9,23 @@ class Response(SimpleTemplateResponse): """ def __init__(self, data=None, status=200, - template_name=None, headers=None): + template_name=None, headers=None, + exception=False): """ Alters the init arguments slightly. For example, drop 'template_name', and instead use 'data'. - Setting 'renderer' and 'media_type' will typically be defered, + Setting 'renderer' and 'media_type' will typically be deferred, For example being set automatically by the `APIView`. """ super(Response, self).__init__(None, status=status) self.data = data - self.headers = headers and headers[:] or [] self.template_name = template_name + self.exception = exception + + if headers: + for name,value in headers.iteritems(): + self[name] = value @property def rendered_content(self): @@ -45,3 +50,13 @@ class Response(SimpleTemplateResponse): # TODO: Deprecate and use a template tag instead # TODO: Status code text for RFC 6585 status codes return STATUS_CODE_TEXT.get(self.status_code, '') + + def __getstate__(self): + """ + Remove attributes from the response that shouldn't be cached + """ + state = super(Response, self).__getstate__() + for key in ('accepted_renderer', 'renderer_context', 'data'): + if key in state: + del state[key] + return state diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index ba663f98..c9db02f0 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse from django.utils.functional import lazy -def reverse(viewname, *args, **kwargs): +def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ Same as `django.core.urlresolvers.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ - request = kwargs.pop('request', None) - url = django_reverse(viewname, *args, **kwargs) + if format is not None: + kwargs = kwargs or {} + kwargs['format'] = format + url = django_reverse(viewname, args=args, kwargs=kwargs, **extra) if request: return request.build_absolute_uri(url) return url diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py index ea2e3d45..bcab1d14 100755 --- a/rest_framework/runtests/runcoverage.py +++ b/rest_framework/runtests/runcoverage.py @@ -8,6 +8,9 @@ Useful tool to run the test suite for rest_framework and generate a coverage rep # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' from coverage import coverage @@ -32,10 +35,10 @@ def main(): 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', DeprecationWarning ) - failures = TestRunner(['rest_framework']) + failures = TestRunner(['tests']) else: test_runner = TestRunner() - failures = test_runner.run_tests(['rest_framework']) + failures = test_runner.run_tests(['tests']) cov.stop() # Discover the list of all modules that we should test coverage for @@ -55,6 +58,12 @@ def main(): if 'compat.py' in files: files.remove('compat.py') + # Same applies to template tags module. + # This module has to include branching on Django versions, + # so it's never possible for it to have full coverage. + if 'rest_framework.py' in files: + files.remove('rest_framework.py') + cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')]) cov.report(cov_files) diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index b2438c9b..505994e2 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -5,6 +5,9 @@ # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' from django.conf import settings @@ -32,7 +35,7 @@ def main(): else: print usage() sys.exit(1) - failures = test_runner.run_tests(['rest_framework' + test_case]) + failures = test_runner.run_tests(['tests' + test_case]) sys.exit(failures) diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 67de82c8..dd5d9dc3 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -21,6 +21,12 @@ DATABASES = { } } +CACHES = { + 'default': { + 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + } +} + # Local time zone for this installation. Choices can be found here: # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name # although not all choices may be available on all operating systems. @@ -91,6 +97,7 @@ INSTALLED_APPS = ( # 'django.contrib.admindocs', 'rest_framework', 'rest_framework.authtoken', + 'rest_framework.tests' ) STATIC_URL = '/static/' @@ -100,13 +107,6 @@ import django if django.VERSION < (1, 3): INSTALLED_APPS += ('staticfiles',) -# OAuth support is optional, so we only test oauth if it's installed. -try: - import oauth_provider -except ImportError: - pass -else: - INSTALLED_APPS += ('oauth_provider',) # If we're running on the Jenkins server we want to archive the coverage reports as XML. import os diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py index 4b7da787..ed5baeae 100644 --- a/rest_framework/runtests/urls.py +++ b/rest_framework/runtests/urls.py @@ -1,7 +1,7 @@ """ Blank URLConf just to keep runtests.py happy. """ -from django.conf.urls.defaults import * +from rest_framework.compat import patterns urlpatterns = patterns('', ) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8ee9a0ec..27458f96 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -3,8 +3,18 @@ import datetime import types from decimal import Decimal from django.db import models +from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model + +# Note: We do the following so that users of the framework can use this style: +# +# example_field = serializers.CharField(...) +# +# This helps keep the seperation between model fields, form fields, and +# serializer fields more explicit. + +from rest_framework.relations import * from rest_framework.fields import * @@ -12,7 +22,16 @@ class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. """ - pass + def __getstate__(self): + """ + Used by pickle (e.g., caching). + Overriden to remove metadata from the dict, since it shouldn't be pickled + and may in some instances be unpickleable. + """ + # return an instance of the first dict in MRO that isn't a DictWithMetadata + for base in self.__class__.__mro__: + if not isinstance(base, DictWithMetadata) and isinstance(base, dict): + return base(self) class SortedDictWithMetadata(SortedDict, DictWithMetadata): @@ -22,10 +41,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata): pass -class RecursionOccured(BaseException): - pass - - def _is_protected_type(obj): """ True if the object is a native datatype that does not need to @@ -33,10 +48,10 @@ def _is_protected_type(obj): """ return isinstance(obj, ( types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) + int, long, + datetime.datetime, datetime.date, datetime.time, + float, Decimal, + basestring) ) @@ -54,7 +69,7 @@ def _get_declared_fields(bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to the correct order of fields. + # in order to maintain the correct order of fields. for base in bases[::-1]: if hasattr(base, 'base_fields'): fields = base.base_fields.items() + fields @@ -73,7 +88,7 @@ class SerializerOptions(object): Meta class options for Serializer """ def __init__(self, meta): - self.nested = getattr(meta, 'nested', False) + self.depth = getattr(meta, 'depth', 0) self.fields = getattr(meta, 'fields', ()) self.exclude = getattr(meta, 'exclude', ()) @@ -83,51 +98,53 @@ class BaseSerializer(Field): pass _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations. + _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. - def __init__(self, data=None, instance=None, context=None, **kwargs): + def __init__(self, instance=None, data=None, files=None, + context=None, partial=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) - self.fields = copy.deepcopy(self.base_fields) self.opts = self._options_class(self.Meta) self.parent = None self.root = None + self.partial = partial - self.stack = [] self.context = context or {} self.init_data = data + self.init_files = files self.object = instance + self.fields = self.get_fields() self._data = None + self._files = None self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. - def default_fields(self, serialize, obj=None, data=None, nested=False): + def get_default_fields(self): """ Return the complete set of default fields for the object, as a dict. """ return {} - def get_fields(self, serialize, obj=None, data=None, nested=False): + def get_fields(self): """ Returns the complete set of fields for the object as a dict. This will be the set of any explicitly declared fields, - plus the set of fields returned by default_fields(). + plus the set of fields returned by get_default_fields(). """ ret = SortedDict() # Get the explicitly declared fields - for key, field in self.fields.items(): + base_fields = copy.deepcopy(self.base_fields) + for key, field in base_fields.items(): ret[key] = field - # Set up the field - field.initialize(parent=self) # Add in the default fields - fields = self.default_fields(serialize, obj, data, nested) - for key, val in fields.items(): + default_fields = self.get_default_fields() + for key, val in default_fields.items(): if key not in ret: ret[key] = val @@ -143,25 +160,25 @@ class BaseSerializer(Field): for key in self.opts.exclude: ret.pop(key, None) + for key, field in ret.items(): + field.initialize(parent=self, field_name=key) + return ret ##### # Field methods - used when the serializer class is itself used as a field. - def initialize(self, parent): + def initialize(self, parent, field_name): """ Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth and recursion. + of state so that we can deal with handling maximum depth. """ - super(BaseSerializer, self).initialize(parent) - self.stack = parent.stack[:] - if parent.opts.nested and not isinstance(parent.opts.nested, bool): - self.opts.nested = parent.opts.nested - 1 - else: - self.opts.nested = parent.opts.nested + super(BaseSerializer, self).initialize(parent, field_name) + if parent.opts.depth: + self.opts.depth = parent.opts.depth - 1 ##### - # Methods to convert or revert from objects <--> primative representations. + # Methods to convert or revert from objects <--> primitive representations. def get_field_key(self, field_name): """ @@ -174,35 +191,32 @@ class BaseSerializer(Field): Core of serialization. Convert an object into a dictionary of serialized field values. """ - if obj in self.stack and not self.source == '*': - raise RecursionOccured() - self.stack.append(obj) - ret = self._dict_class() ret.fields = {} - fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested) - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) - try: - value = field.field_to_native(obj, field_name) - except RecursionOccured: - field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name] - value = field.field_to_native(obj, field_name) + value = field.field_to_native(obj, field_name) ret[key] = value ret.fields[key] = field return ret - def restore_fields(self, data): + def restore_fields(self, data, files): """ Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested) reverted_data = {} - for field_name, field in fields.items(): + + if data is not None and not isinstance(data, dict): + self._errors['non_field_errors'] = [u'Invalid data'] + return None + + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) try: - field.field_from_native(data, field_name, reverted_data) + field.field_from_native(data, files, field_name, reverted_data) except ValidationError as err: self._errors[field_name] = list(err.messages) @@ -212,9 +226,7 @@ class BaseSerializer(Field): """ Run `validate_<fieldname>()` and `validate()` methods on the serializer """ - fields = self.get_fields(serialize=False, data=attrs, nested=self.opts.nested) - - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -223,10 +235,18 @@ class BaseSerializer(Field): except ValidationError as err: self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - try: - attrs = self.validate(attrs) - except ValidationError as err: - self._errors['non_field_errors'] = err.messages + # If there are already errors, we don't run .validate() because + # field-validation failed and thus `attrs` may not be complete. + # which in turn can cause inconsistent validation errors. + if not self._errors: + try: + attrs = self.validate(attrs) + except ValidationError as err: + if hasattr(err, 'message_dict'): + for field_name, error_messages in err.message_dict.items(): + self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) + elif hasattr(err, 'messages'): + self._errors['non_field_errors'] = err.messages return attrs @@ -249,26 +269,23 @@ class BaseSerializer(Field): def to_native(self, obj): """ - Serialize objects -> primatives. + Serialize objects -> primitives. """ - if isinstance(obj, dict): - return dict([(key, self.to_native(val)) - for (key, val) in obj.items()]) - elif hasattr(obj, '__iter__'): - return [self.to_native(item) for item in obj] + if hasattr(obj, '__iter__'): + return [self.convert_object(item) for item in obj] return self.convert_object(obj) - def from_native(self, data): + def from_native(self, data, files): """ - Deserialize primatives -> objects. + Deserialize primitives -> objects. """ if hasattr(data, '__iter__') and not isinstance(data, dict): # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) + return [self.from_native(item, None) for item in data] self._errors = {} - if data is not None: - attrs = self.restore_fields(data) + if data is not None or files is not None: + attrs = self.restore_fields(data, files) attrs = self.perform_validation(attrs) else: self._errors['non_field_errors'] = ['No input provided'] @@ -281,22 +298,36 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ - obj = getattr(obj, self.source or field_name) + try: + if self.source: + for component in self.source.split('.'): + obj = getattr(obj, component) + if is_simple_callable(obj): + obj = obj() + else: + obj = getattr(obj, field_name) + if is_simple_callable(obj): + obj = obj() + except ObjectDoesNotExist: + return None # If the object has an "all" method, assume it's a relationship if is_simple_callable(getattr(obj, 'all', None)): return [self.to_native(item) for item in obj.all()] + if obj is None: + return None + return self.to_native(obj) @property def errors(self): """ Run deserialization and return error data, - setting self.object if no errors occured. + setting self.object if no errors occurred. """ if self._errors is None: - obj = self.from_native(self.init_data) + obj = self.from_native(self.init_data, self.init_files) if not self._errors: self.object = obj return self._errors @@ -329,6 +360,7 @@ class ModelSerializerOptions(SerializerOptions): def __init__(self, meta): super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model', None) + self.read_only_fields = getattr(meta, 'read_only_fields', ()) class ModelSerializer(Serializer): @@ -337,16 +369,10 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions - def default_fields(self, serialize, obj=None, data=None, nested=False): + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - # TODO: Modfiy this so that it's called on init, and drop - # serialize/obj/data arguments. - # - # We *could* provide a hook for dynamic fields, but - # it'd be nice if the default was to generate fields statically - # at the point of __init__ cls = self.opts.model opts = get_concrete_model(cls)._meta @@ -358,6 +384,7 @@ class ModelSerializer(Serializer): fields += [field for field in opts.many_to_many if field.serialize] ret = SortedDict() + nested = bool(self.opts.depth) is_pk = True # First field in the list is the pk for model_field in fields: @@ -374,22 +401,30 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: - field.initialize(parent=self) ret[model_field.name] = field + for field_name in self.opts.read_only_fields: + assert field_name in ret, \ + "read_only_fields on '%s' included invalid item '%s'" % \ + (self.__class__.__name__, field_name) + ret[field_name].read_only = True + return ret def get_pk_field(self, model_field): """ Returns a default instance of the pk field. """ - return Field() + return self.get_field(model_field) def get_nested_field(self, model_field): """ Creates a default instance of a nested relational field. """ - return ModelSerializer() + class NestedModelSerializer(ModelSerializer): + class Meta: + model = model_field.rel.to + return NestedModelSerializer() def get_related_field(self, model_field, to_many=False): """ @@ -397,20 +432,43 @@ class ModelSerializer(Serializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - queryset = model_field.rel.to._default_manager + kwargs = { + 'null': model_field.null or model_field.blank, + 'queryset': model_field.rel.to._default_manager + } + if to_many: - return ManyPrimaryKeyRelatedField(queryset=queryset) - return PrimaryKeyRelatedField(queryset=queryset) + return ManyPrimaryKeyRelatedField(**kwargs) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ Creates a default instance of a basic non-relational field. """ kwargs = {} + + kwargs['blank'] = model_field.blank + + if model_field.null or model_field.blank: + kwargs['required'] = False + + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + if model_field.has_default(): kwargs['required'] = False + kwargs['default'] = model_field.get_default() + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + # TODO: TypedChoiceField? + if model_field.flatchoices: # This ModelField contains choices + kwargs['choices'] = model_field.flatchoices + return ChoiceField(**kwargs) field_mapping = { + models.AutoField: IntegerField, models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, @@ -420,42 +478,86 @@ class ModelSerializer(Serializer): models.DateField: DateField, models.EmailField: EmailField, models.CharField: CharField, + models.URLField: URLField, + models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, + models.FileField: FileField, + models.ImageField: ImageField, } try: return field_mapping[model_field.__class__](**kwargs) except KeyError: return ModelField(model_field=model_field, **kwargs) + def get_validation_exclusions(self): + """ + Return a list of field names to exclude from model validation. + """ + cls = self.opts.model + opts = get_concrete_model(cls)._meta + exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): + if field_name in exclusions and not field.read_only: + exclusions.remove(field_name) + return exclusions + def restore_object(self, attrs, instance=None): """ Restore the model instance. """ self.m2m_data = {} + self.related_data = {} - if instance: - for key, val in attrs.items(): - setattr(instance, key, val) - return instance + # Reverse fk relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) + # Reverse m2m relations + for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.m2m_data[field_name] = attrs.pop(field_name) + + # Forward m2m relations for field in self.opts.model._meta.many_to_many: if field.name in attrs: self.m2m_data[field.name] = attrs.pop(field.name) - return self.opts.model(**attrs) - def save(self, save_m2m=True): + if instance is not None: + for key, val in attrs.items(): + setattr(instance, key, val) + + else: + instance = self.opts.model(**attrs) + + try: + instance.full_clean(exclude=self.get_validation_exclusions()) + except ValidationError, err: + self._errors = err.message_dict + return None + + return instance + + def save(self): """ Save the deserialized object and return it. """ self.object.save() - if self.m2m_data and save_m2m: + if getattr(self, 'm2m_data', None): for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} + if getattr(self, 'related_data', None): + for accessor_name, object_list in self.related_data.items(): + setattr(self.object, accessor_name, object_list) + self.related_data = {} + return self.object @@ -502,9 +604,9 @@ class HyperlinkedModelSerializer(ModelSerializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) rel = model_field.rel.to - queryset = rel._default_manager kwargs = { - 'queryset': queryset, + 'null': model_field.null, + 'queryset': rel._default_manager, 'view_name': self._get_default_view_name(rel) } if to_many: diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 3c508294..5c77c55c 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -37,11 +37,14 @@ DEFAULTS = { 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' ), - 'DEFAULT_PERMISSION_CLASSES': (), - 'DEFAULT_THROTTLE_CLASSES': (), + 'DEFAULT_PERMISSION_CLASSES': ( + 'rest_framework.permissions.AllowAny', + ), + 'DEFAULT_THROTTLE_CLASSES': ( + ), + 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', - 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': @@ -51,18 +54,26 @@ DEFAULTS = { 'user': None, 'anon': None, }, + + # Pagination 'PAGINATE_BY': None, + 'PAGINATE_BY_PARAM': None, + + # Filtering + 'FILTER_BACKEND': None, + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # Browser enhancements 'FORM_METHOD_OVERRIDE': '_method', 'FORM_CONTENT_OVERRIDE': '_content', 'FORM_CONTENTTYPE_OVERRIDE': '_content_type', 'URL_ACCEPT_OVERRIDE': 'accept', 'URL_FORMAT_OVERRIDE': 'format', - 'FORMAT_SUFFIX_KWARG': 'format' + 'FORMAT_SUFFIX_KWARG': 'format', } @@ -76,6 +87,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) @@ -103,8 +115,8 @@ def import_from_string(val, setting_name): module_path, class_name = '.'.join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) - except: - msg = "Could not import '%s' for API setting '%s'" % (val, setting_name) + except ImportError as e: + msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) @@ -139,8 +151,15 @@ class APISettings(object): if val and attr in self.import_strings: val = perform_import(val, attr) + self.validate_setting(attr, val) + # Cache the result setattr(self, attr, val) return val + def validate_setting(self, attr, val): + if attr == 'FILTER_BACKEND' and val is not None: + # Make sure we can initialize the class + val() + api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS) diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index 739b9300..b2e41b99 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -32,6 +32,17 @@ h2, h3 { margin-right: 1em; } +ul.breadcrumb { + margin: 58px 0 0 0; +} + +form select, form input, form textarea { + width: 90%; +} + +form select[multiple] { + height: 150px; +} /* To allow tooltips to work on disabled elements */ .disabled-tooltip-shield { position: absolute; @@ -55,6 +66,7 @@ pre { .page-header { border-bottom: none; padding-bottom: 0px; + margin-bottom: 20px; } @@ -65,7 +77,7 @@ html{ background: none; } -body, .navbar .navbar-inner .container-fluid{ +body, .navbar .navbar-inner .container-fluid { max-width: 1150px; margin: 0 auto; } @@ -76,13 +88,14 @@ body{ } #content{ - margin: 40px 0 0 0; + margin: 0; } /* custom navigation styles */ .wrapper .navbar{ - width:100%; + width: 100%; position: absolute; - left:0; + left: 0; + top: 0; } .navbar .navbar-inner{ diff --git a/rest_framework/status.py b/rest_framework/status.py index f3a5e481..a1eb48da 100644 --- a/rest_framework/status.py +++ b/rest_framework/status.py @@ -49,4 +49,4 @@ HTTP_502_BAD_GATEWAY = 502 HTTP_503_SERVICE_UNAVAILABLE = 503 HTTP_504_GATEWAY_TIMEOUT = 504 HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 -HTTP_511_NETWORD_AUTHENTICATION_REQUIRED = 511 +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 5ac6ef67..092bf2e4 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -1,6 +1,5 @@ {% load url from future %} {% load rest_framework %} -{% load static %} <!DOCTYPE html> <html> <head> @@ -14,10 +13,10 @@ <title>{% block title %}Django REST framework{% endblock %}</title> {% block style %} - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/prettify.css'/> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> {% endblock %} {% endblock %} @@ -109,11 +108,11 @@ <div class="content-main"> <div class="page-header"><h1>{{ name }}</h1></div> - <p class="resource-description">{{ description }}</p> + {{ description }} <div class="request-info"> <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> - <div> + </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} {% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> @@ -131,12 +130,12 @@ {% csrf_token %} {{ post_form.non_field_errors }} {% for field in post_form %} - <div class="control-group {% if field.errors %}error{% endif %}"> + <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> {{ field.label_tag|add_class:"control-label" }} <div class="controls"> - {{ field|add_class:"input-xlarge" }} + {{ field }} <span class="help-inline">{{ field.help_text }}</span> - {{ field.errors|add_class:"help-block" }} + <!--{{ field.errors|add_class:"help-block" }}--> </div> </div> {% endfor %} @@ -156,12 +155,12 @@ {% csrf_token %} {{ put_form.non_field_errors }} {% for field in put_form %} - <div class="control-group {% if field.errors %}error{% endif %}"> + <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> {{ field.label_tag|add_class:"control-label" }} <div class="controls"> - {{ field|add_class:"input-xlarge" }} + {{ field }} <span class='help-inline'>{{ field.help_text }}</span> - {{ field.errors|add_class:"help-block" }} + <!--{{ field.errors|add_class:"help-block" }}--> </div> </div> {% endfor %} @@ -195,10 +194,10 @@ {% endblock %} {% block script %} - <script src="{% get_static_prefix %}rest_framework/js/jquery-1.8.1-min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/bootstrap.min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/prettify-min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/default.js"></script> + <script src="{% static "rest_framework/js/jquery-1.8.1-min.js" %}"></script> + <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script> + <script src="{% static "rest_framework/js/prettify-min.js" %}"></script> + <script src="{% static "rest_framework/js/default.js" %}"></script> {% endblock %} </body> </html> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index 65af512e..6e2bd8d4 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,44 +1,52 @@ {% load url from future %} -{% load static %} +{% load rest_framework %} <html> <head> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> </head> - <body class="login"> + <body class="container"> - <div id="container"> - - <div id="header"> - <div id="branding"> - <h1 id="site-name">Django REST framework</h1> +<div class="container-fluid" style="margin-top: 30px"> + <div class="row-fluid"> + + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + <h3 style="margin: 0 0 20px;">Django REST framework</h3> </div> - </div> + </div><!-- /row fluid --> - <div id="content" class="colM"> - <div id="content-main"> - <form method="post" action="{% url 'rest_framework:login' %}" id="login-form"> + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> {% csrf_token %} - <div class="form-row"> - <label for="id_username">Username:</label> {{ form.username }} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> </div> - <div class="form-row"> - <label for="id_password">Password:</label> {{ form.password }} - <input type="hidden" name="next" value="{{ next }}" /> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> </div> - <div class="form-row"> - <label> </label><input type="submit" value="Log in"> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> </div> </form> - <script type="text/javascript"> - document.getElementById('id_username').focus() - </script> </div> - <br class="clear"> - </div> + </div><!-- /row fluid --> + </div><!--/span--> - <div id="footer"></div> + </div><!-- /.row-fluid --> + </div> </div> </body> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c9b6eb10..82fcdfe7 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -11,6 +11,101 @@ import string register = template.Library() +# Note we don't use 'load staticfiles', because we need a 1.3 compatible +# version, so instead we include the `static` template tag ourselves. + +# When 1.3 becomes unsupported by REST framework, we can instead start to +# use the {% load staticfiles %} tag, remove the following code, +# and add a dependancy that `django.contrib.staticfiles` must be installed. + +# Note: We can't put this into the `compat` module because the compat import +# from rest_framework.compat import ... +# conflicts with this rest_framework template tag module. + +try: # Django 1.5+ + from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode + + @register.tag('static') + def do_static(parser, token): + return StaticFilesNode.handle_token(parser, token) + +except: + try: # Django 1.4 + from django.contrib.staticfiles.storage import staticfiles_storage + + @register.simple_tag + def static(path): + """ + A template tag that returns the URL to a file + using staticfiles' storage backend + """ + return staticfiles_storage.url(path) + + except: # Django 1.3 + from urlparse import urljoin + from django import template + from django.templatetags.static import PrefixNode + + class StaticNode(template.Node): + def __init__(self, varname=None, path=None): + if path is None: + raise template.TemplateSyntaxError( + "Static template nodes must be given a path to return.") + self.path = path + self.varname = varname + + def url(self, context): + path = self.path.resolve(context) + return self.handle_simple(path) + + def render(self, context): + url = self.url(context) + if self.varname is None: + return url + context[self.varname] = url + return '' + + @classmethod + def handle_simple(cls, path): + return urljoin(PrefixNode.handle_simple("STATIC_URL"), path) + + @classmethod + def handle_token(cls, parser, token): + """ + Class method to parse prefix node and return a Node. + """ + bits = token.split_contents() + + if len(bits) < 2: + raise template.TemplateSyntaxError( + "'%s' takes at least one argument (path to file)" % bits[0]) + + path = parser.compile_filter(bits[1]) + + if len(bits) >= 2 and bits[-2] == 'as': + varname = bits[3] + else: + varname = None + + return cls(varname, path) + + @register.tag('static') + def do_static_13(parser, token): + return StaticNode.handle_token(parser, token) + + +def replace_query_param(url, key, val): + """ + Given a URL and a key/val pair, set or replace an item in the query + parameters of the URL, and return the new URL. + """ + (scheme, netloc, path, query, fragment) = urlsplit(url) + query_dict = QueryDict(query).copy() + query_dict[key] = val + query = query_dict.urlencode() + return urlunsplit((scheme, netloc, path, query, fragment)) + + # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') @@ -31,19 +126,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '| trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') -# Helper function for 'add_query_param' -def replace_query_param(url, key, val): - """ - Given a URL and a key/val pair, set or replace an item in the query - parameters of the URL, and return the new URL. - """ - (scheme, netloc, path, query, fragment) = urlsplit(url) - query_dict = QueryDict(query).copy() - query_dict[key] = val - query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) - - # And the template tags themselves... @register.simple_tag diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py index adeaf6da..e69de29b 100644 --- a/rest_framework/tests/__init__.py +++ b/rest_framework/tests/__init__.py @@ -1,13 +0,0 @@ -""" -Force import of all modules in this package in order to get the standard test -runner to pick up the tests. Yowzers. -""" -import os - -modules = [filename.rsplit('.', 1)[0] - for filename in os.listdir(os.path.dirname(__file__)) - if filename.endswith('.py') and not filename.startswith('_')] -__test__ = dict() - -for module in modules: - exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8ab4c4e4..e86041bc 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,16 +1,14 @@ -from django.conf.urls.defaults import patterns from django.contrib.auth.models import User -from django.test import Client, TestCase - -from django.utils import simplejson as json from django.http import HttpResponse +from django.test import Client, TestCase -from rest_framework.views import APIView from rest_framework import permissions - from rest_framework.authtoken.models import Token from rest_framework.authentication import TokenAuthentication +from rest_framework.compat import patterns +from rest_framework.views import APIView +import json import base64 @@ -27,6 +25,7 @@ MockView.authentication_classes += (TokenAuthentication,) urlpatterns = patterns('', (r'^$', MockView.as_view()), + (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), ) @@ -152,3 +151,33 @@ class TokenAuthTests(TestCase): self.token.delete() token = Token.objects.create(user=self.user) self.assertTrue(bool(token.key)) + + def test_token_login_json(self): + """Ensure token login view using JSON POST works.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/', + json.dumps({'username': self.username, 'password': self.password}), 'application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.content)['token'], self.key) + + def test_token_login_json_bad_creds(self): + """Ensure token login view using JSON POST fails if bad credentials are used.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/', + json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + self.assertEqual(response.status_code, 400) + + def test_token_login_json_missing_fields(self): + """Ensure token login view using JSON POST fails if missing fields.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/', + json.dumps({'username': self.username}), 'application/json') + self.assertEqual(response.status_code, 400) + + def test_token_login_form(self): + """Ensure token login view using form POST works.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/', + {'username': self.username, 'password': self.password}) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.content)['token'], self.key) diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py index 647ab96d..df891683 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/breadcrumbs.py @@ -1,5 +1,5 @@ -from django.conf.urls.defaults import patterns, url from django.test import TestCase +from rest_framework.compat import patterns, url from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.views import APIView diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 41864d71..5e6bce4e 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,7 +1,6 @@ from django.test import TestCase from rest_framework import status from rest_framework.response import Response -from django.test.client import RequestFactory from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser from rest_framework.authentication import BasicAuthentication @@ -17,6 +16,8 @@ from rest_framework.decorators import ( permission_classes, ) +from rest_framework.tests.utils import RequestFactory + class DecoratorTestCase(TestCase): @@ -63,6 +64,20 @@ class DecoratorTestCase(TestCase): response = view(request) self.assertEqual(response.status_code, 405) + def test_calling_patch_method(self): + + @api_view(['GET', 'PATCH']) + def view(request): + return Response({}) + + request = self.factory.patch('/') + response = view(request) + self.assertEqual(response.status_code, 200) + + request = self.factory.post('/') + response = view(request) + self.assertEqual(response.status_code, 405) + def test_renderer_classes(self): @api_view(['GET']) diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/extras/__init__.py diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/rest_framework/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py new file mode 100644 index 00000000..8068272d --- /dev/null +++ b/rest_framework/tests/fields.py @@ -0,0 +1,49 @@ +""" +General serializer field tests. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class TimestampedModel(models.Model): + added = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): + id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): + class Meta: + model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): + class Meta: + model = CharPrimaryKeyModel + + +class ReadOnlyFieldTests(TestCase): + def test_auto_now_fields_read_only(self): + """ + auto_now and auto_now_add fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEquals(serializer.fields['added'].read_only, True) + + def test_auto_pk_fields_read_only(self): + """ + AutoField fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEquals(serializer.fields['id'].read_only, True) + + def test_non_auto_pk_fields_not_read_only(self): + """ + PK fields other than AutoField fields should not be read_only by default. + """ + serializer = CharPrimaryKeyModelSerializer() + self.assertEquals(serializer.fields['id'].read_only, False) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index 61d7f7b1..446e23c0 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -1,34 +1,51 @@ -# from django.test import TestCase -# from django import forms +import StringIO +import datetime -# from django.test.client import RequestFactory -# from rest_framework.views import View -# from rest_framework.response import Response +from django.test import TestCase -# import StringIO +from rest_framework import serializers -# class UploadFilesTests(TestCase): -# """Check uploading of files""" -# def setUp(self): -# self.factory = RequestFactory() +class UploadedFile(object): + def __init__(self, file, created=None): + self.file = file + self.created = created or datetime.datetime.now() -# def test_upload_file(self): -# class FileForm(forms.Form): -# file = forms.FileField() +class UploadedFileSerializer(serializers.Serializer): + file = serializers.FileField() + created = serializers.DateTimeField() -# class MockView(View): -# permissions = () -# form = FileForm + def restore_object(self, attrs, instance=None): + if instance: + instance.file = attrs['file'] + instance.created = attrs['created'] + return instance + return UploadedFile(**attrs) -# def post(self, request, *args, **kwargs): -# return Response({'FILE_NAME': self.CONTENT['file'].name, -# 'FILE_CONTENT': self.CONTENT['file'].read()}) -# file = StringIO.StringIO('stuff') -# file.name = 'stuff.txt' -# request = self.factory.post('/', {'file': file}) -# view = MockView.as_view() -# response = view(request) -# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}) +class FileSerializerTests(TestCase): + def test_create(self): + now = datetime.datetime.now() + file = StringIO.StringIO('stuff') + file.name = 'stuff.txt' + file.size = file.len + serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) + uploaded_file = UploadedFile(file=file, created=now) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.object.created, uploaded_file.created) + self.assertEquals(serializer.object.file, uploaded_file.file) + self.assertFalse(serializer.object is uploaded_file) + + def test_creation_failure(self): + """ + Passing files=None should result in an ValidationError + + Regression test for: + https://github.com/tomchristie/django-rest-framework/issues/542 + """ + now = datetime.datetime.now() + + serializer = UploadedFileSerializer(data={'created': now}) + self.assertFalse(serializer.is_valid()) + self.assertIn('file', serializer.errors) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py new file mode 100644 index 00000000..af2e6c2e --- /dev/null +++ b/rest_framework/tests/filterset.py @@ -0,0 +1,168 @@ +import datetime +from decimal import Decimal +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, status, filters +from rest_framework.compat import django_filters +from rest_framework.tests.models import FilterableItem, BasicModel + +factory = RequestFactory() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backend = filters.DjangoFilterBackend + + +class IntegrationTestFiltering(TestCase): + """ + Integration tests for filtered list views. + """ + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEquals(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEquals(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEquals(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index 1d7e33bc..bc7378e1 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -25,7 +25,7 @@ class TestGenericRelations(TestCase): model = Bookmark exclude = ('id',) - serializer = BookmarkSerializer(instance=self.bookmark) + serializer = BookmarkSerializer(self.bookmark) expected = { 'tags': [u'django', u'python'], 'url': u'https://www.djangoproject.com/' diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f4263478..4799a04b 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,8 +1,9 @@ +import json +from django.db import models from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import generics, serializers, status -from rest_framework.tests.models import BasicModel, Comment +from rest_framework.tests.utils import RequestFactory +from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel factory = RequestFactory() @@ -22,6 +23,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel +class SlugSerializer(serializers.ModelSerializer): + slug = serializers.Field() # read only + + class Meta: + model = SlugBasedModel + exclude = ('id',) + + +class SlugBasedInstanceView(InstanceView): + """ + A model with a slug-field. + """ + model = SlugBasedModel + serializer_class = SlugSerializer + + class TestRootView(TestCase): def setUp(self): """ @@ -129,6 +146,7 @@ class TestInstanceView(TestCase): for obj in self.objects.all() ] self.view = InstanceView.as_view() + self.slug_based_view = SlugBasedInstanceView.as_view() def test_get_instance_view(self): """ @@ -157,6 +175,20 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') + response = self.view(request, pk='1').render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEquals(updated.text, 'foobar') + + def test_patch_instance_view(self): + """ + PATCH requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + response = self.view(request, pk=1).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) @@ -198,7 +230,7 @@ class TestInstanceView(TestCase): def test_put_cannot_set_id(self): """ - POST requests to create a new object should not be able to set the id. + PUT requests to create a new object should not be able to set the id. """ content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), @@ -219,11 +251,39 @@ class TestInstanceView(TestCase): request = factory.put('/1', json.dumps(content), content_type='application/json') response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) self.assertEquals(updated.text, 'foobar') + def test_put_as_create_on_id_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if it doesn't exist. + """ + content = {'text': 'foobar'} + # pk fields can not be created on demand, only the database can set th pk for a new object + request = factory.put('/5', json.dumps(content), + content_type='application/json') + response = self.view(request, pk=5).render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + new_obj = self.objects.get(pk=5) + self.assertEquals(new_obj.text, 'foobar') + + def test_put_as_create_on_slug_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. + """ + content = {'text': 'foobar'} + request = factory.put('/test_slug', json.dumps(content), + content_type='application/json') + response = self.slug_based_view(request, slug='test_slug').render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) + new_obj = SlugBasedModel.objects.get(slug='test_slug') + self.assertEquals(new_obj.text, 'foobar') + # Regression test for #285 @@ -256,3 +316,36 @@ class TestCreateModelWithAutoNowAddField(TestCase): self.assertEquals(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) self.assertEquals(created.content, 'foobar') + + +# Test for particularly ugly reression with m2m in browseable API +class ClassB(models.Model): + name = models.CharField(max_length=255) + + +class ClassA(models.Model): + name = models.CharField(max_length=255) + childs = models.ManyToManyField(ClassB, blank=True, null=True) + + +class ClassASerializer(serializers.ModelSerializer): + childs = serializers.ManyPrimaryKeyRelatedField(source='childs') + + class Meta: + model = ClassA + + +class ExampleView(generics.ListCreateAPIView): + serializer_class = ClassASerializer + model = ClassA + + +class TestM2MBrowseableAPI(TestCase): + def test_m2m_in_browseable_api(self): + """ + Test for particularly ugly reression with m2m in browseable API + """ + request = factory.get('/', HTTP_ACCEPT='text/html') + view = ExampleView().as_view() + response = view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index da2f83c3..54096206 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -1,14 +1,16 @@ -from django.conf.urls.defaults import patterns, url +from django.core.exceptions import PermissionDenied +from django.http import Http404 from django.test import TestCase from django.template import TemplateDoesNotExist, Template import django.template.loader +from rest_framework.compat import patterns, url from rest_framework.decorators import api_view, renderer_classes -from rest_framework.renderers import HTMLRenderer +from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.response import Response @api_view(('GET',)) -@renderer_classes((HTMLRenderer,)) +@renderer_classes((TemplateHTMLRenderer,)) def example(request): """ A view that can returns an HTML representation. @@ -17,12 +19,26 @@ def example(request): return Response(data, template_name='example.html') +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def permission_denied(request): + raise PermissionDenied() + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def not_found(request): + raise Http404() + + urlpatterns = patterns('', url(r'^$', example), + url(r'^permission_denied$', permission_denied), + url(r'^not_found$', not_found), ) -class HTMLRendererTests(TestCase): +class TemplateHTMLRendererTests(TestCase): urls = 'rest_framework.tests.htmlrenderer' def setUp(self): @@ -48,3 +64,52 @@ class HTMLRendererTests(TestCase): response = self.client.get('/') self.assertContains(response, "example: foobar") self.assertEquals(response['Content-Type'], 'text/html') + + def test_not_found_html_view(self): + response = self.client.get('/not_found') + self.assertEquals(response.status_code, 404) + self.assertEquals(response.content, "404 Not Found") + self.assertEquals(response['Content-Type'], 'text/html') + + def test_permission_denied_html_view(self): + response = self.client.get('/permission_denied') + self.assertEquals(response.status_code, 403) + self.assertEquals(response.content, "403 Forbidden") + self.assertEquals(response['Content-Type'], 'text/html') + + +class TemplateHTMLRendererExceptionTests(TestCase): + urls = 'rest_framework.tests.htmlrenderer' + + def setUp(self): + """ + Monkeypatch get_template + """ + self.get_template = django.template.loader.get_template + + def get_template(template_name): + if template_name == '404.html': + return Template("404: {{ detail }}") + if template_name == '403.html': + return Template("403: {{ detail }}") + raise TemplateDoesNotExist(template_name) + + django.template.loader.get_template = get_template + + def tearDown(self): + """ + Revert monkeypatching + """ + django.template.loader.get_template = self.get_template + + def test_not_found_html_view_with_template(self): + response = self.client.get('/not_found') + self.assertEquals(response.status_code, 404) + self.assertEquals(response.content, "404: Not found") + self.assertEquals(response['Content-Type'], 'text/html') + + def test_permission_denied_html_view_with_template(self): + response = self.client.get('/permission_denied') + self.assertEquals(response.status_code, 403) + self.assertEquals(response.content, "403: Permission denied") + self.assertEquals(response['Content-Type'], 'text/html') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 5532a8ee..c6a8224b 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,12 +1,31 @@ -from django.conf.urls.defaults import patterns, url +import json from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, serializers -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel +from rest_framework.compat import patterns, url +from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel factory = RequestFactory() +class BlogPostCommentSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') + text = serializers.CharField() + blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') + + class Meta: + model = BlogPostComment + fields = ('text', 'blog_post_url', 'url') + + +class PhotoSerializer(serializers.Serializer): + description = serializers.CharField() + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + + def restore_object(self, attrs, instance=None): + return Photo(**attrs) + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -32,12 +51,46 @@ class ManyToManyDetail(generics.RetrieveAPIView): model_serializer_class = serializers.HyperlinkedModelSerializer +class BlogPostCommentListCreate(generics.ListCreateAPIView): + model = BlogPostComment + serializer_class = BlogPostCommentSerializer + + +class BlogPostCommentDetail(generics.RetrieveAPIView): + model = BlogPostComment + serializer_class = BlogPostCommentSerializer + + +class BlogPostDetail(generics.RetrieveAPIView): + model = BlogPost + + +class PhotoListCreate(generics.ListCreateAPIView): + model = Photo + model_serializer_class = PhotoSerializer + + +class AlbumDetail(generics.RetrieveAPIView): + model = Album + + +class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): + model = OptionalRelationModel + model_serializer_class = serializers.HyperlinkedModelSerializer + + urlpatterns = patterns('', url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), + url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), + url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), + url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), + url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), + url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), + url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), ) @@ -112,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase): GET requests to ListCreateAPIView should return list of objects. """ request = factory.get('/manytomany/') - response = self.list_view(request).render() + response = self.list_view(request) self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data) @@ -121,6 +174,89 @@ class TestManyToManyHyperlinkedView(TestCase): GET requests to ListCreateAPIView should return list of objects. """ request = factory.get('/manytomany/1/') - response = self.detail_view(request, pk=1).render() + response = self.detail_view(request, pk=1) self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data[0]) + + +class TestCreateWithForeignKeys(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create a blog post + """ + self.post = BlogPost.objects.create(title="Test post") + self.create_view = BlogPostCommentListCreate.as_view() + + def test_create_comment(self): + + data = { + 'text': 'A test comment', + 'blog_post_url': 'http://testserver/posts/1/' + } + + request = factory.post('/comments/', data=data) + response = self.create_view(request) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response['Location'], 'http://testserver/comments/1/') + self.assertEqual(self.post.blogpostcomment_set.count(), 1) + self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') + + +class TestCreateWithForeignKeysAndCustomSlug(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create an Album + """ + self.post = Album.objects.create(title='test-album') + self.list_create_view = PhotoListCreate.as_view() + + def test_create_photo(self): + + data = { + 'description': 'A test photo', + 'album_url': 'http://testserver/albums/test-album/' + } + + request = factory.post('/photos/', data=data) + response = self.list_create_view(request) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') + self.assertEqual(self.post.photo_set.count(), 1) + self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') + + +class TestOptionalRelationHyperlinkedView(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 1 OptionalRelationModel intances. + """ + OptionalRelationModel().save() + self.objects = OptionalRelationModel.objects + self.detail_view = OptionalRelationDetail.as_view() + self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} + + def test_get_detail_view(self): + """ + GET requests to RetrieveAPIView with optional relations should return None + for non existing relations. + """ + request = factory.get('/optionalrelationmodel-detail/1') + response = self.detail_view(request, pk=1) + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data, self.data) + + def test_put_detail_view(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView with optional relations + should accept None for non existing relations. + """ + response = self.client.put('/optionalrelation/1/', + data=json.dumps(self.data), + content_type='application/json') + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 97cd0849..93f09761 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -35,15 +35,27 @@ def foobar(): return 'foobar' +class CustomField(models.CharField): + + def __init__(self, *args, **kwargs): + kwargs['max_length'] = 12 + super(CustomField, self).__init__(*args, **kwargs) + + class RESTFrameworkModel(models.Model): """ Base for test models that sets app_label, so they play nicely. """ class Meta: - app_label = 'rest_framework' + app_label = 'tests' abstract = True +class HasPositiveIntegerAsChoice(RESTFrameworkModel): + some_choices = ((1, 'A'), (2, 'B'), (3, 'C')) + some_integer = models.PositiveIntegerField(choices=some_choices) + + class Anchor(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') @@ -52,8 +64,14 @@ class BasicModel(RESTFrameworkModel): text = models.CharField(max_length=100) +class SlugBasedModel(RESTFrameworkModel): + text = models.CharField(max_length=100) + slug = models.SlugField(max_length=32) + + class DefaultValueModel(RESTFrameworkModel): text = models.CharField(default='foobar', max_length=100) + extra = models.CharField(blank=True, null=True, max_length=100) class CallableDefaultValueModel(RESTFrameworkModel): @@ -62,12 +80,12 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) - + class ReadOnlyManyToManyModel(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') rel = models.ManyToManyField(Anchor) - + # Models to test generic relations @@ -90,6 +108,13 @@ class Bookmark(RESTFrameworkModel): tags = GenericRelation(TaggedItem) +# Model to test filtering. +class FilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): @@ -101,13 +126,93 @@ class Comment(RESTFrameworkModel): class ActionItem(RESTFrameworkModel): title = models.CharField(max_length=200) done = models.BooleanField(default=False) + info = CustomField(default='---', max_length=12) # Models for reverse relations +class Person(RESTFrameworkModel): + name = models.CharField(max_length=10) + age = models.IntegerField(null=True, blank=True) + + @property + def info(self): + return { + 'name': self.name, + 'age': self.age, + } + + class BlogPost(RESTFrameworkModel): title = models.CharField(max_length=100) + writer = models.ForeignKey(Person, null=True, blank=True) + + def get_first_comment(self): + return self.blogpostcomment_set.all()[0] class BlogPostComment(RESTFrameworkModel): text = models.TextField() blog_post = models.ForeignKey(BlogPost) + + +class Album(RESTFrameworkModel): + title = models.CharField(max_length=100, unique=True) + + +class Photo(RESTFrameworkModel): + description = models.TextField() + album = models.ForeignKey(Album) + + +# Model for issue #324 +class BlankFieldModel(RESTFrameworkModel): + title = models.CharField(max_length=100, blank=True, null=False) + + +# Model for issue #380 +class OptionalRelationModel(RESTFrameworkModel): + other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) + + +# Model for RegexField +class Book(RESTFrameworkModel): + isbn = models.CharField(max_length=13) + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, + related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='nullable_source') diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py index 1f8468e8..f12e3b97 100644 --- a/rest_framework/tests/modelviews.py +++ b/rest_framework/tests/modelviews.py @@ -1,4 +1,4 @@ -# from django.conf.urls.defaults import patterns, url +# from rest_framework.compat import patterns, url # from django.forms import ModelForm # from django.contrib.auth.models import Group, User # from rest_framework.resources import ModelResource diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index a939c9ef..3b550877 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,8 +1,12 @@ +import datetime +from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory -from rest_framework import generics, status, pagination -from rest_framework.tests.models import BasicModel +from django.utils import unittest +from rest_framework import generics, status, pagination, filters, serializers +from rest_framework.compat import django_filters +from rest_framework.tests.models import BasicModel, FilterableItem factory = RequestFactory() @@ -15,6 +19,36 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 +if django_filters: + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend + + +class DefaultPageSizeKwargView(generics.ListAPIView): + """ + View for testing default paginate_by_param usage + """ + model = BasicModel + + +class PaginateByParamView(generics.ListAPIView): + """ + View for testing custom paginate_by_param usage + """ + model = BasicModel + paginate_by_param = 'page_size' + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -22,7 +56,7 @@ class IntegrationTestPagination(TestCase): def setUp(self): """ - Create 26 BasicModel intances. + Create 26 BasicModel instances. """ for char in 'abcdefghijklmnopqrstuvwxyz': BasicModel(text=char * 3).save() @@ -62,9 +96,66 @@ class IntegrationTestPagination(TestCase): self.assertNotEquals(response.data['previous'], None) +class IntegrationTestPaginationAndFiltering(TestCase): + + def setUp(self): + """ + Create 50 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(26): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() + ] + self.view = FilterFieldsRootView.as_view() + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_paginated_filtered_root_view(self): + """ + GET requests to paginated filtered ListCreateAPIView should return + paginated results. The next and previous links should preserve the + filtered parameters. + """ + request = factory.get('/?decimal=15.20') + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[10:15]) + self.assertEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None) + + request = factory.get(response.data['previous']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 15) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + +class PassOnContextPaginationSerializer(pagination.PaginationSerializer): + class Meta: + object_serializer_class = serializers.Serializer + + class UnitTestPagination(TestCase): """ - Unit tests for pagination of primative objects. + Unit tests for pagination of primitive objects. """ def setUp(self): @@ -74,14 +165,117 @@ class UnitTestPagination(TestCase): self.last_page = paginator.page(3) def test_native_pagination(self): - serializer = pagination.PaginationSerializer(instance=self.first_page) + serializer = pagination.PaginationSerializer(self.first_page) self.assertEquals(serializer.data['count'], 26) self.assertEquals(serializer.data['next'], '?page=2') self.assertEquals(serializer.data['previous'], None) self.assertEquals(serializer.data['results'], self.objects[:10]) - serializer = pagination.PaginationSerializer(instance=self.last_page) + serializer = pagination.PaginationSerializer(self.last_page) self.assertEquals(serializer.data['count'], 26) self.assertEquals(serializer.data['next'], None) self.assertEquals(serializer.data['previous'], '?page=2') self.assertEquals(serializer.data['results'], self.objects[20:]) + + def test_context_available_in_result(self): + """ + Ensure context gets passed through to the object serializer. + """ + serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) + serializer.data + results = serializer.fields[serializer.results_field] + self.assertEquals(serializer.context, results.context) + + +class TestUnpaginated(TestCase): + """ + Tests for list views without pagination. + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = DefaultPageSizeKwargView.as_view() + + def test_unpaginated(self): + """ + Tests the default page size for this view. + no page size --> no limit --> no meta data + """ + request = factory.get('/') + response = self.view(request) + self.assertEquals(response.data, self.data) + + +class TestCustomPaginateByParam(TestCase): + """ + Tests for list views with default page size kwarg + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = PaginateByParamView.as_view() + + def test_default_page_size(self): + """ + Tests the default page size for this view. + no page size --> no limit --> no meta data + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEquals(response.data, self.data) + + def test_paginate_by_param(self): + """ + If paginate_by_param is set, the new kwarg should limit per view requests. + """ + request = factory.get('/?page_size=5') + response = self.view(request).render() + self.assertEquals(response.data['count'], 13) + self.assertEquals(response.data['results'], self.data[:5]) + + +class CustomField(serializers.Field): + def to_native(self, value): + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into custom field") + return "value" + + +class BasicModelSerializer(serializers.Serializer): + text = CustomField() + + +class TestContextPassedToCustomField(TestCase): + def setUp(self): + BasicModel.objects.create(text='ala ma kota') + + def test_with_pagination(self): + class ListView(generics.ListCreateAPIView): + model = BasicModel + serializer_class = BasicModelSerializer + paginate_by = 1 + + self.view = ListView.as_view() + request = factory.get('/') + response = self.view(request).render() + + self.assertEquals(response.status_code, status.HTTP_200_OK) + diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py new file mode 100644 index 00000000..91daea8a --- /dev/null +++ b/rest_framework/tests/relations.py @@ -0,0 +1,33 @@ +""" +General tests for relational fields. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class NullModel(models.Model): + pass + + +class FieldTests(TestCase): + def test_pk_related_field_with_empty_string(self): + """ + Regression test for #446 + + https://github.com/tomchristie/django-rest-framework/issues/446 + """ + field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_hyperlinked_related_field_with_empty_string(self): + field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_slug_related_field_with_empty_string(self): + field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py new file mode 100644 index 00000000..57913670 --- /dev/null +++ b/rest_framework/tests/relations_hyperlink.py @@ -0,0 +1,434 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.compat import patterns, url +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource + +def dummy_view(request, pk): + pass + +urlpatterns = patterns('', + url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), + url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), + url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), + url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), + url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), + url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), + url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), +) + +class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): + sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') + + class Meta: + model = ManyToManyTarget + + +class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = ManyToManySource + + +class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): + sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = ForeignKeySource + + +# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = NullableForeignKeySource + + +# OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): + nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') + + class Meta: + model = OneToOneTarget + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class HyperlinkedManyToManyTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + for idx in range(1, 4): + target = ManyToManyTarget(name='target-%d' % idx) + target.save() + source = ManyToManySource(name='source-%d' % idx) + source.save() + for target in ManyToManyTarget.objects.all(): + source.targets.add(target) + + def test_many_to_many_retrieve(self): + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, + {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, + {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_retrieve(self): + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, + {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, + {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + ] + self.assertEquals(serializer.data, expected) + + def test_many_to_many_update(self): + data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + instance = ManyToManySource.objects.get(pk=1) + serializer = ManyToManySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure source 1 is updated, and everything else is as expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, + {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, + {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_update(self): + data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} + instance = ManyToManyTarget.objects.get(pk=1) + serializer = ManyToManyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 1 is updated, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, + {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, + {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + + ] + self.assertEquals(serializer.data, expected) + + def test_many_to_many_create(self): + data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} + serializer = ManyToManySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, + {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, + {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, + {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_create(self): + data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} + serializer = ManyToManyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-4') + + # Ensure target 4 is added, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, + {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, + {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, + {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} + ] + self.assertEquals(serializer.data, expected) + + +class HyperlinkedForeignKeyTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, + {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, + {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_update(self): + data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, + {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(new_serializer.data, expected) + + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, + {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} + serializer = ForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, + {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-3') + + # Ensure target 4 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, + {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, + {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class HyperlinkedNullableForeignKeyTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_null(self): + data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} + expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, expected_data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_null(self): + data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, + {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} + expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, expected_data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, + {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, + {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + # reverse foreign keys MUST be read_only + # In the general case they do not provide .remove() or .clear() + # and cannot be arbitrarily set. + + # def test_reverse_foreign_key_update(self): + # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # instance = ForeignKeyTarget.objects.get(pk=1) + # serializer = ForeignKeyTargetSerializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # self.assertEquals(serializer.data, data) + # serializer.save() + + # # Ensure target 1 is updated, and everything else is as expected + # queryset = ForeignKeyTarget.objects.all() + # serializer = ForeignKeyTargetSerializer(queryset) + # expected = [ + # {'id': 1, 'name': u'target-1', 'sources': [1]}, + # {'id': 2, 'name': u'target-2', 'sources': []}, + # ] + # self.assertEquals(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, + {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py new file mode 100644 index 00000000..0e129fae --- /dev/null +++ b/rest_framework/tests/relations_nested.py @@ -0,0 +1,114 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + depth = 1 + model = ForeignKeySource + + +class FlatForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = FlatForeignKeySourceSerializer() + + class Meta: + model = ForeignKeyTarget + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + depth = 1 + model = NullableForeignKeySource + + +class NullableOneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableOneToOneSource + + +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + nullable_source = NullableOneToOneSourceSerializer() + + class Meta: + model = OneToOneTarget + + +class ReverseForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, + {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, + {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1}, + ]}, + {'id': 2, 'name': u'target-2', 'sources': [ + ]} + ] + self.assertEquals(serializer.data, expected) + + +class NestedNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, + {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, + {'id': 3, 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + +class NestedNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, + {'id': 2, 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py new file mode 100644 index 00000000..54835860 --- /dev/null +++ b/rest_framework/tests/relations_pk.py @@ -0,0 +1,411 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource + + +class ManyToManyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManyPrimaryKeyRelatedField() + + class Meta: + model = ManyToManyTarget + + +class ManyToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManySource + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.ManyPrimaryKeyRelatedField() + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource + + +# OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + nullable_source = serializers.PrimaryKeyRelatedField() + + class Meta: + model = OneToOneTarget + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class PKManyToManyTests(TestCase): + def setUp(self): + for idx in range(1, 4): + target = ManyToManyTarget(name='target-%d' % idx) + target.save() + source = ManyToManySource(name='source-%d' % idx) + source.save() + for target in ManyToManyTarget.objects.all(): + source.targets.add(target) + + def test_many_to_many_retrieve(self): + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'targets': [1]}, + {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_retrieve(self): + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]} + ] + self.assertEquals(serializer.data, expected) + + def test_many_to_many_update(self): + data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} + instance = ManyToManySource.objects.get(pk=1) + serializer = ManyToManySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure source 1 is updated, and everything else is as expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, + {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_update(self): + data = {'id': 1, 'name': u'target-1', 'sources': [1]} + instance = ManyToManyTarget.objects.get(pk=1) + serializer = ManyToManyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 1 is updated, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]} + ] + self.assertEquals(serializer.data, expected) + + def test_many_to_many_create(self): + data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]} + serializer = ManyToManySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ManyToManySource.objects.all() + serializer = ManyToManySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'targets': [1]}, + {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}, + {'id': 4, 'name': u'source-4', 'targets': [1, 3]}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_many_to_many_create(self): + data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + serializer = ManyToManyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-4') + + # Ensure target 4 is added, and everything else is as expected + queryset = ManyToManyTarget.objects.all() + serializer = ManyToManyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': u'target-3', 'sources': [3]}, + {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + ] + self.assertEquals(serializer.data, expected) + + +class PKForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': u'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 2}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1} + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': u'target-2', 'sources': []}, + ] + self.assertEquals(new_serializer.data, expected) + + serializer.save() + self.assertEquals(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [2]}, + {'id': 2, 'name': u'target-2', 'sources': [1, 3]}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': u'source-4', 'target': 2} + serializer = ForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': 1}, + {'id': 4, 'name': u'source-4', 'target': 2}, + ] + self.assertEquals(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'sources': [2]}, + {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 3, 'name': u'target-3', 'sources': [1, 3]}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'id': 1, 'name': u'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class PKNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': None}, + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_null(self): + data = {'id': 4, 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 4, 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 4, 'name': u'source-4', 'target': ''} + expected_data = {'id': 4, 'name': u'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEquals(serializer.data, expected_data) + self.assertEqual(obj.name, u'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': 1}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 4, 'name': u'source-4', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_null(self): + data = {'id': 1, 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': None}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 1, 'name': u'source-1', 'target': ''} + expected_data = {'id': 1, 'name': u'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.data, expected_data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset) + expected = [ + {'id': 1, 'name': u'source-1', 'target': None}, + {'id': 2, 'name': u'source-2', 'target': 1}, + {'id': 3, 'name': u'source-3', 'target': None} + ] + self.assertEquals(serializer.data, expected) + + # reverse foreign keys MUST be read_only + # In the general case they do not provide .remove() or .clear() + # and cannot be arbitrarily set. + + # def test_reverse_foreign_key_update(self): + # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # instance = ForeignKeyTarget.objects.get(pk=1) + # serializer = ForeignKeyTargetSerializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # self.assertEquals(serializer.data, data) + # serializer.save() + + # # Ensure target 1 is updated, and everything else is as expected + # queryset = ForeignKeyTarget.objects.all() + # serializer = ForeignKeyTargetSerializer(queryset) + # expected = [ + # {'id': 1, 'name': u'target-1', 'sources': [1]}, + # {'id': 2, 'name': u'target-2', 'sources': []}, + # ] + # self.assertEquals(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset) + expected = [ + {'id': 1, 'name': u'target-1', 'nullable_source': 1}, + {'id': 2, 'name': u'target-2', 'nullable_source': None}, + ] + self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index 48d8d9bd..c1b4e624 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -1,11 +1,12 @@ +import pickle import re -from django.conf.urls.defaults import patterns, url, include +from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory from rest_framework import status, permissions -from rest_framework.compat import yaml +from rest_framework.compat import yaml, patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ @@ -83,6 +84,7 @@ class HTMLView1(APIView): urlpatterns = patterns('', url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^cache$', MockGETView.as_view()), url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), url(r'^html$', HTMLView.as_view()), @@ -416,3 +418,89 @@ class XMLRendererTestCase(TestCase): self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) self.assertTrue(xml.endswith('</root>')) self.assertTrue(string in xml, '%r not in %r' % (string, xml)) + + +# Tests for caching issue, #346 +class CacheRenderTest(TestCase): + """ + Tests specific to caching responses + """ + + urls = 'rest_framework.tests.renderers' + + cache_key = 'just_a_cache_key' + + @classmethod + def _get_pickling_errors(cls, obj, seen=None): + """ Return any errors that would be raised if `obj' is pickled + Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 + """ + if seen == None: + seen = [] + try: + state = obj.__getstate__() + except AttributeError: + return + if state == None: + return + if isinstance(state, tuple): + if not isinstance(state[0], dict): + state = state[1] + else: + state = state[0].update(state[1]) + result = {} + for i in state: + try: + pickle.dumps(state[i], protocol=2) + except pickle.PicklingError: + if not state[i] in seen: + seen.append(state[i]) + result[i] = cls._get_pickling_errors(state[i], seen) + return result + + def http_resp(self, http_method, url): + """ + Simple wrapper for Client http requests + Removes the `client' and `request' attributes from as they are + added by django.test.client.Client and not part of caching + responses outside of tests. + """ + method = getattr(self.client, http_method) + resp = method(url) + del resp.client, resp.request + return resp + + def test_obj_pickling(self): + """ + Test that responses are properly pickled + """ + resp = self.http_resp('get', '/cache') + + # Make sure that no pickling errors occurred + self.assertEqual(self._get_pickling_errors(resp), {}) + + # Unfortunately LocMem backend doesn't raise PickleErrors but returns + # None instead. + cache.set(self.cache_key, resp) + self.assertTrue(cache.get(self.cache_key) is not None) + + def test_head_caching(self): + """ + Test caching of HEAD requests + """ + resp = self.http_resp('head', '/cache') + cache.set(self.cache_key, resp) + + cached_resp = cache.get(self.cache_key) + self.assertIsInstance(cached_resp, Response) + + def test_get_caching(self): + """ + Test caching of GET requests + """ + resp = self.http_resp('get', '/cache') + cache.set(self.cache_key, resp) + + cached_resp = cache.get(self.cache_key) + self.assertIsInstance(cached_resp, Response) + self.assertEqual(cached_resp.content, resp.content) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index ff48f3fa..4b032405 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,14 +1,15 @@ """ Tests for content parsing, and form-overloaded content parsing. """ -from django.conf.urls.defaults import patterns +import json from django.contrib.auth.models import User +from django.contrib.auth import authenticate, login, logout +from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client -from django.utils import simplejson as json - +from django.test.client import RequestFactory from rest_framework import status from rest_framework.authentication import SessionAuthentication -from django.test.client import RequestFactory +from rest_framework.compat import patterns from rest_framework.parsers import ( BaseParser, FormParser, @@ -276,3 +277,37 @@ class TestContentParsingWithAuthentication(TestCase): # response = self.csrf_client.post('/', content) # self.assertEqual(status.OK, response.status_code, "POST data is malformed") + + +class TestUserSetter(TestCase): + + def setUp(self): + # Pass request object through session middleware so session is + # available to login and logout functions + self.request = Request(factory.get('/')) + SessionMiddleware().process_request(self.request) + + User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') + self.user = authenticate(username='ringo', password='yellow') + + def test_user_can_be_set(self): + self.request.user = self.user + self.assertEqual(self.request.user, self.user) + + def test_user_can_login(self): + login(self.request, self.user) + self.assertEqual(self.request.user, self.user) + + def test_user_can_logout(self): + self.request.user = self.user + self.assertFalse(self.request.user.is_anonymous()) + logout(self.request) + self.assertTrue(self.request.user.is_anonymous()) + + +class TestAuthSetter(TestCase): + + def test_auth_can_be_set(self): + request = Request(factory.get('/')) + request.auth = 'DUMMY' + self.assertEqual(request.auth, 'DUMMY') diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 18b6af39..875f4d42 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -1,8 +1,5 @@ -import unittest - -from django.conf.urls.defaults import patterns, url, include from django.test import TestCase - +from rest_framework.compat import patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework import status @@ -131,12 +128,6 @@ class RendererIntegrationTests(TestCase): self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.status_code, DUMMYSTATUS) - @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') - def test_unsatisfiable_accept_header_on_request_returns_406_status(self): - """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" - resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py index fd9a7d64..8c86e1fb 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/reverse.py @@ -1,6 +1,6 @@ -from django.conf.urls.defaults import patterns, url from django.test import TestCase from django.test.client import RequestFactory +from rest_framework.compat import patterns, url from rest_framework.reverse import reverse factory = RequestFactory() diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 936f15aa..bd96ba23 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,13 +1,16 @@ import datetime +import pickle from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * +from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, + BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) class SubComment(object): def __init__(self, sub_comment): self.sub_comment = sub_comment - + class Comment(object): def __init__(self, email, content, created): @@ -18,7 +21,7 @@ class Comment(object): def __eq__(self, other): return all([getattr(self, attr) == getattr(other, attr) for attr in ('email', 'content', 'created')]) - + def get_sub_comment(self): sub_comment = SubComment('And Merry Christmas!') return sub_comment @@ -29,7 +32,7 @@ class CommentSerializer(serializers.Serializer): content = serializers.CharField(max_length=1000) created = serializers.DateTimeField() sub_comment = serializers.Field(source='get_sub_comment.sub_comment') - + def restore_object(self, data, instance=None): if instance is None: return Comment(**data) @@ -38,10 +41,41 @@ class CommentSerializer(serializers.Serializer): return instance +class BookSerializer(serializers.ModelSerializer): + isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) + + class Meta: + model = Book + + class ActionItemSerializer(serializers.ModelSerializer): + class Meta: model = ActionItem + +class PersonSerializer(serializers.ModelSerializer): + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age',) + + +class AlbumsSerializer(serializers.ModelSerializer): + + class Meta: + model = Album + fields = ['title'] # lists are also valid options + + +class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): + class Meta: + model = HasPositiveIntegerAsChoice + fields = ['some_integer'] + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -61,6 +95,9 @@ class BasicTests(TestCase): 'created': datetime.datetime(2012, 1, 1), 'sub_comment': 'And Merry Christmas!' } + self.person_data = {'name': 'dwight', 'age': 35} + self.person = Person(**self.person_data) + self.person.save() def test_empty(self): serializer = CommentSerializer() @@ -73,11 +110,11 @@ class BasicTests(TestCase): self.assertEquals(serializer.data, expected) def test_retrieve(self): - serializer = CommentSerializer(instance=self.comment) + serializer = CommentSerializer(self.comment) self.assertEquals(serializer.data, self.expected) def test_create(self): - serializer = CommentSerializer(self.data) + serializer = CommentSerializer(data=self.data) expected = self.comment self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) @@ -85,13 +122,54 @@ class BasicTests(TestCase): self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') def test_update(self): - serializer = CommentSerializer(self.data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) self.assertTrue(serializer.object is expected) self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + def test_partial_update(self): + msg = 'Merry New Year!' + partial_data = {'content': msg} + serializer = CommentSerializer(self.comment, data=partial_data) + self.assertEquals(serializer.is_valid(), False) + serializer = CommentSerializer(self.comment, data=partial_data, partial=True) + expected = self.comment + self.assertEqual(serializer.is_valid(), True) + self.assertEquals(serializer.object, expected) + self.assertTrue(serializer.object is expected) + self.assertEquals(serializer.data['content'], msg) + + def test_model_fields_as_expected(self): + """ + Make sure that the fields returned are the same as defined + in the Meta data + """ + serializer = PersonSerializer(self.person) + self.assertEquals(set(serializer.data.keys()), + set(['name', 'age', 'info'])) + + def test_field_with_dictionary(self): + """ + Make sure that dictionaries from fields are left intact + """ + serializer = PersonSerializer(self.person) + expected = self.person_data + self.assertEquals(serializer.data['info'], expected) + + def test_read_only_fields(self): + """ + Attempting to update fields set as read_only should have no effect. + """ + + serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + self.assertEquals(serializer.errors, {}) + # Assert age is unchanged (35) + self.assertEquals(instance.age, self.person_data['age']) + class ValidationTests(TestCase): def setUp(self): @@ -104,17 +182,17 @@ class ValidationTests(TestCase): 'email': 'tom@example.com', 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) - } - self.actionitem = ActionItem('Some to do item', + } + self.actionitem = ActionItem(title='Some to do item', ) def test_create(self): - serializer = CommentSerializer(self.data) + serializer = CommentSerializer(data=self.data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) def test_update(self): - serializer = CommentSerializer(self.data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=self.data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) @@ -123,7 +201,7 @@ class ValidationTests(TestCase): 'content': 'xxx', 'created': datetime.datetime(2012, 1, 1) } - serializer = CommentSerializer(data, instance=self.comment) + serializer = CommentSerializer(self.comment, data=data) self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) @@ -131,10 +209,10 @@ class ValidationTests(TestCase): """Make sure that a boolean value with a 'False' value is not mistaken for not having a default.""" data = { - 'title':'Some action item', + 'title': 'Some action item', #No 'done' value. } - serializer = ActionItemSerializer(data, instance=self.actionitem) + serializer = ActionItemSerializer(self.actionitem, data=data) self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.errors, {}) @@ -154,15 +232,34 @@ class ValidationTests(TestCase): 'created': datetime.datetime(2012, 1, 1) } - serializer = CommentSerializerWithFieldValidator(data) + serializer = CommentSerializerWithFieldValidator(data=data) self.assertTrue(serializer.is_valid()) data['content'] = 'This should not validate' - serializer = CommentSerializerWithFieldValidator(data) + serializer = CommentSerializerWithFieldValidator(data=data) self.assertFalse(serializer.is_valid()) self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + def test_bad_type_data_is_false(self): + """ + Data of the wrong type is not valid. + """ + data = ['i am', 'a', 'list'] + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + + data = 'and i am a string' + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + + data = 42 + serializer = CommentSerializer(self.comment, data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + def test_cross_field_validation(self): class CommentSerializerWithCrossFieldValidator(CommentSerializer): @@ -178,15 +275,109 @@ class ValidationTests(TestCase): 'created': datetime.datetime(2012, 1, 1) } - serializer = CommentSerializerWithCrossFieldValidator(data) + serializer = CommentSerializerWithCrossFieldValidator(data=data) self.assertTrue(serializer.is_valid()) data['content'] = 'A comment from foo@bar.com' - serializer = CommentSerializerWithCrossFieldValidator(data) + serializer = CommentSerializerWithCrossFieldValidator(data=data) self.assertFalse(serializer.is_valid()) self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) + def test_null_is_true_fields(self): + """ + Omitting a value for null-field should validate. + """ + serializer = PersonSerializer(data={'name': 'marko'}) + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.errors, {}) + + def test_modelserializer_max_length_exceeded(self): + data = { + 'title': 'x' * 201, + } + serializer = ActionItemSerializer(data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) + + def test_default_modelfield_max_length_exceeded(self): + data = { + 'title': 'Testing "info" field...', + 'info': 'x' * 13, + } + serializer = ActionItemSerializer(data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) + + +class PositiveIntegerAsChoiceTests(TestCase): + def test_positive_integer_in_json_is_correctly_parsed(self): + data = {'some_integer':1} + serializer = PositiveIntegerAsChoiceSerializer(data=data) + self.assertEquals(serializer.is_valid(), True) + +class ModelValidationTests(TestCase): + def test_validate_unique(self): + """ + Just check if serializers.ModelSerializer handles unique checks via .full_clean() + """ + serializer = AlbumsSerializer(data={'title': 'a'}) + serializer.is_valid() + serializer.save() + second_serializer = AlbumsSerializer(data={'title': 'a'}) + self.assertFalse(second_serializer.is_valid()) + self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) + + def test_foreign_key_with_partial(self): + """ + Test ModelSerializer validation with partial=True + + Specifically test foreign key validation. + """ + + album = Album(title='test') + album.save() + + class PhotoSerializer(serializers.ModelSerializer): + class Meta: + model = Photo + + photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) + self.assertTrue(photo_serializer.is_valid()) + photo = photo_serializer.save() + + # Updating only the album (foreign key) + photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) + self.assertTrue(photo_serializer.is_valid()) + self.assertTrue(photo_serializer.save()) + + # Updating only the description + photo_serializer = PhotoSerializer(instance=photo, + data={'description': 'new'}, + partial=True) + + self.assertTrue(photo_serializer.is_valid()) + self.assertTrue(photo_serializer.save()) + + +class RegexValidationTest(TestCase): + def test_create_failed(self): + serializer = BookSerializer(data={'isbn': '1234567890'}) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + + serializer = BookSerializer(data={'isbn': '12345678901234'}) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + + serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + + def test_create_success(self): + serializer = BookSerializer(data={'isbn': '1234567890123'}) + self.assertTrue(serializer.is_valid()) + class MetadataTests(TestCase): def test_empty(self): @@ -233,7 +424,7 @@ class ManyToManyTests(TestCase): Create an instance of a model with a ManyToMany relationship. """ data = {'rel': [self.anchor.id]} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) @@ -247,7 +438,7 @@ class ManyToManyTests(TestCase): new_anchor = Anchor() new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 1) @@ -260,7 +451,7 @@ class ManyToManyTests(TestCase): containing no items. """ data = {'rel': []} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) @@ -275,7 +466,7 @@ class ManyToManyTests(TestCase): new_anchor = Anchor() new_anchor.save() data = {'rel': []} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 1) @@ -289,17 +480,19 @@ class ManyToManyTests(TestCase): lists (eg form data). """ data = {'rel': ''} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ManyToManyModel.objects.all()), 2) self.assertEquals(instance.pk, 2) self.assertEquals(list(instance.rel.all()), []) - + + class ReadOnlyManyToManyTests(TestCase): def setUp(self): class ReadOnlyManyToManySerializer(serializers.ModelSerializer): - rel = serializers.ManyRelatedField(readonly=True) + rel = serializers.ManyRelatedField(read_only=True) + class Meta: model = ReadOnlyManyToManyModel @@ -317,16 +510,15 @@ class ReadOnlyManyToManyTests(TestCase): # A serialized representation of the model instance self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} - def test_update(self): """ Attempt to update an instance of a model with a ManyToMany - relationship. Not updated due to readonly=True + relationship. Not updated due to read_only=True """ new_anchor = Anchor() new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) @@ -337,12 +529,12 @@ class ReadOnlyManyToManyTests(TestCase): def test_update_without_relationship(self): """ Attempt to update an instance of a model where many to ManyToMany - relationship is not supplied. Not updated due to readonly=True + relationship is not supplied. Not updated due to read_only=True """ new_anchor = Anchor() new_anchor.save() data = {} - serializer = self.serializer_class(data, instance=self.instance) + serializer = self.serializer_class(self.instance, data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) @@ -362,7 +554,7 @@ class DefaultValueTests(TestCase): def test_create_using_default(self): data = {} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -371,13 +563,28 @@ class DefaultValueTests(TestCase): def test_create_overriding_default(self): data = {'text': 'overridden'} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) self.assertEquals(instance.pk, 1) self.assertEquals(instance.text, 'overridden') + def test_partial_update_default(self): + """ Regression test for issue #532 """ + data = {'text': 'overridden'} + serializer = self.serializer_class(data=data, partial=True) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + + data = {'extra': 'extra_value'} + serializer = self.serializer_class(instance=instance, data=data, partial=True) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + + self.assertEquals(instance.extra, 'extra_value') + self.assertEquals(instance.text, 'overridden') + class CallableDefaultValueTests(TestCase): def setUp(self): @@ -390,7 +597,7 @@ class CallableDefaultValueTests(TestCase): def test_create_using_default(self): data = {} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -399,7 +606,7 @@ class CallableDefaultValueTests(TestCase): def test_create_overriding_default(self): data = {'text': 'overridden'} - serializer = self.serializer_class(data) + serializer = self.serializer_class(data=data) self.assertEquals(serializer.is_valid(), True) instance = serializer.save() self.assertEquals(len(self.objects.all()), 1) @@ -408,7 +615,10 @@ class CallableDefaultValueTests(TestCase): class ManyRelatedTests(TestCase): - def setUp(self): + def test_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") class BlogPostCommentSerializer(serializers.Serializer): text = serializers.CharField() @@ -417,14 +627,7 @@ class ManyRelatedTests(TestCase): title = serializers.CharField() comments = BlogPostCommentSerializer(source='blogpostcomment_set') - self.serializer_class = BlogPostSerializer - - def test_reverse_relations(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - serializer = self.serializer_class(instance=post) + serializer = BlogPostSerializer(instance=post) expected = { 'title': 'Test blog post', 'comments': [ @@ -434,3 +637,267 @@ class ManyRelatedTests(TestCase): } self.assertEqual(serializer.data, expected) + + def test_callable_source(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostCommentSerializer(serializers.Serializer): + text = serializers.CharField() + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + first_comment = BlogPostCommentSerializer(source='get_first_comment') + + serializer = BlogPostSerializer(post) + + expected = { + 'title': 'Test blog post', + 'first_comment': {'text': 'I love this blog post'} + } + self.assertEqual(serializer.data, expected) + + +class RelatedTraversalTest(TestCase): + def test_nested_traversal(self): + user = Person.objects.create(name="django") + post = BlogPost.objects.create(title="Test blog post", writer=user) + post.blogpostcomment_set.create(text="I love this blog post") + + from rest_framework.tests.models import BlogPostComment + + class PersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + fields = ("name", "age") + + class BlogPostCommentSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPostComment + fields = ("text", "post_owner") + + text = serializers.CharField() + post_owner = PersonSerializer(source='blog_post.writer') + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + comments = BlogPostCommentSerializer(source='blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + + expected = { + 'title': u'Test blog post', + 'comments': [{ + 'text': u'I love this blog post', + 'post_owner': { + "name": u"django", + "age": None + } + }] + } + + self.assertEqual(serializer.data, expected) + + +class SerializerMethodFieldTests(TestCase): + def setUp(self): + + class BoopSerializer(serializers.Serializer): + beep = serializers.SerializerMethodField('get_beep') + boop = serializers.Field() + boop_count = serializers.SerializerMethodField('get_boop_count') + + def get_beep(self, obj): + return 'hello!' + + def get_boop_count(self, obj): + return len(obj.boop) + + self.serializer_class = BoopSerializer + + def test_serializer_method_field(self): + + class MyModel(object): + boop = ['a', 'b', 'c'] + + source_data = MyModel() + + serializer = self.serializer_class(source_data) + + expected = { + 'beep': u'hello!', + 'boop': [u'a', u'b', u'c'], + 'boop_count': 3, + } + + self.assertEqual(serializer.data, expected) + + +# Test for issue #324 +class BlankFieldTests(TestCase): + def setUp(self): + + class BlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BlankFieldModel + + class BlankFieldSerializer(serializers.Serializer): + title = serializers.CharField(blank=True) + + class NotBlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class NotBlankFieldSerializer(serializers.Serializer): + title = serializers.CharField() + + self.model_serializer_class = BlankFieldModelSerializer + self.serializer_class = BlankFieldSerializer + self.not_blank_model_serializer_class = NotBlankFieldModelSerializer + self.not_blank_serializer_class = NotBlankFieldSerializer + self.data = {'title': ''} + + def test_create_blank_field(self): + serializer = self.serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_model_blank_field(self): + serializer = self.model_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_model_null_field(self): + serializer = self.model_serializer_class(data={'title': None}) + self.assertEquals(serializer.is_valid(), True) + + def test_create_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a non-model serializer + """ + serializer = self.not_blank_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), False) + + def test_create_model_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a model serializer + """ + serializer = self.not_blank_model_serializer_class(data=self.data) + self.assertEquals(serializer.is_valid(), False) + + def test_create_model_null_field(self): + serializer = self.model_serializer_class(data={}) + self.assertEquals(serializer.is_valid(), True) + + +#test for issue #460 +class SerializerPickleTests(TestCase): + """ + Test pickleability of the output of Serializers + """ + def test_pickle_simple_model_serializer_data(self): + """ + Test simple serializer + """ + pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data) + + def test_pickle_inner_serializer(self): + """ + Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will + have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle. + See DictWithMetadata.__getstate__ + """ + class InnerPersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + fields = ('name', 'age') + pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data) + + +class DepthTest(TestCase): + def test_implicit_nesting(self): + writer = Person.objects.create(name="django", age=1) + post = BlogPost.objects.create(title="Test blog post", writer=writer) + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = {'id': 1, 'title': u'Test blog post', + 'writer': {'id': 1, 'name': u'django', 'age': 1}} + + self.assertEqual(serializer.data, expected) + + def test_explicit_nesting(self): + writer = Person.objects.create(name="django", age=1) + post = BlogPost.objects.create(title="Test blog post", writer=writer) + + class PersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + + class BlogPostSerializer(serializers.ModelSerializer): + writer = PersonSerializer() + + class Meta: + model = BlogPost + + serializer = BlogPostSerializer(instance=post) + expected = {'id': 1, 'title': u'Test blog post', + 'writer': {'id': 1, 'name': u'django', 'age': 1}} + + self.assertEqual(serializer.data, expected) + + +class NestedSerializerContextTests(TestCase): + + def test_nested_serializer_context(self): + """ + Regression for #497 + + https://github.com/tomchristie/django-rest-framework/issues/497 + """ + class PhotoSerializer(serializers.ModelSerializer): + class Meta: + model = Photo + fields = ("description", "callable") + + callable = serializers.SerializerMethodField('_callable') + + def _callable(self, instance): + if not 'context_item' in self.context: + raise RuntimeError("context isn't getting passed into 2nd level nested serializer") + return "success" + + class AlbumSerializer(serializers.ModelSerializer): + class Meta: + model = Album + fields = ("photo_set", "callable") + + photo_set = PhotoSerializer(source="photo_set") + callable = serializers.SerializerMethodField("_callable") + + def _callable(self, instance): + if not 'context_item' in self.context: + raise RuntimeError("context isn't getting passed into 1st level nested serializer") + return "success" + + class AlbumCollection(object): + albums = None + + class AlbumCollectionSerializer(serializers.Serializer): + albums = AlbumSerializer(source="albums") + + album1 = Album.objects.create(title="album 1") + album2 = Album.objects.create(title="album 2") + Photo.objects.create(description="Bigfoot", album=album1) + Photo.objects.create(description="Unicorn", album=album1) + Photo.objects.create(description="Yeti", album=album2) + Photo.objects.create(description="Sasquatch", album=album2) + album_collection = AlbumCollection() + album_collection.albums = [album1, album2] + + # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers + AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py new file mode 100644 index 00000000..0293fdc3 --- /dev/null +++ b/rest_framework/tests/settings.py @@ -0,0 +1,21 @@ +"""Tests for the settings module""" +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): + """Tests relating to the api settings""" + + def test_non_import_errors(self): + """Make sure other errors aren't suppressed.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ValueError): + settings.DEFAULT_MODEL_SERIALIZER_CLASS + + def test_import_error_message_maintained(self): + """Make sure real import errors are captured and raised sensibly.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ImportError) as cm: + settings.DEFAULT_MODEL_SERIALIZER_CLASS + self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py index c90224aa..97f492ff 100644 --- a/rest_framework/tests/testcases.py +++ b/rest_framework/tests/testcases.py @@ -6,6 +6,7 @@ from django.test import TestCase NO_SETTING = ('!', None) + class TestSettingsManager(object): """ A class which can modify some Django settings temporarily for a @@ -19,7 +20,7 @@ class TestSettingsManager(object): self._original_settings = {} def set(self, **kwargs): - for k,v in kwargs.iteritems(): + for k, v in kwargs.iteritems(): self._original_settings.setdefault(k, getattr(settings, k, NO_SETTING)) setattr(settings, k, v) @@ -31,7 +32,7 @@ class TestSettingsManager(object): call_command('syncdb', verbosity=0) def revert(self): - for k,v in self._original_settings.iteritems(): + for k, v in self._original_settings.iteritems(): if v == NO_SETTING: delattr(settings, k) else: @@ -57,6 +58,7 @@ class SettingsTestCase(TestCase): def tearDown(self): self.settings_manager.revert() + class TestModelsTestCase(SettingsTestCase): def setUp(self, *args, **kwargs): installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py new file mode 100644 index 00000000..adeaf6da --- /dev/null +++ b/rest_framework/tests/tests.py @@ -0,0 +1,13 @@ +""" +Force import of all modules in this package in order to get the standard test +runner to pick up the tests. Yowzers. +""" +import os + +modules = [filename.rsplit('.', 1)[0] + for filename in os.listdir(os.path.dirname(__file__)) + if filename.endswith('.py') and not filename.startswith('_')] +__test__ = dict() + +for module in modules: + exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py index 0b94c25b..4b98b941 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/throttling.py @@ -106,7 +106,7 @@ class ThrottlingTests(TestCase): if expect is not None: self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) else: - self.assertFalse('X-Throttle-Wait-Seconds' in response.headers) + self.assertFalse('X-Throttle-Wait-Seconds' in response) def test_seconds_fields(self): """ diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..3906adb9 --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,27 @@ +from django.test.client import RequestFactory, FakePayload +from django.test.client import MULTIPART_CONTENT +from urlparse import urlparse + + +class RequestFactory(RequestFactory): + + def __init__(self, **defaults): + super(RequestFactory, self).__init__(**defaults) + + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + **extra): + "Construct a PATCH request." + + patch_data = self._encode_data(data, content_type) + + parsed = urlparse(path) + r = { + 'CONTENT_LENGTH': len(patch_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': parsed[4], + 'REQUEST_METHOD': 'PATCH', + 'wsgi.input': FakePayload(patch_data), + } + r.update(extra) + return self.request(**r) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py index b390c42f..c032985e 100644 --- a/rest_framework/tests/validators.py +++ b/rest_framework/tests/validators.py @@ -285,7 +285,7 @@ # uiop = models.CharField(max_length=256, blank=True) # @property -# def readonly(self): +# def read_only(self): # return 'read only' # class MockResource(ModelResource): @@ -298,7 +298,7 @@ # def test_property_fields_are_allowed_on_model_forms(self): # """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} # self.assertEqual(self.validator.validate_request(content, None), content) # def test_property_fields_are_not_required_on_model_forms(self): @@ -310,19 +310,19 @@ # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_requires_fields_on_model_forms(self): # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'readonly': 'read only'} +# content = {'read_only': 'read only'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_does_not_require_blankable_fields_on_model_forms(self): # """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" -# content = {'qwerty': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'read_only': 'read only'} # self.validator.validate_request(content, None) # def test_model_form_validator_uses_model_forms(self): diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index 43365e07..7cd82656 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -18,7 +18,7 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.DATA}) -@api_view(['GET', 'POST', 'PUT']) +@api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': return {'method': 'GET'} @@ -26,6 +26,8 @@ def basic_view(request): return {'method': 'POST', 'data': request.DATA} elif request.method == 'PUT': return {'method': 'PUT', 'data': request.DATA} + elif request.method == 'PATCH': + return {'method': 'PATCH', 'data': request.DATA} def sanitise_json_error(error_dict): diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 6860e6b9..8fe64248 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -16,7 +16,7 @@ class BaseThrottle(object): def wait(self): """ - Optionally, return a recommeded number of seconds to wait before + Optionally, return a recommended number of seconds to wait before the next request. """ return None diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 316ccd19..143928c9 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,10 +1,10 @@ -from django.conf.urls.defaults import url +from rest_framework.compat import url from rest_framework.settings import api_settings def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): """ - Supplement existing urlpatterns with corrosponding patterns that also + Supplement existing urlpatterns with corresponding patterns that also include a '.format' suffix. Retains urlpattern ordering. urlpatterns: diff --git a/rest_framework/urls.py b/rest_framework/urls.py index 1a81101f..fbe4bc07 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -1,7 +1,7 @@ """ -Login and logout views for the browseable API. +Login and logout views for the browsable API. -Add these to your root URLconf if you're using the browseable API and +Add these to your root URLconf if you're using the browsable API and your API requires authentication. The urls must be namespaced as 'rest_framework', and you should make sure @@ -12,7 +12,7 @@ your authentication settings include `SessionAuthentication`. url(r'^auth', include('rest_framework.urls', namespace='rest_framework')) ) """ -from django.conf.urls.defaults import patterns, url +from rest_framework.compat import patterns, url template_name = {'template_name': 'rest_framework/login.html'} diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index a59fff45..84fcb5db 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,7 +1,6 @@ from django.utils.encoding import smart_unicode from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO - import re import xml.etree.ElementTree as ET diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 672d32a3..80e39d46 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -6,7 +6,7 @@ def get_breadcrumbs(url): from rest_framework.views import APIView - def breadcrumbs_recursive(url, breadcrumbs_list, prefix): + def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" try: @@ -16,7 +16,11 @@ def get_breadcrumbs(url): else: # Check if this is a REST framework view, and if so add it to the breadcrumbs if isinstance(getattr(view, 'cls_instance', None), APIView): - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + # Don't list the same view twice in a row. + # Probably an optional trailing slash. + if not seen or seen[-1] != view: + breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + seen.append(view) if url == '': # All done @@ -24,11 +28,11 @@ def get_breadcrumbs(url): elif url.endswith('/'): # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix) + return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix) + return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] - return breadcrumbs_recursive(url, [], prefix) + return breadcrumbs_recursive(url, [], prefix, []) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 2d1fb353..7afe100a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -4,7 +4,7 @@ Helper classes for parsers. import datetime import decimal import types -from django.utils import simplejson as json +import json from django.utils.datastructures import SortedDict from rest_framework.compat import timezone from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata @@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata class JSONEncoder(json.JSONEncoder): """ - JSONEncoder subclass that knows how to encode date/time, + JSONEncoder subclass that knows how to encode date/time/timedelta, decimal types, and generators. """ def default(self, o): @@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder): if o.microsecond: r = r[:12] return r + elif isinstance(o, datetime.timedelta): + return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): return str(o) elif hasattr(o, '__iter__'): diff --git a/rest_framework/views.py b/rest_framework/views.py index c721be3c..10bdd5a5 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -140,7 +140,7 @@ class APIView(View): def http_method_not_allowed(self, request, *args, **kwargs): """ - Called if `request.method` does not corrospond to a handler method. + Called if `request.method` does not correspond to a handler method. """ raise exceptions.MethodNotAllowed(request.method) @@ -218,7 +218,7 @@ class APIView(View): def get_throttles(self): """ - Instantiates and returns the list of thottles that this view uses. + Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes] @@ -320,13 +320,17 @@ class APIView(View): self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait if isinstance(exc, exceptions.APIException): - return Response({'detail': exc.detail}, status=exc.status_code) + return Response({'detail': exc.detail}, + status=exc.status_code, + exception=True) elif isinstance(exc, Http404): return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND) + status=status.HTTP_404_NOT_FOUND, + exception=True) elif isinstance(exc, PermissionDenied): return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN) + status=status.HTTP_403_FORBIDDEN, + exception=True) raise # Note: session based authentication is explicitly CSRF validated, |
