diff options
Diffstat (limited to 'djangorestframework')
| -rw-r--r-- | djangorestframework/mixins.py | 17 | ||||
| -rw-r--r-- | djangorestframework/permissions.py | 21 | ||||
| -rw-r--r-- | djangorestframework/renderers.py | 2 | ||||
| -rw-r--r-- | djangorestframework/resources.py | 129 | ||||
| -rw-r--r-- | djangorestframework/serializer.py | 315 | ||||
| -rw-r--r-- | djangorestframework/tests/resources.py | 60 | ||||
| -rw-r--r-- | djangorestframework/tests/serializer.py | 117 | ||||
| -rw-r--r-- | djangorestframework/tests/throttling.py | 55 | ||||
| -rw-r--r-- | djangorestframework/views.py | 13 | 
9 files changed, 510 insertions, 219 deletions
diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 11e3bb38..910d06ae 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -466,7 +466,7 @@ class InstanceMixin(object):              # We do a little dance when we store the view callable...              # we need to store it wrapped in a 1-tuple, so that inspect will treat it              # as a function when we later look it up (rather than turning it into a method). -            # This makes sure our URL reversing works ok.       +            # This makes sure our URL reversing works ok.              resource.view_callable = (view,)          return view @@ -479,6 +479,7 @@ class ReadModelMixin(object):      """      def get(self, request, *args, **kwargs):          model = self.resource.model +          try:              if args:                  # If we have any none kwargs then assume the last represents the primrary key @@ -498,6 +499,7 @@ class CreateModelMixin(object):      """      def post(self, request, *args, **kwargs):                  model = self.resource.model +          # translated 'related_field' kwargs into 'related_field_id'          for related_name in [field.name for field in model._meta.fields if isinstance(field, RelatedField)]:              if kwargs.has_key(related_name): @@ -522,6 +524,7 @@ class UpdateModelMixin(object):      """      def put(self, request, *args, **kwargs):          model = self.resource.model +                  # TODO: update on the url of a non-existing resource url doesn't work correctly at the moment - will end up with a new url           try:              if args: @@ -547,6 +550,7 @@ class DeleteModelMixin(object):      """      def delete(self, request, *args, **kwargs):          model = self.resource.model +          try:              if args:                  # If we have any none kwargs then assume the last represents the primrary key @@ -581,8 +585,15 @@ class ListModelMixin(object):      queryset = None      def get(self, request, *args, **kwargs): -        queryset = self.queryset if self.queryset else self.resource.model.objects.all() -        ordering = getattr(self.resource, 'ordering', None) +        model = self.resource.model + +        queryset = self.queryset if self.queryset else model.objects.all() + +        if hasattr(self, 'resource'): +            ordering = getattr(self.resource, 'ordering', None) +        else: +            ordering = None +          if ordering:              args = as_tuple(ordering)              queryset = queryset.order_by(*args) diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index 4825a174..7dcabcf0 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -137,7 +137,7 @@ class BaseThrottle(BasePermission):          # Drop any requests from the history which have now passed the          # throttle duration -        while self.history and self.history[0] <= self.now - self.duration: +        while self.history and self.history[-1] <= self.now - self.duration:              self.history.pop()          if len(self.history) >= self.num_requests:              self.throttle_failure() @@ -151,23 +151,32 @@ class BaseThrottle(BasePermission):          """          self.history.insert(0, self.now)          cache.set(self.key, self.history, self.duration) -        self.view.add_header('X-Throttle', 'status=SUCCESS; next=%s sec' % self.next()) +        header = 'status=SUCCESS; next=%s sec' % self.next() +        self.view.add_header('X-Throttle', header)      def throttle_failure(self):          """          Called when a request to the API has failed due to throttling.          Raises a '503 service unavailable' response.          """ -        self.view.add_header('X-Throttle', 'status=FAILURE; next=%s sec' % self.next()) +        header = 'status=FAILURE; next=%s sec' % self.next() +        self.view.add_header('X-Throttle', header)          raise _503_SERVICE_UNAVAILABLE      def next(self):          """          Returns the recommended next request time in seconds.          """ -        return '%.2f' % (self.duration / (self.num_requests - len(self.history) *1.0 + 1))   -     -     +        if self.history: +            remaining_duration = self.duration - (self.now - self.history[-1]) +        else: +            remaining_duration = self.duration + +        available_requests = self.num_requests - len(self.history) + 1 + +        return '%.2f' % (remaining_duration / float(available_requests)) + +  class PerUserThrottling(BaseThrottle):      """      Limits the rate of API calls that may be made by a given user. diff --git a/djangorestframework/renderers.py b/djangorestframework/renderers.py index 9834ba5e..7aa8777c 100644 --- a/djangorestframework/renderers.py +++ b/djangorestframework/renderers.py @@ -181,7 +181,7 @@ class DocumentingTemplateRenderer(BaseRenderer):          # Get the form instance if we have one bound to the input          form_instance = None -        if method == view.method.lower(): +        if method == getattr(view, 'method', view.request.method).lower():              form_instance = getattr(view, 'bound_form_instance', None)          if not form_instance and hasattr(view, 'get_bound_form'): diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index 07c97d43..08f9e0ae 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -6,6 +6,7 @@ from django.db.models.fields.related import RelatedField  from django.utils.encoding import smart_unicode  from djangorestframework.response import ErrorResponse +from djangorestframework.serializer import Serializer  from djangorestframework.utils import as_tuple  import decimal @@ -13,122 +14,9 @@ import inspect  import re -# TODO: _IgnoreFieldException -# Map model classes to resource classes -#_model_to_resource = {} - -def _model_to_dict(instance, resource=None): -    """ -    Given a model instance, return a ``dict`` representing the model. -     -    The implementation is similar to Django's ``django.forms.model_to_dict``, except: - -    * It doesn't coerce related objects into primary keys. -    * It doesn't drop ``editable=False`` fields. -    * It also supports attribute or method fields on the instance or resource. -    """ -    opts = instance._meta -    data = {} - -    #print [rel.name for rel in opts.get_all_related_objects()] -    #related = [rel.get_accessor_name() for rel in opts.get_all_related_objects()] -    #print [getattr(instance, rel) for rel in related] -    #if resource.fields: -    #    fields = resource.fields -    #else: -    #    fields = set(opts.fields + opts.many_to_many) -     -    fields = resource and resource.fields or () -    include = resource and resource.include or () -    exclude = resource and resource.exclude or () - -    extra_fields = fields and list(fields) or list(include) - -    # Model fields -    for f in opts.fields + opts.many_to_many: -        if fields and not f.name in fields: -            continue -        if exclude and f.name in exclude: -            continue -        if isinstance(f, models.ForeignKey): -            data[f.name] = getattr(instance, f.name) -        else: -            data[f.name] = f.value_from_object(instance) -         -        if extra_fields and f.name in extra_fields: -            extra_fields.remove(f.name) -     -    # Method fields -    for fname in extra_fields: -         -        if isinstance(fname, (tuple, list)): -            fname, fields = fname -        else: -            fname, fields = fname, False - -        try: -            if hasattr(resource, fname): -                # check the resource first, to allow it to override fields -                obj = getattr(resource, fname) -                # if it's a method like foo(self, instance), then call it  -                if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) == 2: -                    obj = obj(instance) -            elif hasattr(instance, fname): -                # now check the object instance -                obj = getattr(instance, fname) -            else: -                continue -     -            # TODO: It would be nicer if this didn't recurse here. -            # Let's keep _model_to_dict flat, and _object_to_data recursive. -            if fields: -                Resource = type('Resource', (object,), {'fields': fields, -                                                        'include': (), -                                                        'exclude': ()}) -                data[fname] = _object_to_data(obj, Resource()) -            else: -                data[fname] = _object_to_data(obj) - -        except NoReverseMatch: -            # Ug, bit of a hack for now -            pass -    -    return data - - -def _object_to_data(obj, resource=None): -    """ -    Convert an object into a serializable representation. -    """ -    if isinstance(obj, dict): -        # dictionaries -        # TODO: apply same _model_to_dict logic fields/exclude here -        return dict([ (key, _object_to_data(val)) for key, val in obj.iteritems() ]) -    if isinstance(obj, (tuple, list, set, QuerySet)): -        # basic iterables -        return [_object_to_data(item, resource) for item in obj] -    if isinstance(obj, models.Manager): -        # Manager objects -        return [_object_to_data(item, resource) for item in obj.all()] -    if isinstance(obj, models.Model): -        # Model instances -        return _object_to_data(_model_to_dict(obj, resource)) -    if isinstance(obj, decimal.Decimal): -        # Decimals (force to string representation) -        return str(obj) -    if inspect.isfunction(obj) and not inspect.getargspec(obj)[0]: -        # function with no args -        return _object_to_data(obj(), resource) -    if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1: -        # bound method -        return _object_to_data(obj(), resource) - -    return smart_unicode(obj, strings_only=True) - - -class BaseResource(object): +class BaseResource(Serializer):      """      Base class for all Resource classes, which simply defines the interface they provide.      """ @@ -136,7 +24,8 @@ class BaseResource(object):      include = None      exclude = None -    def __init__(self, view): +    def __init__(self, view=None, depth=None, stack=[], **kwargs): +        super(BaseResource, self).__init__(depth, stack, **kwargs)          self.view = view      def validate_request(self, data, files=None): @@ -150,7 +39,7 @@ class BaseResource(object):          """          Given the response content, filter it into a serializable object.          """ -        return _object_to_data(obj, self) +        return self.serialize(obj)  class Resource(BaseResource): @@ -297,7 +186,7 @@ class FormResource(Resource):          """          # A form on the view overrides a form on the resource. -        form = getattr(self.view, 'form', self.form) +        form = getattr(self.view, 'form', None) or self.form          # Use the requested method or determine the request method          if method is None and hasattr(self.view, 'request') and hasattr(self.view, 'method'): @@ -390,8 +279,8 @@ class ModelResource(FormResource):          """          super(ModelResource, self).__init__(view) -        if getattr(view, 'model', None): -            self.model = view.model +        self.model = getattr(view, 'model', None) or self.model +      def validate_request(self, data, files=None):          """ @@ -476,7 +365,7 @@ class ModelResource(FormResource):                      if isinstance(attr, models.Model):                          instance_attrs[param] = attr.pk                      else: -                        instance_attrs[param] = attr     +                        instance_attrs[param] = attr                  try:                      return reverse(self.view_callable[0], kwargs=instance_attrs) diff --git a/djangorestframework/serializer.py b/djangorestframework/serializer.py new file mode 100644 index 00000000..8d73d623 --- /dev/null +++ b/djangorestframework/serializer.py @@ -0,0 +1,315 @@ +""" +Customizable serialization. +""" +from django.db import models +from django.db.models.query import QuerySet +from django.db.models.fields.related import RelatedField +from django.utils.encoding import smart_unicode + +import decimal +import inspect +import types + + +# We register serializer classes, so that we can refer to them by their +# class names, if there are cyclical serialization heirachys. +_serializers = {} + + +def _field_to_tuple(field): +    """ +    Convert an item in the `fields` attribute into a 2-tuple.  +    """ +    if isinstance(field, (tuple, list)): +        return (field[0], field[1]) +    return (field, None) + +def _fields_to_list(fields): +    """ +    Return a list of field names. +    """ +    return [_field_to_tuple(field)[0] for field in fields or ()] + +def _fields_to_dict(fields): +    """ +    Return a `dict` of field name -> None, or tuple of fields, or Serializer class +    """ +    return dict([_field_to_tuple(field) for field in fields or ()]) + + +class _SkipField(Exception): +    """ +    Signals that a serialized field should be ignored. +    We use this mechanism as the default behavior for ensuring +    that we don't infinitely recurse when dealing with nested data. +    """ +    pass + + +class _RegisterSerializer(type): +    """ +    Metaclass to register serializers. +    """ +    def __new__(cls, name, bases, attrs): +        # Build the class and register it. +        ret = super(_RegisterSerializer, cls).__new__(cls, name, bases, attrs)  +        _serializers[name] = ret +        return ret + + +class Serializer(object): +    """ +    Converts python objects into plain old native types suitable for +    serialization.  In particular it handles models and querysets. +     +    The output format is specified by setting a number of attributes +    on the class. + +    You may also override any of the serialization methods, to provide +    for more flexible behavior. +  +    Valid output types include anything that may be directly rendered into +    json, xml etc... +    """ +    __metaclass__ = _RegisterSerializer + +    fields = ()  +    """ +    Specify the fields to be serialized on a model or dict. +    Overrides `include` and `exclude`. +    """ + +    include = () +    """ +    Fields to add to the default set to be serialized on a model/dict. +    """ + +    exclude = () +    """ +    Fields to remove from the default set to be serialized on a model/dict. +    """ + +    rename = {} +    """ +    A dict of key->name to use for the field keys. +    """ + +    related_serializer = None +    """ +    The default serializer class to use for any related models. +    """ + +    depth = None +    """ +    The maximum depth to serialize to, or `None`. +    """ + + +    def __init__(self, depth=None, stack=[], **kwargs): +        self.depth = depth or self.depth +        self.stack = stack +         + +    def get_fields(self, obj): +        """ +        Return the set of field names/keys to use for a model instance/dict. +        """ +        fields = self.fields + +        # If `fields` is not set, we use the default fields and modify +        # them with `include` and `exclude` +        if not fields: +            default = self.get_default_fields(obj) +            include = self.include or () +            exclude = self.exclude or () +            fields = set(default + list(include)) - set(exclude) + +        else: +            fields = _fields_to_list(self.fields) + +        return fields + + +    def get_default_fields(self, obj): +        """ +        Return the default list of field names/keys for a model instance/dict. +        These are used if `fields` is not given. +        """ +        if isinstance(obj, models.Model): +            opts = obj._meta +            return [field.name for field in opts.fields + opts.many_to_many] +        else: +            return obj.keys() + + +    def get_related_serializer(self, key): +        info = _fields_to_dict(self.fields).get(key, None) + +        # If an element in `fields` is a 2-tuple of (str, tuple) +        # then the second element of the tuple is the fields to +        # set on the related serializer +        if isinstance(info, (list, tuple)): +            class OnTheFlySerializer(Serializer): +                fields = info +            return OnTheFlySerializer + +        # If an element in `fields` is a 2-tuple of (str, Serializer) +        # then the second element of the tuple is the Serializer +        # class to use for that field. +        elif isinstance(info, type) and issubclass(info, Serializer): +            return info + +        # If an element in `fields` is a 2-tuple of (str, str) +        # then the second element of the tuple is the name of the Serializer +        # class to use for that field. +        # +        # Black magic to deal with cyclical Serializer dependancies. +        # Similar to what Django does for cyclically related models. +        elif isinstance(info, str) and info in _serializers: +            return _serializers[info] +         +        # Otherwise use `related_serializer` or fall back to `Serializer` +        return getattr(self, 'related_serializer') or Serializer + + +    def serialize_key(self, key): +        """ +        Keys serialize to their string value, +        unless they exist in the `rename` dict. +        """ +        return getattr(self.rename, key, key) + + +    def serialize_val(self, key, obj): +        """ +        Convert a model field or dict value into a serializable representation. +        """ +        related_serializer = self.get_related_serializer(key) +      +        if self.depth is None: +            depth = None +        elif self.depth <= 0: +            return self.serialize_max_depth(obj) +        else: +            depth = self.depth - 1 + +        if any([obj is elem for elem in self.stack]): +            return self.serialize_recursion(obj) +        else: +            stack = self.stack[:] +            stack.append(obj) + +        return related_serializer(depth=depth, stack=stack).serialize(obj) + + +    def serialize_max_depth(self, obj): +        """ +        Determine how objects should be serialized once `depth` is exceeded. +        The default behavior is to ignore the field. +        """ +        raise _SkipField + + +    def serialize_recursion(self, obj): +        """ +        Determine how objects should be serialized if recursion occurs. +        The default behavior is to ignore the field. +        """ +        raise _SkipField + + +    def serialize_model(self, instance): +        """ +        Given a model instance or dict, serialize it to a dict.. +        """ +        data = {} + +        fields = self.get_fields(instance) + +        # serialize each required field  +        for fname in fields: +            if hasattr(self, fname): +                # check for a method 'fname' on self first +                meth = getattr(self, fname) +                if inspect.ismethod(meth) and len(inspect.getargspec(meth)[0]) == 2: +                    obj = meth(instance) +            elif hasattr(instance, fname): +                # now check for an attribute 'fname' on the instance +                obj = getattr(instance, fname) +            elif fname in instance: +                # finally check for a key 'fname' on the instance +                obj = instance[fname] +            else: +                continue + +            try: +                key = self.serialize_key(fname) +                val = self.serialize_val(fname, obj) +                data[key] = val +            except _SkipField: +                pass + +        return data + + +    def serialize_iter(self, obj): +        """ +        Convert iterables into a serializable representation. +        """ +        return [self.serialize(item) for item in obj] + + +    def serialize_func(self, obj): +        """ +        Convert no-arg methods and functions into a serializable representation. +        """ +        return self.serialize(obj()) + + +    def serialize_manager(self, obj): +        """ +        Convert a model manager into a serializable representation. +        """ +        return self.serialize_iter(obj.all()) + + +    def serialize_decimal(self, obj): +        """ +        Convert a Decimal instance into a serializable representation. +        """ +        return str(obj) + + +    def serialize_fallback(self, obj): +        """ +        Convert any unhandled object into a serializable representation. +        """ +        return smart_unicode(obj, strings_only=True) +  +  +    def serialize(self, obj): +        """ +        Convert any object into a serializable representation. +        """ +         +        if isinstance(obj, (dict, models.Model)): +            # Model instances & dictionaries +            return self.serialize_model(obj) +        elif isinstance(obj, (tuple, list, set, QuerySet, types.GeneratorType)): +            # basic iterables +            return self.serialize_iter(obj) +        elif isinstance(obj, models.Manager): +            # Manager objects +            return self.serialize_manager(obj) +        elif isinstance(obj, decimal.Decimal): +            # Decimals (force to string representation) +            return self.serialize_decimal(obj) +        elif inspect.isfunction(obj) and not inspect.getargspec(obj)[0]: +            # function with no args +            return self.serialize_func(obj) +        elif inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1: +            # bound method +            return self.serialize_func(obj) + +        # fall back to smart unicode +        return self.serialize_fallback(obj) diff --git a/djangorestframework/tests/resources.py b/djangorestframework/tests/resources.py deleted file mode 100644 index 088e3159..00000000 --- a/djangorestframework/tests/resources.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Tests for the resource module""" -from django.test import TestCase -from djangorestframework.resources import _object_to_data - -from django.db import models - -import datetime -import decimal - -class TestObjectToData(TestCase):  -    """Tests for the _object_to_data function""" - -    def test_decimal(self): -        """Decimals need to be converted to a string representation.""" -        self.assertEquals(_object_to_data(decimal.Decimal('1.5')), '1.5') - -    def test_function(self): -        """Functions with no arguments should be called.""" -        def foo(): -            return 1 -        self.assertEquals(_object_to_data(foo), 1) - -    def test_method(self): -        """Methods with only a ``self`` argument should be called.""" -        class Foo(object): -            def foo(self): -                return 1 -        self.assertEquals(_object_to_data(Foo().foo), 1) - -    def test_datetime(self): -        """datetime objects are left as-is.""" -        now = datetime.datetime.now() -        self.assertEquals(_object_to_data(now), now) -     -    def test_tuples(self): -        """ Test tuple serialisation """ -        class M1(models.Model): -            field1 = models.CharField() -            field2 = models.CharField() -         -        class M2(models.Model): -            field = models.OneToOneField(M1) -         -        class M3(models.Model): -            field = models.ForeignKey(M1) -         -        m1 = M1(field1='foo', field2='bar') -        m2 = M2(field=m1) -        m3 = M3(field=m1) -         -        Resource = type('Resource', (object,), {'fields':(), 'include':(), 'exclude':()}) -         -        r = Resource() -        r.fields = (('field', ('field1')),) - -        self.assertEqual(_object_to_data(m2, r), dict(field=dict(field1=u'foo'))) -         -        r.fields = (('field', ('field2')),) -        self.assertEqual(_object_to_data(m3, r), dict(field=dict(field2=u'bar'))) -         diff --git a/djangorestframework/tests/serializer.py b/djangorestframework/tests/serializer.py new file mode 100644 index 00000000..783e941e --- /dev/null +++ b/djangorestframework/tests/serializer.py @@ -0,0 +1,117 @@ +"""Tests for the resource module""" +from django.test import TestCase +from djangorestframework.serializer import Serializer + +from django.db import models + +import datetime +import decimal + +class TestObjectToData(TestCase):  +    """ +    Tests for the Serializer class. +    """ + +    def setUp(self): +        self.serializer = Serializer() +        self.serialize = self.serializer.serialize + +    def test_decimal(self): +        """Decimals need to be converted to a string representation.""" +        self.assertEquals(self.serialize(decimal.Decimal('1.5')), '1.5') + +    def test_function(self): +        """Functions with no arguments should be called.""" +        def foo(): +            return 1 +        self.assertEquals(self.serialize(foo), 1) + +    def test_method(self): +        """Methods with only a ``self`` argument should be called.""" +        class Foo(object): +            def foo(self): +                return 1 +        self.assertEquals(self.serialize(Foo().foo), 1) + +    def test_datetime(self): +        """ +        datetime objects are left as-is. +        """ +        now = datetime.datetime.now() +        self.assertEquals(self.serialize(now), now) + + +class TestFieldNesting(TestCase): +    """ +    Test nesting the fields in the Serializer class +    """ +    def setUp(self): +        self.serializer = Serializer() +        self.serialize = self.serializer.serialize + +        class M1(models.Model): +            field1 = models.CharField() +            field2 = models.CharField() + +        class M2(models.Model): +            field = models.OneToOneField(M1) + +        class M3(models.Model): +            field = models.ForeignKey(M1) + +        self.m1 = M1(field1='foo', field2='bar') +        self.m2 = M2(field=self.m1) +        self.m3 = M3(field=self.m1) + + +    def test_tuple_nesting(self): +        """ +        Test tuple nesting on `fields` attr +        """ +        class SerializerM2(Serializer): +            fields = (('field', ('field1',)),) + +        class SerializerM3(Serializer): +            fields = (('field', ('field2',)),) + +        self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) +        self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) + + +    def test_serializer_class_nesting(self): +        """ +        Test related model serialization +        """ +        class NestedM2(Serializer): +            fields = ('field1', ) + +        class NestedM3(Serializer): +            fields = ('field2', ) + +        class SerializerM2(Serializer): +            fields = [('field', NestedM2)] + +        class SerializerM3(Serializer): +            fields = [('field', NestedM3)] + +        self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) +        self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) + +    def test_serializer_classname_nesting(self): +        """ +        Test related model serialization +        """ +        class SerializerM2(Serializer): +            fields = [('field', 'NestedM2')] + +        class SerializerM3(Serializer): +            fields = [('field', 'NestedM3')] + +        class NestedM2(Serializer): +            fields = ('field1', ) + +        class NestedM3(Serializer): +            fields = ('field2', ) + +        self.assertEqual(SerializerM2().serialize(self.m2), {'field': {'field1': u'foo'}}) +        self.assertEqual(SerializerM3().serialize(self.m3), {'field': {'field2': u'bar'}}) diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index 80cfc2e1..b620ee24 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -13,23 +13,22 @@ from djangorestframework.resources import FormResource  class MockView(View):      permissions = ( PerUserThrottling, ) -    throttle = '3/sec' # 3 requests per second +    throttle = '3/sec'      def get(self, request):          return 'foo' -class MockView1(MockView): +class MockView_PerViewThrottling(MockView):      permissions = ( PerViewThrottling, ) -class MockView2(MockView): +class MockView_PerResourceThrottling(MockView):          permissions = ( PerResourceThrottling, ) -    #No resource set -     -class MockView3(MockView2):          resource = FormResource -class MockView4(MockView): -    throttle = '3/min' # 3 request per minute +class MockView_MinuteThrottling(MockView): +    throttle = '3/min' +  +   class ThrottlingTests(TestCase):      urls = 'djangorestframework.tests.throttling'    @@ -93,13 +92,13 @@ class ThrottlingTests(TestCase):          """          Ensure request rate is limited globally per View for PerViewThrottles          """ -        self.ensure_is_throttled(MockView1, 503) +        self.ensure_is_throttled(MockView_PerViewThrottling, 503)      def test_request_throttling_is_per_resource(self):          """          Ensure request rate is limited globally per Resource for PerResourceThrottles          """         -        self.ensure_is_throttled(MockView3, 503) +        self.ensure_is_throttled(MockView_PerResourceThrottling, 503)      def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): @@ -108,8 +107,8 @@ class ThrottlingTests(TestCase):          set properly.          """          request = self.factory.get('/') -        for expect in expected_headers: -            self.set_throttle_timer(view, 0) +        for timer, expect in expected_headers: +            self.set_throttle_timer(view, timer)              response = view.as_view()(request)              self.assertEquals(response['X-Throttle'], expect) @@ -118,20 +117,32 @@ class ThrottlingTests(TestCase):          Ensure for second based throttles.          """          self.ensure_response_header_contains_proper_throttle_field(MockView, -         ('status=SUCCESS; next=0.33 sec', -          'status=SUCCESS; next=0.50 sec', -          'status=SUCCESS; next=1.00 sec', -          'status=FAILURE; next=1.00 sec' +         ((0, 'status=SUCCESS; next=0.33 sec'), +          (0, 'status=SUCCESS; next=0.50 sec'), +          (0, 'status=SUCCESS; next=1.00 sec'), +          (0, 'status=FAILURE; next=1.00 sec')           ))      def test_minutes_fields(self):          """          Ensure for minute based throttles.          """ -        self.ensure_response_header_contains_proper_throttle_field(MockView4, -         ('status=SUCCESS; next=20.00 sec', -          'status=SUCCESS; next=30.00 sec', -          'status=SUCCESS; next=60.00 sec', -          'status=FAILURE; next=60.00 sec' +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, 'status=SUCCESS; next=20.00 sec'), +          (0, 'status=SUCCESS; next=30.00 sec'), +          (0, 'status=SUCCESS; next=60.00 sec'), +          (0, 'status=FAILURE; next=60.00 sec') +         )) +     +    def test_next_rate_remains_constant_if_followed(self): +        """ +        If a client follows the recommended next request rate, +        the throttling rate should stay constant. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, 'status=SUCCESS; next=20.00 sec'), +          (20, 'status=SUCCESS; next=20.00 sec'), +          (40, 'status=SUCCESS; next=20.00 sec'), +          (60, 'status=SUCCESS; next=20.00 sec'), +          (80, 'status=SUCCESS; next=20.00 sec')           )) -                                                                    
\ No newline at end of file diff --git a/djangorestframework/views.py b/djangorestframework/views.py index e38207ac..18d064e1 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -64,10 +64,6 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):      """      permissions = ( permissions.FullAnonAccess, ) -    """ -    Headers to be sent with response. -    """ -    headers = {}      @classmethod      def as_view(cls, **initkwargs): @@ -105,12 +101,14 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          """          pass +      def add_header(self, field, value):          """          Add *field* and *value* to the :attr:`headers` attribute of the :class:`View` class.           """          self.headers[field] = value -     + +      # Note: session based authentication is explicitly CSRF validated,      # all other authentication is CSRF exempt.      @csrf_exempt @@ -118,6 +116,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          self.request = request          self.args = args          self.kwargs = kwargs +        self.headers = {}          # Calls to 'reverse' will not be fully qualified unless we set the scheme/host/port here.          prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) @@ -160,8 +159,8 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):          response.headers['Allow'] = ', '.join(self.allowed_methods)          response.headers['Vary'] = 'Authenticate, Accept' -        # merge with headers possibly set by a Throttle class -        response.headers = dict(response.headers.items() + self.headers.items()) +        # merge with headers possibly set at some point in the view +        response.headers.update(self.headers)          return self.render(response)      | 
