diff options
| -rw-r--r-- | rest_framework/fields.py | 15 | ||||
| -rw-r--r-- | rest_framework/generics.py | 2 | ||||
| -rw-r--r-- | rest_framework/relations.py | 73 | ||||
| -rw-r--r-- | tests/test_generics.py | 4 | 
4 files changed, 49 insertions, 45 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a96f9ba8..4f06d186 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -508,7 +508,7 @@ class DecimalField(Field):  class DateField(Field):      default_error_messages = { -        'invalid': _("Date has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),      }      input_formats = api_settings.DATE_INPUT_FORMATS      format = api_settings.DATE_FORMAT @@ -551,8 +551,7 @@ class DateField(Field):                      return parsed.date()          humanized_format = humanize_datetime.date_formats(self.input_formats) -        msg = self.error_messages['invalid'] % humanized_format -        raise ValidationError(msg) +        self.fail('invalid', format=humanized_format)      def to_representation(self, value):          if value is None or self.format is None: @@ -568,7 +567,7 @@ class DateField(Field):  class DateTimeField(Field):      default_error_messages = { -        'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),      }      input_formats = api_settings.DATETIME_INPUT_FORMATS      format = api_settings.DATETIME_FORMAT @@ -617,8 +616,7 @@ class DateTimeField(Field):                      return parsed          humanized_format = humanize_datetime.datetime_formats(self.input_formats) -        msg = self.error_messages['invalid'] % humanized_format -        raise ValidationError(msg) +        self.fail('invalid', format=humanized_format)      def to_representation(self, value):          if value is None or self.format is None: @@ -634,7 +632,7 @@ class DateTimeField(Field):  class TimeField(Field):      default_error_messages = { -        'invalid': _("Time has wrong format. Use one of these formats instead: %s"), +        'invalid': _('Time has wrong format. Use one of these formats instead: {format}'),      }      input_formats = api_settings.TIME_INPUT_FORMATS      format = api_settings.TIME_FORMAT @@ -669,8 +667,7 @@ class TimeField(Field):                      return parsed.time()          humanized_format = humanize_datetime.time_formats(self.input_formats) -        msg = self.error_messages['invalid'] % humanized_format -        raise ValidationError(msg) +        self.fail('invalid', format=humanized_format)      def to_representation(self, value):          if value is None or self.format is None: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 338d56a6..eb6b64ef 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -216,7 +216,7 @@ class GenericAPIView(views.APIView):          )          queryset = self.queryset -        if isinstance(self.queryset, QuerySet): +        if isinstance(queryset, QuerySet):              # Ensure queryset is re-evaluated on each request.              queryset = queryset.all()          return queryset diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 30a252db..e23a4152 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -2,28 +2,35 @@ from rest_framework.fields import Field  from rest_framework.reverse import reverse  from django.core.exceptions import ObjectDoesNotExist  from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch +from django.db.models.query import QuerySet  from rest_framework.compat import urlparse -def get_default_queryset(serializer_class, field_name): -    manager = getattr(serializer_class.opts.model, field_name) -    if hasattr(manager, 'related'): -        # Forward relationships -        return manager.related.model._default_manager.all() -    # Reverse relationships -    return manager.field.rel.to._default_manager.all() - -  class RelatedField(Field):      def __init__(self, **kwargs):          self.queryset = kwargs.pop('queryset', None)          self.many = kwargs.pop('many', False) +        assert self.queryset is not None or kwargs.get('read_only', False), ( +            'Relational field must provide a `queryset` argument, ' +            'or set read_only=`True`.' +        )          super(RelatedField, self).__init__(**kwargs) -    def bind(self, field_name, parent, root): -        super(RelatedField, self).bind(field_name, parent, root) -        if self.queryset is None and not self.read_only: -            self.queryset = get_default_queryset(parent, self.source) +    def get_queryset(self): +        queryset = self.queryset +        if isinstance(queryset, QuerySet): +            # Ensure queryset is re-evaluated whenever used. +            queryset = queryset.all() +        return queryset + + +class StringRelatedField(Field): +    def __init__(self, **kwargs): +        kwargs['read_only'] = True +        super(StringRelatedField, self).__init__(**kwargs) + +    def to_representation(self, value): +        return str(value)  class PrimaryKeyRelatedField(RelatedField): @@ -33,9 +40,9 @@ class PrimaryKeyRelatedField(RelatedField):          'incorrect_type': 'Incorrect type.  Expected pk value, received {data_type}.',      } -    def from_native(self, data): +    def to_internal_value(self, data):          try: -            return self.queryset.get(pk=data) +            return self.get_queryset().get(pk=data)          except ObjectDoesNotExist:              self.fail('does_not_exist', pk_value=data)          except (TypeError, ValueError): @@ -68,9 +75,9 @@ class HyperlinkedRelatedField(RelatedField):          """          lookup_value = view_kwargs[self.lookup_url_kwarg]          lookup_kwargs = {self.lookup_field: lookup_value} -        return self.queryset.get(**lookup_kwargs) +        return self.get_queryset().get(**lookup_kwargs) -    def from_native(self, value): +    def to_internal_value(self, value):          try:              http_prefix = value.startswith(('http:', 'https:'))          except AttributeError: @@ -102,13 +109,26 @@ class HyperlinkedIdentityField(RelatedField):      def __init__(self, **kwargs):          kwargs['read_only'] = True +        kwargs['source'] = '*'          self.view_name = kwargs.pop('view_name')          self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)          self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)          super(HyperlinkedIdentityField, self).__init__(**kwargs) -    def get_attribute(self, instance): -        return instance +    def get_url(self, obj, view_name, request, format): +        """ +        Given an object, return the URL that hyperlinks to the object. + +        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` +        attributes are not configured to correctly match the URL conf. +        """ +        # Unsaved objects will not yet have a valid URL. +        if obj.pk is None: +            return None + +        lookup_value = getattr(obj, self.lookup_field) +        kwargs = {self.lookup_url_kwarg: lookup_value} +        return reverse(view_name, kwargs=kwargs, request=request, format=format)      def to_representation(self, value):          request = self.context.get('request', None) @@ -144,21 +164,6 @@ class HyperlinkedIdentityField(RelatedField):              )              raise Exception(msg % self.view_name) -    def get_url(self, obj, view_name, request, format): -        """ -        Given an object, return the URL that hyperlinks to the object. - -        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` -        attributes are not configured to correctly match the URL conf. -        """ -        # Unsaved objects will not yet have a valid URL. -        if obj.pk is None: -            return None - -        lookup_value = getattr(obj, self.lookup_field) -        kwargs = {self.lookup_url_kwarg: lookup_value} -        return reverse(view_name, kwargs=kwargs, request=request, format=format) -  class SlugRelatedField(RelatedField):      def __init__(self, **kwargs): diff --git a/tests/test_generics.py b/tests/test_generics.py index 17bfca2f..51004edf 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -547,7 +547,9 @@ class ClassA(models.Model):  class ClassASerializer(serializers.ModelSerializer): -    childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') +    childs = serializers.PrimaryKeyRelatedField( +        many=True, queryset=ClassB.objects.all() +    )      class Meta:          model = ClassA | 
