diff options
| -rw-r--r-- | rest_framework/fields.py | 2 | ||||
| -rw-r--r-- | tests/test_fields.py | 19 | 
2 files changed, 21 insertions, 0 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 13301f31..c327f11b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -114,6 +114,8 @@ class CreateOnlyDefault:      def set_context(self, serializer_field):          self.is_update = serializer_field.parent.instance is not None +        if callable(self.default) and hasattr(self.default, 'set_context'): +            self.default.set_context(serializer_field)      def __call__(self):          if self.is_update: diff --git a/tests/test_fields.py b/tests/test_fields.py index 7f5f8102..1aa528da 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -317,6 +317,25 @@ class TestCreateOnlyDefault:              'text': 'example',          } +    def test_create_only_default_callable_sets_context(self): +        """ +        CreateOnlyDefault instances with a callable default should set_context +        on the callable if possible +        """ +        class TestCallableDefault: +            def set_context(self, serializer_field): +                self.field = serializer_field + +            def __call__(self): +                return "success" if hasattr(self, 'field') else "failure" + +        class TestSerializer(serializers.Serializer): +            context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault())) + +        serializer = TestSerializer(data={}) +        assert serializer.is_valid() +        assert serializer.validated_data['context_set'] == 'success' +  # Tests for field input and output values.  # ---------------------------------------- | 
