diff options
| author | Ben Konrath | 2012-11-01 14:06:56 +0100 |
|---|---|---|
| committer | Ben Konrath | 2012-11-01 14:06:56 +0100 |
| commit | 9c82f9717e58f1bb250d5fd4b27619dbcbbd1f21 (patch) | |
| tree | e976854e6871a8b826e91d8eb16d9a139b90664f /rest_framework | |
| parent | c24997df3b943e5d7a3b2e101508e4b79ee82dc4 (diff) | |
| parent | 204db7bdaa59cd17f762d6cf0e6a8623c2cc9939 (diff) | |
| download | django-rest-framework-9c82f9717e58f1bb250d5fd4b27619dbcbbd1f21.tar.bz2 | |
Merge branch 'master' into restframework2-filter
Diffstat (limited to 'rest_framework')
35 files changed, 1045 insertions, 543 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index ee5bd2f2..30c78ebc 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,10 +1,10 @@ """ -The :mod:`authentication` module provides a set of pluggable authentication classes. - -Authentication behavior is provided by mixing the :class:`mixins.RequestMixin` class into a :class:`View` class. +Provides a set of pluggable authentication policies. """ from django.contrib.auth import authenticate +from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError +from rest_framework import exceptions from rest_framework.compat import CsrfViewMiddleware from rest_framework.authtoken.models import Token import base64 @@ -17,25 +17,14 @@ class BaseAuthentication(object): def authenticate(self, request): """ - Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_ - - .. [*] The authentication context *will* typically be a :obj:`User`, - but it need not be. It can be any user-like object so long as the - permissions classes (see the :mod:`permissions` module) on the view can - handle the object and use it to determine if the request has the required - permissions or not. - - This can be an important distinction if you're implementing some token - based authentication mechanism, where the authentication context - may be more involved than simply mapping to a :obj:`User`. + Authenticate the request and return a two-tuple of (user, token). """ - return None + raise NotImplementedError(".authenticate() must be overridden.") class BasicAuthentication(BaseAuthentication): """ - Base class for HTTP Basic authentication. - Subclasses should implement `.authenticate_credentials()`. + HTTP Basic authentication against username/password. """ def authenticate(self, request): @@ -43,8 +32,6 @@ class BasicAuthentication(BaseAuthentication): Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError - if 'HTTP_AUTHORIZATION' in request.META: auth = request.META['HTTP_AUTHORIZATION'].split() if len(auth) == 2 and auth[0].lower() == "basic": @@ -54,7 +41,8 @@ class BasicAuthentication(BaseAuthentication): return None try: - userid, password = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2]) + userid = smart_unicode(auth_parts[0]) + password = smart_unicode(auth_parts[2]) except DjangoUnicodeDecodeError: return None @@ -62,15 +50,6 @@ class BasicAuthentication(BaseAuthentication): def authenticate_credentials(self, userid, password): """ - Given the Basic authentication userid and password, authenticate - and return a user instance. - """ - raise NotImplementedError('.authenticate_credentials() must be overridden') - - -class UserBasicAuthentication(BasicAuthentication): - def authenticate_credentials(self, userid, password): - """ Authenticate the userid and password against username and password. """ user = authenticate(username=userid, password=password) @@ -85,20 +64,31 @@ class SessionAuthentication(BaseAuthentication): def authenticate(self, request): """ - Returns a :obj:`User` if the request session currently has a logged in user. - Otherwise returns :const:`None`. + Returns a `User` if the request session currently has a logged in user. + Otherwise returns `None`. """ # Get the underlying HttpRequest object http_request = request._request user = getattr(http_request, 'user', None) - if user and user.is_active: - # Enforce CSRF validation for session based authentication. - resp = CsrfViewMiddleware().process_view(http_request, None, (), {}) + # Unauthenticated, CSRF validation not required + if not user or not user.is_active: + return + + # Enforce CSRF validation for session based authentication. + class CSRFCheck(CsrfViewMiddleware): + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason + + reason = CSRFCheck().process_view(http_request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) - if resp is None: # csrf passed - return (user, None) + # CSRF passed with authenticated user + return (user, None) class TokenAuthentication(BaseAuthentication): diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7664c400..b0367a32 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -1,6 +1,8 @@ """ -The :mod:`compat` module provides support for backwards compatibility with older versions of django/python. +The `compat` module provides support for backwards compatibility with older +versions of django/python, and compatbility wrappers around optional packages. """ +# flake8: noqa import django # cStringIO only if it's available, otherwise StringIO diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 948973ae..a231f191 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,8 +10,18 @@ def api_view(http_method_names): def decorator(func): - class WrappedAPIView(APIView): - pass + WrappedAPIView = type( + 'WrappedAPIView', + (APIView,), + {'__doc__': func.__doc__} + ) + + # Note, the above allows us to set the docstring. + # It is the equivelent of: + # + # class WrappedAPIView(APIView): + # pass + # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 572425b9..89479deb 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -31,14 +31,6 @@ class PermissionDenied(APIException): self.detail = detail or self.default_detail -class InvalidFormat(APIException): - status_code = status.HTTP_404_NOT_FOUND - default_detail = "Format suffix '.%s' not found." - - def __init__(self, format, detail=None): - self.detail = (detail or self.default_detail) % format - - class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = "Method '%s' not allowed." diff --git a/rest_framework/fields.py b/rest_framework/fields.py index bb9a523d..73c8f72b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -5,13 +5,15 @@ import warnings from django.core import validators from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve +from django.core.urlresolvers import resolve, get_script_prefix from django.conf import settings +from django.forms import widgets from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ from rest_framework.reverse import reverse from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import timezone +from urlparse import urlparse def is_simple_callable(obj): @@ -42,7 +44,7 @@ class Field(object): Called to set up a field prior to field_to_native or field_from_native. parent - The parent serializer. - model_field - The model field this field corrosponds to, if one exists. + model_field - The model field this field corresponds to, if one exists. """ self.parent = parent self.root = parent.root or parent @@ -70,6 +72,8 @@ class Field(object): value = obj for component in self.source.split('.'): value = getattr(value, component) + if is_simple_callable(value): + value = value() else: value = getattr(obj, field_name) return self.to_native(value) @@ -105,15 +109,20 @@ class WritableField(Field): 'required': _('This field is required.'), 'invalid': _('Invalid value.'), } + widget = widgets.TextInput + default = None + + def __init__(self, source=None, read_only=False, required=None, + validators=[], error_messages=None, widget=None, + default=None, blank=None): - def __init__(self, source=None, readonly=False, required=None, - validators=[], error_messages=None): super(WritableField, self).__init__(source=source) - self.readonly = readonly + + self.read_only = read_only if required is None: - self.required = not(readonly) + self.required = not(read_only) else: - assert not readonly, "Cannot set required=True and readonly=True" + assert not read_only, "Cannot set required=True and read_only=True" self.required = required messages = {} @@ -123,6 +132,14 @@ class WritableField(Field): self.error_messages = messages self.validators = self.default_validators + validators + self.default = default or self.default + self.blank = blank + + # Widgets are ony used for HTML forms. + widget = widget or self.widget + if isinstance(widget, type): + widget = widget() + self.widget = widget def validate(self, value): if value in validators.EMPTY_VALUES and self.required: @@ -151,15 +168,18 @@ class WritableField(Field): Given a dictionary and a field name, updates the dictionary `into`, with the field and it's deserialized value. """ - if self.readonly: + if self.read_only: return try: native = data[field_name] except KeyError: - if self.required: - raise ValidationError(self.error_messages['required']) - return + if self.default is not None: + native = self.default + else: + if self.required: + raise ValidationError(self.error_messages['required']) + return value = self.from_native(native) if self.source == '*': @@ -179,7 +199,7 @@ class WritableField(Field): class ModelField(WritableField): """ - A generic field that can be used against an arbirtrary model field. + A generic field that can be used against an arbitrary model field. """ def __init__(self, *args, **kwargs): try: @@ -191,9 +211,9 @@ class ModelField(WritableField): def from_native(self, value): try: rel = self.model_field.rel + return rel.to._meta.get_field(rel.field_name).to_python(value) except: return self.model_field.to_python(value) - return rel.to._meta.get_field(rel.field_name).to_python(value) def field_to_native(self, obj, field_name): value = self.model_field._get_val_from_obj(obj) @@ -222,8 +242,11 @@ class RelatedField(WritableField): return self.to_native(value) def field_from_native(self, data, field_name, into): + if self.read_only: + return + value = data.get(field_name) - into[(self.source or field_name) + '_id'] = self.from_native(value) + into[(self.source or field_name)] = self.from_native(value) class ManyRelatedMixin(object): @@ -235,6 +258,9 @@ class ManyRelatedMixin(object): return [self.to_native(item) for item in value.all()] def field_from_native(self, data, field_name, into): + if self.read_only: + return + try: # Form data value = data.getlist(self.source or field_name) @@ -264,6 +290,15 @@ class PrimaryKeyRelatedField(RelatedField): def to_native(self, pk): return pk + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + raise ValidationError('Invalid hyperlink - object does not exist.') + def field_to_native(self, obj, field_name): try: # Prefer obj.serializable_value for performance reasons @@ -307,14 +342,16 @@ class HyperlinkedRelatedField(RelatedField): self.view_name = kwargs.pop('view_name') except: raise ValueError("Hyperlinked field requires 'view_name' kwarg") + self.format = kwargs.pop('format', None) super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) def to_native(self, obj): view_name = self.view_name request = self.context.get('request', None) + format = self.format or self.context.get('format', None) kwargs = {self.pk_url_kwarg: obj.pk} try: - return reverse(view_name, kwargs=kwargs, request=request) + return reverse(view_name, kwargs=kwargs, request=request, format=format) except: pass @@ -325,13 +362,13 @@ class HyperlinkedRelatedField(RelatedField): kwargs = {self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request) + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) except: pass kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} try: - return reverse(self.view_name, kwargs=kwargs, request=request) + return reverse(self.view_name, kwargs=kwargs, request=request, format=format) except: pass @@ -340,6 +377,16 @@ class HyperlinkedRelatedField(RelatedField): def from_native(self, value): # Convert URL -> model instance pk # TODO: Use values_list + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + if value.startswith('http:') or value.startswith('https:'): + # If needed convert absolute URLs to relative path + value = urlparse(value).path + prefix = get_script_prefix() + if value.startswith(prefix): + value = '/' + value[len(prefix):] + try: match = resolve(value) except: @@ -353,7 +400,7 @@ class HyperlinkedRelatedField(RelatedField): # Try explicit primary key. if pk is not None: - return pk + queryset = self.queryset.filter(pk=pk) # Next, try looking up by slug. elif slug is not None: slug_field = self.get_slug_field() @@ -366,7 +413,7 @@ class HyperlinkedRelatedField(RelatedField): obj = queryset.get() except ObjectDoesNotExist: raise ValidationError('Invalid hyperlink - object does not exist.') - return obj.pk + return obj class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): @@ -381,33 +428,38 @@ class HyperlinkedIdentityField(Field): # TODO: Make this mandatory, and have the HyperlinkedModelSerializer # set it on-the-fly self.view_name = kwargs.pop('view_name', None) + self.format = kwargs.pop('format', None) super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) def field_to_native(self, obj, field_name): request = self.context.get('request', None) + format = self.format or self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name view_kwargs = {'pk': obj.pk} - return reverse(view_name, kwargs=view_kwargs, request=request) + return reverse(view_name, kwargs=view_kwargs, request=request, format=format) ##### Typed Fields ##### class BooleanField(WritableField): type_name = 'BooleanField' + widget = widgets.CheckboxInput default_error_messages = { 'invalid': _(u"'%s' value must be either True or False."), } + empty = False + + # Note: we set default to `False` in order to fill in missing value not + # supplied by html form. TODO: Fix so that only html form input gets + # this behavior. + default = False def from_native(self, value): - if value in (True, False): - # if value is 1 or 0 than it's equal to True or False, but we want - # to return a true bool for semantic reasons. - return bool(value) if value in ('t', 'True', '1'): return True if value in ('f', 'False', '0'): return False - raise ValidationError(self.error_messages['invalid'] % value) + return bool(value) class CharField(WritableField): @@ -421,12 +473,68 @@ class CharField(WritableField): if max_length is not None: self.validators.append(validators.MaxLengthValidator(max_length)) + def validate(self, value): + """ + Validates that the value is supplied (if required). + """ + # if empty string and allow blank + if self.blank and not value: + return + else: + super(CharField, self).validate(value) + def from_native(self, value): if isinstance(value, basestring) or value is None: return value return smart_unicode(value) +class ChoiceField(WritableField): + type_name = 'ChoiceField' + widget = widgets.Select + default_error_messages = { + 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), + } + + def __init__(self, choices=(), *args, **kwargs): + super(ChoiceField, self).__init__(*args, **kwargs) + self.choices = choices + + def _get_choices(self): + return self._choices + + def _set_choices(self, value): + # Setting choices also sets the choices on the widget. + # choices can be any iterable, but we call list() on it because + # it will be consumed more than once. + self._choices = self.widget.choices = list(value) + + choices = property(_get_choices, _set_choices) + + def validate(self, value): + """ + Validates that the input is in self.choices. + """ + super(ChoiceField, self).validate(value) + if value and not self.valid_value(value): + raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) + + def valid_value(self, value): + """ + Check to see if the provided value is a valid choice. + """ + for k, v in self.choices: + if isinstance(v, (list, tuple)): + # This is an optgroup, so look inside the group for options + for k2, v2 in v: + if value == smart_unicode(k2): + return True + else: + if value == smart_unicode(k): + return True + return False + + class EmailField(CharField): type_name = 'EmailField' @@ -436,7 +544,10 @@ class EmailField(CharField): default_validators = [validators.validate_email] def from_native(self, value): - return super(EmailField, self).from_native(value).strip() + ret = super(EmailField, self).from_native(value) + if ret is None: + return None + return ret.strip() def __deepcopy__(self, memo): result = copy.copy(self) @@ -458,8 +569,9 @@ class DateField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): if timezone and settings.USE_TZ and timezone.is_aware(value): # Convert aware datetimes to the default time zone @@ -497,8 +609,9 @@ class DateTimeField(WritableField): empty = None def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + if isinstance(value, datetime.datetime): return value if isinstance(value, datetime.date): @@ -556,6 +669,7 @@ class IntegerField(WritableField): def from_native(self, value): if value in validators.EMPTY_VALUES: return None + try: value = int(str(value)) except (ValueError, TypeError): @@ -571,8 +685,9 @@ class FloatField(WritableField): } def from_native(self, value): - if value is None: - return value + if value in validators.EMPTY_VALUES: + return None + try: return float(value) except (TypeError, ValueError): diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 3b2bea3b..063382bb 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -10,12 +10,12 @@ import django_filters ### Base classes for the generic views ### -class BaseView(views.APIView): +class GenericAPIView(views.APIView): """ Base class for all other generic views. """ serializer_class = None - model_serializer_class = api_settings.MODEL_SERIALIZER + model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS def get_serializer_context(self): """ @@ -51,12 +51,12 @@ class BaseView(views.APIView): return serializer_class(data, instance=instance, context=context) -class MultipleObjectBaseView(MultipleObjectMixin, BaseView): +class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): """ Base class for generic views onto a queryset. """ - pagination_serializer_class = api_settings.PAGINATION_SERIALIZER + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS paginate_by = api_settings.PAGINATE_BY filter_class = None filter_fields = None @@ -106,7 +106,7 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView): return pagination_serializer_class(instance=page, context=context) -class SingleObjectBaseView(SingleObjectMixin, BaseView): +class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): """ Base class for generic views onto a model instance. """ @@ -117,7 +117,7 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): """ Override default to add support for object-level permissions. """ - obj = super(SingleObjectBaseView, self).get_object() + obj = super(SingleObjectAPIView, self).get_object() if not self.has_permission(self.request, obj): self.permission_denied(self.request) return obj @@ -126,8 +126,19 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): ### Concrete view classes that provide method handlers ### ### by composing the mixin classes with a base view. ### + +class CreateAPIView(mixins.CreateModelMixin, + GenericAPIView): + + """ + Concrete view for creating a model instance. + """ + def post(self, request, *args, **kwargs): + return self.create(request, *args, **kwargs) + + class ListAPIView(mixins.ListModelMixin, - MultipleObjectBaseView): + MultipleObjectAPIView): """ Concrete view for listing a queryset. """ @@ -135,9 +146,38 @@ class ListAPIView(mixins.ListModelMixin, return self.list(request, *args, **kwargs) +class RetrieveAPIView(mixins.RetrieveModelMixin, + SingleObjectAPIView): + """ + Concrete view for retrieving a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + +class DestroyAPIView(mixins.DestroyModelMixin, + SingleObjectAPIView): + + """ + Concrete view for deleting a model instance. + """ + def delete(self, request, *args, **kwargs): + return self.destroy(request, *args, **kwargs) + + +class UpdateAPIView(mixins.UpdateModelMixin, + SingleObjectAPIView): + + """ + Concrete view for updating a model instance. + """ + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectBaseView): + MultipleObjectAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -148,18 +188,9 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) -class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectBaseView): - """ - Concrete view for retrieving a model instance. - """ - def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) - - class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectBaseView): + SingleObjectAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -173,7 +204,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectBaseView): + SingleObjectAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 04626fb0..b0cc043a 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -3,9 +3,6 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. - -Eg. Use mixins to build a Resource class, and have a Router class - perform the binding of http methods to actions for us. """ from django.http import Http404 from rest_framework import status @@ -20,10 +17,14 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA) if serializer.is_valid(): + self.pre_save(serializer.object) self.object = serializer.save() return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def pre_save(self, obj): + pass + class ListModelMixin(object): """ @@ -46,7 +47,8 @@ class ListModelMixin(object): # which may be `None` to disable pagination. page_size = self.get_paginate_by(self.object_list) if page_size: - paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size) + packed = self.paginate_queryset(self.object_list, page_size) + paginator, page, queryset, is_paginated = packed serializer = self.get_pagination_serializer(page) else: serializer = self.get_serializer(instance=self.object_list) @@ -73,26 +75,25 @@ class UpdateModelMixin(object): def update(self, request, *args, **kwargs): try: self.object = self.get_object() + success_status = status.HTTP_200_OK except Http404: self.object = None + success_status = status.HTTP_201_CREATED serializer = self.get_serializer(data=request.DATA, instance=self.object) if serializer.is_valid(): - if self.object is None: - # If PUT occurs to a non existant object, we need to set any - # attributes on the object that are implicit in the URL. - self.update_urlconf_attributes(serializer.object) + self.pre_save(serializer.object) self.object = serializer.save() - return Response(serializer.data) + return Response(serializer.data, status=success_status) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def update_urlconf_attributes(self, obj): + def pre_save(self, obj): """ - When update (re)creates an object, we need to set any attributes that - are tied to the URLconf. + Set any attributes on the object that are implicit in the request. """ + # pk and/or slug attributes are implicit in the URL. pk = self.kwargs.get(self.pk_url_kwarg, None) if pk: setattr(obj, 'pk', pk) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 8b22f669..dae38477 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,48 +1,38 @@ +from django.http import Http404 from rest_framework import exceptions from rest_framework.settings import api_settings from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches class BaseContentNegotiation(object): - def negotiate(self, request, renderers, format=None, force=False): - raise NotImplementedError('.negotiate() must be implemented') + def select_parser(self, request, parsers): + raise NotImplementedError('.select_parser() must be implemented') + def select_renderer(self, request, renderers, format_suffix=None): + raise NotImplementedError('.select_renderer() must be implemented') -class DefaultContentNegotiation(object): + +class DefaultContentNegotiation(BaseContentNegotiation): settings = api_settings - def select_parser(self, parsers, media_type): + def select_parser(self, request, parsers): """ Given a list of parsers and a media type, return the appropriate parser to handle the incoming request. """ for parser in parsers: - if media_type_matches(parser.media_type, media_type): + if media_type_matches(parser.media_type, request.content_type): return parser return None - def negotiate(self, request, renderers, format=None, force=False): + def select_renderer(self, request, renderers, format_suffix=None): """ Given a request and a list of renderers, return a two-tuple of: (renderer, media type). - - If force is set, then suppress exceptions, and forcibly return a - fallback renderer and media_type. - """ - try: - return self.unforced_negotiate(request, renderers, format) - except (exceptions.InvalidFormat, exceptions.NotAcceptable): - if force: - return (renderers[0], renderers[0].media_type) - raise - - def unforced_negotiate(self, request, renderers, format=None): - """ - As `.negotiate()`, but does not take the optional `force` agument, - or suppress exceptions. """ # Allow URL style format override. eg. "?format=json - format = format or request.GET.get(self.settings.URL_FORMAT_OVERRIDE) + format_query_param = self.settings.URL_FORMAT_OVERRIDE + format = format_suffix or request.GET.get(format_query_param) if format: renderers = self.filter_renderers(renderers, format) @@ -77,7 +67,7 @@ class DefaultContentNegotiation(object): renderers = [renderer for renderer in renderers if renderer.format == format] if not renderers: - raise exceptions.InvalidFormat(format) + raise Http404 return renderers def get_accept_list(self, request): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 5325a64b..4841676c 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -1,14 +1,8 @@ """ -Django supports parsing the content of an HTTP request, but only for form POST requests. -That behavior is sufficient for dealing with standard HTML forms, but it doesn't map well -to general HTTP requests. +Parsers are used to parse the content of incoming HTTP requests. -We need a method to be able to: - -1.) Determine the parsed content on a request for methods other than POST (eg typically also PUT) - -2.) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded - and multipart/form-data. (eg also handle multipart/json) +They give us a generic way of being able to handle various media types +on the request, such as form content or json encoded data. """ from django.http import QueryDict @@ -21,7 +15,6 @@ from xml.etree import ElementTree as ET from xml.parsers.expat import ExpatError import datetime import decimal -from io import BytesIO class DataAndFiles(object): @@ -33,29 +26,18 @@ class DataAndFiles(object): class BaseParser(object): """ All parsers should extend `BaseParser`, specifying a `media_type` - attribute, and overriding the `.parse_stream()` method. + attribute, and overriding the `.parse()` method. """ media_type = None - def parse(self, string_or_stream, **opts): - """ - The main entry point to parsers. This is a light wrapper around - `parse_stream`, that instead handles both string and stream objects. + def parse(self, stream, media_type=None, parser_context=None): """ - if isinstance(string_or_stream, basestring): - stream = BytesIO(string_or_stream) - else: - stream = string_or_stream - return self.parse_stream(stream, **opts) - - def parse_stream(self, stream, **opts): - """ - Given a stream to read from, return the deserialized output. - Should return parsed data, or a DataAndFiles object consisting of the + Given a stream to read from, return the parsed representation. + Should return parsed data, or a `DataAndFiles` object consisting of the parsed data and files. """ - raise NotImplementedError(".parse_stream() must be overridden.") + raise NotImplementedError(".parse() must be overridden.") class JSONParser(BaseParser): @@ -65,7 +47,7 @@ class JSONParser(BaseParser): media_type = 'application/json' - def parse_stream(self, stream, **opts): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -85,7 +67,7 @@ class YAMLParser(BaseParser): media_type = 'application/yaml' - def parse_stream(self, stream, **opts): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -98,23 +80,6 @@ class YAMLParser(BaseParser): raise ParseError('YAML parse error - %s' % unicode(exc)) -class PlainTextParser(BaseParser): - """ - Plain text parser. - """ - - media_type = 'text/plain' - - def parse_stream(self, stream, **opts): - """ - Returns a 2-tuple of `(data, files)`. - - `data` will simply be a string representing the body of the request. - `files` will always be `None`. - """ - return stream.read() - - class FormParser(BaseParser): """ Parser for form data. @@ -122,7 +87,7 @@ class FormParser(BaseParser): media_type = 'application/x-www-form-urlencoded' - def parse_stream(self, stream, **opts): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a 2-tuple of `(data, files)`. @@ -140,15 +105,18 @@ class MultiPartParser(BaseParser): media_type = 'multipart/form-data' - def parse_stream(self, stream, **opts): + def parse(self, stream, media_type=None, parser_context=None): """ Returns a DataAndFiles object. `.data` will be a `QueryDict` containing all the form parameters. `.files` will be a `QueryDict` containing all the form files. """ - meta = opts['meta'] - upload_handlers = opts['upload_handlers'] + parser_context = parser_context or {} + request = parser_context['request'] + meta = request.META + upload_handlers = request.upload_handlers + try: parser = DjangoMultiPartParser(meta, stream, upload_handlers) data, files = parser.parse() @@ -164,7 +132,7 @@ class XMLParser(BaseParser): media_type = 'application/xml' - def parse_stream(self, stream, **opts): + def parse(self, stream, media_type=None, parser_context=None): try: tree = ET.parse(stream) except (ExpatError, ETParseError, ValueError), exc: diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 13ea39ea..655b78a3 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -1,8 +1,5 @@ """ -The :mod:`permissions` module bundles a set of permission classes that are used -for checking if a request passes a certain set of constraints. - -Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class. +Provides a set of pluggable permission policies. """ @@ -16,11 +13,22 @@ class BasePermission(object): def has_permission(self, request, view, obj=None): """ - Should simply return, or raise an :exc:`response.ImmediateResponse`. + Return `True` if permission is granted, `False` otherwise. """ raise NotImplementedError(".has_permission() must be overridden.") +class AllowAny(BasePermission): + """ + Allow any access. + This isn't strictly required, since you could use an empty + permission_classes list, but it's useful because it makes the intention + more explicit. + """ + def has_permission(self, request, view, obj=None): + return True + + class IsAuthenticated(BasePermission): """ Allows access only to authenticated users. @@ -64,7 +72,8 @@ class DjangoModelPermissions(BasePermission): It ensures that the user is authenticated, and has the appropriate `add`/`change`/`delete` permissions on the model. - This permission should only be used on views with a `ModelResource`. + This permission will only be applied against view classes that + provide a `.model` attribute, such as the generic class-based views. """ # Map methods into required permission codes. @@ -87,12 +96,15 @@ class DjangoModelPermissions(BasePermission): """ kwargs = { 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.module_name + 'model_name': model_cls._meta.module_name } return [perm % kwargs for perm in self.perms_map[method]] def has_permission(self, request, view, obj=None): - model_cls = view.model + model_cls = getattr(view, 'model', None) + if not model_cls: + return True + perms = self.get_required_permissions(request.method, model_cls) if (request.user and diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index e5e4134b..8dff0c77 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -1,12 +1,15 @@ """ -Renderers are used to serialize a View's output into specific media types. +Renderers are used to serialize a response into specific media types. -Django REST framework also provides HTML and PlainText renderers that help self-document the API, -by serializing the output along with documentation regarding the View, output status and headers, -and providing forms and links depending on the allowed methods, renderers and parsers on the View. +They give us a generic way of being able to handle various media types +on the response, such as JSON encoded data or HTML output. + +REST framework also provides an HTML renderer the renders the browseable API. """ +import copy import string from django import forms +from django.http.multipartparser import parse_header from django.template import RequestContext, loader from django.utils import simplejson as json from rest_framework.compat import yaml @@ -16,15 +19,14 @@ from rest_framework.request import clone_request from rest_framework.utils import dict2xml from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework.utils.mediatypes import get_media_type_params from rest_framework import VERSION from rest_framework import serializers, parsers class BaseRenderer(object): """ - All renderers must extend this class, set the :attr:`media_type` attribute, - and override the :meth:`render` method. + All renderers should extend this class, setting the `media_type` + and `format` attributes, and override the `.render()` method. """ media_type = None @@ -58,7 +60,7 @@ class JSONRenderer(BaseRenderer): if accepted_media_type: # If the media type looks like 'application/json; indent=4', # then pretty print the result. - params = get_media_type_params(accepted_media_type) + base_media_type, params = parse_header(accepted_media_type) indent = params.get('indent', indent) try: indent = max(min(int(indent), 8), 0) @@ -137,13 +139,24 @@ class YAMLRenderer(BaseRenderer): return yaml.dump(data, stream=None, Dumper=self.encoder) -class HTMLRenderer(BaseRenderer): +class TemplateHTMLRenderer(BaseRenderer): """ - A Base class provided for convenience. + An HTML renderer for use with templates. + + The data supplied to the Response object should be a dictionary that will + be used as context for the template. + + The template name is determined by (in order of preference): + + 1. An explicit `.template_name` attribute set on the response. + 2. An explicit `.template_name` attribute set on this class. + 3. The return result of calling `view.get_template_names()`. - Render the object simply by using the given template. - To create a template renderer, subclass this class, and set - the :attr:`media_type` and :attr:`template` attributes. + For example: + data = {'users': User.objects.all()} + return Response(data, template_name='users.html') + + For pre-rendered HTML, see StaticHTMLRenderer. """ media_type = 'text/html' @@ -186,6 +199,26 @@ class HTMLRenderer(BaseRenderer): raise ConfigurationError('Returned a template response with no template_name') +class StaticHTMLRenderer(BaseRenderer): + """ + An HTML renderer class that simply returns pre-rendered HTML. + + The data supplied to the Response object should be a string representing + the pre-rendered HTML content. + + For example: + data = '<html><body>example</body></html>' + return Response(data) + + For template rendered HTML, see TemplateHTMLRenderer. + """ + media_type = 'text/html' + format = 'html' + + def render(self, data, accepted_media_type=None, renderer_context=None): + return data + + class BrowsableAPIRenderer(BaseRenderer): """ HTML renderer used to self-document the API. @@ -222,11 +255,9 @@ class BrowsableAPIRenderer(BaseRenderer): return content - def get_form(self, view, method, request): + def show_form_for_method(self, view, method, request, obj): """ - Get a form, possibly bound to either the input or output data. - In the absence on of the Resource having an associated form then - provide a form that can be used to submit arbitrary content. + Returns True if a form should be shown for this method. """ if not method in view.allowed_methods: return # Not a valid method @@ -235,22 +266,14 @@ class BrowsableAPIRenderer(BaseRenderer): return # Cannot use form overloading request = clone_request(request, method) - if not view.has_permission(request): - return # Don't have permission - - if method == 'DELETE' or method == 'OPTIONS': - return True # Don't actually need to return a form - - if (not getattr(view, 'get_serializer', None) or - not parsers.FormParser in getattr(view, 'parser_classes')): - media_types = [parser.media_type for parser in view.parser_classes] - return self.get_generic_content_form(media_types) - - ##### - # TODO: This is a little bit of a hack. Actually we'd like to remove - # this and just render serializer fields to html directly. + try: + if not view.has_permission(request, obj): + return # Don't have permission + except: + return # Don't have permission and exception explicitly raise + return True - # We need to map our Fields to Django's Fields. + def serializer_to_form_fields(self, serializer): field_mapping = { serializers.FloatField: forms.FloatField, serializers.IntegerField: forms.IntegerField, @@ -260,32 +283,69 @@ class BrowsableAPIRenderer(BaseRenderer): serializers.CharField: forms.CharField, serializers.BooleanField: forms.BooleanField, serializers.PrimaryKeyRelatedField: forms.ModelChoiceField, - serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField + serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField, + serializers.HyperlinkedRelatedField: forms.ModelChoiceField, + serializers.ManyHyperlinkedRelatedField: forms.ModelMultipleChoiceField } - # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python fields = {} - obj, data = None, None - if getattr(view, 'object', None): - obj = view.object - - serializer = view.get_serializer(instance=obj) for k, v in serializer.get_fields(True).items(): - if getattr(v, 'readonly', True): + if getattr(v, 'read_only', True): continue kwargs = {} + kwargs['required'] = v.required + if getattr(v, 'queryset', None): - kwargs['queryset'] = getattr(v, 'queryset', None) + kwargs['queryset'] = v.queryset + + if getattr(v, 'widget', None): + widget = copy.deepcopy(v.widget) + # If choices have friendly readable names, + # then add in the identities too + if getattr(widget, 'choices', None): + choices = widget.choices + if any([ident != desc for (ident, desc) in choices]): + choices = [(ident, "%s (%s)" % (desc, ident)) + for (ident, desc) in choices] + widget.choices = choices + kwargs['widget'] = widget + + if getattr(v, 'default', None) is not None: + kwargs['initial'] = v.default + + kwargs['label'] = k try: fields[k] = field_mapping[v.__class__](**kwargs) except KeyError: - fields[k] = forms.CharField() + fields[k] = forms.CharField(**kwargs) + return fields + + def get_form(self, view, method, request): + """ + Get a form, possibly bound to either the input or output data. + In the absence on of the Resource having an associated form then + provide a form that can be used to submit arbitrary content. + """ + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + + if method == 'DELETE' or method == 'OPTIONS': + return True # Don't actually need to return a form + + if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: + media_types = [parser.media_type for parser in view.parser_classes] + return self.get_generic_content_form(media_types) + + serializer = view.get_serializer(instance=obj) + fields = self.serializer_to_form_fields(serializer) + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) - if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted - data = serializer.data + data = (obj is not None) and serializer.data or None form_instance = OnTheFlyForm(data) return form_instance diff --git a/rest_framework/request.py b/rest_framework/request.py index 0a57d376..a1827ba4 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -11,9 +11,18 @@ The wrapped request then offers a richer API, in particular : """ from StringIO import StringIO +from django.http.multipartparser import parse_header from rest_framework import exceptions from rest_framework.settings import api_settings -from rest_framework.utils.mediatypes import is_form_media_type + + +def is_form_media_type(media_type): + """ + Return True if the media type is a valid form media type. + """ + base_media_type, params = parse_header(media_type) + return (base_media_type == 'application/x-www-form-urlencoded' or + base_media_type == 'multipart/form-data') class Empty(object): @@ -35,7 +44,8 @@ def clone_request(request, method): """ ret = Request(request._request, request.parsers, - request.authenticators) + request.authenticators, + request.parser_context) ret._data = request._data ret._files = request._files ret._content_type = request._content_type @@ -65,19 +75,24 @@ class Request(object): _CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE def __init__(self, request, parsers=None, authenticators=None, - negotiator=None): + negotiator=None, parser_context=None): self._request = request self.parsers = parsers or () self.authenticators = authenticators or () self.negotiator = negotiator or self._default_negotiator() + self.parser_context = parser_context self._data = Empty self._files = Empty self._method = Empty self._content_type = Empty self._stream = Empty + if self.parser_context is None: + self.parser_context = {} + self.parser_context['request'] = self + def _default_negotiator(self): - return api_settings.DEFAULT_CONTENT_NEGOTIATION() + return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() @property def method(self): @@ -96,7 +111,7 @@ class Request(object): """ Returns the content type header. - This should be used instead of ``request.META.get('HTTP_CONTENT_TYPE')``, + This should be used instead of `request.META.get('HTTP_CONTENT_TYPE')`, as it allows the content type to be overridden by using a hidden form field on a form POST request. """ @@ -245,16 +260,19 @@ class Request(object): May raise an `UnsupportedMediaType`, or `ParseError` exception. """ - if self.stream is None or self.content_type is None: + stream = self.stream + media_type = self.content_type + + if stream is None or media_type is None: return (None, None) - parser = self.negotiator.select_parser(self.parsers, self.content_type) + parser = self.negotiator.select_parser(self, self.parsers) if not parser: - raise exceptions.UnsupportedMediaType(self.content_type) + raise exceptions.UnsupportedMediaType(media_type) + + parsed = parser.parse(stream, media_type, self.parser_context) - parsed = parser.parse(self.stream, meta=self.META, - upload_handlers=self.upload_handlers) # Parser classes may return the raw data, or a # DataAndFiles object. Unpack the result as required. try: diff --git a/rest_framework/resources.py b/rest_framework/resources.py deleted file mode 100644 index bb3d581f..00000000 --- a/rest_framework/resources.py +++ /dev/null @@ -1,95 +0,0 @@ -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -from functools import update_wrapper -import inspect -from django.utils.decorators import classonlymethod -from rest_framework import views, generics - - -def wrapped(source, dest): - """ - Copy public, non-method attributes from source to dest, and return dest. - """ - for attr in [attr for attr in dir(source) - if not attr.startswith('_') and not inspect.ismethod(attr)]: - setattr(dest, attr, getattr(source, attr)) - return dest - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ResourceMixin(object): - """ - Clone Django's `View.as_view()` behaviour *except* using REST framework's - 'method -> action' binding for resources. - """ - - @classonlymethod - def as_view(cls, actions, **initkwargs): - """ - Main entry point for a request-response process. - """ - # sanitize keyword arguments - for key in initkwargs: - if key in cls.http_method_names: - raise TypeError("You tried to pass in the %s method name as a " - "keyword argument to %s(). Don't do that." - % (key, cls.__name__)) - if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r" % ( - cls.__name__, key)) - - def view(request, *args, **kwargs): - self = cls(**initkwargs) - - # Bind methods to actions - for method, action in actions.items(): - handler = getattr(self, action) - setattr(self, method, handler) - - # As you were, solider. - if hasattr(self, 'get') and not hasattr(self, 'head'): - self.head = self.get - return self.dispatch(request, *args, **kwargs) - - # take name and docstring from class - update_wrapper(view, cls, updated=()) - - # and possible attributes set by decorators - # like csrf_exempt from dispatch - update_wrapper(view, cls.dispatch, assigned=()) - return view - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class Resource(ResourceMixin, views.APIView): - pass - - -##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY ##### - -class ModelResource(ResourceMixin, views.APIView): - root_class = generics.ListCreateAPIView - detail_class = generics.RetrieveUpdateDestroyAPIView - - def root_view(self): - return wrapped(self, self.root_class()) - - def detail_view(self): - return wrapped(self, self.detail_class()) - - def list(self, request, *args, **kwargs): - return self.root_view().list(request, args, kwargs) - - def create(self, request, *args, **kwargs): - return self.root_view().create(request, args, kwargs) - - def retrieve(self, request, *args, **kwargs): - return self.detail_view().retrieve(request, args, kwargs) - - def update(self, request, *args, **kwargs): - return self.detail_view().update(request, args, kwargs) - - def destroy(self, request, *args, **kwargs): - return self.detail_view().destroy(request, args, kwargs) diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index ba663f98..c9db02f0 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse from django.utils.functional import lazy -def reverse(viewname, *args, **kwargs): +def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ Same as `django.core.urlresolvers.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ - request = kwargs.pop('request', None) - url = django_reverse(viewname, *args, **kwargs) + if format is not None: + kwargs = kwargs or {} + kwargs['format'] = format + url = django_reverse(viewname, args=args, kwargs=kwargs, **extra) if request: return request.build_absolute_uri(url) return url diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index b2438c9b..1bd0a5fc 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -32,7 +32,7 @@ def main(): else: print usage() sys.exit(1) - failures = test_runner.run_tests(['rest_framework' + test_case]) + failures = test_runner.run_tests(['tests' + test_case]) sys.exit(failures) diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 67de82c8..951b1e72 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -91,6 +91,7 @@ INSTALLED_APPS = ( # 'django.contrib.admindocs', 'rest_framework', 'rest_framework.authtoken', + 'rest_framework.tests' ) STATIC_URL = '/static/' @@ -100,14 +101,6 @@ import django if django.VERSION < (1, 3): INSTALLED_APPS += ('staticfiles',) -# OAuth support is optional, so we only test oauth if it's installed. -try: - import oauth_provider -except ImportError: - pass -else: - INSTALLED_APPS += ('oauth_provider',) - # If we're running on the Jenkins server we want to archive the coverage reports as XML. import os if os.environ.get('HUDSON_URL', None): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 06330017..3d134a74 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -3,6 +3,7 @@ import datetime import types from decimal import Decimal from django.db import models +from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model from rest_framework.fields import * @@ -22,10 +23,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata): pass -class RecursionOccured(BaseException): - pass - - def _is_protected_type(obj): """ True if the object is a native datatype that does not need to @@ -33,10 +30,10 @@ def _is_protected_type(obj): """ return isinstance(obj, ( types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) + int, long, + datetime.datetime, datetime.date, datetime.time, + float, Decimal, + basestring) ) @@ -73,7 +70,7 @@ class SerializerOptions(object): Meta class options for Serializer """ def __init__(self, meta): - self.nested = getattr(meta, 'nested', False) + self.depth = getattr(meta, 'depth', 0) self.fields = getattr(meta, 'fields', ()) self.exclude = getattr(meta, 'exclude', ()) @@ -92,7 +89,6 @@ class BaseSerializer(Field): self.parent = None self.root = None - self.stack = [] self.context = context or {} self.init_data = data @@ -151,14 +147,11 @@ class BaseSerializer(Field): def initialize(self, parent): """ Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth and recursion. + of state so that we can deal with handling maximum depth. """ super(BaseSerializer, self).initialize(parent) - self.stack = parent.stack[:] - if parent.opts.nested and not isinstance(parent.opts.nested, bool): - self.opts.nested = parent.opts.nested - 1 - else: - self.opts.nested = parent.opts.nested + if parent.opts.depth: + self.opts.depth = parent.opts.depth - 1 ##### # Methods to convert or revert from objects <--> primative representations. @@ -174,21 +167,13 @@ class BaseSerializer(Field): Core of serialization. Convert an object into a dictionary of serialized field values. """ - if obj in self.stack and not self.source == '*': - raise RecursionOccured() - self.stack.append(obj) - ret = self._dict_class() ret.fields = {} - fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested) + fields = self.get_fields(serialize=True, obj=obj, nested=bool(self.opts.depth)) for field_name, field in fields.items(): key = self.get_field_key(field_name) - try: - value = field.field_to_native(obj, field_name) - except RecursionOccured: - field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name] - value = field.field_to_native(obj, field_name) + value = field.field_to_native(obj, field_name) ret[key] = value ret.fields[key] = field return ret @@ -198,7 +183,7 @@ class BaseSerializer(Field): Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested) + fields = self.get_fields(serialize=False, data=data, nested=bool(self.opts.depth)) reverted_data = {} for field_name, field in fields.items(): try: @@ -208,6 +193,35 @@ class BaseSerializer(Field): return reverted_data + def perform_validation(self, attrs): + """ + Run `validate_<fieldname>()` and `validate()` methods on the serializer + """ + # TODO: refactor this so we're not determining the fields again + fields = self.get_fields(serialize=False, data=attrs, nested=bool(self.opts.depth)) + + for field_name, field in fields.items(): + try: + validate_method = getattr(self, 'validate_%s' % field_name, None) + if validate_method: + source = field.source or field_name + attrs = validate_method(attrs, source) + except ValidationError as err: + self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) + + try: + attrs = self.validate(attrs) + except ValidationError as err: + self._errors['non_field_errors'] = err.messages + + return attrs + + def validate(self, attrs): + """ + Stub method, to be overridden in Serializer subclasses + """ + return attrs + def restore_object(self, attrs, instance=None): """ Deserialize a dictionary of attributes into an object instance. @@ -241,17 +255,31 @@ class BaseSerializer(Field): self._errors = {} if data is not None: attrs = self.restore_fields(data) + attrs = self.perform_validation(attrs) else: - self._errors['non_field_errors'] = 'No input provided' + self._errors['non_field_errors'] = ['No input provided'] if not self._errors: return self.restore_object(attrs, instance=getattr(self, 'object', None)) + def field_to_native(self, obj, field_name): + """ + Override default so that we can apply ModelSerializer as a nested + field to relationships. + """ + obj = getattr(obj, self.source or field_name) + + # If the object has an "all" method, assume it's a relationship + if is_simple_callable(getattr(obj, 'all', None)): + return [self.to_native(item) for item in obj.all()] + + return self.to_native(obj) + @property def errors(self): """ Run deserialization and return error data, - setting self.object if no errors occured. + setting self.object if no errors occurred. """ if self._errors is None: obj = self.from_native(self.init_data) @@ -295,16 +323,6 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions - def field_to_native(self, obj, field_name): - """ - Override default so that we can apply ModelSerializer as a nested - field to relationships. - """ - obj = getattr(obj, self.source or field_name) - if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): - return [self.to_native(item) for item in obj.all()] - return self.to_native(obj) - def default_fields(self, serialize, obj=None, data=None, nested=False): """ Return all the fields that should be serialized for the model. @@ -374,25 +392,43 @@ class ModelSerializer(Serializer): """ Creates a default instance of a basic non-relational field. """ + kwargs = {} + + kwargs['blank'] = model_field.blank + + if model_field.null: + kwargs['required'] = False + + if model_field.has_default(): + kwargs['required'] = False + kwargs['default'] = model_field.get_default() + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + # TODO: TypedChoiceField? + if model_field.flatchoices: # This ModelField contains choices + kwargs['choices'] = model_field.flatchoices + return ChoiceField(**kwargs) + field_mapping = { models.FloatField: FloatField, models.IntegerField: IntegerField, + models.PositiveIntegerField: IntegerField, + models.SmallIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, models.DateTimeField: DateTimeField, models.DateField: DateField, models.EmailField: EmailField, models.CharField: CharField, + models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, } try: - ret = field_mapping[model_field.__class__]() + return field_mapping[model_field.__class__](**kwargs) except KeyError: - ret = ModelField(model_field=model_field) - - if model_field.default: - ret.required = False - - return ret + return ModelField(model_field=model_field, **kwargs) def restore_object(self, attrs, instance=None): """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 5ebe7ba5..9c40a214 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -3,11 +3,11 @@ Settings for REST framework are all namespaced in the REST_FRAMEWORK setting. For example your project's `settings.py` file might look like this: REST_FRAMEWORK = { - 'DEFAULT_RENDERERS': ( + 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.YAMLRenderer', ) - 'DEFAULT_PARSERS': ( + 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', 'rest_framework.parsers.YAMLParser', ) @@ -24,30 +24,36 @@ from django.utils import importlib USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { - 'DEFAULT_RENDERERS': ( + 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', ), - 'DEFAULT_PARSERS': ( + 'DEFAULT_PARSER_CLASSES': ( 'rest_framework.parsers.JSONParser', 'rest_framework.parsers.FormParser', 'rest_framework.parsers.MultiPartParser' ), - 'DEFAULT_AUTHENTICATION': ( + 'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', - 'rest_framework.authentication.UserBasicAuthentication' + 'rest_framework.authentication.BasicAuthentication' ), - 'DEFAULT_PERMISSIONS': (), - 'DEFAULT_THROTTLES': (), - 'DEFAULT_CONTENT_NEGOTIATION': + 'DEFAULT_PERMISSION_CLASSES': ( + 'rest_framework.permissions.AllowAny', + ), + 'DEFAULT_THROTTLE_CLASSES': ( + ), + + 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + 'DEFAULT_MODEL_SERIALIZER_CLASS': + 'rest_framework.serializers.ModelSerializer', + 'DEFAULT_PAGINATION_SERIALIZER_CLASS': + 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, }, - - 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer', - 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer', 'PAGINATE_BY': None, 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', @@ -65,14 +71,14 @@ DEFAULTS = { # List of settings that may be in string import notation. IMPORT_STRINGS = ( - 'DEFAULT_RENDERERS', - 'DEFAULT_PARSERS', - 'DEFAULT_AUTHENTICATION', - 'DEFAULT_PERMISSIONS', - 'DEFAULT_THROTTLES', - 'DEFAULT_CONTENT_NEGOTIATION', - 'MODEL_SERIALIZER', - 'PAGINATION_SERIALIZER', + 'DEFAULT_RENDERER_CLASSES', + 'DEFAULT_PARSER_CLASSES', + 'DEFAULT_AUTHENTICATION_CLASSES', + 'DEFAULT_PERMISSION_CLASSES', + 'DEFAULT_THROTTLE_CLASSES', + 'DEFAULT_CONTENT_NEGOTIATION_CLASS', + 'DEFAULT_MODEL_SERIALIZER_CLASS', + 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) @@ -111,7 +117,7 @@ class APISettings(object): For example: from rest_framework.settings import api_settings - print api_settings.DEFAULT_RENDERERS + print api_settings.DEFAULT_RENDERER_CLASSES Any setting with string import paths will be automatically resolved and return the class, rather than the string literal. diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index 739b9300..e29da395 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -32,6 +32,10 @@ h2, h3 { margin-right: 1em; } +ul.breadcrumb { + margin: 58px 0 0 0; +} + /* To allow tooltips to work on disabled elements */ .disabled-tooltip-shield { position: absolute; @@ -55,6 +59,7 @@ pre { .page-header { border-bottom: none; padding-bottom: 0px; + margin-bottom: 20px; } @@ -65,7 +70,7 @@ html{ background: none; } -body, .navbar .navbar-inner .container-fluid{ +body, .navbar .navbar-inner .container-fluid { max-width: 1150px; margin: 0 auto; } @@ -76,13 +81,14 @@ body{ } #content{ - margin: 40px 0 0 0; + margin: 0; } /* custom navigation styles */ .wrapper .navbar{ - width:100%; + width: 100%; position: absolute; - left:0; + left: 0; + top: 0; } .navbar .navbar-inner{ diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 5ac6ef67..e0f79481 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -109,7 +109,7 @@ <div class="content-main"> <div class="page-header"><h1>{{ name }}</h1></div> - <p class="resource-description">{{ description }}</p> + {{ description }} <div class="request-info"> <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index 65af512e..c1271399 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -3,42 +3,50 @@ <html> <head> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/> + <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> + <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> + <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> </head> - <body class="login"> + <body class="container"> - <div id="container"> - - <div id="header"> - <div id="branding"> - <h1 id="site-name">Django REST framework</h1> +<div class="container-fluid" style="margin-top: 30px"> + <div class="row-fluid"> + + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + <h3 style="margin: 0 0 20px;">Django REST framework</h3> </div> - </div> + </div><!-- /row fluid --> - <div id="content" class="colM"> - <div id="content-main"> - <form method="post" action="{% url 'rest_framework:login' %}" id="login-form"> + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> {% csrf_token %} - <div class="form-row"> - <label for="id_username">Username:</label> {{ form.username }} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> </div> - <div class="form-row"> - <label for="id_password">Password:</label> {{ form.password }} - <input type="hidden" name="next" value="{{ next }}" /> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls" style="height: 30px"> + <Label class="span4" style="margin-top: 3px">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> </div> - <div class="form-row"> - <label> </label><input type="submit" value="Log in"> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> </div> </form> - <script type="text/javascript"> - document.getElementById('id_username').focus() - </script> </div> - <br class="clear"> - </div> + </div><!-- /row fluid --> + </div><!--/span--> - <div id="footer"></div> + </div><!-- /.row-fluid --> + </div> </div> </body> diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py index adeaf6da..e69de29b 100644 --- a/rest_framework/tests/__init__.py +++ b/rest_framework/tests/__init__.py @@ -1,13 +0,0 @@ -""" -Force import of all modules in this package in order to get the standard test -runner to pick up the tests. Yowzers. -""" -import os - -modules = [filename.rsplit('.', 1)[0] - for filename in os.listdir(os.path.dirname(__file__)) - if filename.endswith('.py') and not filename.startswith('_')] -__test__ = dict() - -for module in modules: - exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f4263478..a8279ef2 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -2,7 +2,7 @@ from django.test import TestCase from django.test.client import RequestFactory from django.utils import simplejson as json from rest_framework import generics, serializers, status -from rest_framework.tests.models import BasicModel, Comment +from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel factory = RequestFactory() @@ -22,6 +22,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView): model = BasicModel +class SlugSerializer(serializers.ModelSerializer): + slug = serializers.Field() # read only + + class Meta: + model = SlugBasedModel + exclude = ('id',) + + +class SlugBasedInstanceView(InstanceView): + """ + A model with a slug-field. + """ + model = SlugBasedModel + serializer_class = SlugSerializer + + class TestRootView(TestCase): def setUp(self): """ @@ -129,6 +145,7 @@ class TestInstanceView(TestCase): for obj in self.objects.all() ] self.view = InstanceView.as_view() + self.slug_based_view = SlugBasedInstanceView.as_view() def test_get_instance_view(self): """ @@ -198,7 +215,7 @@ class TestInstanceView(TestCase): def test_put_cannot_set_id(self): """ - POST requests to create a new object should not be able to set the id. + PUT requests to create a new object should not be able to set the id. """ content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), @@ -219,11 +236,39 @@ class TestInstanceView(TestCase): request = factory.put('/1', json.dumps(content), content_type='application/json') response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) self.assertEquals(updated.text, 'foobar') + def test_put_as_create_on_id_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if it doesn't exist. + """ + content = {'text': 'foobar'} + # pk fields can not be created on demand, only the database can set th pk for a new object + request = factory.put('/5', json.dumps(content), + content_type='application/json') + response = self.view(request, pk=5).render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + new_obj = self.objects.get(pk=5) + self.assertEquals(new_obj.text, 'foobar') + + def test_put_as_create_on_slug_based_url(self): + """ + PUT requests to RetrieveUpdateDestroyAPIView should create an object + at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. + """ + content = {'text': 'foobar'} + request = factory.put('/test_slug', json.dumps(content), + content_type='application/json') + response = self.slug_based_view(request, slug='test_slug').render() + self.assertEquals(response.status_code, status.HTTP_201_CREATED) + self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) + new_obj = SlugBasedModel.objects.get(slug='test_slug') + self.assertEquals(new_obj.text, 'foobar') + # Regression test for #285 diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index da2f83c3..10d7e31d 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -3,12 +3,12 @@ from django.test import TestCase from django.template import TemplateDoesNotExist, Template import django.template.loader from rest_framework.decorators import api_view, renderer_classes -from rest_framework.renderers import HTMLRenderer +from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.response import Response @api_view(('GET',)) -@renderer_classes((HTMLRenderer,)) +@renderer_classes((TemplateHTMLRenderer,)) def example(request): """ A view that can returns an HTML representation. @@ -22,7 +22,7 @@ urlpatterns = patterns('', ) -class HTMLRendererTests(TestCase): +class TemplateHTMLRendererTests(TestCase): urls = 'rest_framework.tests.htmlrenderer' def setUp(self): diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 5532a8ee..92c3691e 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -2,11 +2,19 @@ from django.conf.urls.defaults import patterns, url from django.test import TestCase from django.test.client import RequestFactory from rest_framework import generics, status, serializers -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel +from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment factory = RequestFactory() +class BlogPostCommentSerializer(serializers.Serializer): + text = serializers.CharField() + blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail', queryset=BlogPost.objects.all()) + + def restore_object(self, attrs, instance=None): + return BlogPostComment(**attrs) + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -32,12 +40,22 @@ class ManyToManyDetail(generics.RetrieveAPIView): model_serializer_class = serializers.HyperlinkedModelSerializer +class BlogPostCommentListCreate(generics.ListCreateAPIView): + model = BlogPostComment + model_serializer_class = BlogPostCommentSerializer + + +class BlogPostDetail(generics.RetrieveAPIView): + model = BlogPost + urlpatterns = patterns('', url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), + url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), + url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list') ) @@ -124,3 +142,27 @@ class TestManyToManyHyperlinkedView(TestCase): response = self.detail_view(request, pk=1).render() self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.data, self.data[0]) + + +class TestCreateWithForeignKeys(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create a blog post + """ + self.post = BlogPost.objects.create(title="Test post") + self.create_view = BlogPostCommentListCreate.as_view() + + def test_create_comment(self): + + data = { + 'text': 'A test comment', + 'blog_post_url': 'http://testserver/posts/1/' + } + + request = factory.post('/comments/', data=data) + response = self.create_view(request).render() + self.assertEqual(response.status_code, 201) + self.assertEqual(self.post.blogpostcomment_set.count(), 1) + self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 780c9dba..9efedbc4 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -40,7 +40,7 @@ class RESTFrameworkModel(models.Model): Base for test models that sets app_label, so they play nicely. """ class Meta: - app_label = 'rest_framework' + app_label = 'tests' abstract = True @@ -52,6 +52,11 @@ class BasicModel(RESTFrameworkModel): text = models.CharField(max_length=100) +class SlugBasedModel(RESTFrameworkModel): + text = models.CharField(max_length=100) + slug = models.SlugField(max_length=32) + + class DefaultValueModel(RESTFrameworkModel): text = models.CharField(default='foobar', max_length=100) @@ -63,6 +68,11 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) + +class ReadOnlyManyToManyModel(RESTFrameworkModel): + text = models.CharField(max_length=100, default='anchor') + rel = models.ManyToManyField(Anchor) + # Models to test generic relations @@ -98,3 +108,28 @@ class Comment(RESTFrameworkModel): email = models.EmailField() content = models.CharField(max_length=200) created = models.DateTimeField(auto_now_add=True) + + +class ActionItem(RESTFrameworkModel): + title = models.CharField(max_length=200) + done = models.BooleanField(default=False) + + +# Models for reverse relations +class BlogPost(RESTFrameworkModel): + title = models.CharField(max_length=100) + + +class BlogPostComment(RESTFrameworkModel): + text = models.TextField() + blog_post = models.ForeignKey(BlogPost) + + +class Person(RESTFrameworkModel): + name = models.CharField(max_length=10) + age = models.IntegerField(null=True, blank=True) + + +# Model for issue #324 +class BlankFieldModel(RESTFrameworkModel): + title = models.CharField(max_length=100, blank=True) diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py index d8265b43..e06354ea 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/negotiation.py @@ -18,20 +18,20 @@ class TestAcceptedMediaType(TestCase): self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] self.negotiator = DefaultContentNegotiation() - def negotiate(self, request): - return self.negotiator.negotiate(request, self.renderers) + def select_renderer(self, request): + return self.negotiator.select_renderer(request, self.renderers) def test_client_without_accept_use_renderer(self): request = factory.get('/') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json') def test_client_underspecifies_accept_use_renderer(self): request = factory.get('/', HTTP_ACCEPT='*/*') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json') def test_client_overspecifies_accept_use_client(self): request = factory.get('/', HTTP_ACCEPT='application/json; indent=8') - accepted_renderer, accepted_media_type = self.negotiate(request) + accepted_renderer, accepted_media_type = self.select_renderer(request) self.assertEquals(accepted_media_type, 'application/json; indent=8') diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 7b24b036..ff48f3fa 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -10,9 +10,9 @@ from rest_framework import status from rest_framework.authentication import SessionAuthentication from django.test.client import RequestFactory from rest_framework.parsers import ( + BaseParser, FormParser, MultiPartParser, - PlainTextParser, JSONParser ) from rest_framework.request import Request @@ -24,6 +24,19 @@ from rest_framework.views import APIView factory = RequestFactory() +class PlainTextParser(BaseParser): + media_type = 'text/plain' + + def parse(self, stream, media_type=None, parser_context=None): + """ + Returns a 2-tuple of `(data, files)`. + + `data` will simply be a string representing the body of the request. + `files` will always be `None`. + """ + return stream.read() + + class TestMethodOverloading(TestCase): def test_method(self): """ diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 256987ad..d4b43862 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -4,6 +4,11 @@ from rest_framework import serializers from rest_framework.tests.models import * +class SubComment(object): + def __init__(self, sub_comment): + self.sub_comment = sub_comment + + class Comment(object): def __init__(self, email, content, created): self.email = email @@ -14,11 +19,16 @@ class Comment(object): return all([getattr(self, attr) == getattr(other, attr) for attr in ('email', 'content', 'created')]) + def get_sub_comment(self): + sub_comment = SubComment('And Merry Christmas!') + return sub_comment + class CommentSerializer(serializers.Serializer): email = serializers.EmailField() content = serializers.CharField(max_length=1000) created = serializers.DateTimeField() + sub_comment = serializers.Field(source='get_sub_comment.sub_comment') def restore_object(self, data, instance=None): if instance is None: @@ -28,6 +38,16 @@ class CommentSerializer(serializers.Serializer): return instance +class ActionItemSerializer(serializers.ModelSerializer): + class Meta: + model = ActionItem + + +class PersonSerializer(serializers.ModelSerializer): + class Meta: + model = Person + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -38,7 +58,14 @@ class BasicTests(TestCase): self.data = { 'email': 'tom@example.com', 'content': 'Happy new year!', - 'created': datetime.datetime(2012, 1, 1) + 'created': datetime.datetime(2012, 1, 1), + 'sub_comment': 'This wont change' + } + self.expected = { + 'email': 'tom@example.com', + 'content': 'Happy new year!', + 'created': datetime.datetime(2012, 1, 1), + 'sub_comment': 'And Merry Christmas!' } def test_empty(self): @@ -46,14 +73,14 @@ class BasicTests(TestCase): expected = { 'email': '', 'content': '', - 'created': None + 'created': None, + 'sub_comment': '' } self.assertEquals(serializer.data, expected) def test_retrieve(self): serializer = CommentSerializer(instance=self.comment) - expected = self.data - self.assertEquals(serializer.data, expected) + self.assertEquals(serializer.data, self.expected) def test_create(self): serializer = CommentSerializer(self.data) @@ -61,6 +88,7 @@ class BasicTests(TestCase): self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) self.assertFalse(serializer.object is expected) + self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') def test_update(self): serializer = CommentSerializer(self.data, instance=self.comment) @@ -68,6 +96,7 @@ class BasicTests(TestCase): self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.object, expected) self.assertTrue(serializer.object is expected) + self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') class ValidationTests(TestCase): @@ -82,6 +111,8 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } + self.actionitem = ActionItem('Some to do item', + ) def test_create(self): serializer = CommentSerializer(self.data) @@ -102,6 +133,74 @@ class ValidationTests(TestCase): self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) + def test_missing_bool_with_default(self): + """Make sure that a boolean value with a 'False' value is not + mistaken for not having a default.""" + data = { + 'title': 'Some action item', + #No 'done' value. + } + serializer = ActionItemSerializer(data, instance=self.actionitem) + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.errors, {}) + + def test_field_validation(self): + + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = CommentSerializerWithFieldValidator(data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = CommentSerializerWithFieldValidator(data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + + def test_cross_field_validation(self): + + class CommentSerializerWithCrossFieldValidator(CommentSerializer): + + def validate(self, attrs): + if attrs["email"] not in attrs["content"]: + raise serializers.ValidationError("Email address not in content") + return attrs + + data = { + 'email': 'tom@example.com', + 'content': 'A comment from tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = CommentSerializerWithCrossFieldValidator(data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'A comment from foo@bar.com' + + serializer = CommentSerializerWithCrossFieldValidator(data) + self.assertFalse(serializer.is_valid()) + self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) + + def test_null_is_true_fields(self): + """ + Omitting a value for null-field should validate. + """ + serializer = PersonSerializer({'name': 'marko'}) + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.errors, {}) + class MetadataTests(TestCase): def test_empty(self): @@ -212,6 +311,61 @@ class ManyToManyTests(TestCase): self.assertEquals(list(instance.rel.all()), []) +class ReadOnlyManyToManyTests(TestCase): + def setUp(self): + class ReadOnlyManyToManySerializer(serializers.ModelSerializer): + rel = serializers.ManyRelatedField(read_only=True) + + class Meta: + model = ReadOnlyManyToManyModel + + self.serializer_class = ReadOnlyManyToManySerializer + + # An anchor instance to use for the relationship + self.anchor = Anchor() + self.anchor.save() + + # A model instance with a many to many relationship to the anchor + self.instance = ReadOnlyManyToManyModel() + self.instance.save() + self.instance.rel.add(self.anchor) + + # A serialized representation of the model instance + self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} + + def test_update(self): + """ + Attempt to update an instance of a model with a ManyToMany + relationship. Not updated due to read_only=True + """ + new_anchor = Anchor() + new_anchor.save() + data = {'rel': [self.anchor.id, new_anchor.id]} + serializer = self.serializer_class(data, instance=self.instance) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEquals(instance.pk, 1) + # rel is still as original (1 entry) + self.assertEquals(list(instance.rel.all()), [self.anchor]) + + def test_update_without_relationship(self): + """ + Attempt to update an instance of a model where many to ManyToMany + relationship is not supplied. Not updated due to read_only=True + """ + new_anchor = Anchor() + new_anchor.save() + data = {} + serializer = self.serializer_class(data, instance=self.instance) + self.assertEquals(serializer.is_valid(), True) + instance = serializer.save() + self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEquals(instance.pk, 1) + # rel is still as original (1 entry) + self.assertEquals(list(instance.rel.all()), [self.anchor]) + + class DefaultValueTests(TestCase): def setUp(self): class DefaultValueSerializer(serializers.ModelSerializer): @@ -266,3 +420,81 @@ class CallableDefaultValueTests(TestCase): self.assertEquals(len(self.objects.all()), 1) self.assertEquals(instance.pk, 1) self.assertEquals(instance.text, 'overridden') + + +class ManyRelatedTests(TestCase): + def setUp(self): + + class BlogPostCommentSerializer(serializers.Serializer): + text = serializers.CharField() + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + comments = BlogPostCommentSerializer(source='blogpostcomment_set') + + self.serializer_class = BlogPostSerializer + + def test_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + serializer = self.serializer_class(instance=post) + expected = { + 'title': 'Test blog post', + 'comments': [ + {'text': 'I hate this blog post'}, + {'text': 'I love this blog post'} + ] + } + + self.assertEqual(serializer.data, expected) + + +# Test for issue #324 +class BlankFieldTests(TestCase): + def setUp(self): + + class BlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BlankFieldModel + + class BlankFieldSerializer(serializers.Serializer): + title = serializers.CharField(blank=True) + + class NotBlankFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class NotBlankFieldSerializer(serializers.Serializer): + title = serializers.CharField() + + self.model_serializer_class = BlankFieldModelSerializer + self.serializer_class = BlankFieldSerializer + self.not_blank_model_serializer_class = NotBlankFieldModelSerializer + self.not_blank_serializer_class = NotBlankFieldSerializer + self.data = {'title': ''} + + def test_create_blank_field(self): + serializer = self.serializer_class(self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_model_blank_field(self): + serializer = self.model_serializer_class(self.data) + self.assertEquals(serializer.is_valid(), True) + + def test_create_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a non-model serializer + """ + serializer = self.not_blank_serializer_class(self.data) + self.assertEquals(serializer.is_valid(), False) + + def test_create_model_not_blank_field(self): + """ + Test to ensure blank data in a field not marked as blank=True + is considered invalid in a model serializer + """ + serializer = self.not_blank_model_serializer_class(self.data) + self.assertEquals(serializer.is_valid(), False) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py new file mode 100644 index 00000000..adeaf6da --- /dev/null +++ b/rest_framework/tests/tests.py @@ -0,0 +1,13 @@ +""" +Force import of all modules in this package in order to get the standard test +runner to pick up the tests. Yowzers. +""" +import os + +modules = [filename.rsplit('.', 1)[0] + for filename in os.listdir(os.path.dirname(__file__)) + if filename.endswith('.py') and not filename.startswith('_')] +__test__ = dict() + +for module in modules: + exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py index b390c42f..c032985e 100644 --- a/rest_framework/tests/validators.py +++ b/rest_framework/tests/validators.py @@ -285,7 +285,7 @@ # uiop = models.CharField(max_length=256, blank=True) # @property -# def readonly(self): +# def read_only(self): # return 'read only' # class MockResource(ModelResource): @@ -298,7 +298,7 @@ # def test_property_fields_are_allowed_on_model_forms(self): # """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} # self.assertEqual(self.validator.validate_request(content, None), content) # def test_property_fields_are_not_required_on_model_forms(self): @@ -310,19 +310,19 @@ # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'} +# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_requires_fields_on_model_forms(self): # """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'readonly': 'read only'} +# content = {'read_only': 'read only'} # self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # def test_validate_does_not_require_blankable_fields_on_model_forms(self): # """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" -# content = {'qwerty': 'example', 'readonly': 'read only'} +# content = {'qwerty': 'example', 'read_only': 'read only'} # self.validator.validate_request(content, None) # def test_model_form_validator_uses_model_forms(self): diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 566c277d..8fe64248 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -16,13 +16,13 @@ class BaseThrottle(object): def wait(self): """ - Optionally, return a recommeded number of seconds to wait before + Optionally, return a recommended number of seconds to wait before the next request. """ return None -class SimpleRateThottle(BaseThrottle): +class SimpleRateThrottle(BaseThrottle): """ A simple cache implementation, that only requires `.get_cache_key()` to be overridden. @@ -60,7 +60,7 @@ class SimpleRateThottle(BaseThrottle): Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None): - msg = ("You must set either `.scope` or `.rate` for '%s' thottle" % + msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise exceptions.ConfigurationError(msg) @@ -133,11 +133,11 @@ class SimpleRateThottle(BaseThrottle): return remaining_duration / float(available_requests) -class AnonRateThrottle(SimpleRateThottle): +class AnonRateThrottle(SimpleRateThrottle): """ Limits the rate of API calls that may be made by a anonymous users. - The IP address of the request will be used as the unqiue cache key. + The IP address of the request will be used as the unique cache key. """ scope = 'anon' @@ -153,7 +153,7 @@ class AnonRateThrottle(SimpleRateThottle): } -class UserRateThrottle(SimpleRateThottle): +class UserRateThrottle(SimpleRateThrottle): """ Limits the rate of API calls that may be made by a given user. @@ -175,7 +175,7 @@ class UserRateThrottle(SimpleRateThottle): } -class ScopedRateThrottle(SimpleRateThottle): +class ScopedRateThrottle(SimpleRateThrottle): """ Limits the rate of API calls by different amounts for various parts of the API. Any view that has the `throttle_scope` property set will be diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 386c78a2..316ccd19 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -2,26 +2,23 @@ from django.conf.urls.defaults import url from rest_framework.settings import api_settings -def format_suffix_patterns(urlpatterns, suffix_required=False, - suffix_kwarg=None, allowed=None): +def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): """ Supplement existing urlpatterns with corrosponding patterns that also include a '.format' suffix. Retains urlpattern ordering. + urlpatterns: + A list of URL patterns. + suffix_required: If `True`, only suffixed URLs will be generated, and non-suffixed URLs will not be used. Defaults to `False`. - suffix_kwarg: - The name of the kwarg that will be passed to the view. - Defaults to 'format'. - allowed: An optional tuple/list of allowed suffixes. eg ['json', 'api'] Defaults to `None`, which allows any suffix. - """ - suffix_kwarg = suffix_kwarg or api_settings.FORMAT_SUFFIX_KWARG + suffix_kwarg = api_settings.FORMAT_SUFFIX_KWARG if allowed: if len(allowed) == 1: allowed_pattern = allowed[0] diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index 5eba7fb2..ee7f3a54 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -25,32 +25,6 @@ def media_type_matches(lhs, rhs): return lhs.match(rhs) -def is_form_media_type(media_type): - """ - Return True if the media type is a valid form media type as defined by the HTML4 spec. - (NB. HTML5 also adds text/plain to the list of valid form media types, but we don't support this here) - """ - media_type = _MediaType(media_type) - return media_type.full_type == 'application/x-www-form-urlencoded' or \ - media_type.full_type == 'multipart/form-data' - - -def add_media_type_param(media_type, key, val): - """ - Add a key, value parameter to a media type string, and return the new media type string. - """ - media_type = _MediaType(media_type) - media_type.params[key] = val - return str(media_type) - - -def get_media_type_params(media_type): - """ - Return a dictionary of the parameters on the given media type. - """ - return _MediaType(media_type).params - - def order_by_precedence(media_type_lst): """ Returns a list of sets of media type strings, ordered by precedence. diff --git a/rest_framework/views.py b/rest_framework/views.py index b3f36085..71e1fe6c 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,8 +1,5 @@ """ -The :mod:`views` module provides the Views you will most probably -be subclassing in your implementation. - -By setting or modifying class attributes on your view, you change it's predefined behaviour. +Provides an APIView class that is used as the base of all class-based views. """ import re @@ -57,12 +54,12 @@ def _camelcase_to_spaces(content): class APIView(View): settings = api_settings - renderer_classes = api_settings.DEFAULT_RENDERERS - parser_classes = api_settings.DEFAULT_PARSERS - authentication_classes = api_settings.DEFAULT_AUTHENTICATION - throttle_classes = api_settings.DEFAULT_THROTTLES - permission_classes = api_settings.DEFAULT_PERMISSIONS - content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION + renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES + parser_classes = api_settings.DEFAULT_PARSER_CLASSES + authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES + permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES + content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS @classmethod def as_view(cls, **initkwargs): @@ -159,18 +156,31 @@ class APIView(View): """ raise exceptions.Throttled(wait) + def get_parser_context(self, http_request): + """ + Returns a dict that is passed through to Parser.parse(), + as the `parser_context` keyword argument. + """ + # Note: Additionally `request` will also be added to the context + # by the Request object. + return { + 'view': self, + 'args': getattr(self, 'args', ()), + 'kwargs': getattr(self, 'kwargs', {}) + } + def get_renderer_context(self): """ - Returns a dict that is passed through to the Renderer.render(), + Returns a dict that is passed through to Renderer.render(), as the `renderer_context` keyword argument. """ - # Note: Additionally 'response' will also be set on the context, + # Note: Additionally 'response' will also be added to the context, # by the Response object. return { 'view': self, - 'request': self.request, - 'args': self.args, - 'kwargs': self.kwargs + 'args': getattr(self, 'args', ()), + 'kwargs': getattr(self, 'kwargs', {}), + 'request': getattr(self, 'request', None) } # API policy instantiation methods @@ -208,7 +218,7 @@ class APIView(View): def get_throttles(self): """ - Instantiates and returns the list of thottles that this view uses. + Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes] @@ -228,7 +238,13 @@ class APIView(View): """ renderers = self.get_renderers() conneg = self.get_content_negotiator() - return conneg.negotiate(request, renderers, self.format_kwarg, force) + + try: + return conneg.select_renderer(request, renderers, self.format_kwarg) + except: + if force: + return (renderers[0], renderers[0].media_type) + raise def has_permission(self, request, obj=None): """ @@ -253,10 +269,13 @@ class APIView(View): """ Returns the initial request object. """ + parser_context = self.get_parser_context(request) + return Request(request, parsers=self.get_parsers(), authenticators=self.get_authenticators(), - negotiator=self.get_content_negotiator()) + negotiator=self.get_content_negotiator(), + parser_context=parser_context) def initial(self, request, *args, **kwargs): """ |
