diff options
| author | Tom Christie | 2014-11-13 21:11:13 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-11-13 21:11:13 +0000 | 
| commit | 992330055eeb5d787ddd7d62dfc9121a2256fd9b (patch) | |
| tree | 4d8b44256721e69558c3b643ac7945e4320c0f91 | |
| parent | 78a741be27f5007d6fa2f73c6cedf04bfe638f9c (diff) | |
| download | django-rest-framework-992330055eeb5d787ddd7d62dfc9121a2256fd9b.tar.bz2 | |
Refactor many
| -rw-r--r-- | rest_framework/relations.py | 22 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 25 | 
2 files changed, 37 insertions, 10 deletions
| diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 6dc02a11..79c8057b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,9 +10,17 @@ from django.utils.translation import ugettext_lazy as _  class PKOnlyObject(object): +    """ +    This is a mock object, used for when we only need the pk of the object +    instance, but still want to return an object with a .pk attribute, +    in order to keep the same interface as a regular model instance. +    """      def __init__(self, pk):          self.pk = pk + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer.  MANY_RELATION_KWARGS = (      'read_only', 'write_only', 'required', 'default', 'initial', 'source',      'label', 'help_text', 'style', 'error_messages' @@ -36,13 +44,17 @@ class RelatedField(Field):          # We override this method in order to automagically create          # `ManyRelatedField` classes instead when `many=True` is set.          if kwargs.pop('many', False): -            list_kwargs = {'child_relation': cls(*args, **kwargs)} -            for key in kwargs.keys(): -                if key in MANY_RELATION_KWARGS: -                    list_kwargs[key] = kwargs[key] -            return ManyRelatedField(**list_kwargs) +            return cls.many_init(*args, **kwargs)          return super(RelatedField, cls).__new__(cls, *args, **kwargs) +    @classmethod +    def many_init(cls, *args, **kwargs): +        list_kwargs = {'child_relation': cls(*args, **kwargs)} +        for key in kwargs.keys(): +            if key in MANY_RELATION_KWARGS: +                list_kwargs[key] = kwargs[key] +        return ManyRelatedField(**list_kwargs) +      def run_validation(self, data=empty):          # We force empty strings to None values for relational fields.          if data == '': diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a4aeeeb7..70bba8ab 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -46,6 +46,9 @@ import warnings  from rest_framework.relations import *  # NOQA  from rest_framework.fields import *  # NOQA + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer.  LIST_SERIALIZER_KWARGS = (      'read_only', 'write_only', 'required', 'default', 'initial', 'source',      'label', 'help_text', 'style', 'error_messages', @@ -73,13 +76,25 @@ class BaseSerializer(Field):          # We override this method in order to automagically create          # `ListSerializer` classes instead when `many=True` is set.          if kwargs.pop('many', False): -            list_kwargs = {'child': cls(*args, **kwargs)} -            for key in kwargs.keys(): -                if key in LIST_SERIALIZER_KWARGS: -                    list_kwargs[key] = kwargs[key] -            return ListSerializer(*args, **list_kwargs) +            return cls.many_init(*args, **kwargs)          return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) +    @classmethod +    def many_init(cls, *args, **kwargs): +        """ +        This method implements the creation of a `ListSerializer` parent +        class when `many=True` is used. You can customize it if you need to +        control which keyword arguments are passed to the parent, and +        which are passed to the child. +        """ +        child_serializer = cls(*args, **kwargs) +        list_kwargs = {'child': child_serializer} +        list_kwargs.update(dict([ +            (key, value) for key, value in kwargs.items() +            if key in LIST_SERIALIZER_KWARGS +        ])) +        return ListSerializer(*args, **list_kwargs) +      def to_internal_value(self, data):          raise NotImplementedError('`to_internal_value()` must be implemented.') | 
