diff options
| author | Dmitry Mukhin | 2014-08-20 20:04:48 +0400 | 
|---|---|---|
| committer | Dmitry Mukhin | 2014-08-20 20:04:48 +0400 | 
| commit | 3b07d0c9978335e183f369480618b48ff1e1b1ab (patch) | |
| tree | 041027c50d2965da1be7f93b1a6360e07ad976f9 /rest_framework | |
| parent | c3891b6e00daa7a92cca1c88599e046f72926bb4 (diff) | |
| parent | 59b47eac14778767a17e56bd8adc0610417f2878 (diff) | |
| download | django-rest-framework-3b07d0c9978335e183f369480618b48ff1e1b1ab.tar.bz2 | |
Merge branch 'master' into set-retry-after
Conflicts:
	tests/test_throttling.py
Diffstat (limited to 'rest_framework')
101 files changed, 656 insertions, 13857 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 2d76b55d..f30012b9 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,14 +1,14 @@  """ -______ _____ _____ _____    __                                             _     -| ___ \  ___/  ___|_   _|  / _|                                           | |    -| |_/ / |__ \ `--.  | |   | |_ _ __ __ _ _ __ ___   _____      _____  _ __| | __ +______ _____ _____ _____    __ +| ___ \  ___/  ___|_   _|  / _|                                           | | +| |_/ / |__ \ `--.  | |   | |_ _ __ __ _ _ __ ___   _____      _____  _ __| |__  |    /|  __| `--. \ | |   |  _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ / -| |\ \| |___/\__/ / | |   | | | | | (_| | | | | | |  __/\ V  V / (_) | |  |   <  +| |\ \| |___/\__/ / | |   | | | | | (_| | | | | | |  __/\ V  V / (_) | |  |   <  \_| \_\____/\____/  \_/   |_| |_|  \__,_|_| |_| |_|\___| \_/\_/ \___/|_|  |_|\_|  """  __title__ = 'Django REST framework' -__version__ = '2.3.13' +__version__ = '2.3.14'  __author__ = 'Tom Christie'  __license__ = 'BSD 2-Clause'  __copyright__ = 'Copyright 2011-2014 Tom Christie' diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index da9ca510..5721a869 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -6,9 +6,9 @@ import base64  from django.contrib.auth import authenticate  from django.core.exceptions import ImproperlyConfigured +from django.middleware.csrf import CsrfViewMiddleware  from django.conf import settings  from rest_framework import exceptions, HTTP_HEADER_ENCODING -from rest_framework.compat import CsrfViewMiddleware  from rest_framework.compat import oauth, oauth_provider, oauth_provider_store  from rest_framework.compat import oauth2_provider, provider_now, check_nonce  from rest_framework.authtoken.models import Token @@ -21,7 +21,7 @@ def get_authorization_header(request):      Hide some test client ickyness where the header can be unicode.      """      auth = request.META.get('HTTP_AUTHORIZATION', b'') -    if type(auth) == type(''): +    if isinstance(auth, type('')):          # Work around django test client oddness          auth = auth.encode(HTTP_HEADER_ENCODING)      return auth @@ -310,6 +310,13 @@ class OAuth2Authentication(BaseAuthentication):          auth = get_authorization_header(request).split() +        if len(auth) == 1: +            msg = 'Invalid bearer header. No credentials provided.' +            raise exceptions.AuthenticationFailed(msg) +        elif len(auth) > 2: +            msg = 'Invalid bearer header. Token string should not contain spaces.' +            raise exceptions.AuthenticationFailed(msg) +          if auth and auth[0].lower() == b'bearer':              access_token = auth[1]          elif 'access_token' in request.POST: @@ -319,13 +326,6 @@ class OAuth2Authentication(BaseAuthentication):          else:              return None -        if len(auth) == 1: -            msg = 'Invalid bearer header. No credentials provided.' -            raise exceptions.AuthenticationFailed(msg) -        elif len(auth) > 2: -            msg = 'Invalid bearer header. Token string should not contain spaces.' -            raise exceptions.AuthenticationFailed(msg) -          return self.authenticate_credentials(request, access_token)      def authenticate_credentials(self, request, access_token): diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index d5965e40..2e5d6b47 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -1,67 +1,27 @@ -# -*- coding: utf-8 -*- -import datetime -from south.db import db -from south.v2 import SchemaMigration -from django.db import models - -from rest_framework.settings import api_settings - - -try: -    from django.contrib.auth import get_user_model -except ImportError: # django < 1.5 -    from django.contrib.auth.models import User -else: -    User = get_user_model() - - -class Migration(SchemaMigration): - -    def forwards(self, orm): -        # Adding model 'Token' -        db.create_table('authtoken_token', ( -            ('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)), -            ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])), -            ('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)), -        )) -        db.send_create_signal('authtoken', ['Token']) - - -    def backwards(self, orm): -        # Deleting model 'Token' -        db.delete_table('authtoken_token') - - -    models = { -        'auth.group': { -            'Meta': {'object_name': 'Group'}, -            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), -            'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}), -            'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}) -        }, -        'auth.permission': { -            'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'}, -            'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}), -            'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}), -            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), -            'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) -        }, -        "%s.%s" % (User._meta.app_label, User._meta.module_name): { -            'Meta': {'object_name': User._meta.module_name}, -        }, -        'authtoken.token': { -            'Meta': {'object_name': 'Token'}, -            'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}), -            'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}), -            'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)}) -        }, -        'contenttypes.contenttype': { -            'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, -            'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), -            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), -            'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), -            'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) -        } -    } - -    complete_apps = ['authtoken'] +# encoding: utf8 +from __future__ import unicode_literals + +from django.db import models, migrations +from django.conf import settings + + +class Migration(migrations.Migration): + +    dependencies = [ +        migrations.swappable_dependency(settings.AUTH_USER_MODEL), +    ] + +    operations = [ +        migrations.CreateModel( +            name='Token', +            fields=[ +                ('key', models.CharField(max_length=40, serialize=False, primary_key=True)), +                ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, to_field='id')), +                ('created', models.DateTimeField(auto_now_add=True)), +            ], +            options={ +                'abstract': False, +            }, +            bases=(models.Model,), +        ), +    ] diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 8eac2cc4..db21d44c 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,6 +1,5 @@  import binascii  import os -from hashlib import sha1  from django.conf import settings  from django.db import models @@ -34,7 +33,7 @@ class Token(models.Model):          return super(Token, self).save(*args, **kwargs)      def generate_key(self): -        return binascii.hexlify(os.urandom(20)) +        return binascii.hexlify(os.urandom(20)).decode()      def __unicode__(self):          return self.key diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 60a3740e..99e99ae3 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -1,4 +1,6 @@  from django.contrib.auth import authenticate +from django.utils.translation import ugettext_lazy as _ +  from rest_framework import serializers @@ -15,10 +17,13 @@ class AuthTokenSerializer(serializers.Serializer):              if user:                  if not user.is_active: -                    raise serializers.ValidationError('User account is disabled.') +                    msg = _('User account is disabled.') +                    raise serializers.ValidationError(msg)                  attrs['user'] = user                  return attrs              else: -                raise serializers.ValidationError('Unable to login with provided credentials.') +                msg = _('Unable to login with provided credentials.') +                raise serializers.ValidationError(msg)          else: -            raise serializers.ValidationError('Must include "username" and "password"') +            msg = _('Must include "username" and "password"') +            raise serializers.ValidationError(msg) diff --git a/rest_framework/authtoken/south_migrations/0001_initial.py b/rest_framework/authtoken/south_migrations/0001_initial.py new file mode 100644 index 00000000..926de02b --- /dev/null +++ b/rest_framework/authtoken/south_migrations/0001_initial.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +from south.db import db +from south.v2 import SchemaMigration + +try: +    from django.contrib.auth import get_user_model +except ImportError:  # django < 1.5 +    from django.contrib.auth.models import User +else: +    User = get_user_model() + + +class Migration(SchemaMigration): + +    def forwards(self, orm): +        # Adding model 'Token' +        db.create_table('authtoken_token', ( +            ('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)), +            ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])), +            ('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)), +        )) +        db.send_create_signal('authtoken', ['Token']) + +    def backwards(self, orm): +        # Deleting model 'Token' +        db.delete_table('authtoken_token') + +    models = { +        'auth.group': { +            'Meta': {'object_name': 'Group'}, +            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), +            'name': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '80'}), +            'permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}) +        }, +        'auth.permission': { +            'Meta': {'ordering': "('content_type__app_label', 'content_type__model', 'codename')", 'unique_together': "(('content_type', 'codename'),)", 'object_name': 'Permission'}, +            'codename': ('django.db.models.fields.CharField', [], {'max_length': '100'}), +            'content_type': ('django.db.models.fields.related.ForeignKey', [], {'to': "orm['contenttypes.ContentType']"}), +            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), +            'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) +        }, +        "%s.%s" % (User._meta.app_label, User._meta.module_name): { +            'Meta': {'object_name': User._meta.module_name}, +        }, +        'authtoken.token': { +            'Meta': {'object_name': 'Token'}, +            'created': ('django.db.models.fields.DateTimeField', [], {'auto_now_add': 'True', 'blank': 'True'}), +            'key': ('django.db.models.fields.CharField', [], {'max_length': '40', 'primary_key': 'True'}), +            'user': ('django.db.models.fields.related.OneToOneField', [], {'related_name': "'auth_token'", 'unique': 'True', 'to': "orm['%s.%s']" % (User._meta.app_label, User._meta.object_name)}) +        }, +        'contenttypes.contenttype': { +            'Meta': {'ordering': "('name',)", 'unique_together': "(('app_label', 'model'),)", 'object_name': 'ContentType', 'db_table': "'django_content_type'"}, +            'app_label': ('django.db.models.fields.CharField', [], {'max_length': '100'}), +            'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), +            'model': ('django.db.models.fields.CharField', [], {'max_length': '100'}), +            'name': ('django.db.models.fields.CharField', [], {'max_length': '100'}) +        } +    } + +    complete_apps = ['authtoken'] diff --git a/rest_framework/runtests/__init__.py b/rest_framework/authtoken/south_migrations/__init__.py index e69de29b..e69de29b 100644 --- a/rest_framework/runtests/__init__.py +++ b/rest_framework/authtoken/south_migrations/__init__.py diff --git a/rest_framework/compat.py b/rest_framework/compat.py index d155f554..fa0f0bfb 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -5,25 +5,14 @@ versions of django/python, and compatibility wrappers around optional packages.  # flake8: noqa  from __future__ import unicode_literals -  import django  import inspect  from django.core.exceptions import ImproperlyConfigured  from django.conf import settings +from django.utils import six -# Try to import six from Django, fallback to included `six`. -try: -    from django.utils import six -except ImportError: -    from rest_framework import six - -# location of patterns, url, include changes in 1.4 onwards -try: -    from django.conf.urls import patterns, url, include -except ImportError: -    from django.conf.urls.defaults import patterns, url, include -# Handle django.utils.encoding rename: +# Handle django.utils.encoding rename in 1.5 onwards.  # smart_unicode -> smart_text  # force_unicode -> force_text  try: @@ -42,17 +31,23 @@ try:  except ImportError:      from django.http import HttpResponse as HttpResponseBase +  # django-filter is optional  try:      import django_filters  except ImportError:      django_filters = None -# guardian is optional -try: -    import guardian -except ImportError: -    guardian = None + +# Django-guardian is optional. Import only if guardian is in INSTALLED_APPS +# Fixes (#1712). We keep the try/except for the test suite. +guardian = None +if 'guardian' in settings.INSTALLED_APPS: +    try: +        import guardian +        import guardian.shortcuts  # Fixes #1624 +    except ImportError: +        pass  # cStringIO only if it's available, otherwise StringIO @@ -104,46 +99,13 @@ def get_concrete_model(model_cls):          return model_cls +# View._allowed_methods only present from 1.5 onwards  if django.VERSION >= (1, 5):      from django.views.generic import View  else: -    from django.views.generic import View as _View -    from django.utils.decorators import classonlymethod -    from django.utils.functional import update_wrapper - -    class View(_View): -        # 1.3 does not include head method in base View class -        # See: https://code.djangoproject.com/ticket/15668 -        @classonlymethod -        def as_view(cls, **initkwargs): -            """ -            Main entry point for a request-response process. -            """ -            # sanitize keyword arguments -            for key in initkwargs: -                if key in cls.http_method_names: -                    raise TypeError("You tried to pass in the %s method name as a " -                                    "keyword argument to %s(). Don't do that." -                                    % (key, cls.__name__)) -                if not hasattr(cls, key): -                    raise TypeError("%s() received an invalid keyword %r" % ( -                        cls.__name__, key)) - -            def view(request, *args, **kwargs): -                self = cls(**initkwargs) -                if hasattr(self, 'get') and not hasattr(self, 'head'): -                    self.head = self.get -                return self.dispatch(request, *args, **kwargs) - -            # take name and docstring from class -            update_wrapper(view, cls, updated=()) - -            # and possible attributes set by decorators -            # like csrf_exempt from dispatch -            update_wrapper(view, cls.dispatch, assigned=()) -            return view - -        # _allowed_methods only present from 1.5 onwards +    from django.views.generic import View as DjangoView + +    class View(DjangoView):          def _allowed_methods(self):              return [m.upper() for m in self.http_method_names if hasattr(self, m)] @@ -153,316 +115,16 @@ if 'patch' not in View.http_method_names:      View.http_method_names = View.http_method_names + ['patch'] -# PUT, DELETE do not require CSRF until 1.4.  They should.  Make it better. -if django.VERSION >= (1, 4): -    from django.middleware.csrf import CsrfViewMiddleware -else: -    import hashlib -    import re -    import random -    import logging - -    from django.conf import settings -    from django.core.urlresolvers import get_callable - -    try: -        from logging import NullHandler -    except ImportError: -        class NullHandler(logging.Handler): -            def emit(self, record): -                pass - -    logger = logging.getLogger('django.request') - -    if not logger.handlers: -        logger.addHandler(NullHandler()) - -    def same_origin(url1, url2): -        """ -        Checks if two URLs are 'same-origin' -        """ -        p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2) -        return p1[0:2] == p2[0:2] - -    def constant_time_compare(val1, val2): -        """ -        Returns True if the two strings are equal, False otherwise. - -        The time taken is independent of the number of characters that match. -        """ -        if len(val1) != len(val2): -            return False -        result = 0 -        for x, y in zip(val1, val2): -            result |= ord(x) ^ ord(y) -        return result == 0 - -    # Use the system (hardware-based) random number generator if it exists. -    if hasattr(random, 'SystemRandom'): -        randrange = random.SystemRandom().randrange -    else: -        randrange = random.randrange - -    _MAX_CSRF_KEY = 18446744073709551616      # 2 << 63 - -    REASON_NO_REFERER = "Referer checking failed - no Referer." -    REASON_BAD_REFERER = "Referer checking failed - %s does not match %s." -    REASON_NO_CSRF_COOKIE = "CSRF cookie not set." -    REASON_BAD_TOKEN = "CSRF token missing or incorrect." - -    def _get_failure_view(): -        """ -        Returns the view to be used for CSRF rejections -        """ -        return get_callable(settings.CSRF_FAILURE_VIEW) - -    def _get_new_csrf_key(): -        return hashlib.md5("%s%s" % (randrange(0, _MAX_CSRF_KEY), settings.SECRET_KEY)).hexdigest() - -    def get_token(request): -        """ -        Returns the the CSRF token required for a POST form. The token is an -        alphanumeric value. - -        A side effect of calling this function is to make the the csrf_protect -        decorator and the CsrfViewMiddleware add a CSRF cookie and a 'Vary: Cookie' -        header to the outgoing response.  For this reason, you may need to use this -        function lazily, as is done by the csrf context processor. -        """ -        request.META["CSRF_COOKIE_USED"] = True -        return request.META.get("CSRF_COOKIE", None) - -    def _sanitize_token(token): -        # Allow only alphanum, and ensure we return a 'str' for the sake of the post -        # processing middleware. -        token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore'))) -        if token == "": -            # In case the cookie has been truncated to nothing at some point. -            return _get_new_csrf_key() -        else: -            return token - -    class CsrfViewMiddleware(object): -        """ -        Middleware that requires a present and correct csrfmiddlewaretoken -        for POST requests that have a CSRF cookie, and sets an outgoing -        CSRF cookie. - -        This middleware should be used in conjunction with the csrf_token template -        tag. -        """ -        # The _accept and _reject methods currently only exist for the sake of the -        # requires_csrf_token decorator. -        def _accept(self, request): -            # Avoid checking the request twice by adding a custom attribute to -            # request.  This will be relevant when both decorator and middleware -            # are used. -            request.csrf_processing_done = True -            return None - -        def _reject(self, request, reason): -            return _get_failure_view()(request, reason=reason) - -        def process_view(self, request, callback, callback_args, callback_kwargs): - -            if getattr(request, 'csrf_processing_done', False): -                return None - -            try: -                csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME]) -                # Use same token next time -                request.META['CSRF_COOKIE'] = csrf_token -            except KeyError: -                csrf_token = None -                # Generate token and store it in the request, so it's available to the view. -                request.META["CSRF_COOKIE"] = _get_new_csrf_key() - -            # Wait until request.META["CSRF_COOKIE"] has been manipulated before -            # bailing out, so that get_token still works -            if getattr(callback, 'csrf_exempt', False): -                return None - -            # Assume that anything not defined as 'safe' by RC2616 needs protection. -            if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'): -                if getattr(request, '_dont_enforce_csrf_checks', False): -                    # Mechanism to turn off CSRF checks for test suite.  It comes after -                    # the creation of CSRF cookies, so that everything else continues to -                    # work exactly the same (e.g. cookies are sent etc), but before the -                    # any branches that call reject() -                    return self._accept(request) - -                if request.is_secure(): -                    # Suppose user visits http://example.com/ -                    # An active network attacker,(man-in-the-middle, MITM) sends a -                    # POST form which targets https://example.com/detonate-bomb/ and -                    # submits it via javascript. -                    # -                    # The attacker will need to provide a CSRF cookie and token, but -                    # that is no problem for a MITM and the session independent -                    # nonce we are using. So the MITM can circumvent the CSRF -                    # protection. This is true for any HTTP connection, but anyone -                    # using HTTPS expects better!  For this reason, for -                    # https://example.com/ we need additional protection that treats -                    # http://example.com/ as completely untrusted.  Under HTTPS, -                    # Barth et al. found that the Referer header is missing for -                    # same-domain requests in only about 0.2% of cases or less, so -                    # we can use strict Referer checking. -                    referer = request.META.get('HTTP_REFERER') -                    if referer is None: -                        logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path), -                            extra={ -                                'status_code': 403, -                                'request': request, -                            } -                        ) -                        return self._reject(request, REASON_NO_REFERER) - -                    # Note that request.get_host() includes the port -                    good_referer = 'https://%s/' % request.get_host() -                    if not same_origin(referer, good_referer): -                        reason = REASON_BAD_REFERER % (referer, good_referer) -                        logger.warning('Forbidden (%s): %s' % (reason, request.path), -                            extra={ -                                'status_code': 403, -                                'request': request, -                            } -                        ) -                        return self._reject(request, reason) - -                if csrf_token is None: -                    # No CSRF cookie. For POST requests, we insist on a CSRF cookie, -                    # and in this way we can avoid all CSRF attacks, including login -                    # CSRF. -                    logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path), -                        extra={ -                            'status_code': 403, -                            'request': request, -                        } -                    ) -                    return self._reject(request, REASON_NO_CSRF_COOKIE) - -                # check non-cookie token for match -                request_csrf_token = "" -                if request.method == "POST": -                    request_csrf_token = request.POST.get('csrfmiddlewaretoken', '') - -                if request_csrf_token == "": -                    # Fall back to X-CSRFToken, to make things easier for AJAX, -                    # and possible for PUT/DELETE -                    request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '') - -                if not constant_time_compare(request_csrf_token, csrf_token): -                    logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path), -                        extra={ -                            'status_code': 403, -                            'request': request, -                        } -                    ) -                    return self._reject(request, REASON_BAD_TOKEN) - -            return self._accept(request) - -# timezone support is new in Django 1.4 -try: -    from django.utils import timezone -except ImportError: -    timezone = None - -# dateparse is ALSO new in Django 1.4 -try: -    from django.utils.dateparse import parse_date, parse_datetime, parse_time -except ImportError: -    import datetime -    import re - -    date_re = re.compile( -        r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})$' -    ) - -    datetime_re = re.compile( -        r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})' -        r'[T ](?P<hour>\d{1,2}):(?P<minute>\d{1,2})' -        r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?' -        r'(?P<tzinfo>Z|[+-]\d{1,2}:\d{1,2})?$' -    ) - -    time_re = re.compile( -        r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})' -        r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?' -    ) - -    def parse_date(value): -        match = date_re.match(value) -        if match: -            kw = dict((k, int(v)) for k, v in match.groupdict().iteritems()) -            return datetime.date(**kw) - -    def parse_time(value): -        match = time_re.match(value) -        if match: -            kw = match.groupdict() -            if kw['microsecond']: -                kw['microsecond'] = kw['microsecond'].ljust(6, '0') -            kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) -            return datetime.time(**kw) - -    def parse_datetime(value): -        """Parse datetime, but w/o the timezone awareness in 1.4""" -        match = datetime_re.match(value) -        if match: -            kw = match.groupdict() -            if kw['microsecond']: -                kw['microsecond'] = kw['microsecond'].ljust(6, '0') -            kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) -            return datetime.datetime(**kw) - - -# smart_urlquote is new on Django 1.4 -try: -    from django.utils.html import smart_urlquote -except ImportError: -    import re -    from django.utils.encoding import smart_str -    try: -        from urllib.parse import quote, urlsplit, urlunsplit -    except ImportError:     # Python 2 -        from urllib import quote -        from urlparse import urlsplit, urlunsplit - -    unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})') - -    def smart_urlquote(url): -        "Quotes a URL if it isn't already quoted." -        # Handle IDN before quoting. -        scheme, netloc, path, query, fragment = urlsplit(url) -        try: -            netloc = netloc.encode('idna').decode('ascii')  # IDN -> ACE -        except UnicodeError:  # invalid domain part -            pass -        else: -            url = urlunsplit((scheme, netloc, path, query, fragment)) - -        # An URL is considered unquoted if it contains no % characters or -        # contains a % not followed by two hexadecimal digits. See #9655. -        if '%' not in url or unquoted_percents_re.search(url): -            # See http://bugs.python.org/issue2637 -            url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~') - -        return force_text(url) - - -# RequestFactory only provide `generic` from 1.5 onwards - +# RequestFactory only provides `generic` from 1.5 onwards  from django.test.client import RequestFactory as DjangoRequestFactory  from django.test.client import FakePayload  try:      # In 1.5 the test client uses force_bytes      from django.utils.encoding import force_bytes as force_bytes_or_smart_bytes  except ImportError: -    # In 1.3 and 1.4 the test client just uses smart_str +    # In 1.4 the test client just uses smart_str      from django.utils.encoding import smart_str as force_bytes_or_smart_bytes -  class RequestFactory(DjangoRequestFactory):      def generic(self, method, path,              data='', content_type='application/octet-stream', **extra): @@ -487,6 +149,7 @@ class RequestFactory(DjangoRequestFactory):          r.update(extra)          return self.request(**r) +  # Markdown is optional  try:      import markdown @@ -501,7 +164,6 @@ try:          safe_mode = False          md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)          return md.convert(text) -  except ImportError:      apply_markdown = None @@ -519,14 +181,16 @@ try:  except ImportError:      etree = None -# OAuth is optional + +# OAuth2 is optional  try:      # Note: The `oauth2` package actually provides oauth1.0a support.  Urg.      import oauth2 as oauth  except ImportError:      oauth = None -# OAuth is optional + +# OAuthProvider is optional  try:      import oauth_provider      from oauth_provider.store import store as oauth_provider_store @@ -548,6 +212,7 @@ except (ImportError, ImproperlyConfigured):      oauth_provider_store = None      check_nonce = None +  # OAuth 2 support is optional  try:      import provider as oauth2_provider @@ -567,7 +232,8 @@ except ImportError:      oauth2_constants = None      provider_now = None -# Handle lazy strings + +# Handle lazy strings across Py2/Py3  from django.utils.functional import Promise  if six.PY3: diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index c69756a4..449ba0a2 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -3,13 +3,14 @@ The most important decorator in this module is `@api_view`, which is used  for writing function-based views with REST framework.  There are also various decorators for setting the API policies on function -based views, as well as the `@action` and `@link` decorators, which are +based views, as well as the `@detail_route` and `@list_route` decorators, which are  used to annotate methods on viewsets that should be included by routers.  """  from __future__ import unicode_literals -from rest_framework.compat import six +from django.utils import six  from rest_framework.views import APIView  import types +import warnings  def api_view(http_method_names): @@ -107,23 +108,59 @@ def permission_classes(permission_classes):      return decorator +def detail_route(methods=['get'], **kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for detail requests. +    """ +    def decorator(func): +        func.bind_to_methods = methods +        func.detail = True +        func.kwargs = kwargs +        return func +    return decorator + + +def list_route(methods=['get'], **kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for list requests. +    """ +    def decorator(func): +        func.bind_to_methods = methods +        func.detail = False +        func.kwargs = kwargs +        return func +    return decorator + + +# These are now pending deprecation, in favor of `detail_route` and `list_route`. +  def link(**kwargs):      """ -    Used to mark a method on a ViewSet that should be routed for GET requests. +    Used to mark a method on a ViewSet that should be routed for detail GET requests.      """ +    msg = 'link is pending deprecation. Use detail_route instead.' +    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +      def decorator(func):          func.bind_to_methods = ['get'] +        func.detail = True          func.kwargs = kwargs          return func +      return decorator  def action(methods=['post'], **kwargs):      """ -    Used to mark a method on a ViewSet that should be routed for POST requests. +    Used to mark a method on a ViewSet that should be routed for detail POST requests.      """ +    msg = 'action is pending deprecation. Use detail_route instead.' +    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +      def decorator(func):          func.bind_to_methods = methods +        func.detail = True          func.kwargs = kwargs          return func +      return decorator diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 389032bd..ad52d172 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -23,6 +23,7 @@ class APIException(Exception):      def __str__(self):          return self.detail +  class ParseError(APIException):      status_code = status.HTTP_400_BAD_REQUEST      default_detail = 'Malformed request.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 68b95682..9d707c9b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -18,12 +18,14 @@ from django.conf import settings  from django.db.models.fields import BLANK_CHOICE_DASH  from django.http import QueryDict  from django.forms import widgets +from django.utils import six, timezone  from django.utils.encoding import is_protected_type  from django.utils.translation import ugettext_lazy as _  from django.utils.datastructures import SortedDict +from django.utils.dateparse import parse_date, parse_datetime, parse_time  from rest_framework import ISO_8601  from rest_framework.compat import ( -    timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text, +    BytesIO, smart_text,      force_text, is_non_str_iterable  )  from rest_framework.settings import api_settings @@ -61,8 +63,10 @@ def get_component(obj, attr_name):  def readable_datetime_formats(formats): -    format = ', '.join(formats).replace(ISO_8601, -             'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') +    format = ', '.join(formats).replace( +        ISO_8601, +        'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' +    )      return humanize_strptime(format) @@ -154,7 +158,12 @@ class Field(object):      def widget_html(self):          if not self.widget:              return '' -        return self.widget.render(self._name, self._value) + +        attrs = {} +        if 'id' not in self.widget.attrs: +            attrs['id'] = self._name + +        return self.widget.render(self._name, self._value, attrs=attrs)      def label_tag(self):          return '<label for="%s">%s:</label>' % (self._name, self.label) @@ -164,7 +173,7 @@ class Field(object):          Called to set up a field prior to field_to_native or field_from_native.          parent - The parent serializer. -        model_field - The model field this field corresponds to, if one exists. +        field_name - The name of the field being initialized.          """          self.parent = parent          self.root = parent.root or parent @@ -182,7 +191,7 @@ class Field(object):      def field_to_native(self, obj, field_name):          """ -        Given and object and a field name, returns the value that should be +        Given an object and a field name, returns the value that should be          serialized for that field.          """          if obj is None: @@ -260,13 +269,6 @@ class WritableField(Field):                   validators=[], error_messages=None, widget=None,                   default=None, blank=None): -        # 'blank' is to be deprecated in favor of 'required' -        if blank is not None: -            warnings.warn('The `blank` keyword argument is deprecated. ' -                          'Use the `required` keyword argument instead.', -                          DeprecationWarning, stacklevel=2) -            required = not(blank) -          super(WritableField, self).__init__(source=source, label=label, help_text=help_text)          self.read_only = read_only @@ -289,7 +291,7 @@ class WritableField(Field):          self.validators = self.default_validators + validators          self.default = default if default is not None else self.default -        # Widgets are ony used for HTML forms. +        # Widgets are only used for HTML forms.          widget = widget or self.widget          if isinstance(widget, type):              widget = widget() @@ -425,7 +427,7 @@ class ModelField(WritableField):          } -##### Typed Fields ##### +# Typed Fields  class BooleanField(WritableField):      type_name = 'BooleanField' @@ -460,8 +462,9 @@ class CharField(WritableField):      type_label = 'string'      form_field_class = forms.CharField -    def __init__(self, max_length=None, min_length=None, *args, **kwargs): +    def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):          self.max_length, self.min_length = max_length, min_length +        self.allow_none = allow_none          super(CharField, self).__init__(*args, **kwargs)          if min_length is not None:              self.validators.append(validators.MinLengthValidator(min_length)) @@ -469,8 +472,12 @@ class CharField(WritableField):              self.validators.append(validators.MaxLengthValidator(max_length))      def from_native(self, value): -        if isinstance(value, six.string_types) or value is None: +        if isinstance(value, six.string_types):              return value + +        if value is None and not self.allow_none: +            return '' +          return smart_text(value) @@ -479,7 +486,7 @@ class URLField(CharField):      type_label = 'url'      def __init__(self, **kwargs): -        if not 'validators' in kwargs: +        if 'validators' not in kwargs:              kwargs['validators'] = [validators.URLValidator()]          super(URLField, self).__init__(**kwargs) @@ -501,7 +508,7 @@ class SlugField(CharField):  class ChoiceField(WritableField):      type_name = 'ChoiceField' -    type_label = 'multiple choice' +    type_label = 'choice'      form_field_class = forms.ChoiceField      widget = widgets.Select      default_error_messages = { @@ -509,12 +516,16 @@ class ChoiceField(WritableField):                              'the available choices.'),      } -    def __init__(self, choices=(), *args, **kwargs): +    def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):          self.empty = kwargs.pop('empty', '')          super(ChoiceField, self).__init__(*args, **kwargs)          self.choices = choices          if not self.required: -            self.choices = BLANK_CHOICE_DASH + self.choices +            if blank_display_value is None: +                blank_choice = BLANK_CHOICE_DASH +            else: +                blank_choice = [('', blank_display_value)] +            self.choices = blank_choice + self.choices      def _get_choices(self):          return self._choices @@ -1018,9 +1029,9 @@ class SerializerMethodField(Field):      A field that gets its value by calling a method on the serializer it's attached to.      """ -    def __init__(self, method_name): +    def __init__(self, method_name, *args, **kwargs):          self.method_name = method_name -        super(SerializerMethodField, self).__init__() +        super(SerializerMethodField, self).__init__(*args, **kwargs)      def field_to_native(self, obj, field_name):          value = getattr(self.parent, self.method_name)(obj) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 96d15eb9..e2080013 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -5,7 +5,8 @@ returned by list views.  from __future__ import unicode_literals  from django.core.exceptions import ImproperlyConfigured  from django.db import models -from rest_framework.compat import django_filters, six, guardian, get_model_name +from django.utils import six +from rest_framework.compat import django_filters, guardian, get_model_name  from rest_framework.settings import api_settings  from functools import reduce  import operator @@ -44,7 +45,7 @@ class DjangoFilterBackend(BaseFilterBackend):          if filter_class:              filter_model = filter_class.Meta.model -            assert issubclass(filter_model, queryset.model), \ +            assert issubclass(queryset.model, filter_model), \                  'FilterSet model %s does not match queryset model %s' % \                  (filter_model, queryset.model) @@ -116,6 +117,10 @@ class OrderingFilter(BaseFilterBackend):      def get_ordering(self, request):          """          Ordering is set by a comma delimited ?ordering=... query parameter. + +        The `ordering` query parameter can be overridden by setting +        the `ordering_param` value on the OrderingFilter or by +        specifying an `ORDERING_PARAM` value in the API settings.          """          params = request.QUERY_PARAMS.get(self.ordering_param)          if params: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 7bac510f..77deb8e4 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None):          ret = min(ret, cutoff)      return ret +  def get_object_or_404(queryset, *filter_args, **filter_kwargs):      """      Same as Django's standard shortcut, but make sure to raise 404 @@ -43,6 +44,10 @@ class GenericAPIView(views.APIView):      # You'll need to either set these attributes,      # or override `get_queryset()`/`get_serializer_class()`. +    # If you are overriding a view method, it is important that you call +    # `get_queryset()` instead of accessing the `queryset` property directly, +    # as `queryset` will get evaluated only once, and those results are cached +    # for all subsequent requests.      queryset = None      serializer_class = None @@ -90,8 +95,8 @@ class GenericAPIView(views.APIView):              'view': self          } -    def get_serializer(self, instance=None, data=None, -                       files=None, many=False, partial=False): +    def get_serializer(self, instance=None, data=None, files=None, many=False, +                       partial=False, allow_add_remove=False):          """          Return the serializer instance that should be used for validating and          deserializing input, and for serializing output. @@ -99,7 +104,9 @@ class GenericAPIView(views.APIView):          serializer_class = self.get_serializer_class()          context = self.get_serializer_context()          return serializer_class(instance, data=data, files=files, -                                many=many, partial=partial, context=context) +                                many=many, partial=partial, +                                allow_add_remove=allow_add_remove, +                                context=context)      def get_pagination_serializer(self, page):          """ @@ -121,11 +128,11 @@ class GenericAPIView(views.APIView):          deprecated_style = False          if page_size is not None:              warnings.warn('The `page_size` parameter to `paginate_queryset()` ' -                          'is due to be deprecated. ' +                          'is deprecated. '                            'Note that the return style of this method is also '                            'changed, and will simply return a page object '                            'when called without a `page_size` argument.', -                          PendingDeprecationWarning, stacklevel=2) +                          DeprecationWarning, stacklevel=2)              deprecated_style = True          else:              # Determine the required page size. @@ -136,10 +143,10 @@ class GenericAPIView(views.APIView):          if not self.allow_empty:              warnings.warn( -                'The `allow_empty` parameter is due to be deprecated. ' +                'The `allow_empty` parameter is deprecated. '                  'To use `allow_empty=False` style behavior, You should override '                  '`get_queryset()` and explicitly raise a 404 on empty querysets.', -                PendingDeprecationWarning, stacklevel=2 +                DeprecationWarning, stacklevel=2              )          paginator = self.paginator_class(queryset, page_size, @@ -156,10 +163,11 @@ class GenericAPIView(views.APIView):                  raise Http404(_("Page is not 'last', nor can it be converted to an int."))          try:              page = paginator.page(page_number) -        except InvalidPage as e: -            raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { -                                'page_number': page_number, -                                'message': str(e) +        except InvalidPage as exc: +            error_format = _('Invalid page (%(page_number)s): %(message)s') +            raise Http404(error_format % { +                'page_number': page_number, +                'message': str(exc)              })          if deprecated_style: @@ -183,22 +191,27 @@ class GenericAPIView(views.APIView):          """          Returns the list of filter backends that this view requires.          """ -        filter_backends = self.filter_backends or [] +        if self.filter_backends is None: +            filter_backends = [] +        else: +            # Note that we are returning a *copy* of the class attribute, +            # so that it is safe for the view to mutate it if needed. +            filter_backends = list(self.filter_backends) +          if not filter_backends and self.filter_backend:              warnings.warn(                  'The `filter_backend` attribute and `FILTER_BACKEND` setting ' -                'are due to be deprecated in favor of a `filter_backends` ' +                'are deprecated in favor of a `filter_backends` '                  'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '                  'a *list* of filter backend classes.', -                PendingDeprecationWarning, stacklevel=2 +                DeprecationWarning, stacklevel=2              )              filter_backends = [self.filter_backend] -        return filter_backends +        return filter_backends -    ######################## -    ### The following methods provide default implementations -    ### that you may want to override for more complex cases. +    # The following methods provide default implementations +    # that you may want to override for more complex cases.      def get_paginate_by(self, queryset=None):          """ @@ -211,8 +224,8 @@ class GenericAPIView(views.APIView):          """          if queryset is not None:              warnings.warn('The `queryset` parameter to `get_paginate_by()` ' -                          'is due to be deprecated.', -                          PendingDeprecationWarning, stacklevel=2) +                          'is deprecated.', +                          DeprecationWarning, stacklevel=2)          if self.paginate_by_param:              try: @@ -256,6 +269,10 @@ class GenericAPIView(views.APIView):          This must be an iterable, and may be a queryset.          Defaults to using `self.queryset`. +        This method should always be used rather than accessing `self.queryset` +        directly, as `self.queryset` gets evaluated only once, and those results +        are cached for all subsequent requests. +          You may want to override this if you need to provide different          querysets depending on the incoming request. @@ -267,8 +284,8 @@ class GenericAPIView(views.APIView):          if self.model is not None:              return self.model._default_manager.all() -        raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" -                                    % self.__class__.__name__) +        error_format = "'%s' must define 'queryset' or 'model'" +        raise ImproperlyConfigured(error_format % self.__class__.__name__)      def get_object(self, queryset=None):          """ @@ -295,16 +312,16 @@ class GenericAPIView(views.APIView):              filter_kwargs = {self.lookup_field: lookup}          elif pk is not None and self.lookup_field == 'pk':              warnings.warn( -                'The `pk_url_kwarg` attribute is due to be deprecated. ' +                'The `pk_url_kwarg` attribute is deprecated. '                  'Use the `lookup_field` attribute instead', -                PendingDeprecationWarning +                DeprecationWarning              )              filter_kwargs = {'pk': pk}          elif slug is not None and self.lookup_field == 'pk':              warnings.warn( -                'The `slug_url_kwarg` attribute is due to be deprecated. ' +                'The `slug_url_kwarg` attribute is deprecated. '                  'Use the `lookup_field` attribute instead', -                PendingDeprecationWarning +                DeprecationWarning              )              filter_kwargs = {self.slug_field: slug}          else: @@ -322,12 +339,11 @@ class GenericAPIView(views.APIView):          return obj -    ######################## -    ### The following are placeholder methods, -    ### and are intended to be overridden. -    ### -    ### The are not called by GenericAPIView directly, -    ### but are used by the mixin methods. +    # The following are placeholder methods, +    # and are intended to be overridden. +    # +    # The are not called by GenericAPIView directly, +    # but are used by the mixin methods.      def pre_save(self, obj):          """ @@ -399,10 +415,8 @@ class GenericAPIView(views.APIView):          return ret -########################################################## -### Concrete view classes that provide method handlers ### -### by composing the mixin classes with the base view. ### -########################################################## +# Concrete view classes that provide method handlers +# by composing the mixin classes with the base view.  class CreateAPIView(mixins.CreateModelMixin,                      GenericAPIView): @@ -517,16 +531,14 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,          return self.destroy(request, *args, **kwargs) -########################## -### Deprecated classes ### -########################## +# Deprecated classes  class MultipleObjectAPIView(GenericAPIView):      def __init__(self, *args, **kwargs):          warnings.warn( -            'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' +            'Subclassing `MultipleObjectAPIView` is deprecated. '              'You should simply subclass `GenericAPIView` instead.', -            PendingDeprecationWarning, stacklevel=2 +            DeprecationWarning, stacklevel=2          )          super(MultipleObjectAPIView, self).__init__(*args, **kwargs) @@ -534,8 +546,8 @@ class MultipleObjectAPIView(GenericAPIView):  class SingleObjectAPIView(GenericAPIView):      def __init__(self, *args, **kwargs):          warnings.warn( -            'Subclassing `SingleObjectAPIView` is due to be deprecated. ' +            'Subclassing `SingleObjectAPIView` is deprecated. '              'You should simply subclass `GenericAPIView` instead.', -            PendingDeprecationWarning, stacklevel=2 +            DeprecationWarning, stacklevel=2          )          super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index e1a24dc7..2cc87eef 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -26,14 +26,14 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)      include = []      if pk: -        # Pending deprecation +        # Deprecated          pk_field = obj._meta.pk          while pk_field.rel:              pk_field = pk_field.rel.to._meta.pk          include.append(pk_field.name)      if slug_field: -        # Pending deprecation +        # Deprecated          include.append(slug_field)      if lookup_field and lookup_field != 'pk': @@ -79,10 +79,10 @@ class ListModelMixin(object):          # `.allow_empty = False`, to raise 404 errors on empty querysets.          if not self.allow_empty and not self.object_list:              warnings.warn( -                'The `allow_empty` parameter is due to be deprecated. ' +                'The `allow_empty` parameter is deprecated. '                  'To use `allow_empty=False` style behavior, You should override '                  '`get_queryset()` and explicitly raise a 404 on empty querysets.', -                PendingDeprecationWarning +                DeprecationWarning              )              class_name = self.__class__.__name__              error_msg = self.empty_error % {'class_name': class_name} diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 4d205c0e..ca7b5397 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation):                  for media_type in media_type_set:                      if media_type_matches(renderer.media_type, media_type):                          # Return the most specific media type as accepted. -                        if (_MediaType(renderer.media_type).precedence > -                            _MediaType(media_type).precedence): +                        if ( +                            _MediaType(renderer.media_type).precedence > +                            _MediaType(media_type).precedence +                        ):                              # Eg client requests '*/*'                              # Accepted media type is 'application/json'                              return renderer, renderer.media_type diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index f1b3e38d..aa4fd3f1 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -10,7 +10,8 @@ from django.core.files.uploadhandler import StopFutureHandlers  from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter -from rest_framework.compat import etree, six, yaml +from django.utils import six +from rest_framework.compat import etree, yaml, force_text  from rest_framework.exceptions import ParseError  from rest_framework import renderers  import json @@ -288,7 +289,7 @@ class FileUploadParser(BaseParser):          try:              meta = parser_context['request'].META -            disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) -            return disposition[1]['filename'] +            disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) +            return force_text(disposition[1]['filename'])          except (AttributeError, KeyError):              pass diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index f24a5123..6a1a0077 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -2,15 +2,12 @@  Provides a set of pluggable permission policies.  """  from __future__ import unicode_literals -import inspect -import warnings - -SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] -  from django.http import Http404  from rest_framework.compat import (get_model_name, oauth2_provider_scope,                                     oauth2_constants) +SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] +  class BasePermission(object):      """ @@ -27,13 +24,6 @@ class BasePermission(object):          """          Return `True` if permission is granted, `False` otherwise.          """ -        if len(inspect.getargspec(self.has_permission).args) == 4: -            warnings.warn( -                'The `obj` argument in `has_permission` is deprecated. ' -                'Use `has_object_permission()` instead for object permissions.', -                DeprecationWarning, stacklevel=2 -            ) -            return self.has_permission(request, view, obj)          return True @@ -72,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission):      """      def has_permission(self, request, view): -        return (request.method in SAFE_METHODS or  -            request.user and  -            request.user.is_authenticated()) +        return ( +            request.method in SAFE_METHODS or +            request.user and +            request.user.is_authenticated() +        )  class DjangoModelPermissions(BasePermission): @@ -132,9 +124,11 @@ class DjangoModelPermissions(BasePermission):          perms = self.get_required_permissions(request.method, model_cls) -        return (request.user and +        return ( +            request.user and              (request.user.is_authenticated() or not self.authenticated_users_only) and -            request.user.has_perms(perms)) +            request.user.has_perms(perms) +        )  class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): @@ -222,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission):              required = oauth2_constants.READ if read_only else oauth2_constants.WRITE              return oauth2_provider_scope.check(required, request.auth.scope) -        assert False, ('TokenHasReadWriteScope requires either the' -        '`OAuthAuthentication` or `OAuth2Authentication` authentication ' -        'class to be used.') +        assert False, ( +            'TokenHasReadWriteScope requires either the' +            '`OAuthAuthentication` or `OAuth2Authentication` authentication ' +            'class to be used.' +        ) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 308545ce..1acbdce2 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -19,8 +19,7 @@ from rest_framework.compat import smart_text  import warnings -##### Relational fields ##### - +# Relational fields  # Not actually Writable, but subclasses may need to be.  class RelatedField(WritableField): @@ -41,14 +40,6 @@ class RelatedField(WritableField):      many = False      def __init__(self, *args, **kwargs): - -        # 'null' is to be deprecated in favor of 'required' -        if 'null' in kwargs: -            warnings.warn('The `null` keyword argument is deprecated. ' -                          'Use the `required` keyword argument instead.', -                          DeprecationWarning, stacklevel=2) -            kwargs['required'] = not kwargs.pop('null') -          queryset = kwargs.pop('queryset', None)          self.many = kwargs.pop('many', self.many)          if self.many: @@ -59,6 +50,8 @@ class RelatedField(WritableField):          super(RelatedField, self).__init__(*args, **kwargs)          if not self.required: +            # Accessed in ModelChoiceIterator django/forms/models.py:1034 +            # If set adds empty choice.              self.empty_label = BLANK_CHOICE_DASH[0][1]          self.queryset = queryset @@ -72,7 +65,7 @@ class RelatedField(WritableField):              else:  # Reverse                  self.queryset = manager.field.rel.to._default_manager.all() -    ### We need this stuff to make form choices work... +    # We need this stuff to make form choices work...      def prepare_value(self, obj):          return self.to_native(obj) @@ -119,7 +112,7 @@ class RelatedField(WritableField):      choices = property(_get_choices, _set_choices) -    ### Default value handling +    # Default value handling      def get_default_value(self):          default = super(RelatedField, self).get_default_value() @@ -127,7 +120,7 @@ class RelatedField(WritableField):              return []          return default -    ### Regular serializer stuff... +    # Regular serializer stuff...      def field_to_native(self, obj, field_name):          try: @@ -187,7 +180,7 @@ class RelatedField(WritableField):              into[(self.source or field_name)] = self.from_native(value) -### PrimaryKey relationships +# PrimaryKey relationships  class PrimaryKeyRelatedField(RelatedField):      """ @@ -275,8 +268,7 @@ class PrimaryKeyRelatedField(RelatedField):          return self.to_native(pk) -### Slug relationships - +# Slug relationships  class SlugRelatedField(RelatedField):      """ @@ -311,7 +303,7 @@ class SlugRelatedField(RelatedField):              raise ValidationError(msg) -### Hyperlinked relationships +# Hyperlinked relationships  class HyperlinkedRelatedField(RelatedField):      """ @@ -328,7 +320,7 @@ class HyperlinkedRelatedField(RelatedField):          'incorrect_type': _('Incorrect type.  Expected url string, received %s.'),      } -    # These are all pending deprecation +    # These are all deprecated      pk_url_kwarg = 'pk'      slug_field = 'slug'      slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden @@ -342,16 +334,16 @@ class HyperlinkedRelatedField(RelatedField):          self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)          self.format = kwargs.pop('format', None) -        # These are pending deprecation +        # These are deprecated          if 'pk_url_kwarg' in kwargs: -            msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          if 'slug_url_kwarg' in kwargs: -            msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          if 'slug_field' in kwargs: -            msg = 'slug_field is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'slug_field is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)          self.slug_field = kwargs.pop('slug_field', self.slug_field) @@ -394,9 +386,9 @@ class HyperlinkedRelatedField(RelatedField):                      # If the lookup succeeds using the default slug params,                      # then `slug_field` is being used implicitly, and we                      # we need to warn about the pending deprecation. -                    msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ +                    msg = 'Implicit slug field hyperlinked fields are deprecated.' \                            'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' -                    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +                    warnings.warn(msg, DeprecationWarning, stacklevel=2)                  return ret              except NoReverseMatch:                  pass @@ -430,14 +422,11 @@ class HyperlinkedRelatedField(RelatedField):          request = self.context.get('request', None)          format = self.format or self.context.get('format', None) -        if request is None: -            msg = ( -                "Using `HyperlinkedRelatedField` without including the request " -                "in the serializer context is deprecated. " -                "Add `context={'request': request}` when instantiating " -                "the serializer." -            ) -            warnings.warn(msg, DeprecationWarning, stacklevel=4) +        assert request is not None, ( +            "`HyperlinkedRelatedField` requires the request in the serializer " +            "context. Add `context={'request': request}` when instantiating " +            "the serializer." +        )          # If the object has not yet been saved then we cannot hyperlink to it.          if getattr(obj, 'pk', None) is None: @@ -497,7 +486,7 @@ class HyperlinkedIdentityField(Field):      lookup_field = 'pk'      read_only = True -    # These are all pending deprecation +    # These are all deprecated      pk_url_kwarg = 'pk'      slug_field = 'slug'      slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden @@ -513,16 +502,16 @@ class HyperlinkedIdentityField(Field):          lookup_field = kwargs.pop('lookup_field', None)          self.lookup_field = lookup_field or self.lookup_field -        # These are pending deprecation +        # These are deprecated          if 'pk_url_kwarg' in kwargs: -            msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          if 'slug_url_kwarg' in kwargs: -            msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          if 'slug_field' in kwargs: -            msg = 'slug_field is pending deprecation. Use lookup_field instead.' -            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +            msg = 'slug_field is deprecated. Use lookup_field instead.' +            warnings.warn(msg, DeprecationWarning, stacklevel=2)          self.slug_field = kwargs.pop('slug_field', self.slug_field)          default_slug_kwarg = self.slug_url_kwarg or self.slug_field @@ -536,11 +525,11 @@ class HyperlinkedIdentityField(Field):          format = self.context.get('format', None)          view_name = self.view_name -        if request is None: -            warnings.warn("Using `HyperlinkedIdentityField` without including the " -                          "request in the serializer context is deprecated. " -                          "Add `context={'request': request}` when instantiating the serializer.", -                          DeprecationWarning, stacklevel=4) +        assert request is not None, ( +            "`HyperlinkedIdentityField` requires the request in the serializer" +            " context. Add `context={'request': request}` when instantiating " +            "the serializer." +        )          # By default use whatever format is given for the current context          # unless the target is a different type to the source. @@ -604,41 +593,3 @@ class HyperlinkedIdentityField(Field):                  pass          raise NoReverseMatch() - - -### Old-style many classes for backwards compat - -class ManyRelatedField(RelatedField): -    def __init__(self, *args, **kwargs): -        warnings.warn('`ManyRelatedField()` is deprecated. ' -                      'Use `RelatedField(many=True)` instead.', -                       DeprecationWarning, stacklevel=2) -        kwargs['many'] = True -        super(ManyRelatedField, self).__init__(*args, **kwargs) - - -class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): -    def __init__(self, *args, **kwargs): -        warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. ' -                      'Use `PrimaryKeyRelatedField(many=True)` instead.', -                       DeprecationWarning, stacklevel=2) -        kwargs['many'] = True -        super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs) - - -class ManySlugRelatedField(SlugRelatedField): -    def __init__(self, *args, **kwargs): -        warnings.warn('`ManySlugRelatedField()` is deprecated. ' -                      'Use `SlugRelatedField(many=True)` instead.', -                       DeprecationWarning, stacklevel=2) -        kwargs['many'] = True -        super(ManySlugRelatedField, self).__init__(*args, **kwargs) - - -class ManyHyperlinkedRelatedField(HyperlinkedRelatedField): -    def __init__(self, *args, **kwargs): -        warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. ' -                      'Use `HyperlinkedRelatedField(many=True)` instead.', -                       DeprecationWarning, stacklevel=2) -        kwargs['many'] = True -        super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 7a7da561..748ebac9 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API.  """  from __future__ import unicode_literals -import copy  import json  import django  from django import forms @@ -16,11 +15,9 @@ from django.core.exceptions import ImproperlyConfigured  from django.http.multipartparser import parse_header  from django.template import RequestContext, loader, Template  from django.test.client import encode_multipart +from django.utils import six  from django.utils.xmlutils import SimplerXMLGenerator -from rest_framework.compat import StringIO -from rest_framework.compat import six -from rest_framework.compat import smart_text -from rest_framework.compat import yaml +from rest_framework.compat import StringIO, smart_text, yaml  from rest_framework.exceptions import ParseError  from rest_framework.settings import api_settings  from rest_framework.request import is_form_media_type, override_method @@ -54,35 +51,41 @@ class JSONRenderer(BaseRenderer):      format = 'json'      encoder_class = encoders.JSONEncoder      ensure_ascii = True -    charset = None -    # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32. + +    # We don't set a charset because JSON is a binary encoding, +    # that can be encoded as utf-8, utf-16 or utf-32.      # See: http://www.ietf.org/rfc/rfc4627.txt      # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/ +    charset = None + +    def get_indent(self, accepted_media_type, renderer_context): +        if accepted_media_type: +            # If the media type looks like 'application/json; indent=4', +            # then pretty print the result. +            base_media_type, params = parse_header(accepted_media_type.encode('ascii')) +            try: +                return max(min(int(params['indent']), 8), 0) +            except (KeyError, ValueError, TypeError): +                pass + +        # If 'indent' is provided in the context, then pretty print the result. +        # E.g. If we're being called by the BrowsableAPIRenderer. +        return renderer_context.get('indent', None)      def render(self, data, accepted_media_type=None, renderer_context=None):          """ -        Render `data` into JSON. +        Render `data` into JSON, returning a bytestring.          """          if data is None:              return bytes() -        # If 'indent' is provided in the context, then pretty print the result. -        # E.g. If we're being called by the BrowsableAPIRenderer.          renderer_context = renderer_context or {} -        indent = renderer_context.get('indent', None) +        indent = self.get_indent(accepted_media_type, renderer_context) -        if accepted_media_type: -            # If the media type looks like 'application/json; indent=4', -            # then pretty print the result. -            base_media_type, params = parse_header(accepted_media_type.encode('ascii')) -            indent = params.get('indent', indent) -            try: -                indent = max(min(int(indent), 8), 0) -            except (ValueError, TypeError): -                indent = None - -        ret = json.dumps(data, cls=self.encoder_class, -            indent=indent, ensure_ascii=self.ensure_ascii) +        ret = json.dumps( +            data, cls=self.encoder_class, +            indent=indent, ensure_ascii=self.ensure_ascii +        )          # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,          # but if ensure_ascii=False, the return type is underspecified, @@ -193,6 +196,7 @@ class YAMLRenderer(BaseRenderer):      format = 'yaml'      encoder = encoders.SafeDumper      charset = 'utf-8' +    ensure_ascii = True      def render(self, data, accepted_media_type=None, renderer_context=None):          """ @@ -203,7 +207,15 @@ class YAMLRenderer(BaseRenderer):          if data is None:              return '' -        return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder) +        return yaml.dump(data, stream=None, encoding=self.charset, Dumper=self.encoder, allow_unicode=not self.ensure_ascii) + + +class UnicodeYAMLRenderer(YAMLRenderer): +    """ +    Renderer which serializes to YAML. +    Does *not* apply character escaping for non-ascii characters. +    """ +    ensure_ascii = False  class TemplateHTMLRenderer(BaseRenderer): @@ -400,7 +412,7 @@ class BrowsableAPIRenderer(BaseRenderer):          """          Returns True if a form should be shown for this method.          """ -        if not method in view.allowed_methods: +        if method not in view.allowed_methods:              return  # Not a valid method          if not api_settings.FORM_METHOD_OVERRIDE: @@ -440,8 +452,10 @@ class BrowsableAPIRenderer(BaseRenderer):              if method in ('DELETE', 'OPTIONS'):                  return True  # Don't actually need to return a form -            if (not getattr(view, 'get_serializer', None) -                or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): +            if ( +                not getattr(view, 'get_serializer', None) +                or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes) +            ):                  return              serializer = view.get_serializer(instance=obj, data=data, files=files) @@ -562,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):              'version': VERSION,              'breadcrumblist': self.get_breadcrumbs(request),              'allowed_methods': view.allowed_methods, -            'available_formats': [renderer.format for renderer in view.renderer_classes], +            'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],              'response_headers': response_headers,              'put_form': self.get_rendered_html_form(view, 'PUT', request), @@ -611,4 +625,3 @@ class MultiPartRenderer(BaseRenderer):      def render(self, data, accepted_media_type=None, renderer_context=None):          return encode_multipart(self.BOUNDARY, data) - diff --git a/rest_framework/request.py b/rest_framework/request.py index 40467c03..27532661 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -42,13 +42,20 @@ class override_method(object):          self.view = view          self.request = request          self.method = method +        self.action = getattr(view, 'action', None)      def __enter__(self):          self.view.request = clone_request(self.request, self.method) +        if self.action is not None: +            # For viewsets we also set the `.action` attribute. +            action_map = getattr(self.view, 'action_map', {}) +            self.view.action = action_map.get(self.method.lower())          return self.view.request      def __exit__(self, *args, **kwarg):          self.view.request = self.request +        if self.action is not None: +            self.view.action = self.action  class Empty(object): @@ -280,16 +287,19 @@ class Request(object):              self._method = self._request.method              # Allow X-HTTP-METHOD-OVERRIDE header -            self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', -                                         self._method) +            if 'HTTP_X_HTTP_METHOD_OVERRIDE' in self.META: +                self._method = self.META['HTTP_X_HTTP_METHOD_OVERRIDE'].upper()      def _load_stream(self):          """          Return the content body of the request, as a stream.          """          try: -            content_length = int(self.META.get('CONTENT_LENGTH', -                                    self.META.get('HTTP_CONTENT_LENGTH'))) +            content_length = int( +                self.META.get( +                    'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH') +                ) +            )          except (ValueError, TypeError):              content_length = 0 @@ -313,9 +323,11 @@ class Request(object):          )          # We only need to use form overloading on form POST requests. -        if (not USE_FORM_OVERLOADING +        if ( +            not USE_FORM_OVERLOADING              or self._request.method != 'POST' -            or not is_form_media_type(self._content_type)): +            or not is_form_media_type(self._content_type) +        ):              return          # At this point we're committed to parsing the request as form data. @@ -323,15 +335,19 @@ class Request(object):          self._files = self._request.FILES          # Method overloading - change the method and remove the param from the content. -        if (self._METHOD_PARAM and -            self._METHOD_PARAM in self._data): +        if ( +            self._METHOD_PARAM and +            self._METHOD_PARAM in self._data +        ):              self._method = self._data[self._METHOD_PARAM].upper()          # Content overloading - modify the content type, and force re-parse. -        if (self._CONTENT_PARAM and +        if ( +            self._CONTENT_PARAM and              self._CONTENTTYPE_PARAM and              self._CONTENT_PARAM in self._data and -            self._CONTENTTYPE_PARAM in self._data): +            self._CONTENTTYPE_PARAM in self._data +        ):              self._content_type = self._data[self._CONTENTTYPE_PARAM]              self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))              self._data, self._files = (Empty, Empty) @@ -387,7 +403,7 @@ class Request(object):                  self._not_authenticated()                  raise -            if not user_auth_tuple is None: +            if user_auth_tuple is not None:                  self._authenticator = authenticator                  self._user, self._auth = user_auth_tuple                  return diff --git a/rest_framework/response.py b/rest_framework/response.py index 1dc6abcf..0a7d313f 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -5,9 +5,10 @@ it is initialized with unrendered data, instead of a pre-rendered string.  The appropriate renderer is called during Django's template response rendering.  """  from __future__ import unicode_literals +import django  from django.core.handlers.wsgi import STATUS_CODE_TEXT  from django.template.response import SimpleTemplateResponse -from rest_framework.compat import six +from django.utils import six  class Response(SimpleTemplateResponse): @@ -15,8 +16,11 @@ class Response(SimpleTemplateResponse):      An HttpResponse that allows its data to be rendered into      arbitrary media types.      """ +    # TODO: remove that once Django 1.3 isn't supported +    if django.VERSION >= (1, 4): +        rendering_attrs = SimpleTemplateResponse.rendering_attrs + ['_closable_objects'] -    def __init__(self, data=None, status=200, +    def __init__(self, data=None, status=None,                   template_name=None, headers=None,                   exception=False, content_type=None):          """ @@ -58,8 +62,10 @@ class Response(SimpleTemplateResponse):          ret = renderer.render(self.data, media_type, context)          if isinstance(ret, six.text_type): -            assert charset, 'renderer returned unicode, and did not specify ' \ -            'a charset value.' +            assert charset, ( +                'renderer returned unicode, and did not specify ' +                'a charset value.' +            )              return bytes(ret.encode(charset))          if not ret: diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 97b35c10..406ebcf7 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -17,15 +17,17 @@ from __future__ import unicode_literals  import itertools  from collections import namedtuple +from django.conf.urls import patterns, url  from django.core.exceptions import ImproperlyConfigured  from rest_framework import views -from rest_framework.compat import patterns, url  from rest_framework.response import Response  from rest_framework.reverse import reverse  from rest_framework.urlpatterns import format_suffix_patterns  Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) +DynamicDetailRoute = namedtuple('DynamicDetailRoute', ['url', 'name', 'initkwargs']) +DynamicListRoute = namedtuple('DynamicListRoute', ['url', 'name', 'initkwargs'])  def replace_methodname(format_string, methodname): @@ -88,6 +90,14 @@ class SimpleRouter(BaseRouter):              name='{basename}-list',              initkwargs={'suffix': 'List'}          ), +        # Dynamically generated list routes. +        # Generated using @list_route decorator +        # on methods of the viewset. +        DynamicListRoute( +            url=r'^{prefix}/{methodname}{trailing_slash}$', +            name='{basename}-{methodnamehyphen}', +            initkwargs={} +        ),          # Detail route.          Route(              url=r'^{prefix}/{lookup}{trailing_slash}$', @@ -100,13 +110,10 @@ class SimpleRouter(BaseRouter):              name='{basename}-detail',              initkwargs={'suffix': 'Instance'}          ), -        # Dynamically generated routes. -        # Generated using @action or @link decorators on methods of the viewset. -        Route( +        # Dynamically generated detail routes. +        # Generated using @detail_route decorator on methods of the viewset. +        DynamicDetailRoute(              url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$', -            mapping={ -                '{httpmethod}': '{methodname}', -            },              name='{basename}-{methodnamehyphen}',              initkwargs={}          ), @@ -139,25 +146,42 @@ class SimpleRouter(BaseRouter):          Returns a list of the Route namedtuple.          """ -        known_actions = flatten([route.mapping.values() for route in self.routes]) +        known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]) -        # Determine any `@action` or `@link` decorated methods on the viewset -        dynamic_routes = [] +        # Determine any `@detail_route` or `@list_route` decorated methods on the viewset +        detail_routes = [] +        list_routes = []          for methodname in dir(viewset):              attr = getattr(viewset, methodname)              httpmethods = getattr(attr, 'bind_to_methods', None) +            detail = getattr(attr, 'detail', True)              if httpmethods:                  if methodname in known_actions: -                    raise ImproperlyConfigured('Cannot use @action or @link decorator on ' -                                               'method "%s" as it is an existing route' % methodname) +                    raise ImproperlyConfigured('Cannot use @detail_route or @list_route ' +                                               'decorators on method "%s" ' +                                               'as it is an existing route' % methodname)                  httpmethods = [method.lower() for method in httpmethods] -                dynamic_routes.append((httpmethods, methodname)) +                if detail: +                    detail_routes.append((httpmethods, methodname)) +                else: +                    list_routes.append((httpmethods, methodname))          ret = []          for route in self.routes: -            if route.mapping == {'{httpmethod}': '{methodname}'}: -                # Dynamic routes (@link or @action decorator) -                for httpmethods, methodname in dynamic_routes: +            if isinstance(route, DynamicDetailRoute): +                # Dynamic detail routes (@detail_route decorator) +                for httpmethods, methodname in detail_routes: +                    initkwargs = route.initkwargs.copy() +                    initkwargs.update(getattr(viewset, methodname).kwargs) +                    ret.append(Route( +                        url=replace_methodname(route.url, methodname), +                        mapping=dict((httpmethod, methodname) for httpmethod in httpmethods), +                        name=replace_methodname(route.name, methodname), +                        initkwargs=initkwargs, +                    )) +            elif isinstance(route, DynamicListRoute): +                # Dynamic list routes (@list_route decorator) +                for httpmethods, methodname in list_routes:                      initkwargs = route.initkwargs.copy()                      initkwargs.update(getattr(viewset, methodname).kwargs)                      ret.append(Route( @@ -195,13 +219,16 @@ class SimpleRouter(BaseRouter):          https://github.com/alanjds/drf-nested-routers          """ -        if self.trailing_slash: -            base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/]+)' -        else: -            # Don't consume `.json` style suffixes -            base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/.]+)' +        base_regex = '(?P<{lookup_prefix}{lookup_field}>{lookup_value})' +        # Use `pk` as default field, unset set.  Default regex should not +        # consume `.json` style suffixes and should break at '/' boundaries.          lookup_field = getattr(viewset, 'lookup_field', 'pk') -        return base_regex.format(lookup_field=lookup_field, lookup_prefix=lookup_prefix) +        lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+') +        return base_regex.format( +            lookup_prefix=lookup_prefix, +            lookup_field=lookup_field, +            lookup_value=lookup_value +        )      def get_urls(self):          """ diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py deleted file mode 100755 index ce11b213..00000000 --- a/rest_framework/runtests/runcoverage.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python -""" -Useful tool to run the test suite for rest_framework and generate a coverage report. -""" - -# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/ -# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/ -# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py -import os -import sys - -# fix sys path so we don't need to setup PYTHONPATH -sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) -os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' - -from coverage import coverage - - -def main(): -    """Run the tests for rest_framework and generate a coverage report.""" - -    cov = coverage() -    cov.erase() -    cov.start() - -    from django.conf import settings -    from django.test.utils import get_runner -    TestRunner = get_runner(settings) - -    if hasattr(TestRunner, 'func_name'): -        # Pre 1.2 test runners were just functions, -        # and did not support the 'failfast' option. -        import warnings -        warnings.warn( -            'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.', -            DeprecationWarning -        ) -        failures = TestRunner(['tests']) -    else: -        test_runner = TestRunner() -        failures = test_runner.run_tests(['tests']) -    cov.stop() - -    # Discover the list of all modules that we should test coverage for -    import rest_framework - -    project_dir = os.path.dirname(rest_framework.__file__) -    cov_files = [] - -    for (path, dirs, files) in os.walk(project_dir): -        # Drop tests and runtests directories from the test coverage report -        if os.path.basename(path) in ['tests', 'runtests', 'migrations']: -            continue - -        # Drop the compat and six modules from coverage, since we're not interested in the coverage -        # of modules which are specifically for resolving environment dependant imports. -        # (Because we'll end up getting different coverage reports for it for each environment) -        if 'compat.py' in files: -            files.remove('compat.py') - -        if 'six.py' in files: -            files.remove('six.py') - -        # Same applies to template tags module. -        # This module has to include branching on Django versions, -        # so it's never possible for it to have full coverage. -        if 'rest_framework.py' in files: -            files.remove('rest_framework.py') - -        cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')]) - -    cov.report(cov_files) -    if '--html' in sys.argv: -        cov.html_report(cov_files, directory='coverage') -    sys.exit(failures) - -if __name__ == '__main__': -    main() diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py deleted file mode 100755 index 2daaae4e..00000000 --- a/rest_framework/runtests/runtests.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python - -# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/ -# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/ -# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py -import os -import sys - -# fix sys path so we don't need to setup PYTHONPATH -sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) -os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' - -import django -from django.conf import settings -from django.test.utils import get_runner - - -def usage(): -    return """ -    Usage: python runtests.py [UnitTestClass].[method] - -    You can pass the Class name of the `UnitTestClass` you want to test. - -    Append a method name if you only want to test a specific method of that class. -    """ - - -def main(): -    try: -        django.setup() -    except AttributeError: -        pass -    TestRunner = get_runner(settings) - -    test_runner = TestRunner() -    if len(sys.argv) == 2: -        test_case = '.' + sys.argv[1] -    elif len(sys.argv) == 1: -        test_case = '' -    else: -        print(usage()) -        sys.exit(1) -    test_module_name = 'rest_framework.tests' -    if django.VERSION[0] == 1 and django.VERSION[1] < 6: -        test_module_name = 'tests' - -    failures = test_runner.run_tests([test_module_name + test_case]) - -    sys.exit(failures) - -if __name__ == '__main__': -    main() diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py deleted file mode 100644 index 3fc0eb2f..00000000 --- a/rest_framework/runtests/settings.py +++ /dev/null @@ -1,169 +0,0 @@ -# Django settings for testproject project. - -DEBUG = True -TEMPLATE_DEBUG = DEBUG -DEBUG_PROPAGATE_EXCEPTIONS = True - -ALLOWED_HOSTS = ['*'] - -ADMINS = ( -    # ('Your Name', 'your_email@domain.com'), -) - -MANAGERS = ADMINS - -DATABASES = { -    'default': { -        'ENGINE': 'django.db.backends.sqlite3',  # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. -        'NAME': 'sqlite.db',                     # Or path to database file if using sqlite3. -        'USER': '',                      # Not used with sqlite3. -        'PASSWORD': '',                  # Not used with sqlite3. -        'HOST': '',                      # Set to empty string for localhost. Not used with sqlite3. -        'PORT': '',                      # Set to empty string for default. Not used with sqlite3. -    } -} - -CACHES = { -    'default': { -        'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', -    } -} - -# Local time zone for this installation. Choices can be found here: -# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name -# although not all choices may be available on all operating systems. -# On Unix systems, a value of None will cause Django to use the same -# timezone as the operating system. -# If running in a Windows environment this must be set to the same as your -# system time zone. -TIME_ZONE = 'Europe/London' - -# Language code for this installation. All choices can be found here: -# http://www.i18nguy.com/unicode/language-identifiers.html -LANGUAGE_CODE = 'en-uk' - -SITE_ID = 1 - -# If you set this to False, Django will make some optimizations so as not -# to load the internationalization machinery. -USE_I18N = True - -# If you set this to False, Django will not format dates, numbers and -# calendars according to the current locale -USE_L10N = True - -# Absolute filesystem path to the directory that will hold user-uploaded files. -# Example: "/home/media/media.lawrence.com/" -MEDIA_ROOT = '' - -# URL that handles the media served from MEDIA_ROOT. Make sure to use a -# trailing slash if there is a path component (optional in other cases). -# Examples: "http://media.lawrence.com", "http://example.com/media/" -MEDIA_URL = '' - -# Make this unique, and don't share it with anybody. -SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy' - -# List of callables that know how to import templates from various sources. -TEMPLATE_LOADERS = ( -    'django.template.loaders.filesystem.Loader', -    'django.template.loaders.app_directories.Loader', -#     'django.template.loaders.eggs.Loader', -) - -MIDDLEWARE_CLASSES = ( -    'django.middleware.common.CommonMiddleware', -    'django.contrib.sessions.middleware.SessionMiddleware', -    'django.middleware.csrf.CsrfViewMiddleware', -    'django.contrib.auth.middleware.AuthenticationMiddleware', -    'django.contrib.messages.middleware.MessageMiddleware', -) - -ROOT_URLCONF = 'urls' - -TEMPLATE_DIRS = ( -    # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". -    # Always use forward slashes, even on Windows. -    # Don't forget to use absolute paths, not relative paths. -) - -INSTALLED_APPS = ( -    'django.contrib.auth', -    'django.contrib.contenttypes', -    'django.contrib.sessions', -    'django.contrib.sites', -    'django.contrib.messages', -    # Uncomment the next line to enable the admin: -    # 'django.contrib.admin', -    # Uncomment the next line to enable admin documentation: -    # 'django.contrib.admindocs', -    'rest_framework', -    'rest_framework.authtoken', -    'rest_framework.tests', -    'rest_framework.tests.accounts', -    'rest_framework.tests.records', -    'rest_framework.tests.users', -) - -# OAuth is optional and won't work if there is no oauth_provider & oauth2 -try: -    import oauth_provider -    import oauth2 -except ImportError: -    pass -else: -    INSTALLED_APPS += ( -        'oauth_provider', -    ) - -try: -    import provider -except ImportError: -    pass -else: -    INSTALLED_APPS += ( -        'provider', -        'provider.oauth2', -    ) - -# guardian is optional -try: -    import guardian -except ImportError: -    pass -else: -    ANONYMOUS_USER_ID = -1 -    AUTHENTICATION_BACKENDS = ( -        'django.contrib.auth.backends.ModelBackend', # default -        'guardian.backends.ObjectPermissionBackend', -    ) -    INSTALLED_APPS += ( -        'guardian', -    ) - -STATIC_URL = '/static/' - -PASSWORD_HASHERS = ( -    'django.contrib.auth.hashers.SHA1PasswordHasher', -    'django.contrib.auth.hashers.PBKDF2PasswordHasher', -    'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', -    'django.contrib.auth.hashers.BCryptPasswordHasher', -    'django.contrib.auth.hashers.MD5PasswordHasher', -    'django.contrib.auth.hashers.CryptPasswordHasher', -) - -AUTH_USER_MODEL = 'auth.User' - -import django - -if django.VERSION < (1, 3): -    INSTALLED_APPS += ('staticfiles',) - - -# If we're running on the Jenkins server we want to archive the coverage reports as XML. -import os -if os.environ.get('HUDSON_URL', None): -    TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner' -    TEST_OUTPUT_VERBOSE = True -    TEST_OUTPUT_DESCRIPTIONS = True -    TEST_OUTPUT_DIR = 'xmlrunner' diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py deleted file mode 100644 index ed5baeae..00000000 --- a/rest_framework/runtests/urls.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Blank URLConf just to keep runtests.py happy. -""" -from rest_framework.compat import patterns - -urlpatterns = patterns('', -) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cb7539e0..be8ad3f2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -16,11 +16,13 @@ import datetime  import inspect  import types  from decimal import Decimal +from django.contrib.contenttypes.generic import GenericForeignKey  from django.core.paginator import Page  from django.db import models  from django.forms import widgets +from django.utils import six  from django.utils.datastructures import SortedDict -from rest_framework.compat import get_concrete_model, six +from django.core.exceptions import ObjectDoesNotExist  from rest_framework.settings import api_settings @@ -31,8 +33,8 @@ from rest_framework.settings import api_settings  # This helps keep the separation between model fields, form fields, and  # serializer fields more explicit. -from rest_framework.relations import * -from rest_framework.fields import * +from rest_framework.relations import *  # NOQA +from rest_framework.fields import *  # NOQA  def _resolve_model(obj): @@ -47,7 +49,7 @@ def _resolve_model(obj):      String representations should have the format:          'appname.ModelName'      """ -    if type(obj) == str and len(obj.split('.')) == 2: +    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:          app_name, model_name = obj.split('.')          return models.get_model(app_name, model_name)      elif inspect.isclass(obj) and issubclass(obj, models.Model): @@ -180,7 +182,7 @@ class BaseSerializer(WritableField):      _dict_class = SortedDictWithMetadata      def __init__(self, instance=None, data=None, files=None, -                 context=None, partial=False, many=None, +                 context=None, partial=False, many=False,                   allow_add_remove=False, **kwargs):          super(BaseSerializer, self).__init__(**kwargs)          self.opts = self._options_class(self.Meta) @@ -343,7 +345,7 @@ class BaseSerializer(WritableField):          for field_name, field in self.fields.items():              if field.read_only and obj is None: -               continue +                continue              field.initialize(parent=self, field_name=field_name)              key = self.get_field_key(field_name)              value = field.field_to_native(obj, field_name) @@ -410,12 +412,7 @@ class BaseSerializer(WritableField):          if value is None:              return None -        if self.many is not None: -            many = self.many -        else: -            many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type)) - -        if many: +        if self.many:              return [self.to_native(item) for item in value]          return self.to_native(value) @@ -452,9 +449,11 @@ class BaseSerializer(WritableField):                  # If we have a model manager or similar object then we need                  # to iterate through each instance. -                if (self.many and +                if ( +                    self.many and                      not hasattr(obj, '__iter__') and -                    is_simple_callable(getattr(obj, 'all', None))): +                    is_simple_callable(getattr(obj, 'all', None)) +                ):                      obj = obj.all()                  kwargs = { @@ -604,8 +603,10 @@ class BaseSerializer(WritableField):          API schemas for auto-documentation.          """          return SortedDict( -            [(field_name, field.metadata()) -            for field_name, field in six.iteritems(self.fields)] +            [ +                (field_name, field.metadata()) +                for field_name, field in six.iteritems(self.fields) +            ]          ) @@ -659,9 +660,11 @@ class ModelSerializer(Serializer):          """          cls = self.opts.model -        assert cls is not None, \ -                "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ -        opts = get_concrete_model(cls)._meta +        assert cls is not None, ( +            "Serializer class '%s' is missing 'model' Meta option" % +            self.__class__.__name__ +        ) +        opts = cls._meta.concrete_model._meta          ret = SortedDict()          nested = bool(self.opts.depth) @@ -671,9 +674,9 @@ class ModelSerializer(Serializer):              # If model is a child via multitable inheritance, use parent's pk              pk_field = pk_field.rel.to._meta.pk -        field = self.get_pk_field(pk_field) -        if field: -            ret[pk_field.name] = field +        serializer_pk_field = self.get_pk_field(pk_field) +        if serializer_pk_field: +            ret[pk_field.name] = serializer_pk_field          # Deal with forward relationships          forward_rels = [field for field in opts.fields if field.serialize] @@ -694,10 +697,10 @@ class ModelSerializer(Serializer):                  if len(inspect.getargspec(self.get_nested_field).args) == 2:                      warnings.warn(                          'The `get_nested_field(model_field)` call signature ' -                        'is due to be deprecated. ' +                        'is deprecated. '                          'Use `get_nested_field(model_field, related_model, '                          'to_many) instead', -                        PendingDeprecationWarning +                        DeprecationWarning                      )                      field = self.get_nested_field(model_field)                  else: @@ -706,10 +709,10 @@ class ModelSerializer(Serializer):                  if len(inspect.getargspec(self.get_nested_field).args) == 3:                      warnings.warn(                          'The `get_related_field(model_field, to_many)` call ' -                        'signature is due to be deprecated. ' +                        'signature is deprecated. '                          'Use `get_related_field(model_field, related_model, '                          'to_many) instead', -                        PendingDeprecationWarning +                        DeprecationWarning                      )                      field = self.get_related_field(model_field, to_many=to_many)                  else: @@ -742,9 +745,11 @@ class ModelSerializer(Serializer):              is_m2m = isinstance(relation.field,                                  models.fields.related.ManyToManyField) -            if (is_m2m and +            if ( +                is_m2m and                  hasattr(relation.field.rel, 'through') and -                not relation.field.rel.through._meta.auto_created): +                not relation.field.rel.through._meta.auto_created +            ):                  has_through_model = True              if nested: @@ -757,9 +762,9 @@ class ModelSerializer(Serializer):                      field.read_only = True                  ret[accessor_name] = field -         +          # Ensure that 'read_only_fields' is an iterable -        assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'  +        assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'          # Add the `read_only` flag to any fields that have been specified          # in the `read_only_fields` option @@ -774,10 +779,10 @@ class ModelSerializer(Serializer):                  "on serializer '%s'." %                  (field_name, self.__class__.__name__))              ret[field_name].read_only = True -         +          # Ensure that 'write_only_fields' is an iterable -        assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'  -         +        assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' +          for field_name in self.opts.write_only_fields:              assert field_name not in self.base_fields.keys(), (                  "field '%s' on serializer '%s' specified in " @@ -788,7 +793,7 @@ class ModelSerializer(Serializer):                  "Non-existant field '%s' specified in `write_only_fields` "                  "on serializer '%s'." %                  (field_name, self.__class__.__name__)) -            ret[field_name].write_only = True             +            ret[field_name].write_only = True          return ret @@ -827,6 +832,19 @@ class ModelSerializer(Serializer):          if model_field:              kwargs['required'] = not(model_field.null or model_field.blank) +            if model_field.help_text is not None: +                kwargs['help_text'] = model_field.help_text +            if model_field.verbose_name is not None: +                kwargs['label'] = model_field.verbose_name + +            if not model_field.editable: +                kwargs['read_only'] = True + +            if model_field.verbose_name is not None: +                kwargs['label'] = model_field.verbose_name + +            if model_field.help_text is not None: +                kwargs['help_text'] = model_field.help_text          return PrimaryKeyRelatedField(**kwargs) @@ -866,6 +884,10 @@ class ModelSerializer(Serializer):                  issubclass(model_field.__class__, models.PositiveSmallIntegerField):              kwargs['min_value'] = 0 +        if model_field.null and \ +                issubclass(model_field.__class__, (models.CharField, models.TextField)): +            kwargs['allow_none'] = True +          attribute_dict = {              models.CharField: ['max_length'],              models.CommaSeparatedIntegerField: ['max_length'], @@ -892,15 +914,17 @@ class ModelSerializer(Serializer):          Return a list of field names to exclude from model validation.          """          cls = self.opts.model -        opts = get_concrete_model(cls)._meta +        opts = cls._meta.concrete_model._meta          exclusions = [field.name for field in opts.fields + opts.many_to_many]          for field_name, field in self.fields.items():              field_name = field.source or field_name -            if field_name in exclusions \ -                and not field.read_only \ -                and (field.required or hasattr(instance, field_name)) \ -                and not isinstance(field, Serializer): +            if ( +                field_name in exclusions +                and not field.read_only +                and (field.required or hasattr(instance, field_name)) +                and not isinstance(field, Serializer) +            ):                  exclusions.remove(field_name)          return exclusions @@ -943,6 +967,8 @@ class ModelSerializer(Serializer):          # Forward m2m relations          for field in meta.many_to_many + meta.virtual_fields: +            if isinstance(field, GenericForeignKey): +                continue              if field.name in attrs:                  m2m_data[field.name] = attrs.pop(field.name) @@ -952,17 +978,15 @@ class ModelSerializer(Serializer):              if isinstance(self.fields.get(field_name, None), Serializer):                  nested_forward_relations[field_name] = attrs[field_name] -        # Update an existing instance... -        if instance is not None: -            for key, val in attrs.items(): -                try: -                    setattr(instance, key, val) -                except ValueError: -                    self._errors[key] = self.error_messages['required'] +        # Create an empty instance of the model +        if instance is None: +            instance = self.opts.model() -        # ...or create a new instance -        else: -            instance = self.opts.model(**attrs) +        for key, val in attrs.items(): +            try: +                setattr(instance, key, val) +            except ValueError: +                self._errors[key] = [self.error_messages['required']]          # Any relations that cannot be set until we've          # saved the model get hidden away on these @@ -1087,6 +1111,10 @@ class HyperlinkedModelSerializer(ModelSerializer):          if model_field:              kwargs['required'] = not(model_field.null or model_field.blank) +            if model_field.help_text is not None: +                kwargs['help_text'] = model_field.help_text +            if model_field.verbose_name is not None: +                kwargs['label'] = model_field.verbose_name          if self.opts.lookup_field:              kwargs['lookup_field'] = self.opts.lookup_field diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 38753c96..644751f8 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -18,12 +18,9 @@ REST framework settings, checking for user settings first, then falling  back to the defaults.  """  from __future__ import unicode_literals -  from django.conf import settings -from django.utils import importlib - +from django.utils import importlib, six  from rest_framework import ISO_8601 -from rest_framework.compat import six  USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) @@ -46,16 +43,12 @@ DEFAULTS = {      'DEFAULT_PERMISSION_CLASSES': (          'rest_framework.permissions.AllowAny',      ), -    'DEFAULT_THROTTLE_CLASSES': ( -    ), -    'DEFAULT_CONTENT_NEGOTIATION_CLASS': -        'rest_framework.negotiation.DefaultContentNegotiation', +    'DEFAULT_THROTTLE_CLASSES': (), +    'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',      # Genric view behavior -    'DEFAULT_MODEL_SERIALIZER_CLASS': -        'rest_framework.serializers.ModelSerializer', -    'DEFAULT_PAGINATION_SERIALIZER_CLASS': -        'rest_framework.pagination.PaginationSerializer', +    'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', +    'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',      'DEFAULT_FILTER_BACKENDS': (),      # Throttling @@ -63,6 +56,7 @@ DEFAULTS = {          'user': None,          'anon': None,      }, +    'NUM_PROXIES': None,      # Pagination      'PAGINATE_BY': None, @@ -119,6 +113,7 @@ DEFAULTS = {      # Pending deprecation      'FILTER_BACKEND': None, +  } diff --git a/rest_framework/six.py b/rest_framework/six.py deleted file mode 100644 index 9e382312..00000000 --- a/rest_framework/six.py +++ /dev/null @@ -1,389 +0,0 @@ -"""Utilities for writing code that runs on Python 2 and 3""" - -import operator -import sys -import types - -__author__ = "Benjamin Peterson <benjamin@python.org>" -__version__ = "1.2.0" - - -# True if we are running on Python 3. -PY3 = sys.version_info[0] == 3 - -if PY3: -    string_types = str, -    integer_types = int, -    class_types = type, -    text_type = str -    binary_type = bytes - -    MAXSIZE = sys.maxsize -else: -    string_types = basestring, -    integer_types = (int, long) -    class_types = (type, types.ClassType) -    text_type = unicode -    binary_type = str - -    if sys.platform == "java": -        # Jython always uses 32 bits. -        MAXSIZE = int((1 << 31) - 1) -    else: -        # It's possible to have sizeof(long) != sizeof(Py_ssize_t). -        class X(object): -            def __len__(self): -                return 1 << 31 -        try: -            len(X()) -        except OverflowError: -            # 32-bit -            MAXSIZE = int((1 << 31) - 1) -        else: -            # 64-bit -            MAXSIZE = int((1 << 63) - 1) -            del X - - -def _add_doc(func, doc): -    """Add documentation to a function.""" -    func.__doc__ = doc - - -def _import_module(name): -    """Import module, returning the module after the last dot.""" -    __import__(name) -    return sys.modules[name] - - -class _LazyDescr(object): - -    def __init__(self, name): -        self.name = name - -    def __get__(self, obj, tp): -        result = self._resolve() -        setattr(obj, self.name, result) -        # This is a bit ugly, but it avoids running this again. -        delattr(tp, self.name) -        return result - - -class MovedModule(_LazyDescr): - -    def __init__(self, name, old, new=None): -        super(MovedModule, self).__init__(name) -        if PY3: -            if new is None: -                new = name -            self.mod = new -        else: -            self.mod = old - -    def _resolve(self): -        return _import_module(self.mod) - - -class MovedAttribute(_LazyDescr): - -    def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): -        super(MovedAttribute, self).__init__(name) -        if PY3: -            if new_mod is None: -                new_mod = name -            self.mod = new_mod -            if new_attr is None: -                if old_attr is None: -                    new_attr = name -                else: -                    new_attr = old_attr -            self.attr = new_attr -        else: -            self.mod = old_mod -            if old_attr is None: -                old_attr = name -            self.attr = old_attr - -    def _resolve(self): -        module = _import_module(self.mod) -        return getattr(module, self.attr) - - - -class _MovedItems(types.ModuleType): -    """Lazy loading of moved objects""" - - -_moved_attributes = [ -    MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), -    MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), -    MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), -    MovedAttribute("map", "itertools", "builtins", "imap", "map"), -    MovedAttribute("reload_module", "__builtin__", "imp", "reload"), -    MovedAttribute("reduce", "__builtin__", "functools"), -    MovedAttribute("StringIO", "StringIO", "io"), -    MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), -    MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - -    MovedModule("builtins", "__builtin__"), -    MovedModule("configparser", "ConfigParser"), -    MovedModule("copyreg", "copy_reg"), -    MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), -    MovedModule("http_cookies", "Cookie", "http.cookies"), -    MovedModule("html_entities", "htmlentitydefs", "html.entities"), -    MovedModule("html_parser", "HTMLParser", "html.parser"), -    MovedModule("http_client", "httplib", "http.client"), -    MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), -    MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), -    MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), -    MovedModule("cPickle", "cPickle", "pickle"), -    MovedModule("queue", "Queue"), -    MovedModule("reprlib", "repr"), -    MovedModule("socketserver", "SocketServer"), -    MovedModule("tkinter", "Tkinter"), -    MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), -    MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), -    MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), -    MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), -    MovedModule("tkinter_tix", "Tix", "tkinter.tix"), -    MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), -    MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), -    MovedModule("tkinter_colorchooser", "tkColorChooser", -                "tkinter.colorchooser"), -    MovedModule("tkinter_commondialog", "tkCommonDialog", -                "tkinter.commondialog"), -    MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), -    MovedModule("tkinter_font", "tkFont", "tkinter.font"), -    MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), -    MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", -                "tkinter.simpledialog"), -    MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), -    MovedModule("winreg", "_winreg"), -] -for attr in _moved_attributes: -    setattr(_MovedItems, attr.name, attr) -del attr - -moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves") - - -def add_move(move): -    """Add an item to six.moves.""" -    setattr(_MovedItems, move.name, move) - - -def remove_move(name): -    """Remove item from six.moves.""" -    try: -        delattr(_MovedItems, name) -    except AttributeError: -        try: -            del moves.__dict__[name] -        except KeyError: -            raise AttributeError("no such move, %r" % (name,)) - - -if PY3: -    _meth_func = "__func__" -    _meth_self = "__self__" - -    _func_code = "__code__" -    _func_defaults = "__defaults__" - -    _iterkeys = "keys" -    _itervalues = "values" -    _iteritems = "items" -else: -    _meth_func = "im_func" -    _meth_self = "im_self" - -    _func_code = "func_code" -    _func_defaults = "func_defaults" - -    _iterkeys = "iterkeys" -    _itervalues = "itervalues" -    _iteritems = "iteritems" - - -try: -    advance_iterator = next -except NameError: -    def advance_iterator(it): -        return it.next() -next = advance_iterator - - -if PY3: -    def get_unbound_function(unbound): -        return unbound - -    Iterator = object - -    def callable(obj): -        return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) -else: -    def get_unbound_function(unbound): -        return unbound.im_func - -    class Iterator(object): - -        def next(self): -            return type(self).__next__(self) - -    callable = callable -_add_doc(get_unbound_function, -         """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) - - -def iterkeys(d): -    """Return an iterator over the keys of a dictionary.""" -    return iter(getattr(d, _iterkeys)()) - -def itervalues(d): -    """Return an iterator over the values of a dictionary.""" -    return iter(getattr(d, _itervalues)()) - -def iteritems(d): -    """Return an iterator over the (key, value) pairs of a dictionary.""" -    return iter(getattr(d, _iteritems)()) - - -if PY3: -    def b(s): -        return s.encode("latin-1") -    def u(s): -        return s -    if sys.version_info[1] <= 1: -        def int2byte(i): -            return bytes((i,)) -    else: -        # This is about 2x faster than the implementation above on 3.2+ -        int2byte = operator.methodcaller("to_bytes", 1, "big") -    import io -    StringIO = io.StringIO -    BytesIO = io.BytesIO -else: -    def b(s): -        return s -    def u(s): -        return unicode(s, "unicode_escape") -    int2byte = chr -    import StringIO -    StringIO = BytesIO = StringIO.StringIO -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -if PY3: -    import builtins -    exec_ = getattr(builtins, "exec") - - -    def reraise(tp, value, tb=None): -        if value.__traceback__ is not tb: -            raise value.with_traceback(tb) -        raise value - - -    print_ = getattr(builtins, "print") -    del builtins - -else: -    def exec_(code, globs=None, locs=None): -        """Execute code in a namespace.""" -        if globs is None: -            frame = sys._getframe(1) -            globs = frame.f_globals -            if locs is None: -                locs = frame.f_locals -            del frame -        elif locs is None: -            locs = globs -        exec("""exec code in globs, locs""") - - -    exec_("""def reraise(tp, value, tb=None): -    raise tp, value, tb -""") - - -    def print_(*args, **kwargs): -        """The new-style print function.""" -        fp = kwargs.pop("file", sys.stdout) -        if fp is None: -            return -        def write(data): -            if not isinstance(data, basestring): -                data = str(data) -            fp.write(data) -        want_unicode = False -        sep = kwargs.pop("sep", None) -        if sep is not None: -            if isinstance(sep, unicode): -                want_unicode = True -            elif not isinstance(sep, str): -                raise TypeError("sep must be None or a string") -        end = kwargs.pop("end", None) -        if end is not None: -            if isinstance(end, unicode): -                want_unicode = True -            elif not isinstance(end, str): -                raise TypeError("end must be None or a string") -        if kwargs: -            raise TypeError("invalid keyword arguments to print()") -        if not want_unicode: -            for arg in args: -                if isinstance(arg, unicode): -                    want_unicode = True -                    break -        if want_unicode: -            newline = unicode("\n") -            space = unicode(" ") -        else: -            newline = "\n" -            space = " " -        if sep is None: -            sep = space -        if end is None: -            end = newline -        for i, arg in enumerate(args): -            if i: -                write(sep) -            write(arg) -        write(end) - -_add_doc(reraise, """Reraise an exception.""") - - -def with_metaclass(meta, base=object): -    """Create a base class with a metaclass.""" -    return meta("NewBase", (base,), {}) - - -### Additional customizations for Django ### - -if PY3: -    _iterlists = "lists" -    _assertRaisesRegex = "assertRaisesRegex" -else: -    _iterlists = "iterlists" -    _assertRaisesRegex = "assertRaisesRegexp" - - -def iterlists(d): -    """Return an iterator over the values of a MultiValueDict.""" -    return getattr(d, _iterlists)() - - -def assertRaisesRegex(self, *args, **kwargs): -    return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -add_move(MovedModule("_dummy_thread", "dummy_thread")) -add_move(MovedModule("_thread", "thread")) diff --git a/rest_framework/status.py b/rest_framework/status.py index 76435371..90a75508 100644 --- a/rest_framework/status.py +++ b/rest_framework/status.py @@ -10,15 +10,19 @@ from __future__ import unicode_literals  def is_informational(code):      return code >= 100 and code <= 199 +  def is_success(code):      return code >= 200 and code <= 299 +  def is_redirect(code):      return code >= 300 and code <= 399 +  def is_client_error(code):      return code >= 400 and code <= 499 +  def is_server_error(code):      return code >= 500 and code <= 599 diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 7067ee2f..b6e9ca5c 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -1,4 +1,5 @@  {% load url from future %} +{% load staticfiles %}  {% load rest_framework %}  <!DOCTYPE html>  <html> @@ -24,6 +25,7 @@      {% endblock %}      </head> +  {% block body %}    <body class="{% block bodyclass %}{% endblock %} container">      <div class="wrapper"> @@ -93,7 +95,7 @@          {% endif %}          {% if options_form %} -            <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right"> +            <form class="button-form" action="{{ request.get_full_path }}" method="POST">                  {% csrf_token %}                  <input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="OPTIONS" />                  <button class="btn btn-primary js-tooltip" title="Make an OPTIONS request on the {{ name }} resource">OPTIONS</button> @@ -101,7 +103,7 @@          {% endif %}          {% if delete_form %} -            <form class="button-form" action="{{ request.get_full_path }}" method="POST" class="pull-right"> +            <form class="button-form" action="{{ request.get_full_path }}" method="POST">                  {% csrf_token %}                  <input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="DELETE" />                  <button class="btn btn-danger js-tooltip" title="Make a DELETE request on the {{ name }} resource">DELETE</button> @@ -230,4 +232,5 @@      <script src="{% static "rest_framework/js/default.js" %}"></script>      {% endblock %}    </body> +  {% endblock %}  </html> diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html index be9a0072..43860e53 100644 --- a/rest_framework/templates/rest_framework/login_base.html +++ b/rest_framework/templates/rest_framework/login_base.html @@ -1,17 +1,9 @@ +{% extends "rest_framework/base.html" %}  {% load url from future %} +{% load staticfiles %}  {% load rest_framework %} -<html> - -    <head> -        {% block style %} -        {% block bootstrap_theme %} -            <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/> -            <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> -        {% endblock %} -        <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> -        {% endblock %} -    </head> +    {% block body %}      <body class="container">          <div class="container-fluid" style="margin-top: 30px"> @@ -50,4 +42,4 @@              </div><!-- /.row-fluid -->          </div><!-- /.container-fluid -->      </body> -</html> +    {% endblock %} diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index beb8c5b0..b80a7d77 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -2,98 +2,17 @@ from __future__ import unicode_literals, absolute_import  from django import template  from django.core.urlresolvers import reverse, NoReverseMatch  from django.http import QueryDict +from django.utils import six  from django.utils.encoding import iri_to_uri  from django.utils.html import escape  from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse, force_text, six, smart_urlquote +from django.utils.html import smart_urlquote +from rest_framework.compat import urlparse, force_text  import re  register = template.Library() -# Note we don't use 'load staticfiles', because we need a 1.3 compatible -# version, so instead we include the `static` template tag ourselves. - -# When 1.3 becomes unsupported by REST framework, we can instead start to -# use the {% load staticfiles %} tag, remove the following code, -# and add a dependency that `django.contrib.staticfiles` must be installed. - -# Note: We can't put this into the `compat` module because the compat import -# from rest_framework.compat import ... -# conflicts with this rest_framework template tag module. - -try:  # Django 1.5+ -    from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode - -    @register.tag('static') -    def do_static(parser, token): -        return StaticFilesNode.handle_token(parser, token) - -except ImportError: -    try:  # Django 1.4 -        from django.contrib.staticfiles.storage import staticfiles_storage - -        @register.simple_tag -        def static(path): -            """ -            A template tag that returns the URL to a file -            using staticfiles' storage backend -            """ -            return staticfiles_storage.url(path) - -    except ImportError:  # Django 1.3 -        from urlparse import urljoin -        from django import template -        from django.templatetags.static import PrefixNode - -        class StaticNode(template.Node): -            def __init__(self, varname=None, path=None): -                if path is None: -                    raise template.TemplateSyntaxError( -                        "Static template nodes must be given a path to return.") -                self.path = path -                self.varname = varname - -            def url(self, context): -                path = self.path.resolve(context) -                return self.handle_simple(path) - -            def render(self, context): -                url = self.url(context) -                if self.varname is None: -                    return url -                context[self.varname] = url -                return '' - -            @classmethod -            def handle_simple(cls, path): -                return urljoin(PrefixNode.handle_simple("STATIC_URL"), path) - -            @classmethod -            def handle_token(cls, parser, token): -                """ -                Class method to parse prefix node and return a Node. -                """ -                bits = token.split_contents() - -                if len(bits) < 2: -                    raise template.TemplateSyntaxError( -                        "'%s' takes at least one argument (path to file)" % bits[0]) - -                path = parser.compile_filter(bits[1]) - -                if len(bits) >= 2 and bits[-2] == 'as': -                    varname = bits[3] -                else: -                    varname = None - -                return cls(varname, path) - -        @register.tag('static') -        def do_static_13(parser, token): -            return StaticNode.handle_token(parser, token) - -  def replace_query_param(url, key, val):      """      Given a URL and a key/val pair, set or replace an item in the query @@ -122,7 +41,7 @@ def optional_login(request):      except NoReverseMatch:          return '' -    snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path) +    snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, escape(request.path))      return snippet @@ -136,7 +55,7 @@ def optional_logout(request):      except NoReverseMatch:          return '' -    snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path) +    snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, escape(request.path))      return snippet @@ -180,7 +99,7 @@ def add_class(value, css_class):  # Bunch of stuff cloned from urlize -TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]  WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'),                          ('"', '"'), ("'", "'")]  word_split_re = re.compile(r'(\s+)') @@ -234,8 +153,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru                      middle = middle[len(opening):]                      lead = lead + opening                  # Keep parentheses at the end only if they're balanced. -                if (middle.endswith(closing) -                    and middle.count(closing) == middle.count(opening) + 1): +                if ( +                    middle.endswith(closing) +                    and middle.count(closing) == middle.count(opening) + 1 +                ):                      middle = middle[:-len(closing)]                      trail = closing + trail @@ -246,7 +167,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru                  url = smart_urlquote_wrapper(middle)              elif simple_url_2_re.match(middle):                  url = smart_urlquote_wrapper('http://%s' % middle) -            elif not ':' in middle and simple_email_re.match(middle): +            elif ':' not in middle and simple_email_re.match(middle):                  local, domain = middle.rsplit('@', 1)                  try:                      domain = domain.encode('idna').decode('ascii') diff --git a/rest_framework/test.py b/rest_framework/test.py index df5a5b3b..f89a6dcd 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -8,10 +8,11 @@ from django.conf import settings  from django.test.client import Client as DjangoClient  from django.test.client import ClientHandler  from django.test import testcases +from django.utils import six  from django.utils.http import urlencode  from rest_framework.settings import api_settings  from rest_framework.compat import RequestFactory as DjangoRequestFactory -from rest_framework.compat import force_bytes_or_smart_bytes, six +from rest_framework.compat import force_bytes_or_smart_bytes  def force_authenticate(request, user=None, token=None): @@ -36,7 +37,7 @@ class APIRequestFactory(DjangoRequestFactory):          """          if not data: -            return ('', None) +            return ('', content_type)          assert format is None or content_type is None, (              'You may not set both `format` and `content_type`.' @@ -49,9 +50,10 @@ class APIRequestFactory(DjangoRequestFactory):          else:              format = format or self.default_format -            assert format in self.renderer_classes, ("Invalid format '{0}'. " -                "Available formats are {1}.  Set TEST_REQUEST_RENDERER_CLASSES " -                "to enable extra request formats.".format( +            assert format in self.renderer_classes, ( +                "Invalid format '{0}'. Available formats are {1}. " +                "Set TEST_REQUEST_RENDERER_CLASSES to enable " +                "extra request formats.".format(                      format,                      ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])                  ) @@ -154,6 +156,10 @@ class APIClient(APIRequestFactory, DjangoClient):          kwargs.update(self._credentials)          return super(APIClient, self).request(**kwargs) +    def logout(self): +        self._credentials = {} +        return super(APIClient, self).logout() +  class APITransactionTestCase(testcases.TransactionTestCase):      client_class = APIClient diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/rest_framework/tests/__init__.py +++ /dev/null diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/rest_framework/tests/accounts/__init__.py +++ /dev/null diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py deleted file mode 100644 index 525e601b..00000000 --- a/rest_framework/tests/accounts/models.py +++ /dev/null @@ -1,8 +0,0 @@ -from django.db import models - -from rest_framework.tests.users.models import User - - -class Account(models.Model): -    owner = models.ForeignKey(User, related_name='accounts_owned') -    admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py deleted file mode 100644 index a27b9ca6..00000000 --- a/rest_framework/tests/accounts/serializers.py +++ /dev/null @@ -1,11 +0,0 @@ -from rest_framework import serializers - -from rest_framework.tests.accounts.models import Account -from rest_framework.tests.users.serializers import UserSerializer - - -class AccountSerializer(serializers.ModelSerializer): -    admins = UserSerializer(many=True) - -    class Meta: -        model = Account diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py deleted file mode 100644 index b46d7f54..00000000 --- a/rest_framework/tests/description.py +++ /dev/null @@ -1,26 +0,0 @@ -# -- coding: utf-8 -- - -# Apparently there is a python 2.6 issue where docstrings of imported view classes -# do not retain their encoding information even if a module has a proper -# encoding declaration at the top of its source file. Therefore for tests -# to catch unicode related errors, a mock view has to be declared in a separate -# module. - -from rest_framework.views import APIView - - -# test strings snatched from http://www.columbia.edu/~fdc/utf8/, -# http://winrus.com/utf8-jap.htm and memory -UTF8_TEST_DOCSTRING = ( -    'zażółć gęślą jaźń' -    'Sîne klâwen durh die wolken sint geslagen' -    'Τη γλώσσα μου έδωσαν ελληνική' -    'யாமறிந்த மொழிகளிலே தமிழ்மொழி' -    'На берегу пустынных волн' -    'てすと' -    'アイウエオカキクケコサシスセソタチツテ' -) - - -class ViewWithNonASCIICharactersInDocstring(APIView): -    __doc__ = UTF8_TEST_DOCSTRING diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/rest_framework/tests/extras/__init__.py +++ /dev/null diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py deleted file mode 100644 index 68263d94..00000000 --- a/rest_framework/tests/extras/bad_import.py +++ /dev/null @@ -1 +0,0 @@ -raise ValueError diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py deleted file mode 100644 index 6c8f2342..00000000 --- a/rest_framework/tests/models.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers - - -def foobar(): -    return 'foobar' - - -class CustomField(models.CharField): - -    def __init__(self, *args, **kwargs): -        kwargs['max_length'] = 12 -        super(CustomField, self).__init__(*args, **kwargs) - - -class RESTFrameworkModel(models.Model): -    """ -    Base for test models that sets app_label, so they play nicely. -    """ -    class Meta: -        app_label = 'tests' -        abstract = True - - -class HasPositiveIntegerAsChoice(RESTFrameworkModel): -    some_choices = ((1, 'A'), (2, 'B'), (3, 'C')) -    some_integer = models.PositiveIntegerField(choices=some_choices) - - -class Anchor(RESTFrameworkModel): -    text = models.CharField(max_length=100, default='anchor') - - -class BasicModel(RESTFrameworkModel): -    text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) - - -class SlugBasedModel(RESTFrameworkModel): -    text = models.CharField(max_length=100) -    slug = models.SlugField(max_length=32) - - -class DefaultValueModel(RESTFrameworkModel): -    text = models.CharField(default='foobar', max_length=100) -    extra = models.CharField(blank=True, null=True, max_length=100) - - -class CallableDefaultValueModel(RESTFrameworkModel): -    text = models.CharField(default=foobar, max_length=100) - - -class ManyToManyModel(RESTFrameworkModel): -    rel = models.ManyToManyField(Anchor, help_text='Some help text.') - - -class ReadOnlyManyToManyModel(RESTFrameworkModel): -    text = models.CharField(max_length=100, default='anchor') -    rel = models.ManyToManyField(Anchor) - - -# Model for regression test for #285 - -class Comment(RESTFrameworkModel): -    email = models.EmailField() -    content = models.CharField(max_length=200) -    created = models.DateTimeField(auto_now_add=True) - - -class ActionItem(RESTFrameworkModel): -    title = models.CharField(max_length=200) -    started = models.NullBooleanField(default=False) -    done = models.BooleanField(default=False) -    info = CustomField(default='---', max_length=12) - - -# Models for reverse relations -class Person(RESTFrameworkModel): -    name = models.CharField(max_length=10) -    age = models.IntegerField(null=True, blank=True) - -    @property -    def info(self): -        return { -            'name': self.name, -            'age': self.age, -        } - - -class BlogPost(RESTFrameworkModel): -    title = models.CharField(max_length=100) -    writer = models.ForeignKey(Person, null=True, blank=True) - -    def get_first_comment(self): -        return self.blogpostcomment_set.all()[0] - - -class BlogPostComment(RESTFrameworkModel): -    text = models.TextField() -    blog_post = models.ForeignKey(BlogPost) - - -class Album(RESTFrameworkModel): -    title = models.CharField(max_length=100, unique=True) -    ref = models.CharField(max_length=10, unique=True, null=True, blank=True) - -class Photo(RESTFrameworkModel): -    description = models.TextField() -    album = models.ForeignKey(Album) - - -# Model for issue #324 -class BlankFieldModel(RESTFrameworkModel): -    title = models.CharField(max_length=100, blank=True, null=False) - - -# Model for issue #380 -class OptionalRelationModel(RESTFrameworkModel): -    other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) - - -# Model for RegexField -class Book(RESTFrameworkModel): -    isbn = models.CharField(max_length=13) - - -# Models for relations tests -# ManyToMany -class ManyToManyTarget(RESTFrameworkModel): -    name = models.CharField(max_length=100) - - -class ManyToManySource(RESTFrameworkModel): -    name = models.CharField(max_length=100) -    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') - - -# ForeignKey -class ForeignKeyTarget(RESTFrameworkModel): -    name = models.CharField(max_length=100) - - -class ForeignKeySource(RESTFrameworkModel): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - - -# Nullable ForeignKey -class NullableForeignKeySource(RESTFrameworkModel): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, -                               related_name='nullable_sources') - - -# OneToOne -class OneToOneTarget(RESTFrameworkModel): -    name = models.CharField(max_length=100) - - -class NullableOneToOneSource(RESTFrameworkModel): -    name = models.CharField(max_length=100) -    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, -                                  related_name='nullable_source') - - -# Serializer used to test BasicModel -class BasicModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = BasicModel - - -# Models to test filters -class FilterableItem(models.Model): -    text = models.CharField(max_length=100) -    decimal = models.DecimalField(max_digits=4, decimal_places=2) -    date = models.DateField() diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/rest_framework/tests/records/__init__.py +++ /dev/null diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py deleted file mode 100644 index 76954807..00000000 --- a/rest_framework/tests/records/models.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.db import models - - -class Record(models.Model): -    account = models.ForeignKey('accounts.Account', blank=True, null=True) -    owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py deleted file mode 100644 index cc943c7d..00000000 --- a/rest_framework/tests/serializers.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import serializers - -from rest_framework.tests.models import NullableForeignKeySource - - -class NullableFKSourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = NullableForeignKeySource diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py deleted file mode 100644 index c37d2a51..00000000 --- a/rest_framework/tests/test_authentication.py +++ /dev/null @@ -1,663 +0,0 @@ -from __future__ import unicode_literals -from django.contrib.auth.models import User -from django.http import HttpResponse -from django.test import TestCase -from django.utils import unittest -from django.utils.http import urlencode -from rest_framework import HTTP_HEADER_ENCODING -from rest_framework import exceptions -from rest_framework import permissions -from rest_framework import renderers -from rest_framework.response import Response -from rest_framework import status -from rest_framework.authentication import ( -    BaseAuthentication, -    TokenAuthentication, -    BasicAuthentication, -    SessionAuthentication, -    OAuthAuthentication, -    OAuth2Authentication -) -from rest_framework.authtoken.models import Token -from rest_framework.compat import patterns, url, include -from rest_framework.compat import oauth2_provider, oauth2_provider_scope -from rest_framework.compat import oauth, oauth_provider -from rest_framework.test import APIRequestFactory, APIClient -from rest_framework.views import APIView -import base64 -import time -import datetime - -factory = APIRequestFactory() - - -class MockView(APIView): -    permission_classes = (permissions.IsAuthenticated,) - -    def get(self, request): -        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - -    def post(self, request): -        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - -    def put(self, request): -        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) - - -urlpatterns = patterns('', -    (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), -    (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), -    (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), -    (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), -    (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), -    (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], -        permission_classes=[permissions.TokenHasReadWriteScope])) -) - -class OAuth2AuthenticationDebug(OAuth2Authentication): -    allow_query_params_token = True - -if oauth2_provider is not None: -    urlpatterns += patterns('', -        url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), -        url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), -        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), -        url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], -            permission_classes=[permissions.TokenHasReadWriteScope])), -    ) - - -class BasicAuthTests(TestCase): -    """Basic authentication""" -    urls = 'rest_framework.tests.test_authentication' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -    def test_post_form_passing_basic_auth(self): -        """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" -        credentials = ('%s:%s' % (self.username, self.password)) -        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) -        auth = 'Basic %s' % base64_credentials -        response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_post_json_passing_basic_auth(self): -        """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" -        credentials = ('%s:%s' % (self.username, self.password)) -        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) -        auth = 'Basic %s' % base64_credentials -        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_post_form_failing_basic_auth(self): -        """Ensure POSTing form over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/basic/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - -    def test_post_json_failing_basic_auth(self): -        """Ensure POSTing json over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') -        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) -        self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') - - -class SessionAuthTests(TestCase): -    """User session authentication""" -    urls = 'rest_framework.tests.test_authentication' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.non_csrf_client = APIClient(enforce_csrf_checks=False) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -    def tearDown(self): -        self.csrf_client.logout() - -    def test_post_form_session_auth_failing_csrf(self): -        """ -        Ensure POSTing form over session authentication without CSRF token fails. -        """ -        self.csrf_client.login(username=self.username, password=self.password) -        response = self.csrf_client.post('/session/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_post_form_session_auth_passing(self): -        """ -        Ensure POSTing form over session authentication with logged in user and CSRF token passes. -        """ -        self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.post('/session/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_put_form_session_auth_passing(self): -        """ -        Ensure PUTting form over session authentication with logged in user and CSRF token passes. -        """ -        self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.put('/session/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_post_form_session_auth_failing(self): -        """ -        Ensure POSTing form over session authentication without logged in user fails. -        """ -        response = self.csrf_client.post('/session/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - -class TokenAuthTests(TestCase): -    """Token authentication""" -    urls = 'rest_framework.tests.test_authentication' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -        self.key = 'abcd1234' -        self.token = Token.objects.create(key=self.key, user=self.user) - -    def test_post_form_passing_token_auth(self): -        """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" -        auth = 'Token ' + self.key -        response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_post_json_passing_token_auth(self): -        """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" -        auth = "Token " + self.key -        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_post_form_failing_token_auth(self): -        """Ensure POSTing form over token auth without correct credentials fails""" -        response = self.csrf_client.post('/token/', {'example': 'example'}) -        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - -    def test_post_json_failing_token_auth(self): -        """Ensure POSTing json over token auth without correct credentials fails""" -        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') -        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - -    def test_token_has_auto_assigned_key_if_none_provided(self): -        """Ensure creating a token with no key will auto-assign a key""" -        self.token.delete() -        token = Token.objects.create(user=self.user) -        self.assertTrue(bool(token.key)) - -    def test_token_login_json(self): -        """Ensure token login view using JSON POST works.""" -        client = APIClient(enforce_csrf_checks=True) -        response = client.post('/auth-token/', -                               {'username': self.username, 'password': self.password}, format='json') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['token'], self.key) - -    def test_token_login_json_bad_creds(self): -        """Ensure token login view using JSON POST fails if bad credentials are used.""" -        client = APIClient(enforce_csrf_checks=True) -        response = client.post('/auth-token/', -                               {'username': self.username, 'password': "badpass"}, format='json') -        self.assertEqual(response.status_code, 400) - -    def test_token_login_json_missing_fields(self): -        """Ensure token login view using JSON POST fails if missing fields.""" -        client = APIClient(enforce_csrf_checks=True) -        response = client.post('/auth-token/', -                               {'username': self.username}, format='json') -        self.assertEqual(response.status_code, 400) - -    def test_token_login_form(self): -        """Ensure token login view using form POST works.""" -        client = APIClient(enforce_csrf_checks=True) -        response = client.post('/auth-token/', -                               {'username': self.username, 'password': self.password}) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['token'], self.key) - - -class IncorrectCredentialsTests(TestCase): -    def test_incorrect_credentials(self): -        """ -        If a request contains bad authentication credentials, then -        authentication should run and error, even if no permissions -        are set on the view. -        """ -        class IncorrectCredentialsAuth(BaseAuthentication): -            def authenticate(self, request): -                raise exceptions.AuthenticationFailed('Bad credentials') - -        request = factory.get('/') -        view = MockView.as_view( -            authentication_classes=(IncorrectCredentialsAuth,), -            permission_classes=() -        ) -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -        self.assertEqual(response.data, {'detail': 'Bad credentials'}) - - -class OAuthTests(TestCase): -    """OAuth 1.0a authentication""" -    urls = 'rest_framework.tests.test_authentication' - -    def setUp(self): -        # these imports are here because oauth is optional and hiding them in try..except block or compat -        # could obscure problems if something breaks -        from oauth_provider.models import Consumer, Scope -        from oauth_provider.models import Token as OAuthToken -        from oauth_provider import consts - -        self.consts = consts - -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -        self.CONSUMER_KEY = 'consumer_key' -        self.CONSUMER_SECRET = 'consumer_secret' -        self.TOKEN_KEY = "token_key" -        self.TOKEN_SECRET = "token_secret" - -        self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, -            name='example', user=self.user, status=self.consts.ACCEPTED) - -        self.scope = Scope.objects.create(name="resource name", url="api/") -        self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope, -            token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True -        ) - -    def _create_authorization_header(self): -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="GET", url="http://example.com", parameters=params) - -        signature_method = oauth.SignatureMethod_PLAINTEXT() -        req.sign_request(signature_method, self.consumer, self.token) - -        return req.to_header()["Authorization"] - -    def _create_authorization_url_parameters(self): -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="GET", url="http://example.com", parameters=params) - -        signature_method = oauth.SignatureMethod_PLAINTEXT() -        req.sign_request(signature_method, self.consumer, self.token) -        return dict(req) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_passing_oauth(self): -        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_repeated_nonce_failing_oauth(self): -        """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -        # simulate reply attack auth header containes already used (nonce, timestamp) pair -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_token_removed_failing_oauth(self): -        """Ensure POSTing when there is no OAuth access token in db fails""" -        self.token.delete() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_consumer_status_not_accepted_failing_oauth(self): -        """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" -        for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): -            self.consumer.status = consumer_status -            self.consumer.save() - -            auth = self._create_authorization_header() -            response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -            self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_request_token_failing_oauth(self): -        """Ensure POSTing with unauthorized request token instead of access token fails""" -        self.token.token_type = self.token.REQUEST -        self.token.save() - -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_urlencoded_parameters(self): -        """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" -        params = self._create_authorization_url_parameters() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_get_form_with_url_parameters(self): -        """Ensure GETing with auth in url parameters passes""" -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.get('/oauth/', params) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_hmac_sha1_signature_passes(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_get_form_with_readonly_resource_passing_auth(self): -        """Ensure POSTing with a readonly scope instead of a write scope fails""" -        read_only_access_token = self.token -        read_only_access_token.scope.is_readonly = True -        read_only_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.get('/oauth-with-scope/', params) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_readonly_resource_failing_auth(self): -        """Ensure POSTing with a readonly resource instead of a write scope fails""" -        read_only_access_token = self.token -        read_only_access_token.scope.is_readonly = True -        read_only_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.post('/oauth-with-scope/', params) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_write_resource_passing_auth(self): -        """Ensure POSTing with a write resource succeed""" -        read_write_access_token = self.token -        read_write_access_token.scope.is_readonly = False -        read_write_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_bad_consumer_key(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': 'badconsumerkey' -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_bad_token_key(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': 'badtokenkey', -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - - -class OAuth2Tests(TestCase): -    """OAuth 2.0 authentication""" -    urls = 'rest_framework.tests.test_authentication' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -        self.CLIENT_ID = 'client_key' -        self.CLIENT_SECRET = 'client_secret' -        self.ACCESS_TOKEN = "access_token" -        self.REFRESH_TOKEN = "refresh_token" - -        self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( -                client_id=self.CLIENT_ID, -                client_secret=self.CLIENT_SECRET, -                redirect_uri='', -                client_type=0, -                name='example', -                user=None, -            ) - -        self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( -                token=self.ACCESS_TOKEN, -                client=self.oauth2_client, -                user=self.user, -            ) -        self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( -                user=self.user, -                access_token=self.access_token, -                client=self.oauth2_client -            ) - -    def _create_authorization_header(self, token=None): -        return "Bearer {0}".format(token or self.access_token.token) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_type_failing(self): -        """Ensure that a wrong token type lead to the correct HTTP error status code""" -        auth = "Wrong token-type-obsviously" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_format_failing(self): -        """Ensure that a wrong token format lead to the correct HTTP error status code""" -        auth = "Bearer wrong token format" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_failing(self): -        """Ensure that a wrong token lead to the correct HTTP error status code""" -        auth = "Bearer wrong-token" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_passing_auth(self): -        """Ensure GETing form over OAuth with correct client credentials succeed""" -        auth = self._create_authorization_header() -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_passing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" -        response = self.csrf_client.post('/oauth2-test/', -                data={'access_token': self.access_token.token}) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_passing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" -        query = urlencode({'access_token': self.access_token.token}) -        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_failing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" -        query = urlencode({'access_token': self.access_token.token}) -        response = self.csrf_client.get('/oauth2-test/?%s' % query) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_passing_auth(self): -        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_token_removed_failing_auth(self): -        """Ensure POSTing when there is no OAuth access token in db fails""" -        self.access_token.delete() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_refresh_token_failing_auth(self): -        """Ensure POSTing with refresh token instead of access token fails""" -        auth = self._create_authorization_header(token=self.refresh_token.token) -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_expired_access_token_failing_auth(self): -        """Ensure POSTing with expired access token fails with an 'Invalid token' error""" -        self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10)  # 10 seconds late -        self.access_token.save() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) -        self.assertIn('Invalid token', response.content) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_invalid_scope_failing_auth(self): -        """Ensure POSTing with a readonly scope instead of a write scope fails""" -        read_only_access_token = self.access_token -        read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] -        read_only_access_token.save() -        auth = self._create_authorization_header(token=read_only_access_token.token) -        response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) -        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_valid_scope_passing_auth(self): -        """Ensure POSTing with a write scope succeed""" -        read_write_access_token = self.access_token -        read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] -        read_write_access_token.save() -        auth = self._create_authorization_header(token=read_write_access_token.token) -        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - - -class FailingAuthAccessedInRenderer(TestCase): -    def setUp(self): -        class AuthAccessingRenderer(renderers.BaseRenderer): -            media_type = 'text/plain' -            format = 'txt' - -            def render(self, data, media_type=None, renderer_context=None): -                request = renderer_context['request'] -                if request.user.is_authenticated(): -                    return b'authenticated' -                return b'not authenticated' - -        class FailingAuth(BaseAuthentication): -            def authenticate(self, request): -                raise exceptions.AuthenticationFailed('authentication failed') - -        class ExampleView(APIView): -            authentication_classes = (FailingAuth,) -            renderer_classes = (AuthAccessingRenderer,) - -            def get(self, request): -                return Response({'foo': 'bar'}) - -        self.view = ExampleView.as_view() - -    def test_failing_auth_accessed_in_renderer(self): -        """ -        When authentication fails the renderer should still be able to access -        `request.user` without raising an exception. Particularly relevant -        to HTML responses that might reasonably access `request.user`. -        """ -        request = factory.get('/') -        response = self.view(request) -        content = response.render().content -        self.assertEqual(content, b'not authenticated') diff --git a/rest_framework/tests/test_breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py deleted file mode 100644 index 41ddf2ce..00000000 --- a/rest_framework/tests/test_breadcrumbs.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.compat import patterns, url -from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework.views import APIView - - -class Root(APIView): -    pass - - -class ResourceRoot(APIView): -    pass - - -class ResourceInstance(APIView): -    pass - - -class NestedResourceRoot(APIView): -    pass - - -class NestedResourceInstance(APIView): -    pass - -urlpatterns = patterns('', -    url(r'^$', Root.as_view()), -    url(r'^resource/$', ResourceRoot.as_view()), -    url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), -    url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), -    url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), -) - - -class BreadcrumbTests(TestCase): -    """Tests the breadcrumb functionality used by the HTML renderer.""" - -    urls = 'rest_framework.tests.test_breadcrumbs' - -    def test_root_breadcrumbs(self): -        url = '/' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) - -    def test_resource_root_breadcrumbs(self): -        url = '/resource/' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), -                                            ('Resource Root', '/resource/')]) - -    def test_resource_instance_breadcrumbs(self): -        url = '/resource/123' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), -                                            ('Resource Root', '/resource/'), -                                            ('Resource Instance', '/resource/123')]) - -    def test_nested_resource_breadcrumbs(self): -        url = '/resource/123/' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), -                                            ('Resource Root', '/resource/'), -                                            ('Resource Instance', '/resource/123'), -                                            ('Nested Resource Root', '/resource/123/')]) - -    def test_nested_resource_instance_breadcrumbs(self): -        url = '/resource/123/abc' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), -                                            ('Resource Root', '/resource/'), -                                            ('Resource Instance', '/resource/123'), -                                            ('Nested Resource Root', '/resource/123/'), -                                            ('Nested Resource Instance', '/resource/123/abc')]) - -    def test_broken_url_breadcrumbs_handled_gracefully(self): -        url = '/foobar' -        self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py deleted file mode 100644 index 195f0ba3..00000000 --- a/rest_framework/tests/test_decorators.py +++ /dev/null @@ -1,157 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import status -from rest_framework.authentication import BasicAuthentication -from rest_framework.parsers import JSONParser -from rest_framework.permissions import IsAuthenticated -from rest_framework.response import Response -from rest_framework.renderers import JSONRenderer -from rest_framework.test import APIRequestFactory -from rest_framework.throttling import UserRateThrottle -from rest_framework.views import APIView -from rest_framework.decorators import ( -    api_view, -    renderer_classes, -    parser_classes, -    authentication_classes, -    throttle_classes, -    permission_classes, -) - - -class DecoratorTestCase(TestCase): - -    def setUp(self): -        self.factory = APIRequestFactory() - -    def _finalize_response(self, request, response, *args, **kwargs): -        response.request = request -        return APIView.finalize_response(self, request, response, *args, **kwargs) - -    def test_api_view_incorrect(self): -        """ -        If @api_view is not applied correct, we should raise an assertion. -        """ - -        @api_view -        def view(request): -            return Response() - -        request = self.factory.get('/') -        self.assertRaises(AssertionError, view, request) - -    def test_api_view_incorrect_arguments(self): -        """ -        If @api_view is missing arguments, we should raise an assertion. -        """ - -        with self.assertRaises(AssertionError): -            @api_view('GET') -            def view(request): -                return Response() - -    def test_calling_method(self): - -        @api_view(['GET']) -        def view(request): -            return Response({}) - -        request = self.factory.get('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -        request = self.factory.post('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - -    def test_calling_put_method(self): - -        @api_view(['GET', 'PUT']) -        def view(request): -            return Response({}) - -        request = self.factory.put('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -        request = self.factory.post('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - -    def test_calling_patch_method(self): - -        @api_view(['GET', 'PATCH']) -        def view(request): -            return Response({}) - -        request = self.factory.patch('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -        request = self.factory.post('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - -    def test_renderer_classes(self): - -        @api_view(['GET']) -        @renderer_classes([JSONRenderer]) -        def view(request): -            return Response({}) - -        request = self.factory.get('/') -        response = view(request) -        self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer)) - -    def test_parser_classes(self): - -        @api_view(['GET']) -        @parser_classes([JSONParser]) -        def view(request): -            self.assertEqual(len(request.parsers), 1) -            self.assertTrue(isinstance(request.parsers[0], -                                       JSONParser)) -            return Response({}) - -        request = self.factory.get('/') -        view(request) - -    def test_authentication_classes(self): - -        @api_view(['GET']) -        @authentication_classes([BasicAuthentication]) -        def view(request): -            self.assertEqual(len(request.authenticators), 1) -            self.assertTrue(isinstance(request.authenticators[0], -                                       BasicAuthentication)) -            return Response({}) - -        request = self.factory.get('/') -        view(request) - -    def test_permission_classes(self): - -        @api_view(['GET']) -        @permission_classes([IsAuthenticated]) -        def view(request): -            return Response({}) - -        request = self.factory.get('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_throttle_classes(self): -        class OncePerDayUserThrottle(UserRateThrottle): -            rate = '1/day' - -        @api_view(['GET']) -        @throttle_classes([OncePerDayUserThrottle]) -        def view(request): -            return Response({}) - -        request = self.factory.get('/') -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -        response = view(request) -        self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py deleted file mode 100644 index 4c03c1de..00000000 --- a/rest_framework/tests/test_description.py +++ /dev/null @@ -1,108 +0,0 @@ -# -- coding: utf-8 -- - -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.compat import apply_markdown, smart_text -from rest_framework.views import APIView -from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring -from rest_framework.tests.description import UTF8_TEST_DOCSTRING - -# We check that docstrings get nicely un-indented. -DESCRIPTION = """an example docstring -==================== - -* list -* list - -another header --------------- - -    code block - -indented - -# hash style header #""" - -# If markdown is installed we also test it's working -# (and that our wrapped forces '=' to h2 and '-' to h3) - -# We support markdown < 2.1 and markdown >= 2.1 -MARKED_DOWN_lt_21 = """<h2>an example docstring</h2> -<ul> -<li>list</li> -<li>list</li> -</ul> -<h3>another header</h3> -<pre><code>code block -</code></pre> -<p>indented</p> -<h2 id="hash_style_header">hash style header</h2>""" - -MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> -<ul> -<li>list</li> -<li>list</li> -</ul> -<h3 id="another-header">another header</h3> -<pre><code>code block -</code></pre> -<p>indented</p> -<h2 id="hash-style-header">hash style header</h2>""" - - -class TestViewNamesAndDescriptions(TestCase): -    def test_view_name_uses_class_name(self): -        """ -        Ensure view names are based on the class name. -        """ -        class MockView(APIView): -            pass -        self.assertEqual(MockView().get_view_name(), 'Mock') - -    def test_view_description_uses_docstring(self): -        """Ensure view descriptions are based on the docstring.""" -        class MockView(APIView): -            """an example docstring -            ==================== - -            * list -            * list - -            another header -            -------------- - -                code block - -            indented - -            # hash style header #""" - -        self.assertEqual(MockView().get_view_description(), DESCRIPTION) - -    def test_view_description_supports_unicode(self): -        """ -        Unicode in docstrings should be respected. -        """ - -        self.assertEqual( -            ViewWithNonASCIICharactersInDocstring().get_view_description(), -            smart_text(UTF8_TEST_DOCSTRING) -        ) - -    def test_view_description_can_be_empty(self): -        """ -        Ensure that if a view has no docstring, -        then it's description is the empty string. -        """ -        class MockView(APIView): -            pass -        self.assertEqual(MockView().get_view_description(), '') - -    def test_markdown(self): -        """ -        Ensure markdown to HTML works as expected. -        """ -        if apply_markdown: -            gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 -            lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 -            self.assertTrue(gte_21_match or lt_21_match) diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py deleted file mode 100644 index e127feef..00000000 --- a/rest_framework/tests/test_fields.py +++ /dev/null @@ -1,984 +0,0 @@ -""" -General serializer field tests. -""" -from __future__ import unicode_literals - -import datetime -from decimal import Decimal -from uuid import uuid4 -from django.core import validators -from django.db import models -from django.test import TestCase -from django.utils.datastructures import SortedDict -from rest_framework import serializers -from rest_framework.tests.models import RESTFrameworkModel - - -class TimestampedModel(models.Model): -    added = models.DateTimeField(auto_now_add=True) -    updated = models.DateTimeField(auto_now=True) - - -class CharPrimaryKeyModel(models.Model): -    id = models.CharField(max_length=20, primary_key=True) - - -class TimestampedModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = TimestampedModel - - -class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = CharPrimaryKeyModel - - -class TimeFieldModel(models.Model): -    clock = models.TimeField() - - -class TimeFieldModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = TimeFieldModel - - -SAMPLE_CHOICES = [ -    ('red', 'Red'), -    ('green', 'Green'), -    ('blue', 'Blue'), -] - - -class ChoiceFieldModel(models.Model): -    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255) - - -class ChoiceFieldModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = ChoiceFieldModel - - -class ChoiceFieldModelWithNull(models.Model): -    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255) - - -class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer): -    class Meta: -        model = ChoiceFieldModelWithNull - - -class BasicFieldTests(TestCase): -    def test_auto_now_fields_read_only(self): -        """ -        auto_now and auto_now_add fields should be read_only by default. -        """ -        serializer = TimestampedModelSerializer() -        self.assertEqual(serializer.fields['added'].read_only, True) - -    def test_auto_pk_fields_read_only(self): -        """ -        AutoField fields should be read_only by default. -        """ -        serializer = TimestampedModelSerializer() -        self.assertEqual(serializer.fields['id'].read_only, True) - -    def test_non_auto_pk_fields_not_read_only(self): -        """ -        PK fields other than AutoField fields should not be read_only by default. -        """ -        serializer = CharPrimaryKeyModelSerializer() -        self.assertEqual(serializer.fields['id'].read_only, False) - -    def test_dict_field_ordering(self): -        """ -        Field should preserve dictionary ordering, if it exists. -        See: https://github.com/tomchristie/django-rest-framework/issues/832 -        """ -        ret = SortedDict() -        ret['c'] = 1 -        ret['b'] = 1 -        ret['a'] = 1 -        ret['z'] = 1 -        field = serializers.Field() -        keys = list(field.to_native(ret).keys()) -        self.assertEqual(keys, ['c', 'b', 'a', 'z']) - - -class DateFieldTest(TestCase): -    """ -    Tests for the DateFieldTest from_native() and to_native() behavior -    """ - -    def test_from_native_string(self): -        """ -        Make sure from_native() accepts default iso input formats. -        """ -        f = serializers.DateField() -        result_1 = f.from_native('1984-07-31') - -        self.assertEqual(datetime.date(1984, 7, 31), result_1) - -    def test_from_native_datetime_date(self): -        """ -        Make sure from_native() accepts a datetime.date instance. -        """ -        f = serializers.DateField() -        result_1 = f.from_native(datetime.date(1984, 7, 31)) - -        self.assertEqual(result_1, datetime.date(1984, 7, 31)) - -    def test_from_native_custom_format(self): -        """ -        Make sure from_native() accepts custom input formats. -        """ -        f = serializers.DateField(input_formats=['%Y -- %d']) -        result = f.from_native('1984 -- 31') - -        self.assertEqual(datetime.date(1984, 1, 31), result) - -    def test_from_native_invalid_default_on_custom_format(self): -        """ -        Make sure from_native() don't accept default formats if custom format is preset -        """ -        f = serializers.DateField(input_formats=['%Y -- %d']) - -        try: -            f.from_native('1984-07-31') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_empty(self): -        """ -        Make sure from_native() returns None on empty param. -        """ -        f = serializers.DateField() -        result = f.from_native('') - -        self.assertEqual(result, None) - -    def test_from_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DateField() -        result = f.from_native(None) - -        self.assertEqual(result, None) - -    def test_from_native_invalid_date(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid date. -        """ -        f = serializers.DateField() - -        try: -            f.from_native('1984-13-31') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_invalid_format(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid format. -        """ -        f = serializers.DateField() - -        try: -            f.from_native('1984 -- 31') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_to_native(self): -        """ -        Make sure to_native() returns datetime as default. -        """ -        f = serializers.DateField() - -        result_1 = f.to_native(datetime.date(1984, 7, 31)) - -        self.assertEqual(datetime.date(1984, 7, 31), result_1) - -    def test_to_native_iso(self): -        """ -        Make sure to_native() with 'iso-8601' returns iso formated date. -        """ -        f = serializers.DateField(format='iso-8601') - -        result_1 = f.to_native(datetime.date(1984, 7, 31)) - -        self.assertEqual('1984-07-31', result_1) - -    def test_to_native_custom_format(self): -        """ -        Make sure to_native() returns correct custom format. -        """ -        f = serializers.DateField(format="%Y - %m.%d") - -        result_1 = f.to_native(datetime.date(1984, 7, 31)) - -        self.assertEqual('1984 - 07.31', result_1) - -    def test_to_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DateField(required=False) -        self.assertEqual(None, f.to_native(None)) - - -class DateTimeFieldTest(TestCase): -    """ -    Tests for the DateTimeField from_native() and to_native() behavior -    """ - -    def test_from_native_string(self): -        """ -        Make sure from_native() accepts default iso input formats. -        """ -        f = serializers.DateTimeField() -        result_1 = f.from_native('1984-07-31 04:31') -        result_2 = f.from_native('1984-07-31 04:31:59') -        result_3 = f.from_native('1984-07-31 04:31:59.000200') - -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) - -    def test_from_native_datetime_datetime(self): -        """ -        Make sure from_native() accepts a datetime.datetime instance. -        """ -        f = serializers.DateTimeField() -        result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) -        result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) -        result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - -        self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) -        self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) -        self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - -    def test_from_native_custom_format(self): -        """ -        Make sure from_native() accepts custom input formats. -        """ -        f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) -        result = f.from_native('1984 -- 04:59') - -        self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) - -    def test_from_native_invalid_default_on_custom_format(self): -        """ -        Make sure from_native() don't accept default formats if custom format is preset -        """ -        f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) - -        try: -            f.from_native('1984-07-31 04:31:59') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_empty(self): -        """ -        Make sure from_native() returns None on empty param. -        """ -        f = serializers.DateTimeField() -        result = f.from_native('') - -        self.assertEqual(result, None) - -    def test_from_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DateTimeField() -        result = f.from_native(None) - -        self.assertEqual(result, None) - -    def test_from_native_invalid_datetime(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid datetime. -        """ -        f = serializers.DateTimeField() - -        try: -            f.from_native('04:61:59') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " -                                          "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_invalid_format(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid format. -        """ -        f = serializers.DateTimeField() - -        try: -            f.from_native('04 -- 31') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " -                                          "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_to_native(self): -        """ -        Make sure to_native() returns isoformat as default. -        """ -        f = serializers.DateTimeField() - -        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) -        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) -        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) -        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - -        self.assertEqual(datetime.datetime(1984, 7, 31), result_1) -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) -        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) - -    def test_to_native_iso(self): -        """ -        Make sure to_native() with format=iso-8601 returns iso formatted datetime. -        """ -        f = serializers.DateTimeField(format='iso-8601') - -        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) -        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) -        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) -        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - -        self.assertEqual('1984-07-31T00:00:00', result_1) -        self.assertEqual('1984-07-31T04:31:00', result_2) -        self.assertEqual('1984-07-31T04:31:59', result_3) -        self.assertEqual('1984-07-31T04:31:59.000200', result_4) - -    def test_to_native_custom_format(self): -        """ -        Make sure to_native() returns correct custom format. -        """ -        f = serializers.DateTimeField(format="%Y - %H:%M") - -        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) -        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) -        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) -        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - -        self.assertEqual('1984 - 00:00', result_1) -        self.assertEqual('1984 - 04:31', result_2) -        self.assertEqual('1984 - 04:31', result_3) -        self.assertEqual('1984 - 04:31', result_4) - -    def test_to_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DateTimeField(required=False) -        self.assertEqual(None, f.to_native(None)) - - -class TimeFieldTest(TestCase): -    """ -    Tests for the TimeField from_native() and to_native() behavior -    """ - -    def test_from_native_string(self): -        """ -        Make sure from_native() accepts default iso input formats. -        """ -        f = serializers.TimeField() -        result_1 = f.from_native('04:31') -        result_2 = f.from_native('04:31:59') -        result_3 = f.from_native('04:31:59.000200') - -        self.assertEqual(datetime.time(4, 31), result_1) -        self.assertEqual(datetime.time(4, 31, 59), result_2) -        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) - -    def test_from_native_datetime_time(self): -        """ -        Make sure from_native() accepts a datetime.time instance. -        """ -        f = serializers.TimeField() -        result_1 = f.from_native(datetime.time(4, 31)) -        result_2 = f.from_native(datetime.time(4, 31, 59)) -        result_3 = f.from_native(datetime.time(4, 31, 59, 200)) - -        self.assertEqual(result_1, datetime.time(4, 31)) -        self.assertEqual(result_2, datetime.time(4, 31, 59)) -        self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) - -    def test_from_native_custom_format(self): -        """ -        Make sure from_native() accepts custom input formats. -        """ -        f = serializers.TimeField(input_formats=['%H -- %M']) -        result = f.from_native('04 -- 31') - -        self.assertEqual(datetime.time(4, 31), result) - -    def test_from_native_invalid_default_on_custom_format(self): -        """ -        Make sure from_native() don't accept default formats if custom format is preset -        """ -        f = serializers.TimeField(input_formats=['%H -- %M']) - -        try: -            f.from_native('04:31:59') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_empty(self): -        """ -        Make sure from_native() returns None on empty param. -        """ -        f = serializers.TimeField() -        result = f.from_native('') - -        self.assertEqual(result, None) - -    def test_from_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.TimeField() -        result = f.from_native(None) - -        self.assertEqual(result, None) - -    def test_from_native_invalid_time(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid time. -        """ -        f = serializers.TimeField() - -        try: -            f.from_native('04:61:59') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " -                                          "hh:mm[:ss[.uuuuuu]]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_invalid_format(self): -        """ -        Make sure from_native() raises a ValidationError on passing an invalid format. -        """ -        f = serializers.TimeField() - -        try: -            f.from_native('04 -- 31') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " -                                          "hh:mm[:ss[.uuuuuu]]"]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_to_native(self): -        """ -        Make sure to_native() returns time object as default. -        """ -        f = serializers.TimeField() -        result_1 = f.to_native(datetime.time(4, 31)) -        result_2 = f.to_native(datetime.time(4, 31, 59)) -        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) - -        self.assertEqual(datetime.time(4, 31), result_1) -        self.assertEqual(datetime.time(4, 31, 59), result_2) -        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) - -    def test_to_native_iso(self): -        """ -        Make sure to_native() with format='iso-8601' returns iso formatted time. -        """ -        f = serializers.TimeField(format='iso-8601') -        result_1 = f.to_native(datetime.time(4, 31)) -        result_2 = f.to_native(datetime.time(4, 31, 59)) -        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) - -        self.assertEqual('04:31:00', result_1) -        self.assertEqual('04:31:59', result_2) -        self.assertEqual('04:31:59.000200', result_3) - -    def test_to_native_custom_format(self): -        """ -        Make sure to_native() returns correct custom format. -        """ -        f = serializers.TimeField(format="%H - %S [%f]") -        result_1 = f.to_native(datetime.time(4, 31)) -        result_2 = f.to_native(datetime.time(4, 31, 59)) -        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) - -        self.assertEqual('04 - 00 [000000]', result_1) -        self.assertEqual('04 - 59 [000000]', result_2) -        self.assertEqual('04 - 59 [000200]', result_3) - - -class DecimalFieldTest(TestCase): -    """ -    Tests for the DecimalField from_native() and to_native() behavior -    """ - -    def test_from_native_string(self): -        """ -        Make sure from_native() accepts string values -        """ -        f = serializers.DecimalField() -        result_1 = f.from_native('9000') -        result_2 = f.from_native('1.00000001') - -        self.assertEqual(Decimal('9000'), result_1) -        self.assertEqual(Decimal('1.00000001'), result_2) - -    def test_from_native_invalid_string(self): -        """ -        Make sure from_native() raises ValidationError on passing invalid string -        """ -        f = serializers.DecimalField() - -        try: -            f.from_native('123.45.6') -        except validators.ValidationError as e: -            self.assertEqual(e.messages, ["Enter a number."]) -        else: -            self.fail("ValidationError was not properly raised") - -    def test_from_native_integer(self): -        """ -        Make sure from_native() accepts integer values -        """ -        f = serializers.DecimalField() -        result = f.from_native(9000) - -        self.assertEqual(Decimal('9000'), result) - -    def test_from_native_float(self): -        """ -        Make sure from_native() accepts float values -        """ -        f = serializers.DecimalField() -        result = f.from_native(1.00000001) - -        self.assertEqual(Decimal('1.00000001'), result) - -    def test_from_native_empty(self): -        """ -        Make sure from_native() returns None on empty param. -        """ -        f = serializers.DecimalField() -        result = f.from_native('') - -        self.assertEqual(result, None) - -    def test_from_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DecimalField() -        result = f.from_native(None) - -        self.assertEqual(result, None) - -    def test_to_native(self): -        """ -        Make sure to_native() returns Decimal as string. -        """ -        f = serializers.DecimalField() - -        result_1 = f.to_native(Decimal('9000')) -        result_2 = f.to_native(Decimal('1.00000001')) - -        self.assertEqual(Decimal('9000'), result_1) -        self.assertEqual(Decimal('1.00000001'), result_2) - -    def test_to_native_none(self): -        """ -        Make sure from_native() returns None on None param. -        """ -        f = serializers.DecimalField(required=False) -        self.assertEqual(None, f.to_native(None)) - -    def test_valid_serialization(self): -        """ -        Make sure the serializer works correctly -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(max_value=9010, -                                                     min_value=9000, -                                                     max_digits=6, -                                                     decimal_places=2) - -        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) -        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) -        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) - -        self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) -        self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) -        self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) - -    def test_raise_max_value(self): -        """ -        Make sure max_value violations raises ValidationError -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(max_value=100) - -        s = DecimalSerializer(data={'decimal_field': '123'}) - -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is less than or equal to 100.']}) - -    def test_raise_min_value(self): -        """ -        Make sure min_value violations raises ValidationError -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(min_value=100) - -        s = DecimalSerializer(data={'decimal_field': '99'}) - -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) - -    def test_raise_max_digits(self): -        """ -        Make sure max_digits violations raises ValidationError -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(max_digits=5) - -        s = DecimalSerializer(data={'decimal_field': '123.456'}) - -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) - -    def test_raise_max_decimal_places(self): -        """ -        Make sure max_decimal_places violations raises ValidationError -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(decimal_places=3) - -        s = DecimalSerializer(data={'decimal_field': '123.4567'}) - -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) - -    def test_raise_max_whole_digits(self): -        """ -        Make sure max_whole_digits violations raises ValidationError -        """ -        class DecimalSerializer(serializers.Serializer): -            decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) - -        s = DecimalSerializer(data={'decimal_field': '12345.6'}) - -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) - - -class ChoiceFieldTests(TestCase): -    """ -    Tests for the ChoiceField options generator -    """ -    def test_choices_required(self): -        """ -        Make sure proper choices are rendered if field is required -        """ -        f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES) -        self.assertEqual(f.choices, SAMPLE_CHOICES) - -    def test_choices_not_required(self): -        """ -        Make sure proper choices (plus blank) are rendered if the field isn't required -        """ -        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) -        self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) - -    def test_invalid_choice_model(self): -        s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) -        self.assertFalse(s.is_valid()) -        self.assertEqual(s.errors,  {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) -        self.assertEqual(s.data['choice'], '') - -    def test_empty_choice_model(self): -        """ -        Test that the 'empty' value is correctly passed and used depending on -        the 'null' property on the model field. -        """ -        s = ChoiceFieldModelSerializer(data={'choice': ''}) -        self.assertTrue(s.is_valid()) -        self.assertEqual(s.data['choice'], '') - -        s = ChoiceFieldModelWithNullSerializer(data={'choice': ''}) -        self.assertTrue(s.is_valid()) -        self.assertEqual(s.data['choice'], None) - -    def test_from_native_empty(self): -        """ -        Make sure from_native() returns an empty string on empty param by default. -        """ -        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) -        self.assertEqual(f.from_native(''), '') -        self.assertEqual(f.from_native(None), '') - -    def test_from_native_empty_override(self): -        """ -        Make sure you can override from_native() behavior regarding empty values. -        """ -        f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None) -        self.assertEqual(f.from_native(''), None) -        self.assertEqual(f.from_native(None), None) - -    def test_metadata_choices(self): -        """ -        Make sure proper choices are included in the field's metadata. -        """ -        choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES] -        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) -        self.assertEqual(f.metadata()['choices'], choices) - -    def test_metadata_choices_not_required(self): -        """ -        Make sure proper choices are included in the field's metadata. -        """ -        choices = [{'value': v, 'display_name': n} -                   for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES] -        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) -        self.assertEqual(f.metadata()['choices'], choices) - - -class EmailFieldTests(TestCase): -    """ -    Tests for EmailField attribute values -    """ - -    class EmailFieldModel(RESTFrameworkModel): -        email_field = models.EmailField(blank=True) - -    class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): -        email_field = models.EmailField(max_length=150, blank=True) - -    def test_default_model_value(self): -        class EmailFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.EmailFieldModel - -        serializer = EmailFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) - -    def test_given_model_value(self): -        class EmailFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.EmailFieldWithGivenMaxLengthModel - -        serializer = EmailFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) - -    def test_given_serializer_value(self): -        class EmailFieldSerializer(serializers.ModelSerializer): -            email_field = serializers.EmailField(source='email_field', max_length=20, required=False) - -            class Meta: -                model = self.EmailFieldModel - -        serializer = EmailFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) - - -class SlugFieldTests(TestCase): -    """ -    Tests for SlugField attribute values -    """ - -    class SlugFieldModel(RESTFrameworkModel): -        slug_field = models.SlugField(blank=True) - -    class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): -        slug_field = models.SlugField(max_length=84, blank=True) - -    def test_default_model_value(self): -        class SlugFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.SlugFieldModel - -        serializer = SlugFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) - -    def test_given_model_value(self): -        class SlugFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.SlugFieldWithGivenMaxLengthModel - -        serializer = SlugFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) - -    def test_given_serializer_value(self): -        class SlugFieldSerializer(serializers.ModelSerializer): -            slug_field = serializers.SlugField(source='slug_field', -                                               max_length=20, required=False) - -            class Meta: -                model = self.SlugFieldModel - -        serializer = SlugFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['slug_field'], -                                 'max_length'), 20) - -    def test_invalid_slug(self): -        """ -        Make sure an invalid slug raises ValidationError -        """ -        class SlugFieldSerializer(serializers.ModelSerializer): -            slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) - -            class Meta: -                model = self.SlugFieldModel - -        s = SlugFieldSerializer(data={'slug_field': 'a b'}) - -        self.assertEqual(s.is_valid(), False) -        self.assertEqual(s.errors,  {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) - - -class URLFieldTests(TestCase): -    """ -    Tests for URLField attribute values. - -    (Includes test for #1210, checking that validators can be overridden.) -    """ - -    class URLFieldModel(RESTFrameworkModel): -        url_field = models.URLField(blank=True) - -    class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): -        url_field = models.URLField(max_length=128, blank=True) - -    def test_default_model_value(self): -        class URLFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.URLFieldModel - -        serializer = URLFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], -                                 'max_length'), 200) - -    def test_given_model_value(self): -        class URLFieldSerializer(serializers.ModelSerializer): -            class Meta: -                model = self.URLFieldWithGivenMaxLengthModel - -        serializer = URLFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], -                                 'max_length'), 128) - -    def test_given_serializer_value(self): -        class URLFieldSerializer(serializers.ModelSerializer): -            url_field = serializers.URLField(source='url_field', -                                             max_length=20, required=False) - -            class Meta: -                model = self.URLFieldWithGivenMaxLengthModel - -        serializer = URLFieldSerializer(data={}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(getattr(serializer.fields['url_field'], -                         'max_length'), 20) - -    def test_validators_can_be_overridden(self): -        url_field = serializers.URLField(validators=[]) -        validators = url_field.validators -        self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') - - -class FieldMetadata(TestCase): -    def setUp(self): -        self.required_field = serializers.Field() -        self.required_field.label = uuid4().hex -        self.required_field.required = True - -        self.optional_field = serializers.Field() -        self.optional_field.label = uuid4().hex -        self.optional_field.required = False - -    def test_required(self): -        self.assertEqual(self.required_field.metadata()['required'], True) - -    def test_optional(self): -        self.assertEqual(self.optional_field.metadata()['required'], False) - -    def test_label(self): -        for field in (self.required_field, self.optional_field): -            self.assertEqual(field.metadata()['label'], field.label) - - -class FieldCallableDefault(TestCase): -    def setUp(self): -        self.simple_callable = lambda: 'foo bar' - -    def test_default_can_be_simple_callable(self): -        """ -        Ensure that the 'default' argument can also be a simple callable. -        """ -        field = serializers.WritableField(default=self.simple_callable) -        into = {} -        field.field_from_native({}, {}, 'field', into) -        self.assertEqual(into, {'field': 'foo bar'}) - - -class CustomIntegerField(TestCase): -    """ -        Test that custom fields apply min_value and max_value constraints -    """ -    def test_custom_fields_can_be_validated_for_value(self): - -        class MoneyField(models.PositiveIntegerField): -            pass - -        class EntryModel(models.Model): -            bank = MoneyField(validators=[validators.MaxValueValidator(100)]) - -        class EntrySerializer(serializers.ModelSerializer): -            class Meta: -                model = EntryModel - -        entry = EntryModel(bank=1) - -        serializer = EntrySerializer(entry, data={"bank": 11}) -        self.assertTrue(serializer.is_valid()) - -        serializer = EntrySerializer(entry, data={"bank": -1}) -        self.assertFalse(serializer.is_valid()) - -        serializer = EntrySerializer(entry, data={"bank": 101}) -        self.assertFalse(serializer.is_valid()) - - -class BooleanField(TestCase): -    """ -        Tests for BooleanField -    """ -    def test_boolean_required(self): -        class BooleanRequiredSerializer(serializers.Serializer): -            bool_field = serializers.BooleanField(required=True) - -        self.assertFalse(BooleanRequiredSerializer(data={}).is_valid()) diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py deleted file mode 100644 index 78f4cf42..00000000 --- a/rest_framework/tests/test_files.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import serializers -from rest_framework.compat import BytesIO -from rest_framework.compat import six -import datetime - - -class UploadedFile(object): -    def __init__(self, file=None, created=None): -        self.file = file -        self.created = created or datetime.datetime.now() - - -class UploadedFileSerializer(serializers.Serializer): -    file = serializers.FileField(required=False) -    created = serializers.DateTimeField() - -    def restore_object(self, attrs, instance=None): -        if instance: -            instance.file = attrs['file'] -            instance.created = attrs['created'] -            return instance -        return UploadedFile(**attrs) - - -class FileSerializerTests(TestCase): -    def test_create(self): -        now = datetime.datetime.now() -        file = BytesIO(six.b('stuff')) -        file.name = 'stuff.txt' -        file.size = len(file.getvalue()) -        serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) -        uploaded_file = UploadedFile(file=file, created=now) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.object.created, uploaded_file.created) -        self.assertEqual(serializer.object.file, uploaded_file.file) -        self.assertFalse(serializer.object is uploaded_file) - -    def test_creation_failure(self): -        """ -        Passing files=None should result in an ValidationError - -        Regression test for: -        https://github.com/tomchristie/django-rest-framework/issues/542 -        """ -        now = datetime.datetime.now() - -        serializer = UploadedFileSerializer(data={'created': now}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.object.created, now) -        self.assertIsNone(serializer.object.file) - -    def test_remove_with_empty_string(self): -        """ -        Passing empty string as data should cause file to be removed - -        Test for: -        https://github.com/tomchristie/django-rest-framework/issues/937 -        """ -        now = datetime.datetime.now() -        file = BytesIO(six.b('stuff')) -        file.name = 'stuff.txt' -        file.size = len(file.getvalue()) - -        uploaded_file = UploadedFile(file=file, created=now) - -        serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.object.created, uploaded_file.created) -        self.assertIsNone(serializer.object.file) - -    def test_validation_error_with_non_file(self): -        """ -        Passing non-files should raise a validation error. -        """ -        now = datetime.datetime.now() -        errmsg = 'No file was submitted. Check the encoding type on the form.' - -        serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'file': [errmsg]}) - -    def test_validation_with_no_data(self): -        """ -        Validation should still function when no data dictionary is provided. -        """ -        now = datetime.datetime.now() -        file = BytesIO(six.b('stuff')) -        file.name = 'stuff.txt' -        file.size = len(file.getvalue()) -        uploaded_file = UploadedFile(file=file, created=now) - -        serializer = UploadedFileSerializer(files={'file': file}) -        self.assertFalse(serializer.is_valid()) diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py deleted file mode 100644 index 23226bbc..00000000 --- a/rest_framework/tests/test_filters.py +++ /dev/null @@ -1,661 +0,0 @@ -from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.db import models -from django.core.urlresolvers import reverse -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, filters -from rest_framework.compat import django_filters, patterns, url -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel -from .models import FilterableItem -from .utils import temporary_setting - -factory = APIRequestFactory() - - -if django_filters: -    # Basic filter on a list view. -    class FilterFieldsRootView(generics.ListCreateAPIView): -        model = FilterableItem -        filter_fields = ['decimal', 'date'] -        filter_backends = (filters.DjangoFilterBackend,) - -    # These class are used to test a filter class. -    class SeveralFieldsFilter(django_filters.FilterSet): -        text = django_filters.CharFilter(lookup_type='icontains') -        decimal = django_filters.NumberFilter(lookup_type='lt') -        date = django_filters.DateFilter(lookup_type='gt') - -        class Meta: -            model = FilterableItem -            fields = ['text', 'decimal', 'date'] - -    class FilterClassRootView(generics.ListCreateAPIView): -        model = FilterableItem -        filter_class = SeveralFieldsFilter -        filter_backends = (filters.DjangoFilterBackend,) - -    # These classes are used to test a misconfigured filter class. -    class MisconfiguredFilter(django_filters.FilterSet): -        text = django_filters.CharFilter(lookup_type='icontains') - -        class Meta: -            model = BasicModel -            fields = ['text'] - -    class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): -        model = FilterableItem -        filter_class = MisconfiguredFilter -        filter_backends = (filters.DjangoFilterBackend,) - -    class FilterClassDetailView(generics.RetrieveAPIView): -        model = FilterableItem -        filter_class = SeveralFieldsFilter -        filter_backends = (filters.DjangoFilterBackend,) - -    # Regression test for #814 -    class FilterableItemSerializer(serializers.ModelSerializer): -        class Meta: -            model = FilterableItem - -    class FilterFieldsQuerysetView(generics.ListCreateAPIView): -        queryset = FilterableItem.objects.all() -        serializer_class = FilterableItemSerializer -        filter_fields = ['decimal', 'date'] -        filter_backends = (filters.DjangoFilterBackend,) - -    class GetQuerysetView(generics.ListCreateAPIView): -        serializer_class = FilterableItemSerializer -        filter_class = SeveralFieldsFilter -        filter_backends = (filters.DjangoFilterBackend,) - -        def get_queryset(self): -            return FilterableItem.objects.all() - -    urlpatterns = patterns('', -        url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), -        url(r'^$', FilterClassRootView.as_view(), name='root-view'), -        url(r'^get-queryset/$', GetQuerysetView.as_view(), -            name='get-queryset-view'), -    ) - - -class CommonFilteringTestCase(TestCase): -    def _serialize_object(self, obj): -        return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - -    def setUp(self): -        """ -        Create 10 FilterableItem instances. -        """ -        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) -        for i in range(10): -            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. -            decimal = base_data[1] + i -            date = base_data[2] - datetime.timedelta(days=i * 2) -            FilterableItem(text=text, decimal=decimal, date=date).save() - -        self.objects = FilterableItem.objects -        self.data = [ -            self._serialize_object(obj) -            for obj in self.objects.all() -        ] - - -class IntegrationTestFiltering(CommonFilteringTestCase): -    """ -    Integration tests for filtered list views. -    """ - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_filtered_fields_root_view(self): -        """ -        GET requests to paginated ListCreateAPIView should return paginated results. -        """ -        view = FilterFieldsRootView.as_view() - -        # Basic test with no filter. -        request = factory.get('/') -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -        # Tests that the decimal filter works. -        search_decimal = Decimal('2.25') -        request = factory.get('/', {'decimal': '%s' % search_decimal}) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['decimal'] == search_decimal] -        self.assertEqual(response.data, expected_data) - -        # Tests that the date filter works. -        search_date = datetime.date(2012, 9, 22) -        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-09-22' -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['date'] == search_date] -        self.assertEqual(response.data, expected_data) - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_filter_with_queryset(self): -        """ -        Regression test for #814. -        """ -        view = FilterFieldsQuerysetView.as_view() - -        # Tests that the decimal filter works. -        search_decimal = Decimal('2.25') -        request = factory.get('/', {'decimal': '%s' % search_decimal}) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['decimal'] == search_decimal] -        self.assertEqual(response.data, expected_data) - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_filter_with_get_queryset_only(self): -        """ -        Regression test for #834. -        """ -        view = GetQuerysetView.as_view() -        request = factory.get('/get-queryset/') -        view(request).render() -        # Used to raise "issubclass() arg 2 must be a class or tuple of classes" -        # here when neither `model' nor `queryset' was specified. - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_filtered_class_root_view(self): -        """ -        GET requests to filtered ListCreateAPIView that have a filter_class set -        should return filtered results. -        """ -        view = FilterClassRootView.as_view() - -        # Basic test with no filter. -        request = factory.get('/') -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -        # Tests that the decimal filter set with 'lt' in the filter class works. -        search_decimal = Decimal('4.25') -        request = factory.get('/', {'decimal': '%s' % search_decimal}) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['decimal'] < search_decimal] -        self.assertEqual(response.data, expected_data) - -        # Tests that the date filter set with 'gt' in the filter class works. -        search_date = datetime.date(2012, 10, 2) -        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-10-02' -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['date'] > search_date] -        self.assertEqual(response.data, expected_data) - -        # Tests that the text filter set with 'icontains' in the filter class works. -        search_text = 'ff' -        request = factory.get('/', {'text': '%s' % search_text}) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if search_text in f['text'].lower()] -        self.assertEqual(response.data, expected_data) - -        # Tests that multiple filters works. -        search_decimal = Decimal('5.25') -        search_date = datetime.date(2012, 10, 2) -        request = factory.get('/', { -            'decimal': '%s' % (search_decimal,), -            'date': '%s' % (search_date,) -        }) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if f['date'] > search_date and -                         f['decimal'] < search_decimal] -        self.assertEqual(response.data, expected_data) - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_incorrectly_configured_filter(self): -        """ -        An error should be displayed when the filter class is misconfigured. -        """ -        view = IncorrectlyConfiguredRootView.as_view() - -        request = factory.get('/') -        self.assertRaises(AssertionError, view, request) - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_unknown_filter(self): -        """ -        GET requests with filters that aren't configured should return 200. -        """ -        view = FilterFieldsRootView.as_view() - -        search_integer = 10 -        request = factory.get('/', {'integer': '%s' % search_integer}) -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) - - -class IntegrationTestDetailFiltering(CommonFilteringTestCase): -    """ -    Integration tests for filtered detail views. -    """ -    urls = 'rest_framework.tests.test_filters' - -    def _get_url(self, item): -        return reverse('detail-view', kwargs=dict(pk=item.pk)) - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_filtered_detail_view(self): -        """ -        GET requests to filtered RetrieveAPIView that have a filter_class set -        should return filtered results. -        """ -        item = self.objects.all()[0] -        data = self._serialize_object(item) - -        # Basic test with no filter. -        response = self.client.get(self._get_url(item)) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, data) - -        # Tests that the decimal filter set that should fail. -        search_decimal = Decimal('4.25') -        high_item = self.objects.filter(decimal__gt=search_decimal)[0] -        response = self.client.get( -            '{url}'.format(url=self._get_url(high_item)), -            {'decimal': '{param}'.format(param=search_decimal)}) -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - -        # Tests that the decimal filter set that should succeed. -        search_decimal = Decimal('4.25') -        low_item = self.objects.filter(decimal__lt=search_decimal)[0] -        low_item_data = self._serialize_object(low_item) -        response = self.client.get( -            '{url}'.format(url=self._get_url(low_item)), -            {'decimal': '{param}'.format(param=search_decimal)}) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, low_item_data) - -        # Tests that multiple filters works. -        search_decimal = Decimal('5.25') -        search_date = datetime.date(2012, 10, 2) -        valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] -        valid_item_data = self._serialize_object(valid_item) -        response = self.client.get( -            '{url}'.format(url=self._get_url(valid_item)), { -                'decimal': '{decimal}'.format(decimal=search_decimal), -                'date': '{date}'.format(date=search_date) -            }) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, valid_item_data) - - -class SearchFilterModel(models.Model): -    title = models.CharField(max_length=20) -    text = models.CharField(max_length=100) - - -class SearchFilterTests(TestCase): -    def setUp(self): -        # Sequence of title/text is: -        # -        # z   abc -        # zz  bcd -        # zzz cde -        # ... -        for idx in range(10): -            title = 'z' * (idx + 1) -            text = ( -                chr(idx + ord('a')) + -                chr(idx + ord('b')) + -                chr(idx + ord('c')) -            ) -            SearchFilterModel(title=title, text=text).save() - -    def test_search(self): -        class SearchListView(generics.ListAPIView): -            model = SearchFilterModel -            filter_backends = (filters.SearchFilter,) -            search_fields = ('title', 'text') - -        view = SearchListView.as_view() -        request = factory.get('/', {'search': 'b'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 1, 'title': 'z', 'text': 'abc'}, -                {'id': 2, 'title': 'zz', 'text': 'bcd'} -            ] -        ) - -    def test_exact_search(self): -        class SearchListView(generics.ListAPIView): -            model = SearchFilterModel -            filter_backends = (filters.SearchFilter,) -            search_fields = ('=title', 'text') - -        view = SearchListView.as_view() -        request = factory.get('/', {'search': 'zzz'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 3, 'title': 'zzz', 'text': 'cde'} -            ] -        ) - -    def test_startswith_search(self): -        class SearchListView(generics.ListAPIView): -            model = SearchFilterModel -            filter_backends = (filters.SearchFilter,) -            search_fields = ('title', '^text') - -        view = SearchListView.as_view() -        request = factory.get('/', {'search': 'b'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 2, 'title': 'zz', 'text': 'bcd'} -            ] -        ) - -    def test_search_with_nonstandard_search_param(self): -        with temporary_setting('SEARCH_PARAM', 'query', module=filters): -            class SearchListView(generics.ListAPIView): -                model = SearchFilterModel -                filter_backends = (filters.SearchFilter,) -                search_fields = ('title', 'text') - -            view = SearchListView.as_view() -            request = factory.get('/', {'query': 'b'}) -            response = view(request) -            self.assertEqual( -                response.data, -                [ -                    {'id': 1, 'title': 'z', 'text': 'abc'}, -                    {'id': 2, 'title': 'zz', 'text': 'bcd'} -                ] -            ) - - -class OrdringFilterModel(models.Model): -    title = models.CharField(max_length=20) -    text = models.CharField(max_length=100) - - -class OrderingFilterRelatedModel(models.Model): -    related_object = models.ForeignKey(OrdringFilterModel, -                                       related_name="relateds") - - -class OrderingFilterTests(TestCase): -    def setUp(self): -        # Sequence of title/text is: -        # -        # zyx abc -        # yxw bcd -        # xwv cde -        for idx in range(3): -            title = ( -                chr(ord('z') - idx) + -                chr(ord('y') - idx) + -                chr(ord('x') - idx) -            ) -            text = ( -                chr(idx + ord('a')) + -                chr(idx + ord('b')) + -                chr(idx + ord('c')) -            ) -            OrdringFilterModel(title=title, text=text).save() - -    def test_ordering(self): -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = ('title',) -            ordering_fields = ('text',) - -        view = OrderingListView.as_view() -        request = factory.get('/', {'ordering': 'text'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -            ] -        ) - -    def test_reverse_ordering(self): -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = ('title',) -            ordering_fields = ('text',) - -        view = OrderingListView.as_view() -        request = factory.get('/', {'ordering': '-text'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -            ] -        ) - -    def test_incorrectfield_ordering(self): -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = ('title',) -            ordering_fields = ('text',) - -        view = OrderingListView.as_view() -        request = factory.get('/', {'ordering': 'foobar'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -            ] -        ) - -    def test_default_ordering(self): -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = ('title',) -            oredering_fields = ('text',) - -        view = OrderingListView.as_view() -        request = factory.get('') -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -            ] -        ) - -    def test_default_ordering_using_string(self): -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = 'title' -            ordering_fields = ('text',) - -        view = OrderingListView.as_view() -        request = factory.get('') -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -            ] -        ) - -    def test_ordering_by_aggregate_field(self): -        # create some related models to aggregate order by -        num_objs = [2, 5, 3] -        for obj, num_relateds in zip(OrdringFilterModel.objects.all(), -                                     num_objs): -            for _ in range(num_relateds): -                new_related = OrderingFilterRelatedModel( -                    related_object=obj -                ) -                new_related.save() - -        class OrderingListView(generics.ListAPIView): -            model = OrdringFilterModel -            filter_backends = (filters.OrderingFilter,) -            ordering = 'title' -            ordering_fields = '__all__' -            queryset = OrdringFilterModel.objects.all().annotate( -                models.Count("relateds")) - -        view = OrderingListView.as_view() -        request = factory.get('/', {'ordering': 'relateds__count'}) -        response = view(request) -        self.assertEqual( -            response.data, -            [ -                {'id': 1, 'title': 'zyx', 'text': 'abc'}, -                {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -            ] -        ) - -    def test_ordering_with_nonstandard_ordering_param(self): -        with temporary_setting('ORDERING_PARAM', 'order', filters): -            class OrderingListView(generics.ListAPIView): -                model = OrdringFilterModel -                filter_backends = (filters.OrderingFilter,) -                ordering = ('title',) -                ordering_fields = ('text',) - -            view = OrderingListView.as_view() -            request = factory.get('/', {'order': 'text'}) -            response = view(request) -            self.assertEqual( -                response.data, -                [ -                    {'id': 1, 'title': 'zyx', 'text': 'abc'}, -                    {'id': 2, 'title': 'yxw', 'text': 'bcd'}, -                    {'id': 3, 'title': 'xwv', 'text': 'cde'}, -                ] -            ) - - -class SensitiveOrderingFilterModel(models.Model): -    username = models.CharField(max_length=20) -    password = models.CharField(max_length=100) - - -# Three different styles of serializer. -# All should allow ordering by username, but not by password. -class SensitiveDataSerializer1(serializers.ModelSerializer): -    username = serializers.CharField() - -    class Meta: -        model = SensitiveOrderingFilterModel -        fields = ('id', 'username') - - -class SensitiveDataSerializer2(serializers.ModelSerializer): -    username = serializers.CharField() -    password = serializers.CharField(write_only=True) - -    class Meta: -        model = SensitiveOrderingFilterModel -        fields = ('id', 'username', 'password') - - -class SensitiveDataSerializer3(serializers.ModelSerializer): -    user = serializers.CharField(source='username') - -    class Meta: -        model = SensitiveOrderingFilterModel -        fields = ('id', 'user') - - -class SensitiveOrderingFilterTests(TestCase): -    def setUp(self): -        for idx in range(3): -            username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] -            password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] -            SensitiveOrderingFilterModel(username=username, password=password).save() - -    def test_order_by_serializer_fields(self): -        for serializer_cls in [ -            SensitiveDataSerializer1, -            SensitiveDataSerializer2, -            SensitiveDataSerializer3 -        ]: -            class OrderingListView(generics.ListAPIView): -                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') -                filter_backends = (filters.OrderingFilter,) -                serializer_class = serializer_cls - -            view = OrderingListView.as_view() -            request = factory.get('/', {'ordering': '-username'}) -            response = view(request) - -            if serializer_cls == SensitiveDataSerializer3: -                username_field = 'user' -            else: -                username_field = 'username' - -            # Note: Inverse username ordering correctly applied. -            self.assertEqual( -                response.data, -                [ -                    {'id': 3, username_field: 'userC'}, -                    {'id': 2, username_field: 'userB'}, -                    {'id': 1, username_field: 'userA'}, -                ] -            ) - -    def test_cannot_order_by_non_serializer_fields(self): -        for serializer_cls in [ -            SensitiveDataSerializer1, -            SensitiveDataSerializer2, -            SensitiveDataSerializer3 -        ]: -            class OrderingListView(generics.ListAPIView): -                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') -                filter_backends = (filters.OrderingFilter,) -                serializer_class = serializer_cls - -            view = OrderingListView.as_view() -            request = factory.get('/', {'ordering': 'password'}) -            response = view(request) - -            if serializer_cls == SensitiveDataSerializer3: -                username_field = 'user' -            else: -                username_field = 'username' - -            # Note: The passwords are not in order.  Default ordering is used. -            self.assertEqual( -                response.data, -                [ -                    {'id': 1, username_field: 'userA'}, # PassB -                    {'id': 2, username_field: 'userB'}, # PassC -                    {'id': 3, username_field: 'userC'}, # PassA -                ] -            ) diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py deleted file mode 100644 index fa09c9e6..00000000 --- a/rest_framework/tests/test_genericrelations.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import unicode_literals -from django.contrib.contenttypes.models import ContentType -from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey -from django.db import models -from django.test import TestCase -from rest_framework import serializers -from rest_framework.compat import python_2_unicode_compatible - - -@python_2_unicode_compatible -class Tag(models.Model): -    """ -    Tags have a descriptive slug, and are attached to an arbitrary object. -    """ -    tag = models.SlugField() -    content_type = models.ForeignKey(ContentType) -    object_id = models.PositiveIntegerField() -    tagged_item = GenericForeignKey('content_type', 'object_id') - -    def __str__(self): -        return self.tag - - -@python_2_unicode_compatible -class Bookmark(models.Model): -    """ -    A URL bookmark that may have multiple tags attached. -    """ -    url = models.URLField() -    tags = GenericRelation(Tag) - -    def __str__(self): -        return 'Bookmark: %s' % self.url - - -@python_2_unicode_compatible -class Note(models.Model): -    """ -    A textual note that may have multiple tags attached. -    """ -    text = models.TextField() -    tags = GenericRelation(Tag) - -    def __str__(self): -        return 'Note: %s' % self.text - - -class TestGenericRelations(TestCase): -    def setUp(self): -        self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') -        Tag.objects.create(tagged_item=self.bookmark, tag='django') -        Tag.objects.create(tagged_item=self.bookmark, tag='python') -        self.note = Note.objects.create(text='Remember the milk') -        Tag.objects.create(tagged_item=self.note, tag='reminder') - -    def test_generic_relation(self): -        """ -        Test a relationship that spans a GenericRelation field. -        IE. A reverse generic relationship. -        """ - -        class BookmarkSerializer(serializers.ModelSerializer): -            tags = serializers.RelatedField(many=True) - -            class Meta: -                model = Bookmark -                exclude = ('id',) - -        serializer = BookmarkSerializer(self.bookmark) -        expected = { -            'tags': ['django', 'python'], -            'url': 'https://www.djangoproject.com/' -        } -        self.assertEqual(serializer.data, expected) - -    def test_generic_nested_relation(self): -        """ -        Test saving a GenericRelation field via a nested serializer. -        """ - -        class TagSerializer(serializers.ModelSerializer): -            class Meta: -                model = Tag -                exclude = ('content_type', 'object_id') - -        class BookmarkSerializer(serializers.ModelSerializer): -            tags = TagSerializer() - -            class Meta: -                model = Bookmark -                exclude = ('id',) - -        data = { -            'url': 'https://docs.djangoproject.com/', -            'tags': [ -                {'tag': 'contenttypes'}, -                {'tag': 'genericrelations'}, -            ] -        } -        serializer = BookmarkSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.object.tags.count(), 2) - -    def test_generic_fk(self): -        """ -        Test a relationship that spans a GenericForeignKey field. -        IE. A forward generic relationship. -        """ - -        class TagSerializer(serializers.ModelSerializer): -            tagged_item = serializers.RelatedField() - -            class Meta: -                model = Tag -                exclude = ('id', 'content_type', 'object_id') - -        serializer = TagSerializer(Tag.objects.all(), many=True) -        expected = [ -        { -            'tag': 'django', -            'tagged_item': 'Bookmark: https://www.djangoproject.com/' -        }, -        { -            'tag': 'python', -            'tagged_item': 'Bookmark: https://www.djangoproject.com/' -        }, -        { -            'tag': 'reminder', -            'tagged_item': 'Note: Remember the milk' -        } -        ] -        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py deleted file mode 100644 index 996bd5b0..00000000 --- a/rest_framework/tests/test_generics.py +++ /dev/null @@ -1,609 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.shortcuts import get_object_or_404 -from django.test import TestCase -from rest_framework import generics, renderers, serializers, status -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel -from rest_framework.compat import six - -factory = APIRequestFactory() - - -class RootView(generics.ListCreateAPIView): -    """ -    Example description for OPTIONS. -    """ -    model = BasicModel - - -class InstanceView(generics.RetrieveUpdateDestroyAPIView): -    """ -    Example description for OPTIONS. -    """ -    model = BasicModel - -    def get_queryset(self): -        queryset = super(InstanceView, self).get_queryset() -        return queryset.exclude(text='filtered out') - - -class SlugSerializer(serializers.ModelSerializer): -    slug = serializers.Field()  # read only - -    class Meta: -        model = SlugBasedModel -        exclude = ('id',) - - -class SlugBasedInstanceView(InstanceView): -    """ -    A model with a slug-field. -    """ -    model = SlugBasedModel -    serializer_class = SlugSerializer -    lookup_field = 'slug' - - -class TestRootView(TestCase): -    def setUp(self): -        """ -        Create 3 BasicModel instances. -        """ -        items = ['foo', 'bar', 'baz'] -        for item in items: -            BasicModel(text=item).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = RootView.as_view() - -    def test_get_root_view(self): -        """ -        GET requests to ListCreateAPIView should return list of objects. -        """ -        request = factory.get('/') -        with self.assertNumQueries(1): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -    def test_post_root_view(self): -        """ -        POST requests to ListCreateAPIView should create a new object. -        """ -        data = {'text': 'foobar'} -        request = factory.post('/', data, format='json') -        with self.assertNumQueries(1): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) -        created = self.objects.get(id=4) -        self.assertEqual(created.text, 'foobar') - -    def test_put_root_view(self): -        """ -        PUT requests to ListCreateAPIView should not be allowed -        """ -        data = {'text': 'foobar'} -        request = factory.put('/', data, format='json') -        with self.assertNumQueries(0): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) - -    def test_delete_root_view(self): -        """ -        DELETE requests to ListCreateAPIView should not be allowed -        """ -        request = factory.delete('/') -        with self.assertNumQueries(0): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) - -    def test_options_root_view(self): -        """ -        OPTIONS requests to ListCreateAPIView should return metadata -        """ -        request = factory.options('/') -        with self.assertNumQueries(0): -            response = self.view(request).render() -        expected = { -            'parses': [ -                'application/json', -                'application/x-www-form-urlencoded', -                'multipart/form-data' -            ], -            'renders': [ -                'application/json', -                'text/html' -            ], -            'name': 'Root', -            'description': 'Example description for OPTIONS.', -            'actions': { -                'POST': { -                    'text': { -                        'max_length': 100, -                        'read_only': False, -                        'required': True, -                        'type': 'string', -                        "label": "Text comes here", -                        "help_text": "Text description." -                    }, -                    'id': { -                        'read_only': True, -                        'required': False, -                        'type': 'integer', -                        'label': 'ID', -                    }, -                } -            } -        } -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, expected) - -    def test_post_cannot_set_id(self): -        """ -        POST requests to create a new object should not be able to set the id. -        """ -        data = {'id': 999, 'text': 'foobar'} -        request = factory.post('/', data, format='json') -        with self.assertNumQueries(1): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) -        created = self.objects.get(id=4) -        self.assertEqual(created.text, 'foobar') - - -class TestInstanceView(TestCase): -    def setUp(self): -        """ -        Create 3 BasicModel intances. -        """ -        items = ['foo', 'bar', 'baz', 'filtered out'] -        for item in items: -            BasicModel(text=item).save() -        self.objects = BasicModel.objects.exclude(text='filtered out') -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = InstanceView.as_view() -        self.slug_based_view = SlugBasedInstanceView.as_view() - -    def test_get_instance_view(self): -        """ -        GET requests to RetrieveUpdateDestroyAPIView should return a single object. -        """ -        request = factory.get('/1') -        with self.assertNumQueries(1): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data[0]) - -    def test_post_instance_view(self): -        """ -        POST requests to RetrieveUpdateDestroyAPIView should not be allowed -        """ -        data = {'text': 'foobar'} -        request = factory.post('/', data, format='json') -        with self.assertNumQueries(0): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) - -    def test_put_instance_view(self): -        """ -        PUT requests to RetrieveUpdateDestroyAPIView should update an object. -        """ -        data = {'text': 'foobar'} -        request = factory.put('/1', data, format='json') -        with self.assertNumQueries(2): -            response = self.view(request, pk='1').render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) -        updated = self.objects.get(id=1) -        self.assertEqual(updated.text, 'foobar') - -    def test_patch_instance_view(self): -        """ -        PATCH requests to RetrieveUpdateDestroyAPIView should update an object. -        """ -        data = {'text': 'foobar'} -        request = factory.patch('/1', data, format='json') - -        with self.assertNumQueries(2): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) -        updated = self.objects.get(id=1) -        self.assertEqual(updated.text, 'foobar') - -    def test_delete_instance_view(self): -        """ -        DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. -        """ -        request = factory.delete('/1') -        with self.assertNumQueries(2): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) -        self.assertEqual(response.content, six.b('')) -        ids = [obj.id for obj in self.objects.all()] -        self.assertEqual(ids, [2, 3]) - -    def test_options_instance_view(self): -        """ -        OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata -        """ -        request = factory.options('/1') -        with self.assertNumQueries(1): -            response = self.view(request, pk=1).render() -        expected = { -            'parses': [ -                'application/json', -                'application/x-www-form-urlencoded', -                'multipart/form-data' -            ], -            'renders': [ -                'application/json', -                'text/html' -            ], -            'name': 'Instance', -            'description': 'Example description for OPTIONS.', -            'actions': { -                'PUT': { -                    'text': { -                        'max_length': 100, -                        'read_only': False, -                        'required': True, -                        'type': 'string', -                        'label': 'Text comes here', -                        'help_text': 'Text description.' -                    }, -                    'id': { -                        'read_only': True, -                        'required': False, -                        'type': 'integer', -                        'label': 'ID', -                    }, -                } -            } -        } -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, expected) - -    def test_options_before_instance_create(self): -        """ -        OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata -        before the instance has been created -        """ -        request = factory.options('/999') -        with self.assertNumQueries(1): -            response = self.view(request, pk=999).render() -        expected = { -            'parses': [ -                'application/json', -                'application/x-www-form-urlencoded', -                'multipart/form-data' -            ], -            'renders': [ -                'application/json', -                'text/html' -            ], -            'name': 'Instance', -            'description': 'Example description for OPTIONS.', -            'actions': { -                'PUT': { -                    'text': { -                        'max_length': 100, -                        'read_only': False, -                        'required': True, -                        'type': 'string', -                        'label': 'Text comes here', -                        'help_text': 'Text description.' -                    }, -                    'id': { -                        'read_only': True, -                        'required': False, -                        'type': 'integer', -                        'label': 'ID', -                    }, -                } -            } -        } -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, expected) - -    def test_get_instance_view_incorrect_arg(self): -        """ -        GET requests with an incorrect pk type, should raise 404, not 500. -        Regression test for #890. -        """ -        request = factory.get('/a') -        with self.assertNumQueries(0): -            response = self.view(request, pk='a').render() -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - -    def test_put_cannot_set_id(self): -        """ -        PUT requests to create a new object should not be able to set the id. -        """ -        data = {'id': 999, 'text': 'foobar'} -        request = factory.put('/1', data, format='json') -        with self.assertNumQueries(2): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) -        updated = self.objects.get(id=1) -        self.assertEqual(updated.text, 'foobar') - -    def test_put_to_deleted_instance(self): -        """ -        PUT requests to RetrieveUpdateDestroyAPIView should create an object -        if it does not currently exist. -        """ -        self.objects.get(id=1).delete() -        data = {'text': 'foobar'} -        request = factory.put('/1', data, format='json') -        with self.assertNumQueries(3): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) -        updated = self.objects.get(id=1) -        self.assertEqual(updated.text, 'foobar') - -    def test_put_to_filtered_out_instance(self): -        """ -        PUT requests to an URL of instance which is filtered out should not be -        able to create new objects. -        """ -        data = {'text': 'foo'} -        filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk -        request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') -        response = self.view(request, pk=filtered_out_pk).render() -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - -    def test_put_as_create_on_id_based_url(self): -        """ -        PUT requests to RetrieveUpdateDestroyAPIView should create an object -        at the requested url if it doesn't exist. -        """ -        data = {'text': 'foobar'} -        # pk fields can not be created on demand, only the database can set the pk for a new object -        request = factory.put('/5', data, format='json') -        with self.assertNumQueries(3): -            response = self.view(request, pk=5).render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        new_obj = self.objects.get(pk=5) -        self.assertEqual(new_obj.text, 'foobar') - -    def test_put_as_create_on_slug_based_url(self): -        """ -        PUT requests to RetrieveUpdateDestroyAPIView should create an object -        at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. -        """ -        data = {'text': 'foobar'} -        request = factory.put('/test_slug', data, format='json') -        with self.assertNumQueries(2): -            response = self.slug_based_view(request, slug='test_slug').render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) -        new_obj = SlugBasedModel.objects.get(slug='test_slug') -        self.assertEqual(new_obj.text, 'foobar') - -    def test_patch_cannot_create_an_object(self): -        """ -        PATCH requests should not be able to create objects. -        """ -        data = {'text': 'foobar'} -        request = factory.patch('/999', data, format='json') -        with self.assertNumQueries(1): -            response = self.view(request, pk=999).render() -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertFalse(self.objects.filter(id=999).exists()) - - -class TestOverriddenGetObject(TestCase): -    """ -    Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the -    queryset/model mechanism but instead overrides get_object() -    """ -    def setUp(self): -        """ -        Create 3 BasicModel intances. -        """ -        items = ['foo', 'bar', 'baz'] -        for item in items: -            BasicModel(text=item).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] - -        class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): -            """ -            Example detail view for override of get_object(). -            """ -            model = BasicModel - -            def get_object(self): -                pk = int(self.kwargs['pk']) -                return get_object_or_404(BasicModel.objects.all(), id=pk) - -        self.view = OverriddenGetObjectView.as_view() - -    def test_overridden_get_object_view(self): -        """ -        GET requests to RetrieveUpdateDestroyAPIView should return a single object. -        """ -        request = factory.get('/1') -        with self.assertNumQueries(1): -            response = self.view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data[0]) - - -# Regression test for #285 - -class CommentSerializer(serializers.ModelSerializer): -    class Meta: -        model = Comment -        exclude = ('created',) - - -class CommentView(generics.ListCreateAPIView): -    serializer_class = CommentSerializer -    model = Comment - - -class TestCreateModelWithAutoNowAddField(TestCase): -    def setUp(self): -        self.objects = Comment.objects -        self.view = CommentView.as_view() - -    def test_create_model_with_auto_now_add_field(self): -        """ -        Regression test for #285 - -        https://github.com/tomchristie/django-rest-framework/issues/285 -        """ -        data = {'email': 'foobar@example.com', 'content': 'foobar'} -        request = factory.post('/', data, format='json') -        response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        created = self.objects.get(id=1) -        self.assertEqual(created.content, 'foobar') - - -# Test for particularly ugly regression with m2m in browsable API -class ClassB(models.Model): -    name = models.CharField(max_length=255) - - -class ClassA(models.Model): -    name = models.CharField(max_length=255) -    childs = models.ManyToManyField(ClassB, blank=True, null=True) - - -class ClassASerializer(serializers.ModelSerializer): -    childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') - -    class Meta: -        model = ClassA - - -class ExampleView(generics.ListCreateAPIView): -    serializer_class = ClassASerializer -    model = ClassA - - -class TestM2MBrowseableAPI(TestCase): -    def test_m2m_in_browseable_api(self): -        """ -        Test for particularly ugly regression with m2m in browsable API -        """ -        request = factory.get('/', HTTP_ACCEPT='text/html') -        view = ExampleView().as_view() -        response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) - - -class InclusiveFilterBackend(object): -    def filter_queryset(self, request, queryset, view): -        return queryset.filter(text='foo') - - -class ExclusiveFilterBackend(object): -    def filter_queryset(self, request, queryset, view): -        return queryset.filter(text='other') - - -class TwoFieldModel(models.Model): -    field_a = models.CharField(max_length=100) -    field_b = models.CharField(max_length=100) - - -class DynamicSerializerView(generics.ListCreateAPIView): -    model = TwoFieldModel -    renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) - -    def get_serializer_class(self): -        if self.request.method == 'POST': -            class DynamicSerializer(serializers.ModelSerializer): -                class Meta: -                    model = TwoFieldModel -                    fields = ('field_b',) -            return DynamicSerializer -        return super(DynamicSerializerView, self).get_serializer_class() - - -class TestFilterBackendAppliedToViews(TestCase): - -    def setUp(self): -        """ -        Create 3 BasicModel instances to filter on. -        """ -        items = ['foo', 'bar', 'baz'] -        for item in items: -            BasicModel(text=item).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] - -    def test_get_root_view_filters_by_name_with_filter_backend(self): -        """ -        GET requests to ListCreateAPIView should return filtered list. -        """ -        root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) -        request = factory.get('/') -        response = root_view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(len(response.data), 1) -        self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) - -    def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): -        """ -        GET requests to ListCreateAPIView should return empty list when all models are filtered out. -        """ -        root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) -        request = factory.get('/') -        response = root_view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, []) - -    def test_get_instance_view_filters_out_name_with_filter_backend(self): -        """ -        GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. -        """ -        instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) -        request = factory.get('/1') -        response = instance_view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.data, {'detail': 'Not found'}) - -    def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): -        """ -        GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded -        """ -        instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) -        request = factory.get('/1') -        response = instance_view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) - -    def test_dynamic_serializer_form_in_browsable_api(self): -        """ -        GET requests to ListCreateAPIView should return filtered list. -        """ -        view = DynamicSerializerView.as_view() -        request = factory.get('/') -        response = view(request).render() -        self.assertContains(response, 'field_b') -        self.assertNotContains(response, 'field_a') diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py deleted file mode 100644 index 514d9e2b..00000000 --- a/rest_framework/tests/test_htmlrenderer.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import unicode_literals -from django.core.exceptions import PermissionDenied -from django.http import Http404 -from django.test import TestCase -from django.template import TemplateDoesNotExist, Template -import django.template.loader -from rest_framework import status -from rest_framework.compat import patterns, url -from rest_framework.decorators import api_view, renderer_classes -from rest_framework.renderers import TemplateHTMLRenderer -from rest_framework.response import Response -from rest_framework.compat import six - - -@api_view(('GET',)) -@renderer_classes((TemplateHTMLRenderer,)) -def example(request): -    """ -    A view that can returns an HTML representation. -    """ -    data = {'object': 'foobar'} -    return Response(data, template_name='example.html') - - -@api_view(('GET',)) -@renderer_classes((TemplateHTMLRenderer,)) -def permission_denied(request): -    raise PermissionDenied() - - -@api_view(('GET',)) -@renderer_classes((TemplateHTMLRenderer,)) -def not_found(request): -    raise Http404() - - -urlpatterns = patterns('', -    url(r'^$', example), -    url(r'^permission_denied$', permission_denied), -    url(r'^not_found$', not_found), -) - - -class TemplateHTMLRendererTests(TestCase): -    urls = 'rest_framework.tests.test_htmlrenderer' - -    def setUp(self): -        """ -        Monkeypatch get_template -        """ -        self.get_template = django.template.loader.get_template - -        def get_template(template_name, dirs=None): -            if template_name == 'example.html': -                return Template("example: {{ object }}") -            raise TemplateDoesNotExist(template_name) - -        django.template.loader.get_template = get_template - -    def tearDown(self): -        """ -        Revert monkeypatching -        """ -        django.template.loader.get_template = self.get_template - -    def test_simple_html_view(self): -        response = self.client.get('/') -        self.assertContains(response, "example: foobar") -        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') - -    def test_not_found_html_view(self): -        response = self.client.get('/not_found') -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.content, six.b("404 Not Found")) -        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') - -    def test_permission_denied_html_view(self): -        response = self.client.get('/permission_denied') -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -        self.assertEqual(response.content, six.b("403 Forbidden")) -        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') - - -class TemplateHTMLRendererExceptionTests(TestCase): -    urls = 'rest_framework.tests.test_htmlrenderer' - -    def setUp(self): -        """ -        Monkeypatch get_template -        """ -        self.get_template = django.template.loader.get_template - -        def get_template(template_name): -            if template_name == '404.html': -                return Template("404: {{ detail }}") -            if template_name == '403.html': -                return Template("403: {{ detail }}") -            raise TemplateDoesNotExist(template_name) - -        django.template.loader.get_template = get_template - -    def tearDown(self): -        """ -        Revert monkeypatching -        """ -        django.template.loader.get_template = self.get_template - -    def test_not_found_html_view_with_template(self): -        response = self.client.get('/not_found') -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertTrue(response.content in ( -            six.b("404: Not found"), six.b("404 Not Found"))) -        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') - -    def test_permission_denied_html_view_with_template(self): -        response = self.client.get('/permission_denied') -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -        self.assertTrue(response.content in ( -            six.b("403: Permission denied"), six.b("403 Forbidden"))) -        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py deleted file mode 100644 index 83d46043..00000000 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ /dev/null @@ -1,379 +0,0 @@ -from __future__ import unicode_literals -import json -from django.test import TestCase -from rest_framework import generics, status, serializers -from rest_framework.compat import patterns, url -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import ( -    Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, -    Album, Photo, OptionalRelationModel -) - -factory = APIRequestFactory() - - -class BlogPostCommentSerializer(serializers.ModelSerializer): -    url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') -    text = serializers.CharField() -    blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') - -    class Meta: -        model = BlogPostComment -        fields = ('text', 'blog_post_url', 'url') - - -class PhotoSerializer(serializers.Serializer): -    description = serializers.CharField() -    album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') - -    def restore_object(self, attrs, instance=None): -        return Photo(**attrs) - - -class AlbumSerializer(serializers.ModelSerializer): -    url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') - -    class Meta: -        model = Album -        fields = ('title', 'url') - - -class BasicList(generics.ListCreateAPIView): -    model = BasicModel -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -class BasicDetail(generics.RetrieveUpdateDestroyAPIView): -    model = BasicModel -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -class AnchorDetail(generics.RetrieveAPIView): -    model = Anchor -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -class ManyToManyList(generics.ListAPIView): -    model = ManyToManyModel -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -class ManyToManyDetail(generics.RetrieveAPIView): -    model = ManyToManyModel -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -class BlogPostCommentListCreate(generics.ListCreateAPIView): -    model = BlogPostComment -    serializer_class = BlogPostCommentSerializer - - -class BlogPostCommentDetail(generics.RetrieveAPIView): -    model = BlogPostComment -    serializer_class = BlogPostCommentSerializer - - -class BlogPostDetail(generics.RetrieveAPIView): -    model = BlogPost - - -class PhotoListCreate(generics.ListCreateAPIView): -    model = Photo -    model_serializer_class = PhotoSerializer - - -class AlbumDetail(generics.RetrieveAPIView): -    model = Album -    serializer_class = AlbumSerializer -    lookup_field = 'title' - - -class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): -    model = OptionalRelationModel -    model_serializer_class = serializers.HyperlinkedModelSerializer - - -urlpatterns = patterns('', -    url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), -    url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), -    url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), -    url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), -    url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), -    url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), -    url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), -    url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), -    url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), -    url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), -    url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), -) - - -class TestBasicHyperlinkedView(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create 3 BasicModel instances. -        """ -        items = ['foo', 'bar', 'baz'] -        for item in items: -            BasicModel(text=item).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.list_view = BasicList.as_view() -        self.detail_view = BasicDetail.as_view() - -    def test_get_list_view(self): -        """ -        GET requests to ListCreateAPIView should return list of objects. -        """ -        request = factory.get('/basic/') -        response = self.list_view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -    def test_get_detail_view(self): -        """ -        GET requests to ListCreateAPIView should return list of objects. -        """ -        request = factory.get('/basic/1') -        response = self.detail_view(request, pk=1).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data[0]) - - -class TestManyToManyHyperlinkedView(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create 3 BasicModel instances. -        """ -        items = ['foo', 'bar', 'baz'] -        anchors = [] -        for item in items: -            anchor = Anchor(text=item) -            anchor.save() -            anchors.append(anchor) - -        manytomany = ManyToManyModel() -        manytomany.save() -        manytomany.rel.add(*anchors) - -        self.data = [{ -            'url': 'http://testserver/manytomany/1/', -            'rel': [ -                'http://testserver/anchor/1/', -                'http://testserver/anchor/2/', -                'http://testserver/anchor/3/', -            ] -        }] -        self.list_view = ManyToManyList.as_view() -        self.detail_view = ManyToManyDetail.as_view() - -    def test_get_list_view(self): -        """ -        GET requests to ListCreateAPIView should return list of objects. -        """ -        request = factory.get('/manytomany/') -        response = self.list_view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -    def test_get_detail_view(self): -        """ -        GET requests to ListCreateAPIView should return list of objects. -        """ -        request = factory.get('/manytomany/1/') -        response = self.detail_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data[0]) - - -class TestHyperlinkedIdentityFieldLookup(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create 3 Album instances. -        """ -        titles = ['foo', 'bar', 'baz'] -        for title in titles: -            album = Album(title=title) -            album.save() -        self.detail_view = AlbumDetail.as_view() -        self.data = { -            'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, -            'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, -            'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} -        } - -    def test_lookup_field(self): -        """ -        GET requests to AlbumDetail view should return serialized Albums -        with a url field keyed by `title`. -        """ -        for album in Album.objects.all(): -            request = factory.get('/albums/{0}/'.format(album.title)) -            response = self.detail_view(request, title=album.title) -            self.assertEqual(response.status_code, status.HTTP_200_OK) -            self.assertEqual(response.data, self.data[album.title]) - - -class TestCreateWithForeignKeys(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create a blog post -        """ -        self.post = BlogPost.objects.create(title="Test post") -        self.create_view = BlogPostCommentListCreate.as_view() - -    def test_create_comment(self): - -        data = { -            'text': 'A test comment', -            'blog_post_url': 'http://testserver/posts/1/' -        } - -        request = factory.post('/comments/', data=data) -        response = self.create_view(request) -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertEqual(response['Location'], 'http://testserver/comments/1/') -        self.assertEqual(self.post.blogpostcomment_set.count(), 1) -        self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') - - -class TestCreateWithForeignKeysAndCustomSlug(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create an Album -        """ -        self.post = Album.objects.create(title='test-album') -        self.list_create_view = PhotoListCreate.as_view() - -    def test_create_photo(self): - -        data = { -            'description': 'A test photo', -            'album_url': 'http://testserver/albums/test-album/' -        } - -        request = factory.post('/photos/', data=data) -        response = self.list_create_view(request) -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) -        self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') -        self.assertEqual(self.post.photo_set.count(), 1) -        self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') - - -class TestOptionalRelationHyperlinkedView(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        """ -        Create 1 OptionalRelationModel instances. -        """ -        OptionalRelationModel().save() -        self.objects = OptionalRelationModel.objects -        self.detail_view = OptionalRelationDetail.as_view() -        self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} - -    def test_get_detail_view(self): -        """ -        GET requests to RetrieveAPIView with optional relations should return None -        for non existing relations. -        """ -        request = factory.get('/optionalrelationmodel-detail/1') -        response = self.detail_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data, self.data) - -    def test_put_detail_view(self): -        """ -        PUT requests to RetrieveUpdateDestroyAPIView with optional relations -        should accept None for non existing relations. -        """ -        response = self.client.put('/optionalrelation/1/', -                                   data=json.dumps(self.data), -                                   content_type='application/json') -        self.assertEqual(response.status_code, status.HTTP_200_OK) - - -class TestOverriddenURLField(TestCase): -    def setUp(self): -        class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): -            url = serializers.SerializerMethodField('get_url') - -            class Meta: -                model = BlogPost -                fields = ('title', 'url') - -            def get_url(self, obj): -                return 'foo bar' - -        self.Serializer = OverriddenURLSerializer -        self.obj = BlogPost.objects.create(title='New blog post') - -    def test_overridden_url_field(self): -        """ -        The 'url' field should respect overriding. -        Regression test for #936. -        """ -        serializer = self.Serializer(self.obj) -        self.assertEqual( -            serializer.data, -            {'title': 'New blog post', 'url': 'foo bar'} -        ) - - -class TestURLFieldNameBySettings(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        self.saved_url_field_name = api_settings.URL_FIELD_NAME -        api_settings.URL_FIELD_NAME = 'global_url_field' - -        class Serializer(serializers.HyperlinkedModelSerializer): - -            class Meta: -                model = BlogPost -                fields = ('title', api_settings.URL_FIELD_NAME) - -        self.Serializer = Serializer -        self.obj = BlogPost.objects.create(title="New blog post") - -    def tearDown(self): -        api_settings.URL_FIELD_NAME = self.saved_url_field_name - -    def test_overridden_url_field_name(self): -        request = factory.get('/posts/') -        serializer = self.Serializer(self.obj, context={'request': request}) -        self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) - - -class TestURLFieldNameByOptions(TestCase): -    urls = 'rest_framework.tests.test_hyperlinkedserializers' - -    def setUp(self): -        class Serializer(serializers.HyperlinkedModelSerializer): - -            class Meta: -                model = BlogPost -                fields = ('title', 'serializer_url_field') -                url_field_name = 'serializer_url_field' - -        self.Serializer = Serializer -        self.obj = BlogPost.objects.create(title="New blog post") - -    def test_overridden_url_field_name(self): -        request = factory.get('/posts/') -        serializer = self.Serializer(self.obj, context={'request': request}) -        self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/rest_framework/tests/test_multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py deleted file mode 100644 index 00c15327..00000000 --- a/rest_framework/tests/test_multitable_inheritance.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers -from rest_framework.tests.models import RESTFrameworkModel - - -# Models -class ParentModel(RESTFrameworkModel): -    name1 = models.CharField(max_length=100) - - -class ChildModel(ParentModel): -    name2 = models.CharField(max_length=100) - - -class AssociatedModel(RESTFrameworkModel): -    ref = models.OneToOneField(ParentModel, primary_key=True) -    name = models.CharField(max_length=100) - - -# Serializers -class DerivedModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = ChildModel - - -class AssociatedModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = AssociatedModel - - -# Tests -class IneritedModelSerializationTests(TestCase): - -    def test_multitable_inherited_model_fields_as_expected(self): -        """ -        Assert that the parent pointer field is not included in the fields -        serialized fields -        """ -        child = ChildModel(name1='parent name', name2='child name') -        serializer = DerivedModelSerializer(child) -        self.assertEqual(set(serializer.data.keys()), -                         set(['name1', 'name2', 'id'])) - -    def test_onetoone_primary_key_model_fields_as_expected(self): -        """ -        Assert that a model with a onetoone field that is the primary key is -        not treated like a derived model -        """ -        parent = ParentModel(name1='parent name') -        associate = AssociatedModel(name='hello', ref=parent) -        serializer = AssociatedModelSerializer(associate) -        self.assertEqual(set(serializer.data.keys()), -                         set(['name', 'ref'])) - -    def test_data_is_valid_without_parent_ptr(self): -        """ -        Assert that the pointer to the parent table is not a required field -        for input data -        """ -        data = { -            'name1': 'parent name', -            'name2': 'child name', -        } -        serializer = DerivedModelSerializer(data=data) -        self.assertEqual(serializer.is_valid(), True) diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py deleted file mode 100644 index 04b89eb6..00000000 --- a/rest_framework/tests/test_negotiation.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.negotiation import DefaultContentNegotiation -from rest_framework.request import Request -from rest_framework.renderers import BaseRenderer -from rest_framework.test import APIRequestFactory - - -factory = APIRequestFactory() - - -class MockJSONRenderer(BaseRenderer): -    media_type = 'application/json' - - -class MockHTMLRenderer(BaseRenderer): -    media_type = 'text/html' - - -class NoCharsetSpecifiedRenderer(BaseRenderer): -    media_type = 'my/media' - - -class TestAcceptedMediaType(TestCase): -    def setUp(self): -        self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] -        self.negotiator = DefaultContentNegotiation() - -    def select_renderer(self, request): -        return self.negotiator.select_renderer(request, self.renderers) - -    def test_client_without_accept_use_renderer(self): -        request = Request(factory.get('/')) -        accepted_renderer, accepted_media_type = self.select_renderer(request) -        self.assertEqual(accepted_media_type, 'application/json') - -    def test_client_underspecifies_accept_use_renderer(self): -        request = Request(factory.get('/', HTTP_ACCEPT='*/*')) -        accepted_renderer, accepted_media_type = self.select_renderer(request) -        self.assertEqual(accepted_media_type, 'application/json') - -    def test_client_overspecifies_accept_use_client(self): -        request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) -        accepted_renderer, accepted_media_type = self.select_renderer(request) -        self.assertEqual(accepted_media_type, 'application/json; indent=8') diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py deleted file mode 100644 index 6ee55c00..00000000 --- a/rest_framework/tests/test_nullable_fields.py +++ /dev/null @@ -1,30 +0,0 @@ -from django.core.urlresolvers import reverse - -from rest_framework.compat import patterns, url -from rest_framework.test import APITestCase -from rest_framework.tests.models import NullableForeignKeySource -from rest_framework.tests.serializers import NullableFKSourceSerializer -from rest_framework.tests.views import NullableFKSourceDetail - - -urlpatterns = patterns( -    '', -    url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), -) - - -class NullableForeignKeyTests(APITestCase): -    """ -    DRF should be able to handle nullable foreign keys when a test -    Client POST/PUT request is made with its own serialized object. -    """ -    urls = 'rest_framework.tests.test_nullable_fields' - -    def test_updating_object_with_null_fk(self): -        obj = NullableForeignKeySource(name='example', target=None) -        obj.save() -        serialized_data = NullableFKSourceSerializer(obj).data - -        response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) - -        self.assertEqual(response.data, serialized_data) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py deleted file mode 100644 index 24c1ba39..00000000 --- a/rest_framework/tests/test_pagination.py +++ /dev/null @@ -1,521 +0,0 @@ -from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.db import models -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, status, pagination, filters, serializers -from rest_framework.compat import django_filters -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel -from .models import FilterableItem - -factory = APIRequestFactory() - -# Helper function to split arguments out of an url -def split_arguments_from_url(url): -    if '?' not in url: -        return url - -    path, args = url.split('?') -    args = dict(r.split('=') for r in args.split('&')) -    return path, args - - -class RootView(generics.ListCreateAPIView): -    """ -    Example description for OPTIONS. -    """ -    model = BasicModel -    paginate_by = 10 - - -class DefaultPageSizeKwargView(generics.ListAPIView): -    """ -    View for testing default paginate_by_param usage -    """ -    model = BasicModel - - -class PaginateByParamView(generics.ListAPIView): -    """ -    View for testing custom paginate_by_param usage -    """ -    model = BasicModel -    paginate_by_param = 'page_size' - - -class MaxPaginateByView(generics.ListAPIView): -    """ -    View for testing custom max_paginate_by usage -    """ -    model = BasicModel -    paginate_by = 3 -    max_paginate_by = 5 -    paginate_by_param = 'page_size' - - -class IntegrationTestPagination(TestCase): -    """ -    Integration tests for paginated list views. -    """ - -    def setUp(self): -        """ -        Create 26 BasicModel instances. -        """ -        for char in 'abcdefghijklmnopqrstuvwxyz': -            BasicModel(text=char * 3).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = RootView.as_view() - -    def test_get_paginated_root_view(self): -        """ -        GET requests to paginated ListCreateAPIView should return paginated results. -        """ -        request = factory.get('/') -        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[10:20]) -        self.assertNotEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[20:]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - -    def setUp(self): -        """ -        Create 50 FilterableItem instances. -        """ -        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) -        for i in range(26): -            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. -            decimal = base_data[1] + i -            date = base_data[2] - datetime.timedelta(days=i * 2) -            FilterableItem(text=text, decimal=decimal, date=date).save() - -        self.objects = FilterableItem.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} -            for obj in self.objects.all() -        ] - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_django_filter_paginated_filtered_root_view(self): -        """ -        GET requests to paginated filtered ListCreateAPIView should return -        paginated results. The next and previous links should preserve the -        filtered parameters. -        """ -        class DecimalFilter(django_filters.FilterSet): -            decimal = django_filters.NumberFilter(lookup_type='lt') - -            class Meta: -                model = FilterableItem -                fields = ['text', 'decimal', 'date'] - -        class FilterFieldsRootView(generics.ListCreateAPIView): -            model = FilterableItem -            paginate_by = 10 -            filter_class = DecimalFilter -            filter_backends = (filters.DjangoFilterBackend,) - -        view = FilterFieldsRootView.as_view() - -        EXPECTED_NUM_QUERIES = 2 - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -    def test_get_basic_paginated_filtered_root_view(self): -        """ -        Same as `test_get_django_filter_paginated_filtered_root_view`, -        except using a custom filter backend instead of the django-filter -        backend, -        """ - -        class DecimalFilterBackend(filters.BaseFilterBackend): -            def filter_queryset(self, request, queryset, view): -                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - -        class BasicFilterFieldsRootView(generics.ListCreateAPIView): -            model = FilterableItem -            paginate_by = 10 -            filter_backends = (DecimalFilterBackend,) - -        view = BasicFilterFieldsRootView.as_view() - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): -    """ -    Unit tests for pagination of primitive objects. -    """ - -    def setUp(self): -        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] -        paginator = Paginator(self.objects, 10) -        self.first_page = paginator.page(1) -        self.last_page = paginator.page(3) - -    def test_native_pagination(self): -        serializer = pagination.PaginationSerializer(self.first_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], '?page=2') -        self.assertEqual(serializer.data['previous'], None) -        self.assertEqual(serializer.data['results'], self.objects[:10]) - -        serializer = pagination.PaginationSerializer(self.last_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], None) -        self.assertEqual(serializer.data['previous'], '?page=2') -        self.assertEqual(serializer.data['results'], self.objects[20:]) - -    def test_context_available_in_result(self): -        """ -        Ensure context gets passed through to the object serializer. -        """ -        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) -        serializer.data -        results = serializer.fields[serializer.results_field] -        self.assertEqual(serializer.context, results.context) - - -class TestUnpaginated(TestCase): -    """ -    Tests for list views without pagination. -    """ - -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -        {'id': obj.id, 'text': obj.text} -        for obj in self.objects.all() -        ] -        self.view = DefaultPageSizeKwargView.as_view() - -    def test_unpaginated(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') -        response = self.view(request) -        self.assertEqual(response.data, self.data) - - -class TestCustomPaginateByParam(TestCase): -    """ -    Tests for list views with default page size kwarg -    """ - -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -        {'id': obj.id, 'text': obj.text} -        for obj in self.objects.all() -        ] -        self.view = PaginateByParamView.as_view() - -    def test_default_page_size(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data, self.data) - -    def test_paginate_by_param(self): -        """ -        If paginate_by_param is set, the new kwarg should limit per view requests. -        """ -        request = factory.get('/', {'page_size': 5}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) - - -class TestMaxPaginateByParam(TestCase): -    """ -    Tests for list views with max_paginate_by kwarg -    """ - -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = MaxPaginateByView.as_view() - -    def test_max_paginate_by(self): -        """ -        If max_paginate_by is set, it should limit page size for the view. -        """ -        request = factory.get('/', data={'page_size': 10}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) - -    def test_max_paginate_by_without_page_size_param(self): -        """ -        If max_paginate_by is set, but client does not specifiy page_size, -        standard `paginate_by` behavior should be used. -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data['results'], self.data[:3]) - - -### Tests for context in pagination serializers - -class CustomField(serializers.Field): -    def to_native(self, value): -        if not 'view' in self.context: -            raise RuntimeError("context isn't getting passed into custom field") -        return "value" - - -class BasicModelSerializer(serializers.Serializer): -    text = CustomField() - -    def __init__(self, *args, **kwargs): -        super(BasicModelSerializer, self).__init__(*args, **kwargs) -        if not 'view' in self.context: -            raise RuntimeError("context isn't getting passed into serializer init") - - -class TestContextPassedToCustomField(TestCase): -    def setUp(self): -        BasicModel.objects.create(text='ala ma kota') - -    def test_with_pagination(self): -        class ListView(generics.ListCreateAPIView): -            model = BasicModel -            serializer_class = BasicModelSerializer -            paginate_by = 1 - -        self.view = ListView.as_view() -        request = factory.get('/') -        response = self.view(request).render() - -        self.assertEqual(response.status_code, status.HTTP_200_OK) - - -### Tests for custom pagination serializers - -class LinksSerializer(serializers.Serializer): -    next = pagination.NextPageField(source='*') -    prev = pagination.PreviousPageField(source='*') - - -class CustomPaginationSerializer(pagination.BasePaginationSerializer): -    links = LinksSerializer(source='*')  # Takes the page object as the source -    total_results = serializers.Field(source='paginator.count') - -    results_field = 'objects' - - -class TestCustomPaginationSerializer(TestCase): -    def setUp(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = Paginator(objects, 2) -        self.page = paginator.page(1) - -    def test_custom_pagination_serializer(self): -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=self.page, -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page=2', -                'prev': None -            }, -            'total_results': 4, -            'objects': ['john', 'paul'] -        } -        self.assertEqual(serializer.data, expected) - - -class NonIntegerPage(object): - -    def __init__(self, paginator, object_list, prev_token, token, next_token): -        self.paginator = paginator -        self.object_list = object_list -        self.prev_token = prev_token -        self.token = token -        self.next_token = next_token - -    def has_next(self): -        return not not self.next_token - -    def next_page_number(self): -        return self.next_token - -    def has_previous(self): -        return not not self.prev_token - -    def previous_page_number(self): -        return self.prev_token - - -class NonIntegerPaginator(object): - -    def __init__(self, object_list, per_page): -        self.object_list = object_list -        self.per_page = per_page - -    def count(self): -        # pretend like we don't know how many pages we have -        return None - -    def page(self, token=None): -        if token: -            try: -                first = self.object_list.index(token) -            except ValueError: -                first = 0 -        else: -            first = 0 -        n = len(self.object_list) -        last = min(first + self.per_page, n) -        prev_token = self.object_list[last - (2 * self.per_page)] if first else None -        next_token = self.object_list[last] if last < n else None -        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - - -    def test_custom_pagination_serializer(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = NonIntegerPaginator(objects, 2) - -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page(), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), -                'prev': None -            }, -            'total_results': None, -            'objects': objects[:2] -        } -        self.assertEqual(serializer.data, expected) - -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page('george'), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': None, -                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), -            }, -            'total_results': None, -            'objects': objects[2:] -        } -        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_parsers.py b/rest_framework/tests/test_parsers.py deleted file mode 100644 index 7699e10c..00000000 --- a/rest_framework/tests/test_parsers.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import unicode_literals -from rest_framework.compat import StringIO -from django import forms -from django.core.files.uploadhandler import MemoryFileUploadHandler -from django.test import TestCase -from django.utils import unittest -from rest_framework.compat import etree -from rest_framework.parsers import FormParser, FileUploadParser -from rest_framework.parsers import XMLParser -import datetime - - -class Form(forms.Form): -    field1 = forms.CharField(max_length=3) -    field2 = forms.CharField() - - -class TestFormParser(TestCase): -    def setUp(self): -        self.string = "field1=abc&field2=defghijk" - -    def test_parse(self): -        """ Make sure the `QueryDict` works OK """ -        parser = FormParser() - -        stream = StringIO(self.string) -        data = parser.parse(stream) - -        self.assertEqual(Form(data).is_valid(), True) - - -class TestXMLParser(TestCase): -    def setUp(self): -        self._input = StringIO( -            '<?xml version="1.0" encoding="utf-8"?>' -            '<root>' -            '<field_a>121.0</field_a>' -            '<field_b>dasd</field_b>' -            '<field_c></field_c>' -            '<field_d>2011-12-25 12:45:00</field_d>' -            '</root>' -        ) -        self._data = { -            'field_a': 121, -            'field_b': 'dasd', -            'field_c': None, -            'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) -        } -        self._complex_data_input = StringIO( -            '<?xml version="1.0" encoding="utf-8"?>' -            '<root>' -            '<creation_date>2011-12-25 12:45:00</creation_date>' -            '<sub_data_list>' -            '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>' -            '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>' -            '</sub_data_list>' -            '<name>name</name>' -            '</root>' -        ) -        self._complex_data = { -            "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), -            "name": "name", -            "sub_data_list": [ -                { -                    "sub_id": 1, -                    "sub_name": "first" -                }, -                { -                    "sub_id": 2, -                    "sub_name": "second" -                } -            ] -        } - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_parse(self): -        parser = XMLParser() -        data = parser.parse(self._input) -        self.assertEqual(data, self._data) - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_complex_data_parse(self): -        parser = XMLParser() -        data = parser.parse(self._complex_data_input) -        self.assertEqual(data, self._complex_data) - - -class TestFileUploadParser(TestCase): -    def setUp(self): -        class MockRequest(object): -            pass -        from io import BytesIO -        self.stream = BytesIO( -            "Test text file".encode('utf-8') -        ) -        request = MockRequest() -        request.upload_handlers = (MemoryFileUploadHandler(),) -        request.META = { -            'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), -            'HTTP_CONTENT_LENGTH': 14, -        } -        self.parser_context = {'request': request, 'kwargs': {}} - -    def test_parse(self): -        """ Make sure the `QueryDict` works OK """ -        parser = FileUploadParser() -        self.stream.seek(0) -        data_and_files = parser.parse(self.stream, None, self.parser_context) -        file_obj = data_and_files.files['file'] -        self.assertEqual(file_obj._size, 14) - -    def test_get_filename(self): -        parser = FileUploadParser() -        filename = parser.get_filename(self.stream, None, self.parser_context) -        self.assertEqual(filename, 'file.txt'.encode('utf-8')) diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py deleted file mode 100644 index 6e3a6303..00000000 --- a/rest_framework/tests/test_permissions.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import unicode_literals -from django.contrib.auth.models import User, Permission, Group -from django.db import models -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.compat import guardian, get_model_name -from rest_framework.filters import DjangoObjectPermissionsFilter -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel -import base64 - -factory = APIRequestFactory() - -class RootView(generics.ListCreateAPIView): -    model = BasicModel -    authentication_classes = [authentication.BasicAuthentication] -    permission_classes = [permissions.DjangoModelPermissions] - - -class InstanceView(generics.RetrieveUpdateDestroyAPIView): -    model = BasicModel -    authentication_classes = [authentication.BasicAuthentication] -    permission_classes = [permissions.DjangoModelPermissions] - -root_view = RootView.as_view() -instance_view = InstanceView.as_view() - - -def basic_auth_header(username, password): -    credentials = ('%s:%s' % (username, password)) -    base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) -    return 'Basic %s' % base64_credentials - - -class ModelPermissionsIntegrationTests(TestCase): -    def setUp(self): -        User.objects.create_user('disallowed', 'disallowed@example.com', 'password') -        user = User.objects.create_user('permitted', 'permitted@example.com', 'password') -        user.user_permissions = [ -            Permission.objects.get(codename='add_basicmodel'), -            Permission.objects.get(codename='change_basicmodel'), -            Permission.objects.get(codename='delete_basicmodel') -        ] -        user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') -        user.user_permissions = [ -            Permission.objects.get(codename='change_basicmodel'), -        ] - -        self.permitted_credentials = basic_auth_header('permitted', 'password') -        self.disallowed_credentials = basic_auth_header('disallowed', 'password') -        self.updateonly_credentials = basic_auth_header('updateonly', 'password') - -        BasicModel(text='foo').save() - -    def test_has_create_permissions(self): -        request = factory.post('/', {'text': 'foobar'}, format='json', -                               HTTP_AUTHORIZATION=self.permitted_credentials) -        response = root_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_201_CREATED) - -    def test_has_put_permissions(self): -        request = factory.put('/1', {'text': 'foobar'}, format='json', -                              HTTP_AUTHORIZATION=self.permitted_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_has_delete_permissions(self): -        request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) -        response = instance_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - -    def test_does_not_have_create_permissions(self): -        request = factory.post('/', {'text': 'foobar'}, format='json', -                               HTTP_AUTHORIZATION=self.disallowed_credentials) -        response = root_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_does_not_have_put_permissions(self): -        request = factory.put('/1', {'text': 'foobar'}, format='json', -                              HTTP_AUTHORIZATION=self.disallowed_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_does_not_have_delete_permissions(self): -        request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) -        response = instance_view(request, pk=1) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_has_put_as_create_permissions(self): -        # User only has update permissions - should be able to update an entity. -        request = factory.put('/1', {'text': 'foobar'}, format='json', -                              HTTP_AUTHORIZATION=self.updateonly_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -        # But if PUTing to a new entity, permission should be denied. -        request = factory.put('/2', {'text': 'foobar'}, format='json', -                              HTTP_AUTHORIZATION=self.updateonly_credentials) -        response = instance_view(request, pk='2') -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    def test_options_permitted(self): -        request = factory.options('/', -                               HTTP_AUTHORIZATION=self.permitted_credentials) -        response = root_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertIn('actions', response.data) -        self.assertEqual(list(response.data['actions'].keys()), ['POST']) - -        request = factory.options('/1', -                               HTTP_AUTHORIZATION=self.permitted_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertIn('actions', response.data) -        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) - -    def test_options_disallowed(self): -        request = factory.options('/', -                               HTTP_AUTHORIZATION=self.disallowed_credentials) -        response = root_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertNotIn('actions', response.data) - -        request = factory.options('/1', -                               HTTP_AUTHORIZATION=self.disallowed_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertNotIn('actions', response.data) - -    def test_options_updateonly(self): -        request = factory.options('/', -                               HTTP_AUTHORIZATION=self.updateonly_credentials) -        response = root_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertNotIn('actions', response.data) - -        request = factory.options('/1', -                               HTTP_AUTHORIZATION=self.updateonly_credentials) -        response = instance_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertIn('actions', response.data) -        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) - - -class BasicPermModel(models.Model): -    text = models.CharField(max_length=100) - -    class Meta: -        app_label = 'tests' -        permissions = ( -            ('view_basicpermmodel', 'Can view basic perm model'), -            # add, change, delete built in to django -        ) - -# Custom object-level permission, that includes 'view' permissions -class ViewObjectPermissions(permissions.DjangoObjectPermissions): -    perms_map = { -        'GET': ['%(app_label)s.view_%(model_name)s'], -        'OPTIONS': ['%(app_label)s.view_%(model_name)s'], -        'HEAD': ['%(app_label)s.view_%(model_name)s'], -        'POST': ['%(app_label)s.add_%(model_name)s'], -        'PUT': ['%(app_label)s.change_%(model_name)s'], -        'PATCH': ['%(app_label)s.change_%(model_name)s'], -        'DELETE': ['%(app_label)s.delete_%(model_name)s'], -    } - - -class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): -    model = BasicPermModel -    authentication_classes = [authentication.BasicAuthentication] -    permission_classes = [ViewObjectPermissions] - -object_permissions_view = ObjectPermissionInstanceView.as_view() - - -class ObjectPermissionListView(generics.ListAPIView): -    model = BasicPermModel -    authentication_classes = [authentication.BasicAuthentication] -    permission_classes = [ViewObjectPermissions] - -object_permissions_list_view = ObjectPermissionListView.as_view() - - -@unittest.skipUnless(guardian, 'django-guardian not installed') -class ObjectPermissionsIntegrationTests(TestCase): -    """ -    Integration tests for the object level permissions API. -    """ -    @classmethod -    def setUpClass(cls): -        from guardian.shortcuts import assign_perm - -        # create users -        create = User.objects.create_user -        users = { -            'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), -            'readonly': create('readonly', 'readonly@example.com', 'password'), -            'writeonly': create('writeonly', 'writeonly@example.com', 'password'), -            'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), -        } - -        # give everyone model level permissions, as we are not testing those -        everyone = Group.objects.create(name='everyone') -        model_name = get_model_name(BasicPermModel) -        app_label = BasicPermModel._meta.app_label -        f = '{0}_{1}'.format -        perms = { -            'view':   f('view', model_name), -            'change': f('change', model_name), -            'delete': f('delete', model_name) -        } -        for perm in perms.values(): -            perm = '{0}.{1}'.format(app_label, perm) -            assign_perm(perm, everyone) -        everyone.user_set.add(*users.values()) - -        cls.perms = perms -        cls.users = users - -    def setUp(self): -        from guardian.shortcuts import assign_perm -        perms = self.perms -        users = self.users - -        # appropriate object level permissions -        readers = Group.objects.create(name='readers') -        writers = Group.objects.create(name='writers') -        deleters = Group.objects.create(name='deleters') - -        model = BasicPermModel.objects.create(text='foo') -         -        assign_perm(perms['view'], readers, model) -        assign_perm(perms['change'], writers, model) -        assign_perm(perms['delete'], deleters, model) - -        readers.user_set.add(users['fullaccess'], users['readonly']) -        writers.user_set.add(users['fullaccess'], users['writeonly']) -        deleters.user_set.add(users['fullaccess'], users['deleteonly']) - -        self.credentials = {} -        for user in users.values(): -            self.credentials[user.username] = basic_auth_header(user.username, 'password') - -    # Delete -    def test_can_delete_permissions(self): -        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - -    def test_cannot_delete_permissions(self): -        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    # Update -    def test_can_update_permissions(self): -        request = factory.patch('/1', {'text': 'foobar'}, format='json', -            HTTP_AUTHORIZATION=self.credentials['writeonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data.get('text'), 'foobar') - -    def test_cannot_update_permissions(self): -        request = factory.patch('/1', {'text': 'foobar'}, format='json', -            HTTP_AUTHORIZATION=self.credentials['deleteonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - -    def test_cannot_update_permissions_non_existing(self): -        request = factory.patch('/999', {'text': 'foobar'}, format='json', -            HTTP_AUTHORIZATION=self.credentials['deleteonly']) -        response = object_permissions_view(request, pk='999') -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - -    # Read -    def test_can_read_permissions(self): -        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_cannot_read_permissions(self): -        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) -        response = object_permissions_view(request, pk='1') -        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - -    # Read list -    def test_can_read_list_permissions(self): -        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) -        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) -        response = object_permissions_list_view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data[0].get('id'), 1) - -    def test_cannot_read_list_permissions(self): -        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) -        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) -        response = object_permissions_list_view(request) -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertListEqual(response.data, []) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py deleted file mode 100644 index f52e0e1e..00000000 --- a/rest_framework/tests/test_relations.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -General tests for relational fields. -""" -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers -from rest_framework.tests.models import BlogPost - - -class NullModel(models.Model): -    pass - - -class FieldTests(TestCase): -    def test_pk_related_field_with_empty_string(self): -        """ -        Regression test for #446 - -        https://github.com/tomchristie/django-rest-framework/issues/446 -        """ -        field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) -        self.assertRaises(serializers.ValidationError, field.from_native, '') -        self.assertRaises(serializers.ValidationError, field.from_native, []) - -    def test_hyperlinked_related_field_with_empty_string(self): -        field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') -        self.assertRaises(serializers.ValidationError, field.from_native, '') -        self.assertRaises(serializers.ValidationError, field.from_native, []) - -    def test_slug_related_field_with_empty_string(self): -        field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') -        self.assertRaises(serializers.ValidationError, field.from_native, '') -        self.assertRaises(serializers.ValidationError, field.from_native, []) - - -class TestManyRelatedMixin(TestCase): -    def test_missing_many_to_many_related_field(self): -        ''' -        Regression test for #632 - -        https://github.com/tomchristie/django-rest-framework/pull/632 -        ''' -        field = serializers.RelatedField(many=True, read_only=False) - -        into = {} -        field.field_from_native({}, None, 'field_name', into) -        self.assertEqual(into['field_name'], []) - - -# Regression tests for #694 (`source` attribute on related fields) - -class RelatedFieldSourceTests(TestCase): -    def test_related_manager_source(self): -        """ -        Relational fields should be able to use manager-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.RelatedField(many=True, source='get_blogposts_manager') - -        class ClassWithManagerMethod(object): -            def get_blogposts_manager(self): -                return BlogPost.objects - -        obj = ClassWithManagerMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['BlogPost object']) - -    def test_related_queryset_source(self): -        """ -        Relational fields should be able to use queryset-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.RelatedField(many=True, source='get_blogposts_queryset') - -        class ClassWithQuerysetMethod(object): -            def get_blogposts_queryset(self): -                return BlogPost.objects.all() - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['BlogPost object']) - -    def test_dotted_source(self): -        """ -        Source argument should support dotted.source notation. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.RelatedField(many=True, source='a.b.c') - -        class ClassWithQuerysetMethod(object): -            a = { -                'b': { -                    'c': BlogPost.objects.all() -                } -            } - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['BlogPost object']) - -    # Regression for #1129 -    def test_exception_for_incorect_fk(self): -        """ -        Check that the exception message are correct if the source field -        doesn't exist. -        """ -        from rest_framework.tests.models import ManyToManySource -        class Meta: -            model = ManyToManySource -        attrs = { -            'name': serializers.SlugRelatedField( -                slug_field='name', source='banzai'), -            'Meta': Meta, -        } - -        TestSerializer = type(str('TestSerializer'), -            (serializers.ModelSerializer,), attrs) -        with self.assertRaises(AttributeError): -            TestSerializer(data={'name': 'foo'}) diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py deleted file mode 100644 index 3c4d39af..00000000 --- a/rest_framework/tests/test_relations_hyperlink.py +++ /dev/null @@ -1,524 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import serializers -from rest_framework.compat import patterns, url -from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import ( -    BlogPost, -    ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, -    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -) - -factory = APIRequestFactory() -request = factory.get('/')  # Just to ensure we have a request in the serializer context - - -def dummy_view(request, pk): -    pass - -urlpatterns = patterns('', -    url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), -    url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), -    url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), -    url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), -    url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), -    url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), -    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), -    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), -) - - -# ManyToMany -class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = ManyToManyTarget -        fields = ('url', 'name', 'sources') - - -class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = ManyToManySource -        fields = ('url', 'name', 'targets') - - -# ForeignKey -class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = ForeignKeyTarget -        fields = ('url', 'name', 'sources') - - -class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = ForeignKeySource -        fields = ('url', 'name', 'target') - - -# Nullable ForeignKey -class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = NullableForeignKeySource -        fields = ('url', 'name', 'target') - - -# Nullable OneToOne -class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): -    class Meta: -        model = OneToOneTarget -        fields = ('url', 'name', 'nullable_source') - - -# TODO: Add test that .data cannot be accessed prior to .is_valid - -class HyperlinkedManyToManyTests(TestCase): -    urls = 'rest_framework.tests.test_relations_hyperlink' - -    def setUp(self): -        for idx in range(1, 4): -            target = ManyToManyTarget(name='target-%d' % idx) -            target.save() -            source = ManyToManySource(name='source-%d' % idx) -            source.save() -            for target in ManyToManyTarget.objects.all(): -                source.targets.add(target) - -    def test_many_to_many_retrieve(self): -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -                {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, -                {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, -                {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_retrieve(self): -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_many_to_many_update(self): -        data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} -        instance = ManyToManySource.objects.get(pk=1) -        serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -                {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, -                {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, -                {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_update(self): -        data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} -        instance = ManyToManyTarget.objects.get(pk=1) -        serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure target 1 is updated, and everything else is as expected -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, -            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} - -        ] -        self.assertEqual(serializer.data, expected) - -    def test_many_to_many_create(self): -        data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} -        serializer = ManyToManySourceSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is added, and everything else is as expected -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, -            {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, -            {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, -            {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_create(self): -        data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} -        serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-4') - -        # Ensure target 4 is added, and everything else is as expected -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, -            {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} -        ] -        self.assertEqual(serializer.data, expected) - - -class HyperlinkedForeignKeyTests(TestCase): -    urls = 'rest_framework.tests.test_relations_hyperlink' - -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        new_target = ForeignKeyTarget(name='target-2') -        new_target.save() -        for idx in range(1, 4): -            source = ForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve(self): -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_retrieve(self): -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, -            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update(self): -        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, -            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_incorrect_type(self): -        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['Incorrect type.  Expected url string, received int.']}) - -    def test_reverse_foreign_key_update(self): -        data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} -        instance = ForeignKeyTarget.objects.get(pk=2) -        serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        # We shouldn't have saved anything to the db yet since save -        # hasn't been called. -        queryset = ForeignKeyTarget.objects.all() -        new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, -            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(new_serializer.data, expected) - -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure target 2 is update, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, -            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create(self): -        data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} -        serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_create(self): -        data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} -        serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-3') - -        # Ensure target 4 is added, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, -            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, -            {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_invalid_null(self): -        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) - - -class HyperlinkedNullableForeignKeyTests(TestCase): -    urls = 'rest_framework.tests.test_relations_hyperlink' - -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        for idx in range(1, 4): -            if idx == 3: -                target = None -            source = NullableForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve_with_null(self): -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_null(self): -        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, -            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} -        expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, expected_data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, -            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_null(self): -        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, -            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} -        expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, expected_data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, -            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, -            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, -        ] -        self.assertEqual(serializer.data, expected) - -    # reverse foreign keys MUST be read_only -    # In the general case they do not provide .remove() or .clear() -    # and cannot be arbitrarily set. - -    # def test_reverse_foreign_key_update(self): -    #     data = {'id': 1, 'name': 'target-1', 'sources': [1]} -    #     instance = ForeignKeyTarget.objects.get(pk=1) -    #     serializer = ForeignKeyTargetSerializer(instance, data=data) -    #     self.assertTrue(serializer.is_valid()) -    #     self.assertEqual(serializer.data, data) -    #     serializer.save() - -    #     # Ensure target 1 is updated, and everything else is as expected -    #     queryset = ForeignKeyTarget.objects.all() -    #     serializer = ForeignKeyTargetSerializer(queryset, many=True) -    #     expected = [ -    #         {'id': 1, 'name': 'target-1', 'sources': [1]}, -    #         {'id': 2, 'name': 'target-2', 'sources': []}, -    #     ] -    #     self.assertEqual(serializer.data, expected) - - -class HyperlinkedNullableOneToOneTests(TestCase): -    urls = 'rest_framework.tests.test_relations_hyperlink' - -    def setUp(self): -        target = OneToOneTarget(name='target-1') -        target.save() -        new_target = OneToOneTarget(name='target-2') -        new_target.save() -        source = NullableOneToOneSource(name='source-1', target=target) -        source.save() - -    def test_reverse_foreign_key_retrieve_with_null(self): -        queryset = OneToOneTarget.objects.all() -        serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) -        expected = [ -            {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, -            {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, -        ] -        self.assertEqual(serializer.data, expected) - - -# Regression tests for #694 (`source` attribute on related fields) - -class HyperlinkedRelatedFieldSourceTests(TestCase): -    urls = 'rest_framework.tests.test_relations_hyperlink' - -    def test_related_manager_source(self): -        """ -        Relational fields should be able to use manager-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.HyperlinkedRelatedField( -            many=True, -            source='get_blogposts_manager', -            view_name='dummy-url', -        ) -        field.context = {'request': request} - -        class ClassWithManagerMethod(object): -            def get_blogposts_manager(self): -                return BlogPost.objects - -        obj = ClassWithManagerMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['http://testserver/dummyurl/1/']) - -    def test_related_queryset_source(self): -        """ -        Relational fields should be able to use queryset-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.HyperlinkedRelatedField( -            many=True, -            source='get_blogposts_queryset', -            view_name='dummy-url', -        ) -        field.context = {'request': request} - -        class ClassWithQuerysetMethod(object): -            def get_blogposts_queryset(self): -                return BlogPost.objects.all() - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['http://testserver/dummyurl/1/']) - -    def test_dotted_source(self): -        """ -        Source argument should support dotted.source notation. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.HyperlinkedRelatedField( -            many=True, -            source='a.b.c', -            view_name='dummy-url', -        ) -        field.context = {'request': request} - -        class ClassWithQuerysetMethod(object): -            a = { -                'b': { -                    'c': BlogPost.objects.all() -                } -            } - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py deleted file mode 100644 index 4d9da489..00000000 --- a/rest_framework/tests/test_relations_nested.py +++ /dev/null @@ -1,326 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers - -from .models import OneToOneTarget - - -class OneToOneSource(models.Model): -    name = models.CharField(max_length=100) -    target = models.OneToOneField(OneToOneTarget, related_name='source', -                                  null=True, blank=True) - - -class OneToManyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class OneToManySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(OneToManyTarget, related_name='sources') - - -class ReverseNestedOneToOneTests(TestCase): -    def setUp(self): -        class OneToOneSourceSerializer(serializers.ModelSerializer): -            class Meta: -                model = OneToOneSource -                fields = ('id', 'name') - -        class OneToOneTargetSerializer(serializers.ModelSerializer): -            source = OneToOneSourceSerializer() - -            class Meta: -                model = OneToOneTarget -                fields = ('id', 'name', 'source') - -        self.Serializer = OneToOneTargetSerializer - -        for idx in range(1, 4): -            target = OneToOneTarget(name='target-%d' % idx) -            target.save() -            source = OneToOneSource(name='source-%d' % idx, target=target) -            source.save() - -    def test_one_to_one_retrieve(self): -        queryset = OneToOneTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, -            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, -            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_one_create(self): -        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} -        serializer = self.Serializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-4') - -        # Ensure (target 4, target_source 4, source 4) are added, and -        # everything else is as expected. -        queryset = OneToOneTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, -            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, -            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, -            {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_one_create_with_invalid_data(self): -        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} -        serializer = self.Serializer(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) - -    def test_one_to_one_update(self): -        data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} -        instance = OneToOneTarget.objects.get(pk=3) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-3-updated') - -        # Ensure (target 3, target_source 3, source 3) are updated, -        # and everything else is as expected. -        queryset = OneToOneTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, -            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, -            {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} -        ] -        self.assertEqual(serializer.data, expected) - - -class ForwardNestedOneToOneTests(TestCase): -    def setUp(self): -        class OneToOneTargetSerializer(serializers.ModelSerializer): -            class Meta: -                model = OneToOneTarget -                fields = ('id', 'name') - -        class OneToOneSourceSerializer(serializers.ModelSerializer): -            target = OneToOneTargetSerializer() - -            class Meta: -                model = OneToOneSource -                fields = ('id', 'name', 'target') - -        self.Serializer = OneToOneSourceSerializer - -        for idx in range(1, 4): -            target = OneToOneTarget(name='target-%d' % idx) -            target.save() -            source = OneToOneSource(name='source-%d' % idx, target=target) -            source.save() - -    def test_one_to_one_retrieve(self): -        queryset = OneToOneSource.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, -            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_one_create(self): -        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} -        serializer = self.Serializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure (target 4, target_source 4, source 4) are added, and -        # everything else is as expected. -        queryset = OneToOneSource.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, -            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, -            {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_one_create_with_invalid_data(self): -        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} -        serializer = self.Serializer(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) - -    def test_one_to_one_update(self): -        data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} -        instance = OneToOneSource.objects.get(pk=3) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-3-updated') - -        # Ensure (target 3, target_source 3, source 3) are updated, -        # and everything else is as expected. -        queryset = OneToOneSource.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, -            {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_one_update_to_null(self): -        data = {'id': 3, 'name': 'source-3-updated', 'target': None} -        instance = OneToOneSource.objects.get(pk=3) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() - -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-3-updated') -        self.assertEqual(obj.target, None) - -        queryset = OneToOneSource.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, -            {'id': 3, 'name': 'source-3-updated', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    # TODO: Nullable 1-1 tests -    # def test_one_to_one_delete(self): -    #     data = {'id': 3, 'name': 'target-3', 'target_source': None} -    #     instance = OneToOneTarget.objects.get(pk=3) -    #     serializer = self.Serializer(instance, data=data) -    #     self.assertTrue(serializer.is_valid()) -    #     serializer.save() - -    #     # Ensure (target_source 3, source 3) are deleted, -    #     # and everything else is as expected. -    #     queryset = OneToOneTarget.objects.all() -    #     serializer = self.Serializer(queryset) -    #     expected = [ -    #         {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, -    #         {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, -    #         {'id': 3, 'name': 'target-3', 'source': None} -    #     ] -    #     self.assertEqual(serializer.data, expected) - - -class ReverseNestedOneToManyTests(TestCase): -    def setUp(self): -        class OneToManySourceSerializer(serializers.ModelSerializer): -            class Meta: -                model = OneToManySource -                fields = ('id', 'name') - -        class OneToManyTargetSerializer(serializers.ModelSerializer): -            sources = OneToManySourceSerializer(many=True, allow_add_remove=True) - -            class Meta: -                model = OneToManyTarget -                fields = ('id', 'name', 'sources') - -        self.Serializer = OneToManyTargetSerializer - -        target = OneToManyTarget(name='target-1') -        target.save() -        for idx in range(1, 4): -            source = OneToManySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_one_to_many_retrieve(self): -        queryset = OneToManyTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                      {'id': 2, 'name': 'source-2'}, -                                                      {'id': 3, 'name': 'source-3'}]}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_many_create(self): -        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                         {'id': 2, 'name': 'source-2'}, -                                                         {'id': 3, 'name': 'source-3'}, -                                                         {'id': 4, 'name': 'source-4'}]} -        instance = OneToManyTarget.objects.get(pk=1) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-1') - -        # Ensure source 4 is added, and everything else is as -        # expected. -        queryset = OneToManyTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                      {'id': 2, 'name': 'source-2'}, -                                                      {'id': 3, 'name': 'source-3'}, -                                                      {'id': 4, 'name': 'source-4'}]} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_many_create_with_invalid_data(self): -        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                         {'id': 2, 'name': 'source-2'}, -                                                         {'id': 3, 'name': 'source-3'}, -                                                         {'id': 4}]} -        serializer = self.Serializer(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) - -    def test_one_to_many_update(self): -        data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, -                                                                 {'id': 2, 'name': 'source-2'}, -                                                                 {'id': 3, 'name': 'source-3'}]} -        instance = OneToManyTarget.objects.get(pk=1) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-1-updated') - -        # Ensure (target 1, source 1) are updated, -        # and everything else is as expected. -        queryset = OneToManyTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, -                                                              {'id': 2, 'name': 'source-2'}, -                                                              {'id': 3, 'name': 'source-3'}]} - -        ] -        self.assertEqual(serializer.data, expected) - -    def test_one_to_many_delete(self): -        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                         {'id': 3, 'name': 'source-3'}]} -        instance = OneToManyTarget.objects.get(pk=1) -        serializer = self.Serializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        serializer.save() - -        # Ensure source 2 is deleted, and everything else is as -        # expected. -        queryset = OneToManyTarget.objects.all() -        serializer = self.Serializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, -                                                      {'id': 3, 'name': 'source-3'}]} - -        ] -        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py deleted file mode 100644 index 3815afdd..00000000 --- a/rest_framework/tests/test_relations_pk.py +++ /dev/null @@ -1,551 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers -from rest_framework.tests.models import ( -    BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, -    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, -) -from rest_framework.compat import six - - -# ManyToMany -class ManyToManyTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = ManyToManyTarget -        fields = ('id', 'name', 'sources') - - -class ManyToManySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = ManyToManySource -        fields = ('id', 'name', 'targets') - - -# ForeignKey -class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeyTarget -        fields = ('id', 'name', 'sources') - - -class ForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeySource -        fields = ('id', 'name', 'target') - - -# Nullable ForeignKey -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = NullableForeignKeySource -        fields = ('id', 'name', 'target') - - -# Nullable OneToOne -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = OneToOneTarget -        fields = ('id', 'name', 'nullable_source') - - -# TODO: Add test that .data cannot be accessed prior to .is_valid - -class PKManyToManyTests(TestCase): -    def setUp(self): -        for idx in range(1, 4): -            target = ManyToManyTarget(name='target-%d' % idx) -            target.save() -            source = ManyToManySource(name='source-%d' % idx) -            source.save() -            for target in ManyToManyTarget.objects.all(): -                source.targets.add(target) - -    def test_many_to_many_retrieve(self): -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True) -        expected = [ -                {'id': 1, 'name': 'source-1', 'targets': [1]}, -                {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, -                {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_retrieve(self): -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, -            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, -            {'id': 3, 'name': 'target-3', 'sources': [3]} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_many_to_many_update(self): -        data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} -        instance = ManyToManySource.objects.get(pk=1) -        serializer = ManyToManySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True) -        expected = [ -                {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, -                {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, -                {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_update(self): -        data = {'id': 1, 'name': 'target-1', 'sources': [1]} -        instance = ManyToManyTarget.objects.get(pk=1) -        serializer = ManyToManyTargetSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure target 1 is updated, and everything else is as expected -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [1]}, -            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, -            {'id': 3, 'name': 'target-3', 'sources': [3]} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_many_to_many_create(self): -        data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} -        serializer = ManyToManySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is added, and everything else is as expected -        queryset = ManyToManySource.objects.all() -        serializer = ManyToManySourceSerializer(queryset, many=True) -        self.assertFalse(serializer.fields['targets'].read_only) -        expected = [ -            {'id': 1, 'name': 'source-1', 'targets': [1]}, -            {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, -            {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, -            {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_many_to_many_create(self): -        data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} -        serializer = ManyToManyTargetSerializer(data=data) -        self.assertFalse(serializer.fields['sources'].read_only) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-4') - -        # Ensure target 4 is added, and everything else is as expected -        queryset = ManyToManyTarget.objects.all() -        serializer = ManyToManyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, -            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, -            {'id': 3, 'name': 'target-3', 'sources': [3]}, -            {'id': 4, 'name': 'target-4', 'sources': [1, 3]} -        ] -        self.assertEqual(serializer.data, expected) - - -class PKForeignKeyTests(TestCase): -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        new_target = ForeignKeyTarget(name='target-2') -        new_target.save() -        for idx in range(1, 4): -            source = ForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve(self): -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 1}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': 1} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_retrieve(self): -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update(self): -        data = {'id': 1, 'name': 'source-1', 'target': 2} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 2}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': 1} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_incorrect_type(self): -        data = {'id': 1, 'name': 'source-1', 'target': 'foo'} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['Incorrect type.  Expected pk value, received %s.' % six.text_type.__name__]}) - -    def test_reverse_foreign_key_update(self): -        data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} -        instance = ForeignKeyTarget.objects.get(pk=2) -        serializer = ForeignKeyTargetSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        # We shouldn't have saved anything to the db yet since save -        # hasn't been called. -        queryset = ForeignKeyTarget.objects.all() -        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(new_serializer.data, expected) - -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure target 2 is update, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [2]}, -            {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create(self): -        data = {'id': 4, 'name': 'source-4', 'target': 2} -        serializer = ForeignKeySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is added, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 1}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': 1}, -            {'id': 4, 'name': 'source-4', 'target': 2}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_create(self): -        data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} -        serializer = ForeignKeyTargetSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-3') - -        # Ensure target 3 is added, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [2]}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -            {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_invalid_null(self): -        data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) - -    def test_foreign_key_with_empty(self): -        """ -        Regression test for #1072 - -        https://github.com/tomchristie/django-rest-framework/issues/1072 -        """ -        serializer = NullableForeignKeySourceSerializer() -        self.assertEqual(serializer.data['target'], None) - - -class PKNullableForeignKeyTests(TestCase): -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        for idx in range(1, 4): -            if idx == 3: -                target = None -            source = NullableForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve_with_null(self): -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 1}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': None}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_null(self): -        data = {'id': 4, 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 1}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': None}, -            {'id': 4, 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'id': 4, 'name': 'source-4', 'target': ''} -        expected_data = {'id': 4, 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, expected_data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 1}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': None}, -            {'id': 4, 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_null(self): -        data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': None}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'id': 1, 'name': 'source-1', 'target': ''} -        expected_data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, expected_data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': None}, -            {'id': 2, 'name': 'source-2', 'target': 1}, -            {'id': 3, 'name': 'source-3', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    # reverse foreign keys MUST be read_only -    # In the general case they do not provide .remove() or .clear() -    # and cannot be arbitrarily set. - -    # def test_reverse_foreign_key_update(self): -    #     data = {'id': 1, 'name': 'target-1', 'sources': [1]} -    #     instance = ForeignKeyTarget.objects.get(pk=1) -    #     serializer = ForeignKeyTargetSerializer(instance, data=data) -    #     self.assertTrue(serializer.is_valid()) -    #     self.assertEqual(serializer.data, data) -    #     serializer.save() - -    #     # Ensure target 1 is updated, and everything else is as expected -    #     queryset = ForeignKeyTarget.objects.all() -    #     serializer = ForeignKeyTargetSerializer(queryset, many=True) -    #     expected = [ -    #         {'id': 1, 'name': 'target-1', 'sources': [1]}, -    #         {'id': 2, 'name': 'target-2', 'sources': []}, -    #     ] -    #     self.assertEqual(serializer.data, expected) - - -class PKNullableOneToOneTests(TestCase): -    def setUp(self): -        target = OneToOneTarget(name='target-1') -        target.save() -        new_target = OneToOneTarget(name='target-2') -        new_target.save() -        source = NullableOneToOneSource(name='source-1', target=new_target) -        source.save() - -    def test_reverse_foreign_key_retrieve_with_null(self): -        queryset = OneToOneTarget.objects.all() -        serializer = NullableOneToOneTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'nullable_source': None}, -            {'id': 2, 'name': 'target-2', 'nullable_source': 1}, -        ] -        self.assertEqual(serializer.data, expected) - - -# The below models and tests ensure that serializer fields corresponding -# to a ManyToManyField field with a user-specified ``through`` model are -# set to read only - - -class ManyToManyThroughTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ManyToManyThrough(models.Model): -    source = models.ForeignKey('ManyToManyThroughSource') -    target = models.ForeignKey(ManyToManyThroughTarget) - - -class ManyToManyThroughSource(models.Model): -    name = models.CharField(max_length=100) -    targets = models.ManyToManyField(ManyToManyThroughTarget, -                                     related_name='sources', -                                     through='ManyToManyThrough') - - -class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = ManyToManyThroughTarget -        fields = ('id', 'name', 'sources') - - -class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = ManyToManyThroughSource -        fields = ('id', 'name', 'targets') - - -class PKManyToManyThroughTests(TestCase): -    def setUp(self): -        self.source = ManyToManyThroughSource.objects.create( -            name='through-source-1') -        self.target = ManyToManyThroughTarget.objects.create( -            name='through-target-1') - -    def test_many_to_many_create(self): -        data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} -        serializer = ManyToManyThroughSourceSerializer(data=data) -        self.assertTrue(serializer.fields['targets'].read_only) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(obj.name, 'source-2') -        self.assertEqual(obj.targets.count(), 0) - -    def test_many_to_many_reverse_create(self): -        data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} -        serializer = ManyToManyThroughTargetSerializer(data=data) -        self.assertTrue(serializer.fields['sources'].read_only) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        obj = serializer.save() -        self.assertEqual(obj.name, 'target-2') -        self.assertEqual(obj.sources.count(), 0) - - -# Regression tests for #694 (`source` attribute on related fields) - - -class PrimaryKeyRelatedFieldSourceTests(TestCase): -    def test_related_manager_source(self): -        """ -        Relational fields should be able to use manager-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') - -        class ClassWithManagerMethod(object): -            def get_blogposts_manager(self): -                return BlogPost.objects - -        obj = ClassWithManagerMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, [1]) - -    def test_related_queryset_source(self): -        """ -        Relational fields should be able to use queryset-returning methods as their source. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') - -        class ClassWithQuerysetMethod(object): -            def get_blogposts_queryset(self): -                return BlogPost.objects.all() - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, [1]) - -    def test_dotted_source(self): -        """ -        Source argument should support dotted.source notation. -        """ -        BlogPost.objects.create(title='blah') -        field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') - -        class ClassWithQuerysetMethod(object): -            a = { -                'b': { -                    'c': BlogPost.objects.all() -                } -            } - -        obj = ClassWithQuerysetMethod() -        value = field.field_to_native(obj, 'field_name') -        self.assertEqual(value, [1]) diff --git a/rest_framework/tests/test_relations_slug.py b/rest_framework/tests/test_relations_slug.py deleted file mode 100644 index 435c821c..00000000 --- a/rest_framework/tests/test_relations_slug.py +++ /dev/null @@ -1,257 +0,0 @@ -from django.test import TestCase -from rest_framework import serializers -from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget - - -class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    sources = serializers.SlugRelatedField(many=True, slug_field='name') - -    class Meta: -        model = ForeignKeyTarget - - -class ForeignKeySourceSerializer(serializers.ModelSerializer): -    target = serializers.SlugRelatedField(slug_field='name') - -    class Meta: -        model = ForeignKeySource - - -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): -    target = serializers.SlugRelatedField(slug_field='name', required=False) - -    class Meta: -        model = NullableForeignKeySource - - -# TODO: M2M Tests, FKTests (Non-nullable), One2One -class SlugForeignKeyTests(TestCase): -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        new_target = ForeignKeyTarget(name='target-2') -        new_target.save() -        for idx in range(1, 4): -            source = ForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve(self): -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': 'target-1'} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_retrieve(self): -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update(self): -        data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-2'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': 'target-1'} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_incorrect_type(self): -        data = {'id': 1, 'name': 'source-1', 'target': 123} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) - -    def test_reverse_foreign_key_update(self): -        data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} -        instance = ForeignKeyTarget.objects.get(pk=2) -        serializer = ForeignKeyTargetSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        # We shouldn't have saved anything to the db yet since save -        # hasn't been called. -        queryset = ForeignKeyTarget.objects.all() -        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -        ] -        self.assertEqual(new_serializer.data, expected) - -        serializer.save() -        self.assertEqual(serializer.data, data) - -        # Ensure target 2 is update, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, -            {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create(self): -        data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} -        serializer = ForeignKeySourceSerializer(data=data) -        serializer.is_valid() -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is added, and everything else is as expected -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': 'target-1'}, -            {'id': 4, 'name': 'source-4', 'target': 'target-2'}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_reverse_foreign_key_create(self): -        data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} -        serializer = ForeignKeyTargetSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'target-3') - -        # Ensure target 3 is added, and everything else is as expected -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, -            {'id': 2, 'name': 'target-2', 'sources': []}, -            {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_invalid_null(self): -        data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = ForeignKeySource.objects.get(pk=1) -        serializer = ForeignKeySourceSerializer(instance, data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) - - -class SlugNullableForeignKeyTests(TestCase): -    def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        for idx in range(1, 4): -            if idx == 3: -                target = None -            source = NullableForeignKeySource(name='source-%d' % idx, target=target) -            source.save() - -    def test_foreign_key_retrieve_with_null(self): -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': None}, -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_null(self): -        data = {'id': 4, 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': None}, -            {'id': 4, 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_create_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'id': 4, 'name': 'source-4', 'target': ''} -        expected_data = {'id': 4, 'name': 'source-4', 'target': None} -        serializer = NullableForeignKeySourceSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        obj = serializer.save() -        self.assertEqual(serializer.data, expected_data) -        self.assertEqual(obj.name, 'source-4') - -        # Ensure source 4 is created, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': None}, -            {'id': 4, 'name': 'source-4', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_null(self): -        data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': None}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) - -    def test_foreign_key_update_with_valid_emptystring(self): -        """ -        The emptystring should be interpreted as null in the context -        of relationships. -        """ -        data = {'id': 1, 'name': 'source-1', 'target': ''} -        expected_data = {'id': 1, 'name': 'source-1', 'target': None} -        instance = NullableForeignKeySource.objects.get(pk=1) -        serializer = NullableForeignKeySourceSerializer(instance, data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, expected_data) -        serializer.save() - -        # Ensure source 1 is updated, and everything else is as expected -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) -        expected = [ -            {'id': 1, 'name': 'source-1', 'target': None}, -            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, -            {'id': 3, 'name': 'source-3', 'target': None} -        ] -        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py deleted file mode 100644 index c7bf772e..00000000 --- a/rest_framework/tests/test_renderers.py +++ /dev/null @@ -1,655 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from decimal import Decimal -from django.core.cache import cache -from django.db import models -from django.test import TestCase -from django.utils import unittest -from django.utils.translation import ugettext_lazy as _ -from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO -from rest_framework.response import Response -from rest_framework.views import APIView -from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ -    XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer -from rest_framework.parsers import YAMLParser, XMLParser -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory -from collections import MutableMapping -import datetime -import json -import pickle -import re - - -DUMMYSTATUS = status.HTTP_200_OK -DUMMYCONTENT = 'dummycontent' - -RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') -RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') - - -expected_results = [ -    ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]')  # Generator -] - - -class DummyTestModel(models.Model): -    name = models.CharField(max_length=42, default='') - - -class BasicRendererTests(TestCase): -    def test_expected_results(self): -        for value, renderer_cls, expected in expected_results: -            output = renderer_cls().render(value) -            self.assertEqual(output, expected) - - -class RendererA(BaseRenderer): -    media_type = 'mock/renderera' -    format = "formata" - -    def render(self, data, media_type=None, renderer_context=None): -        return RENDERER_A_SERIALIZER(data) - - -class RendererB(BaseRenderer): -    media_type = 'mock/rendererb' -    format = "formatb" - -    def render(self, data, media_type=None, renderer_context=None): -        return RENDERER_B_SERIALIZER(data) - - -class MockView(APIView): -    renderer_classes = (RendererA, RendererB) - -    def get(self, request, **kwargs): -        response = Response(DUMMYCONTENT, status=DUMMYSTATUS) -        return response - - -class MockGETView(APIView): -    def get(self, request, **kwargs): -        return Response({'foo': ['bar', 'baz']}) - - - -class MockPOSTView(APIView): -    def post(self, request, **kwargs): -        return Response({'foo': request.DATA}) - - -class EmptyGETView(APIView): -    renderer_classes = (JSONRenderer,) - -    def get(self, request, **kwargs): -        return Response(status=status.HTTP_204_NO_CONTENT) - - -class HTMLView(APIView): -    renderer_classes = (BrowsableAPIRenderer, ) - -    def get(self, request, **kwargs): -        return Response('text') - - -class HTMLView1(APIView): -    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) - -    def get(self, request, **kwargs): -        return Response('text') - -urlpatterns = patterns('', -    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), -    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), -    url(r'^cache$', MockGETView.as_view()), -    url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), -    url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), -    url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])), -    url(r'^html$', HTMLView.as_view()), -    url(r'^html1$', HTMLView1.as_view()), -    url(r'^empty$', EmptyGETView.as_view()), -    url(r'^api', include('rest_framework.urls', namespace='rest_framework')) -) - - -class POSTDeniedPermission(permissions.BasePermission): -    def has_permission(self, request, view): -        return request.method != 'POST' - - -class POSTDeniedView(APIView): -    renderer_classes = (BrowsableAPIRenderer,) -    permission_classes = (POSTDeniedPermission,) - -    def get(self, request): -        return Response() - -    def post(self, request): -        return Response() - -    def put(self, request): -        return Response() - -    def patch(self, request): -        return Response() - - -class DocumentingRendererTests(TestCase): -    def test_only_permitted_forms_are_displayed(self): -        view = POSTDeniedView.as_view() -        request = APIRequestFactory().get('/') -        response = view(request).render() -        self.assertNotContains(response, '>POST<') -        self.assertContains(response, '>PUT<') -        self.assertContains(response, '>PATCH<') - - -class RendererEndToEndTests(TestCase): -    """ -    End-to-end testing of renderers using an RendererMixin on a generic view. -    """ - -    urls = 'rest_framework.tests.test_renderers' - -    def test_default_renderer_serializes_content(self): -        """If the Accept header is not set the default renderer should serialize the response.""" -        resp = self.client.get('/') -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_head_method_serializes_no_content(self): -        """No response must be included in HEAD requests.""" -        resp = self.client.head('/') -        self.assertEqual(resp.status_code, DUMMYSTATUS) -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, six.b('')) - -    def test_default_renderer_serializes_content_on_accept_any(self): -        """If the Accept header is set to */* the default renderer should serialize the response.""" -        resp = self.client.get('/', HTTP_ACCEPT='*/*') -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_default_case(self): -        """If the Accept header is set the specified renderer should serialize the response. -        (In this case we check that works for the default renderer)""" -        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_non_default_case(self): -        """If the Accept header is set the specified renderer should serialize the response. -        (In this case we check that works for a non-default renderer)""" -        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_on_accept_query(self): -        """The '_accept' query string should behave in the same way as the Accept header.""" -        param = '?%s=%s' % ( -            api_settings.URL_ACCEPT_OVERRIDE, -            RendererB.media_type -        ) -        resp = self.client.get('/' + param) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_unsatisfiable_accept_header_on_request_returns_406_status(self): -        """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" -        resp = self.client.get('/', HTTP_ACCEPT='foo/bar') -        self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) - -    def test_specified_renderer_serializes_content_on_format_query(self): -        """If a 'format' query is specified, the renderer with the matching -        format attribute should serialize the response.""" -        param = '?%s=%s' % ( -            api_settings.URL_FORMAT_OVERRIDE, -            RendererB.format -        ) -        resp = self.client.get('/' + param) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_on_format_kwargs(self): -        """If a 'format' keyword arg is specified, the renderer with the matching -        format attribute should serialize the response.""" -        resp = self.client.get('/something.formatb') -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): -        """If both a 'format' query and a matching Accept header specified, -        the renderer with the matching format attribute should serialize the response.""" -        param = '?%s=%s' % ( -            api_settings.URL_FORMAT_OVERRIDE, -            RendererB.format -        ) -        resp = self.client.get('/' + param, -                               HTTP_ACCEPT=RendererB.media_type) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_parse_error_renderers_browsable_api(self): -        """Invalid data should still render the browsable API correctly.""" -        resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') -        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') -        self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) - -    def test_204_no_content_responses_have_no_content_type_set(self): -        """ -        Regression test for #1196 - -        https://github.com/tomchristie/django-rest-framework/issues/1196 -        """ -        resp = self.client.get('/empty') -        self.assertEqual(resp.get('Content-Type', None), None) -        self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) - -    def test_contains_headers_of_api_response(self): -        """ -        Issue #1437 - -        Test we display the headers of the API response and not those from the -        HTML response -        """ -        resp = self.client.get('/html1') -        self.assertContains(resp, '>GET, HEAD, OPTIONS<') -        self.assertContains(resp, '>application/json<') -        self.assertNotContains(resp, '>text/html; charset=utf-8<') - - -_flat_repr = '{"foo": ["bar", "baz"]}' -_indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' - - -def strip_trailing_whitespace(content): -    """ -    Seems to be some inconsistencies re. trailing whitespace with -    different versions of the json lib. -    """ -    return re.sub(' +\n', '\n', content) - - -class JSONRendererTests(TestCase): -    """ -    Tests specific to the JSON Renderer -    """ - -    def test_render_lazy_strings(self): -        """ -        JSONRenderer should deal with lazy translated strings. -        """ -        ret = JSONRenderer().render(_('test')) -        self.assertEqual(ret, b'"test"') - -    def test_render_queryset_values(self): -        o = DummyTestModel.objects.create(name='dummy') -        qs = DummyTestModel.objects.values('id', 'name') -        ret = JSONRenderer().render(qs) -        data = json.loads(ret.decode('utf-8')) -        self.assertEquals(data, [{'id': o.id, 'name': o.name}]) - -    def test_render_queryset_values_list(self): -        o = DummyTestModel.objects.create(name='dummy') -        qs = DummyTestModel.objects.values_list('id', 'name') -        ret = JSONRenderer().render(qs) -        data = json.loads(ret.decode('utf-8')) -        self.assertEquals(data, [[o.id, o.name]]) - -    def test_render_dict_abc_obj(self): -        class Dict(MutableMapping): -            def __init__(self): -                self._dict = dict() -            def __getitem__(self, key): -                return self._dict.__getitem__(key) -            def __setitem__(self, key, value): -                return self._dict.__setitem__(key, value) -            def __delitem__(self, key): -                return self._dict.__delitem__(key) -            def __iter__(self): -                return self._dict.__iter__() -            def __len__(self): -                return self._dict.__len__() -            def keys(self): -                return self._dict.keys() - -        x = Dict() -        x['key'] = 'string value' -        x[2] = 3 -        ret = JSONRenderer().render(x) -        data = json.loads(ret.decode('utf-8')) -        self.assertEquals(data, {'key': 'string value', '2': 3})     - -    def test_render_obj_with_getitem(self): -        class DictLike(object): -            def __init__(self): -                self._dict = {} -            def set(self, value): -                self._dict = dict(value) -            def __getitem__(self, key): -                return self._dict[key] -             -        x = DictLike() -        x.set({'a': 1, 'b': 'string'}) -        with self.assertRaises(TypeError): -            JSONRenderer().render(x) -         -    def test_without_content_type_args(self): -        """ -        Test basic JSON rendering. -        """ -        obj = {'foo': ['bar', 'baz']} -        renderer = JSONRenderer() -        content = renderer.render(obj, 'application/json') -        # Fix failing test case which depends on version of JSON library. -        self.assertEqual(content.decode('utf-8'), _flat_repr) - -    def test_with_content_type_args(self): -        """ -        Test JSON rendering with additional content type arguments supplied. -        """ -        obj = {'foo': ['bar', 'baz']} -        renderer = JSONRenderer() -        content = renderer.render(obj, 'application/json; indent=2') -        self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) - -    def test_check_ascii(self): -        obj = {'countries': ['United Kingdom', 'France', 'España']} -        renderer = JSONRenderer() -        content = renderer.render(obj, 'application/json') -        self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) - - -class UnicodeJSONRendererTests(TestCase): -    """ -    Tests specific for the Unicode JSON Renderer -    """ -    def test_proper_encoding(self): -        obj = {'countries': ['United Kingdom', 'France', 'España']} -        renderer = UnicodeJSONRenderer() -        content = renderer.render(obj, 'application/json') -        self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) - - -class JSONPRendererTests(TestCase): -    """ -    Tests specific to the JSONP Renderer -    """ - -    urls = 'rest_framework.tests.test_renderers' - -    def test_without_callback_with_json_renderer(self): -        """ -        Test JSONP rendering with View JSON Renderer. -        """ -        resp = self.client.get('/jsonp/jsonrenderer', -                               HTTP_ACCEPT='application/javascript') -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual(resp.content, -            ('callback(%s);' % _flat_repr).encode('ascii')) - -    def test_without_callback_without_json_renderer(self): -        """ -        Test JSONP rendering without View JSON Renderer. -        """ -        resp = self.client.get('/jsonp/nojsonrenderer', -                               HTTP_ACCEPT='application/javascript') -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual(resp.content, -            ('callback(%s);' % _flat_repr).encode('ascii')) - -    def test_with_callback(self): -        """ -        Test JSONP rendering with callback function name. -        """ -        callback_func = 'myjsonpcallback' -        resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, -                               HTTP_ACCEPT='application/javascript') -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual(resp.content, -            ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) - - -if yaml: -    _yaml_repr = 'foo: [bar, baz]\n' - -    class YAMLRendererTests(TestCase): -        """ -        Tests specific to the YAML Renderer -        """ - -        def test_render(self): -            """ -            Test basic YAML rendering. -            """ -            obj = {'foo': ['bar', 'baz']} -            renderer = YAMLRenderer() -            content = renderer.render(obj, 'application/yaml') -            self.assertEqual(content, _yaml_repr) - -        def test_render_and_parse(self): -            """ -            Test rendering and then parsing returns the original object. -            IE obj -> render -> parse -> obj. -            """ -            obj = {'foo': ['bar', 'baz']} - -            renderer = YAMLRenderer() -            parser = YAMLParser() - -            content = renderer.render(obj, 'application/yaml') -            data = parser.parse(StringIO(content)) -            self.assertEqual(obj, data) - -        def test_render_decimal(self): -            """ -            Test YAML decimal rendering. -            """ -            renderer = YAMLRenderer() -            content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') -            self.assertYAMLContains(content, "field: '111.2'") - -        def assertYAMLContains(self, content, string): -            self.assertTrue(string in content, '%r not in %r' % (string, content)) - - -class XMLRendererTestCase(TestCase): -    """ -    Tests specific to the XML Renderer -    """ - -    _complex_data = { -        "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), -        "name": "name", -        "sub_data_list": [ -            { -                "sub_id": 1, -                "sub_name": "first" -            }, -            { -                "sub_id": 2, -                "sub_name": "second" -            } -        ] -    } - -    def test_render_string(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 'astring'}, 'application/xml') -        self.assertXMLContains(content, '<field>astring</field>') - -    def test_render_integer(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 111}, 'application/xml') -        self.assertXMLContains(content, '<field>111</field>') - -    def test_render_datetime(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({ -            'field': datetime.datetime(2011, 12, 25, 12, 45, 00) -        }, 'application/xml') -        self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>') - -    def test_render_float(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 123.4}, 'application/xml') -        self.assertXMLContains(content, '<field>123.4</field>') - -    def test_render_decimal(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': Decimal('111.2')}, 'application/xml') -        self.assertXMLContains(content, '<field>111.2</field>') - -    def test_render_none(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': None}, 'application/xml') -        self.assertXMLContains(content, '<field></field>') - -    def test_render_complex_data(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render(self._complex_data, 'application/xml') -        self.assertXMLContains(content, '<sub_name>first</sub_name>') -        self.assertXMLContains(content, '<sub_name>second</sub_name>') - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_render_and_parse_complex_data(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = StringIO(renderer.render(self._complex_data, 'application/xml')) - -        parser = XMLParser() -        complex_data_out = parser.parse(content) -        error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) -        self.assertEqual(self._complex_data, complex_data_out, error_msg) - -    def assertXMLContains(self, xml, string): -        self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) -        self.assertTrue(xml.endswith('</root>')) -        self.assertTrue(string in xml, '%r not in %r' % (string, xml)) - - -# Tests for caching issue, #346 -class CacheRenderTest(TestCase): -    """ -    Tests specific to caching responses -    """ - -    urls = 'rest_framework.tests.test_renderers' - -    cache_key = 'just_a_cache_key' - -    @classmethod -    def _get_pickling_errors(cls, obj, seen=None): -        """ Return any errors that would be raised if `obj' is pickled -        Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 -        """ -        if seen == None: -            seen = [] -        try: -            state = obj.__getstate__() -        except AttributeError: -            return -        if state == None: -            return -        if isinstance(state, tuple): -            if not isinstance(state[0], dict): -                state = state[1] -            else: -                state = state[0].update(state[1]) -        result = {} -        for i in state: -            try: -                pickle.dumps(state[i], protocol=2) -            except pickle.PicklingError: -                if not state[i] in seen: -                    seen.append(state[i]) -                    result[i] = cls._get_pickling_errors(state[i], seen) -        return result - -    def http_resp(self, http_method, url): -        """ -        Simple wrapper for Client http requests -        Removes the `client' and `request' attributes from as they are -        added by django.test.client.Client and not part of caching -        responses outside of tests. -        """ -        method = getattr(self.client, http_method) -        resp = method(url) -        del resp.client, resp.request -        try: -            del resp.wsgi_request -        except AttributeError: -            pass -        return resp - -    def test_obj_pickling(self): -        """ -        Test that responses are properly pickled -        """ -        resp = self.http_resp('get', '/cache') - -        # Make sure that no pickling errors occurred -        self.assertEqual(self._get_pickling_errors(resp), {}) - -        # Unfortunately LocMem backend doesn't raise PickleErrors but returns -        # None instead. -        cache.set(self.cache_key, resp) -        self.assertTrue(cache.get(self.cache_key) is not None) - -    def test_head_caching(self): -        """ -        Test caching of HEAD requests -        """ -        resp = self.http_resp('head', '/cache') -        cache.set(self.cache_key, resp) - -        cached_resp = cache.get(self.cache_key) -        self.assertIsInstance(cached_resp, Response) - -    def test_get_caching(self): -        """ -        Test caching of GET requests -        """ -        resp = self.http_resp('get', '/cache') -        cache.set(self.cache_key, resp) - -        cached_resp = cache.get(self.cache_key) -        self.assertIsInstance(cached_resp, Response) -        self.assertEqual(cached_resp.content, resp.content) diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py deleted file mode 100644 index c0b50f33..00000000 --- a/rest_framework/tests/test_request.py +++ /dev/null @@ -1,347 +0,0 @@ -""" -Tests for content parsing, and form-overloaded content parsing. -""" -from __future__ import unicode_literals -from django.contrib.auth.models import User -from django.contrib.auth import authenticate, login, logout -from django.contrib.sessions.middleware import SessionMiddleware -from django.core.handlers.wsgi import WSGIRequest -from django.test import TestCase -from rest_framework import status -from rest_framework.authentication import SessionAuthentication -from rest_framework.compat import patterns -from rest_framework.parsers import ( -    BaseParser, -    FormParser, -    MultiPartParser, -    JSONParser -) -from rest_framework.request import Request, Empty -from rest_framework.response import Response -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory, APIClient -from rest_framework.views import APIView -from rest_framework.compat import six -from io import BytesIO -import json - - -factory = APIRequestFactory() - - -class PlainTextParser(BaseParser): -    media_type = 'text/plain' - -    def parse(self, stream, media_type=None, parser_context=None): -        """ -        Returns a 2-tuple of `(data, files)`. - -        `data` will simply be a string representing the body of the request. -        `files` will always be `None`. -        """ -        return stream.read() - - -class TestMethodOverloading(TestCase): -    def test_method(self): -        """ -        Request methods should be same as underlying request. -        """ -        request = Request(factory.get('/')) -        self.assertEqual(request.method, 'GET') -        request = Request(factory.post('/')) -        self.assertEqual(request.method, 'POST') - -    def test_overloaded_method(self): -        """ -        POST requests can be overloaded to another method by setting a -        reserved form field -        """ -        request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) -        self.assertEqual(request.method, 'DELETE') - -    def test_x_http_method_override_header(self): -        """ -        POST requests can also be overloaded to another method by setting -        the X-HTTP-Method-Override header. -        """ -        request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) -        self.assertEqual(request.method, 'DELETE') - -        request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) -        self.assertEqual(request.method, 'DELETE') - - -class TestContentParsing(TestCase): -    def test_standard_behaviour_determines_no_content_GET(self): -        """ -        Ensure request.DATA returns empty QueryDict for GET request. -        """ -        request = Request(factory.get('/')) -        self.assertEqual(request.DATA, {}) - -    def test_standard_behaviour_determines_no_content_HEAD(self): -        """ -        Ensure request.DATA returns empty QueryDict for HEAD request. -        """ -        request = Request(factory.head('/')) -        self.assertEqual(request.DATA, {}) - -    def test_request_DATA_with_form_content(self): -        """ -        Ensure request.DATA returns content for POST request with form content. -        """ -        data = {'qwerty': 'uiop'} -        request = Request(factory.post('/', data)) -        request.parsers = (FormParser(), MultiPartParser()) -        self.assertEqual(list(request.DATA.items()), list(data.items())) - -    def test_request_DATA_with_text_content(self): -        """ -        Ensure request.DATA returns content for POST request with -        non-form content. -        """ -        content = six.b('qwerty') -        content_type = 'text/plain' -        request = Request(factory.post('/', content, content_type=content_type)) -        request.parsers = (PlainTextParser(),) -        self.assertEqual(request.DATA, content) - -    def test_request_POST_with_form_content(self): -        """ -        Ensure request.POST returns content for POST request with form content. -        """ -        data = {'qwerty': 'uiop'} -        request = Request(factory.post('/', data)) -        request.parsers = (FormParser(), MultiPartParser()) -        self.assertEqual(list(request.POST.items()), list(data.items())) - -    def test_standard_behaviour_determines_form_content_PUT(self): -        """ -        Ensure request.DATA returns content for PUT request with form content. -        """ -        data = {'qwerty': 'uiop'} -        request = Request(factory.put('/', data)) -        request.parsers = (FormParser(), MultiPartParser()) -        self.assertEqual(list(request.DATA.items()), list(data.items())) - -    def test_standard_behaviour_determines_non_form_content_PUT(self): -        """ -        Ensure request.DATA returns content for PUT request with -        non-form content. -        """ -        content = six.b('qwerty') -        content_type = 'text/plain' -        request = Request(factory.put('/', content, content_type=content_type)) -        request.parsers = (PlainTextParser(), ) -        self.assertEqual(request.DATA, content) - -    def test_overloaded_behaviour_allows_content_tunnelling(self): -        """ -        Ensure request.DATA returns content for overloaded POST request. -        """ -        json_data = {'foobar': 'qwerty'} -        content = json.dumps(json_data) -        content_type = 'application/json' -        form_data = { -            api_settings.FORM_CONTENT_OVERRIDE: content, -            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type -        } -        request = Request(factory.post('/', form_data)) -        request.parsers = (JSONParser(), ) -        self.assertEqual(request.DATA, json_data) - -    def test_form_POST_unicode(self): -        """ -        JSON POST via default web interface with unicode data -        """ -        # Note: environ and other variables here have simplified content compared to real Request -        CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D' -        environ = { -            'REQUEST_METHOD': 'POST', -            'CONTENT_TYPE': 'application/x-www-form-urlencoded', -            'CONTENT_LENGTH': len(CONTENT), -            'wsgi.input': BytesIO(CONTENT), -        } -        wsgi_request = WSGIRequest(environ=environ) -        wsgi_request._load_post_and_files() -        parsers = (JSONParser(), FormParser(), MultiPartParser()) -        parser_context = { -            'encoding': 'utf-8', -            'kwargs': {}, -            'args': (), -        } -        request = Request(wsgi_request, parsers=parsers, parser_context=parser_context) -        method = request.method -        self.assertEqual(method, 'POST') -        self.assertEqual(request._content_type, 'application/json') -        self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}') -        self.assertEqual(request._data, Empty) -        self.assertEqual(request._files, Empty) - -    # def test_accessing_post_after_data_form(self): -    #     """ -    #     Ensures request.POST can be accessed after request.DATA in -    #     form request. -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     request = factory.post('/', data=data) -    #     self.assertEqual(request.DATA.items(), data.items()) -    #     self.assertEqual(request.POST.items(), data.items()) - -    # def test_accessing_post_after_data_for_json(self): -    #     """ -    #     Ensures request.POST can be accessed after request.DATA in -    #     json request. -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     content = json.dumps(data) -    #     content_type = 'application/json' -    #     parsers = (JSONParser, ) - -    #     request = factory.post('/', content, content_type=content_type, -    #                            parsers=parsers) -    #     self.assertEqual(request.DATA.items(), data.items()) -    #     self.assertEqual(request.POST.items(), []) - -    # def test_accessing_post_after_data_for_overloaded_json(self): -    #     """ -    #     Ensures request.POST can be accessed after request.DATA in overloaded -    #     json request. -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     content = json.dumps(data) -    #     content_type = 'application/json' -    #     parsers = (JSONParser, ) -    #     form_data = {Request._CONTENT_PARAM: content, -    #                  Request._CONTENTTYPE_PARAM: content_type} - -    #     request = factory.post('/', form_data, parsers=parsers) -    #     self.assertEqual(request.DATA.items(), data.items()) -    #     self.assertEqual(request.POST.items(), form_data.items()) - -    # def test_accessing_data_after_post_form(self): -    #     """ -    #     Ensures request.DATA can be accessed after request.POST in -    #     form request. -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     parsers = (FormParser, MultiPartParser) -    #     request = factory.post('/', data, parsers=parsers) - -    #     self.assertEqual(request.POST.items(), data.items()) -    #     self.assertEqual(request.DATA.items(), data.items()) - -    # def test_accessing_data_after_post_for_json(self): -    #     """ -    #     Ensures request.DATA can be accessed after request.POST in -    #     json request. -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     content = json.dumps(data) -    #     content_type = 'application/json' -    #     parsers = (JSONParser, ) -    #     request = factory.post('/', content, content_type=content_type, -    #                            parsers=parsers) -    #     self.assertEqual(request.POST.items(), []) -    #     self.assertEqual(request.DATA.items(), data.items()) - -    # def test_accessing_data_after_post_for_overloaded_json(self): -    #     """ -    #     Ensures request.DATA can be accessed after request.POST in overloaded -    #     json request -    #     """ -    #     data = {'qwerty': 'uiop'} -    #     content = json.dumps(data) -    #     content_type = 'application/json' -    #     parsers = (JSONParser, ) -    #     form_data = {Request._CONTENT_PARAM: content, -    #                  Request._CONTENTTYPE_PARAM: content_type} - -    #     request = factory.post('/', form_data, parsers=parsers) -    #     self.assertEqual(request.POST.items(), form_data.items()) -    #     self.assertEqual(request.DATA.items(), data.items()) - - -class MockView(APIView): -    authentication_classes = (SessionAuthentication,) - -    def post(self, request): -        if request.POST.get('example') is not None: -            return Response(status=status.HTTP_200_OK) - -        return Response(status=status.INTERNAL_SERVER_ERROR) - -urlpatterns = patterns('', -    (r'^$', MockView.as_view()), -) - - -class TestContentParsingWithAuthentication(TestCase): -    urls = 'rest_framework.tests.test_request' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -    def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): -        """ -        Ensures request.POST exists after SessionAuthentication when user -        doesn't log in. -        """ -        content = {'example': 'example'} - -        response = self.client.post('/', content) -        self.assertEqual(status.HTTP_200_OK, response.status_code) - -        response = self.csrf_client.post('/', content) -        self.assertEqual(status.HTTP_200_OK, response.status_code) - -    # def test_user_logged_in_authentication_has_post_when_logged_in(self): -    #     """Ensures request.POST exists after UserLoggedInAuthentication when user does log in""" -    #     self.client.login(username='john', password='password') -    #     self.csrf_client.login(username='john', password='password') -    #     content = {'example': 'example'} - -    #     response = self.client.post('/', content) -    #     self.assertEqual(status.OK, response.status_code, "POST data is malformed") - -    #     response = self.csrf_client.post('/', content) -    #     self.assertEqual(status.OK, response.status_code, "POST data is malformed") - - -class TestUserSetter(TestCase): - -    def setUp(self): -        # Pass request object through session middleware so session is -        # available to login and logout functions -        self.request = Request(factory.get('/')) -        SessionMiddleware().process_request(self.request) - -        User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') -        self.user = authenticate(username='ringo', password='yellow') - -    def test_user_can_be_set(self): -        self.request.user = self.user -        self.assertEqual(self.request.user, self.user) - -    def test_user_can_login(self): -        login(self.request, self.user) -        self.assertEqual(self.request.user, self.user) - -    def test_user_can_logout(self): -        self.request.user = self.user -        self.assertFalse(self.request.user.is_anonymous()) -        logout(self.request) -        self.assertTrue(self.request.user.is_anonymous()) - - -class TestAuthSetter(TestCase): - -    def test_auth_can_be_set(self): -        request = Request(factory.get('/')) -        request.auth = 'DUMMY' -        self.assertEqual(request.auth, 'DUMMY') diff --git a/rest_framework/tests/test_response.py b/rest_framework/tests/test_response.py deleted file mode 100644 index eea3c641..00000000 --- a/rest_framework/tests/test_response.py +++ /dev/null @@ -1,278 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.tests.models import BasicModel, BasicModelSerializer -from rest_framework.compat import patterns, url, include -from rest_framework.response import Response -from rest_framework.views import APIView -from rest_framework import generics -from rest_framework import routers -from rest_framework import status -from rest_framework.renderers import ( -    BaseRenderer, -    JSONRenderer, -    BrowsableAPIRenderer -) -from rest_framework import viewsets -from rest_framework.settings import api_settings -from rest_framework.compat import six - - -class MockPickleRenderer(BaseRenderer): -    media_type = 'application/pickle' - - -class MockJsonRenderer(BaseRenderer): -    media_type = 'application/json' - - -class MockTextMediaRenderer(BaseRenderer): -    media_type = 'text/html' - -DUMMYSTATUS = status.HTTP_200_OK -DUMMYCONTENT = 'dummycontent' - -RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') -RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') - - -class RendererA(BaseRenderer): -    media_type = 'mock/renderera' -    format = "formata" - -    def render(self, data, media_type=None, renderer_context=None): -        return RENDERER_A_SERIALIZER(data) - - -class RendererB(BaseRenderer): -    media_type = 'mock/rendererb' -    format = "formatb" - -    def render(self, data, media_type=None, renderer_context=None): -        return RENDERER_B_SERIALIZER(data) - - -class RendererC(RendererB): -    media_type = 'mock/rendererc' -    format = 'formatc' -    charset = "rendererc" - - -class MockView(APIView): -    renderer_classes = (RendererA, RendererB, RendererC) - -    def get(self, request, **kwargs): -        return Response(DUMMYCONTENT, status=DUMMYSTATUS) - - -class MockViewSettingContentType(APIView): -    renderer_classes = (RendererA, RendererB, RendererC) - -    def get(self, request, **kwargs): -        return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') - - -class HTMLView(APIView): -    renderer_classes = (BrowsableAPIRenderer, ) - -    def get(self, request, **kwargs): -        return Response('text') - - -class HTMLView1(APIView): -    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) - -    def get(self, request, **kwargs): -        return Response('text') - - -class HTMLNewModelViewSet(viewsets.ModelViewSet): -    model = BasicModel - - -class HTMLNewModelView(generics.ListCreateAPIView): -    renderer_classes = (BrowsableAPIRenderer,) -    permission_classes = [] -    serializer_class = BasicModelSerializer -    model = BasicModel - - -new_model_viewset_router = routers.DefaultRouter() -new_model_viewset_router.register(r'', HTMLNewModelViewSet) - - -urlpatterns = patterns('', -    url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), -    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), -    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), -    url(r'^html$', HTMLView.as_view()), -    url(r'^html1$', HTMLView1.as_view()), -    url(r'^html_new_model$', HTMLNewModelView.as_view()), -    url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), -    url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) -) - - -# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... -class RendererIntegrationTests(TestCase): -    """ -    End-to-end testing of renderers using an ResponseMixin on a generic view. -    """ - -    urls = 'rest_framework.tests.test_response' - -    def test_default_renderer_serializes_content(self): -        """If the Accept header is not set the default renderer should serialize the response.""" -        resp = self.client.get('/') -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_head_method_serializes_no_content(self): -        """No response must be included in HEAD requests.""" -        resp = self.client.head('/') -        self.assertEqual(resp.status_code, DUMMYSTATUS) -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, six.b('')) - -    def test_default_renderer_serializes_content_on_accept_any(self): -        """If the Accept header is set to */* the default renderer should serialize the response.""" -        resp = self.client.get('/', HTTP_ACCEPT='*/*') -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_default_case(self): -        """If the Accept header is set the specified renderer should serialize the response. -        (In this case we check that works for the default renderer)""" -        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) -        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_non_default_case(self): -        """If the Accept header is set the specified renderer should serialize the response. -        (In this case we check that works for a non-default renderer)""" -        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_on_accept_query(self): -        """The '_accept' query string should behave in the same way as the Accept header.""" -        param = '?%s=%s' % ( -            api_settings.URL_ACCEPT_OVERRIDE, -            RendererB.media_type -        ) -        resp = self.client.get('/' + param) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_on_format_query(self): -        """If a 'format' query is specified, the renderer with the matching -        format attribute should serialize the response.""" -        resp = self.client.get('/?format=%s' % RendererB.format) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_serializes_content_on_format_kwargs(self): -        """If a 'format' keyword arg is specified, the renderer with the matching -        format attribute should serialize the response.""" -        resp = self.client.get('/something.formatb') -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - -    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): -        """If both a 'format' query and a matching Accept header specified, -        the renderer with the matching format attribute should serialize the response.""" -        resp = self.client.get('/?format=%s' % RendererB.format, -                               HTTP_ACCEPT=RendererB.media_type) -        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') -        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) -        self.assertEqual(resp.status_code, DUMMYSTATUS) - - -class Issue122Tests(TestCase): -    """ -    Tests that covers #122. -    """ -    urls = 'rest_framework.tests.test_response' - -    def test_only_html_renderer(self): -        """ -        Test if no infinite recursion occurs. -        """ -        self.client.get('/html') - -    def test_html_renderer_is_first(self): -        """ -        Test if no infinite recursion occurs. -        """ -        self.client.get('/html1') - - -class Issue467Tests(TestCase): -    """ -    Tests for #467 -    """ - -    urls = 'rest_framework.tests.test_response' - -    def test_form_has_label_and_help_text(self): -        resp = self.client.get('/html_new_model') -        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') -        self.assertContains(resp, 'Text comes here') -        self.assertContains(resp, 'Text description.') - - -class Issue807Tests(TestCase): -    """ -    Covers #807 -    """ - -    urls = 'rest_framework.tests.test_response' - -    def test_does_not_append_charset_by_default(self): -        """ -        Renderers don't include a charset unless set explicitly. -        """ -        headers = {"HTTP_ACCEPT": RendererA.media_type} -        resp = self.client.get('/', **headers) -        expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8') -        self.assertEqual(expected, resp['Content-Type']) - -    def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): -        """ -        If renderer class has charset attribute declared, it gets appended -        to Response's Content-Type -        """ -        headers = {"HTTP_ACCEPT": RendererC.media_type} -        resp = self.client.get('/', **headers) -        expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) -        self.assertEqual(expected, resp['Content-Type']) - -    def test_content_type_set_explictly_on_response(self): -        """ -        The content type may be set explictly on the response. -        """ -        headers = {"HTTP_ACCEPT": RendererC.media_type} -        resp = self.client.get('/setbyview', **headers) -        self.assertEqual('setbyview', resp['Content-Type']) - -    def test_viewset_label_help_text(self): -        param = '?%s=%s' % ( -            api_settings.URL_ACCEPT_OVERRIDE, -            'text/html' -        ) -        resp = self.client.get('/html_new_model_viewset/' + param) -        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') -        self.assertContains(resp, 'Text comes here') -        self.assertContains(resp, 'Text description.') - -    def test_form_has_label_and_help_text(self): -        resp = self.client.get('/html_new_model') -        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') -        self.assertContains(resp, 'Text comes here') -        self.assertContains(resp, 'Text description.') diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py deleted file mode 100644 index 690a30b1..00000000 --- a/rest_framework/tests/test_reverse.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.compat import patterns, url -from rest_framework.reverse import reverse -from rest_framework.test import APIRequestFactory - -factory = APIRequestFactory() - - -def null_view(request): -    pass - -urlpatterns = patterns('', -    url(r'^view$', null_view, name='view'), -) - - -class ReverseTests(TestCase): -    """ -    Tests for fully qualified URLs when using `reverse`. -    """ -    urls = 'rest_framework.tests.test_reverse' - -    def test_reversed_urls_are_fully_qualified(self): -        request = factory.get('/view') -        url = reverse('view', request=request) -        self.assertEqual(url, 'http://testserver/view') diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py deleted file mode 100644 index e723f7d4..00000000 --- a/rest_framework/tests/test_routers.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from django.core.exceptions import ImproperlyConfigured -from rest_framework import serializers, viewsets, permissions -from rest_framework.compat import include, patterns, url -from rest_framework.decorators import link, action -from rest_framework.response import Response -from rest_framework.routers import SimpleRouter, DefaultRouter -from rest_framework.test import APIRequestFactory - -factory = APIRequestFactory() - -urlpatterns = patterns('',) - - -class BasicViewSet(viewsets.ViewSet): -    def list(self, request, *args, **kwargs): -        return Response({'method': 'list'}) - -    @action() -    def action1(self, request, *args, **kwargs): -        return Response({'method': 'action1'}) - -    @action() -    def action2(self, request, *args, **kwargs): -        return Response({'method': 'action2'}) - -    @action(methods=['post', 'delete']) -    def action3(self, request, *args, **kwargs): -        return Response({'method': 'action2'}) - -    @link() -    def link1(self, request, *args, **kwargs): -        return Response({'method': 'link1'}) - -    @link() -    def link2(self, request, *args, **kwargs): -        return Response({'method': 'link2'}) - - -class TestSimpleRouter(TestCase): -    def setUp(self): -        self.router = SimpleRouter() - -    def test_link_and_action_decorator(self): -        routes = self.router.get_routes(BasicViewSet) -        decorator_routes = routes[2:] -        # Make sure all these endpoints exist and none have been clobbered -        for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): -            route = decorator_routes[i] -            # check url listing -            self.assertEqual(route.url, -                             '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) -            # check method to function mapping -            if endpoint == 'action3': -                methods_map = ['post', 'delete'] -            elif endpoint.startswith('action'): -                methods_map = ['post'] -            else: -                methods_map = ['get'] -            for method in methods_map: -                self.assertEqual(route.mapping[method], endpoint) - - -class RouterTestModel(models.Model): -    uuid = models.CharField(max_length=20) -    text = models.CharField(max_length=200) - - -class TestCustomLookupFields(TestCase): -    """ -    Ensure that custom lookup fields are correctly routed. -    """ -    urls = 'rest_framework.tests.test_routers' - -    def setUp(self): -        class NoteSerializer(serializers.HyperlinkedModelSerializer): -            class Meta: -                model = RouterTestModel -                lookup_field = 'uuid' -                fields = ('url', 'uuid', 'text') - -        class NoteViewSet(viewsets.ModelViewSet): -            queryset = RouterTestModel.objects.all() -            serializer_class = NoteSerializer -            lookup_field = 'uuid' - -        RouterTestModel.objects.create(uuid='123', text='foo bar') - -        self.router = SimpleRouter() -        self.router.register(r'notes', NoteViewSet) - -        from rest_framework.tests import test_routers -        urls = getattr(test_routers, 'urlpatterns') -        urls += patterns('', -            url(r'^', include(self.router.urls)), -        ) - -    def test_custom_lookup_field_route(self): -        detail_route = self.router.urls[-1] -        detail_url_pattern = detail_route.regex.pattern -        self.assertIn('<uuid>', detail_url_pattern) - -    def test_retrieve_lookup_field_list_view(self): -        response = self.client.get('/notes/') -        self.assertEqual(response.data, -            [{ -                "url": "http://testserver/notes/123/", -                "uuid": "123", "text": "foo bar" -            }] -        ) - -    def test_retrieve_lookup_field_detail_view(self): -        response = self.client.get('/notes/123/') -        self.assertEqual(response.data, -            { -                "url": "http://testserver/notes/123/", -                "uuid": "123", "text": "foo bar" -            } -        ) - - -class TestTrailingSlashIncluded(TestCase): -    def setUp(self): -        class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel - -        self.router = SimpleRouter() -        self.router.register(r'notes', NoteViewSet) -        self.urls = self.router.urls - -    def test_urls_have_trailing_slash_by_default(self): -        expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] -        for idx in range(len(expected)): -            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) - - -class TestTrailingSlashRemoved(TestCase): -    def setUp(self): -        class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel - -        self.router = SimpleRouter(trailing_slash=False) -        self.router.register(r'notes', NoteViewSet) -        self.urls = self.router.urls - -    def test_urls_can_have_trailing_slash_removed(self): -        expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] -        for idx in range(len(expected)): -            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) - - -class TestNameableRoot(TestCase): -    def setUp(self): -        class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel -        self.router = DefaultRouter() -        self.router.root_view_name = 'nameable-root' -        self.router.register(r'notes', NoteViewSet) -        self.urls = self.router.urls - -    def test_router_has_custom_name(self): -        expected = 'nameable-root' -        self.assertEqual(expected, self.urls[0].name) - - -class TestActionKeywordArgs(TestCase): -    """ -    Ensure keyword arguments passed in the `@action` decorator -    are properly handled.  Refs #940. -    """ - -    def setUp(self): -        class TestViewSet(viewsets.ModelViewSet): -            permission_classes = [] - -            @action(permission_classes=[permissions.AllowAny]) -            def custom(self, request, *args, **kwargs): -                return Response({ -                    'permission_classes': self.permission_classes -                }) - -        self.router = SimpleRouter() -        self.router.register(r'test', TestViewSet, base_name='test') -        self.view = self.router.urls[-1].callback - -    def test_action_kwargs(self): -        request = factory.post('/test/0/custom/') -        response = self.view(request) -        self.assertEqual( -            response.data, -            {'permission_classes': [permissions.AllowAny]} -        ) - - -class TestActionAppliedToExistingRoute(TestCase): -    """ -    Ensure `@action` decorator raises an except when applied -    to an existing route -    """ - -    def test_exception_raised_when_action_applied_to_existing_route(self): -        class TestViewSet(viewsets.ModelViewSet): - -            @action() -            def retrieve(self, request, *args, **kwargs): -                return Response({ -                    'hello': 'world' -                }) - -        self.router = SimpleRouter() -        self.router.register(r'test', TestViewSet, base_name='test') - -        with self.assertRaises(ImproperlyConfigured): -            self.router.urls diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py deleted file mode 100644 index 3ee2b38a..00000000 --- a/rest_framework/tests/test_serializer.py +++ /dev/null @@ -1,1949 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from django.db import models -from django.db.models.fields import BLANK_CHOICE_DASH -from django.test import TestCase -from django.utils import unittest -from django.utils.datastructures import MultiValueDict -from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers, fields, relations -from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, -    BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, -    ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel) -from rest_framework.tests.models import BasicModelSerializer -import datetime -import pickle -try: -    import PIL -except: -    PIL = None - - -if PIL is not None: -    class AMOAFModel(RESTFrameworkModel): -        char_field = models.CharField(max_length=1024, blank=True) -        comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) -        decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) -        email_field = models.EmailField(max_length=1024, blank=True) -        file_field = models.FileField(upload_to='test', max_length=1024, blank=True) -        image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) -        slug_field = models.SlugField(max_length=1024, blank=True) -        url_field = models.URLField(max_length=1024, blank=True) - -    class DVOAFModel(RESTFrameworkModel): -        positive_integer_field = models.PositiveIntegerField(blank=True) -        positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) -        email_field = models.EmailField(blank=True) -        file_field = models.FileField(upload_to='test', blank=True) -        image_field = models.ImageField(upload_to='test', blank=True) -        slug_field = models.SlugField(blank=True) -        url_field = models.URLField(blank=True) - - -class SubComment(object): -    def __init__(self, sub_comment): -        self.sub_comment = sub_comment - - -class Comment(object): -    def __init__(self, email, content, created): -        self.email = email -        self.content = content -        self.created = created or datetime.datetime.now() - -    def __eq__(self, other): -        return all([getattr(self, attr) == getattr(other, attr) -                    for attr in ('email', 'content', 'created')]) - -    def get_sub_comment(self): -        sub_comment = SubComment('And Merry Christmas!') -        return sub_comment - - -class CommentSerializer(serializers.Serializer): -    email = serializers.EmailField() -    content = serializers.CharField(max_length=1000) -    created = serializers.DateTimeField() -    sub_comment = serializers.Field(source='get_sub_comment.sub_comment') - -    def restore_object(self, data, instance=None): -        if instance is None: -            return Comment(**data) -        for key, val in data.items(): -            setattr(instance, key, val) -        return instance - - -class NamesSerializer(serializers.Serializer): -    first = serializers.CharField() -    last = serializers.CharField(required=False, default='') -    initials = serializers.CharField(required=False, default='') - - -class PersonIdentifierSerializer(serializers.Serializer): -    ssn = serializers.CharField() -    names = NamesSerializer(source='names', required=False) - - -class BookSerializer(serializers.ModelSerializer): -    isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) - -    class Meta: -        model = Book - - -class ActionItemSerializer(serializers.ModelSerializer): - -    class Meta: -        model = ActionItem - -class ActionItemSerializerOptionalFields(serializers.ModelSerializer): -    """ -    Intended to test that fields with `required=False` are excluded from validation. -    """ -    title = serializers.CharField(required=False) - -    class Meta: -        model = ActionItem -        fields = ('title',) - -class ActionItemSerializerCustomRestore(serializers.ModelSerializer): - -    class Meta: -        model = ActionItem - -    def restore_object(self, data, instance=None): -        if instance is None: -            return ActionItem(**data) -        for key, val in data.items(): -            setattr(instance, key, val) -        return instance - - -class PersonSerializer(serializers.ModelSerializer): -    info = serializers.Field(source='info') - -    class Meta: -        model = Person -        fields = ('name', 'age', 'info') -        read_only_fields = ('age',) - - -class NestedSerializer(serializers.Serializer): -    info = serializers.Field() - - -class ModelSerializerWithNestedSerializer(serializers.ModelSerializer): -    nested = NestedSerializer(source='*') - -    class Meta: -        model = Person - - -class NestedSerializerWithRenamedField(serializers.Serializer): -    renamed_info = serializers.Field(source='info') - - -class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer): -    nested = NestedSerializerWithRenamedField(source='*') - -    class Meta: -        model = Person - - -class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): -    """ -    Testing for #652. -    """ -    info = serializers.Field(source='info') - -    class Meta: -        model = Person -        fields = ('name', 'age', 'info') -        read_only_fields = ('age', 'info') - - -class AlbumsSerializer(serializers.ModelSerializer): - -    class Meta: -        model = Album -        fields = ['title', 'ref']  # lists are also valid options - - -class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): -    class Meta: -        model = HasPositiveIntegerAsChoice -        fields = ['some_integer'] - - -class BasicTests(TestCase): -    def setUp(self): -        self.comment = Comment( -            'tom@example.com', -            'Happy new year!', -            datetime.datetime(2012, 1, 1) -        ) -        self.actionitem = ActionItem(title='Some to do item',) -        self.data = { -            'email': 'tom@example.com', -            'content': 'Happy new year!', -            'created': datetime.datetime(2012, 1, 1), -            'sub_comment': 'This wont change' -        } -        self.expected = { -            'email': 'tom@example.com', -            'content': 'Happy new year!', -            'created': datetime.datetime(2012, 1, 1), -            'sub_comment': 'And Merry Christmas!' -        } -        self.person_data = {'name': 'dwight', 'age': 35} -        self.person = Person(**self.person_data) -        self.person.save() - -    def test_empty(self): -        serializer = CommentSerializer() -        expected = { -            'email': '', -            'content': '', -            'created': None -        } -        self.assertEqual(serializer.data, expected) - -    def test_retrieve(self): -        serializer = CommentSerializer(self.comment) -        self.assertEqual(serializer.data, self.expected) - -    def test_create(self): -        serializer = CommentSerializer(data=self.data) -        expected = self.comment -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected) -        self.assertFalse(serializer.object is expected) -        self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') - -    def test_create_nested(self): -        """Test a serializer with nested data.""" -        names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} -        data = {'ssn': '1234567890', 'names': names} -        serializer = PersonIdentifierSerializer(data=data) - -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, data) -        self.assertFalse(serializer.object is data) -        self.assertEqual(serializer.data['names'], names) - -    def test_create_partial_nested(self): -        """Test a serializer with nested data which has missing fields.""" -        names = {'first': 'John'} -        data = {'ssn': '1234567890', 'names': names} -        serializer = PersonIdentifierSerializer(data=data) - -        expected_names = {'first': 'John', 'last': '', 'initials': ''} -        data['names'] = expected_names - -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, data) -        self.assertFalse(serializer.object is expected_names) -        self.assertEqual(serializer.data['names'], expected_names) - -    def test_null_nested(self): -        """Test a serializer with a nonexistent nested field""" -        data = {'ssn': '1234567890'} -        serializer = PersonIdentifierSerializer(data=data) - -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, data) -        self.assertFalse(serializer.object is data) -        expected = {'ssn': '1234567890', 'names': None} -        self.assertEqual(serializer.data, expected) - -    def test_update(self): -        serializer = CommentSerializer(self.comment, data=self.data) -        expected = self.comment -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected) -        self.assertTrue(serializer.object is expected) -        self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') - -    def test_partial_update(self): -        msg = 'Merry New Year!' -        partial_data = {'content': msg} -        serializer = CommentSerializer(self.comment, data=partial_data) -        self.assertEqual(serializer.is_valid(), False) -        serializer = CommentSerializer(self.comment, data=partial_data, partial=True) -        expected = self.comment -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected) -        self.assertTrue(serializer.object is expected) -        self.assertEqual(serializer.data['content'], msg) - -    def test_model_fields_as_expected(self): -        """ -        Make sure that the fields returned are the same as defined -        in the Meta data -        """ -        serializer = PersonSerializer(self.person) -        self.assertEqual(set(serializer.data.keys()), -                          set(['name', 'age', 'info'])) - -    def test_field_with_dictionary(self): -        """ -        Make sure that dictionaries from fields are left intact -        """ -        serializer = PersonSerializer(self.person) -        expected = self.person_data -        self.assertEqual(serializer.data['info'], expected) - -    def test_read_only_fields(self): -        """ -        Attempting to update fields set as read_only should have no effect. -        """ -        serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(serializer.errors, {}) -        # Assert age is unchanged (35) -        self.assertEqual(instance.age, self.person_data['age']) - -    def test_invalid_read_only_fields(self): -        """ -        Regression test for #652. -        """ -        self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) - -    def test_serializer_data_is_cleared_on_save(self): -        """ -        Check _data attribute is cleared on `save()` - -        Regression test for #1116 -            — id field is not populated if `data` is accessed prior to `save()` -        """ -        serializer = ActionItemSerializer(self.actionitem) -        self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.') -        serializer.save() -        self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') - -    def test_fields_marked_as_not_required_are_excluded_from_validation(self): -        """ -        Check that fields with `required=False` are included in list of exclusions. -        """ -        serializer = ActionItemSerializerOptionalFields(self.actionitem) -        exclusions = serializer.get_validation_exclusions() -        self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded') - - -class DictStyleSerializer(serializers.Serializer): -    """ -    Note that we don't have any `restore_object` method, so the default -    case of simply returning a dict will apply. -    """ -    email = serializers.EmailField() - - -class DictStyleSerializerTests(TestCase): -    def test_dict_style_deserialize(self): -        """ -        Ensure serializers can deserialize into a dict. -        """ -        data = {'email': 'foo@example.com'} -        serializer = DictStyleSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, data) - -    def test_dict_style_serialize(self): -        """ -        Ensure serializers can serialize dict objects. -        """ -        data = {'email': 'foo@example.com'} -        serializer = DictStyleSerializer(data) -        self.assertEqual(serializer.data, data) - - -class ValidationTests(TestCase): -    def setUp(self): -        self.comment = Comment( -            'tom@example.com', -            'Happy new year!', -            datetime.datetime(2012, 1, 1) -        ) -        self.data = { -            'email': 'tom@example.com', -            'content': 'x' * 1001, -            'created': datetime.datetime(2012, 1, 1) -        } -        self.actionitem = ActionItem(title='Some to do item',) - -    def test_create(self): -        serializer = CommentSerializer(data=self.data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) - -    def test_update(self): -        serializer = CommentSerializer(self.comment, data=self.data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) - -    def test_update_missing_field(self): -        data = { -            'content': 'xxx', -            'created': datetime.datetime(2012, 1, 1) -        } -        serializer = CommentSerializer(self.comment, data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'email': ['This field is required.']}) - -    def test_missing_bool_with_default(self): -        """Make sure that a boolean value with a 'False' value is not -        mistaken for not having a default.""" -        data = { -            'title': 'Some action item', -            #No 'done' value. -        } -        serializer = ActionItemSerializer(self.actionitem, data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.errors, {}) - -    def test_cross_field_validation(self): - -        class CommentSerializerWithCrossFieldValidator(CommentSerializer): - -            def validate(self, attrs): -                if attrs["email"] not in attrs["content"]: -                    raise serializers.ValidationError("Email address not in content") -                return attrs - -        data = { -            'email': 'tom@example.com', -            'content': 'A comment from tom@example.com', -            'created': datetime.datetime(2012, 1, 1) -        } - -        serializer = CommentSerializerWithCrossFieldValidator(data=data) -        self.assertTrue(serializer.is_valid()) - -        data['content'] = 'A comment from foo@bar.com' - -        serializer = CommentSerializerWithCrossFieldValidator(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']}) - -    def test_null_is_true_fields(self): -        """ -        Omitting a value for null-field should validate. -        """ -        serializer = PersonSerializer(data={'name': 'marko'}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.errors, {}) - -    def test_modelserializer_max_length_exceeded(self): -        data = { -            'title': 'x' * 201, -        } -        serializer = ActionItemSerializer(data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) - -    def test_modelserializer_max_length_exceeded_with_custom_restore(self): -        """ -        When overriding ModelSerializer.restore_object, validation tests should still apply. -        Regression test for #623. - -        https://github.com/tomchristie/django-rest-framework/pull/623 -        """ -        data = { -            'title': 'x' * 201, -        } -        serializer = ActionItemSerializerCustomRestore(data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) - -    def test_default_modelfield_max_length_exceeded(self): -        data = { -            'title': 'Testing "info" field...', -            'info': 'x' * 13, -        } -        serializer = ActionItemSerializer(data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) - -    def test_datetime_validation_failure(self): -        """ -        Test DateTimeField validation errors on non-str values. -        Regression test for #669. - -        https://github.com/tomchristie/django-rest-framework/issues/669 -        """ -        data = self.data -        data['created'] = 0 - -        serializer = CommentSerializer(data=data) -        self.assertEqual(serializer.is_valid(), False) - -        self.assertIn('created', serializer.errors) - -    def test_missing_model_field_exception_msg(self): -        """ -        Assert that a meaningful exception message is outputted when the model -        field is missing (e.g. when mistyping ``model``). -        """ -        class BrokenModelSerializer(serializers.ModelSerializer): -            class Meta: -                fields = ['some_field'] - -        try: -            BrokenModelSerializer() -        except AssertionError as e: -            self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") -        except: -            self.fail('Wrong exception type thrown.') - -    def test_writable_star_source_on_nested_serializer(self): -        """ -        Assert that a nested serializer instantiated with source='*' correctly -        expands the data into the outer serializer. -        """ -        serializer = ModelSerializerWithNestedSerializer(data={ -            'name': 'marko', -            'nested': {'info': 'hi'}}, -        ) -        self.assertEqual(serializer.is_valid(), True) - -    def test_writable_star_source_on_nested_serializer_with_parent_object(self): -        class TitleSerializer(serializers.Serializer): -            title = serializers.WritableField(source='title') - -        class AlbumSerializer(serializers.ModelSerializer): -            nested = TitleSerializer(source='*') - -            class Meta: -                model = Album -                fields = ('nested',) - -        class PhotoSerializer(serializers.ModelSerializer): -            album = AlbumSerializer(source='album') - -            class Meta: -                model = Photo -                fields = ('album', ) - -        photo = Photo(album=Album()) - -        data = {'album': {'nested': {'title': 'test'}}} - -        serializer = PhotoSerializer(photo, data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data, data) - -    def test_writable_star_source_with_inner_source_fields(self): -        """ -        Tests that a serializer with source="*" correctly expands the -        it's fields into the outer serializer even if they have their -        own 'source' parameters. -        """ - -        serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={ -            'name': 'marko', -            'nested': {'renamed_info': 'hi'}}, -        ) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.errors, {}) - - -class CustomValidationTests(TestCase): -    class CommentSerializerWithFieldValidator(CommentSerializer): - -        def validate_email(self, attrs, source): -            attrs[source] -            return attrs - -        def validate_content(self, attrs, source): -            value = attrs[source] -            if "test" not in value: -                raise serializers.ValidationError("Test not in value") -            return attrs - -    def test_field_validation(self): -        data = { -            'email': 'tom@example.com', -            'content': 'A test comment', -            'created': datetime.datetime(2012, 1, 1) -        } - -        serializer = self.CommentSerializerWithFieldValidator(data=data) -        self.assertTrue(serializer.is_valid()) - -        data['content'] = 'This should not validate' - -        serializer = self.CommentSerializerWithFieldValidator(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'content': ['Test not in value']}) - -    def test_missing_data(self): -        """ -        Make sure that validate_content isn't called if the field is missing -        """ -        incomplete_data = { -            'email': 'tom@example.com', -            'created': datetime.datetime(2012, 1, 1) -        } -        serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'content': ['This field is required.']}) - -    def test_wrong_data(self): -        """ -        Make sure that validate_content isn't called if the field input is wrong -        """ -        wrong_data = { -            'email': 'not an email', -            'content': 'A test comment', -            'created': datetime.datetime(2012, 1, 1) -        } -        serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) - -    def test_partial_update(self): -        """ -        Make sure that validate_email isn't called when partial=True and email -        isn't found in data. -        """ -        initial_data = { -            'email': 'tom@example.com', -            'content': 'A test comment', -            'created': datetime.datetime(2012, 1, 1) -        } - -        serializer = self.CommentSerializerWithFieldValidator(data=initial_data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.object - -        new_content = 'An *updated* test comment' -        partial_data = { -            'content': new_content -        } - -        serializer = self.CommentSerializerWithFieldValidator(instance=instance, -                                                              data=partial_data, -                                                              partial=True) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.object -        self.assertEqual(instance.content, new_content) - - -class PositiveIntegerAsChoiceTests(TestCase): -    def test_positive_integer_in_json_is_correctly_parsed(self): -        data = {'some_integer': 1} -        serializer = PositiveIntegerAsChoiceSerializer(data=data) -        self.assertEqual(serializer.is_valid(), True) - - -class ModelValidationTests(TestCase): -    def test_validate_unique(self): -        """ -        Just check if serializers.ModelSerializer handles unique checks via .full_clean() -        """ -        serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'}) -        serializer.is_valid() -        serializer.save() -        second_serializer = AlbumsSerializer(data={'title': 'a'}) -        self.assertFalse(second_serializer.is_valid()) -        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.'],}) -        third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) -        self.assertFalse(third_serializer.is_valid()) -        self.assertEqual(third_serializer.errors,  [{'ref': ['Album with this Ref already exists.']}, {}]) - -    def test_foreign_key_is_null_with_partial(self): -        """ -        Test ModelSerializer validation with partial=True - -        Specifically test that a null foreign key does not pass validation -        """ -        album = Album(title='test') -        album.save() - -        class PhotoSerializer(serializers.ModelSerializer): -            class Meta: -                model = Photo - -        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) -        self.assertTrue(photo_serializer.is_valid()) -        photo = photo_serializer.save() - -        # Updating only the album (foreign key) -        photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True) -        self.assertFalse(photo_serializer.is_valid()) -        self.assertTrue('album' in photo_serializer.errors) -        self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required']) - -    def test_foreign_key_with_partial(self): -        """ -        Test ModelSerializer validation with partial=True - -        Specifically test foreign key validation. -        """ - -        album = Album(title='test') -        album.save() - -        class PhotoSerializer(serializers.ModelSerializer): -            class Meta: -                model = Photo - -        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) -        self.assertTrue(photo_serializer.is_valid()) -        photo = photo_serializer.save() - -        # Updating only the album (foreign key) -        photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) -        self.assertTrue(photo_serializer.is_valid()) -        self.assertTrue(photo_serializer.save()) - -        # Updating only the description -        photo_serializer = PhotoSerializer(instance=photo, -                                           data={'description': 'new'}, -                                           partial=True) - -        self.assertTrue(photo_serializer.is_valid()) -        self.assertTrue(photo_serializer.save()) - - -class RegexValidationTest(TestCase): -    def test_create_failed(self): -        serializer = BookSerializer(data={'isbn': '1234567890'}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - -        serializer = BookSerializer(data={'isbn': '12345678901234'}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - -        serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) -        self.assertFalse(serializer.is_valid()) -        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - -    def test_create_success(self): -        serializer = BookSerializer(data={'isbn': '1234567890123'}) -        self.assertTrue(serializer.is_valid()) - - -class MetadataTests(TestCase): -    def test_empty(self): -        serializer = CommentSerializer() -        expected = { -            'email': serializers.CharField, -            'content': serializers.CharField, -            'created': serializers.DateTimeField -        } -        for field_name, field in expected.items(): -            self.assertTrue(isinstance(serializer.data.fields[field_name], field)) - - -class ManyToManyTests(TestCase): -    def setUp(self): -        class ManyToManySerializer(serializers.ModelSerializer): -            class Meta: -                model = ManyToManyModel - -        self.serializer_class = ManyToManySerializer - -        # An anchor instance to use for the relationship -        self.anchor = Anchor() -        self.anchor.save() - -        # A model instance with a many to many relationship to the anchor -        self.instance = ManyToManyModel() -        self.instance.save() -        self.instance.rel.add(self.anchor) - -        # A serialized representation of the model instance -        self.data = {'id': 1, 'rel': [self.anchor.id]} - -    def test_retrieve(self): -        """ -        Serialize an instance of a model with a ManyToMany relationship. -        """ -        serializer = self.serializer_class(instance=self.instance) -        expected = self.data -        self.assertEqual(serializer.data, expected) - -    def test_create(self): -        """ -        Create an instance of a model with a ManyToMany relationship. -        """ -        data = {'rel': [self.anchor.id]} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ManyToManyModel.objects.all()), 2) -        self.assertEqual(instance.pk, 2) -        self.assertEqual(list(instance.rel.all()), [self.anchor]) - -    def test_update(self): -        """ -        Update an instance of a model with a ManyToMany relationship. -        """ -        new_anchor = Anchor() -        new_anchor.save() -        data = {'rel': [self.anchor.id, new_anchor.id]} -        serializer = self.serializer_class(self.instance, data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ManyToManyModel.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor]) - -    def test_create_empty_relationship(self): -        """ -        Create an instance of a model with a ManyToMany relationship, -        containing no items. -        """ -        data = {'rel': []} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ManyToManyModel.objects.all()), 2) -        self.assertEqual(instance.pk, 2) -        self.assertEqual(list(instance.rel.all()), []) - -    def test_update_empty_relationship(self): -        """ -        Update an instance of a model with a ManyToMany relationship, -        containing no items. -        """ -        new_anchor = Anchor() -        new_anchor.save() -        data = {'rel': []} -        serializer = self.serializer_class(self.instance, data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ManyToManyModel.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(list(instance.rel.all()), []) - -    def test_create_empty_relationship_flat_data(self): -        """ -        Create an instance of a model with a ManyToMany relationship, -        containing no items, using a representation that does not support -        lists (eg form data). -        """ -        data = MultiValueDict() -        data.setlist('rel', ['']) -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ManyToManyModel.objects.all()), 2) -        self.assertEqual(instance.pk, 2) -        self.assertEqual(list(instance.rel.all()), []) - - -class ReadOnlyManyToManyTests(TestCase): -    def setUp(self): -        class ReadOnlyManyToManySerializer(serializers.ModelSerializer): -            rel = serializers.RelatedField(many=True, read_only=True) - -            class Meta: -                model = ReadOnlyManyToManyModel - -        self.serializer_class = ReadOnlyManyToManySerializer - -        # An anchor instance to use for the relationship -        self.anchor = Anchor() -        self.anchor.save() - -        # A model instance with a many to many relationship to the anchor -        self.instance = ReadOnlyManyToManyModel() -        self.instance.save() -        self.instance.rel.add(self.anchor) - -        # A serialized representation of the model instance -        self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} - -    def test_update(self): -        """ -        Attempt to update an instance of a model with a ManyToMany -        relationship.  Not updated due to read_only=True -        """ -        new_anchor = Anchor() -        new_anchor.save() -        data = {'rel': [self.anchor.id, new_anchor.id]} -        serializer = self.serializer_class(self.instance, data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        # rel is still as original (1 entry) -        self.assertEqual(list(instance.rel.all()), [self.anchor]) - -    def test_update_without_relationship(self): -        """ -        Attempt to update an instance of a model where many to ManyToMany -        relationship is not supplied.  Not updated due to read_only=True -        """ -        new_anchor = Anchor() -        new_anchor.save() -        data = {} -        serializer = self.serializer_class(self.instance, data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        # rel is still as original (1 entry) -        self.assertEqual(list(instance.rel.all()), [self.anchor]) - - -class DefaultValueTests(TestCase): -    def setUp(self): -        class DefaultValueSerializer(serializers.ModelSerializer): -            class Meta: -                model = DefaultValueModel - -        self.serializer_class = DefaultValueSerializer -        self.objects = DefaultValueModel.objects - -    def test_create_using_default(self): -        data = {} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(self.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(instance.text, 'foobar') - -    def test_create_overriding_default(self): -        data = {'text': 'overridden'} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(self.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(instance.text, 'overridden') - -    def test_partial_update_default(self): -        """ Regression test for issue #532 """ -        data = {'text': 'overridden'} -        serializer = self.serializer_class(data=data, partial=True) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() - -        data = {'extra': 'extra_value'} -        serializer = self.serializer_class(instance=instance, data=data, partial=True) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() - -        self.assertEqual(instance.extra, 'extra_value') -        self.assertEqual(instance.text, 'overridden') - - -class WritableFieldDefaultValueTests(TestCase): - -    def setUp(self): -        self.expected = {'default': 'value'} -        self.create_field = fields.WritableField - -    def test_get_default_value_with_noncallable(self): -        field = self.create_field(default=self.expected) -        got = field.get_default_value() -        self.assertEqual(got, self.expected) - -    def test_get_default_value_with_callable(self): -        field = self.create_field(default=lambda : self.expected) -        got = field.get_default_value() -        self.assertEqual(got, self.expected) - -    def test_get_default_value_when_not_required(self): -        field = self.create_field(default=self.expected, required=False) -        got = field.get_default_value() -        self.assertEqual(got, self.expected) - -    def test_get_default_value_returns_None(self): -        field = self.create_field() -        got = field.get_default_value() -        self.assertIsNone(got) - -    def test_get_default_value_returns_non_True_values(self): -        values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause -        for expected in values: -            field = self.create_field(default=expected) -            got = field.get_default_value() -            self.assertEqual(got, expected) - - -class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): - -    def setUp(self): -        self.expected = {'foo': 'bar'} -        self.create_field = relations.RelatedField - -    def test_get_default_value_returns_empty_list(self): -        field = self.create_field(many=True) -        got = field.get_default_value() -        self.assertListEqual(got, []) - -    def test_get_default_value_returns_expected(self): -        expected = [1, 2, 3] -        field = self.create_field(many=True, default=expected) -        got = field.get_default_value() -        self.assertListEqual(got, expected) - - -class CallableDefaultValueTests(TestCase): -    def setUp(self): -        class CallableDefaultValueSerializer(serializers.ModelSerializer): -            class Meta: -                model = CallableDefaultValueModel - -        self.serializer_class = CallableDefaultValueSerializer -        self.objects = CallableDefaultValueModel.objects - -    def test_create_using_default(self): -        data = {} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(self.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(instance.text, 'foobar') - -    def test_create_overriding_default(self): -        data = {'text': 'overridden'} -        serializer = self.serializer_class(data=data) -        self.assertEqual(serializer.is_valid(), True) -        instance = serializer.save() -        self.assertEqual(len(self.objects.all()), 1) -        self.assertEqual(instance.pk, 1) -        self.assertEqual(instance.text, 'overridden') - - -class ManyRelatedTests(TestCase): -    def test_reverse_relations(self): -        post = BlogPost.objects.create(title="Test blog post") -        post.blogpostcomment_set.create(text="I hate this blog post") -        post.blogpostcomment_set.create(text="I love this blog post") - -        class BlogPostCommentSerializer(serializers.Serializer): -            text = serializers.CharField() - -        class BlogPostSerializer(serializers.Serializer): -            title = serializers.CharField() -            comments = BlogPostCommentSerializer(source='blogpostcomment_set') - -        serializer = BlogPostSerializer(instance=post) -        expected = { -            'title': 'Test blog post', -            'comments': [ -                {'text': 'I hate this blog post'}, -                {'text': 'I love this blog post'} -            ] -        } - -        self.assertEqual(serializer.data, expected) - -    def test_include_reverse_relations(self): -        post = BlogPost.objects.create(title="Test blog post") -        post.blogpostcomment_set.create(text="I hate this blog post") -        post.blogpostcomment_set.create(text="I love this blog post") - -        class BlogPostSerializer(serializers.ModelSerializer): -            class Meta: -                model = BlogPost -                fields = ('id', 'title', 'blogpostcomment_set') - -        serializer = BlogPostSerializer(instance=post) -        expected = { -            'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] -        } -        self.assertEqual(serializer.data, expected) - -    def test_depth_include_reverse_relations(self): -        post = BlogPost.objects.create(title="Test blog post") -        post.blogpostcomment_set.create(text="I hate this blog post") -        post.blogpostcomment_set.create(text="I love this blog post") - -        class BlogPostSerializer(serializers.ModelSerializer): -            class Meta: -                model = BlogPost -                fields = ('id', 'title', 'blogpostcomment_set') -                depth = 1 - -        serializer = BlogPostSerializer(instance=post) -        expected = { -            'id': 1, 'title': 'Test blog post', -            'blogpostcomment_set': [ -                {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, -                {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} -            ] -        } -        self.assertEqual(serializer.data, expected) - -    def test_callable_source(self): -        post = BlogPost.objects.create(title="Test blog post") -        post.blogpostcomment_set.create(text="I love this blog post") - -        class BlogPostCommentSerializer(serializers.Serializer): -            text = serializers.CharField() - -        class BlogPostSerializer(serializers.Serializer): -            title = serializers.CharField() -            first_comment = BlogPostCommentSerializer(source='get_first_comment') - -        serializer = BlogPostSerializer(post) - -        expected = { -            'title': 'Test blog post', -            'first_comment': {'text': 'I love this blog post'} -        } -        self.assertEqual(serializer.data, expected) - - -class RelatedTraversalTest(TestCase): -    def test_nested_traversal(self): -        """ -        Source argument should support dotted.source notation. -        """ -        user = Person.objects.create(name="django") -        post = BlogPost.objects.create(title="Test blog post", writer=user) -        post.blogpostcomment_set.create(text="I love this blog post") - -        class PersonSerializer(serializers.ModelSerializer): -            class Meta: -                model = Person -                fields = ("name", "age") - -        class BlogPostCommentSerializer(serializers.ModelSerializer): -            class Meta: -                model = BlogPostComment -                fields = ("text", "post_owner") - -            text = serializers.CharField() -            post_owner = PersonSerializer(source='blog_post.writer') - -        class BlogPostSerializer(serializers.Serializer): -            title = serializers.CharField() -            comments = BlogPostCommentSerializer(source='blogpostcomment_set') - -        serializer = BlogPostSerializer(instance=post) - -        expected = { -            'title': 'Test blog post', -            'comments': [{ -                'text': 'I love this blog post', -                'post_owner': { -                    "name": "django", -                    "age": None -                } -            }] -        } - -        self.assertEqual(serializer.data, expected) - -    def test_nested_traversal_with_none(self): -        """ -        If a component of the dotted.source is None, return None for the field. -        """ -        from rest_framework.tests.models import NullableForeignKeySource -        instance = NullableForeignKeySource.objects.create(name='Source with null FK') - -        class NullableSourceSerializer(serializers.Serializer): -            target_name = serializers.Field(source='target.name') - -        serializer = NullableSourceSerializer(instance=instance) - -        expected = { -            'target_name': None, -        } - -        self.assertEqual(serializer.data, expected) - - -class SerializerMethodFieldTests(TestCase): -    def setUp(self): - -        class BoopSerializer(serializers.Serializer): -            beep = serializers.SerializerMethodField('get_beep') -            boop = serializers.Field() -            boop_count = serializers.SerializerMethodField('get_boop_count') - -            def get_beep(self, obj): -                return 'hello!' - -            def get_boop_count(self, obj): -                return len(obj.boop) - -        self.serializer_class = BoopSerializer - -    def test_serializer_method_field(self): - -        class MyModel(object): -            boop = ['a', 'b', 'c'] - -        source_data = MyModel() - -        serializer = self.serializer_class(source_data) - -        expected = { -            'beep': 'hello!', -            'boop': ['a', 'b', 'c'], -            'boop_count': 3, -        } - -        self.assertEqual(serializer.data, expected) - - -# Test for issue #324 -class BlankFieldTests(TestCase): -    def setUp(self): - -        class BlankFieldModelSerializer(serializers.ModelSerializer): -            class Meta: -                model = BlankFieldModel - -        class BlankFieldSerializer(serializers.Serializer): -            title = serializers.CharField(required=False) - -        class NotBlankFieldModelSerializer(serializers.ModelSerializer): -            class Meta: -                model = BasicModel - -        class NotBlankFieldSerializer(serializers.Serializer): -            title = serializers.CharField() - -        self.model_serializer_class = BlankFieldModelSerializer -        self.serializer_class = BlankFieldSerializer -        self.not_blank_model_serializer_class = NotBlankFieldModelSerializer -        self.not_blank_serializer_class = NotBlankFieldSerializer -        self.data = {'title': ''} - -    def test_create_blank_field(self): -        serializer = self.serializer_class(data=self.data) -        self.assertEqual(serializer.is_valid(), True) - -    def test_create_model_blank_field(self): -        serializer = self.model_serializer_class(data=self.data) -        self.assertEqual(serializer.is_valid(), True) - -    def test_create_model_null_field(self): -        serializer = self.model_serializer_class(data={'title': None}) -        self.assertEqual(serializer.is_valid(), True) - -    def test_create_not_blank_field(self): -        """ -        Test to ensure blank data in a field not marked as blank=True -        is considered invalid in a non-model serializer -        """ -        serializer = self.not_blank_serializer_class(data=self.data) -        self.assertEqual(serializer.is_valid(), False) - -    def test_create_model_not_blank_field(self): -        """ -        Test to ensure blank data in a field not marked as blank=True -        is considered invalid in a model serializer -        """ -        serializer = self.not_blank_model_serializer_class(data=self.data) -        self.assertEqual(serializer.is_valid(), False) - -    def test_create_model_empty_field(self): -        serializer = self.model_serializer_class(data={}) -        self.assertEqual(serializer.is_valid(), True) - - -#test for issue #460 -class SerializerPickleTests(TestCase): -    """ -    Test pickleability of the output of Serializers -    """ -    def test_pickle_simple_model_serializer_data(self): -        """ -        Test simple serializer -        """ -        pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data) - -    def test_pickle_inner_serializer(self): -        """ -        Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will -        have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle. -        See DictWithMetadata.__getstate__ -        """ -        class InnerPersonSerializer(serializers.ModelSerializer): -            class Meta: -                model = Person -                fields = ('name', 'age') -        pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0) - -    def test_getstate_method_should_not_return_none(self): -        """ -        Regression test for #645. -        """ -        data = serializers.DictWithMetadata({1: 1}) -        self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1})) - -    def test_serializer_data_is_pickleable(self): -        """ -        Another regression test for #645. -        """ -        data = serializers.SortedDictWithMetadata({1: 1}) -        repr(pickle.loads(pickle.dumps(data, 0))) - - -# test for issue #725 -class SeveralChoicesModel(models.Model): -    color = models.CharField( -        max_length=10, -        choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], -        blank=False -    ) -    drink = models.CharField( -        max_length=10, -        choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], -        blank=False, -        default='beer' -    ) -    os = models.CharField( -        max_length=10, -        choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], -        blank=True -    ) -    music_genre = models.CharField( -        max_length=10, -        choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], -        blank=True, -        default='metal' -    ) - - -class SerializerChoiceFields(TestCase): - -    def setUp(self): -        super(SerializerChoiceFields, self).setUp() - -        class SeveralChoicesSerializer(serializers.ModelSerializer): -            class Meta: -                model = SeveralChoicesModel -                fields = ('color', 'drink', 'os', 'music_genre') - -        self.several_choices_serializer = SeveralChoicesSerializer - -    def test_choices_blank_false_not_default(self): -        serializer = self.several_choices_serializer() -        self.assertEqual( -            serializer.fields['color'].choices, -            [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] -        ) - -    def test_choices_blank_false_with_default(self): -        serializer = self.several_choices_serializer() -        self.assertEqual( -            serializer.fields['drink'].choices, -            [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] -        ) - -    def test_choices_blank_true_not_default(self): -        serializer = self.several_choices_serializer() -        self.assertEqual( -            serializer.fields['os'].choices, -            BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] -        ) - -    def test_choices_blank_true_with_default(self): -        serializer = self.several_choices_serializer() -        self.assertEqual( -            serializer.fields['music_genre'].choices, -            BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] -        ) - - -# Regression tests for #675 -class Ticket(models.Model): -    assigned = models.ForeignKey( -        Person, related_name='assigned_tickets') -    reviewer = models.ForeignKey( -        Person, blank=True, null=True, related_name='reviewed_tickets') - - -class SerializerRelatedChoicesTest(TestCase): - -    def setUp(self): -        super(SerializerRelatedChoicesTest, self).setUp() - -        class RelatedChoicesSerializer(serializers.ModelSerializer): -            class Meta: -                model = Ticket -                fields = ('assigned', 'reviewer') - -        self.related_fields_serializer = RelatedChoicesSerializer - -    def test_empty_queryset_required(self): -        serializer = self.related_fields_serializer() -        self.assertEqual(serializer.fields['assigned'].queryset.count(), 0) -        self.assertEqual( -            [x for x in serializer.fields['assigned'].widget.choices], -            [] -        ) - -    def test_empty_queryset_not_required(self): -        serializer = self.related_fields_serializer() -        self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0) -        self.assertEqual( -            [x for x in serializer.fields['reviewer'].widget.choices], -            [('', '---------')] -        ) - -    def test_with_some_persons_required(self): -        Person.objects.create(name="Lionel Messi") -        Person.objects.create(name="Xavi Hernandez") -        serializer = self.related_fields_serializer() -        self.assertEqual(serializer.fields['assigned'].queryset.count(), 2) -        self.assertEqual( -            [x for x in serializer.fields['assigned'].widget.choices], -            [(1, 'Person object - 1'), (2, 'Person object - 2')] -        ) - -    def test_with_some_persons_not_required(self): -        Person.objects.create(name="Lionel Messi") -        Person.objects.create(name="Xavi Hernandez") -        serializer = self.related_fields_serializer() -        self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2) -        self.assertEqual( -            [x for x in serializer.fields['reviewer'].widget.choices], -            [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')] -        ) - - -class DepthTest(TestCase): -    def test_implicit_nesting(self): - -        writer = Person.objects.create(name="django", age=1) -        post = BlogPost.objects.create(title="Test blog post", writer=writer) -        comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - -        class BlogPostCommentSerializer(serializers.ModelSerializer): -            class Meta: -                model = BlogPostComment -                depth = 2 - -        serializer = BlogPostCommentSerializer(instance=comment) -        expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', -                    'writer': {'id': 1, 'name': 'django', 'age': 1}}} - -        self.assertEqual(serializer.data, expected) - -    def test_explicit_nesting(self): -        writer = Person.objects.create(name="django", age=1) -        post = BlogPost.objects.create(title="Test blog post", writer=writer) -        comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - -        class PersonSerializer(serializers.ModelSerializer): -            class Meta: -                model = Person - -        class BlogPostSerializer(serializers.ModelSerializer): -            writer = PersonSerializer() - -            class Meta: -                model = BlogPost - -        class BlogPostCommentSerializer(serializers.ModelSerializer): -            blog_post = BlogPostSerializer() - -            class Meta: -                model = BlogPostComment - -        serializer = BlogPostCommentSerializer(instance=comment) -        expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', -                    'writer': {'id': 1, 'name': 'django', 'age': 1}}} - -        self.assertEqual(serializer.data, expected) - - -class NestedSerializerContextTests(TestCase): - -    def test_nested_serializer_context(self): -        """ -        Regression for #497 - -        https://github.com/tomchristie/django-rest-framework/issues/497 -        """ -        class PhotoSerializer(serializers.ModelSerializer): -            class Meta: -                model = Photo -                fields = ("description", "callable") - -            callable = serializers.SerializerMethodField('_callable') - -            def _callable(self, instance): -                if not 'context_item' in self.context: -                    raise RuntimeError("context isn't getting passed into 2nd level nested serializer") -                return "success" - -        class AlbumSerializer(serializers.ModelSerializer): -            class Meta: -                model = Album -                fields = ("photo_set", "callable") - -            photo_set = PhotoSerializer(source="photo_set") -            callable = serializers.SerializerMethodField("_callable") - -            def _callable(self, instance): -                if not 'context_item' in self.context: -                    raise RuntimeError("context isn't getting passed into 1st level nested serializer") -                return "success" - -        class AlbumCollection(object): -            albums = None - -        class AlbumCollectionSerializer(serializers.Serializer): -            albums = AlbumSerializer(source="albums") - -        album1 = Album.objects.create(title="album 1") -        album2 = Album.objects.create(title="album 2") -        Photo.objects.create(description="Bigfoot", album=album1) -        Photo.objects.create(description="Unicorn", album=album1) -        Photo.objects.create(description="Yeti", album=album2) -        Photo.objects.create(description="Sasquatch", album=album2) -        album_collection = AlbumCollection() -        album_collection.albums = [album1, album2] - -        # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers -        AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data - - -class DeserializeListTestCase(TestCase): - -    def setUp(self): -        self.data = { -            'email': 'nobody@nowhere.com', -            'content': 'This is some test content', -            'created': datetime.datetime(2013, 3, 7), -        } - -    def test_no_errors(self): -        data = [self.data.copy() for x in range(0, 3)] -        serializer = CommentSerializer(data=data, many=True) -        self.assertTrue(serializer.is_valid()) -        self.assertTrue(isinstance(serializer.object, list)) -        self.assertTrue( -            all((isinstance(item, Comment) for item in serializer.object)) -        ) - -    def test_errors_return_as_list(self): -        invalid_item = self.data.copy() -        invalid_item['email'] = '' -        data = [self.data.copy(), invalid_item, self.data.copy()] - -        serializer = CommentSerializer(data=data, many=True) -        self.assertFalse(serializer.is_valid()) -        expected = [{}, {'email': ['This field is required.']}, {}] -        self.assertEqual(serializer.errors, expected) - - -# Test for issue 747 - -class LazyStringModel(object): -    def __init__(self, lazystring): -        self.lazystring = lazystring - - -class LazyStringSerializer(serializers.Serializer): -    lazystring = serializers.Field() - -    def restore_object(self, attrs, instance=None): -        if instance is not None: -            instance.lazystring = attrs.get('lazystring', instance.lazystring) -            return instance -        return LazyStringModel(**attrs) - - -class LazyStringsTestCase(TestCase): -    def setUp(self): -        self.model = LazyStringModel(lazystring=_('lazystring')) - -    def test_lazy_strings_are_translated(self): -        serializer = LazyStringSerializer(self.model) -        self.assertEqual(type(serializer.data['lazystring']), -                         type('lazystring')) - - -# Test for issue #467 - -class FieldLabelTest(TestCase): -    def setUp(self): -        self.serializer_class = BasicModelSerializer - -    def test_label_from_model(self): -        """ -        Validates that label and help_text are correctly copied from the model class. -        """ -        serializer = self.serializer_class() -        text_field = serializer.fields['text'] - -        self.assertEqual('Text comes here', text_field.label) -        self.assertEqual('Text description.', text_field.help_text) - -    def test_field_ctor(self): -        """ -        This is check that ctor supports both label and help_text. -        """ -        self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) -        self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) -        self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) - - -# Test for issue #961 - -class ManyFieldHelpTextTest(TestCase): -    def test_help_text_no_hold_down_control_msg(self): -        """ -        Validate that help_text doesn't contain the 'Hold down "Control" ...' -        message that Django appends to choice fields. -        """ -        rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) -        self.assertEqual('Some help text.', rel_field.help_text) - - -@unittest.skipUnless(PIL is not None, 'PIL is not installed') -class AttributeMappingOnAutogeneratedFieldsTests(TestCase): - -    def setUp(self): - -        class AMOAFSerializer(serializers.ModelSerializer): -            class Meta: -                model = AMOAFModel - -        self.serializer_class = AMOAFSerializer -        self.fields_attributes = { -            'char_field': [ -                ('max_length', 1024), -            ], -            'comma_separated_integer_field': [ -                ('max_length', 1024), -            ], -            'decimal_field': [ -                ('max_digits', 64), -                ('decimal_places', 32), -            ], -            'email_field': [ -                ('max_length', 1024), -            ], -            'file_field': [ -                ('max_length', 1024), -            ], -            'image_field': [ -                ('max_length', 1024), -            ], -            'slug_field': [ -                ('max_length', 1024), -            ], -            'url_field': [ -                ('max_length', 1024), -            ], -        } - -    def field_test(self, field): -        serializer = self.serializer_class(data={}) -        self.assertEqual(serializer.is_valid(), True) - -        for attribute in self.fields_attributes[field]: -            self.assertEqual( -                getattr(serializer.fields[field], attribute[0]), -                attribute[1] -            ) - -    def test_char_field(self): -        self.field_test('char_field') - -    def test_comma_separated_integer_field(self): -        self.field_test('comma_separated_integer_field') - -    def test_decimal_field(self): -        self.field_test('decimal_field') - -    def test_email_field(self): -        self.field_test('email_field') - -    def test_file_field(self): -        self.field_test('file_field') - -    def test_image_field(self): -        self.field_test('image_field') - -    def test_slug_field(self): -        self.field_test('slug_field') - -    def test_url_field(self): -        self.field_test('url_field') - - -@unittest.skipUnless(PIL is not None, 'PIL is not installed') -class DefaultValuesOnAutogeneratedFieldsTests(TestCase): - -    def setUp(self): - -        class DVOAFSerializer(serializers.ModelSerializer): -            class Meta: -                model = DVOAFModel - -        self.serializer_class = DVOAFSerializer -        self.fields_attributes = { -            'positive_integer_field': [ -                ('min_value', 0), -            ], -            'positive_small_integer_field': [ -                ('min_value', 0), -            ], -            'email_field': [ -                ('max_length', 75), -            ], -            'file_field': [ -                ('max_length', 100), -            ], -            'image_field': [ -                ('max_length', 100), -            ], -            'slug_field': [ -                ('max_length', 50), -            ], -            'url_field': [ -                ('max_length', 200), -            ], -        } - -    def field_test(self, field): -        serializer = self.serializer_class(data={}) -        self.assertEqual(serializer.is_valid(), True) - -        for attribute in self.fields_attributes[field]: -            self.assertEqual( -                getattr(serializer.fields[field], attribute[0]), -                attribute[1] -            ) - -    def test_positive_integer_field(self): -        self.field_test('positive_integer_field') - -    def test_positive_small_integer_field(self): -        self.field_test('positive_small_integer_field') - -    def test_email_field(self): -        self.field_test('email_field') - -    def test_file_field(self): -        self.field_test('file_field') - -    def test_image_field(self): -        self.field_test('image_field') - -    def test_slug_field(self): -        self.field_test('slug_field') - -    def test_url_field(self): -        self.field_test('url_field') - - -class MetadataSerializer(serializers.Serializer): -    field1 = serializers.CharField(3, required=True) -    field2 = serializers.CharField(10, required=False) - - -class MetadataSerializerTestCase(TestCase): -    def setUp(self): -        self.serializer = MetadataSerializer() - -    def test_serializer_metadata(self): -        metadata = self.serializer.metadata() -        expected = { -            'field1': { -                'required': True, -                'max_length': 3, -                'type': 'string', -                'read_only': False -            }, -            'field2': { -                'required': False, -                'max_length': 10, -                'type': 'string', -                'read_only': False -            } -        } -        self.assertEqual(expected, metadata) - - -### Regression test for #840 - -class SimpleModel(models.Model): -    text = models.CharField(max_length=100) - - -class SimpleModelSerializer(serializers.ModelSerializer): -    text = serializers.CharField() -    other = serializers.CharField() - -    class Meta: -        model = SimpleModel - -    def validate_other(self, attrs, source): -        del attrs['other'] -        return attrs - - -class FieldValidationRemovingAttr(TestCase): -    def test_removing_non_model_field_in_validation(self): -        """ -        Removing an attr during field valiation should ensure that it is not -        passed through when restoring the object. - -        This allows additional non-model fields to be supported. - -        Regression test for #840. -        """ -        serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) -        self.assertTrue(serializer.is_valid()) -        serializer.save() -        self.assertEqual(serializer.object.text, 'foo') - - -### Regression test for #878 - -class SimpleTargetModel(models.Model): -    text = models.CharField(max_length=100) - - -class SimplePKSourceModelSerializer(serializers.Serializer): -    targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) -    text = serializers.CharField() - - -class SimpleSlugSourceModelSerializer(serializers.Serializer): -    targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') -    text = serializers.CharField() - - -class SerializerSupportsManyRelationships(TestCase): -    def setUp(self): -        SimpleTargetModel.objects.create(text='foo') -        SimpleTargetModel.objects.create(text='bar') - -    def test_serializer_supports_pk_many_relationships(self): -        """ -        Regression test for #878. - -        Note that pk behavior has a different code path to usual cases, -        for performance reasons. -        """ -        serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) - -    def test_serializer_supports_slug_many_relationships(self): -        """ -        Regression test for #878. -        """ -        serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) -        self.assertTrue(serializer.is_valid()) -        self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) - - -class TransformMethodsSerializer(serializers.Serializer): -    a = serializers.CharField() -    b_renamed = serializers.CharField(source='b') - -    def transform_a(self, obj, value): -        return value.lower() - -    def transform_b_renamed(self, obj, value): -        if value is not None: -            return 'and ' + value - - -class TestSerializerTransformMethods(TestCase): -    def setUp(self): -        self.s = TransformMethodsSerializer() - -    def test_transform_methods(self): -        self.assertEqual( -            self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}), -            { -                'a': 'green eggs', -                'b_renamed': 'and HAM', -            } -        ) - -    def test_missing_fields(self): -        self.assertEqual( -            self.s.to_native({'a': 'GREEN EGGS'}), -            { -                'a': 'green eggs', -                'b_renamed': None, -            } -        ) - - -class DefaultTrueBooleanModel(models.Model): -    cat = models.BooleanField(default=True) -    dog = models.BooleanField(default=False) - - -class SerializerDefaultTrueBoolean(TestCase): - -    def setUp(self): -        super(SerializerDefaultTrueBoolean, self).setUp() - -        class DefaultTrueBooleanSerializer(serializers.ModelSerializer): -            class Meta: -                model = DefaultTrueBooleanModel -                fields = ('cat', 'dog') - -        self.default_true_boolean_serializer = DefaultTrueBooleanSerializer - -    def test_enabled_as_false(self): -        serializer = self.default_true_boolean_serializer(data={'cat': False, -                                                                'dog': False}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data['cat'], False) -        self.assertEqual(serializer.data['dog'], False) - -    def test_enabled_as_true(self): -        serializer = self.default_true_boolean_serializer(data={'cat': True, -                                                                'dog': True}) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data['cat'], True) -        self.assertEqual(serializer.data['dog'], True) - -    def test_enabled_partial(self): -        serializer = self.default_true_boolean_serializer(data={'cat': False}, -                                                          partial=True) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data['cat'], False) -        self.assertEqual(serializer.data['dog'], False) - - -class BoolenFieldTypeTest(TestCase): -    ''' -    Ensure the various Boolean based model fields are rendered as the proper -    field type - -    ''' - -    def setUp(self): -        ''' -        Setup an ActionItemSerializer for BooleanTesting -        ''' -        data = { -            'title': 'b' * 201, -        } -        self.serializer = ActionItemSerializer(data=data) - -    def test_booleanfield_type(self): -        ''' -        Test that BooleanField is infered from models.BooleanField -        ''' -        bfield = self.serializer.get_fields()['done'] -        self.assertEqual(type(bfield), fields.BooleanField) - -    def test_nullbooleanfield_type(self): -        ''' -        Test that BooleanField is infered from models.NullBooleanField - -        https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8 -        ''' -        bfield = self.serializer.get_fields()['started'] -        self.assertEqual(type(bfield), fields.BooleanField) diff --git a/rest_framework/tests/test_serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py deleted file mode 100644 index 8b0ded1a..00000000 --- a/rest_framework/tests/test_serializer_bulk_update.py +++ /dev/null @@ -1,278 +0,0 @@ -""" -Tests to cover bulk create and update using serializers. -""" -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import serializers - - -class BulkCreateSerializerTests(TestCase): -    """ -    Creating multiple instances using serializers. -    """ - -    def setUp(self): -        class BookSerializer(serializers.Serializer): -            id = serializers.IntegerField() -            title = serializers.CharField(max_length=100) -            author = serializers.CharField(max_length=100) - -        self.BookSerializer = BookSerializer - -    def test_bulk_create_success(self): -        """ -        Correct bulk update serialization should return the input data. -        """ - -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 1, -                'title': 'If this is a man', -                'author': 'Primo Levi' -            }, { -                'id': 2, -                'title': 'The wind-up bird chronicle', -                'author': 'Haruki Murakami' -            } -        ] - -        serializer = self.BookSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, data) - -    def test_bulk_create_errors(self): -        """ -        Correct bulk update serialization should return the input data. -        """ - -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 1, -                'title': 'If this is a man', -                'author': 'Primo Levi' -            }, { -                'id': 'foo', -                'title': 'The wind-up bird chronicle', -                'author': 'Haruki Murakami' -            } -        ] -        expected_errors = [ -            {}, -            {}, -            {'id': ['Enter a whole number.']} -        ] - -        serializer = self.BookSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, expected_errors) - -    def test_invalid_list_datatype(self): -        """ -        Data containing list of incorrect data type should return errors. -        """ -        data = ['foo', 'bar', 'baz'] -        serializer = self.BookSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) - -        expected_errors = [ -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']} -        ] - -        self.assertEqual(serializer.errors, expected_errors) - -    def test_invalid_single_datatype(self): -        """ -        Data containing a single incorrect data type should return errors. -        """ -        data = 123 -        serializer = self.BookSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) - -        expected_errors = {'non_field_errors': ['Expected a list of items.']} - -        self.assertEqual(serializer.errors, expected_errors) - -    def test_invalid_single_object(self): -        """ -        Data containing only a single object, instead of a list of objects -        should return errors. -        """ -        data = { -            'id': 0, -            'title': 'The electric kool-aid acid test', -            'author': 'Tom Wolfe' -        } -        serializer = self.BookSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) - -        expected_errors = {'non_field_errors': ['Expected a list of items.']} - -        self.assertEqual(serializer.errors, expected_errors) - - -class BulkUpdateSerializerTests(TestCase): -    """ -    Updating multiple instances using serializers. -    """ - -    def setUp(self): -        class Book(object): -            """ -            A data type that can be persisted to a mock storage backend -            with `.save()` and `.delete()`. -            """ -            object_map = {} - -            def __init__(self, id, title, author): -                self.id = id -                self.title = title -                self.author = author - -            def save(self): -                Book.object_map[self.id] = self - -            def delete(self): -                del Book.object_map[self.id] - -        class BookSerializer(serializers.Serializer): -            id = serializers.IntegerField() -            title = serializers.CharField(max_length=100) -            author = serializers.CharField(max_length=100) - -            def restore_object(self, attrs, instance=None): -                if instance: -                    instance.id = attrs['id'] -                    instance.title = attrs['title'] -                    instance.author = attrs['author'] -                    return instance -                return Book(**attrs) - -        self.Book = Book -        self.BookSerializer = BookSerializer - -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 1, -                'title': 'If this is a man', -                'author': 'Primo Levi' -            }, { -                'id': 2, -                'title': 'The wind-up bird chronicle', -                'author': 'Haruki Murakami' -            } -        ] - -        for item in data: -            book = Book(item['id'], item['title'], item['author']) -            book.save() - -    def books(self): -        """ -        Return all the objects in the mock storage backend. -        """ -        return self.Book.object_map.values() - -    def test_bulk_update_success(self): -        """ -        Correct bulk update serialization should return the input data. -        """ -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 2, -                'title': 'Kafka on the shore', -                'author': 'Haruki Murakami' -            } -        ] -        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data, data) -        serializer.save() -        new_data = self.BookSerializer(self.books(), many=True).data - -        self.assertEqual(data, new_data) - -    def test_bulk_update_and_create(self): -        """ -        Bulk update serialization may also include created items. -        """ -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 3, -                'title': 'Kafka on the shore', -                'author': 'Haruki Murakami' -            } -        ] -        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.data, data) -        serializer.save() -        new_data = self.BookSerializer(self.books(), many=True).data -        self.assertEqual(data, new_data) - -    def test_bulk_update_invalid_create(self): -        """ -        Bulk update serialization without allow_add_remove may not create items. -        """ -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 3, -                'title': 'Kafka on the shore', -                'author': 'Haruki Murakami' -            } -        ] -        expected_errors = [ -            {}, -            {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} -        ] -        serializer = self.BookSerializer(self.books(), data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, expected_errors) - -    def test_bulk_update_error(self): -        """ -        Incorrect bulk update serialization should return error data. -        """ -        data = [ -            { -                'id': 0, -                'title': 'The electric kool-aid acid test', -                'author': 'Tom Wolfe' -            }, { -                'id': 'foo', -                'title': 'Kafka on the shore', -                'author': 'Haruki Murakami' -            } -        ] -        expected_errors = [ -            {}, -            {'id': ['Enter a whole number.']} -        ] -        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, expected_errors) diff --git a/rest_framework/tests/test_serializer_empty.py b/rest_framework/tests/test_serializer_empty.py deleted file mode 100644 index 30cff361..00000000 --- a/rest_framework/tests/test_serializer_empty.py +++ /dev/null @@ -1,15 +0,0 @@ -from django.test import TestCase -from rest_framework import serializers - - -class EmptySerializerTestCase(TestCase): -    def test_empty_serializer(self): -        class FooBarSerializer(serializers.Serializer): -            foo = serializers.IntegerField() -            bar = serializers.SerializerMethodField('get_bar') - -            def get_bar(self, obj): -                return 'bar' - -        serializer = FooBarSerializer() -        self.assertEquals(serializer.data, {'foo': 0}) diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py deleted file mode 100644 index 9f30a7ff..00000000 --- a/rest_framework/tests/test_serializer_import.py +++ /dev/null @@ -1,19 +0,0 @@ -from django.test import TestCase - -from rest_framework import serializers -from rest_framework.tests.accounts.serializers import AccountSerializer - - -class ImportingModelSerializerTests(TestCase): -    """ -    In some situations like, GH #1225, it is possible, especially in -    testing, to import a serializer who's related models have not yet -    been resolved by Django. `AccountSerializer` is an example of such -    a serializer (imported at the top of this file). -    """ -    def test_import_model_serializer(self): -        """ -        The serializer at the top of this file should have been -        imported successfully, and we should be able to instantiate it. -        """ -        self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py deleted file mode 100644 index 6d69ffbd..00000000 --- a/rest_framework/tests/test_serializer_nested.py +++ /dev/null @@ -1,347 +0,0 @@ -""" -Tests to cover nested serializers. - -Doesn't cover model serializers. -""" -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import serializers -from . import models - - -class WritableNestedSerializerBasicTests(TestCase): -    """ -    Tests for deserializing nested entities. -    Basic tests that use serializers that simply restore to dicts. -    """ - -    def setUp(self): -        class TrackSerializer(serializers.Serializer): -            order = serializers.IntegerField() -            title = serializers.CharField(max_length=100) -            duration = serializers.IntegerField() - -        class AlbumSerializer(serializers.Serializer): -            album_name = serializers.CharField(max_length=100) -            artist = serializers.CharField(max_length=100) -            tracks = TrackSerializer(many=True) - -        self.AlbumSerializer = AlbumSerializer - -    def test_nested_validation_success(self): -        """ -        Correct nested serialization should return the input data. -        """ - -        data = { -            'album_name': 'Discovery', -            'artist': 'Daft Punk', -            'tracks': [ -                {'order': 1, 'title': 'One More Time', 'duration': 235}, -                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, -                {'order': 3, 'title': 'Digital Love', 'duration': 239} -            ] -        } - -        serializer = self.AlbumSerializer(data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, data) - -    def test_nested_validation_error(self): -        """ -        Incorrect nested serialization should return appropriate error data. -        """ - -        data = { -            'album_name': 'Discovery', -            'artist': 'Daft Punk', -            'tracks': [ -                {'order': 1, 'title': 'One More Time', 'duration': 235}, -                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, -                {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} -            ] -        } -        expected_errors = { -            'tracks': [ -                {}, -                {}, -                {'duration': ['Enter a whole number.']} -            ] -        } - -        serializer = self.AlbumSerializer(data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, expected_errors) - -    def test_many_nested_validation_error(self): -        """ -        Incorrect nested serialization should return appropriate error data -        when multiple entities are being deserialized. -        """ - -        data = [ -            { -                'album_name': 'Russian Red', -                'artist': 'I Love Your Glasses', -                'tracks': [ -                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, -                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, -                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} -                ] -            }, -            { -                'album_name': 'Discovery', -                'artist': 'Daft Punk', -                'tracks': [ -                    {'order': 1, 'title': 'One More Time', 'duration': 235}, -                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, -                    {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} -                ] -            } -        ] -        expected_errors = [ -            {}, -            { -                'tracks': [ -                    {}, -                    {}, -                    {'duration': ['Enter a whole number.']} -                ] -            } -        ] - -        serializer = self.AlbumSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, expected_errors) - - -class WritableNestedSerializerObjectTests(TestCase): -    """ -    Tests for deserializing nested entities. -    These tests use serializers that restore to concrete objects. -    """ - -    def setUp(self): -        # Couple of concrete objects that we're going to deserialize into -        class Track(object): -            def __init__(self, order, title, duration): -                self.order, self.title, self.duration = order, title, duration - -            def __eq__(self, other): -                return ( -                    self.order == other.order and -                    self.title == other.title and -                    self.duration == other.duration -                ) - -        class Album(object): -            def __init__(self, album_name, artist, tracks): -                self.album_name, self.artist, self.tracks = album_name, artist, tracks - -            def __eq__(self, other): -                return ( -                    self.album_name == other.album_name and -                    self.artist == other.artist and -                    self.tracks == other.tracks -                ) - -        # And their corresponding serializers -        class TrackSerializer(serializers.Serializer): -            order = serializers.IntegerField() -            title = serializers.CharField(max_length=100) -            duration = serializers.IntegerField() - -            def restore_object(self, attrs, instance=None): -                return Track(attrs['order'], attrs['title'], attrs['duration']) - -        class AlbumSerializer(serializers.Serializer): -            album_name = serializers.CharField(max_length=100) -            artist = serializers.CharField(max_length=100) -            tracks = TrackSerializer(many=True) - -            def restore_object(self, attrs, instance=None): -                return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) - -        self.Album, self.Track = Album, Track -        self.AlbumSerializer = AlbumSerializer - -    def test_nested_validation_success(self): -        """ -        Correct nested serialization should return a restored object -        that corresponds to the input data. -        """ - -        data = { -            'album_name': 'Discovery', -            'artist': 'Daft Punk', -            'tracks': [ -                {'order': 1, 'title': 'One More Time', 'duration': 235}, -                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, -                {'order': 3, 'title': 'Digital Love', 'duration': 239} -            ] -        } -        expected_object = self.Album( -            album_name='Discovery', -            artist='Daft Punk', -            tracks=[ -                self.Track(order=1, title='One More Time', duration=235), -                self.Track(order=2, title='Aerodynamic', duration=184), -                self.Track(order=3, title='Digital Love', duration=239), -            ] -        ) - -        serializer = self.AlbumSerializer(data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected_object) - -    def test_many_nested_validation_success(self): -        """ -        Correct nested serialization should return multiple restored objects -        that corresponds to the input data when multiple objects are -        being deserialized. -        """ - -        data = [ -            { -                'album_name': 'Russian Red', -                'artist': 'I Love Your Glasses', -                'tracks': [ -                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, -                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, -                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} -                ] -            }, -            { -                'album_name': 'Discovery', -                'artist': 'Daft Punk', -                'tracks': [ -                    {'order': 1, 'title': 'One More Time', 'duration': 235}, -                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, -                    {'order': 3, 'title': 'Digital Love', 'duration': 239} -                ] -            } -        ] -        expected_object = [ -            self.Album( -                album_name='Russian Red', -                artist='I Love Your Glasses', -                tracks=[ -                    self.Track(order=1, title='Cigarettes', duration=121), -                    self.Track(order=2, title='No Past Land', duration=198), -                    self.Track(order=3, title='They Don\'t Believe', duration=191), -                ] -            ), -            self.Album( -                album_name='Discovery', -                artist='Daft Punk', -                tracks=[ -                    self.Track(order=1, title='One More Time', duration=235), -                    self.Track(order=2, title='Aerodynamic', duration=184), -                    self.Track(order=3, title='Digital Love', duration=239), -                ] -            ) -        ] - -        serializer = self.AlbumSerializer(data=data, many=True) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected_object) - - -class ForeignKeyNestedSerializerUpdateTests(TestCase): -    def setUp(self): -        class Artist(object): -            def __init__(self, name): -                self.name = name - -            def __eq__(self, other): -                return self.name == other.name - -        class Album(object): -            def __init__(self, name, artist): -                self.name, self.artist = name, artist - -            def __eq__(self, other): -                return self.name == other.name and self.artist == other.artist - -        class ArtistSerializer(serializers.Serializer): -            name = serializers.CharField() - -            def restore_object(self, attrs, instance=None): -                if instance: -                    instance.name = attrs['name'] -                else: -                    instance = Artist(attrs['name']) -                return instance - -        class AlbumSerializer(serializers.Serializer): -            name = serializers.CharField() -            by = ArtistSerializer(source='artist') - -            def restore_object(self, attrs, instance=None): -                if instance: -                    instance.name = attrs['name'] -                    instance.artist = attrs['artist'] -                else: -                    instance = Album(attrs['name'], attrs['artist']) -                return instance - -        self.Artist = Artist -        self.Album = Album -        self.AlbumSerializer = AlbumSerializer - -    def test_create_via_foreign_key_with_source(self): -        """ -        Check that we can both *create* and *update* into objects across -        ForeignKeys that have a `source` specified. -        Regression test for #1170 -        """ -        data = { -            'name': 'Discovery', -            'by': {'name': 'Daft Punk'}, -        } - -        expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery') - -        # create -        serializer = self.AlbumSerializer(data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected) - -        # update -        original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters') -        serializer = self.AlbumSerializer(instance=original, data=data) -        self.assertEqual(serializer.is_valid(), True) -        self.assertEqual(serializer.object, expected) - - -class NestedModelSerializerUpdateTests(TestCase): -    def test_second_nested_level(self): -        john = models.Person.objects.create(name="john") - -        post = john.blogpost_set.create(title="Test blog post") -        post.blogpostcomment_set.create(text="I hate this blog post") -        post.blogpostcomment_set.create(text="I love this blog post") - -        class BlogPostCommentSerializer(serializers.ModelSerializer): -            class Meta: -                model = models.BlogPostComment - -        class BlogPostSerializer(serializers.ModelSerializer): -            comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') -            class Meta: -                model = models.BlogPost -                fields = ('id', 'title', 'comments') - -        class PersonSerializer(serializers.ModelSerializer): -            posts = BlogPostSerializer(many=True, source='blogpost_set') -            class Meta: -                model = models.Person -                fields = ('id', 'name', 'age', 'posts') - -        serialize = PersonSerializer(instance=john) -        deserialize = PersonSerializer(data=serialize.data, instance=john) -        self.assertTrue(deserialize.is_valid()) - -        result = deserialize.object -        result.save() -        self.assertEqual(result.id, john.id) diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py deleted file mode 100644 index 082a400c..00000000 --- a/rest_framework/tests/test_serializers.py +++ /dev/null @@ -1,28 +0,0 @@ -from django.db import models -from django.test import TestCase - -from rest_framework.serializers import _resolve_model -from rest_framework.tests.models import BasicModel - - -class ResolveModelTests(TestCase): -    """ -    `_resolve_model` should return a Django model class given the -    provided argument is a Django model class itself, or a properly -    formatted string representation of one. -    """ -    def test_resolve_django_model(self): -        resolved_model = _resolve_model(BasicModel) -        self.assertEqual(resolved_model, BasicModel) - -    def test_resolve_string_representation(self): -        resolved_model = _resolve_model('tests.BasicModel') -        self.assertEqual(resolved_model, BasicModel) - -    def test_resolve_non_django_model(self): -        with self.assertRaises(ValueError): -            _resolve_model(TestCase) - -    def test_resolve_improper_string_representation(self): -        with self.assertRaises(ValueError): -            _resolve_model('BasicModel') diff --git a/rest_framework/tests/test_settings.py b/rest_framework/tests/test_settings.py deleted file mode 100644 index 857375c2..00000000 --- a/rest_framework/tests/test_settings.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Tests for the settings module""" -from __future__ import unicode_literals -from django.test import TestCase - -from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS - - -class TestSettings(TestCase): -    """Tests relating to the api settings""" - -    def test_non_import_errors(self): -        """Make sure other errors aren't suppressed.""" -        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) -        with self.assertRaises(ValueError): -            settings.DEFAULT_MODEL_SERIALIZER_CLASS - -    def test_import_error_message_maintained(self): -        """Make sure real import errors are captured and raised sensibly.""" -        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) -        with self.assertRaises(ImportError) as cm: -            settings.DEFAULT_MODEL_SERIALIZER_CLASS -        self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py deleted file mode 100644 index 7b1bdae3..00000000 --- a/rest_framework/tests/test_status.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.status import ( -    is_informational, is_success, is_redirect, is_client_error, is_server_error -) - - -class TestStatus(TestCase): -    def test_status_categories(self): -        self.assertFalse(is_informational(99)) -        self.assertTrue(is_informational(100)) -        self.assertTrue(is_informational(199)) -        self.assertFalse(is_informational(200)) - -        self.assertFalse(is_success(199)) -        self.assertTrue(is_success(200)) -        self.assertTrue(is_success(299)) -        self.assertFalse(is_success(300)) - -        self.assertFalse(is_redirect(299)) -        self.assertTrue(is_redirect(300)) -        self.assertTrue(is_redirect(399)) -        self.assertFalse(is_redirect(400)) - -        self.assertFalse(is_client_error(399)) -        self.assertTrue(is_client_error(400)) -        self.assertTrue(is_client_error(499)) -        self.assertFalse(is_client_error(500)) - -        self.assertFalse(is_server_error(499)) -        self.assertTrue(is_server_error(500)) -        self.assertTrue(is_server_error(599)) -        self.assertFalse(is_server_error(600))
\ No newline at end of file diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py deleted file mode 100644 index d4da0c23..00000000 --- a/rest_framework/tests/test_templatetags.py +++ /dev/null @@ -1,51 +0,0 @@ -# encoding: utf-8 -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.test import APIRequestFactory -from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links - -factory = APIRequestFactory() - - -class TemplateTagTests(TestCase): - -    def test_add_query_param_with_non_latin_charactor(self): -        # Ensure we don't double-escape non-latin characters -        # that are present in the querystring. -        # See #1314. -        request = factory.get("/", {'q': '查询'}) -        json_url = add_query_param(request, "format", "json") -        self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) -        self.assertIn("format=json", json_url) - - -class Issue1386Tests(TestCase): -    """ -    Covers #1386 -    """ - -    def test_issue_1386(self): -        """ -        Test function urlize_quoted_links with different args -        """ -        correct_urls = [ -            "asdf.com", -            "asdf.net", -            "www.as_df.org", -            "as.d8f.ghj8.gov", -        ] -        for i in correct_urls: -            res = urlize_quoted_links(i) -            self.assertNotEqual(res, i) -            self.assertIn(i, res) - -        incorrect_urls = [ -            "mailto://asdf@fdf.com", -            "asdf.netnet", -        ] -        for i in incorrect_urls: -            res = urlize_quoted_links(i) -            self.assertEqual(i, res) - -        # example from issue #1386, this shouldn't raise an exception -        _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py deleted file mode 100644 index a55d4b22..00000000 --- a/rest_framework/tests/test_testing.py +++ /dev/null @@ -1,164 +0,0 @@ -# -- coding: utf-8 -- - -from __future__ import unicode_literals -from io import BytesIO - -from django.contrib.auth.models import User -from django.test import TestCase -from rest_framework.compat import patterns, url -from rest_framework.decorators import api_view -from rest_framework.response import Response -from rest_framework.test import APIClient, APIRequestFactory, force_authenticate - - -@api_view(['GET', 'POST']) -def view(request): -    return Response({ -        'auth': request.META.get('HTTP_AUTHORIZATION', b''), -        'user': request.user.username -    }) - - -@api_view(['GET', 'POST']) -def session_view(request): -    active_session = request.session.get('active_session', False) -    request.session['active_session'] = True -    return Response({ -        'active_session': active_session -    }) - - -urlpatterns = patterns('', -    url(r'^view/$', view), -    url(r'^session-view/$', session_view), -) - - -class TestAPITestClient(TestCase): -    urls = 'rest_framework.tests.test_testing' - -    def setUp(self): -        self.client = APIClient() - -    def test_credentials(self): -        """ -        Setting `.credentials()` adds the required headers to each request. -        """ -        self.client.credentials(HTTP_AUTHORIZATION='example') -        for _ in range(0, 3): -            response = self.client.get('/view/') -            self.assertEqual(response.data['auth'], 'example') - -    def test_force_authenticate(self): -        """ -        Setting `.force_authenticate()` forcibly authenticates each request. -        """ -        user = User.objects.create_user('example', 'example@example.com') -        self.client.force_authenticate(user) -        response = self.client.get('/view/') -        self.assertEqual(response.data['user'], 'example') - -    def test_force_authenticate_with_sessions(self): -        """ -        Setting `.force_authenticate()` forcibly authenticates each request. -        """ -        user = User.objects.create_user('example', 'example@example.com') -        self.client.force_authenticate(user) - -        # First request does not yet have an active session -        response = self.client.get('/session-view/') -        self.assertEqual(response.data['active_session'], False) - -        # Subsequant requests have an active session -        response = self.client.get('/session-view/') -        self.assertEqual(response.data['active_session'], True) - -        # Force authenticating as `None` should also logout the user session. -        self.client.force_authenticate(None) -        response = self.client.get('/session-view/') -        self.assertEqual(response.data['active_session'], False) - -    def test_csrf_exempt_by_default(self): -        """ -        By default, the test client is CSRF exempt. -        """ -        User.objects.create_user('example', 'example@example.com', 'password') -        self.client.login(username='example', password='password') -        response = self.client.post('/view/') -        self.assertEqual(response.status_code, 200) - -    def test_explicitly_enforce_csrf_checks(self): -        """ -        The test client can enforce CSRF checks. -        """ -        client = APIClient(enforce_csrf_checks=True) -        User.objects.create_user('example', 'example@example.com', 'password') -        client.login(username='example', password='password') -        response = client.post('/view/') -        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} -        self.assertEqual(response.status_code, 403) -        self.assertEqual(response.data, expected) - - -class TestAPIRequestFactory(TestCase): -    def test_csrf_exempt_by_default(self): -        """ -        By default, the test client is CSRF exempt. -        """ -        user = User.objects.create_user('example', 'example@example.com', 'password') -        factory = APIRequestFactory() -        request = factory.post('/view/') -        request.user = user -        response = view(request) -        self.assertEqual(response.status_code, 200) - -    def test_explicitly_enforce_csrf_checks(self): -        """ -        The test client can enforce CSRF checks. -        """ -        user = User.objects.create_user('example', 'example@example.com', 'password') -        factory = APIRequestFactory(enforce_csrf_checks=True) -        request = factory.post('/view/') -        request.user = user -        response = view(request) -        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} -        self.assertEqual(response.status_code, 403) -        self.assertEqual(response.data, expected) - -    def test_invalid_format(self): -        """ -        Attempting to use a format that is not configured will raise an -        assertion error. -        """ -        factory = APIRequestFactory() -        self.assertRaises(AssertionError, factory.post, -            path='/view/', data={'example': 1}, format='xml' -        ) - -    def test_force_authenticate(self): -        """ -        Setting `force_authenticate()` forcibly authenticates the request. -        """ -        user = User.objects.create_user('example', 'example@example.com') -        factory = APIRequestFactory() -        request = factory.get('/view') -        force_authenticate(request, user=user) -        response = view(request) -        self.assertEqual(response.data['user'], 'example') - -    def test_upload_file(self): -        # This is a 1x1 black png -        simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') -        simple_png.name = 'test.png' -        factory = APIRequestFactory() -        factory.post('/', data={'image': simple_png}) - -    def test_request_factory_url_arguments(self): -        """ -        This is a non regression test against #1461 -        """ -        factory = APIRequestFactory() -        request = factory.get('/view/?demo=test') -        self.assertEqual(dict(request.GET), {'demo': ['test']}) -        request = factory.get('/view/', {'demo': 'test'}) -        self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py deleted file mode 100644 index b5ae02cd..00000000 --- a/rest_framework/tests/test_throttling.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Tests for the throttling implementations in the permissions module. -""" -from __future__ import unicode_literals -from django.test import TestCase -from django.contrib.auth.models import User -from django.core.cache import cache -from rest_framework.test import APIRequestFactory -from rest_framework.views import APIView -from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle -from rest_framework.response import Response - - -class User3SecRateThrottle(UserRateThrottle): -    rate = '3/sec' -    scope = 'seconds' - - -class User3MinRateThrottle(UserRateThrottle): -    rate = '3/min' -    scope = 'minutes' - - -class NonTimeThrottle(BaseThrottle): -    def allow_request(self, request, view): -        if not hasattr(self.__class__, 'called'): -            self.__class__.called = True -            return True -        return False  - - -class MockView(APIView): -    throttle_classes = (User3SecRateThrottle,) - -    def get(self, request): -        return Response('foo') - - -class MockView_MinuteThrottling(APIView): -    throttle_classes = (User3MinRateThrottle,) - -    def get(self, request): -        return Response('foo') - - -class MockView_NonTimeThrottling(APIView): -    throttle_classes = (NonTimeThrottle,) - -    def get(self, request): -        return Response('foo') - - -class ThrottlingTests(TestCase): -    def setUp(self): -        """ -        Reset the cache so that no throttles will be active -        """ -        cache.clear() -        self.factory = APIRequestFactory() - -    def test_requests_are_throttled(self): -        """ -        Ensure request rate is limited -        """ -        request = self.factory.get('/') -        for dummy in range(4): -            response = MockView.as_view()(request) -        self.assertEqual(429, response.status_code) - -    def set_throttle_timer(self, view, value): -        """ -        Explicitly set the timer, overriding time.time() -        """ -        view.throttle_classes[0].timer = lambda self: value - -    def test_request_throttling_expires(self): -        """ -        Ensure request rate is limited for a limited duration only -        """ -        self.set_throttle_timer(MockView, 0) - -        request = self.factory.get('/') -        for dummy in range(4): -            response = MockView.as_view()(request) -        self.assertEqual(429, response.status_code) - -        # Advance the timer by one second -        self.set_throttle_timer(MockView, 1) - -        response = MockView.as_view()(request) -        self.assertEqual(200, response.status_code) - -    def ensure_is_throttled(self, view, expect): -        request = self.factory.get('/') -        request.user = User.objects.create(username='a') -        for dummy in range(3): -            view.as_view()(request) -        request.user = User.objects.create(username='b') -        response = view.as_view()(request) -        self.assertEqual(expect, response.status_code) - -    def test_request_throttling_is_per_user(self): -        """ -        Ensure request rate is only limited per user, not globally for -        PerUserThrottles -        """ -        self.ensure_is_throttled(MockView, 200) - -    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): -        """ -        Ensure the response returns an X-Throttle field with status and next attributes -        set properly. -        """ -        request = self.factory.get('/') -        for timer, expect in expected_headers: -            self.set_throttle_timer(view, timer) -            response = view.as_view()(request) -            if expect is not None: -                self.assertEqual(response['X-Throttle-Wait-Seconds'], expect) -                self.assertEqual(response['Retry-After'], expect) -            else: -                self.assertFalse('X-Throttle-Wait-Seconds' in response) -                self.assertFalse('Retry-After' in response) - -    def test_seconds_fields(self): -        """ -        Ensure for second based throttles. -        """ -        self.ensure_response_header_contains_proper_throttle_field(MockView, -         ((0, None), -          (0, None), -          (0, None), -          (0, '1') -         )) - -    def test_minutes_fields(self): -        """ -        Ensure for minute based throttles. -        """ -        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, -         ((0, None), -          (0, None), -          (0, None), -          (0, '60') -         )) - -    def test_next_rate_remains_constant_if_followed(self): -        """ -        If a client follows the recommended next request rate, -        the throttling rate should stay constant. -        """ -        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, -         ((0, None), -          (20, None), -          (40, None), -          (60, None), -          (80, None) -         )) - -    def test_non_time_throttle(self): -        """ -        Ensure for second based throttles. -        """ -        request = self.factory.get('/') - -        self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) - -        response = MockView_NonTimeThrottling.as_view()(request) -        self.assertFalse('X-Throttle-Wait-Seconds' in response) -        self.assertFalse('Retry-After' in response) - -        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) - -        response = MockView_NonTimeThrottling.as_view()(request) -        self.assertFalse('X-Throttle-Wait-Seconds' in response)  -        self.assertFalse('Retry-After' in response) - - -class ScopedRateThrottleTests(TestCase): -    """ -    Tests for ScopedRateThrottle. -    """ - -    def setUp(self): -        class XYScopedRateThrottle(ScopedRateThrottle): -            TIMER_SECONDS = 0 -            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} -            timer = lambda self: self.TIMER_SECONDS - -        class XView(APIView): -            throttle_classes = (XYScopedRateThrottle,) -            throttle_scope = 'x' - -            def get(self, request): -                return Response('x') - -        class YView(APIView): -            throttle_classes = (XYScopedRateThrottle,) -            throttle_scope = 'y' - -            def get(self, request): -                return Response('y') - -        class UnscopedView(APIView): -            throttle_classes = (XYScopedRateThrottle,) - -            def get(self, request): -                return Response('y') - -        self.throttle_class = XYScopedRateThrottle -        self.factory = APIRequestFactory() -        self.x_view = XView.as_view() -        self.y_view = YView.as_view() -        self.unscoped_view = UnscopedView.as_view() - -    def increment_timer(self, seconds=1): -        self.throttle_class.TIMER_SECONDS += seconds - -    def test_scoped_rate_throttle(self): -        request = self.factory.get('/') - -        # Should be able to hit x view 3 times per minute. -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(429, response.status_code) - -        # Should be able to hit y view 1 time per minute. -        self.increment_timer() -        response = self.y_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.y_view(request) -        self.assertEqual(429, response.status_code) - -        # Ensure throttles properly reset by advancing the rest of the minute -        self.increment_timer(55) - -        # Should still be able to hit x view 3 times per minute. -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.x_view(request) -        self.assertEqual(429, response.status_code) - -        # Should still be able to hit y view 1 time per minute. -        self.increment_timer() -        response = self.y_view(request) -        self.assertEqual(200, response.status_code) - -        self.increment_timer() -        response = self.y_view(request) -        self.assertEqual(429, response.status_code) - -    def test_unscoped_view_not_throttled(self): -        request = self.factory.get('/') - -        for idx in range(10): -            self.increment_timer() -            response = self.unscoped_view(request) -            self.assertEqual(200, response.status_code) diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py deleted file mode 100644 index 8132ec4c..00000000 --- a/rest_framework/tests/test_urlpatterns.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import unicode_literals -from collections import namedtuple -from django.core import urlresolvers -from django.test import TestCase -from rest_framework.test import APIRequestFactory -from rest_framework.compat import patterns, url, include -from rest_framework.urlpatterns import format_suffix_patterns - - -# A container class for test paths for the test case -URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) - - -def dummy_view(request, *args, **kwargs): -    pass - - -class FormatSuffixTests(TestCase): -    """ -    Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. -    """ -    def _resolve_urlpatterns(self, urlpatterns, test_paths): -        factory = APIRequestFactory() -        try: -            urlpatterns = format_suffix_patterns(urlpatterns) -        except Exception: -            self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns") -        resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) -        for test_path in test_paths: -            request = factory.get(test_path.path) -            try: -                callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) -            except Exception: -                self.fail("Failed to resolve URL: %s" % request.path_info) -            self.assertEqual(callback_args, test_path.args) -            self.assertEqual(callback_kwargs, test_path.kwargs) - -    def test_format_suffix(self): -        urlpatterns = patterns( -            '', -            url(r'^test$', dummy_view), -        ) -        test_paths = [ -            URLTestPath('/test', (), {}), -            URLTestPath('/test.api', (), {'format': 'api'}), -            URLTestPath('/test.asdf', (), {'format': 'asdf'}), -        ] -        self._resolve_urlpatterns(urlpatterns, test_paths) - -    def test_default_args(self): -        urlpatterns = patterns( -            '', -            url(r'^test$', dummy_view, {'foo': 'bar'}), -        ) -        test_paths = [ -            URLTestPath('/test', (), {'foo': 'bar', }), -            URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), -            URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), -        ] -        self._resolve_urlpatterns(urlpatterns, test_paths) - -    def test_included_urls(self): -        nested_patterns = patterns( -            '', -            url(r'^path$', dummy_view) -        ) -        urlpatterns = patterns( -            '', -            url(r'^test/', include(nested_patterns), {'foo': 'bar'}), -        ) -        test_paths = [ -            URLTestPath('/test/path', (), {'foo': 'bar', }), -            URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), -            URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), -        ] -        self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py deleted file mode 100644 index e13e4078..00000000 --- a/rest_framework/tests/test_validation.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import unicode_literals -from django.core.validators import MaxValueValidator -from django.db import models -from django.test import TestCase -from rest_framework import generics, serializers, status -from rest_framework.test import APIRequestFactory - -factory = APIRequestFactory() - - -# Regression for #666 - -class ValidationModel(models.Model): -    blank_validated_field = models.CharField(max_length=255) - - -class ValidationModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = ValidationModel -        fields = ('blank_validated_field',) -        read_only_fields = ('blank_validated_field',) - - -class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): -    model = ValidationModel -    serializer_class = ValidationModelSerializer - - -class TestPreSaveValidationExclusions(TestCase): -    def test_pre_save_validation_exclusions(self): -        """ -        Somewhat weird test case to ensure that we don't perform model -        validation on read only fields. -        """ -        obj = ValidationModel.objects.create(blank_validated_field='') -        request = factory.put('/', {}, format='json') -        view = UpdateValidationModel().as_view() -        response = view(request, pk=obj.pk).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) - - -# Regression for #653 - -class ShouldValidateModel(models.Model): -    should_validate_field = models.CharField(max_length=255) - - -class ShouldValidateModelSerializer(serializers.ModelSerializer): -    renamed = serializers.CharField(source='should_validate_field', required=False) - -    def validate_renamed(self, attrs, source): -        value = attrs[source] -        if len(value) < 3: -            raise serializers.ValidationError('Minimum 3 characters.') -        return attrs - -    class Meta: -        model = ShouldValidateModel -        fields = ('renamed',) - - -class TestPreSaveValidationExclusionsSerializer(TestCase): -    def test_renamed_fields_are_model_validated(self): -        """ -        Ensure fields with 'source' applied do get still get model validation. -        """ -        # We've set `required=False` on the serializer, but the model -        # does not have `blank=True`, so this serializer should not validate. -        serializer = ShouldValidateModelSerializer(data={'renamed': ''}) -        self.assertEqual(serializer.is_valid(), False) -        self.assertIn('renamed', serializer.errors) -        self.assertNotIn('should_validate_field', serializer.errors) - - -class TestCustomValidationMethods(TestCase): -    def test_custom_validation_method_is_executed(self): -        serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'}) -        self.assertFalse(serializer.is_valid()) -        self.assertIn('renamed', serializer.errors) - -    def test_custom_validation_method_passing(self): -        serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'}) -        self.assertTrue(serializer.is_valid()) - - -class ValidationSerializer(serializers.Serializer): -    foo = serializers.CharField() - -    def validate_foo(self, attrs, source): -        raise serializers.ValidationError("foo invalid") - -    def validate(self, attrs): -        raise serializers.ValidationError("serializer invalid") - - -class TestAvoidValidation(TestCase): -    """ -    If serializer was initialized with invalid data (None or non dict-like), it -    should avoid validation layer (validate_<field> and validate methods) -    """ -    def test_serializer_errors_has_only_invalid_data_error(self): -        serializer = ValidationSerializer(data='invalid data') -        self.assertFalse(serializer.is_valid()) -        self.assertDictEqual(serializer.errors, -                             {'non_field_errors': ['Invalid data']}) - - -# regression tests for issue: 1493 - -class ValidationMaxValueValidatorModel(models.Model): -    number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) - - -class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = ValidationMaxValueValidatorModel - - -class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): -    model = ValidationMaxValueValidatorModel -    serializer_class = ValidationMaxValueValidatorModelSerializer - - -class TestMaxValueValidatorValidation(TestCase): - -    def test_max_value_validation_serializer_success(self): -        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) -        self.assertTrue(serializer.is_valid()) - -    def test_max_value_validation_serializer_fails(self): -        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) -        self.assertFalse(serializer.is_valid()) -        self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) - -    def test_max_value_validation_success(self): -        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) -        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') -        view = UpdateMaxValueValidationModel().as_view() -        response = view(request, pk=obj.pk).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) - -    def test_max_value_validation_fail(self): -        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) -        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') -        view = UpdateMaxValueValidationModel().as_view() -        response = view(request, pk=obj.pk).render() -        self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py deleted file mode 100644 index 65c7e50e..00000000 --- a/rest_framework/tests/test_views.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import unicode_literals - -import copy -from django.test import TestCase -from rest_framework import status -from rest_framework.decorators import api_view -from rest_framework.response import Response -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory -from rest_framework.views import APIView - -factory = APIRequestFactory() - - -class BasicView(APIView): -    def get(self, request, *args, **kwargs): -        return Response({'method': 'GET'}) - -    def post(self, request, *args, **kwargs): -        return Response({'method': 'POST', 'data': request.DATA}) - - -@api_view(['GET', 'POST', 'PUT', 'PATCH']) -def basic_view(request): -    if request.method == 'GET': -        return {'method': 'GET'} -    elif request.method == 'POST': -        return {'method': 'POST', 'data': request.DATA} -    elif request.method == 'PUT': -        return {'method': 'PUT', 'data': request.DATA} -    elif request.method == 'PATCH': -        return {'method': 'PATCH', 'data': request.DATA} - - -class ErrorView(APIView): -    def get(self, request, *args, **kwargs): -        raise Exception - - -@api_view(['GET']) -def error_view(request): -    raise Exception - - -def sanitise_json_error(error_dict): -    """ -    Exact contents of JSON error messages depend on the installed version -    of json. -    """ -    ret = copy.copy(error_dict) -    chop = len('JSON parse error - No JSON object could be decoded') -    ret['detail'] = ret['detail'][:chop] -    return ret - - -class ClassBasedViewIntegrationTests(TestCase): -    def setUp(self): -        self.view = BasicView.as_view() - -    def test_400_parse_error(self): -        request = factory.post('/', 'f00bar', content_type='application/json') -        response = self.view(request) -        expected = { -            'detail': 'JSON parse error - No JSON object could be decoded' -        } -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(sanitise_json_error(response.data), expected) - -    def test_400_parse_error_tunneled_content(self): -        content = 'f00bar' -        content_type = 'application/json' -        form_data = { -            api_settings.FORM_CONTENT_OVERRIDE: content, -            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type -        } -        request = factory.post('/', form_data) -        response = self.view(request) -        expected = { -            'detail': 'JSON parse error - No JSON object could be decoded' -        } -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(sanitise_json_error(response.data), expected) - - -class FunctionBasedViewIntegrationTests(TestCase): -    def setUp(self): -        self.view = basic_view - -    def test_400_parse_error(self): -        request = factory.post('/', 'f00bar', content_type='application/json') -        response = self.view(request) -        expected = { -            'detail': 'JSON parse error - No JSON object could be decoded' -        } -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(sanitise_json_error(response.data), expected) - -    def test_400_parse_error_tunneled_content(self): -        content = 'f00bar' -        content_type = 'application/json' -        form_data = { -            api_settings.FORM_CONTENT_OVERRIDE: content, -            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type -        } -        request = factory.post('/', form_data) -        response = self.view(request) -        expected = { -            'detail': 'JSON parse error - No JSON object could be decoded' -        } -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(sanitise_json_error(response.data), expected) - - -class TestCustomExceptionHandler(TestCase): -    def setUp(self): -        self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER - -        def exception_handler(exc): -            return Response('Error!', status=status.HTTP_400_BAD_REQUEST) - -        api_settings.EXCEPTION_HANDLER = exception_handler - -    def tearDown(self): -        api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER - -    def test_class_based_view_exception_handler(self): -        view = ErrorView.as_view() - -        request = factory.get('/', content_type='application/json') -        response = view(request) -        expected = 'Error!' -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(response.data, expected) - -    def test_function_based_view_exception_handler(self): -        view = error_view - -        request = factory.get('/', content_type='application/json') -        response = view(request) -        expected = 'Error!' -        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) -        self.assertEqual(response.data, expected) diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py deleted file mode 100644 index aabb18d6..00000000 --- a/rest_framework/tests/test_write_only_fields.py +++ /dev/null @@ -1,42 +0,0 @@ -from django.db import models -from django.test import TestCase -from rest_framework import serializers - - -class ExampleModel(models.Model): -    email = models.EmailField(max_length=100) -    password = models.CharField(max_length=100) - - -class WriteOnlyFieldTests(TestCase): -    def test_write_only_fields(self): -        class ExampleSerializer(serializers.Serializer): -            email = serializers.EmailField() -            password = serializers.CharField(write_only=True) - -        data = { -            'email': 'foo@example.com', -            'password': '123' -        } -        serializer = ExampleSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertEquals(serializer.object, data) -        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) - -    def test_write_only_fields_meta(self): -        class ExampleSerializer(serializers.ModelSerializer): -            class Meta: -                model = ExampleModel -                fields = ('email', 'password') -                write_only_fields = ('password',) - -        data = { -            'email': 'foo@example.com', -            'password': '123' -        } -        serializer = ExampleSerializer(data=data) -        self.assertTrue(serializer.is_valid()) -        self.assertTrue(isinstance(serializer.object, ExampleModel)) -        self.assertEquals(serializer.object.email, data['email']) -        self.assertEquals(serializer.object.password, data['password']) -        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py deleted file mode 100644 index 554ebd1a..00000000 --- a/rest_framework/tests/tests.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Force import of all modules in this package in order to get the standard test -runner to pick up the tests.  Yowzers. -""" -from __future__ import unicode_literals -import os -import django - -modules = [filename.rsplit('.', 1)[0] -           for filename in os.listdir(os.path.dirname(__file__)) -           if filename.endswith('.py') and not filename.startswith('_')] -__test__ = dict() - -if django.VERSION < (1, 6): -    for module in modules: -        exec("from rest_framework.tests.%s import *" % module) diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/rest_framework/tests/users/__init__.py +++ /dev/null diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py deleted file mode 100644 index 128bac90..00000000 --- a/rest_framework/tests/users/models.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.db import models - - -class User(models.Model): -    account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') -    active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py deleted file mode 100644 index da496554..00000000 --- a/rest_framework/tests/users/serializers.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import serializers - -from rest_framework.tests.users.models import User - - -class UserSerializer(serializers.ModelSerializer): -    class Meta: -        model = User diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py deleted file mode 100644 index a8f2eb0b..00000000 --- a/rest_framework/tests/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from contextlib import contextmanager -from rest_framework.compat import six -from rest_framework.settings import api_settings - - -@contextmanager -def temporary_setting(setting, value, module=None): -    """ -    Temporarily change value of setting for test. - -    Optionally reload given module, useful when module uses value of setting on -    import. -    """ -    original_value = getattr(api_settings, setting) -    setattr(api_settings, setting, value) - -    if module is not None: -        six.moves.reload_module(module) - -    yield - -    setattr(api_settings, setting, original_value) - -    if module is not None: -        six.moves.reload_module(module) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py deleted file mode 100644 index 3917b74a..00000000 --- a/rest_framework/tests/views.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import generics -from rest_framework.tests.models import NullableForeignKeySource -from rest_framework.tests.serializers import NullableFKSourceSerializer - - -class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): -    model = NullableForeignKeySource -    model_serializer_class = NullableFKSourceSerializer diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index efa9fb94..361dbddf 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -18,6 +18,25 @@ class BaseThrottle(object):          """          raise NotImplementedError('.allow_request() must be overridden') +    def get_ident(self, request): +        """ +        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR +        if present and number of proxies is > 0. If not use all of +        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. +        """ +        xff = request.META.get('HTTP_X_FORWARDED_FOR') +        remote_addr = request.META.get('REMOTE_ADDR') +        num_proxies = api_settings.NUM_PROXIES + +        if num_proxies is not None: +            if num_proxies == 0 or xff is None: +                return remote_addr +            addrs = xff.split(',') +            client_addr = addrs[-min(num_proxies, len(xff))] +            return client_addr.strip() + +        return xff if xff else remote_addr +      def wait(self):          """          Optionally, return a recommended number of seconds to wait before @@ -41,7 +60,7 @@ class SimpleRateThrottle(BaseThrottle):      cache = default_cache      timer = time.time -    cache_format = 'throtte_%(scope)s_%(ident)s' +    cache_format = 'throttle_%(scope)s_%(ident)s'      scope = None      THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES @@ -157,10 +176,12 @@ class AnonRateThrottle(SimpleRateThrottle):          ident = request.META.get('HTTP_X_FORWARDED_FOR')          if ident is None:              ident = request.META.get('REMOTE_ADDR') +        else: +            ident = ''.join(ident.split())          return self.cache_format % {              'scope': self.scope, -            'ident': ident +            'ident': self.get_ident(request)          } @@ -178,7 +199,7 @@ class UserRateThrottle(SimpleRateThrottle):          if request.user.is_authenticated():              ident = request.user.id          else: -            ident = request.META.get('REMOTE_ADDR', None) +            ident = self.get_ident(request)          return self.cache_format % {              'scope': self.scope, @@ -226,7 +247,7 @@ class ScopedRateThrottle(SimpleRateThrottle):          if request.user.is_authenticated():              ident = request.user.id          else: -            ident = request.META.get('REMOTE_ADDR', None) +            ident = self.get_ident(request)          return self.cache_format % {              'scope': self.scope, diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 0ff137b0..038e9ee3 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,6 +1,6 @@  from __future__ import unicode_literals +from django.conf.urls import url, include  from django.core.urlresolvers import RegexURLResolver -from rest_framework.compat import url, include  from rest_framework.settings import api_settings diff --git a/rest_framework/urls.py b/rest_framework/urls.py index 9c4719f1..8fa3073e 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -2,23 +2,25 @@  Login and logout views for the browsable API.  Add these to your root URLconf if you're using the browsable API and -your API requires authentication. - -The urls must be namespaced as 'rest_framework', and you should make sure -your authentication settings include `SessionAuthentication`. +your API requires authentication:      urlpatterns = patterns('',          ...          url(r'^auth', include('rest_framework.urls', namespace='rest_framework'))      ) + +The urls must be namespaced as 'rest_framework', and you should make sure +your authentication settings include `SessionAuthentication`.  """  from __future__ import unicode_literals -from rest_framework.compat import patterns, url +from django.conf.urls import patterns, url +from django.contrib.auth import views  template_name = {'template_name': 'rest_framework/login.html'} -urlpatterns = patterns('django.contrib.auth.views', -    url(r'^login/$', 'login', template_name, name='login'), -    url(r'^logout/$', 'logout', template_name, name='logout'), +urlpatterns = patterns( +    '', +    url(r'^login/$', views.login, template_name, name='login'), +    url(r'^logout/$', views.logout, template_name, name='logout')  ) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index e5fa4194..00ffdfba 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -2,10 +2,11 @@  Helper classes for parsers.  """  from __future__ import unicode_literals +from django.utils import timezone  from django.db.models.query import QuerySet  from django.utils.datastructures import SortedDict  from django.utils.functional import Promise -from rest_framework.compat import timezone, force_text +from rest_framework.compat import force_text  from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  import datetime  import decimal @@ -97,14 +98,23 @@ else:                      node.flow_style = best_style              return node -    SafeDumper.add_representer(decimal.Decimal, -            SafeDumper.represent_decimal) - -    SafeDumper.add_representer(SortedDict, -            yaml.representer.SafeRepresenter.represent_dict) -    SafeDumper.add_representer(DictWithMetadata, -            yaml.representer.SafeRepresenter.represent_dict) -    SafeDumper.add_representer(SortedDictWithMetadata, -            yaml.representer.SafeRepresenter.represent_dict) -    SafeDumper.add_representer(types.GeneratorType, -            yaml.representer.SafeRepresenter.represent_list) +    SafeDumper.add_representer( +        decimal.Decimal, +        SafeDumper.represent_decimal +    ) +    SafeDumper.add_representer( +        SortedDict, +        yaml.representer.SafeRepresenter.represent_dict +    ) +    SafeDumper.add_representer( +        DictWithMetadata, +        yaml.representer.SafeRepresenter.represent_dict +    ) +    SafeDumper.add_representer( +        SortedDictWithMetadata, +        yaml.representer.SafeRepresenter.represent_dict +    ) +    SafeDumper.add_representer( +        types.GeneratorType, +        yaml.representer.SafeRepresenter.represent_list +    ) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 4b59ba84..6d53aed1 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -6,8 +6,6 @@ from __future__ import unicode_literals  from django.utils.html import escape  from django.utils.safestring import mark_safe  from rest_framework.compat import apply_markdown -from rest_framework.settings import api_settings -from textwrap import dedent  import re @@ -40,6 +38,7 @@ def dedent(content):      return content.strip() +  def camelcase_to_spaces(content):      """      Translate 'CamelCaseNames' to 'Camel Case Names'. @@ -49,6 +48,7 @@ def camelcase_to_spaces(content):      content = re.sub(camelcase_boundry, ' \\1', content).strip()      return ' '.join(content.split('_')).title() +  def markup_description(description):      """      Apply HTML markup to the given description. diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index c09c2933..87b3cc6a 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -57,7 +57,7 @@ class _MediaType(object):              if key != 'q' and other.params.get(key, None) != self.params.get(key, None):                  return False -        if self.sub_type != '*' and other.sub_type != '*'  and other.sub_type != self.sub_type: +        if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:              return False          if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: @@ -74,12 +74,12 @@ class _MediaType(object):              return 0          elif self.sub_type == '*':              return 1 -        elif not self.params or self.params.keys() == ['q']: +        elif not self.params or list(self.params.keys()) == ['q']:              return 2          return 3      def __str__(self): -        return unicode(self).encode('utf-8') +        return self.__unicode__().encode('utf-8')      def __unicode__(self):          ret = "%s/%s" % (self.main_type, self.sub_type) diff --git a/rest_framework/views.py b/rest_framework/views.py index d6ccb301..23df3443 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None):      return name +  def get_view_description(view_cls, html=False):      """      Given a view class, return a textual description to represent the view. @@ -120,7 +121,6 @@ class APIView(View):              headers['Vary'] = 'Accept'          return headers -      def http_method_not_allowed(self, request, *args, **kwargs):          """          If `request.method` does not correspond to a handler method, diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 7eb29f99..bb5b304e 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -127,11 +127,11 @@ class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,  class ModelViewSet(mixins.CreateModelMixin, -                    mixins.RetrieveModelMixin, -                    mixins.UpdateModelMixin, -                    mixins.DestroyModelMixin, -                    mixins.ListModelMixin, -                    GenericViewSet): +                   mixins.RetrieveModelMixin, +                   mixins.UpdateModelMixin, +                   mixins.DestroyModelMixin, +                   mixins.ListModelMixin, +                   GenericViewSet):      """      A viewset that provides default `create()`, `retrieve()`, `update()`,      `partial_update()`, `destroy()` and `list()` actions.  | 
