diff options
28 files changed, 881 insertions, 165 deletions
diff --git a/.travis.yml b/.travis.yml index 205feef9..3a7c2d7a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,9 +7,9 @@ python: - "3.3" env: - - DJANGO="django==1.5 --use-mirrors" - - DJANGO="django==1.4.3 --use-mirrors" - - DJANGO="django==1.3.5 --use-mirrors" + - DJANGO="django==1.5.1 --use-mirrors" + - DJANGO="django==1.4.5 --use-mirrors" + - DJANGO="django==1.3.7 --use-mirrors" install: - pip install $DJANGO @@ -18,7 +18,7 @@ install: - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.0 --use-mirrors; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.3 --use-mirrors; fi" - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4 --use-mirrors; fi" - - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6a1 --use-mirrors; fi" + - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6 --use-mirrors; fi" - export PYTHONPATH=. script: @@ -27,10 +27,11 @@ script: matrix: exclude: - python: "3.2" - env: DJANGO="django==1.4.3 --use-mirrors" + env: DJANGO="django==1.4.5 --use-mirrors" - python: "3.2" - env: DJANGO="django==1.3.5 --use-mirrors" + env: DJANGO="django==1.3.7 --use-mirrors" - python: "3.3" - env: DJANGO="django==1.4.3 --use-mirrors" + env: DJANGO="django==1.4.5 --use-mirrors" - python: "3.3" - env: DJANGO="django==1.3.5 --use-mirrors" + env: DJANGO="django==1.3.7 --use-mirrors" + diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index b5dfc68e..a710ad7d 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -182,7 +182,7 @@ For more details on using filter sets see the [django-filter documentation][djan The `SearchFilterBackend` class supports simple single query parameter based searching, and is based on the [Django admin's search functionality][search-django-admin]. -The `SearchFilterBackend` class will only be applied if the view has a `search_fields` attribute set. The `search_fields` attribute should be a list of names of text fields on the model. +The `SearchFilterBackend` class will only be applied if the view has a `search_fields` attribute set. The `search_fields` attribute should be a list of names of text type fields on the model, such as `CharField` or `TextField`. class UserListView(generics.ListAPIView): queryset = User.objects.all() @@ -190,7 +190,7 @@ The `SearchFilterBackend` class will only be applied if the view has a `search_f filter_backends = (filters.SearchFilter,) search_fields = ('username', 'email') -This will allow the client to filter the itemss in the list by making queries such as: +This will allow the client to filter the items in the list by making queries such as: http://example.com/api/users?search=russell @@ -198,12 +198,50 @@ You can also perform a related lookup on a ForeignKey or ManyToManyField with th search_fields = ('username', 'email', 'profile__profession') -By default, searches will use case-insensitive partial matches. If the search parameter contains multiple whitespace seperated words, then objects will be returned in the list only if all the provided words are matched. +By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. + +The search behavior may be restricted by prepending various characters to the `search_fields`. + +* '^' Starts-with search. +* '=' Exact matches. +* '@' Full-text search. (Currently only supported Django's MySQL backend.) + +For example: + + search_fields = ('=username', '=email') For more details, see the [Django documentation][search-django-admin]. --- +## OrderingFilter + +The `OrderingFilter` class supports simple query parameter controlled ordering of results. To specify the result order, set a query parameter named `'order'` to the required field name. For example: + + http://example.com/api/users?ordering=username + +The client may also specify reverse orderings by prefixing the field name with '-', like so: + + http://example.com/api/users?ordering=-username + +Multiple orderings may also be specified: + + http://example.com/api/users?ordering=account,username + +If an `ordering` attribute is set on the view, this will be used as the default ordering. + +Typicaly you'd instead control this by setting `order_by` on the initial queryset, but using the `ordering` parameter on the view allows you to specify the ordering in a way that it can then be passed automatically as context to a rendered template. This makes it possible to automatically render column headers differently if they are being used to order the results. + + class UserListView(generics.ListAPIView): + queryset = User.objects.all() + serializer = UserSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('username',) + +The `ordering` attribute may be either a string or a list/tuple of strings. + +--- + # Custom generic filtering You can also provide your own generic filtering backend, or write an installable app for other developers to use. @@ -223,7 +261,7 @@ For example, you might need to restrict users to only being able to see objects def filter_queryset(self, request, queryset, view): return queryset.filter(owner=request.user) -We could do the same thing by overriding `get_queryset` on the views, but using a filter backend allows you to more easily add this restriction to multiple views, or to apply it across the entire API. +We could achieve the same behavior by overriding `get_queryset()` on the views, but using a filter backend allows you to more easily add this restriction to multiple views, or to apply it across the entire API. [cite]: https://docs.djangoproject.com/en/dev/topics/db/queries/#retrieving-specific-objects-with-filters [django-filter]: https://github.com/alex/django-filter diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md index 756e1562..155c89de 100644 --- a/docs/api-guide/relations.md +++ b/docs/api-guide/relations.md @@ -196,15 +196,13 @@ Would serialize to a representation like this: 'artist': 'Thom Yorke' 'track_listing': 'http://www.example.com/api/track_list/12/', } - + This field is always read-only. **Arguments**: * `view_name` - The view name that should be used as the target of the relationship. **required**. -* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`. -* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`. -* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`. +* `lookup_field` - The field on the target that should be used for the lookup. Should correspond to a URL keyword argument on the referenced view. Default is `'pk'`. * `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument. --- diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md index e354a43a..cd92dc58 100644 --- a/docs/api-guide/viewsets.md +++ b/docs/api-guide/viewsets.md @@ -13,11 +13,11 @@ A `ViewSet` class is simply **a type of class-based View, that does not provide The method handlers for a `ViewSet` are only bound to the corresponding actions at the point of finalizing the view, using the `.as_view()` method. -Typically, rather than exlicitly registering the views in a viewset in the urlconf, you'll register the viewset with a router class, that automatically determines the urlconf for you. +Typically, rather than explicitly registering the views in a viewset in the urlconf, you'll register the viewset with a router class, that automatically determines the urlconf for you. ## Example -Let's define a simple viewset that can be used to listing or retrieving all the users in the system. +Let's define a simple viewset that can be used to list or retrieve all the users in the system. class UserViewSet(viewsets.ViewSet): """ diff --git a/docs/index.md b/docs/index.md index cc960a98..7c38efd3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -113,8 +113,8 @@ Here's our project's root `urls.py` module: # Routers provide an easy way of automatically determining the URL conf router = routers.DefaultRouter() - router.register(r'users', views.UserViewSet) - router.register(r'groups', views.GroupViewSet) + router.register(r'users', UserViewSet) + router.register(r'groups', GroupViewSet) # Wire up our API using automatic URL routing. diff --git a/docs/topics/credits.md b/docs/topics/credits.md index 4eb78d30..8151b4d3 100644 --- a/docs/topics/credits.md +++ b/docs/topics/credits.md @@ -121,6 +121,9 @@ The following people have helped make REST framework great. * Andrew Hughes - [eyepulp] * Daniel Hepper - [dhepper] * Hamish Campbell - [hamishcampbell] +* Marlon Bailey - [avinash240] +* James Summerfield - [jsummerfield] +* Andy Freeland - [rouge8] Many thanks to everyone who's contributed to the project. @@ -278,3 +281,6 @@ You can also contact [@_tomchristie][twitter] directly on twitter. [eyepulp]: https://github.com/eyepulp [dhepper]: https://github.com/dhepper [hamishcampbell]: https://github.com/hamishcampbell +[avinash240]: https://github.com/avinash240 +[jsummerfield]: https://github.com/jsummerfield +[rouge8]: https://github.com/rouge8 diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 259aafdd..560dd305 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,20 @@ You can determine your currently installed version using `pip freeze`: ## 2.3.x series +### Master + +* Bugfix: HyperlinkedIdentityField now uses `lookup_field` kwarg. + +### 2.3.2 + +**Date**: 16th May 2013 + +* Added SearchFilter +* Added OrderingFilter +* Added GenericViewSet +* Bugfix: Multiple `@action` and `@link` methods now allowed on viewsets. +* Bugfix: Fix API Root view issue with DjangoModelPermissions + ### 2.3.2 **Date**: 8th May 2013 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index b4961e2f..0b1e67fb 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.2' +__version__ = '2.3.3' VERSION = __version__ # synonym diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1f38b795..491aa7ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -15,10 +15,12 @@ import warnings from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings +from django.db.models.fields import BLANK_CHOICE_DASH from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ +from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time @@ -170,7 +172,11 @@ class Field(object): elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)): return [self.to_native(item) for item in value] elif isinstance(value, dict): - return dict(map(self.to_native, (k, v)) for k, v in value.items()) + # Make sure we preserve field ordering, if it exists + ret = SortedDict() + for key, val in value.items(): + ret[key] = self.to_native(val) + return ret return smart_text(value) def attributes(self): @@ -402,6 +408,8 @@ class ChoiceField(WritableField): def __init__(self, choices=(), *args, **kwargs): super(ChoiceField, self).__init__(*args, **kwargs) self.choices = choices + if not self.required: + self.choices = BLANK_CHOICE_DASH + self.choices def _get_choices(self): return self._choices diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 57f0f7c8..c058bc71 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -4,7 +4,7 @@ returned by list views. """ from __future__ import unicode_literals from django.db import models -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, six from functools import reduce import operator @@ -32,40 +32,33 @@ class DjangoFilterBackend(BaseFilterBackend): def __init__(self): assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' - def get_filter_class(self, view): + def get_filter_class(self, view, queryset=None): """ Return the django-filters `FilterSet` used to filter the queryset. """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - model_cls = getattr(view, 'model', None) - queryset = getattr(view, 'queryset', None) - if model_cls is None and queryset is not None: - model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, model_cls), \ - 'FilterSet model %s does not match view model %s' % \ - (filter_model, model_cls) + assert issubclass(filter_model, queryset.model), \ + 'FilterSet model %s does not match queryset model %s' % \ + (filter_model, queryset.model) return filter_class if filter_fields: - assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ - 'on a view which does not have a .model or .queryset attribute.' - class AutoFilterSet(self.default_filter_set): class Meta: - model = model_cls + model = queryset.model fields = filter_fields return AutoFilterSet return None def filter_queryset(self, request, queryset, view): - filter_class = self.get_filter_class(view) + filter_class = self.get_filter_class(view, queryset) if filter_class: return filter_class(request.QUERY_PARAMS, queryset=queryset).qs @@ -75,7 +68,14 @@ class DjangoFilterBackend(BaseFilterBackend): class SearchFilter(BaseFilterBackend): search_param = 'search' # The URL query parameter used for the search. - delimiter = None # For example, set to ',' for comma delimited searchs. + + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.search_param, '') + return params.replace(',', ' ').split() def construct_search(self, field_name): if field_name.startswith('^'): @@ -91,15 +91,53 @@ class SearchFilter(BaseFilterBackend): search_fields = getattr(view, 'search_fields', None) if not search_fields: - return None + return queryset - search_terms = request.QUERY_PARAMS.get(self.search_param) orm_lookups = [self.construct_search(str(search_field)) for search_field in search_fields] - for search_term in search_terms.split(self.delimiter): + for search_term in self.get_search_terms(request): or_queries = [models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups] queryset = queryset.filter(reduce(operator.or_, or_queries)) return queryset + + +class OrderingFilter(BaseFilterBackend): + ordering_param = 'ordering' # The URL query parameter used for the ordering. + + def get_ordering(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.ordering_param) + if params: + return [param.strip() for param in params.split(',')] + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, six.string_types): + return (ordering,) + return ordering + + def remove_invalid_fields(self, queryset, ordering): + field_names = [field.name for field in queryset.model._meta.fields] + return [term for term in ordering if term.lstrip('-') in field_names] + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request) + + if ordering: + # Skip any incorrect parameters + ordering = self.remove_invalid_fields(queryset, ordering) + + if not ordering: + # Use 'ordering' attribtue by default + ordering = self.get_default_ordering(view) + + if ordering: + return queryset.order_by(*ordering) + + return queryset diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index ae703771..f3cd5868 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,6 +10,7 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +import warnings def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): @@ -42,7 +43,6 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None) class CreateModelMixin(object): """ Create a model instance. - Should be mixed in with any `GenericAPIView`. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) @@ -67,7 +67,6 @@ class CreateModelMixin(object): class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectAPIView`. """ empty_error = "Empty list and '%(class_name)s.allow_empty' is False." @@ -77,6 +76,12 @@ class ListModelMixin(object): # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. if not self.allow_empty and not self.object_list: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning + ) class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) @@ -94,7 +99,6 @@ class ListModelMixin(object): class RetrieveModelMixin(object): """ Retrieve a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() @@ -105,17 +109,22 @@ class RetrieveModelMixin(object): class UpdateModelMixin(object): """ Update a model instance. - Should be mixed in with `SingleObjectAPIView`. """ - def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) - self.object = None + def get_object_or_none(self): try: - self.object = self.get_object() + return self.get_object() except Http404: # If this is a PUT-as-create operation, we need to ensure that # we have relevant permissions, as if this was a POST request. - self.check_permissions(clone_request(request, 'POST')) + # This will either raise a PermissionDenied exception, + # or simply return None + self.check_permissions(clone_request(self.request, 'POST')) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + self.object = self.get_object_or_none() + + if self.object is None: created = True save_kwargs = {'force_insert': True} success_status_code = status.HTTP_201_CREATED @@ -168,7 +177,6 @@ class UpdateModelMixin(object): class DestroyModelMixin(object): """ Destroy a model instance. - Should be mixed in with `SingleObjectAPIView`. """ def destroy(self, request, *args, **kwargs): obj = self.get_object() diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 751f31a7..45fcfd66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -126,6 +126,11 @@ class DjangoModelPermissions(BasePermission): if model_cls is None and queryset is not None: model_cls = queryset.model + # Workaround to ensure DjangoModelPermissions are not applied + # to the root view when using DefaultRouter. + if model_cls is None and getattr(view, '_ignore_model_permissions'): + return True + assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' ' does not have `.model` or `.queryset` property.') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index fc5054b2..884b954c 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -221,12 +221,20 @@ class PrimaryKeyRelatedField(RelatedField): def field_to_native(self, obj, field_name): if self.many: # To-many relationship - try: + + queryset = None + if not self.source: # Prefer obj.serializable_value for performance reasons - queryset = obj.serializable_value(self.source or field_name) - except AttributeError: + try: + queryset = obj.serializable_value(field_name) + except AttributeError: + pass + if queryset is None: # RelatedManager (reverse relationship) - queryset = getattr(obj, self.source or field_name) + source = self.source or field_name + queryset = obj + for component in source.split('.'): + queryset = get_component(queryset, component) # Forward relationship return [self.to_native(item.pk) for item in queryset.all()] @@ -465,10 +473,13 @@ class HyperlinkedIdentityField(Field): """ Represents the instance, or a property on the instance, using hyperlinking. """ + lookup_field = 'pk' + read_only = True + + # These are all pending deprecation pk_url_kwarg = 'pk' slug_field = 'slug' slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - read_only = True def __init__(self, *args, **kwargs): # TODO: Make view_name mandatory, and have the @@ -477,6 +488,19 @@ class HyperlinkedIdentityField(Field): # Optionally the format of the target hyperlink may be specified self.format = kwargs.pop('format', None) + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + self.slug_field = kwargs.pop('slug_field', self.slug_field) default_slug_kwarg = self.slug_url_kwarg or self.slug_field self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) @@ -488,7 +512,8 @@ class HyperlinkedIdentityField(Field): request = self.context.get('request', None) format = self.context.get('format', None) view_name = self.view_name or self.parent.opts.view_name - kwargs = {self.pk_url_kwarg: obj.pk} + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} if request is None: warnings.warn("Using `HyperlinkedIdentityField` without including the " diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1917a080..8361cd40 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -336,7 +336,7 @@ class BrowsableAPIRenderer(BaseRenderer): return # Cannot use form overloading try: - view.check_permissions(clone_request(request, method)) + view.check_permissions(request) except exceptions.APIException: return False # Doesn't have permissions return True @@ -372,6 +372,30 @@ class BrowsableAPIRenderer(BaseRenderer): return fields + def _get_form(self, view, method, request): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_form(view, method, request) + finally: + view.request = restore + + def _get_raw_data_form(self, view, method, request, media_types): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_raw_data_form(view, method, request, media_types) + finally: + view.request = restore + def get_form(self, view, method, request): """ Get a form, possibly bound to either the input or output data. @@ -465,15 +489,15 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self.get_form(view, 'PUT', request) - post_form = self.get_form(view, 'POST', request) - patch_form = self.get_form(view, 'PATCH', request) - delete_form = self.get_form(view, 'DELETE', request) - options_form = self.get_form(view, 'OPTIONS', request) + put_form = self._get_form(view, 'PUT', request) + post_form = self._get_form(view, 'POST', request) + patch_form = self._get_form(view, 'PATCH', request) + delete_form = self._get_form(view, 'DELETE', request) + options_form = self._get_form(view, 'OPTIONS', request) - raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) + raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) + raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 0707635a..dba104c3 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -16,7 +16,8 @@ For example, you might have a `urls.py` that looks something like this: from __future__ import unicode_literals from collections import namedtuple -from django.conf.urls import url, patterns +from rest_framework import views +from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.reverse import reverse @@ -127,18 +128,18 @@ class SimpleRouter(BaseRouter): """ # Determine any `@action` or `@link` decorated methods on the viewset - dynamic_routes = {} + dynamic_routes = [] for methodname in dir(viewset): attr = getattr(viewset, methodname) httpmethod = getattr(attr, 'bind_to_method', None) if httpmethod: - dynamic_routes[httpmethod] = methodname + dynamic_routes.append((httpmethod, methodname)) ret = [] for route in self.routes: if route.mapping == {'{httpmethod}': '{methodname}'}: # Dynamic routes (@link or @action decorator) - for httpmethod, methodname in dynamic_routes.items(): + for httpmethod, methodname in dynamic_routes: initkwargs = route.initkwargs.copy() initkwargs.update(getattr(viewset, methodname).kwargs) ret.append(Route( @@ -217,14 +218,16 @@ class DefaultRouter(SimpleRouter): for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - @api_view(('GET',)) - def api_root(request, format=None): - ret = {} - for key, url_name in api_root_dict.items(): - ret[key] = reverse(url_name, request=request, format=format) - return Response(ret) + class APIRoot(views.APIView): + _ignore_model_permissions = True - return api_root + def get(self, request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return APIRoot.as_view() def get_urls(self): """ diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 9b519f27..9dd7b545 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -4,6 +4,8 @@ DEBUG = True TEMPLATE_DEBUG = DEBUG DEBUG_PROPAGATE_EXCEPTIONS = True +ALLOWED_HOSTS = ['*'] + ADMINS = ( # ('Your Name', 'your_email@domain.com'), ) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d7a4c9ef..500bb306 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -200,7 +200,7 @@ class BaseSerializer(WritableField): # If 'fields' is specified, use those fields, in that order. if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple' + assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' new = SortedDict() for key in self.opts.fields: new[key] = ret[key] @@ -208,7 +208,7 @@ class BaseSerializer(WritableField): # Remove anything in 'exclude' if self.opts.exclude: - assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple' + assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' for key in self.opts.exclude: ret.pop(key, None) @@ -649,8 +649,14 @@ class ModelSerializer(Serializer): # Add the `read_only` flag to any fields that have bee specified # in the `read_only_fields` option for field_name in self.opts.read_only_fields: + assert field_name not in self.base_fields.keys(), \ + "field '%s' on serializer '%s' specfied in " \ + "`read_only_fields`, but also added " \ + "as an explict field. Remove it from `read_only_fields`." % \ + (field_name, self.__class__.__name__) assert field_name in ret, \ - "read_only_fields on '%s' included invalid item '%s'" % \ + "Noexistant field '%s' specified in `read_only_fields` " \ + "on serializer '%s'." % \ (self.__class__.__name__, field_name) ret[field_name].read_only = True @@ -699,15 +705,14 @@ class ModelSerializer(Serializer): Creates a default instance of a basic non-relational field. """ kwargs = {} - has_default = model_field.has_default() - if model_field.null or model_field.blank or has_default: + if model_field.null or model_field.blank: kwargs['required'] = False if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True - if has_default: + if model_field.has_default(): kwargs['default'] = model_field.get_default() if issubclass(model_field.__class__, models.TextField): diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html index 380d5820..a3e73b6b 100644 --- a/rest_framework/templates/rest_framework/login_base.html +++ b/rest_framework/templates/rest_framework/login_base.html @@ -12,44 +12,40 @@ <body class="container"> -<div class="container-fluid" style="margin-top: 30px"> - <div class="row-fluid"> - - <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="container-fluid" style="margin-top: 30px"> <div class="row-fluid"> - <div> - {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} - </div> - </div><!-- /row fluid --> - - <div class="row-fluid"> - <div> - <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> - {% csrf_token %} - <div id="div_id_username" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Username:</label> - <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> - </div> - </div> - <div id="div_id_password" class="clearfix control-group"> - <div class="controls"> - <Label class="span4">Password:</label> - <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> - </div> + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} </div> - <input type="hidden" name="next" value="{{ next }}" /> - <div class="form-actions-no-box"> - <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> + </div><!-- /row fluid --> + + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> + {% csrf_token %} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> + </div> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> + </div> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> + </div> + </form> </div> - </form> - </div> - </div><!-- /row fluid --> - </div><!--/span--> - - </div><!-- /.row-fluid --> - </div> - - </div> + </div><!-- /.row-fluid --> + </div><!--/.well--> + </div><!-- /.row-fluid --> + </div><!-- /.container-fluid --> </body> </html> diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 3cdfa0f6..6b1cdfc7 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -2,13 +2,12 @@ General serializer field tests. """ from __future__ import unicode_literals +from django.utils.datastructures import SortedDict import datetime from decimal import Decimal - from django.db import models from django.test import TestCase from django.core import validators - from rest_framework import serializers from rest_framework.serializers import Serializer @@ -63,6 +62,20 @@ class BasicFieldTests(TestCase): serializer = CharPrimaryKeyModelSerializer() self.assertEqual(serializer.fields['id'].read_only, False) + def test_dict_field_ordering(self): + """ + Field should preserve dictionary ordering, if it exists. + See: https://github.com/tomchristie/django-rest-framework/issues/832 + """ + ret = SortedDict() + ret['c'] = 1 + ret['b'] = 1 + ret['a'] = 1 + ret['z'] = 1 + field = serializers.Field() + keys = list(field.to_native(ret).keys()) + self.assertEqual(keys, ['c', 'b', 'a', 'z']) + class DateFieldTest(TestCase): """ @@ -645,4 +658,30 @@ class DecimalFieldTest(TestCase): s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) + + +class ChoiceFieldTests(TestCase): + """ + Tests for the ChoiceField options generator + """ + + SAMPLE_CHOICES = [ + ('red', 'Red'), + ('green', 'Green'), + ('blue', 'Blue'), + ] + + def test_choices_required(self): + """ + Make sure proper choices are rendered if field is required + """ + f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, self.SAMPLE_CHOICES) + + def test_choices_not_required(self): + """ + Make sure proper choices (plus blank) are rendered if the field isn't required + """ + f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filters.py index e5414232..8ae6d530 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filters.py @@ -24,7 +24,7 @@ if django_filters: class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These class are used to test a filter class. class SeveralFieldsFilter(django_filters.FilterSet): @@ -39,7 +39,7 @@ if django_filters: class FilterClassRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These classes are used to test a misconfigured filter class. class MisconfiguredFilter(django_filters.FilterSet): @@ -52,12 +52,12 @@ if django_filters: class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): @@ -68,11 +68,21 @@ if django_filters: queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) + + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + def get_queryset(self): + return FilterableItem.objects.all() urlpatterns = patterns('', url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), ) @@ -148,6 +158,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase): self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ GET requests to filtered ListCreateAPIView that have a filter_class set @@ -222,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): """ Integration tests for filtered detail views. """ - urls = 'rest_framework.tests.filterset' + urls = 'rest_framework.tests.filters' def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -335,3 +356,119 @@ class SearchFilterTests(TestCase): {'id': 2, 'title': 'zz', 'text': 'bcd'} ] ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index eca50d82..15d87e86 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase -from rest_framework import generics, serializers, status +from rest_framework import generics, renderers, serializers, status from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six @@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): @@ -434,22 +435,14 @@ class TestFilterBackendAppliedToViews(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.root_view = RootView.as_view() - self.instance_view = InstanceView.as_view() - self.original_root_backend = getattr(RootView, 'filter_backend') - self.original_instance_backend = getattr(InstanceView, 'filter_backend') - - def tearDown(self): - setattr(RootView, 'filter_backend', self.original_root_backend) - setattr(InstanceView, 'filter_backend', self.original_instance_backend) def test_get_root_view_filters_by_name_with_filter_backend(self): """ GET requests to ListCreateAPIView should return filtered list. """ - setattr(RootView, 'filter_backend', InclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) @@ -458,9 +451,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to ListCreateAPIView should return empty list when all models are filtered out. """ - setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, []) @@ -468,9 +461,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. """ - setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.data, {'detail': 'Not found'}) @@ -478,8 +471,40 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded """ - setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + + +class TwoFieldModel(models.Model): + field_a = models.CharField(max_length=100) + field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): + model = TwoFieldModel + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_serializer_class(self): + if self.request.method == 'POST': + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + fields = ('field_b',) + return DynamicSerializer + return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + + def test_dynamic_serializer_form_in_browsable_api(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + self.assertContains(response, 'field_b') + self.assertNotContains(response, 'field_a') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 9a61f299..8fc6ba77 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase): self.assertEqual(response.data, self.data[0]) +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) + + class TestCreateWithForeignKeys(TestCase): urls = 'rest_framework.tests.hyperlinkedserializers' diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 894d53d6..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -130,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): model = FilterableItem paginate_by = 10 filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) view = FilterFieldsRootView.as_view() @@ -177,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): class BasicFilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem paginate_by = 10 - filter_backend = DecimalFilterBackend + filter_backends = (DecimalFilterBackend,) view = BasicFilterFieldsRootView.as_view() diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py index cbf93c65..d19219c9 100644 --- a/rest_framework/tests/relations.py +++ b/rest_framework/tests/relations.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import serializers +from rest_framework.tests.models import BlogPost class NullModel(models.Model): @@ -33,7 +34,7 @@ class FieldTests(TestCase): self.assertRaises(serializers.ValidationError, field.from_native, []) -class TestManyRelateMixin(TestCase): +class TestManyRelatedMixin(TestCase): def test_missing_many_to_many_related_field(self): ''' Regression test for #632 @@ -45,3 +46,55 @@ class TestManyRelateMixin(TestCase): into = {} field.field_from_native({}, None, 'field_name', into) self.assertEqual(into['field_name'], []) + + +# Regression tests for #694 (`source` attribute on related fields) + +class RelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b1eed9a7..b3efbf52 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -4,6 +4,7 @@ from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) @@ -16,6 +17,7 @@ def dummy_view(request, pk): pass urlpatterns = patterns('', + url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), @@ -451,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class HyperlinkedRelatedFieldSourceTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_manager', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_queryset', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='a.b.c', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index 5ce8b567..0f8c5247 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -1,7 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) from rest_framework.compat import six @@ -421,3 +424,55 @@ class PKNullableOneToOneTests(TestCase): {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class PrimaryKeyRelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py new file mode 100644 index 00000000..4e4765cb --- /dev/null +++ b/rest_framework/tests/routers.py @@ -0,0 +1,55 @@ +from __future__ import unicode_literals +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import status +from rest_framework.response import Response +from rest_framework import viewsets +from rest_framework.decorators import link, action +from rest_framework.routers import SimpleRouter +import copy + +factory = RequestFactory() + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + # check method to function mapping + if endpoint.startswith('action'): + method_map = 'post' + else: + method_map = 'get' + self.assertEqual(route.mapping[method_map], endpoint) + + diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 6e732327..d0a8570c 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,4 +1,6 @@ from __future__ import unicode_literals +from django.db import models +from django.db.models.fields import BLANK_CHOICE_DASH from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers @@ -89,6 +91,18 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: @@ -236,6 +250,12 @@ class BasicTests(TestCase): # Assert age is unchanged (35) self.assertEqual(instance.age, self.person_data['age']) + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + class DictStyleSerializer(serializers.Serializer): """ @@ -900,23 +920,6 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) - def test_queryset_nested_traversal(self): - """ - Relational fields should be able to use methods as their source. - """ - BlogPost.objects.create(title='blah') - - class QuerysetMethodSerializer(serializers.Serializer): - blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') - - class ClassWithQuerysetMethod(object): - def get_all_blogposts(self): - return BlogPost.objects - - obj = ClassWithQuerysetMethod() - serializer = QuerysetMethodSerializer(obj) - self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) - class SerializerMethodFieldTests(TestCase): def setUp(self): @@ -1047,6 +1050,73 @@ class SerializerPickleTests(TestCase): repr(pickle.loads(pickle.dumps(data, 0))) +# test for issue #725 +class SeveralChoicesModel(models.Model): + color = models.CharField( + max_length=10, + choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], + blank=False + ) + drink = models.CharField( + max_length=10, + choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], + blank=False, + default='beer' + ) + os = models.CharField( + max_length=10, + choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], + blank=True + ) + music_genre = models.CharField( + max_length=10, + choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], + blank=True, + default='metal' + ) + + +class SerializerChoiceFields(TestCase): + + def setUp(self): + super(SerializerChoiceFields, self).setUp() + + class SeveralChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = SeveralChoicesModel + fields = ('color', 'drink', 'os', 'music_genre') + + self.several_choices_serializer = SeveralChoicesSerializer + + def test_choices_blank_false_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['color'].choices, + [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] + ) + + def test_choices_blank_false_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['drink'].choices, + [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] + ) + + def test_choices_blank_true_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['os'].choices, + BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] + ) + + def test_choices_blank_true_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['music_genre'].choices, + BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] + ) + + class DepthTest(TestCase): def test_implicit_nesting(self): |
