diff options
| author | Tom Christie | 2012-10-04 13:28:14 +0100 |
|---|---|---|
| committer | Tom Christie | 2012-10-04 13:28:14 +0100 |
| commit | 3a06dde8848dd18810b04db4b7dcb5f8bd768c29 (patch) | |
| tree | 29c08ff7cb4793ea9c9ba62c1179146fcf58146c /rest_framework/fields.py | |
| parent | d89d6887d2eb8293348cb1a7a043a05352819cb8 (diff) | |
| download | django-rest-framework-3a06dde8848dd18810b04db4b7dcb5f8bd768c29.tar.bz2 | |
Clean up field classes
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 292 |
1 files changed, 156 insertions, 136 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 09ccc4ff..32f2d122 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -26,21 +26,88 @@ def is_simple_callable(obj): class Field(object): creation_counter = 0 - default_validators = [] - default_error_messages = { - 'required': _('This field is required.'), - 'invalid': _('Invalid value.'), - } empty = '' - def __init__(self, source=None, readonly=False, required=None, - validators=[], error_messages=None): + def __init__(self, source=None): self.parent = None self.creation_counter = Field.creation_counter Field.creation_counter += 1 self.source = source + + def initialize(self, parent): + """ + Called to set up a field prior to field_to_native or field_from_native. + + parent - The parent serializer. + model_field - The model field this field corrosponds to, if one exists. + """ + self.parent = parent + self.root = parent.root or parent + self.context = self.root.context + + def field_from_native(self, data, field_name, into): + """ + Given a dictionary and a field name, updates the dictionary `into`, + with the field and it's deserialized value. + """ + return + + def field_to_native(self, obj, field_name): + """ + Given and object and a field name, returns the value that should be + serialized for that field. + """ + if obj is None: + return self.empty + + if self.source == '*': + return self.to_native(obj) + + if self.source: + value = obj + for component in self.source.split('.'): + value = getattr(value, component) + else: + value = getattr(obj, field_name) + return self.to_native(value) + + def to_native(self, value): + """ + Converts the field's value into it's simple representation. + """ + if is_simple_callable(value): + value = value() + + if is_protected_type(value): + return value + elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): + return [self.to_native(item) for item in value] + return smart_unicode(value) + + def attributes(self): + """ + Returns a dictionary of attributes to be used when serializing to xml. + """ + if getattr(self, 'type_name', None): + return {'type': self.type_name} + return {} + + +class WritableField(Field): + """ + Base for read/write fields. + """ + default_validators = [] + default_error_messages = { + 'required': _('This field is required.'), + 'invalid': _('Invalid value.'), + } + + def __init__(self, source=None, readonly=False, required=None, + validators=[], error_messages=None): + super(WritableField, self).__init__(source=source) self.readonly = readonly if required is None: self.required = not(readonly) @@ -56,19 +123,6 @@ class Field(object): self.validators = self.default_validators + validators - def initialize(self, parent, model_field=None): - """ - Called to set up a field prior to field_to_native or field_from_native. - - parent - The parent serializer. - model_field - The model field this field corrosponds to, if one exists. - """ - self.parent = parent - self.root = parent.root or parent - self.context = self.root.context - if model_field: - self.model_field = model_field - def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -117,96 +171,75 @@ class Field(object): """ Reverts a simple representation back to the field's value. """ - if hasattr(self, 'model_field'): - try: - return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value) - except: - return self.model_field.to_python(value) return value - def field_to_native(self, obj, field_name): - """ - Given and object and a field name, returns the value that should be - serialized for that field. - """ - if obj is None: - return self.empty - if self.source == '*': - return self.to_native(obj) - - self.obj = obj # Need to hang onto this in the case of model fields - if hasattr(self, 'model_field'): - return self.to_native(self.model_field._get_val_from_obj(obj)) - - if self.source: - value = obj - for component in self.source.split('.'): - value = getattr(value, component) - else: - value = getattr(obj, field_name) - return self.to_native(value) +class ModelField(WritableField): + """ + A generic field that can be used against an arbirtrary model field. + """ + def __init__(self, *args, **kwargs): + try: + self.model_field = kwargs.pop('model_field') + except: + raise ValueError("ModelField requires 'model_field' kwarg") + super(ModelField, self).__init__(*args, **kwargs) - def to_native(self, value): - """ - Converts the field's value into it's simple representation. - """ - if is_simple_callable(value): - value = value() + def from_native(self, value): + try: + rel = self.model_field.rel + except: + return self.model_field.to_python(value) + return rel.to._meta.get_field(rel.field_name).to_python(value) + def field_to_native(self, obj, field_name): + value = self.model_field._get_val_from_obj(obj) if is_protected_type(value): return value - elif hasattr(self, 'model_field'): - return self.model_field.value_to_string(self.obj) - elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): - return [self.to_native(item) for item in value] - return smart_unicode(value) + return self.model_field.value_to_string(self.obj) def attributes(self): - """ - Returns a dictionary of attributes to be used when serializing to xml. - """ - try: - return { - "type": self.model_field.get_internal_type() - } - except AttributeError: - return {} - + return { + "type": self.model_field.get_internal_type() + } -class HyperlinkedIdentityField(Field): - def field_to_native(self, obj, field_name): - request = self.context.get('request', None) - view_name = self.parent.opts.view_name - view_kwargs = {'pk': obj.pk} - return reverse(view_name, kwargs=view_kwargs, request=request) +##### Relational fields ##### -class RelatedField(Field): +class RelatedField(WritableField): """ - A base class for model related fields or related managers. - - Subclass this and override `convert` to define custom behaviour when - serializing related objects. + Base class for related model fields. """ def __init__(self, *args, **kwargs): self.queryset = kwargs.pop('queryset', None) super(RelatedField, self).__init__(*args, **kwargs) def field_to_native(self, obj, field_name): - obj = getattr(obj, self.source or field_name) - if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): - return [self.to_native(item) for item in obj.all()] - return self.to_native(obj) + value = getattr(obj, self.source or field_name) + return self.to_native(value) - def attributes(self): + def field_from_native(self, data, field_name, into): + value = data.get(field_name) + into[(self.source or field_name) + '_id'] = self.from_native(value) + + +class ManyRelatedField(RelatedField): + """ + Base class for related model managers. + """ + def field_to_native(self, obj, field_name): + value = getattr(obj, self.source or field_name) + return [self.to_native(item) for item in value.all()] + + def field_from_native(self, data, field_name, into): try: - return { - "rel": self.model_field.rel.__class__.__name__, - "to": smart_unicode(self.model_field.rel.to._meta) - } - except AttributeError: - return {} + value = data.getlist(self.source or field_name) + except: + value = data.get(self.source or field_name) + else: + if value == ['']: + value = [] + into[field_name] = [self.from_native(item) for item in value] class PrimaryKeyRelatedField(RelatedField): @@ -215,20 +248,11 @@ class PrimaryKeyRelatedField(RelatedField): """ def to_native(self, pk): - """ - You can subclass this method to provide different serialization - behavior based on the pk. - """ return pk def field_to_native(self, obj, field_name): - # This is only implemented for performance reasons - # - # We could leave the default `RelatedField.field_to_native()` in place, - # and inside just implement `to_native()` as `return obj.pk` - # - # That would involve an extra database lookup. try: + # Prefer obj.serializable_value for performance reasons pk = obj.serializable_value(self.source or field_name) except AttributeError: # RelatedObject (reverse relationship) @@ -237,18 +261,17 @@ class PrimaryKeyRelatedField(RelatedField): # Forward relationship return self.to_native(pk) - def field_from_native(self, data, field_name, into): - value = data.get(field_name) - into[field_name + '_id'] = self.from_native(value) - -class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): +class ManyPrimaryKeyRelatedField(ManyRelatedField): """ Serializes a to-many related field or related manager to a pk value. """ + def to_native(self, pk): + return pk def field_to_native(self, obj, field_name): try: + # Prefer obj.serializable_value for performance reasons queryset = obj.serializable_value(self.source or field_name) except AttributeError: # RelatedManager (reverse relationship) @@ -257,40 +280,25 @@ class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): # Forward relationship return [self.to_native(item.pk) for item in queryset.all()] - def field_from_native(self, data, field_name, into): - try: - value = data.getlist(field_name) - except: - value = data.get(field_name) - else: - if value == ['']: - value = [] - into[field_name] = [self.from_native(item) for item in value] - -class NaturalKeyRelatedField(RelatedField): +class HyperlinkedIdentityField(Field): """ - Serializes a model related field or related manager to a natural key value. + A field that represents the model's identity using a hyperlink. """ - is_natural_key = True # XML renderer handles these differently - - def to_native(self, obj): - if hasattr(obj, 'natural_key'): - return obj.natural_key() - return obj + def __init__(self, *args, **kwargs): + pass - def field_from_native(self, data, field_name, into): - value = data.get(field_name) - into[self.model_field.attname] = self.from_native(value) + def field_to_native(self, obj, field_name): + request = self.context.get('request', None) + view_name = self.parent.opts.view_name + view_kwargs = {'pk': obj.pk} + return reverse(view_name, kwargs=view_kwargs, request=request) - def from_native(self, value): - # TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS) - manager = self.model_field.rel.to._default_manager - manager = manager.db_manager(DEFAULT_DB_ALIAS) - return manager.get_by_natural_key(*value).pk +##### Typed Fields ##### -class BooleanField(Field): +class BooleanField(WritableField): + type_name = 'BooleanField' default_error_messages = { 'invalid': _(u"'%s' value must be either True or False."), } @@ -307,7 +315,9 @@ class BooleanField(Field): raise ValidationError(self.error_messages['invalid'] % value) -class CharField(Field): +class CharField(WritableField): + type_name = 'CharField' + def __init__(self, max_length=None, min_length=None, *args, **kwargs): self.max_length, self.min_length = max_length, min_length super(CharField, self).__init__(*args, **kwargs) @@ -323,6 +333,8 @@ class CharField(Field): class EmailField(CharField): + type_name = 'EmailField' + default_error_messages = { 'invalid': _('Enter a valid e-mail address.'), } @@ -339,7 +351,9 @@ class EmailField(CharField): return result -class DateField(Field): +class DateField(WritableField): + type_name = 'DateField' + default_error_messages = { 'invalid': _(u"'%s' value has an invalid date format. It must be " u"in YYYY-MM-DD format."), @@ -373,7 +387,9 @@ class DateField(Field): raise ValidationError(msg) -class DateTimeField(Field): +class DateTimeField(WritableField): + type_name = 'DateTimeField' + default_error_messages = { 'invalid': _(u"'%s' value has an invalid format. It must be in " u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), @@ -424,7 +440,9 @@ class DateTimeField(Field): raise ValidationError(msg) -class IntegerField(Field): +class IntegerField(WritableField): + type_name = 'IntegerField' + default_error_messages = { 'invalid': _('Enter a whole number.'), 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), @@ -450,7 +468,9 @@ class IntegerField(Field): return value -class FloatField(Field): +class FloatField(WritableField): + type_name = 'FloatField' + default_error_messages = { 'invalid': _("'%s' value must be a float."), } |
