diff options
Diffstat (limited to 'rest_framework')
23 files changed, 878 insertions, 277 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 518ba41a..b3a9b0df 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -301,7 +301,10 @@ class WritableField(Field): try: if self.use_files: files = files or {} - native = files[field_name] + try: + native = files[field_name] + except KeyError: + native = data[field_name] else: native = data[field_name] except KeyError: @@ -504,6 +507,11 @@ class ChoiceField(WritableField): return True return False + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + return super(ChoiceField, self).from_native(value) + class EmailField(CharField): type_name = 'EmailField' diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8e6b8e26..851f8474 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -14,13 +14,15 @@ from rest_framework.settings import api_settings import warnings -def strict_positive_int(integer_string): +def strict_positive_int(integer_string, cutoff=None): """ Cast a string to a strictly positive integer. """ ret = int(integer_string) if ret <= 0: raise ValueError() + if cutoff: + ret = min(ret, cutoff) return ret def get_object_or_404(queryset, **filter_kwargs): @@ -56,6 +58,7 @@ class GenericAPIView(views.APIView): # Pagination settings paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM + max_paginate_by = api_settings.MAX_PAGINATE_BY pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS page_kwarg = 'page' @@ -205,9 +208,11 @@ class GenericAPIView(views.APIView): DeprecationWarning, stacklevel=2) if self.paginate_by_param: - query_params = self.request.QUERY_PARAMS try: - return strict_positive_int(query_params[self.paginate_by_param]) + return strict_positive_int( + self.request.QUERY_PARAMS[self.paginate_by_param], + cutoff=self.max_paginate_by + ) except (KeyError, ValueError): pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 679dfa6c..2c85d157 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -142,11 +142,16 @@ class UpdateModelMixin(object): try: return self.get_object() except Http404: - # If this is a PUT-as-create operation, we need to ensure that - # we have relevant permissions, as if this was a POST request. - # This will either raise a PermissionDenied exception, - # or simply return None - self.check_permissions(clone_request(self.request, 'POST')) + if self.request.method == 'PUT': + # For PUT-as-create operation, we need to ensure that we have + # relevant permissions, as if this was a POST request. This + # will either raise a PermissionDenied exception, or simply + # return None. + self.check_permissions(clone_request(self.request, 'POST')) + else: + # PATCH requests where the object does not exist should still + # return a 404 response. + raise def pre_save(self, obj): """ diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 96bfac84..98fc0341 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter -from rest_framework.compat import yaml, etree +from rest_framework.compat import etree, six, yaml from rest_framework.exceptions import ParseError -from rest_framework.compat import six +from rest_framework import renderers import json import datetime import decimal @@ -47,6 +47,7 @@ class JSONParser(BaseParser): """ media_type = 'application/json' + renderer_class = renderers.UnicodeJSONRenderer def parse(self, stream, media_type=None, parser_context=None): """ @@ -121,7 +122,8 @@ class MultiPartParser(BaseParser): parser_context = parser_context or {} request = parser_context['request'] encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) - meta = request.META + meta = request.META.copy() + meta['CONTENT_TYPE'] = media_type upload_handlers = request.upload_handlers try: @@ -129,7 +131,7 @@ class MultiPartParser(BaseParser): data, files = parser.parse() return DataAndFiles(data, files) except MultiPartParserError as exc: - raise ParseError('Multipart form parse error - %s' % six.u(exc)) + raise ParseError('Multipart form parse error - %s' % str(exc)) class XMLParser(BaseParser): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index f1f7dea7..417925b5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -126,9 +126,9 @@ class RelatedField(WritableField): value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: break + value = get_component(value, component) except ObjectDoesNotExist: return None @@ -236,6 +236,8 @@ class PrimaryKeyRelatedField(RelatedField): source = self.source or field_name queryset = obj for component in source.split('.'): + if queryset is None: + return [] queryset = get_component(queryset, component) # Forward relationship @@ -556,8 +558,13 @@ class HyperlinkedIdentityField(Field): May raise a `NoReverseMatch` if the `view_name` and `lookup_field` attributes are not configured to correctly match the URL conf. """ - lookup_field = getattr(obj, self.lookup_field) + lookup_field = getattr(obj, self.lookup_field, None) kwargs = {self.lookup_field: lookup_field} + + # Handle unsaved object case + if lookup_field is None: + return None + try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1006e26c..fca67eee 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -21,10 +21,10 @@ from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml from rest_framework.settings import api_settings -from rest_framework.request import clone_request +from rest_framework.request import is_form_media_type, override_method from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import exceptions, parsers, status, VERSION +from rest_framework import exceptions, status, VERSION class BaseRenderer(object): @@ -36,6 +36,7 @@ class BaseRenderer(object): media_type = None format = None charset = 'utf-8' + render_style = 'text' def render(self, data, accepted_media_type=None, renderer_context=None): raise NotImplemented('Renderer class requires .render() to be implemented') @@ -51,16 +52,17 @@ class JSONRenderer(BaseRenderer): format = 'json' encoder_class = encoders.JSONEncoder ensure_ascii = True - charset = 'utf-8' - # Note that JSON encodings must be utf-8, utf-16 or utf-32. + charset = None + # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32. # See: http://www.ietf.org/rfc/rfc4627.txt + # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/ def render(self, data, accepted_media_type=None, renderer_context=None): """ Render `data` into JSON. """ if data is None: - return '' + return bytes() # If 'indent' is provided in the context, then pretty print the result. # E.g. If we're being called by the BrowsableAPIRenderer. @@ -85,13 +87,12 @@ class JSONRenderer(BaseRenderer): # and may (or may not) be unicode. # On python 3.x json.dumps() returns unicode strings. if isinstance(ret, six.text_type): - return bytes(ret.encode(self.charset)) + return bytes(ret.encode('utf-8')) return ret class UnicodeJSONRenderer(JSONRenderer): ensure_ascii = False - charset = 'utf-8' """ Renderer which serializes to JSON. Does *not* apply JSON's character escaping for non-ascii characters. @@ -108,6 +109,7 @@ class JSONPRenderer(JSONRenderer): format = 'jsonp' callback_parameter = 'callback' default_callback = 'callback' + charset = 'utf-8' def get_callback(self, renderer_context): """ @@ -316,6 +318,90 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): return data +class HTMLFormRenderer(BaseRenderer): + """ + Renderers serializer data into an HTML form. + + If the serializer was instantiated without an object then this will + return an HTML form not bound to any object, + otherwise it will return an HTML form with the appropriate initial data + populated from the object. + + Note that rendering of field and form errors is not currently supported. + """ + media_type = 'text/html' + format = 'form' + template = 'rest_framework/form.html' + charset = 'utf-8' + + def data_to_form_fields(self, data): + fields = {} + for key, val in data.fields.items(): + if getattr(val, 'read_only', True): + # Don't include read-only fields. + continue + + if getattr(val, 'fields', None): + # Nested data not supported by HTML forms. + continue + + kwargs = {} + kwargs['required'] = val.required + + #if getattr(v, 'queryset', None): + # kwargs['queryset'] = v.queryset + + if getattr(val, 'choices', None) is not None: + kwargs['choices'] = val.choices + + if getattr(val, 'regex', None) is not None: + kwargs['regex'] = val.regex + + if getattr(val, 'widget', None): + widget = copy.deepcopy(val.widget) + kwargs['widget'] = widget + + if getattr(val, 'default', None) is not None: + kwargs['initial'] = val.default + + if getattr(val, 'label', None) is not None: + kwargs['label'] = val.label + + if getattr(val, 'help_text', None) is not None: + kwargs['help_text'] = val.help_text + + fields[key] = val.form_field_class(**kwargs) + + return fields + + def render(self, data, accepted_media_type=None, renderer_context=None): + """ + Render serializer data and return an HTML form, as a string. + """ + # The HTMLFormRenderer currently uses something of a hack to render + # the content, by translating each of the serializer fields into + # an html form field, creating a dynamic form using those fields, + # and then rendering that form. + + # This isn't strictly neccessary, as we could render the serilizer + # fields to HTML directly. The implementation is historical and will + # likely change at some point. + + self.renderer_context = renderer_context or {} + request = renderer_context['request'] + + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python + fields = self.data_to_form_fields(data) + DynamicForm = type(str('DynamicForm'), (forms.Form,), fields) + data = None if data.empty else data + + template = loader.get_template(self.template) + context = RequestContext(request, {'form': DynamicForm(data)}) + + return template.render(context) + + class BrowsableAPIRenderer(BaseRenderer): """ HTML renderer used to self-document the API. @@ -324,6 +410,7 @@ class BrowsableAPIRenderer(BaseRenderer): format = 'api' template = 'rest_framework/api.html' charset = 'utf-8' + form_renderer_class = HTMLFormRenderer def get_default_renderer(self, view): """ @@ -348,7 +435,10 @@ class BrowsableAPIRenderer(BaseRenderer): renderer_context['indent'] = 4 content = renderer.render(data, accepted_media_type, renderer_context) - if renderer.charset is None: + render_style = getattr(renderer, 'render_style', 'text') + assert render_style in ['text', 'binary'], 'Expected .render_style ' \ + '"text" or "binary", but got "%s"' % render_style + if render_style == 'binary': return '[%d bytes of binary content]' % len(content) return content @@ -371,130 +461,99 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def serializer_to_form_fields(self, serializer): - fields = {} - for k, v in serializer.get_fields().items(): - if getattr(v, 'read_only', True): - continue - - kwargs = {} - kwargs['required'] = v.required - - #if getattr(v, 'queryset', None): - # kwargs['queryset'] = v.queryset - - if getattr(v, 'choices', None) is not None: - kwargs['choices'] = v.choices - - if getattr(v, 'regex', None) is not None: - kwargs['regex'] = v.regex - - if getattr(v, 'widget', None): - widget = copy.deepcopy(v.widget) - kwargs['widget'] = widget - - if getattr(v, 'default', None) is not None: - kwargs['initial'] = v.default - - if getattr(v, 'label', None) is not None: - kwargs['label'] = v.label - - if getattr(v, 'help_text', None) is not None: - kwargs['help_text'] = v.help_text - - fields[k] = v.form_field_class(**kwargs) - - return fields - - def _get_form(self, view, method, request): - # We need to impersonate a request with the correct method, - # so that eg. any dynamic get_serializer_class methods return the - # correct form for each method. - restore = view.request - request = clone_request(request, method) - view.request = request - try: - return self.get_form(view, method, request) - finally: - view.request = restore - - def _get_raw_data_form(self, view, method, request, media_types): - # We need to impersonate a request with the correct method, - # so that eg. any dynamic get_serializer_class methods return the - # correct form for each method. - restore = view.request - request = clone_request(request, method) - view.request = request - try: - return self.get_raw_data_form(view, method, request, media_types) - finally: - view.request = restore - - def get_form(self, view, method, request): + def get_rendered_html_form(self, view, method, request): """ - Get a form, possibly bound to either the input or output data. - In the absence on of the Resource having an associated form then - provide a form that can be used to submit arbitrary content. + Return a string representing a rendered HTML form, possibly bound to + either the input or output data. + + In the absence of the View having an associated form then return None. """ - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): - return + with override_method(view, request, method) as request: + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return - if method in ('DELETE', 'OPTIONS'): - return True # Don't actually need to return a form + if method in ('DELETE', 'OPTIONS'): + return True # Don't actually need to return a form - if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: - return + if (not getattr(view, 'get_serializer', None) + or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): + return - serializer = view.get_serializer(instance=obj) - fields = self.serializer_to_form_fields(serializer) + serializer = view.get_serializer(instance=obj) - # Creating an on the fly form see: - # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) - data = (obj is not None) and serializer.data or None - form_instance = OnTheFlyForm(data) - return form_instance + data = serializer.data + form_renderer = self.form_renderer_class() + return form_renderer.render(data, self.accepted_media_type, self.renderer_context) - def get_raw_data_form(self, view, method, request, media_types): + def get_raw_data_form(self, view, method, request): """ Returns a form that allows for arbitrary content types to be tunneled via standard HTML forms. (Which are typically application/x-www-form-urlencoded) """ - - # If we're not using content overloading there's no point in supplying a generic form, - # as the view won't treat the form's value as the content of the request. - if not (api_settings.FORM_CONTENT_OVERRIDE - and api_settings.FORM_CONTENTTYPE_OVERRIDE): - return None - - # Check permissions - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): - return - - content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE - content_field = api_settings.FORM_CONTENT_OVERRIDE - choices = [(media_type, media_type) for media_type in media_types] - initial = media_types[0] - - # NB. http://jacobian.org/writing/dynamic-form-generation/ - class GenericContentForm(forms.Form): - def __init__(self): - super(GenericContentForm, self).__init__() - - self.fields[content_type_field] = forms.ChoiceField( - label='Media type', - choices=choices, - initial=initial - ) - self.fields[content_field] = forms.CharField( - label='Content', - widget=forms.Textarea - ) - - return GenericContentForm() + with override_method(view, request, method) as request: + # If we're not using content overloading there's no point in + # supplying a generic form, as the view won't treat the form's + # value as the content of the request. + if not (api_settings.FORM_CONTENT_OVERRIDE + and api_settings.FORM_CONTENTTYPE_OVERRIDE): + return None + + # Check permissions + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + + # If possible, serialize the initial content for the generic form + default_parser = view.parser_classes[0] + renderer_class = getattr(default_parser, 'renderer_class', None) + if (hasattr(view, 'get_serializer') and renderer_class): + # View has a serializer defined and parser class has a + # corresponding renderer that can be used to render the data. + + # Get a read-only version of the serializer + serializer = view.get_serializer(instance=obj) + if obj is None: + for name, field in serializer.fields.items(): + if getattr(field, 'read_only', None): + del serializer.fields[name] + + # Render the raw data content + renderer = renderer_class() + accepted = self.accepted_media_type + context = self.renderer_context.copy() + context['indent'] = 4 + content = renderer.render(serializer.data, accepted, context) + else: + content = None + + # Generate a generic form that includes a content type field, + # and a content field. + content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE + content_field = api_settings.FORM_CONTENT_OVERRIDE + + media_types = [parser.media_type for parser in view.parser_classes] + choices = [(media_type, media_type) for media_type in media_types] + initial = media_types[0] + + # NB. http://jacobian.org/writing/dynamic-form-generation/ + class GenericContentForm(forms.Form): + def __init__(self): + super(GenericContentForm, self).__init__() + + self.fields[content_type_field] = forms.ChoiceField( + label='Media type', + choices=choices, + initial=initial + ) + self.fields[content_field] = forms.CharField( + label='Content', + widget=forms.Textarea, + initial=content + ) + + return GenericContentForm() def get_name(self, view): return view.get_view_name() @@ -509,26 +568,25 @@ class BrowsableAPIRenderer(BaseRenderer): """ Render the HTML for the browsable API representation. """ - accepted_media_type = accepted_media_type or '' - renderer_context = renderer_context or {} + self.accepted_media_type = accepted_media_type or '' + self.renderer_context = renderer_context or {} view = renderer_context['view'] request = renderer_context['request'] response = renderer_context['response'] - media_types = [parser.media_type for parser in view.parser_classes] renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self._get_form(view, 'PUT', request) - post_form = self._get_form(view, 'POST', request) - patch_form = self._get_form(view, 'PATCH', request) - delete_form = self._get_form(view, 'DELETE', request) - options_form = self._get_form(view, 'OPTIONS', request) + put_form = self.get_rendered_html_form(view, 'PUT', request) + post_form = self.get_rendered_html_form(view, 'POST', request) + patch_form = self.get_rendered_html_form(view, 'PATCH', request) + delete_form = self.get_rendered_html_form(view, 'DELETE', request) + options_form = self.get_rendered_html_form(view, 'OPTIONS', request) - raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self.get_raw_data_form(view, 'PUT', request) + raw_data_post_form = self.get_raw_data_form(view, 'POST', request) + raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) @@ -581,3 +639,4 @@ class MultiPartRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): return encode_multipart(self.BOUNDARY, data) + diff --git a/rest_framework/request.py b/rest_framework/request.py index 919716f4..977d4d96 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -28,6 +28,29 @@ def is_form_media_type(media_type): base_media_type == 'multipart/form-data') +class override_method(object): + """ + A context manager that temporarily overrides the method on a request, + additionally setting the `view.request` attribute. + + Usage: + + with override_method(view, request, 'POST') as request: + ... # Do stuff with `view` and `request` + """ + def __init__(self, view, request, method): + self.view = view + self.request = request + self.method = method + + def __enter__(self): + self.view.request = clone_request(self.request, self.method) + return self.view.request + + def __exit__(self, *args, **kwarg): + self.view.request = self.request + + class Empty(object): """ Placeholder for unset attributes. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index b761ba9a..1c7a8158 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -213,7 +213,11 @@ class SimpleRouter(BaseRouter): Given a viewset, return the portion of URL regex that is used to match against a single instance. """ - base_regex = '(?P<{lookup_field}>[^/]+)' + if self.trailing_slash: + base_regex = '(?P<{lookup_field}>[^/]+)' + else: + # Don't consume `.json` style suffixes + base_regex = '(?P<{lookup_field}>[^/.]+)' lookup_field = getattr(viewset, 'lookup_field', 'pk') return base_regex.format(lookup_field=lookup_field) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b3850157..f1775762 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -31,6 +31,9 @@ from rest_framework.relations import * from rest_framework.fields import * +class RelationsList(list): + _deleted = [] + class NestedValidationError(ValidationError): """ The default ValidationError behavior is to stringify each item in the list @@ -160,7 +163,6 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None - self._deleted = None if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') @@ -297,7 +299,8 @@ class BaseSerializer(WritableField): Serialize objects -> primitives. """ ret = self._dict_class() - ret.fields = {} + ret.fields = self._dict_class() + ret.empty = obj is None for field_name, field in self.fields.items(): field.initialize(parent=self, field_name=field_name) @@ -330,14 +333,15 @@ class BaseSerializer(WritableField): if self.source == '*': return self.to_native(obj) + # Get the raw field value try: source = self.source or field_name value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: break + value = get_component(value, component) except ObjectDoesNotExist: return None @@ -372,6 +376,7 @@ class BaseSerializer(WritableField): # Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None + obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj if self.source == '*': if value: @@ -385,7 +390,8 @@ class BaseSerializer(WritableField): 'data': value, 'context': self.context, 'partial': self.partial, - 'many': self.many + 'many': self.many, + 'allow_add_remove': self.allow_add_remove } serializer = self.__class__(**kwargs) @@ -418,8 +424,17 @@ class BaseSerializer(WritableField): if self._errors is None: data, files = self.init_data, self.init_files - if self.many: - ret = [] + if self.many is not None: + many = self.many + else: + many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) + if many: + warnings.warn('Implict list/queryset serialization is deprecated. ' + 'Use the `many=True` flag when instantiating the serializer.', + DeprecationWarning, stacklevel=3) + + if many: + ret = RelationsList() errors = [] update = self.object is not None @@ -446,8 +461,8 @@ class BaseSerializer(WritableField): ret.append(self.from_native(item, None)) errors.append(self._errors) - if update: - self._deleted = identity_to_objects.values() + if update and self.allow_add_remove: + ret._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] else: @@ -490,12 +505,12 @@ class BaseSerializer(WritableField): """ if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] + + if self.object._deleted: + [self.delete_object(item) for item in self.object._deleted] else: self.save_object(self.object, **kwargs) - if self.allow_add_remove and self._deleted: - [self.delete_object(item) for item in self._deleted] - return self.object def metadata(self): @@ -771,9 +786,12 @@ class ModelSerializer(Serializer): cls = self.opts.model opts = get_concrete_model(cls)._meta exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): field_name = field.source or field_name - if field_name in exclusions and not field.read_only: + if field_name in exclusions \ + and not field.read_only \ + and not isinstance(field, Serializer): exclusions.remove(field_name) return exclusions @@ -799,6 +817,7 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + nested_forward_relations = {} meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -818,6 +837,12 @@ class ModelSerializer(Serializer): if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) + # Nested forward relations - These need to be marked so we can save + # them before saving the parent model instance. + for field_name in attrs.keys(): + if isinstance(self.fields.get(field_name, None), Serializer): + nested_forward_relations[field_name] = attrs[field_name] + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -833,6 +858,7 @@ class ModelSerializer(Serializer): # at the point of save. instance._related_data = related_data instance._m2m_data = m2m_data + instance._nested_forward_relations = nested_forward_relations return instance @@ -848,6 +874,14 @@ class ModelSerializer(Serializer): """ Save the deserialized object and return it. """ + if getattr(obj, '_nested_forward_relations', None): + # Nested relationships need to be saved before we can save the + # parent instance. + for field_name, sub_object in obj._nested_forward_relations.items(): + if sub_object: + self.save_object(sub_object) + setattr(obj, field_name, sub_object) + obj.save(**kwargs) if getattr(obj, '_m2m_data', None): @@ -857,7 +891,25 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - setattr(obj, accessor_name, related) + if isinstance(related, RelationsList): + # Nested reverse fk relationship + for related_item in related: + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related_item, fk_field, obj) + self.save_object(related_item) + + # Delete any removed objects + if related._deleted: + [self.delete_object(item) for item in related._deleted] + + elif isinstance(related, models.Model): + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) + else: + # Reverse FK or reverse one-one + setattr(obj, accessor_name, related) del(obj._related_data) @@ -879,6 +931,7 @@ class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField + _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -887,7 +940,7 @@ class HyperlinkedModelSerializer(ModelSerializer): self.opts.view_name = self._get_default_view_name(self.opts.model) if 'url' not in fields: - url_field = HyperlinkedIdentityField( + url_field = self._hyperlink_identify_field_class( view_name=self.opts.view_name, lookup_field=self.opts.lookup_field ) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 7d25e513..8c084751 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -48,7 +48,6 @@ DEFAULTS = { ), 'DEFAULT_THROTTLE_CLASSES': ( ), - 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', @@ -68,15 +67,16 @@ DEFAULTS = { # Pagination 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - - # View configuration - 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', - 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', + 'MAX_PAGINATE_BY': None, # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # View configuration + 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', + # Testing 'TEST_REQUEST_RENDERER_CLASSES': ( 'rest_framework.renderers.MultiPartRenderer', diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js index c74829d7..bcb1964d 100644 --- a/rest_framework/static/rest_framework/js/default.js +++ b/rest_framework/static/rest_framework/js/default.js @@ -1,13 +1,56 @@ +function getCookie(c_name) +{ + // From http://www.w3schools.com/js/js_cookies.asp + var c_value = document.cookie; + var c_start = c_value.indexOf(" " + c_name + "="); + if (c_start == -1) { + c_start = c_value.indexOf(c_name + "="); + } + if (c_start == -1) { + c_value = null; + } else { + c_start = c_value.indexOf("=", c_start) + 1; + var c_end = c_value.indexOf(";", c_start); + if (c_end == -1) { + c_end = c_value.length; + } + c_value = unescape(c_value.substring(c_start,c_end)); + } + return c_value; +} + +// JSON highlighting. prettyPrint(); +// Bootstrap tooltips. $('.js-tooltip').tooltip({ delay: 1000 }); +// Deal with rounded tab styling after tab clicks. $('a[data-toggle="tab"]:first').on('shown', function (e) { $(e.target).parents('.tabbable').addClass('first-tab-active'); }); $('a[data-toggle="tab"]:not(:first)').on('shown', function (e) { $(e.target).parents('.tabbable').removeClass('first-tab-active'); }); -$('.form-switcher a:first').tab('show'); + +$('a[data-toggle="tab"]').click(function(){ + document.cookie="tabstyle=" + this.name + "; path=/"; +}); + +// Store tab preference in cookies & display appropriate tab on load. +var selectedTab = null; +var selectedTabName = getCookie('tabstyle'); + +if (selectedTabName) { + selectedTab = $('.form-switcher a[name=' + selectedTabName + ']'); +} + +if (selectedTab && selectedTab.length > 0) { + // Display whichever tab is selected. + selectedTab.tab('show'); +} else { + // If no tab selected, display rightmost tab. + $('.form-switcher a:first').tab('show'); +} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 51f9c291..aa90e90c 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -128,17 +128,17 @@ <div {% if post_form %}class="tabbable"{% endif %}> {% if post_form %} <ul class="nav nav-tabs form-switcher"> - <li><a href="#object-form" data-toggle="tab">HTML form</a></li> - <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> + <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li> + <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li> </ul> {% endif %} <div class="well tab-content"> {% if post_form %} <div class="tab-pane" id="object-form"> {% with form=post_form %} - <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> + <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal"> <fieldset> - {% include "rest_framework/form.html" %} + {{ post_form }} <div class="form-actions"> <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button> </div> @@ -167,23 +167,21 @@ <div {% if put_form %}class="tabbable"{% endif %}> {% if put_form %} <ul class="nav nav-tabs form-switcher"> - <li><a href="#object-form" data-toggle="tab">HTML form</a></li> - <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> + <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li> + <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li> </ul> {% endif %} <div class="well tab-content"> {% if put_form %} <div class="tab-pane" id="object-form"> - {% with form=put_form %} - <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> + <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal"> <fieldset> - {% include "rest_framework/form.html" %} + {{ put_form }} <div class="form-actions"> <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> </div> </fieldset> </form> - {% endwith %} </div> {% endif %} <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form"> diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..234d10a4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient): """ self.handler._force_user = user self.handler._force_token = token + if user is None: + self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index ebccba7d..34fbab9c 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase): f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) + result = f.from_native('') + self.assertEqual(result, None) + class EmailFieldTests(TestCase): """ diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 487046ac..c13c38b8 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -7,13 +7,13 @@ import datetime class UploadedFile(object): - def __init__(self, file, created=None): + def __init__(self, file=None, created=None): self.file = file self.created = created or datetime.datetime.now() class UploadedFileSerializer(serializers.Serializer): - file = serializers.FileField() + file = serializers.FileField(required=False) created = serializers.DateTimeField() def restore_object(self, attrs, instance=None): @@ -47,5 +47,36 @@ class FileSerializerTests(TestCase): now = datetime.datetime.now() serializer = UploadedFileSerializer(data={'created': now}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, now) + self.assertIsNone(serializer.object.file) + + def test_remove_with_empty_string(self): + """ + Passing empty string as data should cause file to be removed + + Test for: + https://github.com/tomchristie/django-rest-framework/issues/937 + """ + now = datetime.datetime.now() + file = BytesIO(six.b('stuff')) + file.name = 'stuff.txt' + file.size = len(file.getvalue()) + + uploaded_file = UploadedFile(file=file, created=now) + + serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertIsNone(serializer.object.file) + + def test_validation_error_with_non_file(self): + """ + Passing non-files should raise a validation error. + """ + now = datetime.datetime.now() + errmsg = 'No file was submitted. Check the encoding type on the form.' + + serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) self.assertFalse(serializer.is_valid()) - self.assertIn('file', serializer.errors) + self.assertEqual(serializer.errors, {'file': [errmsg]}) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 1550880b..7a87d389 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -338,6 +338,17 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') + def test_patch_cannot_create_an_object(self): + """ + PATCH requests should not be able to create objects. + """ + data = {'text': 'foobar'} + request = factory.patch('/999', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertFalse(self.objects.filter(id=999).exists()) + class TestOverriddenGetObject(TestCase): """ diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index 85d4640e..4170d4b6 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView): paginate_by_param = 'page_size' +class MaxPaginateByView(generics.ListAPIView): + """ + View for testing custom max_paginate_by usage + """ + model = BasicModel + paginate_by = 3 + max_paginate_by = 5 + paginate_by_param = 'page_size' + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase): self.assertEqual(response.data['results'], self.data[:5]) +class TestMaxPaginateByParam(TestCase): + """ + Tests for list views with max_paginate_by kwarg + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = MaxPaginateByView.as_view() + + def test_max_paginate_by(self): + """ + If max_paginate_by is set, it should limit page size for the view. + """ + request = factory.get('/?page_size=10') + response = self.view(request).render() + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:5]) + + def test_max_paginate_by_without_page_size_param(self): + """ + If max_paginate_by is set, but client does not specifiy page_size, + standard `paginate_by` behavior should be used. + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEqual(response.data['results'], self.data[:3]) + + ### Tests for context in pagination serializers class CustomField(serializers.Field): diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index f6d006b3..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -1,107 +1,328 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource - fields = ('id', 'name', 'target') - depth = 1 +class OneToOneTarget(models.Model): + name = models.CharField(max_length=100) -class ForeignKeyTargetSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeyTarget - fields = ('id', 'name', 'sources') - depth = 1 +class OneToOneSource(models.Model): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, related_name='source', + null=True, blank=True) -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableForeignKeySource - fields = ('id', 'name', 'target') - depth = 1 +class OneToManyTarget(models.Model): + name = models.CharField(max_length=100) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneTarget - fields = ('id', 'name', 'nullable_source') - depth = 1 +class OneToManySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(OneToManyTarget, related_name='sources') -class ReverseForeignKeyTests(TestCase): +class ReverseNestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() + class OneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSource + fields = ('id', 'name') + + class OneToOneTargetSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() + + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'source') + + self.Serializer = OneToOneTargetSerializer + for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() - def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} ] self.assertEqual(serializer.data, expected) - def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} + serializer = self.Serializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'target-1', 'sources': [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1}, - ]}, - {'id': 2, 'name': 'target-2', 'sources': [ - ]} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, + {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) -class NestedNullableForeignKeyTests(TestCase): + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} + instance = OneToOneTarget.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} + ] + self.assertEqual(serializer.data, expected) + + +class ForwardNestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() + class OneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name') + + class OneToOneSourceSerializer(serializers.ModelSerializer): + target = OneToOneTargetSerializer() + + class Meta: + model = OneToOneSource + fields = ('id', 'name', 'target') + + self.Serializer = OneToOneSourceSerializer + for idx in range(1, 4): - if idx == 3: - target = None - source = NullableForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() - def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + serializer = self.Serializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, + {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_update_to_null(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': None} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() -class NestedNullableOneToOneTests(TestCase): + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + self.assertEqual(obj.target, None) + + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + # TODO: Nullable 1-1 tests + # def test_one_to_one_delete(self): + # data = {'id': 3, 'name': 'target-3', 'target_source': None} + # instance = OneToOneTarget.objects.get(pk=3) + # serializer = self.Serializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # serializer.save() + + # # Ensure (target_source 3, source 3) are deleted, + # # and everything else is as expected. + # queryset = OneToOneTarget.objects.all() + # serializer = self.Serializer(queryset) + # expected = [ + # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + # {'id': 3, 'name': 'target-3', 'source': None} + # ] + # self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): def setUp(self): - target = OneToOneTarget(name='target-1') + class OneToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToManySource + fields = ('id', 'name') + + class OneToManyTargetSerializer(serializers.ModelSerializer): + sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + + class Meta: + model = OneToManyTarget + fields = ('id', 'name', 'sources') + + self.Serializer = OneToManyTargetSerializer + + target = OneToManyTarget(name='target-1') target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) - source.save() + for idx in range(1, 4): + source = OneToManySource(name='source-%d' % idx, target=target) + source.save() - def test_reverse_foreign_key_retrieve_with_null(self): - queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True) + def test_one_to_many_retrieve(self): + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]}, + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1') + + # Ensure source 4 is added, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, - {'id': 2, 'name': 'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create_with_invalid_data(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4}]} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + + def test_one_to_many_update(self): + data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1-updated') + + # Ensure (target 1, source 1) are updated, + # and everything else is as expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_delete(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + + # Ensure source 2 is deleted, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + ] self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index c3597e38..3f456fef 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase): self.urls = self.router.urls def test_urls_can_have_trailing_slash_removed(self): - expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] + expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] for idx in range(len(expected)): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request): }) +@api_view(['GET', 'POST']) +def session_view(request): + active_session = request.session.get('active_session', False) + request.session['active_session'] = True + return Response({ + 'active_session': active_session + }) + + urlpatterns = patterns('', url(r'^view/$', view), + url(r'^session-view/$', session_view), ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') + def test_force_authenticate_with_sessions(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + + # First request does not yet have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + + # Subsequant requests have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], True) + + # Force authenticating as `None` should also logout the user session. + self.client.force_authenticate(None) + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + def test_csrf_exempt_by_default(self): """ By default, the test client is CSRF exempt. diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 65b45593..a946d837 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,7 +2,7 @@ Provides various throttling policies. """ from __future__ import unicode_literals -from django.core.cache import cache +from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings import time @@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle): Previous request information used for throttling is stored in the cache. """ + cache = default_cache timer = time.time cache_format = 'throtte_%(scope)s_%(ident)s' scope = None @@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle): if self.key is None: return True - self.history = cache.get(self.key, []) + self.history = self.cache.get(self.key, []) self.now = self.timer() # Drop any requests from the history which have now passed the @@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle): into the cache. """ self.history.insert(0, self.now) - cache.set(self.key, self.history, self.duration) + self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): @@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle): if request.user.is_authenticated(): return None # Only throttle unauthenticated requests. - ident = request.META.get('REMOTE_ADDR', None) + ident = request.META.get('HTTP_X_FORWARDED_FOR') + if ident is None: + ident = request.META.get('REMOTE_ADDR') return self.cache_format % { 'scope': self.scope, diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 0384faba..e6690d17 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -8,8 +8,11 @@ def get_breadcrumbs(url): tuple of (name, url). """ + from rest_framework.settings import api_settings from rest_framework.views import APIView + view_name_func = api_settings.VIEW_NAME_FUNCTION + def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): """ Add tuples of (name, url) to the breadcrumbs list, @@ -28,8 +31,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - instance = view.cls() - name = instance.get_view_name() + suffix = getattr(view, 'suffix', None) + name = view_name_func(cls, suffix) breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) diff --git a/rest_framework/views.py b/rest_framework/views.py index 727a9f95..4cff0422 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,8 +15,14 @@ from rest_framework.settings import api_settings from rest_framework.utils import formatting -def get_view_name(cls, suffix=None): - name = cls.__name__ +def get_view_name(view_cls, suffix=None): + """ + Given a view class, return a textual name to represent the view. + This name is used in the browsable API, and in OPTIONS responses. + + This function is the default for the `VIEW_NAME_FUNCTION` setting. + """ + name = view_cls.__name__ name = formatting.remove_trailing_string(name, 'View') name = formatting.remove_trailing_string(name, 'ViewSet') name = formatting.camelcase_to_spaces(name) @@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None): return name -def get_view_description(cls, html=False): - description = cls.__doc__ or '' +def get_view_description(view_cls, html=False): + """ + Given a view class, return a textual description to represent the view. + This name is used in the browsable API, and in OPTIONS responses. + + This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. + """ + description = view_cls.__doc__ or '' description = formatting.dedent(smart_text(description)) if html: return formatting.markup_description(description) return description +def exception_handler(exc): + """ + Returns the response that should be used for any given exception. + + By default we handle the REST framework `APIException`, and also + Django's builtin `Http404` and `PermissionDenied` exceptions. + + Any unhandled exceptions may return `None`, which will cause a 500 error + to be raised. + """ + if isinstance(exc, exceptions.APIException): + headers = {} + if getattr(exc, 'auth_header', None): + headers['WWW-Authenticate'] = exc.auth_header + if getattr(exc, 'wait', None): + headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + + return Response({'detail': exc.detail}, + status=exc.status_code, + headers=headers) + + elif isinstance(exc, Http404): + return Response({'detail': 'Not found'}, + status=status.HTTP_404_NOT_FOUND) + + elif isinstance(exc, PermissionDenied): + return Response({'detail': 'Permission denied'}, + status=status.HTTP_403_FORBIDDEN) + + # Note: Unhandled exceptions will raise a 500 error. + return None + + class APIView(View): - settings = api_settings + # The following policies may be set at either globally, or per-view. renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES @@ -43,6 +88,9 @@ class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + # Allow dependancy injection of other settings to make testing easier. + settings = api_settings + @classmethod def as_view(cls, **initkwargs): """ @@ -133,7 +181,7 @@ class APIView(View): Return the view name, as used in OPTIONS responses and in the browsable API. """ - func = api_settings.VIEW_NAME_FUNCTION + func = self.settings.VIEW_NAME_FUNCTION return func(self.__class__, getattr(self, 'suffix', None)) def get_view_description(self, html=False): @@ -141,7 +189,7 @@ class APIView(View): Return some descriptive text for the view, as used in OPTIONS responses and in the browsable API. """ - func = api_settings.VIEW_DESCRIPTION_FUNCTION + func = self.settings.VIEW_DESCRIPTION_FUNCTION return func(self.__class__, html) # API policy instantiation methods @@ -303,33 +351,23 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, exceptions.Throttled) and exc.wait is not None: - # Throttle wait header - self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait - if isinstance(exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)): # WWW-Authenticate header for 401 responses, else coerce to 403 auth_header = self.get_authenticate_header(self.request) if auth_header: - self.headers['WWW-Authenticate'] = auth_header + exc.auth_header = auth_header else: exc.status_code = status.HTTP_403_FORBIDDEN - if isinstance(exc, exceptions.APIException): - return Response({'detail': exc.detail}, - status=exc.status_code, - exception=True) - elif isinstance(exc, Http404): - return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND, - exception=True) - elif isinstance(exc, PermissionDenied): - return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN, - exception=True) - raise + response = exception_handler(exc) + + if response is None: + raise + + response.exception = True + return response # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. |
