diff options
138 files changed, 12674 insertions, 3981 deletions
diff --git a/.travis.yml b/.travis.yml index 0dc87837..205feef9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,16 +3,34 @@ language: python python: - "2.6" - "2.7" + - "3.2" + - "3.3" env: - - DJANGO=https://github.com/django/django/zipball/master - - DJANGO=django==1.4.3 --use-mirrors - - DJANGO=django==1.3.5 --use-mirrors + - DJANGO="django==1.5 --use-mirrors" + - DJANGO="django==1.4.3 --use-mirrors" + - DJANGO="django==1.3.5 --use-mirrors" install: - pip install $DJANGO - - pip install django-filter==0.5.4 --use-mirrors + - pip install defusedxml==0.3 + - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211 --use-mirrors; fi" + - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.0 --use-mirrors; fi" + - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.3 --use-mirrors; fi" + - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4 --use-mirrors; fi" + - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6a1 --use-mirrors; fi" - export PYTHONPATH=. script: - python rest_framework/runtests/runtests.py + +matrix: + exclude: + - python: "3.2" + env: DJANGO="django==1.4.3 --use-mirrors" + - python: "3.2" + env: DJANGO="django==1.3.5 --use-mirrors" + - python: "3.3" + env: DJANGO="django==1.4.3 --use-mirrors" + - python: "3.3" + env: DJANGO="django==1.3.5 --use-mirrors" diff --git a/MANIFEST.in b/MANIFEST.in index 00e45086..15c4d0b0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include rest_framework/static *.js *.css *.png -recursive-include rest_framework/templates *.txt *.html +recursive-include rest_framework/templates *.html @@ -1,201 +1,110 @@ # Django REST framework -**A toolkit for building well-connected, self-describing web APIs.** - -**Author:** Tom Christie. [Follow me on Twitter][twitter]. - -**Support:** [REST framework discussion group][group]. +**Awesome web-browseable Web APIs.** [![build-status-image]][travis] ---- - -**Full documentation for REST framework is available on [http://django-rest-framework.org][docs].** +**Note**: Full documentation for the project is available at [http://django-rest-framework.org][docs]. -Note that this is the 2.0 version of REST framework. If you are looking for earlier versions please see the [0.4.x branch][0.4] on GitHub. +# Overview ---- +Django REST framework is a powerful and flexible toolkit that makes it easy to build Web APIs. -# Overview +Some reasons you might want to use REST framework: -Django REST framework is a lightweight library that makes it easy to build Web APIs. It is designed as a modular and easy to customize architecture, based on Django's class based views. +* The Web browseable API is a huge useability win for your developers. +* Authentication policies including OAuth1a and OAuth2 out of the box. +* Serialization that supports both ORM and non-ORM data sources. +* Customizable all the way down - just use regular function-based views if you don't need the more powerful features. +* Extensive documentation, and great community support. -Web APIs built using REST framework are fully self-describing and web browseable - a huge useability win for your developers. It also supports a wide range of media types, authentication and permission policies out of the box. +There is a live example API for testing purposes, [available here][sandbox]. -If you are considering using REST framework for your API, we recommend reading the [REST framework 2 announcment][rest-framework-2-announcement] which gives a good overview of the framework and it's capabilities. +**Below**: *Screenshot from the browseable API* -There is also a sandbox API you can use for testing purposes, [available here][sandbox]. +![Screenshot][image] # Requirements -* Python (2.6, 2.7) +* Python (2.6.5+, 2.7, 3.2, 3.3) * Django (1.3, 1.4, 1.5) -**Optional:** - -* [Markdown] - Markdown support for the self describing API. -* [PyYAML] - YAML content type support. -* [django-filter] - Filtering support. - # Installation Install using `pip`... pip install djangorestframework -...or clone the project from github. - - git clone git@github.com:tomchristie/django-rest-framework.git - pip install -r requirements.txt - -# Development - -To build the docs. - - ./mkdocs.py - -To run the tests. - - ./rest_framework/runtests/runtests.py - -# Changelog - -### 2.1.12 - -**Date**: 21st Dec 2012 - -* Bugfix: Fix bug that could occur using ChoiceField. -* Bugfix: Fix exception in browseable API on DELETE. -* Bugfix: Fix issue where pk was was being set to a string if set by URL kwarg. - -## 2.1.11 - -**Date**: 17th Dec 2012 - -* Bugfix: Fix issue with M2M fields in browseable API. - -## 2.1.10 - -**Date**: 17th Dec 2012 - -* Bugfix: Ensure read-only fields don't have model validation applied. -* Bugfix: Fix hyperlinked fields in paginated results. +Add `'rest_framework'` to your `INSTALLED_APPS` setting. -## 2.1.9 + INSTALLED_APPS = ( + ... + 'rest_framework', + ) -**Date**: 11th Dec 2012 +# Example -* Bugfix: Fix broken nested serialization. -* Bugfix: Fix `Meta.fields` only working as tuple not as list. -* Bugfix: Edge case if unnecessarily specifying `required=False` on read only field. +Let's take a look at a quick example of using REST framework to build a simple model-backed API for accessing users and groups. -## 2.1.8 +Here's our project's root `urls.py` module: -**Date**: 8th Dec 2012 + from django.conf.urls.defaults import url, patterns, include + from django.contrib.auth.models import User, Group + from rest_framework import viewsets, routers -* Fix for creating nullable Foreign Keys with `''` as well as `None`. -* Added `null=<bool>` related field option. + # ViewSets define the view behavior. + class UserViewSet(viewsets.ModelViewSet): + model = User -## 2.1.7 + class GroupViewSet(viewsets.ModelViewSet): + model = Group -**Date**: 7th Dec 2012 + + # Routers provide an easy way of automatically determining the URL conf + router = routers.DefaultRouter() + router.register(r'users', UserViewSet) + router.register(r'groups', GroupViewSet) -* Serializers now properly support nullable Foreign Keys. -* Serializer validation now includes model field validation, such as uniqueness constraints. -* Support 'true' and 'false' string values for BooleanField. -* Added pickle support for serialized data. -* Support `source='dotted.notation'` style for nested serializers. -* Make `Request.user` settable. -* Bugfix: Fix `RegexField` to work with `BrowsableAPIRenderer` -## 2.1.6 + # Wire up our API using automatic URL routing. + # Additionally, we include login URLs for the browseable API. + urlpatterns = patterns('', + url(r'^', include(router.urls)), + url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) + ) -**Date**: 23rd Nov 2012 +We'd also like to configure a couple of settings for our API. -* Bugfix: Unfix DjangoModelPermissions. (I am a doofus.) +Add the following to your `settings.py` module: -## 2.1.5 + REST_FRAMEWORK = { + # Use hyperlinked styles by default. + # Only used if the `serializer_class` attribute is not set on a view. + 'DEFAULT_MODEL_SERIALIZER_CLASS': + 'rest_framework.serializers.HyperlinkedModelSerializer', -**Date**: 23rd Nov 2012 + # Use Django's standard `django.contrib.auth` permissions, + # or allow read-only access for unauthenticated users. + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.DjangoModelPermissionsOrAnonReadOnly' + ] + } -* Bugfix: Fix DjangoModelPermissions. +Don't forget to make sure you've also added `rest_framework` to your `INSTALLED_APPS` setting. -## 2.1.4 +That's it, we're done! -**Date**: 22nd Nov 2012 +# Documentation & Support -* Support for partial updates with serializers. -* Added `RegexField`. -* Added `SerializerMethodField`. -* Serializer performance improvements. -* Added `obtain_token_view` to get tokens when using `TokenAuthentication`. -* Bugfix: Django 1.5 configurable user support for `TokenAuthentication`. +Full documentation for the project is available at [http://django-rest-framework.org][docs]. -## 2.1.3 +For questions and support, use the [REST framework discussion group][group], or `#restframework` on freenode IRC. -**Date**: 16th Nov 2012 - -* Added `FileField` and `ImageField`. For use with `MultiPartParser`. -* Added `URLField` and `SlugField`. -* Support for `read_only_fields` on `ModelSerializer` classes. -* Support for clients overriding the pagination page sizes. Use the `PAGINATE_BY_PARAM` setting or set the `paginate_by_param` attribute on a generic view. -* 201 Responses now return a 'Location' header. -* Bugfix: Serializer fields now respect `max_length`. - -## 2.1.2 - -**Date**: 9th Nov 2012 - -* **Filtering support.** -* Bugfix: Support creation of objects with reverse M2M relations. - -## 2.1.1 - -**Date**: 7th Nov 2012 - -* Support use of HTML exception templates. Eg. `403.html` -* Hyperlinked fields take optional `slug_field`, `slug_url_kwarg` and `pk_url_kwarg` arguments. -* Bugfix: Deal with optional trailing slashs properly when generating breadcrumbs. -* Bugfix: Make textareas same width as other fields in browsable API. -* Private API change: `.get_serializer` now uses same `instance` and `data` ordering as serializer initialization. - -## 2.1.0 - -**Date**: 5th Nov 2012 - -**Warning**: Please read [this thread][2.1.0-notes] regarding the `instance` and `data` keyword args before updating to 2.1.0. - -* **Serializer `instance` and `data` keyword args have their position swapped.** -* `queryset` argument is now optional on writable model fields. -* Hyperlinked related fields optionally take `slug_field` and `slug_field_kwarg` arguments. -* Support Django's cache framework. -* Minor field improvements. (Don't stringify dicts, more robust many-pk fields.) -* Bugfixes (Support choice field in Browseable API) - -## 2.0.2 - -**Date**: 2nd Nov 2012 - -* Fix issues with pk related fields in the browsable API. - -## 2.0.1 - -**Date**: 1st Nov 2012 - -* Add support for relational fields in the browsable API. -* Added SlugRelatedField and ManySlugRelatedField. -* If PUT creates an instance return '201 Created', instead of '200 OK'. - -## 2.0.0 - -**Date**: 30th Oct 2012 - -* Redesign of core components. -* Fix **all of the things**. +You may also want to [follow the author on Twitter][twitter]. # License -Copyright (c) 2011, Tom Christie +Copyright (c) 2011-2013, Tom Christie All rights reserved. Redistribution and use in source and binary forms, with or without @@ -218,7 +127,7 @@ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -[build-status-image]: https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=restframework2 +[build-status-image]: https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=master [travis]: http://travis-ci.org/tomchristie/django-rest-framework?branch=master [twitter]: https://twitter.com/_tomchristie [group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework @@ -226,10 +135,17 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. [sandbox]: http://restframework.herokuapp.com/ [rest-framework-2-announcement]: http://django-rest-framework.org/topics/rest-framework-2-announcement.html [2.1.0-notes]: https://groups.google.com/d/topic/django-rest-framework/Vv2M0CMY9bg/discussion +[image]: http://django-rest-framework.org/img/quickstart.png + +[tox]: http://testrun.org/tox/latest/ + +[tehjones]: https://twitter.com/tehjones/status/294986071979196416 +[wlonk]: https://twitter.com/wlonk/status/261689665952833536 +[laserllama]: https://twitter.com/laserllama/status/328688333750407168 [docs]: http://django-rest-framework.org/ [urlobject]: https://github.com/zacharyvoase/urlobject [markdown]: http://pypi.python.org/pypi/Markdown/ [pyyaml]: http://pypi.python.org/pypi/PyYAML -[django-filter]: https://github.com/alex/django-filter - +[defusedxml]: https://pypi.python.org/pypi/defusedxml +[django-filter]: http://pypi.python.org/pypi/django-filter diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 43fc15d2..c2f73901 100644..100755 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -8,25 +8,33 @@ Authentication is the mechanism of associating an incoming request with a set of identifying credentials, such as the user the request came from, or the token that it was signed with. The [permission] and [throttling] policies can then use those credentials to determine if the request should be permitted. -REST framework provides a number of authentication policies out of the box, and also allows you to implement custom policies. +REST framework provides a number of authentication schemes out of the box, and also allows you to implement custom schemes. -Authentication will run the first time either the `request.user` or `request.auth` properties are accessed, and determines how those properties are initialized. +Authentication is always run at the very start of the view, before the permission and throttling checks occur, and before any other code is allowed to proceed. The `request.user` property will typically be set to an instance of the `contrib.auth` package's `User` class. The `request.auth` property is used for any additional authentication information, for example, it may be used to represent an authentication token that the request was signed with. +--- + +**Note:** Don't forget that **authentication by itself won't allow or disallow an incoming request**, it simply identifies the credentials that the request was made with. + +For information on how to setup the permission polices for your API please see the [permissions documentation][permission]. + +--- + ## How authentication is determined -The authentication policy is always defined as a list of classes. REST framework will attempt to authenticate with each class in the list, and will set `request.user` and `request.auth` using the return value of the first class that successfully authenticates. +The authentication schemes are always defined as a list of classes. REST framework will attempt to authenticate with each class in the list, and will set `request.user` and `request.auth` using the return value of the first class that successfully authenticates. If no class authenticates, `request.user` will be set to an instance of `django.contrib.auth.models.AnonymousUser`, and `request.auth` will be set to `None`. The value of `request.user` and `request.auth` for unauthenticated requests can be modified using the `UNAUTHENTICATED_USER` and `UNAUTHENTICATED_TOKEN` settings. -## Setting the authentication policy +## Setting the authentication scheme -The default authentication policy may be set globally, using the `DEFAULT_AUTHENTICATION_CLASSES` setting. For example. +The default authentication schemes may be set globally, using the `DEFAULT_AUTHENTICATION` setting. For example. REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': ( @@ -35,7 +43,8 @@ The default authentication policy may be set globally, using the `DEFAULT_AUTHEN ) } -You can also set the authentication policy on a per-view basis, using the `APIView` class based views. +You can also set the authentication scheme on a per-view or per-viewset basis, +using the `APIView` class based views. class ExampleView(APIView): authentication_classes = (SessionAuthentication, BasicAuthentication) @@ -52,7 +61,7 @@ Or, if you're using the `@api_view` decorator with function based views. @api_view(['GET']) @authentication_classes((SessionAuthentication, BasicAuthentication)) - @permissions_classes((IsAuthenticated,)) + @permission_classes((IsAuthenticated,)) def example_view(request, format=None): content = { 'user': unicode(request.user), # `django.contrib.auth.User` instance. @@ -60,24 +69,59 @@ Or, if you're using the `@api_view` decorator with function based views. } return Response(content) +## Unauthorized and Forbidden responses + +When an unauthenticated request is denied permission there are two different error codes that may be appropriate. + +* [HTTP 401 Unauthorized][http401] +* [HTTP 403 Permission Denied][http403] + +HTTP 401 responses must always include a `WWW-Authenticate` header, that instructs the client how to authenticate. HTTP 403 responses do not include the `WWW-Authenticate` header. + +The kind of response that will be used depends on the authentication scheme. Although multiple authentication schemes may be in use, only one scheme may be used to determine the type of response. **The first authentication class set on the view is used when determining the type of response**. + +Note that when a request may successfully authenticate, but still be denied permission to perform the request, in which case a `403 Permission Denied` response will always be used, regardless of the authentication scheme. + +## Apache mod_wsgi specific configuration + +Note that if deploying to [Apache using mod_wsgi][mod_wsgi_official], the authorization header is not passed through to a WSGI application by default, as it is assumed that authentication will be handled by Apache, rather than at an application level. + +If you are deploying to Apache, and using any non-session based authentication, you will need to explicitly configure mod_wsgi to pass the required headers through to the application. This can be done by specifying the `WSGIPassAuthorization` directive in the appropriate context and setting it to `'On'`. + + # this can go in either server config, virtual host, directory or .htaccess + WSGIPassAuthorization On + +--- + # API Reference ## BasicAuthentication -This policy uses [HTTP Basic Authentication][basicauth], signed against a user's username and password. Basic authentication is generally only appropriate for testing. +This authentication scheme uses [HTTP Basic Authentication][basicauth], signed against a user's username and password. Basic authentication is generally only appropriate for testing. If successfully authenticated, `BasicAuthentication` provides the following credentials. * `request.user` will be a Django `User` instance. * `request.auth` will be `None`. -**Note:** If you use `BasicAuthentication` in production you must ensure that your API is only available over `https` only. You should also ensure that your API clients will always re-request the username and password at login, and will never store those details to persistent storage. +Unauthenticated responses that are denied permission will result in an `HTTP 401 Unauthorized` response with an appropriate WWW-Authenticate header. For example: + + WWW-Authenticate: Basic realm="api" + +**Note:** If you use `BasicAuthentication` in production you must ensure that your API is only available over `https`. You should also ensure that your API clients will always re-request the username and password at login, and will never store those details to persistent storage. ## TokenAuthentication -This policy uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients. +This authentication scheme uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients. -To use the `TokenAuthentication` policy, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting. +To use the `TokenAuthentication` scheme, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting: + + INSTALLED_APPS = ( + ... + 'rest_framework.authtoken' + ) + +Make sure to run `manage.py syncdb` after changing your settings. You'll also need to create tokens for your users. @@ -93,9 +137,23 @@ For clients to authenticate, the token key should be included in the `Authorizat If successfully authenticated, `TokenAuthentication` provides the following credentials. * `request.user` will be a Django `User` instance. -* `request.auth` will be a `rest_framework.tokenauth.models.BasicToken` instance. +* `request.auth` will be a `rest_framework.authtoken.models.BasicToken` instance. + +Unauthenticated responses that are denied permission will result in an `HTTP 401 Unauthorized` response with an appropriate WWW-Authenticate header. For example: + + WWW-Authenticate: Token + +The `curl` command line tool may be useful for testing token authenticated APIs. For example: + + curl -X GET http://127.0.0.1:8000/api/example/ -H 'Authorization: Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b' -**Note:** If you use `TokenAuthentication` in production you must ensure that your API is only available over `https` only. +--- + +**Note:** If you use `TokenAuthentication` in production you must ensure that your API is only available over `https`. + +--- + +#### Generating Tokens If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal. @@ -112,8 +170,7 @@ If you've already created some users, you can generate tokens for all existing u for user in User.objects.all(): Token.objects.get_or_create(user=user) -When using `TokenAuthentication`, you may want to provide a mechanism for clients to obtain a token given the username and password. -REST framework provides a built-in view to provide this behavior. To use it, add the `obtain_auth_token` view to your URLconf: +When using `TokenAuthentication`, you may want to provide a mechanism for clients to obtain a token given the username and password. REST framework provides a built-in view to provide this behavior. To use it, add the `obtain_auth_token` view to your URLconf: urlpatterns += patterns('', url(r'^api-token-auth/', 'rest_framework.authtoken.views.obtain_auth_token') @@ -125,32 +182,186 @@ The `obtain_auth_token` view will return a JSON response when valid `username` a { 'token' : '9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b' } -<!-- -## OAuthAuthentication +Note that the default `obtain_auth_token` view explicitly uses JSON requests and responses, rather than using default renderer and parser classes in your settings. If you need a customized version of the `obtain_auth_token` view, you can do so by overriding the `ObtainAuthToken` view class, and using that in your url conf instead. -This policy uses the [OAuth 2.0][oauth] protocol to authenticate requests. OAuth is appropriate for server-server setups, such as when you want to allow a third-party service to access your API on a user's behalf. +#### Custom user models -If successfully authenticated, `OAuthAuthentication` provides the following credentials. +The `rest_framework.authtoken` app includes a south migration that will create the authtoken table. If you're using a [custom user model][custom-user-model] you'll need to make sure that any initial migration that creates the user table runs before the authtoken table is created. -* `request.user` will be a Django `User` instance. -* `request.auth` will be a `rest_framework.models.OAuthToken` instance. ---> +You can do so by inserting a `needed_by` attribute in your user migration: + + class Migration: + + needed_by = ( + ('authtoken', '0001_initial'), + ) + + def forwards(self): + ... + +For more details, see the [south documentation on dependencies][south-dependencies]. ## SessionAuthentication -This policy uses Django's default session backend for authentication. Session authentication is appropriate for AJAX clients that are running in the same session context as your website. +This authentication scheme uses Django's default session backend for authentication. Session authentication is appropriate for AJAX clients that are running in the same session context as your website. If successfully authenticated, `SessionAuthentication` provides the following credentials. * `request.user` will be a Django `User` instance. * `request.auth` will be `None`. +Unauthenticated responses that are denied permission will result in an `HTTP 403 Forbidden` response. + +If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `PATCH`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details. + +## OAuthAuthentication + +This authentication uses [OAuth 1.0a][oauth-1.0a] authentication scheme. OAuth 1.0a provides signature validation which provides a reasonable level of security over plain non-HTTPS connections. However, it may also be considered more complicated than OAuth2, as it requires clients to sign their requests. + +This authentication class depends on the optional `django-oauth-plus` and `oauth2` packages. In order to make it work you must install these packages and add `oauth_provider` to your `INSTALLED_APPS`: + + INSTALLED_APPS = ( + ... + `oauth_provider`, + ) + +Don't forget to run `syncdb` once you've added the package. + + python manage.py syncdb + +#### Getting started with django-oauth-plus + +The OAuthAuthentication class only provides token verification and signature validation for requests. It doesn't provide authorization flow for your clients. You still need to implement your own views for accessing and authorizing tokens. + +The `django-oauth-plus` package provides simple foundation for classic 'three-legged' oauth flow. Please refer to [the documentation][django-oauth-plus] for more details. + +## OAuth2Authentication + +This authentication uses [OAuth 2.0][rfc6749] authentication scheme. OAuth2 is more simple to work with than OAuth1, and provides much better security than simple token authentication. It is an unauthenticated scheme, and requires you to use an HTTPS connection. + +This authentication class depends on the optional [django-oauth2-provider][django-oauth2-provider] project. In order to make it work you must install this package and add `provider` and `provider.oauth2` to your `INSTALLED_APPS`: + + INSTALLED_APPS = ( + ... + 'provider', + 'provider.oauth2', + ) + +You must also include the following in your root `urls.py` module: + + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + +Note that the `namespace='oauth2'` argument is required. + +Finally, sync your database. + + python manage.py syncdb + python manage.py migrate + +--- + +**Note:** If you use `OAuth2Authentication` in production you must ensure that your API is only available over `https`. + +--- + +#### Getting started with django-oauth2-provider + +The `OAuth2Authentication` class only provides token verification for requests. It doesn't provide authorization flow for your clients. + +The OAuth 2 authorization flow is taken care by the [django-oauth2-provider][django-oauth2-provider] dependency. A walkthrough is given here, but for more details you should refer to [the documentation][django-oauth2-provider-docs]. + +To get started: + +##### 1. Create a client + +You can create a client, either through the shell, or by using the Django admin. + +Go to the admin panel and create a new `Provider.Client` entry. It will create the `client_id` and `client_secret` properties for you. + +##### 2. Request an access token + +To request an access token, submit a `POST` request to the url `/oauth2/access_token` with the following fields: + +* `client_id` the client id you've just configured at the previous step. +* `client_secret` again configured at the previous step. +* `username` the username with which you want to log in. +* `password` well, that speaks for itself. + +You can use the command line to test that your local configuration is working: + + curl -X POST -d "client_id=YOUR_CLIENT_ID&client_secret=YOUR_CLIENT_SECRET&grant_type=password&username=YOUR_USERNAME&password=YOUR_PASSWORD" http://localhost:8000/oauth2/access_token/ + +You should get a response that looks something like this: + + {"access_token": "<your-access-token>", "scope": "read", "expires_in": 86399, "refresh_token": "<your-refresh-token>"} + +##### 3. Access the API + +The only thing needed to make the `OAuth2Authentication` class work is to insert the `access_token` you've received in the `Authorization` request header. + +The command line to test the authentication looks like: + + curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/ + +--- + # Custom authentication -To implement a custom authentication policy, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise. +To implement a custom authentication scheme, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise. + +In some circumstances instead of returning `None`, you may want to raise an `AuthenticationFailed` exception from the `.authenticate()` method. + +Typically the approach you should take is: + +* If authentication is not attempted, return `None`. Any other authentication schemes also in use will still be checked. +* If authentication is attempted but fails, raise a `AuthenticationFailed` exception. An error response will be returned immediately, regardless of any permissions checks, and without checking any other authentication schemes. + +You *may* also override the `.authenticate_header(self, request)` method. If implemented, it should return a string that will be used as the value of the `WWW-Authenticate` header in a `HTTP 401 Unauthorized` response. + +If the `.authenticate_header()` method is not overridden, the authentication scheme will return `HTTP 403 Forbidden` responses when an unauthenticated request is denied access. + +## Example + +The following example will authenticate any incoming request as the user given by the username in a custom request header named 'X_USERNAME'. + + class ExampleAuthentication(authentication.BaseAuthentication): + def authenticate(self, request): + username = request.META.get('X_USERNAME') + if not username: + return None + + try: + user = User.objects.get(username=username) + except User.DoesNotExist: + raise authenticate.AuthenticationFailed('No such user') + + return (user, None) + +--- + +# Third party packages + +The following third party packages are also available. + +## Digest Authentication + +HTTP digest authentication is a widely implemented scheme that was intended to replace HTTP basic authentication, and which provides a simple encrypted authentication mechanism. [Juan Riaza][juanriaza] maintains the [djangorestframework-digestauth][djangorestframework-digestauth] package which provides HTTP digest authentication support for REST framework. [cite]: http://jacobian.org/writing/rest-worst-practices/ +[http401]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.2 +[http403]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.4 [basicauth]: http://tools.ietf.org/html/rfc2617 [oauth]: http://oauth.net/2/ [permission]: permissions.md [throttling]: throttling.md +[csrf-ajax]: https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#ajax +[mod_wsgi_official]: http://code.google.com/p/modwsgi/wiki/ConfigurationDirectives#WSGIPassAuthorization +[custom-user-model]: https://docs.djangoproject.com/en/dev/topics/auth/customizing/#specifying-a-custom-user-model +[south-dependencies]: http://south.readthedocs.org/en/latest/dependencies.html +[juanriaza]: https://github.com/juanriaza +[djangorestframework-digestauth]: https://github.com/juanriaza/django-rest-framework-digestauth +[oauth-1.0a]: http://oauth.net/core/1.0a +[django-oauth-plus]: http://code.larlet.fr/django-oauth-plus +[django-oauth2-provider]: https://github.com/caffeinehit/django-oauth2-provider +[django-oauth2-provider-docs]: https://django-oauth2-provider.readthedocs.org/en/latest/ +[rfc6749]: http://tools.ietf.org/html/rfc6749 diff --git a/docs/api-guide/exceptions.md b/docs/api-guide/exceptions.md index ba57fde8..8b3e50f1 100644 --- a/docs/api-guide/exceptions.md +++ b/docs/api-guide/exceptions.md @@ -53,11 +53,27 @@ Raised if the request contains malformed data when accessing `request.DATA` or ` By default this exception results in a response with the HTTP status code "400 Bad Request". +## AuthenticationFailed + +**Signature:** `AuthenticationFailed(detail=None)` + +Raised when an incoming request includes incorrect authentication. + +By default this exception results in a response with the HTTP status code "401 Unauthenticated", but it may also result in a "403 Forbidden" response, depending on the authentication scheme in use. See the [authentication documentation][authentication] for more details. + +## NotAuthenticated + +**Signature:** `NotAuthenticated(detail=None)` + +Raised when an unauthenticated request fails the permission checks. + +By default this exception results in a response with the HTTP status code "401 Unauthenticated", but it may also result in a "403 Forbidden" response, depending on the authentication scheme in use. See the [authentication documentation][authentication] for more details. + ## PermissionDenied **Signature:** `PermissionDenied(detail=None)` -Raised when an incoming request fails the permission checks. +Raised when an authenticated request fails the permission checks. By default this exception results in a response with the HTTP status code "403 Forbidden". @@ -86,3 +102,4 @@ Raised when an incoming request fails the throttling checks. By default this exception results in a response with the HTTP status code "429 Too Many Requests". [cite]: http://www.doughellmann.com/articles/how-tos/python-exception-handling/index.html +[authentication]: authentication.md diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 50a09701..e117c370 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -2,11 +2,11 @@ # Serializer fields -> Flat is better than nested. +> Each field in a Form class is responsible not only for validating data, but also for "cleaning" it — normalizing it to a consistent format. > -> — [The Zen of Python][cite] +> — [Django documentation][cite] -Serializer fields handle converting between primative values and internal datatypes. They also deal with validating input values, as well as retrieving and setting the values from their parent objects. +Serializer fields handle converting between primitive values and internal datatypes. They also deal with validating input values, as well as retrieving and setting the values from their parent objects. --- @@ -28,7 +28,7 @@ Defaults to the name of the field. ### `read_only` -Set this to `True` to ensure that the field is used when serializing a representation, but is not used when updating an instance dureing deserialization. +Set this to `True` to ensure that the field is used when serializing a representation, but is not used when updating an instance during deserialization. Defaults to `False` @@ -41,7 +41,7 @@ Defaults to `True`. ### `default` -If set, this gives the default value that will be used for the field if none is supplied. If not set the default behaviour is to not populate the attribute at all. +If set, this gives the default value that will be used for the field if none is supplied. If not set the default behavior is to not populate the attribute at all. ### `validators` @@ -96,13 +96,13 @@ Would produce output similar to: 'expired': True } -By default, the `Field` class will perform a basic translation of the source value into primative datatypes, falling back to unicode representations of complex datatypes when necessary. +By default, the `Field` class will perform a basic translation of the source value into primitive datatypes, falling back to unicode representations of complex datatypes when necessary. -You can customize this behaviour by overriding the `.to_native(self, value)` method. +You can customize this behavior by overriding the `.to_native(self, value)` method. ## WritableField -A field that supports both read and write operations. By itself `WriteableField` does not perform any translation of input values into a given type. You won't typically use this field directly, but you may want to override it and implement the `.to_native(self, value)` and `.from_native(self, value)` methods. +A field that supports both read and write operations. By itself `WritableField` does not perform any translation of input values into a given type. You won't typically use this field directly, but you may want to override it and implement the `.to_native(self, value)` and `.from_native(self, value)` methods. ## ModelField @@ -110,6 +110,24 @@ A generic field that can be tied to any arbitrary model field. The `ModelField` **Signature:** `ModelField(model_field=<Django ModelField class>)` +## SerializerMethodField + +This is a read-only field. It gets its value by calling a method on the serializer class it is attached to. It can be used to add any sort of data to the serialized representation of your object. The field's constructor accepts a single argument, which is the name of the method on the serializer to be called. The method should accept a single argument (in addition to `self`), which is the object being serialized. It should return whatever you want to be included in the serialized representation of the object. For example: + + from rest_framework import serializers + from django.contrib.auth.models import User + from django.utils.timezone import now + + class UserSerializer(serializers.ModelSerializer): + + days_since_joined = serializers.SerializerMethodField('get_days_since_joined') + + class Meta: + model = User + + def get_days_since_joined(self, obj): + return (now() - obj.date_joined).days + --- # Typed Fields @@ -163,17 +181,60 @@ Corresponds to `django.forms.fields.RegexField` **Signature:** `RegexField(regex, max_length=None, min_length=None)` +## DateTimeField + +A date and time representation. + +Corresponds to `django.db.models.fields.DateTimeField` + +When using `ModelSerializer` or `HyperlinkedModelSerializer`, note that any model fields with `auto_now=True` or `auto_now_add=True` will use serializer fields that are `read_only=True` by default. + +If you want to override this behavior, you'll need to declare the `DateTimeField` explicitly on the serializer. For example: + + class CommentSerializer(serializers.ModelSerializer): + created = serializers.DateTimeField() + + class Meta: + model = Comment + +Note that by default, datetime representations are deteremined by the renderer in use, although this can be explicitly overridden as detailed below. + +In the case of JSON this means the default datetime representation uses the [ECMA 262 date time string specification][ecma262]. This is a subset of ISO 8601 which uses millisecond precision, and includes the 'Z' suffix for the UTC timezone, for example: `2013-01-29T12:34:56.123Z`. + +**Signature:** `DateTimeField(format=None, input_formats=None)` + +* `format` - A string representing the output format. If not specified, this defaults to `None`, which indicates that python `datetime` objects should be returned by `to_native`. In this case the datetime encoding will be determined by the renderer. +* `input_formats` - A list of strings representing the input formats which may be used to parse the date. If not specified, the `DATETIME_INPUT_FORMATS` setting will be used, which defaults to `['iso-8601']`. + +DateTime format strings may either be [python strftime formats][strftime] which explicitly specifiy the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style datetimes should be used. (eg `'2013-01-29T12:34:56.000000Z'`) + ## DateField A date representation. Corresponds to `django.db.models.fields.DateField` -## DateTimeField +**Signature:** `DateField(format=None, input_formats=None)` -A date and time representation. +* `format` - A string representing the output format. If not specified, this defaults to `None`, which indicates that python `date` objects should be returned by `to_native`. In this case the date encoding will be determined by the renderer. +* `input_formats` - A list of strings representing the input formats which may be used to parse the date. If not specified, the `DATE_INPUT_FORMATS` setting will be used, which defaults to `['iso-8601']`. -Corresponds to `django.db.models.fields.DateTimeField` +Date format strings may either be [python strftime formats][strftime] which explicitly specifiy the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style dates should be used. (eg `'2013-01-29'`) + +## TimeField + +A time representation. + +Optionally takes `format` as parameter to replace the matching pattern. + +Corresponds to `django.db.models.fields.TimeField` + +**Signature:** `TimeField(format=None, input_formats=None)` + +* `format` - A string representing the output format. If not specified, this defaults to `None`, which indicates that python `time` objects should be returned by `to_native`. In this case the time encoding will be determined by the renderer. +* `input_formats` - A list of strings representing the input formats which may be used to parse the date. If not specified, the `TIME_INPUT_FORMATS` setting will be used, which defaults to `['iso-8601']`. + +Time format strings may either be [python strftime formats][strftime] which explicitly specifiy the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style times should be used. (eg `'12:34:56.000000'`) ## IntegerField @@ -187,6 +248,12 @@ A floating point representation. Corresponds to `django.db.models.fields.FloatField`. +## DecimalField + +A decimal representation. + +Corresponds to `django.db.models.fields.DecimalField`. + ## FileField A file representation. Performs Django's standard FileField validation. @@ -211,151 +278,56 @@ Signature and validation is the same as with `FileField`. --- -**Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since eg json doesn't support file uploads. -Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files. +**Note:** `FileFields` and `ImageFields` are only suitable for use with MultiPartParser, since e.g. json doesn't support file uploads. +Django's regular [FILE_UPLOAD_HANDLERS] are used for handling uploaded files. --- -# Relational Fields - -Relational fields are used to represent model relationships. They can be applied to `ForeignKey`, `ManyToManyField` and `OneToOneField` relationships, as well as to reverse relationships, and custom relationships such as `GenericForeignKey`. - -## RelatedField - -This field can be applied to any of the following: - -* A `ForeignKey` field. -* A `OneToOneField` field. -* A reverse OneToOne relationship -* Any other "to-one" relationship. - -By default `RelatedField` will represent the target of the field using it's `__unicode__` method. +# Custom fields -You can customise this behaviour by subclassing `ManyRelatedField`, and overriding the `.to_native(self, value)` method. +If you want to create a custom field, you'll probably want to override either one or both of the `.to_native()` and `.from_native()` methods. These two methods are used to convert between the intial datatype, and a primative, serializable datatype. Primative datatypes may be any of a number, string, date/time/datetime or None. They may also be any list or dictionary like object that only contains other primative objects. -## ManyRelatedField +The `.to_native()` method is called to convert the initial datatype into a primative, serializable datatype. The `from_native()` method is called to restore a primative datatype into it's initial representation. -This field can be applied to any of the following: - -* A `ManyToManyField` field. -* A reverse ManyToMany relationship. -* A reverse ForeignKey relationship -* Any other "to-many" relationship. +## Examples -By default `ManyRelatedField` will represent the targets of the field using their `__unicode__` method. +Let's look at an example of serializing a class that represents an RGB color value: -For example, given the following models: - - class TaggedItem(models.Model): + class Color(object): """ - Tags arbitrary model instances using a generic relation. - - See: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/ + A color represented in the RGB colorspace. """ - tag = models.SlugField() - content_type = models.ForeignKey(ContentType) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') - - def __unicode__(self): - return self.tag - - - class Bookmark(models.Model): + def __init__(self, red, green, blue): + assert(red >= 0 and green >= 0 and blue >= 0) + assert(red < 256 and green < 256 and blue < 256) + self.red, self.green, self.blue = red, green, blue + + class ColourField(serializers.WritableField): """ - A bookmark consists of a URL, and 0 or more descriptive tags. + Color objects are serialized into "rgb(#, #, #)" notation. """ - url = models.URLField() - tags = GenericRelation(TaggedItem) - -And a model serializer defined like this: - - class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.ManyRelatedField(source='tags') - - class Meta: - model = Bookmark - exclude = ('id',) - -Then an example output format for a Bookmark instance would be: - - { - 'tags': [u'django', u'python'], - 'url': u'https://www.djangoproject.com/' - } - -## PrimaryKeyRelatedField / ManyPrimaryKeyRelatedField - -`PrimaryKeyRelatedField` and `ManyPrimaryKeyRelatedField` will represent the target of the relationship using it's primary key. - -By default these fields are read-write, although you can change this behaviour using the `read_only` flag. - -**Arguments**: - -* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. -* `null` - If set to `True`, the field will accept values of `None` or the emptystring for nullable relationships. - -## SlugRelatedField / ManySlugRelatedField - -`SlugRelatedField` and `ManySlugRelatedField` will represent the target of the relationship using a unique slug. - -By default these fields read-write, although you can change this behaviour using the `read_only` flag. - -**Arguments**: - -* `slug_field` - The field on the target that should be used to represent it. This should be a field that uniquely identifies any given instance. For example, `username`. -* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. -* `null` - If set to `True`, the field will accept values of `None` or the emptystring for nullable relationships. - -## HyperlinkedRelatedField / ManyHyperlinkedRelatedField - -`HyperlinkedRelatedField` and `ManyHyperlinkedRelatedField` will represent the target of the relationship using a hyperlink. - -By default, `HyperlinkedRelatedField` is read-write, although you can change this behaviour using the `read_only` flag. - -**Arguments**: - -* `view_name` - The view name that should be used as the target of the relationship. **required**. -* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument. -* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. -* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`. -* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`. -* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`. -* `null` - If set to `True`, the field will accept values of `None` or the emptystring for nullable relationships. - -## HyperLinkedIdentityField - -This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. - -This field is always read-only. - -**Arguments**: - -* `view_name` - The view name that should be used as the target of the relationship. **required**. -* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument. -* `slug_field` - The field on the target that should be used for the lookup. Default is `'slug'`. -* `pk_url_kwarg` - The named url parameter for the pk field lookup. Default is `pk`. -* `slug_url_kwarg` - The named url parameter for the slug field lookup. Default is to use the same value as given for `slug_field`. - -# Other Fields - -## SerializerMethodField - -This is a read-only field. It gets its value by calling a method on the serializer class it is attached to. It can be used to add any sort of data to the serialized representation of your object. The field's constructor accepts a single argument, which is the name of the method on the serializer to be called. The method should accept a single argument (in addition to `self`), which is the object being serialized. It should return whatever you want to be included in the serialized representation of the object. For example: - - from rest_framework import serializers - from django.contrib.auth.models import User - from django.utils.timezone import now - - class UserSerializer(serializers.ModelSerializer): - - days_since_joined = serializers.SerializerMethodField('get_days_since_joined') - - class Meta: - model = User - - def get_days_since_joined(self, obj): - return (now() - obj.date_joined).days - -[cite]: http://www.python.org/dev/peps/pep-0020/ + def to_native(self, obj): + return "rgb(%d, %d, %d)" % (obj.red, obj.green, obj.blue) + + def from_native(self, data): + data = data.strip('rgb(').rstrip(')') + red, green, blue = [int(col) for col in data.split(',')] + return Color(red, green, blue) + + +By default field values are treated as mapping to an attribute on the object. If you need to customize how the field value is accessed and set you need to override `.field_to_native()` and/or `.field_from_native()`. + +As an example, let's create a field that can be used represent the class name of the object being serialized: + + class ClassNameField(serializers.Field): + def field_to_native(self, obj, field_name): + """ + Serialize the object's class name. + """ + return obj.__class__ + +[cite]: https://docs.djangoproject.com/en/dev/ref/forms/api/#django.forms.Form.cleaned_data [FILE_UPLOAD_HANDLERS]: https://docs.djangoproject.com/en/dev/ref/settings/#std:setting-FILE_UPLOAD_HANDLERS +[ecma262]: http://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 +[strftime]: http://docs.python.org/2/library/datetime.html#strftime-and-strptime-behavior +[iso8601]: http://www.w3.org/TR/NOTE-datetime diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index 53ea7cbc..a710ad7d 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -8,7 +8,7 @@ The default behavior of REST framework's generic list views is to return the entire queryset for a model manager. Often you will want your API to restrict the items that are returned by the queryset. -The simplest way to filter the queryset of any view that subclasses `MultipleObjectAPIView` is to override the `.get_queryset()` method. +The simplest way to filter the queryset of any view that subclasses `GenericAPIView` is to override the `.get_queryset()` method. Overriding this method allows you to customize the queryset returned by the view in a number of different ways. @@ -21,7 +21,6 @@ You can do so by filtering based on the value of `request.user`. For example: class PurchaseList(generics.ListAPIView) - model = Purchase serializer_class = PurchaseSerializer def get_queryset(self): @@ -44,7 +43,6 @@ For example if your URL config contained an entry like this: You could then write a view that returned a purchase queryset filtered by the username portion of the URL: class PurchaseList(generics.ListAPIView) - model = Purchase serializer_class = PurchaseSerializer def get_queryset(self): @@ -62,7 +60,6 @@ A final example of filtering the initial queryset would be to determine the init We can override `.get_queryset()` to deal with URLs such as `http://example.com/api/purchases?username=denvercoder9`, and filter the queryset only if the `username` parameter is included in the URL: class PurchaseList(generics.ListAPIView) - model = Purchase serializer_class = PurchaseSerializer def get_queryset(self): @@ -80,35 +77,76 @@ We can override `.get_queryset()` to deal with URLs such as `http://example.com/ # Generic Filtering -As well as being able to override the default queryset, REST framework also includes support for generic filtering backends that allow you to easily construct complex filters that can be specified by the client using query parameters. +As well as being able to override the default queryset, REST framework also includes support for generic filtering backends that allow you to easily construct complex searches and filters. -REST framework supports pluggable backends to implement filtering, and provides an implementation which uses the [django-filter] package. +## Setting filter backends -To use REST framework's filtering backend, first install `django-filter`. - - pip install django-filter - -You must also set the filter backend to `DjangoFilterBackend` in your settings: +The default filter backends may be set globally, using the `DEFAULT_FILTER_BACKENDS` setting. For example. REST_FRAMEWORK = { - 'FILTER_BACKEND': 'rest_framework.filters.DjangoFilterBackend' + 'DEFAULT_FILTER_BACKENDS': ('rest_framework.filters.DjangoFilterBackend',) } +You can also set the authentication policy on a per-view, or per-viewset basis, +using the `GenericAPIView` class based views. -## Specifying filter fields + class UserListView(generics.ListAPIView): + queryset = User.objects.all() + serializer = UserSerializer + filter_backends = (filters.DjangoFilterBackend,) -If all you need is simple equality-based filtering, you can set a `filter_fields` attribute on the view, listing the set of fields you wish to filter against. +## Filtering and object lookups - class ProductList(generics.ListAPIView): +Note that if a filter backend is configured for a view, then as well as being used to filter list views, it will also be used to filter the querysets used for returning a single object. + +For instance, given the previous example, and a product with an id of `4675`, the following URL would either return the corresponding object, or return a 404 response, depending on if the filtering conditions were met by the given product instance: + + http://example.com/api/products/4675/?category=clothing&max_price=10.00 + +## Overriding the initial queryset + +Note that you can use both an overridden `.get_queryset()` and generic filtering together, and everything will work as expected. For example, if `Product` had a many-to-many relationship with `User`, named `purchase`, you might want to write a view like this: + + class PurchasedProductsList(generics.ListAPIView): + """ + Return a list of all the products that the authenticated + user has ever purchased, with optional filtering. + """ model = Product serializer_class = ProductSerializer + filter_class = ProductFilter + + def get_queryset(self): + user = self.request.user + return user.purchase_set.all() + +--- + +# API Guide + +## DjangoFilterBackend + +The `DjangoFilterBackend` class supports highly customizable field filtering, using the [django-filter package][django-filter]. + +To use REST framework's `DjangoFilterBackend`, first install `django-filter`. + + pip install django-filter + + +#### Specifying filter fields + +If all you need is simple equality-based filtering, you can set a `filter_fields` attribute on the view, or viewset, listing the set of fields you wish to filter against. + + class ProductList(generics.ListAPIView): + queryset = Product.objects.all() + serializer_class = ProductSerializer filter_fields = ('category', 'in_stock') This will automatically create a `FilterSet` class for the given fields, and will allow you to make requests such as: http://example.com/api/products?category=clothing&in_stock=True -## Specifying a FilterSet +#### Specifying a FilterSet For more advanced filtering requirements you can specify a `FilterSet` class that should be used by the view. For example: @@ -120,7 +158,7 @@ For more advanced filtering requirements you can specify a `FilterSet` class tha fields = ['category', 'in_stock', 'min_price', 'max_price'] class ProductList(generics.ListAPIView): - model = Product + queryset = Product.objects.all() serializer_class = ProductSerializer filter_class = ProductFilter @@ -134,28 +172,74 @@ For more details on using filter sets see the [django-filter documentation][djan **Hints & Tips** -* By default filtering is not enabled. If you want to use `DjangoFilterBackend` remember to make sure it is installed by using the `'FILTER_BACKEND'` setting. +* By default filtering is not enabled. If you want to use `DjangoFilterBackend` remember to make sure it is installed by using the `'DEFAULT_FILTER_BACKENDS'` setting. * When using boolean fields, you should use the values `True` and `False` in the URL query parameters, rather than `0`, `1`, `true` or `false`. (The allowed boolean values are currently hardwired in Django's [NullBooleanSelect implementation][nullbooleanselect].) * `django-filter` supports filtering across relationships, using Django's double-underscore syntax. --- -## Overriding the initial queryset - -Note that you can use both an overridden `.get_queryset()` and generic filtering together, and everything will work as expected. For example, if `Product` had a many-to-many relationship with `User`, named `purchase`, you might want to write a view like this: +## SearchFilter + +The `SearchFilterBackend` class supports simple single query parameter based searching, and is based on the [Django admin's search functionality][search-django-admin]. + +The `SearchFilterBackend` class will only be applied if the view has a `search_fields` attribute set. The `search_fields` attribute should be a list of names of text type fields on the model, such as `CharField` or `TextField`. + + class UserListView(generics.ListAPIView): + queryset = User.objects.all() + serializer = UserSerializer + filter_backends = (filters.SearchFilter,) + search_fields = ('username', 'email') + +This will allow the client to filter the items in the list by making queries such as: + + http://example.com/api/users?search=russell + +You can also perform a related lookup on a ForeignKey or ManyToManyField with the lookup API double-underscore notation: + + search_fields = ('username', 'email', 'profile__profession') + +By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched. + +The search behavior may be restricted by prepending various characters to the `search_fields`. + +* '^' Starts-with search. +* '=' Exact matches. +* '@' Full-text search. (Currently only supported Django's MySQL backend.) + +For example: + + search_fields = ('=username', '=email') + +For more details, see the [Django documentation][search-django-admin]. + +--- + +## OrderingFilter + +The `OrderingFilter` class supports simple query parameter controlled ordering of results. To specify the result order, set a query parameter named `'order'` to the required field name. For example: + + http://example.com/api/users?ordering=username + +The client may also specify reverse orderings by prefixing the field name with '-', like so: + + http://example.com/api/users?ordering=-username + +Multiple orderings may also be specified: + + http://example.com/api/users?ordering=account,username + +If an `ordering` attribute is set on the view, this will be used as the default ordering. + +Typicaly you'd instead control this by setting `order_by` on the initial queryset, but using the `ordering` parameter on the view allows you to specify the ordering in a way that it can then be passed automatically as context to a rendered template. This makes it possible to automatically render column headers differently if they are being used to order the results. + + class UserListView(generics.ListAPIView): + queryset = User.objects.all() + serializer = UserSerializer + filter_backends = (filters.OrderingFilter,) + ordering = ('username',) + +The `ordering` attribute may be either a string or a list/tuple of strings. - class PurchasedProductsList(generics.ListAPIView): - """ - Return a list of all the products that the authenticated - user has ever purchased, with optional filtering. - """ - model = Product - serializer_class = ProductSerializer - filter_class = ProductFilter - - def get_queryset(self): - user = self.request.user - return user.purchase_set.all() --- # Custom generic filtering @@ -164,15 +248,23 @@ You can also provide your own generic filtering backend, or write an installable To do so override `BaseFilterBackend`, and override the `.filter_queryset(self, request, queryset, view)` method. The method should return a new, filtered queryset. -To install the filter backend, set the `'FILTER_BACKEND'` key in your `'REST_FRAMEWORK'` setting, using the dotted import path of the filter backend class. +As well as allowing clients to perform searches and filtering, generic filter backends can be useful for restricting which objects should be visible to any given request or user. -For example: +## Example - REST_FRAMEWORK = { - 'FILTER_BACKEND': 'custom_filters.CustomFilterBackend' - } +For example, you might need to restrict users to only being able to see objects they created. + + class IsOwnerFilterBackend(filters.BaseFilterBackend): + """ + Filter that only allows users to see their own objects. + """ + def filter_queryset(self, request, queryset, view): + return queryset.filter(owner=request.user) + +We could achieve the same behavior by overriding `get_queryset()` on the views, but using a filter backend allows you to more easily add this restriction to multiple views, or to apply it across the entire API. [cite]: https://docs.djangoproject.com/en/dev/topics/db/queries/#retrieving-specific-objects-with-filters [django-filter]: https://github.com/alex/django-filter [django-filter-docs]: https://django-filter.readthedocs.org/en/latest/index.html -[nullbooleanselect]: https://github.com/django/django/blob/master/django/forms/widgets.py
\ No newline at end of file +[nullbooleanselect]: https://github.com/django/django/blob/master/django/forms/widgets.py +[search-django-admin]: https://docs.djangoproject.com/en/dev/ref/contrib/admin/#django.contrib.admin.ModelAdmin.search_fields diff --git a/docs/api-guide/format-suffixes.md b/docs/api-guide/format-suffixes.md index 6d5feba4..dae3dea3 100644 --- a/docs/api-guide/format-suffixes.md +++ b/docs/api-guide/format-suffixes.md @@ -29,18 +29,27 @@ Example: urlpatterns = patterns('blog.views', url(r'^/$', 'api_root'), - url(r'^comment/$', 'comment_root'), - url(r'^comment/(?P<pk>[0-9]+)/$', 'comment_instance') + url(r'^comments/$', 'comment_list'), + url(r'^comments/(?P<pk>[0-9]+)/$', 'comment_detail') ) urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'html']) -When using `format_suffix_patterns`, you must make sure to add the `'format'` keyword argument to the corresponding views. For example. +When using `format_suffix_patterns`, you must make sure to add the `'format'` keyword argument to the corresponding views. For example: - @api_view(('GET',)) - def api_root(request, format=None): + @api_view(('GET', 'POST')) + def comment_list(request, format=None): # do stuff... +Or with class based views: + + class CommentList(APIView): + def get(self, request, format=None): + # do stuff... + + def post(self, request, format=None): + # do stuff... + The name of the kwarg used may be modified by using the `FORMAT_SUFFIX_KWARG` setting. Also note that `format_suffix_patterns` does not support descending into `include` URL patterns. @@ -58,4 +67,4 @@ It is actually a misconception. For example, take the following quote from Roy The quote does not mention Accept headers, but it does make it clear that format suffixes should be considered an acceptable pattern. [cite]: http://tech.groups.yahoo.com/group/rest-discuss/message/5857 -[cite2]: http://tech.groups.yahoo.com/group/rest-discuss/message/14844
\ No newline at end of file +[cite2]: http://tech.groups.yahoo.com/group/rest-discuss/message/14844 diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index 27c7d3f6..1a060a32 100644..100755 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -18,7 +18,7 @@ If the generic views don't suit the needs of your API, you can drop down to usin Typically when using the generic views, you'll override the view, and set several class attributes. class UserList(generics.ListCreateAPIView): - model = User + queryset = User.objects.all() serializer_class = UserSerializer permission_classes = (IsAdminUser,) paginate_by = 100 @@ -26,17 +26,16 @@ Typically when using the generic views, you'll override the view, and set severa For more complex cases you might also want to override various methods on the view class. For example. class UserList(generics.ListCreateAPIView): - model = User + queryset = User.objects.all() serializer_class = UserSerializer permission_classes = (IsAdminUser,) - def get_paginate_by(self, queryset): + def get_paginate_by(self): """ Use smaller pagination for HTML representations. """ - page_size_param = self.request.QUERY_PARAMS.get('page_size') - if page_size_param: - return int(page_size_param) + if self.request.accepted_renderer.format == 'html': + return 20 return 100 For very simple cases you might want to pass through any class attributes using the `.as_view()` method. For example, your URLconf might include something the following entry. @@ -47,117 +46,127 @@ For very simple cases you might want to pass through any class attributes using # API Reference -The following classes are the concrete generic views. If you're using generic views this is normally the level you'll be working at unless you need heavily customized behavior. +## GenericAPIView -## CreateAPIView +This class extends REST framework's `APIView` class, adding commonly required behavior for standard list and detail views. -Used for **create-only** endpoints. +Each of the concrete generic views provided is built by combining `GenericAPIView`, with one or more mixin classes. -Provides `post` method handlers. +### Attributes -Extends: [GenericAPIView], [CreateModelMixin] +**Basic settings**: -## ListAPIView +The following attributes control the basic view behavior. -Used for **read-only** endpoints to represent a **collection of model instances**. +* `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method. +* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method. +* `lookup_field` - The field that should be used to lookup individual model instances. Defaults to `'pk'`. The URL conf should include a keyword argument corresponding to this value. More complex lookup styles can be supported by overriding the `get_object()` method. -Provides a `get` method handler. +**Shortcuts**: -Extends: [MultipleObjectAPIView], [ListModelMixin] +* `model` - This shortcut may be used instead of setting either (or both) of the `queryset`/`serializer_class` attributes, although using the explicit style is generally preferred. If used instead of `serializer_class`, then then `DEFAULT_MODEL_SERIALIZER_CLASS` setting will determine the base serializer class. -## RetrieveAPIView +**Pagination**: -Used for **read-only** endpoints to represent a **single model instance**. +The following attibutes are used to control pagination when used with list views. -Provides a `get` method handler. +* `paginate_by` - The size of pages to use with paginated data. If set to `None` then pagination is turned off. If unset this uses the same value as the `PAGINATE_BY` setting, which defaults to `None`. +* `paginate_by_param` - The name of a query parameter, which can be used by the client to overide the default page size to use for pagination. If unset this uses the same value as the `PAGINATE_BY_PARAM` setting, which defaults to `None`. +* `pagination_serializer_class` - The pagination serializer class to use when determining the style of paginated responses. Defaults to the same value as the `DEFAULT_PAGINATION_SERIALIZER_CLASS` setting. +* `page_kwarg` - The name of a URL kwarg or URL query parameter which can be used by the client to control which page is requested. Defaults to `'page'`. -Extends: [SingleObjectAPIView], [RetrieveModelMixin] +**Filtering**: -## DestroyAPIView +* `filter_backends` - A list of filter backend classes that should be used for filtering the queryset. Defaults to the same value as the `DEFAULT_FILTER_BACKENDS` setting. -Used for **delete-only** endpoints for a **single model instance**. +### Methods -Provides a `delete` method handler. +**Base methods**: -Extends: [SingleObjectAPIView], [DestroyModelMixin] +#### `get_queryset(self)` -## UpdateAPIView +Returns the queryset that should be used for list views, and that should be used as the base for lookups in detail views. Defaults to returning the queryset specified by the `queryset` attribute, or the default queryset for the model if the `model` shortcut is being used. -Used for **update-only** endpoints for a **single model instance**. +May be overridden to provide dynamic behavior such as returning a queryset that is specific to the user making the request. -Provides a `put` method handler. +For example: -Extends: [SingleObjectAPIView], [UpdateModelMixin] + def get_queryset(self): + return self.user.accounts.all() -## ListCreateAPIView +#### `get_object(self)` -Used for **read-write** endpoints to represent a **collection of model instances**. +Returns an object instance that should be used for detail views. Defaults to using the `lookup_field` parameter to filter the base queryset. -Provides `get` and `post` method handlers. +May be overridden to provide more complex behavior such as object lookups based on more than one URL kwarg. -Extends: [MultipleObjectAPIView], [ListModelMixin], [CreateModelMixin] +For example: -## RetrieveDestroyAPIView + def get_object(self): + queryset = self.get_queryset() + filter = {} + for field in self.multiple_lookup_fields: + filter[field] = self.kwargs[field] + return get_object_or_404(queryset, **filter) -Used for **read or delete** endpoints to represent a **single model instance**. +#### `get_serializer_class(self)` -Provides `get` and `delete` method handlers. +Returns the class that should be used for the serializer. Defaults to returning the `serializer_class` attribute, or dynamically generating a serializer class if the `model` shortcut is being used. -Extends: [SingleObjectAPIView], [RetrieveModelMixin], [DestroyModelMixin] +May be override to provide dynamic behavior such as using different serializers for read and write operations, or providing different serializers to different types of uesr. -## RetrieveUpdateDestroyAPIView +For example: -Used for **read-write-delete** endpoints to represent a **single model instance**. + def get_serializer_class(self): + if self.request.user.is_staff: + return FullAccountSerializer + return BasicAccountSerializer -Provides `get`, `put` and `delete` method handlers. +#### `get_paginate_by(self)` -Extends: [SingleObjectAPIView], [RetrieveModelMixin], [UpdateModelMixin], [DestroyModelMixin] +Returns the page size to use with pagination. By default this uses the `paginate_by` attribute, and may be overridden by the cient if the `paginate_by_param` attribute is set. ---- +You may want to override this method to provide more complex behavior such as modifying page sizes based on the media type of the response. -# Base views +For example: -Each of the generic views provided is built by combining one of the base views below, with one or more mixin classes. + def get_paginate_by(self): + self.request.accepted_renderer.format == 'html': + return 20 + return 100 -## GenericAPIView +**Save hooks**: -Extends REST framework's `APIView` class, adding support for serialization of model instances and model querysets. +The following methods are provided as placeholder interfaces. They contain empty implementations and are not called directly by `GenericAPIView`, but they are overridden and used by some of the mixin classes. -**Attributes**: +* `pre_save(self, obj)` - A hook that is called before saving an object. +* `post_save(self, obj, created=False)` - A hook that is called after saving an object. -* `model` - The model that should be used for this view. Used as a fallback for determining the serializer if `serializer_class` is not set, and as a fallback for determining the queryset if `queryset` is not set. Otherwise not required. -* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. If unset, this defaults to creating a serializer class using `self.model`, with the `DEFAULT_MODEL_SERIALIZER_CLASS` setting as the base serializer class. +The `pre_save` method in particular is a useful hook for setting attributes that are implicit in the request, but are not part of the request data. For instance, you might set an attribute on the object based on the request user, or based on a URL keyword argument. -## MultipleObjectAPIView + def pre_save(self, obj): + """ + Set the object's owner, based on the incoming request. + """ + obj.owner = self.request.user -Provides a base view for acting on a single object, by combining REST framework's `APIView`, and Django's [MultipleObjectMixin]. +Remember that the `pre_save()` method is not called by `GenericAPIView` itself, but it is called by `create()` and `update()` methods on the `CreateModelMixin` and `UpdateModelMixin` classes. -**See also:** ccbv.co.uk documentation for [MultipleObjectMixin][multiple-object-mixin-classy]. +**Other methods**: -**Attributes**: +You won't typically need to override the following methods, although you might need to call into them if you're writing custom views using `GenericAPIView`. -* `queryset` - The queryset that should be used for returning objects from this view. If unset, defaults to the default queryset manager for `self.model`. -* `paginate_by` - The size of pages to use with paginated data. If set to `None` then pagination is turned off. If unset this uses the same value as the `PAGINATE_BY` setting, which defaults to `None`. -* `paginate_by_param` - The name of a query parameter, which can be used by the client to overide the default page size to use for pagination. If unset this uses the same value as the `PAGINATE_BY_PARAM` setting, which defaults to `None`. - -## SingleObjectAPIView - -Provides a base view for acting on a single object, by combining REST framework's `APIView`, and Django's [SingleObjectMixin]. - -**See also:** ccbv.co.uk documentation for [SingleObjectMixin][single-object-mixin-classy]. - -**Attributes**: - -* `queryset` - The queryset that should be used when retrieving an object from this view. If unset, defaults to the default queryset manager for `self.model`. -* `pk_kwarg` - The URL kwarg that should be used to look up objects by primary key. Defaults to `'pk'`. [Can only be set to non-default on Django 1.4+] -* `slug_url_kwarg` - The URL kwarg that should be used to look up objects by a slug. Defaults to `'slug'`. [Can only be set to non-default on Django 1.4+] -* `slug_field` - The field on the model that should be used to look up objects by a slug. If used, this should typically be set to a field with `unique=True`. Defaults to `'slug'`. +* `get_serializer_context(self)` - Returns a dictionary containing any extra context that should be supplied to the serializer. Defaults to including `'request'`, `'view'` and `'format'` keys. +* `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False)` - Returns a serializer instance. +* `get_pagination_serializer(self, page)` - Returns a serializer instance to use with paginated data. +* `paginate_queryset(self, queryset)` - Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view. +* `filter_queryset(self, queryset)` - Given a queryset, filter it with whichever filter backends are in use, returning a new queryset. --- # Mixins -The mixin classes provide the actions that are used to provide the basic view behaviour. Note that the mixin classes provide action methods rather than defining the handler methods such as `.get()` and `.post()` directly. This allows for more flexible composition of behaviour. +The mixin classes provide the actions that are used to provide the basic view behavior. Note that the mixin classes provide action methods rather than defining the handler methods such as `.get()` and `.post()` directly. This allows for more flexible composition of behavior. ## ListModelMixin @@ -165,9 +174,7 @@ Provides a `.list(request, *args, **kwargs)` method, that implements listing a q If the queryset is populated, this returns a `200 OK` response, with a serialized representation of the queryset as the body of the response. The response data may optionally be paginated. -If the queryset is empty this returns a `200 OK` reponse, unless the `.allow_empty` attribute on the view is set to `False`, in which case it will return a `404 Not Found`. - -Should be mixed in with [MultipleObjectAPIView]. +If the queryset is empty this returns a `200 OK` response, unless the `.allow_empty` attribute on the view is set to `False`, in which case it will return a `404 Not Found`. ## CreateModelMixin @@ -177,45 +184,157 @@ If an object is created this returns a `201 Created` response, with a serialized If the request data provided for creating the object was invalid, a `400 Bad Request` response will be returned, with the error details as the body of the response. -Should be mixed in with any [GenericAPIView]. - ## RetrieveModelMixin Provides a `.retrieve(request, *args, **kwargs)` method, that implements returning an existing model instance in a response. -If an object can be retrieve this returns a `200 OK` response, with a serialized representation of the object as the body of the response. Otherwise it will return a `404 Not Found`. - -Should be mixed in with [SingleObjectAPIView]. +If an object can be retrieved this returns a `200 OK` response, with a serialized representation of the object as the body of the response. Otherwise it will return a `404 Not Found`. ## UpdateModelMixin Provides a `.update(request, *args, **kwargs)` method, that implements updating and saving an existing model instance. +Also provides a `.partial_update(request, *args, **kwargs)` method, which is similar to the `update` method, except that all fields for the update will be optional. This allows support for HTTP `PATCH` requests. + If an object is updated this returns a `200 OK` response, with a serialized representation of the object as the body of the response. If an object is created, for example when making a `DELETE` request followed by a `PUT` request to the same URL, this returns a `201 Created` response, with a serialized representation of the object as the body of the response. If the request data provided for updating the object was invalid, a `400 Bad Request` response will be returned, with the error details as the body of the response. -Should be mixed in with [SingleObjectAPIView]. - ## DestroyModelMixin Provides a `.destroy(request, *args, **kwargs)` method, that implements deletion of an existing model instance. If an object is deleted this returns a `204 No Content` response, otherwise it will return a `404 Not Found`. -Should be mixed in with [SingleObjectAPIView]. +--- + +# Concrete View Classes + +The following classes are the concrete generic views. If you're using generic views this is normally the level you'll be working at unless you need heavily customized behavior. + +## CreateAPIView + +Used for **create-only** endpoints. + +Provides a `post` method handler. + +Extends: [GenericAPIView], [CreateModelMixin] + +## ListAPIView + +Used for **read-only** endpoints to represent a **collection of model instances**. + +Provides a `get` method handler. + +Extends: [GenericAPIView], [ListModelMixin] + +## RetrieveAPIView + +Used for **read-only** endpoints to represent a **single model instance**. + +Provides a `get` method handler. + +Extends: [GenericAPIView], [RetrieveModelMixin] + +## DestroyAPIView + +Used for **delete-only** endpoints for a **single model instance**. + +Provides a `delete` method handler. + +Extends: [GenericAPIView], [DestroyModelMixin] + +## UpdateAPIView + +Used for **update-only** endpoints for a **single model instance**. + +Provides `put` and `patch` method handlers. + +Extends: [GenericAPIView], [UpdateModelMixin] + +## ListCreateAPIView + +Used for **read-write** endpoints to represent a **collection of model instances**. + +Provides `get` and `post` method handlers. + +Extends: [GenericAPIView], [ListModelMixin], [CreateModelMixin] + +## RetrieveUpdateAPIView + +Used for **read or update** endpoints to represent a **single model instance**. + +Provides `get`, `put` and `patch` method handlers. + +Extends: [GenericAPIView], [RetrieveModelMixin], [UpdateModelMixin] + +## RetrieveDestroyAPIView + +Used for **read or delete** endpoints to represent a **single model instance**. + +Provides `get` and `delete` method handlers. + +Extends: [GenericAPIView], [RetrieveModelMixin], [DestroyModelMixin] + +## RetrieveUpdateDestroyAPIView + +Used for **read-write-delete** endpoints to represent a **single model instance**. + +Provides `get`, `put`, `patch` and `delete` method handlers. + +Extends: [GenericAPIView], [RetrieveModelMixin], [UpdateModelMixin], [DestroyModelMixin] + +--- + +# Customizing the generic views + +Often you'll want to use the existing generic views, but use some slightly customized behavior. If you find yourself reusing some bit of customized behavior in multiple places, you might want to refactor the behavior into a common class that you can then just apply to any view or viewset as needed. + +## Creating custom mixins + +For example, if you need to lookup objects based on multiple fields in the URL conf, you could create a mixin class like the following: + + class MultipleFieldLookupMixin(object): + """ + Apply this mixin to any view or viewset to get multiple field filtering + based on a `lookup_fields` attribute, instead of the default single field filtering. + """ + def get_object(self): + queryset = self.get_queryset() # Get the base queryset + queryset = self.filter_queryset(queryset) # Apply any filter backends + filter = {} + for field in self.lookup_fields: + filter[field] = self.kwargs[field] + return get_object_or_404(queryset, **filter) # Lookup the object + +You can then simply apply this mixin to a view or viewset anytime you need to apply the custom behavior. + + class RetrieveUserView(MultipleFieldLookupMixin, generics.RetrieveAPIView): + queryset = User.objects.all() + serializer_class = UserSerializer + lookup_fields = ('account', 'username') + +Using custom mixins is a good option if you have custom behavior that needs to be used + +## Creating custom base classes + +If you are using a mixin across multiple views, you can take this a step further and create your own set of base views that can then be used throughout your project. For example: + + class BaseRetrieveView(MultipleFieldLookupMixin, + generics.RetrieveAPIView): + pass + + class BaseRetrieveUpdateDestroyView(MultipleFieldLookupMixin, + generics.RetrieveUpdateDestroyAPIView): + pass + +Using custom base classes is a good option if you have custom behavior that consistently needs to be repeated across a large number of views throughout your project. [cite]: https://docs.djangoproject.com/en/dev/ref/class-based-views/#base-vs-generic-views -[MultipleObjectMixin]: https://docs.djangoproject.com/en/dev/ref/class-based-views/mixins-multiple-object/ -[SingleObjectMixin]: https://docs.djangoproject.com/en/dev/ref/class-based-views/mixins-single-object/ -[multiple-object-mixin-classy]: http://ccbv.co.uk/projects/Django/1.4/django.views.generic.list/MultipleObjectMixin/ -[single-object-mixin-classy]: http://ccbv.co.uk/projects/Django/1.4/django.views.generic.detail/SingleObjectMixin/ [GenericAPIView]: #genericapiview -[SingleObjectAPIView]: #singleobjectapiview -[MultipleObjectAPIView]: #multipleobjectapiview [ListModelMixin]: #listmodelmixin [CreateModelMixin]: #createmodelmixin [RetrieveModelMixin]: #retrievemodelmixin diff --git a/docs/api-guide/pagination.md b/docs/api-guide/pagination.md index ab335e6e..912ce41b 100644 --- a/docs/api-guide/pagination.md +++ b/docs/api-guide/pagination.md @@ -37,7 +37,7 @@ We could now return that data in a `Response` object, and it would be rendered i ## Paginating QuerySets -Our first example worked because we were using primative objects. If we wanted to paginate a queryset or other complex data, we'd need to specify a serializer to use to serialize the result set itself with. +Our first example worked because we were using primitive objects. If we wanted to paginate a queryset or other complex data, we'd need to specify a serializer to use to serialize the result set itself. We can do this using the `object_serializer_class` attribute on the inner `Meta` class of the pagination serializer. For example. @@ -93,10 +93,13 @@ The default pagination style may be set globally, using the `DEFAULT_PAGINATION_ You can also set the pagination style on a per-view basis, using the `ListAPIView` generic class-based view. class PaginatedListView(ListAPIView): - model = ExampleModel + queryset = ExampleModel.objects.all() + serializer_class = ExampleModelSerializer paginate_by = 10 paginate_by_param = 'page_size' +Note that using a `paginate_by` value of `None` will turn off pagination for the view. + For more complex requirements such as serialization that differs depending on the requested media type you can override the `.get_paginate_by()` and `.get_pagination_serializer_class()` methods. --- @@ -112,8 +115,8 @@ You can also override the name used for the object list field, by setting the `r For example, to nest a pair of links labelled 'prev' and 'next', and set the name for the results field to 'objects', you might use something like this. class LinksSerializer(serializers.Serializer): - next = pagination.NextURLField(source='*') - prev = pagination.PreviousURLField(source='*') + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') class CustomPaginationSerializer(pagination.BasePaginationSerializer): links = LinksSerializer(source='*') # Takes the page object as the source diff --git a/docs/api-guide/parsers.md b/docs/api-guide/parsers.md index 185b616c..5bd79a31 100644 --- a/docs/api-guide/parsers.md +++ b/docs/api-guide/parsers.md @@ -14,6 +14,16 @@ REST framework includes a number of built in Parser classes, that allow you to a The set of valid parsers for a view is always defined as a list of classes. When either `request.DATA` or `request.FILES` is accessed, REST framework will examine the `Content-Type` header on the incoming request, and determine which parser to use to parse the request content. +--- + +**Note**: When developing client applications always remember to make sure you're setting the `Content-Type` header when sending data in an HTTP request. + +If you don't set the content type, most clients will default to using `'application/x-www-form-urlencoded'`, which may not be what you wanted. + +As an example, if you are sending `json` encoded data using jQuery with the [.ajax() method][jquery-ajax], you should make sure to include the `contentType: 'application/json'` setting. + +--- + ## Setting the parsers The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSES` setting. For example, the following settings would allow requests with `YAML` content. @@ -24,7 +34,8 @@ The default set of parsers may be set globally, using the `DEFAULT_PARSER_CLASSE ) } -You can also set the renderers used for an individual view, using the `APIView` class based views. +You can also set the renderers used for an individual view, or viewset, +using the `APIView` class based views. class ExampleView(APIView): """ @@ -59,6 +70,8 @@ Parses `JSON` request content. Parses `YAML` request content. +Requires the `pyyaml` package to be installed. + **.media_type**: `application/yaml` ## XMLParser @@ -69,6 +82,8 @@ Note that the `XML` markup language is typically used as the base language for m If you are considering using `XML` for your API, you may want to consider implementing a custom renderer and parser for your specific requirements, and using an existing domain-specific media-type, or creating your own custom XML-based media-type. +Requires the `defusedxml` package to be installed. + **.media_type**: `application/xml` ## FormParser @@ -87,6 +102,33 @@ You will typically want to use both `FormParser` and `MultiPartParser` together **.media_type**: `multipart/form-data` +## FileUploadParser + +Parses raw file upload content. The `request.DATA` property will be an empty `QueryDict`, and `request.FILES` will be a dictionary with a single key `'file'` containing the uploaded file. + +If the view used with `FileUploadParser` is called with a `filename` URL keyword argument, then that argument will be used as the filename. If it is called without a `filename` URL keyword argument, then the client must set the filename in the `Content-Disposition` HTTP header. For example `Content-Disposition: attachment; filename=upload.jpg`. + +**.media_type**: `*/*` + +##### Notes: + +* The `FileUploadParser` is for usage with native clients that can upload the file as a raw data request. For web-based uploads, or for native clients with multipart upload support, you should use the `MultiPartParser` parser instead. +* Since this parser's `media_type` matches any content type, `FileUploadParser` should generally be the only parser set on an API view. +* `FileUploadParser` respects Django's standard `FILE_UPLOAD_HANDLERS` setting, and the `request.upload_handlers` attribute. See the [Django documentation][upload-handlers] for more details. + +##### Basic usage example: + + class FileUploadView(views.APIView): + parser_classes = (FileUploadParser,) + + def put(self, request, filename, format=None): + file_obj = request.FILES['file'] + # ... + # do some staff with uploaded file + # ... + return Response(status=204) + + --- # Custom parsers @@ -130,33 +172,19 @@ The following is an example plaintext parser that will populate the `request.DAT """ return stream.read() -## Uploading file content - -If your custom parser needs to support file uploads, you may return a `DataAndFiles` object from the `.parse()` method. `DataAndFiles` should be instantiated with two arguments. The first argument will be used to populate the `request.DATA` property, and the second argument will be used to populate the `request.FILES` property. - -For example: +--- - class SimpleFileUploadParser(BaseParser): - """ - A naive raw file upload parser. - """ - media_type = '*/*' # Accept anything +# Third party packages - def parse(self, stream, media_type=None, parser_context=None): - content = stream.read() - name = 'example.dat' - content_type = 'application/octet-stream' - size = len(content) - charset = 'utf-8' +The following third party packages are also available. - # Write a temporary file based on the request content - temp = tempfile.NamedTemporaryFile(delete=False) - temp.write(content) - uploaded = UploadedFile(temp, name, content_type, size, charset) +## MessagePack - # Return the uploaded file - data = {} - files = {name: uploaded} - return DataAndFiles(data, files) +[MessagePack][messagepack] is a fast, efficient binary serialization format. [Juan Riaza][juanriaza] maintains the [djangorestframework-msgpack][djangorestframework-msgpack] package which provides MessagePack renderer and parser support for REST framework. +[jquery-ajax]: http://api.jquery.com/jQuery.ajax/ [cite]: https://groups.google.com/d/topic/django-developers/dxI4qVzrBY4/discussion +[upload-handlers]: https://docs.djangoproject.com/en/dev/topics/http/file-uploads/#upload-handlers +[messagepack]: https://github.com/juanriaza/django-rest-framework-msgpack +[juanriaza]: https://github.com/juanriaza +[djangorestframework-msgpack]: https://github.com/juanriaza/django-rest-framework-msgpack diff --git a/docs/api-guide/permissions.md b/docs/api-guide/permissions.md index fce68f6d..db0d4b26 100644 --- a/docs/api-guide/permissions.md +++ b/docs/api-guide/permissions.md @@ -21,7 +21,12 @@ If any permission check fails an `exceptions.PermissionDenied` exception will be REST framework permissions also support object-level permissioning. Object level permissions are used to determine if a user should be allowed to act on a particular object, which will typically be a model instance. -Object level permissions are run by REST framework's generic views when `.get_object()` is called. As with view level permissions, an `exceptions.PermissionDenied` exception will be raised if the user is not allowed to act on the given object. +Object level permissions are run by REST framework's generic views when `.get_object()` is called. +As with view level permissions, an `exceptions.PermissionDenied` exception will be raised if the user is not allowed to act on the given object. + +If you're writing your own views and want to enforce object level permissions, +you'll need to explicitly call the `.check_object_permissions(request, obj)` method on the view at the point at which you've retrieved the object. +This will either raise a `PermissionDenied` or `NotAuthenticated` exception, or simply return if the view has the appropriate permissions. ## Setting the permission policy @@ -39,7 +44,8 @@ If not specified, this setting defaults to allowing unrestricted access: 'rest_framework.permissions.AllowAny', ) -You can also set the authentication policy on a per-view basis, using the `APIView` class based views. +You can also set the authentication policy on a per-view, or per-viewset basis, +using the `APIView` class based views. class ExampleView(APIView): permission_classes = (IsAuthenticated,) @@ -90,29 +96,104 @@ This permission is suitable if you want to your API to allow read permissions to ## DjangoModelPermissions -This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. When applied to a view that has a `.model` property, authorization will only be granted if the user has the relevant model permissions assigned. +This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. When applied to a view that has a `.model` property, authorization will only be granted if the user *is authenticated* and has the *relevant model permissions* assigned. * `POST` requests require the user to have the `add` permission on the model. * `PUT` and `PATCH` requests require the user to have the `change` permission on the model. * `DELETE` requests require the user to have the `delete` permission on the model. - + The default behaviour can also be overridden to support custom model permissions. For example, you might want to include a `view` model permission for `GET` requests. To use custom model permissions, override `DjangoModelPermissions` and set the `.perms_map` property. Refer to the source code for details. -The `DjangoModelPermissions` class also supports object-level permissions. Third-party authorization backends such as [django-guardian][guardian] that provide object-level permissions should work just fine with `DjangoModelPermissions` without any custom configuration required. +## DjangoModelPermissionsOrAnonReadOnly + +Similar to `DjangoModelPermissions`, but also allows unauthenticated users to have read-only access to the API. + +## TokenHasReadWriteScope + +This permission class is intended for use with either of the `OAuthAuthentication` and `OAuth2Authentication` classes, and ties into the scoping that their backends provide. + +Requests with a safe methods of `GET`, `OPTIONS` or `HEAD` will be allowed if the authenticated token has read permission. + +Requests for `POST`, `PUT`, `PATCH` and `DELETE` will be allowed if the authenticated token has write permission. + +This permission class relies on the implementations of the [django-oauth-plus][django-oauth-plus] and [django-oauth2-provider][django-oauth2-provider] libraries, which both provide limited support for controlling the scope of access tokens: + +* `django-oauth-plus`: Tokens are associated with a `Resource` class which has a `name`, `url` and `is_readonly` properties. +* `django-oauth2-provider`: Tokens are associated with a bitwise `scope` attribute, that defaults to providing bitwise values for `read` and/or `write`. + +If you require more advanced scoping for your API, such as restricting tokens to accessing a subset of functionality of your API then you will need to provide a custom permission class. See the source of the `django-oauth-plus` or `django-oauth2-provider` package for more details on scoping token access. --- # Custom permissions -To implement a custom permission, override `BasePermission` and implement the `.has_permission(self, request, view, obj=None)` method. +To implement a custom permission, override `BasePermission` and implement either, or both, of the following methods: + +* `.has_permission(self, request, view)` +* `.has_object_permission(self, request, view, obj)` + +The methods should return `True` if the request should be granted access, and `False` otherwise. + +If you need to test if a request is a read operation or a write operation, you should check the request method against the constant `SAFE_METHODS`, which is a tuple containing `'GET'`, `'OPTIONS'` and `'HEAD'`. For example: + + if request.method in permissions.SAFE_METHODS: + # Check permissions for read-only request + else: + # Check permissions for write request + +--- + +**Note**: In versions 2.0 and 2.1, the signature for the permission checks always included an optional `obj` parameter, like so: `.has_permission(self, request, view, obj=None)`. The method would be called twice, first for the global permission checks, with no object supplied, and second for the object-level check when required. + +As of version 2.2 this signature has now been replaced with two seperate method calls, which is more explict and obvious. The old style signature continues to work, but it's use will result in a `PendingDeprecationWarning`, which is silent by default. In 2.3 this will be escalated to a `DeprecationWarning`, and in 2.4 the old-style signature will be removed. + +For more details see the [2.2 release announcement][2.2-announcement]. + +--- + +## Examples + +The following is an example of a permission class that checks the incoming request's IP address against a blacklist, and denies the request if the IP has been blacklisted. + + class BlacklistPermission(permissions.BasePermission): + """ + Global permission check for blacklisted IPs. + """ + + def has_permission(self, request, view): + ip_addr = request.META['REMOTE_ADDR'] + blacklisted = Blacklist.objects.filter(ip_addr=ip_addr).exists() + return not blacklisted + +As well as global permissions, that are run against all incoming requests, you can also create object-level permissions, that are only run against operations that affect a particular object instance. For example: + + class IsOwnerOrReadOnly(permissions.BasePermission): + """ + Object-level permission to only allow owners of an object to edit it. + Assumes the model instance has an `owner` attribute. + """ + + def has_object_permission(self, request, view, obj): + # Read permissions are allowed to any request, + # so we'll always allow GET, HEAD or OPTIONS requests. + if request.method in permissions.SAFE_METHODS: + return True + + # Instance must have an attribute named `owner`. + return obj.owner == request.user -The method should return `True` if the request should be granted access, and `False` otherwise. +Note that the generic views will check the appropriate object level permissions, but if you're writing your own custom views, you'll need to make sure you check the object level permission checks yourself. You can do so by calling `self.check_object_permissions(request, obj)` from the view once you have the object instance. This call will raise an appropriate `APIException` if any object-level permission checks fail, and will otherwise simply return. +Also note that the generic views will only check the object-level permissions for views that retrieve a single model instance. If you require object-level filtering of list views, you'll need to filter the queryset separately. See the [filtering documentation][filtering] for more details. [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [authentication]: authentication.md [throttling]: throttling.md [contribauth]: https://docs.djangoproject.com/en/1.0/topics/auth/#permissions [guardian]: https://github.com/lukaszb/django-guardian +[django-oauth-plus]: http://code.larlet.fr/django-oauth-plus +[django-oauth2-provider]: https://github.com/caffeinehit/django-oauth2-provider +[2.2-announcement]: ../topics/2.2-announcement.md +[filtering]: filtering.md diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md new file mode 100644 index 00000000..155c89de --- /dev/null +++ b/docs/api-guide/relations.md @@ -0,0 +1,440 @@ +<a class="github" href="relations.py"></a> + +# Serializer relations + +> Bad programmers worry about the code. +> Good programmers worry about data structures and their relationships. +> +> — [Linus Torvalds][cite] + + +Relational fields are used to represent model relationships. They can be applied to `ForeignKey`, `ManyToManyField` and `OneToOneField` relationships, as well as to reverse relationships, and custom relationships such as `GenericForeignKey`. + +--- + +**Note:** The relational fields are declared in `relations.py`, but by convention you should import them from the `serializers` module, using `from rest_framework import serializers` and refer to fields as `serializers.<FieldName>`. + +--- + +# API Reference + +In order to explain the various types of relational fields, we'll use a couple of simple models for our examples. Our models will be for music albums, and the tracks listed on each album. + + class Album(models.Model): + album_name = models.CharField(max_length=100) + artist = models.CharField(max_length=100) + + class Track(models.Model): + album = models.ForeignKey(Album, related_name='tracks') + order = models.IntegerField() + title = models.CharField(max_length=100) + duration = models.IntegerField() + + class Meta: + unique_together = ('album', 'order') + order_by = 'order' + + def __unicode__(self): + return '%d: %s' % (self.order, self.title) + +## RelatedField + +`RelatedField` may be used to represent the target of the relationship using it's `__unicode__` method. + +For example, the following serializer. + + class AlbumSerializer(serializers.ModelSerializer): + tracks = RelatedField(many=True) + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Would serialize to the following representation. + + { + 'album_name': 'Things We Lost In The Fire', + 'artist': 'Low' + 'tracks': [ + '1: Sunflower', + '2: Whitetail', + '3: Dinosaur Act', + ... + ] + } + +This field is read only. + +**Arguments**: + +* `many` - If applied to a to-many relationship, you should set this argument to `True`. + +## PrimaryKeyRelatedField + +`PrimaryKeyRelatedField` may be used to represent the target of the relationship using it's primary key. + +For example, the following serializer: + + class AlbumSerializer(serializers.ModelSerializer): + tracks = PrimaryKeyRelatedField(many=True, read_only=True) + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Would serialize to a representation like this: + + { + 'album_name': 'The Roots', + 'artist': 'Undun' + 'tracks': [ + 89, + 90, + 91, + ... + ] + } + +By default this field is read-write, although you can change this behavior using the `read_only` flag. + +**Arguments**: + +* `many` - If applied to a to-many relationship, you should set this argument to `True`. +* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships. +* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. + +## HyperlinkedRelatedField + +`HyperlinkedRelatedField` may be used to represent the target of the relationship using a hyperlink. + +For example, the following serializer: + + class AlbumSerializer(serializers.ModelSerializer): + tracks = HyperlinkedRelatedField(many=True, read_only=True, + view_name='track-detail') + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Would serialize to a representation like this: + + { + 'album_name': 'Graceland', + 'artist': 'Paul Simon' + 'tracks': [ + 'http://www.example.com/api/tracks/45/', + 'http://www.example.com/api/tracks/46/', + 'http://www.example.com/api/tracks/47/', + ... + ] + } + +By default this field is read-write, although you can change this behavior using the `read_only` flag. + +**Arguments**: + +* `view_name` - The view name that should be used as the target of the relationship. **required**. +* `many` - If applied to a to-many relationship, you should set this argument to `True`. +* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships. +* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. +* `lookup_field` - The field on the target that should be used for the lookup. Should correspond to a URL keyword argument on the referenced view. Default is `'pk'`. +* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument. + +## SlugRelatedField + +`SlugRelatedField` may be used to represent the target of the relationship using a field on the target. + +For example, the following serializer: + + class AlbumSerializer(serializers.ModelSerializer): + tracks = SlugRelatedField(many=True, read_only=True, slug_field='title') + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Would serialize to a representation like this: + + { + 'album_name': 'Dear John', + 'artist': 'Loney Dear' + 'tracks': [ + 'Airport Surroundings', + 'Everything Turns to You', + 'I Was Only Going Out', + ... + ] + } + +By default this field is read-write, although you can change this behavior using the `read_only` flag. + +When using `SlugRelatedField` as a read-write field, you will normally want to ensure that the slug field corresponds to a model field with `unique=True`. + +**Arguments**: + +* `slug_field` - The field on the target that should be used to represent it. This should be a field that uniquely identifies any given instance. For example, `username`. **required** +* `many` - If applied to a to-many relationship, you should set this argument to `True`. +* `required` - If set to `False`, the field will accept values of `None` or the empty-string for nullable relationships. +* `queryset` - By default `ModelSerializer` classes will use the default queryset for the relationship. `Serializer` classes must either set a queryset explicitly, or set `read_only=True`. + +## HyperlinkedIdentityField + +This field can be applied as an identity relationship, such as the `'url'` field on a HyperlinkedModelSerializer. It can also be used for an attribute on the object. For example, the following serializer: + + class AlbumSerializer(serializers.HyperlinkedModelSerializer): + track_listing = HyperlinkedIdentityField(view_name='track-list') + + class Meta: + model = Album + fields = ('album_name', 'artist', 'track_listing') + +Would serialize to a representation like this: + + { + 'album_name': 'The Eraser', + 'artist': 'Thom Yorke' + 'track_listing': 'http://www.example.com/api/track_list/12/', + } + +This field is always read-only. + +**Arguments**: + +* `view_name` - The view name that should be used as the target of the relationship. **required**. +* `lookup_field` - The field on the target that should be used for the lookup. Should correspond to a URL keyword argument on the referenced view. Default is `'pk'`. +* `format` - If using format suffixes, hyperlinked fields will use the same format suffix for the target unless overridden by using the `format` argument. + +--- + +# Nested relationships + +Nested relationships can be expressed by using serializers as fields. + +If the field is used to represent a to-many relationship, you should add the `many=True` flag to the serializer field. + +Note that nested relationships are currently read-only. For read-write relationships, you should use a flat relational style. + +## Example + +For example, the following serializer: + + class TrackSerializer(serializers.ModelSerializer): + class Meta: + model = Track + fields = ('order', 'title') + + class AlbumSerializer(serializers.ModelSerializer): + tracks = TrackSerializer(many=True) + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Would serialize to a nested representation like this: + + { + 'album_name': 'The Grey Album', + 'artist': 'Danger Mouse' + 'tracks': [ + {'order': 1, 'title': 'Public Service Annoucement'}, + {'order': 2, 'title': 'What More Can I Say'}, + {'order': 3, 'title': 'Encore'}, + ... + ], + } + +# Custom relational fields + +To implement a custom relational field, you should override `RelatedField`, and implement the `.to_native(self, value)` method. This method takes the target of the field as the `value` argument, and should return the representation that should be used to serialize the target. + +If you want to implement a read-write relational field, you must also implement the `.from_native(self, data)` method, and add `read_only = False` to the class definition. + +## Example + +For, example, we could define a relational field, to serialize a track to a custom string representation, using it's ordering, title, and duration. + + import time + + class TrackListingField(serializers.RelatedField): + def to_native(self, value): + duration = time.strftime('%M:%S', time.gmtime(value.duration)) + return 'Track %d: %s (%s)' % (value.order, value.name, duration) + + class AlbumSerializer(serializers.ModelSerializer): + tracks = TrackListingField(many=True) + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +This custom field would then serialize to the following representation. + + { + 'album_name': 'Sometimes I Wish We Were an Eagle', + 'artist': 'Bill Callahan' + 'tracks': [ + 'Track 1: Jim Cain (04:39)', + 'Track 2: Eid Ma Clack Shaw (04:19)', + 'Track 3: The Wind and the Dove (04:34)', + ... + ] + } + +--- + +# Further notes + +## Reverse relations + +Note that reverse relationships are not automatically included by the `ModelSerializer` and `HyperlinkedModelSerializer` classes. To include a reverse relationship, you must explicitly add it to the fields list. For example: + + class AlbumSerializer(serializers.ModelSerializer): + class Meta: + fields = ('tracks', ...) + +You'll normally want to ensure that you've set an appropriate `related_name` argument on the relationship, that you can use as the field name. For example: + + class Track(models.Model): + album = models.ForeignKey(Album, related_name='tracks') + ... + +If you have not set a related name for the reverse relationship, you'll need to use the automatically generated related name in the `fields` argument. For example: + + class AlbumSerializer(serializers.ModelSerializer): + class Meta: + fields = ('track_set', ...) + +See the Django documentation on [reverse relationships][reverse-relationships] for more details. + +## Generic relationships + +If you want to serialize a generic foreign key, you need to define a custom field, to determine explicitly how you want serialize the targets of the relationship. + +For example, given the following model for a tag, which has a generic relationship with other arbitrary models: + + class TaggedItem(models.Model): + """ + Tags arbitrary model instances using a generic relation. + + See: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/ + """ + tag_name = models.SlugField() + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + tagged_object = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag + +And the following two models, which may be have associated tags: + + class Bookmark(models.Model): + """ + A bookmark consists of a URL, and 0 or more descriptive tags. + """ + url = models.URLField() + tags = GenericRelation(TaggedItem) + + + class Note(models.Model): + """ + A note consists of some text, and 0 or more descriptive tags. + """ + text = models.CharField(max_length=1000) + tags = GenericRelation(TaggedItem) + +We could define a custom field that could be used to serialize tagged instances, using the type of each instance to determine how it should be serialized. + + class TaggedObjectRelatedField(serializers.RelatedField): + """ + A custom field to use for the `tagged_object` generic relationship. + """ + + def to_native(self, value): + """ + Serialize tagged objects to a simple textual representation. + """ + if isinstance(value, Bookmark): + return 'Bookmark: ' + value.url + elif isinstance(value, Note): + return 'Note: ' + value.text + raise Exception('Unexpected type of tagged object') + +If you need the target of the relationship to have a nested representation, you can use the required serializers inside the `.to_native()` method: + + def to_native(self, value): + """ + Serialize bookmark instances using a bookmark serializer, + and note instances using a note serializer. + """ + if isinstance(value, Bookmark): + serializer = BookmarkSerializer(value) + elif isinstance(value, Note): + serializer = NoteSerializer(value) + else: + raise Exception('Unexpected type of tagged object') + + return serializer.data + +Note that reverse generic keys, expressed using the `GenericRelation` field, can be serialized using the regular relational field types, since the type of the target in the relationship is always known. + +For more information see [the Django documentation on generic relations][generic-relations]. + +## Advanced Hyperlinked fields + +If you have very specific requirements for the style of your hyperlinked relationships you can override `HyperlinkedRelatedField`. + +There are two methods you'll need to override. + +#### get_url(self, obj, view_name, request, format) + +This method should return the URL that corresponds to the given object. + +May raise a `NoReverseMatch` if the `view_name` and `lookup_field` +attributes are not configured to correctly match the URL conf. + +#### get_object(self, queryset, view_name, view_args, view_kwargs) + + +This method should the object that corresponds to the matched URL conf arguments. + +May raise an `ObjectDoesNotExist` exception. + +### Example + +For example, if all your object URLs used both a account and a slug in the the URL to reference the object, you might create a custom field like this: + + class CustomHyperlinkedField(serializers.HyperlinkedRelatedField): + def get_url(self, obj, view_name, request, format): + kwargs = {'account': obj.account, 'slug': obj.slug} + return reverse(view_name, kwargs=kwargs, request=request, format=format) + + def get_object(self, queryset, view_name, view_args, view_kwargs): + account = view_kwargs['account'] + slug = view_kwargs['slug'] + return queryset.get(account=account, slug=sug) + +--- + +## Deprecated APIs + +The following classes have been deprecated, in favor of the `many=<bool>` syntax. +They continue to function, but their usage will raise a `PendingDeprecationWarning`, which is silent by default. + +* `ManyRelatedField` +* `ManyPrimaryKeyRelatedField` +* `ManyHyperlinkedRelatedField` +* `ManySlugRelatedField` + +The `null=<bool>` flag has been deprecated in favor of the `required=<bool>` flag. It will continue to function, but will raise a `PendingDeprecationWarning`. + +In the 2.3 release, these warnings will be escalated to a `DeprecationWarning`, which is loud by default. +In the 2.4 release, these parts of the API will be removed entirely. + +For more details see the [2.2 release announcement][2.2-announcement]. + +[cite]: http://lwn.net/Articles/193245/ +[reverse-relationships]: https://docs.djangoproject.com/en/dev/topics/db/queries/#following-relationships-backward +[generic-relations]: https://docs.djangoproject.com/en/dev/ref/contrib/contenttypes/#id1 +[2.2-announcement]: ../topics/2.2-announcement.md diff --git a/docs/api-guide/renderers.md b/docs/api-guide/renderers.md index 374ff0ab..ed733c65 100644 --- a/docs/api-guide/renderers.md +++ b/docs/api-guide/renderers.md @@ -27,7 +27,8 @@ The default set of renderers may be set globally, using the `DEFAULT_RENDERER_CL ) } -You can also set the renderers used for an individual view, using the `APIView` class based views. +You can also set the renderers used for an individual view, or viewset, +using the `APIView` class based views. class UserCountView(APIView): """ @@ -56,7 +57,7 @@ Or, if you're using the `@api_view` decorator with function based views. It's important when specifying the renderer classes for your API to think about what priority you want to assign to each media type. If a client underspecifies the representations it can accept, such as sending an `Accept: */*` header, or not including an `Accept` header at all, then REST framework will select the first renderer in the list to use for the response. -For example if your API serves JSON responses and the HTML browseable API, you might want to make `JSONRenderer` your default renderer, in order to send `JSON` responses to clients that do not specify an `Accept` header. +For example if your API serves JSON responses and the HTML browsable API, you might want to make `JSONRenderer` your default renderer, in order to send `JSON` responses to clients that do not specify an `Accept` header. If your API includes views that can serve both regular webpages and API responses depending on the request, then you might consider making `TemplateHTMLRenderer` your default renderer, in order to play nicely with older browsers that send [broken accept headers][browser-accept-headers]. @@ -80,7 +81,7 @@ Renders the request data into `JSONP`. The `JSONP` media type provides a mechan The javascript callback function must be set by the client including a `callback` URL query parameter. For example `http://example.com/api/users?callback=jsonpCallback`. If the callback function is not explicitly set by the client it will default to `'callback'`. -**Note**: If you require cross-domain AJAX requests, you may also want to consider using [CORS] as an alternative to `JSONP`. +**Note**: If you require cross-domain AJAX requests, you may want to consider using the more modern approach of [CORS][cors] as an alternative to `JSONP`. See the [CORS documentation][cors-docs] for more details. **.media_type**: `application/javascript` @@ -90,6 +91,8 @@ The javascript callback function must be set by the client including a `callback Renders the request data into `YAML`. +Requires the `pyyaml` package to be installed. + **.media_type**: `application/yaml` **.format**: `'.yaml'` @@ -115,17 +118,17 @@ The TemplateHTMLRenderer will create a `RequestContext`, using the `response.dat The template name is determined by (in order of preference): -1. An explicit `.template_name` attribute set on the response. +1. An explicit `template_name` argument passed to the response. 2. An explicit `.template_name` attribute set on this class. 3. The return result of calling `view.get_template_names()`. An example of a view that uses `TemplateHTMLRenderer`: - class UserInstance(generics.RetrieveUserAPIView): + class UserDetail(generics.RetrieveUserAPIView): """ A view that returns a templated HTML representations of a given user. """ - model = Users + queryset = User.objects.all() renderer_classes = (TemplateHTMLRenderer,) def get(self, request, *args, **kwargs) @@ -164,7 +167,7 @@ See also: `TemplateHTMLRenderer` ## BrowsableAPIRenderer -Renders data into HTML for the Browseable API. This renderer will determine which other renderer would have been given highest priority, and use that to display an API style response within the HTML page. +Renders data into HTML for the Browsable API. This renderer will determine which other renderer would have been given highest priority, and use that to display an API style response within the HTML page. **.media_type**: `text/html` @@ -271,13 +274,32 @@ Exceptions raised and handled by an HTML renderer will attempt to render using o Templates will render with a `RequestContext` which includes the `status_code` and `details` keys. +--- + +# Third party packages + +The following third party packages are also available. + +## MessagePack + +[MessagePack][messagepack] is a fast, efficient binary serialization format. [Juan Riaza][juanriaza] maintains the [djangorestframework-msgpack][djangorestframework-msgpack] package which provides MessagePack renderer and parser support for REST framework. + +## CSV + +Comma-separated values are a plain-text tabular data format, that can be easily imported into spreadsheet applications. [Mjumbe Poe][mjumbewu] maintains the [djangorestframework-csv][djangorestframework-csv] package which provides CSV renderer support for REST framework. [cite]: https://docs.djangoproject.com/en/dev/ref/template-response/#the-rendering-process [conneg]: content-negotiation.md [browser-accept-headers]: http://www.gethifi.com/blog/browser-rest-http-accept-headers -[CORS]: http://en.wikipedia.org/wiki/Cross-origin_resource_sharing +[cors]: http://www.w3.org/TR/cors/ +[cors-docs]: ../topics/ajax-csrf-cors.md [HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas [quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven [application/vnd.github+json]: http://developer.github.com/v3/media/ [application/vnd.collection+json]: http://www.amundsen.com/media-types/collection/ -[django-error-views]: https://docs.djangoproject.com/en/dev/topics/http/views/#customizing-error-views
\ No newline at end of file +[django-error-views]: https://docs.djangoproject.com/en/dev/topics/http/views/#customizing-error-views +[messagepack]: http://msgpack.org/ +[juanriaza]: https://github.com/juanriaza +[mjumbewu]: https://github.com/mjumbewu +[djangorestframework-msgpack]: https://github.com/juanriaza/django-rest-framework-msgpack +[djangorestframework-csv]: https://github.com/mjumbewu/django-rest-framework-csv diff --git a/docs/api-guide/requests.md b/docs/api-guide/requests.md index 72932f5d..39a34fcf 100644 --- a/docs/api-guide/requests.md +++ b/docs/api-guide/requests.md @@ -83,13 +83,13 @@ You won't typically need to access this property. # Browser enhancements -REST framework supports a few browser enhancements such as browser-based `PUT` and `DELETE` forms. +REST framework supports a few browser enhancements such as browser-based `PUT`, `PATCH` and `DELETE` forms. ## .method `request.method` returns the **uppercased** string representation of the request's HTTP method. -Browser-based `PUT` and `DELETE` forms are transparently supported. +Browser-based `PUT`, `PATCH` and `DELETE` forms are transparently supported. For more information see the [browser enhancements documentation]. diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md new file mode 100644 index 00000000..6588d7e5 --- /dev/null +++ b/docs/api-guide/routers.md @@ -0,0 +1,111 @@ +<a class="github" href="routers.py"></a> + +# Routers + +> Resource routing allows you to quickly declare all of the common routes for a given resourceful controller. Instead of declaring separate routes for your index... a resourceful route declares them in a single line of code. +> +> — [Ruby on Rails Documentation][cite] + +Some Web frameworks such as Rails provide functionality for automatically determining how the URLs for an application should be mapped to the logic that deals with handling incoming requests. + +REST framework adds support for automatic URL routing to Django, and provides you with a simple, quick and consistent way of wiring your view logic to a set of URLs. + +## Usage + +Here's an example of a simple URL conf, that uses `DefaultRouter`. + + router = routers.SimpleRouter() + router.register(r'users', UserViewSet) + router.register(r'accounts', AccountViewSet) + urlpatterns = router.urls + +There are two mandatory arguments to the `register()` method: + +* `prefix` - The URL prefix to use for this set of routes. +* `viewset` - The viewset class. + +Optionally, you may also specify an additional argument: + +* `base_name` - The base to use for the URL names that are created. If unset the basename will be automatically generated based on the `model` or `queryset` attribute on the viewset, if it has one. + +The example above would generate the following URL patterns: + +* URL pattern: `^users/$` Name: `'user-list'` +* URL pattern: `^users/{pk}/$` Name: `'user-detail'` +* URL pattern: `^accounts/$` Name: `'account-list'` +* URL pattern: `^accounts/{pk}/$` Name: `'account-detail'` + +### Extra link and actions + +Any methods on the viewset decorated with `@link` or `@action` will also be routed. +For example, a given method like this on the `UserViewSet` class: + + @action(permission_classes=[IsAdminOrIsSelf]) + def set_password(self, request, pk=None): + ... + +The following URL pattern would additionally be generated: + +* URL pattern: `^users/{pk}/set_password/$` Name: `'user-set-password'` + +# API Guide + +## SimpleRouter + +This router includes routes for the standard set of `list`, `create`, `retrieve`, `update`, `partial_update` and `destroy` actions. The viewset can also mark additional methods to be routed, using the `@link` or `@action` decorators. + +<table border=1> + <tr><th>URL Style</th><th>HTTP Method</th><th>Action</th><th>URL Name</th></tr> + <tr><td rowspan=2>{prefix}/</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr> + <tr><td>POST</td><td>create</td></tr> + <tr><td rowspan=4>{prefix}/{lookup}/</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr> + <tr><td>PUT</td><td>update</td></tr> + <tr><td>PATCH</td><td>partial_update</td></tr> + <tr><td>DELETE</td><td>destroy</td></tr> + <tr><td rowspan=2>{prefix}/{lookup}/{methodname}/</td><td>GET</td><td>@link decorated method</td><td rowspan=2>{basename}-{methodname}</td></tr> + <tr><td>POST</td><td>@action decorated method</td></tr> +</table> + +## DefaultRouter + +This router is similar to `SimpleRouter` as above, but additionally includes a default API root view, that returns a response containing hyperlinks to all the list views. It also generates routes for optional `.json` style format suffixes. + +<table border=1> + <tr><th>URL Style</th><th>HTTP Method</th><th>Action</th><th>URL Name</th></tr> + <tr><td>[.format]</td><td>GET</td><td>automatically generated root view</td><td>api-root</td></tr></tr> + <tr><td rowspan=2>{prefix}/[.format]</td><td>GET</td><td>list</td><td rowspan=2>{basename}-list</td></tr></tr> + <tr><td>POST</td><td>create</td></tr> + <tr><td rowspan=4>{prefix}/{lookup}/[.format]</td><td>GET</td><td>retrieve</td><td rowspan=4>{basename}-detail</td></tr></tr> + <tr><td>PUT</td><td>update</td></tr> + <tr><td>PATCH</td><td>partial_update</td></tr> + <tr><td>DELETE</td><td>destroy</td></tr> + <tr><td rowspan=2>{prefix}/{lookup}/{methodname}/[.format]</td><td>GET</td><td>@link decorated method</td><td rowspan=2>{basename}-{methodname}</td></tr> + <tr><td>POST</td><td>@action decorated method</td></tr> +</table> + +# Custom Routers + +Implementing a custom router isn't something you'd need to do very often, but it can be useful if you have specfic requirements about how the your URLs for your API are strutured. Doing so allows you to encapsulate the URL structure in a reusable way that ensures you don't have to write your URL patterns explicitly for each new view. + +The simplest way to implement a custom router is to subclass one of the existing router classes. The `.routes` attribute is used to template the URL patterns that will be mapped to each viewset. + +## Example + +The following example will only route to the `list` and `retrieve` actions, and unlike the routers included by REST framework, it does not use the trailing slash convention. + + class ReadOnlyRouter(SimpleRouter): + """ + A router for read-only APIs, which doesn't use trailing suffixes. + """ + routes = [ + (r'^{prefix}$', {'get': 'list'}, '{basename}-list'), + (r'^{prefix}/{lookup}$', {'get': 'retrieve'}, '{basename}-detail') + ] + +## Advanced custom routers + +If you want to provide totally custom behavior, you can override `BaseRouter` and override the `get_urls(self)` method. The method should insect the registered viewsets and return a list of URL patterns. The registered prefix, viewset and basename tuples may be inspected by accessing the `self.registry` attribute. + +You may also want to override the `get_default_base_name(self, viewset)` method, or else always explicitly set the `base_name` argument when registering your viewsets with the router. + +[cite]: http://guides.rubyonrails.org/routing.html diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 19efde3c..c83a0967 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -4,8 +4,7 @@ > Expanding the usefulness of the serializers is something that we would like to address. However, it's not a trivial problem, and it -will take some serious design work. Any offers to help out in this -area would be gratefully accepted. +will take some serious design work. > > — Russell Keith-Magee, [Django users group][cite] @@ -26,6 +25,7 @@ Let's start by creating a simple object we can use for example purposes: comment = Comment(email='leila@example.com', content='foo bar') We'll declare a serializer that we can use to serialize and deserialize `Comment` objects. + Declaring a serializer looks very similar to declaring a form: class CommentSerializer(serializers.Serializer): @@ -34,14 +34,20 @@ Declaring a serializer looks very similar to declaring a form: created = serializers.DateTimeField() def restore_object(self, attrs, instance=None): + """ + Given a dictionary of deserialized field values, either update + an existing model instance, or create a new model instance. + """ if instance is not None: - instance.title = attrs['title'] - instance.content = attrs['content'] - instance.created = attrs['created'] + instance.title = attrs.get('title', instance.title) + instance.content = attrs.get('content', instance.content) + instance.created = attrs.get('created', instance.created) return instance return Comment(**attrs) -The first part of serializer class defines the fields that get serialized/deserialized. The `restore_object` method defines how fully fledged instances get created when deserializing data. The `restore_object` method is optional, and is only required if we want our serializer to support deserialization. +The first part of serializer class defines the fields that get serialized/deserialized. The `restore_object` method defines how fully fledged instances get created when deserializing data. + +The `restore_object` method is optional, and is only required if we want our serializer to support deserialization into fully fledged object instances. If we don't define this method, then deserializing data will simply return a dictionary of items. ## Serializing objects @@ -53,14 +59,15 @@ We can now use `CommentSerializer` to serialize a comment, or list of comments. At this point we've translated the model instance into python native datatypes. To finalise the serialization process we render the data into `json`. - stream = JSONRenderer().render(data) - stream + json = JSONRenderer().render(serializer.data) + json # '{"email": "leila@example.com", "content": "foo bar", "created": "2012-08-22T16:20:09.822"}' ## Deserializing objects Deserialization is similar. First we parse a stream into python native datatypes... + stream = StringIO(json) data = JSONParser().parse(stream) ...then we restore those native datatypes into a fully populated object instance. @@ -83,9 +90,19 @@ By default, serializers must be passed values for all required fields or they wi ## Validation -When deserializing data, you always need to call `is_valid()` before attempting to access the deserialized object. If any validation errors occur, the `.errors` and `.non_field_errors` properties will contain the resulting error messages. +When deserializing data, you always need to call `is_valid()` before attempting to access the deserialized object. If any validation errors occur, the `.errors` property will contain a dictionary representing the resulting error messages. For example: + + serializer = CommentSerializer(data={'email': 'foobar', 'content': 'baz'}) + serializer.is_valid() + # False + serializer.errors + # {'email': [u'Enter a valid e-mail address.'], 'created': [u'This field is required.']} + +Each key in the dictionary will be the field name, and the values will be lists of strings of any error messages corresponding to that field. The `non_field_errors` key may also be present, and will list any general validation errors. -### Field-level validation +When deserializing a list of items, errors will be returned as a list of dictionaries representing each of the deserialized items. + +#### Field-level validation You can specify custom field-level validation by adding `.validate_<fieldname>` methods to your `Serializer` subclass. These are analagous to `.clean_<fieldname>` methods on Django forms, but accept slightly different arguments. @@ -108,32 +125,65 @@ Your `validate_<fieldname>` methods should either just return the `attrs` dictio raise serializers.ValidationError("Blog post is not about Django") return attrs -### Object-level validation +#### Object-level validation + +To do any other validation that requires access to multiple fields, add a method called `.validate()` to your `Serializer` subclass. This method takes a single argument, which is the `attrs` dictionary. It should raise a `ValidationError` if necessary, or just return `attrs`. For example: + + from rest_framework import serializers -To do any other validation that requires access to multiple fields, add a method called `.validate()` to your `Serializer` subclass. This method takes a single argument, which is the `attrs` dictionary. It should raise a `ValidationError` if necessary, or just return `attrs`. + class EventSerializer(serializers.Serializer): + description = serializers.CharField(max_length=100) + start = serializers.DateTimeField() + finish = serializers.DateTimeField() + + def validate(self, attrs): + """ + Check that the start is before the stop. + """ + if attrs['start'] < attrs['finish']: + raise serializers.ValidationError("finish must occur after start") + return attrs ## Saving object state -Serializers also include a `.save()` method that you can override if you want to provide a method of persisting the state of a deserialized object. The default behavior of the method is to simply call `.save()` on the deserialized object instance. +To save the deserialized objects created by a serializer, call the `.save()` method: + + if serializer.is_valid(): + serializer.save() + +The default behavior of the method is to simply call `.save()` on the deserialized object instance. You can override the default save behaviour by overriding the `.save_object(obj)` method on the serializer class. The generic views provided by REST framework call the `.save()` method when updating or creating entities. ## Dealing with nested objects -The previous example is fine for dealing with objects that only have simple datatypes, but sometimes we also need to be able to represent more complex objects, -where some of the attributes of an object might not be simple datatypes such as strings, dates or integers. +The previous examples are fine for dealing with objects that only have simple datatypes, but sometimes we also need to be able to represent more complex objects, where some of the attributes of an object might not be simple datatypes such as strings, dates or integers. The `Serializer` class is itself a type of `Field`, and can be used to represent relationships where one object type is nested inside another. class UserSerializer(serializers.Serializer): - email = serializers.Field() - username = serializers.Field() + email = serializers.EmailField() + username = serializers.CharField(max_length=100) class CommentSerializer(serializers.Serializer): user = UserSerializer() - title = serializers.Field() - content = serializers.Field() - created = serializers.Field() + content = serializers.CharField(max_length=200) + created = serializers.DateTimeField() + +If a nested representation may optionally accept the `None` value you should pass the `required=False` flag to the nested serializer. + + class CommentSerializer(serializers.Serializer): + user = UserSerializer(required=False) # May be an anonymous user. + content = serializers.CharField(max_length=200) + created = serializers.DateTimeField() + +Similarly if a nested representation should be a list of items, you should the `many=True` flag to the nested serialized. + + class CommentSerializer(serializers.Serializer): + user = UserSerializer(required=False) + edits = EditItemSerializer(many=True) # A nested list of 'edit' items. + content = serializers.CharField(max_length=200) + created = serializers.DateTimeField() --- @@ -141,57 +191,111 @@ The `Serializer` class is itself a type of `Field`, and can be used to represent --- +## Dealing with multiple objects -## Creating custom fields +The `Serializer` class can also handle serializing or deserializing lists of objects. -If you want to create a custom field, you'll probably want to override either one or both of the `.to_native()` and `.from_native()` methods. These two methods are used to convert between the intial datatype, and a primative, serializable datatype. Primative datatypes may be any of a number, string, date/time/datetime or None. They may also be any list or dictionary like object that only contains other primative objects. +#### Serializing multiple objects -The `.to_native()` method is called to convert the initial datatype into a primative, serializable datatype. The `from_native()` method is called to restore a primative datatype into it's initial representation. +To serialize a queryset or list of objects instead of a single object instance, you should pass the `many=True` flag when instantiating the serializer. You can then pass a queryset or list of objects to be serialized. -Let's look at an example of serializing a class that represents an RGB color value: + queryset = Book.objects.all() + serializer = BookSerializer(queryset, many=True) + serializer.data + # [ + # {'id': 0, 'title': 'The electric kool-aid acid test', 'author': 'Tom Wolfe'}, + # {'id': 1, 'title': 'If this is a man', 'author': 'Primo Levi'}, + # {'id': 2, 'title': 'The wind-up bird chronicle', 'author': 'Haruki Murakami'} + # ] - class Color(object): - """ - A color represented in the RGB colorspace. - """ - def __init__(self, red, green, blue): - assert(red >= 0 and green >= 0 and blue >= 0) - assert(red < 256 and green < 256 and blue < 256) - self.red, self.green, self.blue = red, green, blue +#### Deserializing multiple objects for creation - class ColourField(serializers.WritableField): - """ - Color objects are serialized into "rgb(#, #, #)" notation. - """ - def to_native(self, obj): - return "rgb(%d, %d, %d)" % (obj.red, obj.green, obj.blue) - - def from_native(self, data): - data = data.strip('rgb(').rstrip(')') - red, green, blue = [int(col) for col in data.split(',')] - return Color(red, green, blue) - +To deserialize a list of object data, and create multiple object instances in a single pass, you should also set the `many=True` flag, and pass a list of data to be deserialized. -By default field values are treated as mapping to an attribute on the object. If you need to customize how the field value is accessed and set you need to override `.field_to_native()` and/or `.field_from_native()`. +This allows you to write views that create multiple items when a `POST` request is made. -As an example, let's create a field that can be used represent the class name of the object being serialized: +For example: - class ClassNameField(serializers.WritableField): - def field_to_native(self, obj, field_name): - """ - Serialize the object's class name, not an attribute of the object. - """ - return obj.__class__.__name__ + data = [ + {'title': 'The bell jar', 'author': 'Sylvia Plath'}, + {'title': 'For whom the bell tolls', 'author': 'Ernest Hemingway'} + ] + serializer = BookSerializer(data=data, many=True) + serializer.is_valid() + # True + serializer.save() # `.save()` will be called on each deserialized instance + +#### Deserializing multiple objects for update + +You can also deserialize a list of objects as part of a bulk update of multiple existing items. +In this case you need to supply both an existing list or queryset of items, as well as a list of data to update those items with. + +This allows you to write views that update or create multiple items when a `PUT` request is made. + + # Capitalizing the titles of the books + queryset = Book.objects.all() + data = [ + {'id': 3, 'title': 'The Bell Jar', 'author': 'Sylvia Plath'}, + {'id': 4, 'title': 'For Whom the Bell Tolls', 'author': 'Ernest Hemingway'} + ] + serializer = BookSerializer(queryset, data=data, many=True) + serializer.is_valid() + # True + serialize.save() # `.save()` will be called on each updated or newly created instance. + +By default bulk updates will be limited to updating instances that already exist in the provided queryset. + +When performing a bulk update you may want to allow new items to be created, and missing items to be deleted. To do so, pass `allow_add_remove=True` to the serializer. + + serializer = BookSerializer(queryset, data=data, many=True, allow_add_remove=True) + serializer.is_valid() + # True + serializer.save() # `.save()` will be called on updated or newly created instances. + #Â `.delete()` will be called on any other items in the `queryset`. + +Passing `allow_add_remove=True` ensures that any update operations will completely overwrite the existing queryset, rather than simply updating existing objects. - def field_from_native(self, data, field_name, into): +#### How identity is determined when performing bulk updates + +Performing a bulk update is slightly more complicated than performing a bulk creation, because the serializer needs a way to determine how the items in the incoming data should be matched against the existing object instances. + +By default the serializer class will use the `id` key on the incoming data to determine the canonical identity of an object. If you need to change this behavior you should override the `get_identity` method on the `Serializer` class. For example: + + class AccountSerializer(serializers.Serializer): + slug = serializers.CharField(max_length=100) + created = serializers.DateTimeField() + ... # Various other fields + + def get_identity(self, data): """ - We don't want to set anything when we revert this field. + This hook is required for bulk update. + We need to override the default, to use the slug as the identity. + + Note that the data has not yet been validated at this point, + so we need to deal gracefully with incorrect datatypes. """ - pass + try: + return data.get('slug', None) + except AttributeError: + return None + +To map the incoming data items to their corresponding object instances, the `.get_identity()` method will be called both against the incoming data, and against the serialized representation of the existing objects. + +## Including extra context + +There are some cases where you need to provide extra context to the serializer in addition to the object being serialized. One common case is if you're using a serializer that includes hyperlinked relations, which requires the serializer to have access to the current request so that it can properly generate fully qualified URLs. + +You can provide arbitrary additional context by passing a `context` argument when instantiating the serializer. For example: + + serializer = AccountSerializer(account, context={'request': request}) + serializer.data + # {'id': 6, 'owner': u'denvercoder9', 'created': datetime.datetime(2013, 2, 12, 09, 44, 56, 678870), 'details': 'http://example.com/accounts/6/details'} + +The context dictionary can be used within any serializer field logic, such as a custom `.to_native()` method, by accessing the `self.context` attribute. --- -# ModelSerializers +# ModelSerializer Often you'll want serializer classes that map closely to model definitions. The `ModelSerializer` class lets you automatically create a Serializer class with fields that correspond to the Model fields. @@ -200,15 +304,52 @@ The `ModelSerializer` class lets you automatically create a Serializer class wit class Meta: model = Account -**[TODO: Explain model field to serializer field mapping in more detail]** +By default, all the model fields on the class will be mapped to corresponding serializer fields. + +Any relationships such as foreign keys on the model will be mapped to `PrimaryKeyRelatedField`. Other models fields will be mapped to a corresponding serializer field. + +## Specifying which fields should be included + +If you only want a subset of the default fields to be used in a model serializer, you can do so using `fields` or `exclude` options, just as you would with a `ModelForm`. + +For example: + + class AccountSerializer(serializers.ModelSerializer): + class Meta: + model = Account + fields = ('id', 'account_name', 'users', 'created') + +## Specifying nested serialization + +The default `ModelSerializer` uses primary keys for relationships, but you can also easily generate nested representations using the `depth` option: + + class AccountSerializer(serializers.ModelSerializer): + class Meta: + model = Account + fields = ('id', 'account_name', 'users', 'created') + depth = 1 + +The `depth` option should be set to an integer value that indicates the depth of relationships that should be traversed before reverting to a flat representation. + +## Specifying which fields should be read-only + +You may wish to specify multiple fields as read-only. Instead of adding each field explicitly with the `read_only=True` attribute, you may use the `read_only_fields` Meta option, like so: + + class AccountSerializer(serializers.ModelSerializer): + class Meta: + model = Account + fields = ('id', 'account_name', 'users', 'created') + read_only_fields = ('account_name',) + +Model fields which have `editable=False` set, and `AutoField` fields will be set to read-only by default, and do not need to be added to the `read_only_fields` option. ## Specifying fields explicitly You can add extra fields to a `ModelSerializer` or override the default fields by declaring fields on the class, just as you would for a `Serializer` class. class AccountSerializer(serializers.ModelSerializer): - url = CharField(source='get_absolute_url', read_only=True) - group = NaturalKeyField() + url = serializers.CharField(source='get_absolute_url', read_only=True) + groups = serializers.PrimaryKeyRelatedField(many=True) class Meta: model = Account @@ -217,55 +358,74 @@ Extra fields can correspond to any property or callable on the model. ## Relational fields -When serializing model instances, there are a number of different ways you might choose to represent relationships. The default representation is to use the primary keys of the related instances. +When serializing model instances, there are a number of different ways you might choose to represent relationships. The default representation for `ModelSerializer` is to use the primary keys of the related instances. -Alternative representations include serializing using natural keys, serializing complete nested representations, or serializing using a custom representation, such as a URL that uniquely identifies the model instances. +Alternative representations include serializing using hyperlinks, serializing complete nested representations, or serializing with a custom representation. -The `PrimaryKeyRelatedField` and `HyperlinkedRelatedField` fields provide alternative flat representations. +For full details see the [serializer relations][relations] documentation. -The `ModelSerializer` class can itself be used as a field, in order to serialize relationships using nested representations. +--- -The `RelatedField` class may be subclassed to create a custom representation of a relationship. The subclass should override `.to_native()`, and optionally `.from_native()` if deserialization is supported. +# HyperlinkedModelSerializer -All the relational fields may be used for any relationship or reverse relationship on a model. +The `HyperlinkedModelSerializer` class is similar to the `ModelSerializer` class except that it uses hyperlinks to represent relationships, rather than primary keys. -## Specifying which fields should be included +By default the serializer will include a `url` field instead of a primary key field. -If you only want a subset of the default fields to be used in a model serializer, you can do so using `fields` or `exclude` options, just as you would with a `ModelForm`. +The url field will be represented using a `HyperlinkedIdentityField` serializer field, and any relationships on the model will be represented using a `HyperlinkedRelatedField` serializer field. -For example: +You can explicitly include the primary key by adding it to the `fields` option, for example: - class AccountSerializer(serializers.ModelSerializer): + class AccountSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = Account - exclude = ('id',) + fields = ('url', 'id', 'account_name', 'users', 'created') -## Specifiying nested serialization +## How hyperlinked views are determined -The default `ModelSerializer` uses primary keys for relationships, but you can also easily generate nested representations using the `depth` option: +There needs to be a way of determining which views should be used for hyperlinking to model instances. - class AccountSerializer(serializers.ModelSerializer): +By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument. + +You can change the field that is used for object lookups by setting the `lookup_field` option. The value of this option should correspond both with a kwarg in the URL conf, and with an field on the model. For example: + + class AccountSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = Account - exclude = ('id',) - depth = 1 + fields = ('url', 'account_name', 'users', 'created') + lookup_field = 'slug' + +For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly. For example: + + class AccountSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField( + view_name='account_detail', + lookup_field='account_name' + ) + users = serializers.HyperlinkedRelatedField( + view_name='user-detail', + lookup_field='username', + many=True, + read_only=True + ) -The `depth` option should be set to an integer value that indicates the depth of relationships that should be traversed before reverting to a flat representation. + class Meta: + model = Account + fields = ('url', 'account_name', 'users', 'created') -## Specifying which fields should be read-only +--- -You may wish to specify multiple fields as read-only. Instead of adding each field explicitely with the `read_only=True` attribute, you may use the `read_only_fields` Meta option, like so: +# Advanced serializer usage - class AccountSerializer(serializers.ModelSerializer): - class Meta: - model = Account - read_only_fields = ('created', 'modified') +You can create customized subclasses of `ModelSerializer` or `HyperlinkedModelSerializer` that use a different set of default fields. + +Doing so should be considered advanced usage, and will only be needed if you have some particular serializer requirements that you often need to repeat. ## Customising the default fields -You can create customized subclasses of `ModelSerializer` that use a different set of default fields for the representation, by overriding various `get_<field_type>_field` methods. +The `field_mapping` attribute is a dictionary that maps model classes to serializer classes. Overriding the attribute will let you set a different set of default serializer classes. -Each of these methods may either return a field or serializer instance, or `None`. +For more advanced customization than simply changing the default serializer class you can override various `get_<field_type>_field` methods. Doing so will allow you to customize the arguments that each serializer field is initialized with. Each of these methods may either return a field or serializer instance, or `None`. ### get_pk_field @@ -275,23 +435,27 @@ Returns the field instance that should be used to represent the pk field. ### get_nested_field -**Signature**: `.get_nested_field(self, model_field)` +**Signature**: `.get_nested_field(self, model_field, related_model, to_many)` Returns the field instance that should be used to represent a related field when `depth` is specified as being non-zero. +Note that the `model_field` argument will be `None` for reverse relationships. The `related_model` argument will be the model class for the target of the field. The `to_many` argument will be a boolean indicating if this is a to-one or to-many relationship. + ### get_related_field -**Signature**: `.get_related_field(self, model_field, to_many=False)` +**Signature**: `.get_related_field(self, model_field, related_model, to_many)` Returns the field instance that should be used to represent a related field when `depth` is not specified, or when nested representations are being used and the depth reaches zero. +Note that the `model_field` argument will be `None` for reverse relationships. The `related_model` argument will be the model class for the target of the field. The `to_many` argument will be a boolean indicating if this is a to-one or to-many relationship. + ### get_field **Signature**: `.get_field(self, model_field)` Returns the field instance that should be used for non-relational, non-pk fields. -### Example: +## Example The following custom model serializer could be used as a base class for model serializers that should always exclude the pk by default. @@ -302,3 +466,4 @@ The following custom model serializer could be used as a base class for model se [cite]: https://groups.google.com/d/topic/django-users/sVFaOfQi4wY/discussion +[relations]: relations.md diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index 7884d096..b00ab4c1 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -34,7 +34,11 @@ The `api_settings` object will check for any user-defined settings, and otherwis # API Reference -## DEFAULT_RENDERER_CLASSES +## API policy settings + +*The following settings control the basic API policies, and are applied to every `APIView` class based view, or `@api_view` function based view.* + +#### DEFAULT_RENDERER_CLASSES A list or tuple of renderer classes, that determines the default set of renderers that may be used when returning a `Response` object. @@ -43,10 +47,9 @@ Default: ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', - 'rest_framework.renderers.TemplateHTMLRenderer' ) -## DEFAULT_PARSER_CLASSES +#### DEFAULT_PARSER_CLASSES A list or tuple of parser classes, that determines the default set of parsers used when accessing the `request.DATA` property. @@ -54,10 +57,11 @@ Default: ( 'rest_framework.parsers.JSONParser', - 'rest_framework.parsers.FormParser' + 'rest_framework.parsers.FormParser', + 'rest_framework.parsers.MultiPartParser' ) -## DEFAULT_AUTHENTICATION_CLASSES +#### DEFAULT_AUTHENTICATION_CLASSES A list or tuple of authentication classes, that determines the default set of authenticators used when accessing the `request.user` or `request.auth` properties. @@ -65,10 +69,10 @@ Default: ( 'rest_framework.authentication.SessionAuthentication', - 'rest_framework.authentication.UserBasicAuthentication' + 'rest_framework.authentication.BasicAuthentication' ) -## DEFAULT_PERMISSION_CLASSES +#### DEFAULT_PERMISSION_CLASSES A list or tuple of permission classes, that determines the default set of permissions checked at the start of a view. @@ -78,53 +82,78 @@ Default: 'rest_framework.permissions.AllowAny', ) -## DEFAULT_THROTTLE_CLASSES +#### DEFAULT_THROTTLE_CLASSES A list or tuple of throttle classes, that determines the default set of throttles checked at the start of a view. Default: `()` -## DEFAULT_MODEL_SERIALIZER_CLASS +#### DEFAULT_CONTENT_NEGOTIATION_CLASS + +A content negotiation class, that determines how a renderer is selected for the response, given an incoming request. + +Default: `'rest_framework.negotiation.DefaultContentNegotiation'` + +--- + +## Generic view settings -**TODO** +*The following settings control the behavior of the generic class based views.* -Default: `rest_framework.serializers.ModelSerializer` +#### DEFAULT_MODEL_SERIALIZER_CLASS -## DEFAULT_PAGINATION_SERIALIZER_CLASS +A class that determines the default type of model serializer that should be used by a generic view if `model` is specified, but `serializer_class` is not provided. -**TODO** +Default: `'rest_framework.serializers.ModelSerializer'` + +#### DEFAULT_PAGINATION_SERIALIZER_CLASS + +A class the determines the default serialization style for paginated responses. Default: `rest_framework.pagination.PaginationSerializer` -## FILTER_BACKEND +#### DEFAULT_FILTER_BACKENDS -The filter backend class that should be used for generic filtering. If set to `None` then generic filtering is disabled. +A list of filter backend classes that should be used for generic filtering. +If set to `None` then generic filtering is disabled. -## PAGINATE_BY +#### PAGINATE_BY The default page size to use for pagination. If set to `None`, pagination is disabled by default. Default: `None` -## PAGINATE_BY_KWARG +#### PAGINATE_BY_PARAM The name of a query parameter, which can be used by the client to overide the default page size to use for pagination. If set to `None`, clients may not override the default page size. Default: `None` -## UNAUTHENTICATED_USER +--- + +## Authentication settings + +*The following settings control the behavior of unauthenticated requests.* + +#### UNAUTHENTICATED_USER The class that should be used to initialize `request.user` for unauthenticated requests. Default: `django.contrib.auth.models.AnonymousUser` -## UNAUTHENTICATED_TOKEN +#### UNAUTHENTICATED_TOKEN The class that should be used to initialize `request.auth` for unauthenticated requests. Default: `None` -## FORM_METHOD_OVERRIDE +--- + +## Browser overrides + +*The following settings provide URL or form-based overrides of the default browser behavior.* + +#### FORM_METHOD_OVERRIDE The name of a form field that may be used to override the HTTP method of the form. @@ -132,7 +161,7 @@ If the value of this setting is `None` then form method overloading will be disa Default: `'_method'` -## FORM_CONTENT_OVERRIDE +#### FORM_CONTENT_OVERRIDE The name of a form field that may be used to override the content of the form payload. Must be used together with `FORM_CONTENTTYPE_OVERRIDE`. @@ -140,7 +169,7 @@ If either setting is `None` then form content overloading will be disabled. Default: `'_content'` -## FORM_CONTENTTYPE_OVERRIDE +#### FORM_CONTENTTYPE_OVERRIDE The name of a form field that may be used to override the content type of the form payload. Must be used together with `FORM_CONTENT_OVERRIDE`. @@ -148,7 +177,7 @@ If either setting is `None` then form content overloading will be disabled. Default: `'_content_type'` -## URL_ACCEPT_OVERRIDE +#### URL_ACCEPT_OVERRIDE The name of a URL parameter that may be used to override the HTTP `Accept` header. @@ -156,14 +185,75 @@ If the value of this setting is `None` then URL accept overloading will be disab Default: `'accept'` -## URL_FORMAT_OVERRIDE +#### URL_FORMAT_OVERRIDE + +The name of a URL parameter that may be used to override the default `Accept` header based content negotiation. Default: `'format'` -## FORMAT_SUFFIX_KWARG +--- + +## Date and time formatting + +*The following settings are used to control how date and time representations may be parsed and rendered.* + +#### DATETIME_FORMAT + +A format string that should be used by default for rendering the output of `DateTimeField` serializer fields. If `None`, then `DateTimeField` serializer fields will return python `datetime` objects, and the datetime encoding will be determined by the renderer. + +May be any of `None`, `'iso-8601'` or a python [strftime format][strftime] string. + +Default: `None` + +#### DATETIME_INPUT_FORMATS + +A list of format strings that should be used by default for parsing inputs to `DateTimeField` serializer fields. + +May be a list including the string `'iso-8601'` or python [strftime format][strftime] strings. + +Default: `['iso-8601']` + +#### DATE_FORMAT + +A format string that should be used by default for rendering the output of `DateField` serializer fields. If `None`, then `DateField` serializer fields will return python `date` objects, and the date encoding will be determined by the renderer. + +May be any of `None`, `'iso-8601'` or a python [strftime format][strftime] string. + +Default: `None` + +#### DATE_INPUT_FORMATS + +A list of format strings that should be used by default for parsing inputs to `DateField` serializer fields. + +May be a list including the string `'iso-8601'` or python [strftime format][strftime] strings. + +Default: `['iso-8601']` + +#### TIME_FORMAT + +A format string that should be used by default for rendering the output of `TimeField` serializer fields. If `None`, then `TimeField` serializer fields will return python `time` objects, and the time encoding will be determined by the renderer. + +May be any of `None`, `'iso-8601'` or a python [strftime format][strftime] string. + +Default: `None` + +#### TIME_INPUT_FORMATS + +A list of format strings that should be used by default for parsing inputs to `TimeField` serializer fields. + +May be a list including the string `'iso-8601'` or python [strftime format][strftime] strings. + +Default: `['iso-8601']` + +--- + +## Miscellaneous settings + +#### FORMAT_SUFFIX_KWARG -**TODO** +The name of a parameter in the URL conf that may be used to provide a format suffix. Default: `'format'` [cite]: http://www.python.org/dev/peps/pep-0020/ +[strftime]: http://docs.python.org/2/library/time.html#time.strftime diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md index b03bc9e0..d6de85ba 100644 --- a/docs/api-guide/throttling.md +++ b/docs/api-guide/throttling.md @@ -6,8 +6,6 @@ > > [Twitter API rate limiting response][cite] -[cite]: https://dev.twitter.com/docs/error-codes-responses - Throttling is similar to [permissions], in that it determines if a request should be authorized. Throttles indicate a temporary state, and are used to control the rate of requests that clients can make to an API. As with permissions, multiple throttles may be used. Your API might have a restrictive throttle for unauthenticated requests, and a less restrictive throttle for authenticated requests. @@ -42,7 +40,8 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period. -You can also set the throttling policy on a per-view basis, using the `APIView` class based views. +You can also set the throttling policy on a per-view or per-viewset basis, +using the `APIView` class based views. class ExampleView(APIView): throttle_classes = (UserThrottle,) @@ -63,6 +62,10 @@ Or, if you're using the `@api_view` decorator with function based views. } return Response(content) +## Setting up the cache + +The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details. + --- # API Reference @@ -150,8 +153,19 @@ User requests to either `ContactListView` or `ContactDetailView` would be restri # Custom throttles -To create a custom throttle, override `BaseThrottle` and implement `.allow_request(request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise. +To create a custom throttle, override `BaseThrottle` and implement `.allow_request(self, request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise. Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recommended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`. +## Example + +The following is an example of a rate throttle, that will randomly throttle 1 in every 10 requests. + + class RandomRateThrottle(throttles.BaseThrottle): + def allow_request(self, request, view): + return random.randint(1, 10) == 1 + +[cite]: https://dev.twitter.com/docs/error-codes-responses [permissions]: permissions.md +[cache-setting]: https://docs.djangoproject.com/en/dev/ref/settings/#caches +[cache-docs]: https://docs.djangoproject.com/en/dev/topics/cache/#setting-up-the-cache diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index d1e42ec1..8b26b3e3 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -76,16 +76,16 @@ The following methods are used by REST framework to instantiate the various plug The following methods are called before dispatching to the handler method. -### .check_permissions(...) +### .check_permissions(self, request) -### .check_throttles(...) +### .check_throttles(self, request) -### .perform_content_negotiation(...) +### .perform_content_negotiation(self, request, force=False) ## Dispatch methods The following methods are called directly by the view's `.dispatch()` method. -These perform any actions that need to occur before or after calling the handler methods such as `.get()`, `.post()`, `put()` and `.delete()`. +These perform any actions that need to occur before or after calling the handler methods such as `.get()`, `.post()`, `put()`, `patch()` and `.delete()`. ### .initial(self, request, \*args, **kwargs) diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md new file mode 100644 index 00000000..cd92dc58 --- /dev/null +++ b/docs/api-guide/viewsets.md @@ -0,0 +1,219 @@ +<a class="github" href="viewsets.py"></a> + +# ViewSets + +> After routing has determined which controller to use for a request, your controller is responsible for making sense of the request and producing the appropriate output. +> +> — [Ruby on Rails Documentation][cite] + + +Django REST framework allows you to combine the logic for a set of related views in a single class, called a `ViewSet`. In other frameworks you may also find conceptually similar implementations named something like 'Resources' or 'Controllers'. + +A `ViewSet` class is simply **a type of class-based View, that does not provide any method handlers** such as `.get()` or `.post()`, and instead provides actions such as `.list()` and `.create()`. + +The method handlers for a `ViewSet` are only bound to the corresponding actions at the point of finalizing the view, using the `.as_view()` method. + +Typically, rather than explicitly registering the views in a viewset in the urlconf, you'll register the viewset with a router class, that automatically determines the urlconf for you. + +## Example + +Let's define a simple viewset that can be used to list or retrieve all the users in the system. + + class UserViewSet(viewsets.ViewSet): + """ + A simple ViewSet that for listing or retrieving users. + """ + def list(self, request): + queryset = User.objects.all() + serializer = UserSerializer(queryset, many=True) + return Response(serializer.data) + + def retrieve(self, request, pk=None): + queryset = User.objects.all() + user = get_object_or_404(queryset, pk=pk) + serializer = UserSerializer(user) + return Response(serializer.data) + +If we need to, we can bind this viewset into two seperate views, like so: + + user_list = UserViewSet.as_view({'get': 'list'}) + user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically we wouldn't do this, but would instead register the viewset with a router, and allow the urlconf to be automatically generated. + + router = DefaultRouter() + router.register(r'users', UserViewSet) + urlpatterns = router.urls + +Rather than writing your own viewsets, you'll often want to use the existing base classes that provide a default set of behavior. For example: + + class UserViewSet(viewsets.ModelViewSet): + """ + A viewset for viewing and editing user instances. + """ + serializer_class = UserSerializer + queryset = User.objects.all() + +There are two main advantages of using a `ViewSet` class over using a `View` class. + +* Repeated logic can be combined into a single class. In the above example, we only need to specify the `queryset` once, and it'll be used across multiple views. +* By using routers, we no longer need to deal with wiring up the URL conf ourselves. + +Both of these come with a trade-off. Using regular views and URL confs is more explicit and gives you more control. ViewSets are helpful if you want to get up and running quickly, or when you have a large API and you want to enforce a consistent URL configuration throughout. + +## Marking extra methods for routing + +The default routers included with REST framework will provide routes for a standard set of create/retrieve/update/destroy style operations, as shown below: + + class UserViewSet(viewsets.VietSet): + """ + Example empty viewset demonstrating the standard + actions that will be handled by a router class. + + If you're using format suffixes, make sure to also include + the `format=None` keyword argument for each action. + """ + + def list(self, request): + pass + + def create(self, request): + pass + + def retrieve(self, request, pk=None): + pass + + def update(self, request, pk=None): + pass + + def partial_update(self, request, pk=None): + pass + + def destroy(self, request, pk=None): + pass + +If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link` or `@action` decorators. The `@link` decorator will route `GET` requests, and the `@action` decroator will route `POST` requests. + +For example: + + from django.contrib.auth.models import User + from rest_framework import viewsets + from rest_framework.decorators import action + from myapp.serializers import UserSerializer + + class UserViewSet(viewsets.ModelViewSet): + """ + A viewset that provides the standard actions + """ + queryset = User.objects.all() + serializer_class = UserSerializer + + @action + def set_password(self, request, pk=None): + user = self.get_object() + serializer = PasswordSerializer(data=request.DATA) + if serializer.is_valid(): + user.set_password(serializer.data['password']) + user.save() + return Response({'status': 'password set'}) + else: + return Response(serializer.errors, + status=status.HTTP_400_BAD_REQUEST) + +The `@action` and `@link` decorators can additionally take extra arguments that will be set for the routed view only. For example... + + @action(permission_classes=[IsAdminOrIsSelf]) + def set_password(self, request, pk=None): + ... + +--- + +# API Reference + +## ViewSet + +The `ViewSet` class inherits from `APIView`. You can use any of the standard attributes such as `permission_classes`, `authentication_classes` in order to control the API policy on the viewset. + +The `ViewSet` class does not provide any implementations of actions. In order to use a `ViewSet` class you'll override the class and define the action implementations explicitly. + +## GenericViewSet + +The `GenericViewSet` class inherits from `GenericAPIView`, and provides the default set of `get_object`, `get_queryset` methods and other generic view base behavior, but does not include any actions by default. + +In order to use a `GenericViewSet` class you'll override the class and either mixin the required mixin classes, or define the action implementations explicitly. + +## ModelViewSet + +The `ModelViewSet` class inherits from `GenericAPIView` and includes implementations for various actions, by mixing in the behavior of the various mixin classes. + +The actions provided by the `ModelViewSet` class are `.list()`, `.retrieve()`, `.create()`, `.update()`, and `.destroy()`. + +#### Example + +Because `ModelViewSet` extends `GenericAPIView`, you'll normally need to provide at least the `queryset` and `serializer_class` attributes. For example: + + class AccountViewSet(viewsets.ModelViewSet): + """ + A simple ViewSet for viewing and editing accounts. + """ + queryset = Account.objects.all() + serializer_class = AccountSerializer + permission_classes = [IsAccountAdminOrReadOnly] + +Note that you can use any of the standard attributes or method overrides provided by `GenericAPIView`. For example, to use a `ViewSet` that dynamically determines the queryset it should operate on, you might do something like this: + + class AccountViewSet(viewsets.ModelViewSet): + """ + A simple ViewSet for viewing and editing the accounts + associated with the user. + """ + serializer_class = AccountSerializer + permission_classes = [IsAccountAdminOrReadOnly] + + def get_queryset(self): + return request.user.accounts.all() + +Also note that although this class provides the complete set of create/list/retrieve/update/destroy actions by default, you can restrict the available operations by using the standard permission classes. + +## ReadOnlyModelViewSet + +The `ReadOnlyModelViewSet` class also inherits from `GenericAPIView`. As with `ModelViewSet` it also includes implementations for various actions, but unlike `ModelViewSet` only provides the 'read-only' actions, `.list()` and `.retrieve()`. + +#### Example + +As with `ModelViewSet`, you'll normally need to provide at least the `queryset` and `serializer_class` attributes. For example: + + class AccountViewSet(viewsets.ReadOnlyModelViewSet): + """ + A simple ViewSet for viewing accounts. + """ + queryset = Account.objects.all() + serializer_class = AccountSerializer + +Again, as with `ModelViewSet`, you can use any of the standard attributes and method overrides available to `GenericAPIView`. + +# Custom ViewSet base classes + +You may need to provide custom `ViewSet` classes that do not have the full set of `ModelViewSet` actions, or that customize the behavior in some other way. + +## Example + +To create a base viewset class that provides `create`, `list` and `retrieve` operations, inherit from `GenericViewSet`, and mixin the required actions: + + class CreateListRetrieveViewSet(mixins.CreateMixin, + mixins.ListMixin, + mixins.RetrieveMixin, + viewsets.GenericViewSet): + pass + + """ + A viewset that provides `retrieve`, `update`, and `list` actions. + + To use it, override the class and set the `.queryset` and + `.serializer_class` attributes. + """ + pass + +By creating your own base `ViewSet` classes, you can provide common behavior that can be reused in multiple viewsets across your API. + +[cite]: http://guides.rubyonrails.org/routing.html diff --git a/docs/css/default.css b/docs/css/default.css index 57446ff9..998efa27 100644 --- a/docs/css/default.css +++ b/docs/css/default.css @@ -25,18 +25,29 @@ pre { margin-top: 9px; } +body.index-page #main-content p.badges { + padding-bottom: 1px; +} + /* GitHub 'Star' badge */ -body.index-page #main-content iframe { +body.index-page #main-content iframe.github-star-button { float: right; margin-top: -12px; margin-right: -15px; } +/* Tweet button */ +body.index-page #main-content iframe.twitter-share-button { + float: right; + margin-top: -12px; + margin-right: 8px; +} + /* Travis CI badge */ -body.index-page #main-content p:first-of-type { +body.index-page #main-content img.travis-build-image { float: right; margin-right: 8px; - margin-top: -14px; + margin-top: -11px; margin-bottom: 0px; } @@ -266,3 +277,24 @@ footer a { footer a:hover { color: gray; } + +.btn-inverse { + background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#606060), to(#404040)) !important; + background-image: -webkit-linear-gradient(top, #606060, #404040) !important; +} + +.modal-open .modal,.btn:focus{outline:none;} + +@media (max-width: 650px) { + .repo-link.btn-inverse {display: none;} +} + +td, th { + padding: 0.25em; + background-color: #f7f7f9; + border-color: #e1e1e8; +} + +table { + border-color: white; +} diff --git a/docs/index.md b/docs/index.md index cc0f2a13..7c38efd3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,25 +1,29 @@ -<iframe src="http://ghbtns.com/github-btn.html?user=tomchristie&repo=django-rest-framework&type=watch&count=true" allowtransparency="true" frameborder="0" scrolling="0" width="110px" height="20px"></iframe> -[![Travis build image][travis-build-image]][travis] +<p class="badges"> +<iframe src="http://ghbtns.com/github-btn.html?user=tomchristie&repo=django-rest-framework&type=watch&count=true" class="github-star-button" allowtransparency="true" frameborder="0" scrolling="0" width="110px" height="20px"></iframe> -# Django REST framework - -**A toolkit for building well-connected, self-describing Web APIs.** +<a href="https://twitter.com/share" class="twitter-share-button" data-url="django-rest-framework.org" data-text="Checking out the totally awesome Django REST framework! http://django-rest-framework.org" data-count="none">Tweet</a> +<script>!function(d,s,id){var js,fjs=d.getElementsByTagName(s)[0];if(!d.getElementById(id)){js=d.createElement(s);js.id=id;js.src="http://platform.twitter.com/widgets.js";fjs.parentNode.insertBefore(js,fjs);}}(document,"script","twitter-wjs");</script> ---- +<img alt="Travis build image" src="https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=master" class="travis-build-image"> +</p> -**Note**: This documentation is for the 2.0 version of REST framework. If you are looking for earlier versions please see the [0.4.x branch][0.4] on GitHub. +# Django REST framework ---- +**Awesome web-browsable Web APIs.** -Django REST framework is a lightweight library that makes it easy to build Web APIs. It is designed as a modular and easy to customize architecture, based on Django's class based views. +Django REST framework is a powerful and flexible toolkit that makes it easy to build Web APIs. -Web APIs built using REST framework are fully self-describing and web browseable - a huge useability win for your developers. It also supports a wide range of media types, authentication and permission policies out of the box. +Some reasons you might want to use REST framework: -If you are considering using REST framework for your API, we recommend reading the [REST framework 2 announcment][rest-framework-2-announcement] which gives a good overview of the framework and it's capabilities. +* The Web browseable API is a huge useability win for your developers. +* Authentication policies including OAuth1a and OAuth2 out of the box. +* Serialization that supports both ORM and non-ORM data sources. +* Customizable all the way down - just use regular function-based views if you don't need the more powerful features. +* Extensive documentation, and great community support. -There is also a sandbox API you can use for testing purposes, [available here][sandbox]. +There is a live example API for testing purposes, [available here][sandbox]. -**Below**: *Screenshot from the browseable API* +**Below**: *Screenshot from the browsable API* ![Screenshot][image] @@ -27,50 +31,102 @@ There is also a sandbox API you can use for testing purposes, [available here][s REST framework requires the following: -* Python (2.6, 2.7) +* Python (2.6.5+, 2.7, 3.2, 3.3) * Django (1.3, 1.4, 1.5) The following packages are optional: -* [Markdown][markdown] (2.1.0+) - Markdown support for the browseable API. +* [Markdown][markdown] (2.1.0+) - Markdown support for the browsable API. * [PyYAML][yaml] (3.10+) - YAML content-type support. +* [defusedxml][defusedxml] (0.3+) - XML content-type support. * [django-filter][django-filter] (0.5.4+) - Filtering support. +* [django-oauth-plus][django-oauth-plus] (2.0+) and [oauth2][oauth2] (1.5.211+) - OAuth 1.0a support. +* [django-oauth2-provider][django-oauth2-provider] (0.2.3+) - OAuth 2.0 support. + +**Note**: The `oauth2` python package is badly misnamed, and actually provides OAuth 1.0a support. Also note that packages required for both OAuth 1.0a, and OAuth 2.0 are not yet Python 3 compatible. ## Installation Install using `pip`, including any optional packages you want... pip install djangorestframework - pip install markdown # Markdown support for the browseable API. - pip install pyyaml # YAML content-type support. + pip install markdown # Markdown support for the browsable API. pip install django-filter # Filtering support ...or clone the project from github. git clone git@github.com:tomchristie/django-rest-framework.git - cd django-rest-framework - pip install -r requirements.txt - pip install -r optionals.txt -Add `rest_framework` to your `INSTALLED_APPS`. +Add `'rest_framework'` to your `INSTALLED_APPS` setting. INSTALLED_APPS = ( ... 'rest_framework', ) -If you're intending to use the browseable API you'll want to add REST framework's login and logout views. Add the following to your root `urls.py` file. +If you're intending to use the browsable API you'll probably also want to add REST framework's login and logout views. Add the following to your root `urls.py` file. urlpatterns = patterns('', ... url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) ) -Note that the URL path can be whatever you want, but you must include `rest_framework.urls` with the `rest_framework` namespace. +Note that the URL path can be whatever you want, but you must include `'rest_framework.urls'` with the `'rest_framework'` namespace. + +## Example + +Let's take a look at a quick example of using REST framework to build a simple model-backed API. + +We'll create a read-write API for accessing users and groups. + +Any global settings for a REST framework API are kept in a single configuration dictionary named `REST_FRAMEWORK`. Start off by adding the following to your `settings.py` module: + + REST_FRAMEWORK = { + # Use hyperlinked styles by default. + # Only used if the `serializer_class` attribute is not set on a view. + 'DEFAULT_MODEL_SERIALIZER_CLASS': + 'rest_framework.serializers.HyperlinkedModelSerializer', + + # Use Django's standard `django.contrib.auth` permissions, + # or allow read-only access for unauthenticated users. + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.DjangoModelPermissionsOrAnonReadOnly' + ] + } + +Don't forget to make sure you've also added `rest_framework` to your `INSTALLED_APPS`. + +We're ready to create our API now. +Here's our project's root `urls.py` module: + + from django.conf.urls.defaults import url, patterns, include + from django.contrib.auth.models import User, Group + from rest_framework import viewsets, routers + + # ViewSets define the view behavior. + class UserViewSet(viewsets.ModelViewSet): + model = User + + class GroupViewSet(viewsets.ModelViewSet): + model = Group + + + # Routers provide an easy way of automatically determining the URL conf + router = routers.DefaultRouter() + router.register(r'users', UserViewSet) + router.register(r'groups', GroupViewSet) + + + # Wire up our API using automatic URL routing. + # Additionally, we include login URLs for the browseable API. + urlpatterns = patterns('', + url(r'^', include(router.urls)), + url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) + ) ## Quickstart -Can't wait to get started? The [quickstart guide][quickstart] is the fastest way to get up and running with REST framework. +Can't wait to get started? The [quickstart guide][quickstart] is the fastest way to get up and running, and building APIs with REST framework. ## Tutorial @@ -81,6 +137,7 @@ The tutorial will walk you through the building blocks that make up REST framewo * [3 - Class based views][tut-3] * [4 - Authentication & permissions][tut-4] * [5 - Relationships & hyperlinked APIs][tut-5] +* [6 - Viewsets & routers][tut-6] ## API Guide @@ -90,10 +147,13 @@ The API guide is your complete reference manual to all the functionality provide * [Responses][response] * [Views][views] * [Generic views][generic-views] +* [Viewsets][viewsets] +* [Routers][routers] * [Parsers][parsers] * [Renderers][renderers] * [Serializers][serializers] * [Serializer fields][fields] +* [Serializer relations][relations] * [Authentication][authentication] * [Permissions][permissions] * [Throttling][throttling] @@ -110,10 +170,13 @@ The API guide is your complete reference manual to all the functionality provide General guides to using REST framework. +* [AJAX, CSRF & CORS][ajax-csrf-cors] * [Browser enhancements][browser-enhancements] * [The Browsable API][browsableapi] * [REST, Hypermedia & HATEOAS][rest-hypermedia-hateoas] * [2.0 Announcement][rest-framework-2-announcement] +* [2.2 Announcement][2.2-announcement] +* [2.3 Announcement][2.3-announcement] * [Release Notes][release-notes] * [Credits][credits] @@ -129,15 +192,24 @@ Run the tests: ./rest_framework/runtests/runtests.py +To run the tests against all supported configurations, first install [the tox testing tool][tox] globally, using `pip install tox`, then simply run `tox`: + + tox + ## Support -For support please see the [REST framework discussion group][group], or try the `#restframework` channel on `irc.freenode.net`. +For support please see the [REST framework discussion group][group], try the `#restframework` channel on `irc.freenode.net`, or raise a question on [Stack Overflow][stack-overflow], making sure to include the ['django-rest-framework'][django-rest-framework-tag] tag. + +[Paid support is available][paid-support] from [DabApps][dabapps], and can include work on REST framework core, or support with building your REST framework API. Please [contact DabApps][contact-dabapps] if you'd like to discuss commercial support options. -Paid support is also available from [DabApps], and can include work on REST framework core, or support with building your REST framework API. Please contact [Tom Christie][email] if you'd like to discuss commercial support options. +For updates on REST framework development, you may also want to follow [the author][twitter] on Twitter. +<a style="padding-top: 10px" href="https://twitter.com/_tomchristie" class="twitter-follow-button" data-show-count="false">Follow @_tomchristie</a> +<script>!function(d,s,id){var js,fjs=d.getElementsByTagName(s)[0];if(!d.getElementById(id)){js=d.createElement(s);js.id=id;js.src="//platform.twitter.com/widgets.js";fjs.parentNode.insertBefore(js,fjs);}}(document,"script","twitter-wjs");</script> + ## License -Copyright (c) 2011-2012, Tom Christie +Copyright (c) 2011-2013, Tom Christie All rights reserved. Redistribution and use in source and binary forms, with or without @@ -161,11 +233,15 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. [travis]: http://travis-ci.org/tomchristie/django-rest-framework?branch=master -[travis-build-image]: https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=restframework2 +[travis-build-image]: https://secure.travis-ci.org/tomchristie/django-rest-framework.png?branch=master [urlobject]: https://github.com/zacharyvoase/urlobject [markdown]: http://pypi.python.org/pypi/Markdown/ [yaml]: http://pypi.python.org/pypi/PyYAML -[django-filter]: https://github.com/alex/django-filter +[defusedxml]: https://pypi.python.org/pypi/defusedxml +[django-filter]: http://pypi.python.org/pypi/django-filter +[oauth2]: https://github.com/simplegeo/python-oauth2 +[django-oauth-plus]: https://bitbucket.org/david/django-oauth-plus/wiki/Home +[django-oauth2-provider]: https://github.com/caffeinehit/django-oauth2-provider [0.4]: https://github.com/tomchristie/django-rest-framework/tree/0.4.X [image]: img/quickstart.png [sandbox]: http://restframework.herokuapp.com/ @@ -176,15 +252,19 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. [tut-3]: tutorial/3-class-based-views.md [tut-4]: tutorial/4-authentication-and-permissions.md [tut-5]: tutorial/5-relationships-and-hyperlinked-apis.md +[tut-6]: tutorial/6-viewsets-and-routers.md [request]: api-guide/requests.md [response]: api-guide/responses.md [views]: api-guide/views.md [generic-views]: api-guide/generic-views.md +[viewsets]: api-guide/viewsets.md +[routers]: api-guide/routers.md [parsers]: api-guide/parsers.md [renderers]: api-guide/renderers.md [serializers]: api-guide/serializers.md [fields]: api-guide/fields.md +[relations]: api-guide/relations.md [authentication]: api-guide/authentication.md [permissions]: api-guide/permissions.md [throttling]: api-guide/throttling.md @@ -197,15 +277,24 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. [status]: api-guide/status-codes.md [settings]: api-guide/settings.md -[csrf]: topics/csrf.md +[ajax-csrf-cors]: topics/ajax-csrf-cors.md [browser-enhancements]: topics/browser-enhancements.md [browsableapi]: topics/browsable-api.md [rest-hypermedia-hateoas]: topics/rest-hypermedia-hateoas.md [contributing]: topics/contributing.md [rest-framework-2-announcement]: topics/rest-framework-2-announcement.md +[2.2-announcement]: topics/2.2-announcement.md +[2.3-announcement]: topics/2.3-announcement.md [release-notes]: topics/release-notes.md [credits]: topics/credits.md +[tox]: http://testrun.org/tox/latest/ + [group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework -[DabApps]: http://dabapps.com -[email]: mailto:tom@tomchristie.com +[stack-overflow]: http://stackoverflow.com/ +[django-rest-framework-tag]: http://stackoverflow.com/questions/tagged/django-rest-framework +[django-tag]: http://stackoverflow.com/questions/tagged/django +[paid-support]: http://dabapps.com/services/build/api-development/ +[dabapps]: http://dabapps.com +[contact-dabapps]: http://dabapps.com/contact/ +[twitter]: https://twitter.com/_tomchristie diff --git a/docs/template.html b/docs/template.html index 676a4807..53656e7d 100644 --- a/docs/template.html +++ b/docs/template.html @@ -2,11 +2,11 @@ <html lang="en"> <head><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"> <meta charset="utf-8"> - <title>Django REST framework</title> + <title>{{ title }}</title> <link href="{{ base_url }}/img/favicon.ico" rel="icon" type="image/x-icon"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> - <meta name="description" content=""> - <meta name="author" content=""> + <meta name="description" content="{{ description }}"> + <meta name="author" content="Tom Christie"> <!-- Le styles --> <link href="{{ base_url }}/css/prettify.css" rel="stylesheet"> @@ -41,6 +41,9 @@ <div class="navbar-inner"> <div class="container-fluid"> <a class="repo-link btn btn-primary btn-small" href="https://github.com/tomchristie/django-rest-framework/tree/master">GitHub</a> + <a class="repo-link btn btn-inverse btn-small {{ next_url_disabled }}" href="{{ next_url }}">Next <i class="icon-arrow-right icon-white"></i></a> + <a class="repo-link btn btn-inverse btn-small {{ prev_url_disabled }}" href="{{ prev_url }}"><i class="icon-arrow-left icon-white"></i> Previous</a> + <a class="repo-link btn btn-inverse btn-small" href="#searchModal" data-toggle="modal"><i class="icon-search icon-white"></i> Search</a> <a class="btn btn-navbar" data-toggle="collapse" data-target=".nav-collapse"> <span class="icon-bar"></span> <span class="icon-bar"></span> @@ -59,6 +62,7 @@ <li><a href="{{ base_url }}/tutorial/3-class-based-views{{ suffix }}">3 - Class based views</a></li> <li><a href="{{ base_url }}/tutorial/4-authentication-and-permissions{{ suffix }}">4 - Authentication and permissions</a></li> <li><a href="{{ base_url }}/tutorial/5-relationships-and-hyperlinked-apis{{ suffix }}">5 - Relationships and hyperlinked APIs</a></li> + <li><a href="{{ base_url }}/tutorial/6-viewsets-and-routers{{ suffix }}">6 - Viewsets and routers</a></li> </ul> </li> <li class="dropdown"> @@ -68,10 +72,13 @@ <li><a href="{{ base_url }}/api-guide/responses{{ suffix }}">Responses</a></li> <li><a href="{{ base_url }}/api-guide/views{{ suffix }}">Views</a></li> <li><a href="{{ base_url }}/api-guide/generic-views{{ suffix }}">Generic views</a></li> + <li><a href="{{ base_url }}/api-guide/viewsets{{ suffix }}">Viewsets</a></li> + <li><a href="{{ base_url }}/api-guide/routers{{ suffix }}">Routers</a></li> <li><a href="{{ base_url }}/api-guide/parsers{{ suffix }}">Parsers</a></li> <li><a href="{{ base_url }}/api-guide/renderers{{ suffix }}">Renderers</a></li> <li><a href="{{ base_url }}/api-guide/serializers{{ suffix }}">Serializers</a></li> <li><a href="{{ base_url }}/api-guide/fields{{ suffix }}">Serializer fields</a></li> + <li><a href="{{ base_url }}/api-guide/relations{{ suffix }}">Serializer relations</a></li> <li><a href="{{ base_url }}/api-guide/authentication{{ suffix }}">Authentication</a></li> <li><a href="{{ base_url }}/api-guide/permissions{{ suffix }}">Permissions</a></li> <li><a href="{{ base_url }}/api-guide/throttling{{ suffix }}">Throttling</a></li> @@ -88,10 +95,13 @@ <li class="dropdown"> <a href="#" class="dropdown-toggle" data-toggle="dropdown">Topics <b class="caret"></b></a> <ul class="dropdown-menu"> + <li><a href="{{ base_url }}/topics/ajax-csrf-cors{{ suffix }}">AJAX, CSRF & CORS</a></li> <li><a href="{{ base_url }}/topics/browser-enhancements{{ suffix }}">Browser enhancements</a></li> <li><a href="{{ base_url }}/topics/browsable-api{{ suffix }}">The Browsable API</a></li> <li><a href="{{ base_url }}/topics/rest-hypermedia-hateoas{{ suffix }}">REST, Hypermedia & HATEOAS</a></li> <li><a href="{{ base_url }}/topics/rest-framework-2-announcement{{ suffix }}">2.0 Announcement</a></li> + <li><a href="{{ base_url }}/topics/2.2-announcement{{ suffix }}">2.2 Announcement</a></li> + <li><a href="{{ base_url }}/topics/2.3-announcement{{ suffix }}">2.3 Announcement</a></li> <li><a href="{{ base_url }}/topics/release-notes{{ suffix }}">Release Notes</a></li> <li><a href="{{ base_url }}/topics/credits{{ suffix }}">Credits</a></li> </ul> @@ -115,6 +125,34 @@ <div class="body-content"> <div class="container-fluid"> + +<!-- Search Modal --> +<div id="searchModal" class="modal hide fade" tabindex="-1" role="dialog" aria-labelledby="myModalLabel" aria-hidden="true"> + <div class="modal-header"> + <button type="button" class="close" data-dismiss="modal" aria-hidden="true">×</button> + <h3 id="myModalLabel">Documentation search</h3> + </div> + <div class="modal-body"> + <!-- Custom google search --> + <script> + (function() { + var cx = '015016005043623903336:rxraeohqk6w'; + var gcse = document.createElement('script'); + gcse.type = 'text/javascript'; + gcse.async = true; + gcse.src = (document.location.protocol == 'https:' ? 'https:' : 'http:') + + '//www.google.com/cse/cse.js?cx=' + cx; + var s = document.getElementsByTagName('script')[0]; + s.parentNode.insertBefore(gcse, s); + })(); + </script> + <gcse:search></gcse:search> + </div> + <div class="modal-footer"> + <button class="btn" data-dismiss="modal" aria-hidden="true">Close</button> + </div> +</div> + <div class="row-fluid"> <div class="span3"> diff --git a/docs/topics/2.2-announcement.md b/docs/topics/2.2-announcement.md new file mode 100644 index 00000000..d7164ce4 --- /dev/null +++ b/docs/topics/2.2-announcement.md @@ -0,0 +1,159 @@ +# REST framework 2.2 announcement + +The 2.2 release represents an important point for REST framework, with the addition of Python 3 support, and the introduction of an official deprecation policy. + +## Python 3 support + +Thanks to some fantastic work from [Xavier Ordoquy][xordoquy], Django REST framework 2.2 now supports Python 3. You'll need to be running Django 1.5, and it's worth keeping in mind that Django's Python 3 support is currently [considered experimental][django-python-3]. + +Django 1.6's Python 3 support is expected to be officially labeled as 'production-ready'. + +If you want to start ensuring that your own projects are Python 3 ready, we can highly recommend Django's [Porting to Python 3][porting-python-3] documentation. + +Django REST framework's Python 2.6 support now requires 2.6.5 or above, in line with [Django 1.5's Python compatibility][python-compat]. + +## Deprecation policy + +We've now introduced an official deprecation policy, which is in line with [Django's deprecation policy][django-deprecation-policy]. This policy will make it easy for you to continue to track the latest, greatest version of REST framework. + +The timeline for deprecation works as follows: + +* Version 2.2 introduces some API changes as detailed in the release notes. It remains fully backwards compatible with 2.1, but will raise `PendingDeprecationWarning` warnings if you use bits of API that are due to be deprecated. These warnings are silent by default, but can be explicitly enabled when you're ready to start migrating any required changes. For example if you start running your tests using `python -Wd manage.py test`, you'll be warned of any API changes you need to make. + +* Version 2.3 will escalate these warnings to `DeprecationWarning`, which is loud by default. + +* Version 2.4 will remove the deprecated bits of API entirely. + +Note that in line with Django's policy, any parts of the framework not mentioned in the documentation should generally be considered private API, and may be subject to change. + +## Community + +As of the 2.2 merge, we've also hit an impressive milestone. The number of committers listed in [the credits][credits], is now at over **one hundred individuals**. Each name on that list represents at least one merged pull request, however large or small. + +Our [mailing list][mailing-list] and #restframework IRC channel are also very active, and we've got a really impressive rate of development both on REST framework itself, and on third party packages such as the great [django-rest-framework-docs][django-rest-framework-docs] package from [Marc Gibbons][marcgibbons]. + +--- + +## API changes + +The 2.2 release makes a few changes to the API, in order to make it more consistent, simple, and easier to use. + +### Cleaner to-many related fields + +The `ManyRelatedField()` style is being deprecated in favor of a new `RelatedField(many=True)` syntax. + +For example, if a user is associated with multiple questions, which we want to represent using a primary key relationship, we might use something like the following: + + class UserSerializer(serializers.HyperlinkedModelSerializer): + questions = serializers.PrimaryKeyRelatedField(many=True) + + class Meta: + fields = ('username', 'questions') + +The new syntax is cleaner and more obvious, and the change will also make the documentation cleaner, simplify the internal API, and make writing custom relational fields easier. + +The change also applies to serializers. If you have a nested serializer, you should start using `many=True` for to-many relationships. For example, a serializer representation of an Album that can contain many Tracks might look something like this: + + class TrackSerializer(serializer.ModelSerializer): + class Meta: + model = Track + fields = ('name', 'duration') + + class AlbumSerializer(serializer.ModelSerializer): + tracks = TrackSerializer(many=True) + + class Meta: + model = Album + fields = ('album_name', 'artist', 'tracks') + +Additionally, the change also applies when serializing or deserializing data. For example to serialize a queryset of models you should now use the `many=True` flag. + + serializer = SnippetSerializer(Snippet.objects.all(), many=True) + serializer.data + +This more explicit behavior on serializing and deserializing data [makes integration with non-ORM backends such as MongoDB easier][564], as instances to be serialized can include the `__iter__` method, without incorrectly triggering list-based serialization, or requiring workarounds. + +The implicit to-many behavior on serializers, and the `ManyRelatedField` style classes will continue to function, but will raise a `PendingDeprecationWarning`, which can be made visible using the `-Wd` flag. + +**Note**: If you need to forcibly turn off the implict "`many=True` for `__iter__` objects" behavior, you can now do so by specifying `many=False`. This will become the default (instead of the current default of `None`) once the deprecation of the implicit behavior is finalised in version 2.4. + +### Cleaner optional relationships + +Serializer relationships for nullable Foreign Keys will change from using the current `null=True` flag, to instead using `required=False`. + +For example, is a user account has an optional foreign key to a company, that you want to express using a hyperlink, you might use the following field in a `Serializer` class: + + current_company = serializers.HyperlinkedRelatedField(required=False) + +This is in line both with the rest of the serializer fields API, and with Django's `Form` and `ModelForm` API. + +Using `required` throughout the serializers API means you won't need to consider if a particular field should take `blank` or `null` arguments instead of `required`, and also means there will be more consistent behavior for how fields are treated when they are not present in the incoming data. + +The `null=True` argument will continue to function, and will imply `required=False`, but will raise a `PendingDeprecationWarning`. + +### Cleaner CharField syntax + +The `CharField` API previously took an optional `blank=True` argument, which was intended to differentiate between null CharField input, and blank CharField input. + +In keeping with Django's CharField API, REST framework's `CharField` will only ever return the empty string, for missing or `None` inputs. The `blank` flag will no longer be in use, and you should instead just use the `required=<bool>` flag. For example: + + extra_details = CharField(required=False) + +The `blank` keyword argument will continue to function, but will raise a `PendingDeprecationWarning`. + +### Simpler object-level permissions + +Custom permissions classes previously used the signatute `.has_permission(self, request, view, obj=None)`. This method would be called twice, firstly for the global permissions check, with the `obj` parameter set to `None`, and again for the object-level permissions check when appropriate, with the `obj` parameter set to the relevant model instance. + +The global permissions check and object-level permissions check are now seperated into two seperate methods, which gives a cleaner, more obvious API. + +* Global permission checks now use the `.has_permission(self, request, view)` signature. +* Object-level permission checks use a new method `.has_object_permission(self, request, view, obj)`. + +For example, the following custom permission class: + + class IsOwner(permissions.BasePermission): + """ + Custom permission to only allow owners of an object to view or edit it. + Model instances are expected to include an `owner` attribute. + """ + + def has_permission(self, request, view, obj=None): + if obj is None: + # Ignore global permissions check + return True + + return obj.owner == request.user + +Now becomes: + + class IsOwner(permissions.BasePermission): + """ + Custom permission to only allow owners of an object to view or edit it. + Model instances are expected to include an `owner` attribute. + """ + + def has_object_permission(self, request, view, obj): + return obj.owner == request.user + +If you're overriding the `BasePermission` class, the old-style signature will continue to function, and will correctly handle both global and object-level permissions checks, but it's use will raise a `PendingDeprecationWarning`. + +Note also that the usage of the internal APIs for permission checking on the `View` class has been cleaned up slightly, and is now documented and subject to the deprecation policy in all future versions. + +### More explicit hyperlink relations behavior + +When using a serializer with a `HyperlinkedRelatedField` or `HyperlinkedIdentityField`, the hyperlinks would previously use absolute URLs if the serializer context included a `'request'` key, and fallback to using relative URLs otherwise. This could lead to non-obvious behavior, as it might not be clear why some serializers generated absolute URLs, and others do not. + +From version 2.2 onwards, serializers with hyperlinked relationships *always* require a `'request'` key to be supplied in the context dictionary. The implicit behavior will continue to function, but it's use will raise a `PendingDeprecationWarning`. + +[xordoquy]: https://github.com/xordoquy +[django-python-3]: https://docs.djangoproject.com/en/dev/faq/install/#can-i-use-django-with-python-3 +[porting-python-3]: https://docs.djangoproject.com/en/dev/topics/python3/ +[python-compat]: https://docs.djangoproject.com/en/dev/releases/1.5/#python-compatibility +[django-deprecation-policy]: https://docs.djangoproject.com/en/dev/internals/release-process/#internal-release-deprecation-policy +[credits]: http://django-rest-framework.org/topics/credits.html +[mailing-list]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework +[django-rest-framework-docs]: https://github.com/marcgibbons/django-rest-framework-docs +[marcgibbons]: https://github.com/marcgibbons/ +[issues]: https://github.com/tomchristie/django-rest-framework/issues +[564]: https://github.com/tomchristie/django-rest-framework/issues/564 diff --git a/docs/topics/2.3-announcement.md b/docs/topics/2.3-announcement.md new file mode 100644 index 00000000..4df9c819 --- /dev/null +++ b/docs/topics/2.3-announcement.md @@ -0,0 +1,264 @@ +# REST framework 2.3 announcement + +REST framework 2.3 makes it even quicker and easier to build your Web APIs. + +## ViewSets and Routers + +The 2.3 release introduces the [ViewSet][viewset] and [Router][router] classes. + +A viewset is simply a type of class based view that allows you to group multiple views into a single common class. + +Routers allow you to automatically determine the URLconf for your viewset classes. + +As an example of just how simple REST framework APIs can now be, here's an API written in a single `urls.py` module: + + """ + A REST framework API for viewing and editing users and groups. + """ + from django.conf.urls.defaults import url, patterns, include + from django.contrib.auth.models import User, Group + from rest_framework import viewsets, routers + + + # ViewSets define the view behavior. + class UserViewSet(viewsets.ModelViewSet): + model = User + + class GroupViewSet(viewsets.ModelViewSet): + model = Group + + + # Routers provide an easy way of automatically determining the URL conf + router = routers.DefaultRouter() + router.register(r'users', UserViewSet) + router.register(r'groups', GroupViewSet) + + + # Wire up our API using automatic URL routing. + # Additionally, we include login URLs for the browseable API. + urlpatterns = patterns('', + url(r'^', include(router.urls)), + url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) + ) + +The best place to get started with ViewSets and Routers is to take a look at the [newest section in the tutorial][part-6], which demonstrates their usage. + +## Simpler views + +This release rationalises the API and implementation of the generic views, dropping the dependancy on Django's `SingleObjectMixin` and `MultipleObjectMixin` classes, removing a number of unneeded attributes, and generally making the implementation more obvious and easy to work with. + +This improvement is reflected in improved documentation for the `GenericAPIView` base class, and should make it easier to determine how to override methods on the base class if you need to write customized subclasses. + +## Easier Serializers + +REST framework lets you be totally explict regarding how you want to represent relationships, allowing you to choose between styles such as hyperlinking or primary key relationships. + +The ability to specify exactly how you want to represent relationships is powerful, but it also introduces complexity. In order to keep things more simple, REST framework now allows you to include reverse relationships simply by including the field name in the `fields` metadata of the serializer class. + +For example, in REST framework 2.2, reverse relationships needed to be included explicitly on a serializer class. + + class BlogSerializer(serializers.ModelSerializer): + comments = serializers.PrimaryKeyRelatedField(many=True) + + class Meta: + model = Blog + fields = ('id', 'title', 'created', 'comments') + +As of 2.3, you can simply include the field name, and the appropriate serializer field will automatically be used for the relationship. + + class BlogSerializer(serializers.ModelSerializer): + """ + Don't need to specify the 'comments' field explicitly anymore. + """ + class Meta: + model = Blog + fields = ('id', 'title', 'created', 'comments') + +Similarly, you can now easily include the primary key in hyperlinked relationships, simply by adding the field name to the metadata. + + class BlogSerializer(serializers.HyperlinkedModelSerializer): + """ + This is a hyperlinked serializer, which default to using + a field named 'url' as the primary identifier. + Note that we can now easily also add in the 'id' field. + """ + class Meta: + model = Blog + fields = ('url', 'id', 'title', 'created', 'comments') + +## More flexible filtering + +The `FILTER_BACKEND` setting has moved to pending deprecation, in favor of a `DEFAULT_FILTER_BACKENDS` setting that takes a *list* of filter backend classes, instead of a single filter backend class. + +The generic view `filter_backend` attribute has also been moved to pending deprecation in favor of a `filter_backends` setting. + +Being able to specify multiple filters will allow for more flexible, powerful behavior. New filter classes to handle searching and ordering of results are planned to be released shortly. + +--- + +# API Changes + +## Simplified generic view classes + +The functionality provided by `SingleObjectAPIView` and `MultipleObjectAPIView` base classes has now been moved into the base class `GenericAPIView`. The implementation of this base class is simple enough that providing subclasses for the base classes of detail and list views is somewhat unnecessary. + +Additionally the base generic view no longer inherits from Django's `SingleObjectMixin` or `MultipleObjectMixin` classes, simplifying the implementation, and meaning you don't need to cross-reference across to Django's codebase. + +Using the `SingleObjectAPIView` and `MultipleObjectAPIView` base classes continues to be supported, but will raise a `PendingDeprecationWarning`. You should instead simply use `GenericAPIView` as the base for any generic view subclasses. + +### Removed attributes + +The following attributes and methods, were previously present as part of Django's generic view implementations, but were unneeded and unusedand have now been entirely removed. + +* context_object_name +* get_context_data() +* get_context_object_name() + +The following attributes and methods, which were previously present as part of Django's generic view implementations have also been entirely removed. + +* paginator_class +* get_paginator() +* get_allow_empty() +* get_slug_field() + +There may be cases when removing these bits of API might mean you need to write a little more code if your view has highly customized behavior, but generally we believe that providing a coarser-grained API will make the views easier to work with, and is the right trade-off to make for the vast majority of cases. + +Note that the listed attributes and methods have never been a documented part of the REST framework API, and as such are not covered by the deprecation policy. + +### Simplified methods + +The `get_object` and `get_paginate_by` methods no longer take an optional queryset argument. This makes overridden these methods more obvious, and a little more simple. + +Using an optional queryset with these methods continues to be supported, but will raise a `PendingDeprecationWarning`. + +The `paginate_queryset` method no longer takes a `page_size` argument, or returns a four-tuple of pagination information. Instead it simply takes a queryset argument, and either returns a `page` object with an appropraite page size, or returns `None`, if pagination is not configured for the view. + +Using the `page_size` argument is still supported and will trigger the old-style return type, but will raise a `PendingDeprecationWarning`. + +### Deprecated attributes + +The following attributes are used to control queryset lookup, and have all been moved into a pending deprecation state. + +* pk_url_kwarg = 'pk' +* slug_url_kwarg = 'slug' +* slug_field = 'slug' + +Their usage is replaced with a single attribute: + +* lookup_field = 'pk' + +This attribute is used both as the regex keyword argument in the URL conf, and as the model field to filter against when looking up a model instance. To use non-pk based lookup, simply set the `lookup_field` argument to an alternative field, and ensure that the keyword argument in the url conf matches the field name. + +For example, a view with 'username' based lookup might look like this: + + class UserDetail(generics.RetrieveAPIView): + lookup_field = 'username' + queryset = User.objects.all() + serializer_class = UserSerializer + +And would have the following entry in the urlconf: + + url(r'^users/(?P<username>\w+)/$', UserDetail.as_view()), + +Usage of the old-style attributes continues to be supported, but will raise a `PendingDeprecationWarning`. + +The `allow_empty` attribute is also deprecated. To use `allow_empty=False` style behavior you should explicitly override `get_queryset` and raise an `Http404` on empty querysets. + +For example: + + class DisallowEmptyQuerysetMixin(object): + def get_queryset(self): + queryset = super(DisallowEmptyQuerysetMixin, self).get_queryset() + if not queryset.exists(): + raise Http404 + return queryset + +In our opinion removing lesser-used attributes like `allow_empty` helps us move towards simpler generic view implementations, making them more obvious to use and override, and re-inforcing the preferred style of developers writing their own base classes and mixins for custom behavior rather than relying on the configurability of the generic views. + +## Simpler URL lookups + +The `HyperlinkedRelatedField` class now takes a single optional `lookup_field` argument, that replaces the `pk_url_kwarg`, `slug_url_kwarg`, and `slug_field` arguments. + +For example, you might have a field that references it's relationship by a hyperlink based on a slug field: + + account = HyperlinkedRelatedField(read_only=True, + lookup_field='slug', + view_name='account-detail') + +Usage of the old-style attributes continues to be supported, but will raise a `PendingDeprecationWarning`. + +## FileUploadParser + +2.3 adds a `FileUploadParser` parser class, that supports raw file uploads, in addition to the existing multipart upload support. + +## DecimalField + +2.3 introduces a `DecimalField` serializer field, which returns `Decimal` instances. + +For most cases APIs using model fields will behave as previously, however if you are using a custom renderer, not provided by REST framework, then you may now need to add support for rendering `Decimal` instances to your renderer implmentation. + +## ModelSerializers and reverse relationships + +The support for adding reverse relationships to the `fields` option on a `ModelSerializer` class means that the `get_related_field` and `get_nested_field` method signatures have now changed. + +In the unlikely event that you're providing a custom serializer class, and implementing these methods you should note the new call signature for both methods is now `(self, model_field, related_model, to_many)`. For revese relationships `model_field` will be `None`. + +The old-style signature will continue to function but will raise a `PendingDeprecationWarning`. + +## View names and descriptions + +The mechanics of how the names and descriptions used in the browseable API are generated has been modified and cleaned up somewhat. + +If you've been customizing this behavior, for example perhaps to use `rst` markup for the browseable API, then you'll need to take a look at the implementation to see what updates you need to make. + +Note that the relevant methods have always been private APIs, and the docstrings called them out as intended to be deprecated. + +--- + +# Other notes + +## More explicit style + +The usage of `model` attribute in generic Views is still supported, but it's usage is generally being discouraged throughout the documentation, in favour of the setting the more explict `queryset` and `serializer_class` attributes. + +For example, the following is now the recommended style for using generic views: + + class AccountListView(generics.RetrieveAPIView): + queryset = MyModel.objects.all() + serializer_class = MyModelSerializer + +Using an explict `queryset` and `serializer_class` attributes makes the functioning of the view more clear than using the shortcut `model` attribute. + +It also makes the usage of the `get_queryset()` or `get_serializer_class()` methods more obvious. + + class AccountListView(generics.RetrieveAPIView): + serializer_class = MyModelSerializer + + def get_queryset(self): + """ + Determine the queryset dynamically, depending on the + user making the request. + + Note that overriding this method follows on more obviously now + that an explicit `queryset` attribute is the usual view style. + """ + return self.user.accounts + +## Django 1.3 support + +The 2.3.x release series will be the last series to provide compatiblity with Django 1.3. + +## Version 2.2 API changes + +All API changes in 2.2 that previously raised `PendingDeprecationWarning` will now raise a `DeprecationWarning`, which is loud by default. + +## What comes next? + +* Support for read-write nested serializers is almost complete, and due to be released in the next few weeks. +* Extra filter backends for searching and ordering of results are planned to be added shortly. + +The next few months should see a renewed focus on addressing outstanding tickets. The 2.4 release is currently planned for around August-September. + +[viewset]: ../api-guide/viewsets.md +[router]: ../api-guide/routers.md +[part-6]: ../tutorial/6-viewsets-and-routers.md diff --git a/docs/topics/ajax-csrf-cors.md b/docs/topics/ajax-csrf-cors.md new file mode 100644 index 00000000..f7d12940 --- /dev/null +++ b/docs/topics/ajax-csrf-cors.md @@ -0,0 +1,41 @@ +# Working with AJAX, CSRF & CORS + +> "Take a close look at possible CSRF / XSRF vulnerabilities on your own websites. They're the worst kind of vulnerability — very easy to exploit by attackers, yet not so intuitively easy to understand for software developers, at least until you've been bitten by one." +> +> — [Jeff Atwood][cite] + +## Javascript clients + +If your building a javascript client to interface with your Web API, you'll need to consider if the client can use the same authentication policy that is used by the rest of the website, and also determine if you need to use CSRF tokens or CORS headers. + +AJAX requests that are made within the same context as the API they are interacting with will typically use `SessionAuthentication`. This ensures that once a user has logged in, any AJAX requests made can be authenticated using the same session-based authentication that is used for the rest of the website. + +AJAX requests that are made on a different site from the API they are communicating with will typically need to use a non-session-based authentication scheme, such as `TokenAuthentication`. + +## CSRF protection + +[Cross Site Request Forgery][csrf] protection is a mechanism of guarding against a particular type of attack, which can occur when a user has not logged out of a web site, and continues to have a valid session. In this circumstance a malicious site may be able to perform actions against the target site, within the context of the logged-in session. + +To guard against these type of attacks, you need to do two things: + +1. Ensure that the 'safe' HTTP operations, such as `GET`, `HEAD` and `OPTIONS` cannot be used to alter any server-side state. +2. Ensure that any 'unsafe' HTTP operations, such as `POST`, `PUT`, `PATCH` and `DELETE`, always require a valid CSRF token. + +If you're using `SessionAuthentication` you'll need to include valid CSRF tokens for any `POST`, `PUT`, `PATCH` or `DELETE` operations. + +The Django documentation describes how to [include CSRF tokens in AJAX requests][csrf-ajax]. + +## CORS + +[Cross-Origin Resource Sharing][cors] is a mechanism for allowing clients to interact with APIs that are hosted on a different domain. CORS works by requiring the server to include a specific set of headers that allow a browser to determine if and when cross-domain requests should be allowed. + +The best way to deal with CORS in REST framework is to add the required response headers in middleware. This ensures that CORS is supported transparently, without having to change any behavior in your views. + +[Otto Yiu][ottoyiu] maintains the [django-cors-headers] package, which is known to work correctly with REST framework APIs. + +[cite]: http://www.codinghorror.com/blog/2008/10/preventing-csrf-and-xsrf-attacks.html +[csrf]: https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF) +[csrf-ajax]: https://docs.djangoproject.com/en/dev/ref/contrib/csrf/#ajax +[cors]: http://www.w3.org/TR/cors/ +[ottoyiu]: https://github.com/ottoyiu/ +[django-cors-headers]: https://github.com/ottoyiu/django-cors-headers/ diff --git a/docs/topics/browsable-api.md b/docs/topics/browsable-api.md index 9fe82e69..8ee01824 100644 --- a/docs/topics/browsable-api.md +++ b/docs/topics/browsable-api.md @@ -35,23 +35,20 @@ A suitable replacement theme can be generated using Bootstrap's [Customize Tool] You can also change the navbar variant, which by default is `navbar-inverse`, using the `bootstrap_navbar_variant` block. The empty `{% block bootstrap_navbar_variant %}{% endblock %}` will use the original Bootstrap navbar style. -For more specific CSS tweaks, use the `extra_style` block instead. +For more specific CSS tweaks, use the `style` block instead. ### Blocks All of the blocks available in the browsable API base template that can be used in your `api.html`. -* `blockbots` - `<meta>` tag that blocks crawlers * `bodyclass` - (empty) class attribute for the `<body>` * `bootstrap_theme` - CSS for the Bootstrap theme * `bootstrap_navbar_variant` - CSS class for the navbar * `branding` - section of the navbar, see [Bootstrap components][bcomponentsnav] * `breadcrumbs` - Links showing resource nesting, allowing the user to go back up the resources. It's recommended to preserve these, but they can be overridden using the breadcrumbs block. -* `extrastyle` - (empty) extra CSS for the page -* `extrahead` - (empty) extra markup for the page `<head>` * `footer` - Any copyright notices or similar footer materials can go here (by default right-aligned) -* `global_heading` - (empty) Use to insert content below the header but before the breadcrumbs. +* `style` - CSS stylesheets for the page * `title` - title of the page * `userlinks` - This is a list of links on the right of the header, by default containing login/logout links. To add links instead of replace, use {{ block.super }} to preserve the authentication links. @@ -63,6 +60,17 @@ All of the [Bootstrap components][bcomponents] are available. The browsable API makes use of the Bootstrap tooltips component. Any element with the `js-tooltip` class and a `title` attribute has that title content displayed in a tooltip on hover after a 1000ms delay. +### Login Template + +To add branding and customize the look-and-feel of the auth login template, create a template called `login.html` and add it to your project, eg: `templates/rest_framework/login.html`, that extends the `rest_framework/base_login.html` template. + +You can add your site name or branding by including the branding block: + + {% block branding %} + <h3 style="margin: 0 0 20px;">My Site Name</h3> + {% endblock %} + +You can also customize the style by adding the `bootstrap_theme` or `style` block similar to `api.html`. ### Advanced Customization diff --git a/docs/topics/browser-enhancements.md b/docs/topics/browser-enhancements.md index 6a11f0fa..ce07fe95 100644 --- a/docs/topics/browser-enhancements.md +++ b/docs/topics/browser-enhancements.md @@ -19,6 +19,21 @@ For example, given the following form: `request.method` would return `"DELETE"`. +## HTTP header based method overriding + +REST framework also supports method overriding via the semi-standard `X-HTTP-Method-Override` header. This can be useful if you are working with non-form content such as JSON and are working with an older web server and/or hosting provider that doesn't recognise particular HTTP methods such as `PATCH`. For example [Amazon Web Services ELB][aws_elb]. + +To use it, make a `POST` request, setting the `X-HTTP-Method-Override` header. + +For example, making a `PATCH` request via `POST` in jQuery: + + $.ajax({ + url: '/myresource/', + method: 'POST', + headers: {'X-HTTP-Method-Override': 'PATCH'}, + ... + }); + ## Browser based submission of non-form content Browser-based submission of content types other than form are supported by @@ -62,3 +77,4 @@ as well as how to support content types other than form-encoded data. [rails]: http://guides.rubyonrails.org/form_helpers.html#how-do-forms-with-put-or-delete-methods-work [html5]: http://www.w3.org/TR/html5-diff/#changes-2010-06-24 [put_delete]: http://amundsen.com/examples/put-delete-forms/ +[aws_elb]: https://forums.aws.amazon.com/thread.jspa?messageID=400724 diff --git a/docs/topics/contributing.md b/docs/topics/contributing.md index 7fd61c10..1d1fe892 100644 --- a/docs/topics/contributing.md +++ b/docs/topics/contributing.md @@ -4,12 +4,138 @@ > > — [Tim Berners-Lee][cite] -## Running the tests +There are many ways you can contribute to Django REST framework. We'd like it to be a community-led project, so please get involved and help shape the future of the project. -## Building the docs +# Community -## Managing compatibility issues +If you use and enjoy REST framework please consider [staring the project on GitHub][github], and [upvoting it on Django packages][django-packages]. Doing so helps potential new users see that the project is well used, and help us continue to attract new users. -**Describe compat module** +You might also consider writing a blog post on your experience with using REST framework, writing a tutorial about using the project with a particular javascript framework, or simply sharing the love on Twitter. -[cite]: http://www.w3.org/People/Berners-Lee/FAQ.html
\ No newline at end of file +Other really great ways you can help move the community forward include helping answer questions on the [discussion group][google-group], or setting up an [email alert on StackOverflow][so-filter] so that you get notified of any new questions with the `django-rest-framework` tag. + +When answering questions make sure to help future contributors find their way around by hyperlinking wherever possible to related threads and tickets, and include backlinks from those items if relevant. + +# Issues + +It's really helpful if you make sure you address issues to the correct channel. Usage questions should be directed to the [discussion group][google-group]. Feature requests, bug reports and other issues should be raised on the GitHub [issue tracker][issues]. + +Some tips on good issue reporting: + +* When decribing issues try to phrase your ticket in terms of the *behavior* you think needs changing rather than the *code* you think need changing. +* Search the issue list first for related items, and make sure you're running the latest version of REST framework before reporting an issue. +* If reporting a bug, then try to include a pull request with a failing test case. This'll help us quickly identify if there is a valid issue, and make sure that it gets fixed more quickly if there is one. + + + +* TODO: Triage + +# Development + +* git clone & PYTHONPATH +* Pep8 +* Recommend editor that runs pep8 + +### Pull requests + +* Make pull requests early +* Describe branching + +### Managing compatibility issues + +* Describe compat module + +# Testing + +* Running the tests +* tox + +# Documentation + +The documentation for REST framework is built from the [Markdown][markdown] source files in [the docs directory][docs]. + +There are many great markdown editors that make working with the documentation really easy. The [Mou editor for Mac][mou] is one such editor that comes highly recommended. + +## Building the documentation + +To build the documentation, simply run the `mkdocs.py` script. + + ./mkdocs.py + +This will build the html output into the `html` directory. + +You can build the documentation and open a preview in a browser window by using the `-p` flag. + + ./mkdocs.py -p + +## Language style + +Documentation should be in American English. The tone of the documentation is very important - try to stick to a simple, plain, objective and well-balanced style where possible. + +Some other tips: + +* Keep paragraphs reasonably short. +* Use double spacing after the end of sentences. +* Don't use the abbreviations such as 'e.g..' but instead use long form, such as 'For example'. + +## Markdown style + +There are a couple of conventions you should follow when working on the documentation. + +##### 1. Headers + +Headers should use the hash style. For example: + + ### Some important topic + +The underline style should not be used. **Don't do this:** + + Some important topic + ==================== + +##### 2. Links + +Links should always use the reference style, with the referenced hyperlinks kept at the end of the document. + + Here is a link to [some other thing][other-thing]. + + More text... + + [other-thing]: http://example.com/other/thing + +This style helps keep the documentation source consistent and readable. + +If you are hyperlinking to another REST framework document, you should use a relative link, and link to the `.md` suffix. For example: + + [authentication]: ../api-guide/authentication.md + +Linking in this style means you'll be able to click the hyperlink in your markdown editor to open the referenced document. When the documentation is built, these links will be converted into regular links to HTML pages. + +##### 3. Notes + +If you want to draw attention to a note or warning, use a pair of enclosing lines, like so: + + --- + + **Note:** Make sure you do this thing. + + --- + +# Third party packages + +* Django reusable app + +# Core committers + +* Still use pull reqs +* Credits + +[cite]: http://www.w3.org/People/Berners-Lee/FAQ.html +[github]: https://github.com/tomchristie/django-rest-framework +[django-packages]: https://www.djangopackages.com/grids/g/api/ +[google-group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework +[so-filter]: http://stackexchange.com/filters/66475/rest-framework +[issues]: https://github.com/tomchristie/django-rest-framework/issues?state=open +[markdown]: http://daringfireball.net/projects/markdown/basics +[docs]: https://github.com/tomchristie/django-rest-framework/tree/master/docs +[mou]: http://mouapp.com/ diff --git a/docs/topics/credits.md b/docs/topics/credits.md index 8b8cac1a..8151b4d3 100644 --- a/docs/topics/credits.md +++ b/docs/topics/credits.md @@ -2,9 +2,9 @@ The following people have helped make REST framework great. -* Tom Christie - [tomchristie] +* Tom Christie - [tomchristie] * Marko Tibold - [markotibold] -* Paul Bagwell - [pbgwl] +* Paul Miller - [paulmillr] * Sébastien Piquemal - [sebpiq] * Carmen Wick - [cwick] * Alex Ehlke - [aehlke] @@ -19,7 +19,7 @@ The following people have helped make REST framework great. * Craig Blaszczyk - [jakul] * Garcia Solero - [garciasolero] * Tom Drummond - [devioustree] -* Danilo Bargen - [gwrtheyrn] +* Danilo Bargen - [dbrgn] * Andrew McCloud - [amccloud] * Thomas Steinacher - [thomasst] * Meurig Freeman - [meurig] @@ -81,6 +81,49 @@ The following people have helped make REST framework great. * Szymon Teżewski - [sunscrapers] * Joel Marcotte - [joual] * Trey Hunner - [treyhunner] +* Roman Akinfold - [akinfold] +* Toran Billups - [toranb] +* Sébastien Béal - [sebastibe] +* Andrew Hankinson - [ahankinson] +* Juan Riaza - [juanriaza] +* Michael Mior - [michaelmior] +* Marc Tamlyn - [mjtamlyn] +* Richard Wackerbarth - [wackerbarth] +* Johannes Spielmann - [shezi] +* James Cleveland - [radiosilence] +* Steve Gregory - [steve-gregory] +* Federico Capoano - [nemesisdesign] +* Bruno Renié - [brutasse] +* Kevin Stone - [kevinastone] +* Guglielmo Celata - [guglielmo] +* Mike Tums - [mktums] +* Michael Elovskikh - [wronglink] +* MichaÅ‚ Jaworski - [swistakm] +* Andrea de Marco - [z4r] +* Fernando Rocha - [fernandogrd] +* Xavier Ordoquy - [xordoquy] +* Adam Wentz - [floppya] +* Andreas Pelme - [pelme] +* Ryan Detzel - [ryanrdetzel] +* Omer Katz - [thedrow] +* Wiliam Souza - [waa] +* Jonas Braun - [iekadou] +* Ian Dash - [bitmonkey] +* Bouke Haarsma - [bouke] +* Pierre Dulac - [dulaccc] +* Dave Kuhn - [kuhnza] +* Sitong Peng - [stoneg] +* Victor Shih - [vshih] +* Atle Frenvik Sveen - [atlefren] +* J Paul Reed - [preed] +* Matt Majewski - [forgingdestiny] +* Jerome Chen - [chenjyw] +* Andrew Hughes - [eyepulp] +* Daniel Hepper - [dhepper] +* Hamish Campbell - [hamishcampbell] +* Marlon Bailey - [avinash240] +* James Summerfield - [jsummerfield] +* Andy Freeland - [rouge8] Many thanks to everyone who's contributed to the project. @@ -92,9 +135,9 @@ Project hosting is with [GitHub]. Continuous integration testing is managed with [Travis CI][travis-ci]. -The [live sandbox][sandbox] is hosted on [Heroku]. +The [live sandbox][sandbox] is hosted on [Heroku]. -Various inspiration taken from the [Piston], [Tastypie] and [Dagny] projects. +Various inspiration taken from the [Rails], [Piston], [Tastypie], [Dagny] and [django-viewsets] projects. Development of REST framework 2.0 was sponsored by [DabApps]. @@ -103,16 +146,17 @@ Development of REST framework 2.0 was sponsored by [DabApps]. For usage questions please see the [REST framework discussion group][group]. You can also contact [@_tomchristie][twitter] directly on twitter. - -[email]: mailto:tom@tomchristie.com + [twitter]: http://twitter.com/_tomchristie [bootstrap]: http://twitter.github.com/bootstrap/ [markdown]: http://daringfireball.net/projects/markdown/ [github]: https://github.com/tomchristie/django-rest-framework [travis-ci]: https://secure.travis-ci.org/tomchristie/django-rest-framework +[rails]: http://rubyonrails.org/ [piston]: https://bitbucket.org/jespern/django-piston [tastypie]: https://github.com/toastdriven/django-tastypie [dagny]: https://github.com/zacharyvoase/dagny +[django-viewsets]: https://github.com/BertrandBordage/django-viewsets [dabapps]: http://lab.dabapps.com [sandbox]: http://restframework.herokuapp.com/ [heroku]: http://www.heroku.com/ @@ -120,7 +164,7 @@ You can also contact [@_tomchristie][twitter] directly on twitter. [tomchristie]: https://github.com/tomchristie [markotibold]: https://github.com/markotibold -[pbgwl]: https://github.com/pbgwl +[paulmillr]: https://github.com/paulmillr [sebpiq]: https://github.com/sebpiq [cwick]: https://github.com/cwick [aehlke]: https://github.com/aehlke @@ -135,7 +179,7 @@ You can also contact [@_tomchristie][twitter] directly on twitter. [jakul]: https://github.com/jakul [garciasolero]: https://github.com/garciasolero [devioustree]: https://github.com/devioustree -[gwrtheyrn]: https://github.com/gwrtheyrn +[dbrgn]: https://github.com/dbrgn [amccloud]: https://github.com/amccloud [thomasst]: https://github.com/thomasst [meurig]: https://github.com/meurig @@ -197,3 +241,46 @@ You can also contact [@_tomchristie][twitter] directly on twitter. [sunscrapers]: https://github.com/sunscrapers [joual]: https://github.com/joual [treyhunner]: https://github.com/treyhunner +[akinfold]: https://github.com/akinfold +[toranb]: https://github.com/toranb +[sebastibe]: https://github.com/sebastibe +[ahankinson]: https://github.com/ahankinson +[juanriaza]: https://github.com/juanriaza +[michaelmior]: https://github.com/michaelmior +[mjtamlyn]: https://github.com/mjtamlyn +[wackerbarth]: https://github.com/wackerbarth +[shezi]: https://github.com/shezi +[radiosilence]: https://github.com/radiosilence +[steve-gregory]: https://github.com/steve-gregory +[nemesisdesign]: https://github.com/nemesisdesign +[brutasse]: https://github.com/brutasse +[kevinastone]: https://github.com/kevinastone +[guglielmo]: https://github.com/guglielmo +[mktums]: https://github.com/mktums +[wronglink]: https://github.com/wronglink +[swistakm]: https://github.com/swistakm +[z4r]: https://github.com/z4r +[fernandogrd]: https://github.com/fernandogrd +[xordoquy]: https://github.com/xordoquy +[floppya]: https://github.com/floppya +[pelme]: https://github.com/pelme +[ryanrdetzel]: https://github.com/ryanrdetzel +[thedrow]: https://github.com/thedrow +[waa]: https://github.com/wiliamsouza +[iekadou]: https://github.com/iekadou +[bitmonkey]: https://github.com/bitmonkey +[bouke]: https://github.com/bouke +[dulaccc]: https://github.com/dulaccc +[kuhnza]: https://github.com/kuhnza +[stoneg]: https://github.com/stoneg +[vshih]: https://github.com/vshih +[atlefren]: https://github.com/atlefren +[preed]: https://github.com/preed +[forgingdestiny]: https://github.com/forgingdestiny +[chenjyw]: https://github.com/chenjyw +[eyepulp]: https://github.com/eyepulp +[dhepper]: https://github.com/dhepper +[hamishcampbell]: https://github.com/hamishcampbell +[avinash240]: https://github.com/avinash240 +[jsummerfield]: https://github.com/jsummerfield +[rouge8]: https://github.com/rouge8 diff --git a/docs/topics/csrf.md b/docs/topics/csrf.md deleted file mode 100644 index 043144c1..00000000 --- a/docs/topics/csrf.md +++ /dev/null @@ -1,12 +0,0 @@ -# Working with AJAX and CSRF - -> "Take a close look at possible CSRF / XSRF vulnerabilities on your own websites. They're the worst kind of vulnerability -- very easy to exploit by attackers, yet not so intuitively easy to understand for software developers, at least until you've been bitten by one." -> -> — [Jeff Atwood][cite] - -* Explain need to add CSRF token to AJAX requests. -* Explain deferred CSRF style used by REST framework -* Why you should use Django's standard login/logout views, and not REST framework view - - -[cite]: http://www.codinghorror.com/blog/2008/10/preventing-csrf-and-xsrf-attacks.html diff --git a/docs/topics/migration.md b/docs/topics/migration.md deleted file mode 100644 index 25fc9074..00000000 --- a/docs/topics/migration.md +++ /dev/null @@ -1,89 +0,0 @@ -# 2.0 Migration Guide - -> Move fast and break things -> -> — Mark Zuckerberg, [the Hacker Way][cite]. - -REST framework 2.0 introduces a radical redesign of the core components, and a large number of backwards breaking changes. - -### Serialization redesign. - -REST framework's serialization and deserialization previously used a slightly odd combination of serializers for output, and Django Forms and Model Forms for input. The serialization core has been completely redesigned based on work that was originally intended for Django core. - -2.0's form-like serializers comprehensively address those issues, and are a much more flexible and clean solution to the problems around accepting both form-based and non-form based inputs. - -### Generic views improved. - -When REST framework 0.1 was released the current Django version was 1.2. REST framework included a backport of the Django 1.3's upcoming `View` class, but it didn't take full advantage of the generic view implementations. - -As of 2.0 the generic views in REST framework tie in much more cleanly and obviously with Django's existing codebase, and the mixin architecture is radically simplified. - -### Cleaner request-response cycle. - -REST framework 2.0's request-response cycle is now much less complex. - -* Responses inherit from `SimpleTemplateResponse`, allowing rendering to be delegated to the response, not handled by the view. -* Requests extend the regular `HttpRequest`, allowing authentication and parsing to be delegated to the request, not handled by the view. - -### Renamed attributes & classes. - -Various attributes and classes have been renamed in order to fit in better with Django's conventions. - -## Example: Blog Posts API - -Let's take a look at an example from the REST framework 0.4 documentation... - - from djangorestframework.resources import ModelResource - from djangorestframework.reverse import reverse - from blogpost.models import BlogPost, Comment - - - class BlogPostResource(ModelResource): - """ - A Blog Post has a *title* and *content*, and can be associated - with zero or more comments. - """ - model = BlogPost - fields = ('created', 'title', 'slug', 'content', 'url', 'comments') - ordering = ('-created',) - - def url(self, instance): - return reverse('blog-post', - kwargs={'key': instance.key}, - request=self.request) - - def comments(self, instance): - return reverse('comments', - kwargs={'blogpost': instance.key}, - request=self.request) - - - class CommentResource(ModelResource): - """ - A Comment is associated with a given Blog Post and has a - *username* and *comment*, and optionally a *rating*. - """ - model = Comment - fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost') - ordering = ('-created',) - - def blogpost(self, instance): - return reverse('blog-post', - kwargs={'key': instance.blogpost.key}, - request=self.request) - -There's a bit of a mix of concerns going on there. We've got some information about how the data should be serialized, such as the `fields` attribute, and some information about how it should be retrieved from the database - the `ordering` attribute. - -Let's start to re-write this for REST framework 2.0. - - from rest_framework import serializers - - class BlogPostSerializer(serializers.HyperlinkedModelSerializer): - model = BlogPost - fields = ('created', 'title', 'slug', 'content', 'url', 'comments') - - class CommentSerializer(serializers.HyperlinkedModelSerializer): - model = Comment - fields = ('username', 'comment', 'created', 'rating', 'url', 'blogpost') - -[cite]: http://www.wired.com/business/2012/02/zuck-letter/ diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 71fa3c03..560dd305 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -4,10 +4,245 @@ > > — Eric S. Raymond, [The Cathedral and the Bazaar][cite]. -## 2.1.x series +## Versioning + +Minor version numbers (0.0.x) are used for changes that are API compatible. You should be able to upgrade between minor point releases without any other code changes. + +Medium version numbers (0.x.0) may include API changes, in line with the [deprecation policy][deprecation-policy]. You should read the release notes carefully before upgrading between medium point releases. + +Major version numbers (x.0.0) are reserved for substantial project milestones. No major point releases are currently planned. + +## Deprecation policy + +REST framework releases follow a formal deprecation policy, which is in line with [Django's deprecation policy][django-deprecation-policy]. + +The timeline for deprecation of a feature present in version 1.0 would work as follows: + +* Version 1.1 would remain **fully backwards compatible** with 1.0, but would raise `PendingDeprecationWarning` warnings if you use the feature that are due to be deprecated. These warnings are **silent by default**, but can be explicitly enabled when you're ready to start migrating any required changes. For example if you start running your tests using `python -Wd manage.py test`, you'll be warned of any API changes you need to make. + +* Version 1.2 would escalate these warnings to `DeprecationWarning`, which is loud by default. + +* Version 1.3 would remove the deprecated bits of API entirely. + +Note that in line with Django's policy, any parts of the framework not mentioned in the documentation should generally be considered private API, and may be subject to change. + +## Upgrading + +To upgrade Django REST framework to the latest version, use pip: + + pip install -U djangorestframework + +You can determine your currently installed version using `pip freeze`: + + pip freeze | grep djangorestframework + +--- + +## 2.3.x series ### Master +* Bugfix: HyperlinkedIdentityField now uses `lookup_field` kwarg. + +### 2.3.2 + +**Date**: 16th May 2013 + +* Added SearchFilter +* Added OrderingFilter +* Added GenericViewSet +* Bugfix: Multiple `@action` and `@link` methods now allowed on viewsets. +* Bugfix: Fix API Root view issue with DjangoModelPermissions + +### 2.3.2 + +**Date**: 8th May 2013 + +* Bugfix: Fix `TIME_FORMAT`, `DATETIME_FORMAT` and `DATE_FORMAT` settings. +* Bugfix: Fix `DjangoFilterBackend` issue, failing when used on view with queryset attribute. + +### 2.3.1 + +**Date**: 7th May 2013 + +* Bugfix: Fix breadcrumb rendering issue. + +### 2.3.0 + +**Date**: 7th May 2013 + +* ViewSets and Routers. +* ModelSerializers support reverse relations in 'fields' option. +* HyperLinkedModelSerializers support 'id' field in 'fields' option. +* Cleaner generic views. +* Support for multiple filter classes. +* FileUploadParser support for raw file uploads. +* DecimalField support. +* Made Login template easier to restyle. +* Bugfix: Fix issue with depth>1 on ModelSerializer. + +**Note**: See the [2.3 announcement][2.3-announcement] for full details. + +--- + +## 2.2.x series + +### 2.2.7 + +**Date**: 17th April 2013 + +* Loud failure when view does not return a `Response` or `HttpResponse`. +* Bugfix: Fix for Django 1.3 compatiblity. +* Bugfix: Allow overridden `get_object()` to work correctly. + +### 2.2.6 + +**Date**: 4th April 2013 + +* OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token. +* URL hyperlinking in browsable API now handles more cases correctly. +* Long HTTP headers in browsable API are broken in multiple lines when possible. +* Bugfix: Fix regression with DjangoFilterBackend not worthing correctly with single object views. +* Bugfix: OAuth should fail hard when invalid token used. +* Bugfix: Fix serializer potentially returning `None` object for models that define `__bool__` or `__len__`. + +### 2.2.5 + +**Date**: 26th March 2013 + +* Serializer support for bulk create and bulk update operations. +* Regression fix: Date and time fields return date/time objects by default. Fixes regressions caused by 2.2.2. See [#743][743] for more details. +* Bugfix: Fix 500 error is OAuth not attempted with OAuthAuthentication class installed. +* `Serializer.save()` now supports arbitrary keyword args which are passed through to the object `.save()` method. Mixins use `force_insert` and `force_update` where appropriate, resulting in one less database query. + +### 2.2.4 + +**Date**: 13th March 2013 + +* OAuth 2 support. +* OAuth 1.0a support. +* Support X-HTTP-Method-Override header. +* Filtering backends are now applied to the querysets for object lookups as well as lists. (Eg you can use a filtering backend to control which objects should 404) +* Deal with error data nicely when deserializing lists of objects. +* Extra override hook to configure `DjangoModelPermissions` for unauthenticated users. +* Bugfix: Fix regression which caused extra database query on paginated list views. +* Bugfix: Fix pk relationship bug for some types of 1-to-1 relations. +* Bugfix: Workaround for Django bug causing case where `Authtoken` could be registered for cascade delete from `User` even if not installed. + +### 2.2.3 + +**Date**: 7th March 2013 + +* Bugfix: Fix None values for for `DateField`, `DateTimeField` and `TimeField`. + +### 2.2.2 + +**Date**: 6th March 2013 + +* Support for custom input and output formats for `DateField`, `DateTimeField` and `TimeField`. +* Cleanup: Request authentication is no longer lazily evaluated, instead authentication is always run, which results in more consistent, obvious behavior. Eg. Supplying bad auth credentials will now always return an error response, even if no permissions are set on the view. +* Bugfix for serializer data being uncacheable with pickle protocol 0. +* Bugfixes for model field validation edge-cases. +* Bugfix for authtoken migration while using a custom user model and south. + +### 2.2.1 + +**Date**: 22nd Feb 2013 + +* Security fix: Use `defusedxml` package to address XML parsing vulnerabilities. +* Raw data tab added to browsable API. (Eg. Allow for JSON input.) +* Added TimeField. +* Serializer fields can be mapped to any method that takes no args, or only takes kwargs which have defaults. +* Unicode support for view names/descriptions in browsable API. +* Bugfix: request.DATA should return an empty `QueryDict` with no data, not `None`. +* Bugfix: Remove unneeded field validation, which caused extra queries. + +**Security note**: Following the [disclosure of security vulnerabilities][defusedxml-announce] in Python's XML parsing libraries, use of the `XMLParser` class now requires the `defusedxml` package to be installed. + +The security vulnerabilities only affect APIs which use the `XMLParser` class, by enabling it in any views, or by having it set in the `DEFAULT_PARSER_CLASSES` setting. Note that the `XMLParser` class is not enabled by default, so this change should affect a minority of users. + +### 2.2.0 + +**Date**: 13th Feb 2013 + +* Python 3 support. +* Added a `post_save()` hook to the generic views. +* Allow serializers to handle dicts as well as objects. +* Deprecate `ManyRelatedField()` syntax in favor of `RelatedField(many=True)` +* Deprecate `null=True` on relations in favor of `required=False`. +* Deprecate `blank=True` on CharFields, just use `required=False`. +* Deprecate optional `obj` argument in permissions checks in favor of `has_object_permission`. +* Deprecate implicit hyperlinked relations behavior. +* Bugfix: Fix broken DjangoModelPermissions. +* Bugfix: Allow serializer output to be cached. +* Bugfix: Fix styling on browsable API login. +* Bugfix: Fix issue with deserializing empty to-many relations. +* Bugfix: Ensure model field validation is still applied for ModelSerializer subclasses with an custom `.restore_object()` method. + +**Note**: See the [2.2 announcement][2.2-announcement] for full details. + +--- + +## 2.1.x series + +### 2.1.17 + +**Date**: 26th Jan 2013 + +* Support proper 401 Unauthorized responses where appropriate, instead of always using 403 Forbidden. +* Support json encoding of timedelta objects. +* `format_suffix_patterns()` now supports `include` style URL patterns. +* Bugfix: Fix issues with custom pagination serializers. +* Bugfix: Nested serializers now accept `source='*'` argument. +* Bugfix: Return proper validation errors when incorrect types supplied for relational fields. +* Bugfix: Support nullable FKs with `SlugRelatedField`. +* Bugfix: Don't call custom validation methods if the field has an error. + +**Note**: If the primary authentication class is `TokenAuthentication` or `BasicAuthentication`, a view will now correctly return 401 responses to unauthenticated access, with an appropriate `WWW-Authenticate` header, instead of 403 responses. + +### 2.1.16 + +**Date**: 14th Jan 2013 + +* Deprecate `django.utils.simplejson` in favor of Python 2.6's built-in json module. +* Bugfix: `auto_now`, `auto_now_add` and other `editable=False` fields now default to read-only. +* Bugfix: PK fields now only default to read-only if they are an AutoField or if `editable=False`. +* Bugfix: Validation errors instead of exceptions when serializers receive incorrect types. +* Bugfix: Validation errors instead of exceptions when related fields receive incorrect types. +* Bugfix: Handle ObjectDoesNotExist exception when serializing null reverse one-to-one + +**Note**: Prior to 2.1.16, The Decimals would render in JSON using floating point if `simplejson` was installed, but otherwise render using string notation. Now that use of `simplejson` has been deprecated, Decimals will consistently render using string notation. See [#582] for more details. + +### 2.1.15 + +**Date**: 3rd Jan 2013 + +* Added `PATCH` support. +* Added `RetrieveUpdateAPIView`. +* Remove unused internal `save_m2m` flag on `ModelSerializer.save()`. +* Tweak behavior of hyperlinked fields with an explicit format suffix. +* Relation changes are now persisted in `.save()` instead of in `.restore_object()`. +* Bugfix: Fix issue with FileField raising exception instead of validation error when files=None. +* Bugfix: Partial updates should not set default values if field is not included. + +### 2.1.14 + +**Date**: 31st Dec 2012 + +* Bugfix: ModelSerializers now include reverse FK fields on creation. +* Bugfix: Model fields with `blank=True` are now `required=False` by default. +* Bugfix: Nested serializers now support nullable relationships. + +**Note**: From 2.1.14 onwards, relational fields move out of the `fields.py` module and into the new `relations.py` module, in order to separate them from regular data type fields, such as `CharField` and `IntegerField`. + +This change will not affect user code, so long as it's following the recommended import style of `from rest_framework import serializers` and referring to fields using the style `serializers.PrimaryKeyRelatedField`. + + +### 2.1.13 + +**Date**: 28th Dec 2012 + +* Support configurable `STATICFILES_STORAGE` storage. * Bugfix: Related fields now respect the required flag, and may be required=False. ### 2.1.12 @@ -15,14 +250,14 @@ **Date**: 21st Dec 2012 * Bugfix: Fix bug that could occur using ChoiceField. -* Bugfix: Fix exception in browseable API on DELETE. +* Bugfix: Fix exception in browsable API on DELETE. * Bugfix: Fix issue where pk was was being set to a string if set by URL kwarg. ### 2.1.11 **Date**: 17th Dec 2012 -* Bugfix: Fix issue with M2M fields in browseable API. +* Bugfix: Fix issue with M2M fields in browsable API. ### 2.1.10 @@ -105,7 +340,7 @@ * Support use of HTML exception templates. Eg. `403.html` * Hyperlinked fields take optional `slug_field`, `slug_url_kwarg` and `pk_url_kwarg` arguments. -* Bugfix: Deal with optional trailing slashs properly when generating breadcrumbs. +* Bugfix: Deal with optional trailing slashes properly when generating breadcrumbs. * Bugfix: Make textareas same width as other fields in browsable API. * Private API change: `.get_serializer` now uses same `instance` and `data` ordering as serializer initialization. @@ -113,16 +348,16 @@ **Date**: 5th Nov 2012 -**Warning**: Please read [this thread][2.1.0-notes] regarding the `instance` and `data` keyword args before updating to 2.1.0. - * **Serializer `instance` and `data` keyword args have their position swapped.** * `queryset` argument is now optional on writable model fields. * Hyperlinked related fields optionally take `slug_field` and `slug_url_kwarg` arguments. * Support Django's cache framework. * Minor field improvements. (Don't stringify dicts, more robust many-pk fields.) -* Bugfix: Support choice field in Browseable API. +* Bugfix: Support choice field in Browsable API. * Bugfix: Related fields with `read_only=True` do not require a `queryset` argument. +**API-incompatible changes**: Please read [this thread][2.1.0-notes] regarding the `instance` and `data` keyword args before updating to 2.1.0. + --- ## 2.0.x series @@ -159,9 +394,9 @@ * Allow views to specify template used by TemplateRenderer * More consistent error responses * Some serializer fixes -* Fix internet explorer ajax behaviour +* Fix internet explorer ajax behavior * Minor xml and yaml fixes -* Improve setup (eg use staticfiles, not the defunct ADMIN_MEDIA_PREFIX) +* Improve setup (e.g. use staticfiles, not the defunct ADMIN_MEDIA_PREFIX) * Sensible absolute URL generation, not using hacky set_script_prefix --- @@ -172,13 +407,13 @@ * Added DjangoModelPermissions class to support `django.contrib.auth` style permissions. * Use `staticfiles` for css files. - - Easier to override. Won't conflict with customised admin styles (eg grappelli) + - Easier to override. Won't conflict with customized admin styles (e.g. grappelli) * Templates are now nicely namespaced. - Allows easier overriding. * Drop implied 'pk' filter if last arg in urlconf is unnamed. - - Too magical. Explict is better than implicit. -* Saner template variable autoescaping. -* Tider setup.py + - Too magical. Explicit is better than implicit. +* Saner template variable auto-escaping. +* Tidier setup.py * Updated for URLObject 2.0 * Bugfixes: - Bug with PerUserThrottling when user contains unicode chars. @@ -266,5 +501,14 @@ * Initial release. [cite]: http://www.catb.org/~esr/writings/cathedral-bazaar/cathedral-bazaar/ar01s04.html +[deprecation-policy]: #deprecation-policy +[django-deprecation-policy]: https://docs.djangoproject.com/en/dev/internals/release-process/#internal-release-deprecation-policy +[defusedxml-announce]: http://blog.python.org/2013/02/announcing-defusedxml-fixes-for-xml.html +[2.2-announcement]: 2.2-announcement.md +[2.3-announcement]: 2.3-announcement.md +[743]: https://github.com/tomchristie/django-rest-framework/pull/743 +[staticfiles14]: https://docs.djangoproject.com/en/1.4/howto/static-files/#with-a-template-tag +[staticfiles13]: https://docs.djangoproject.com/en/1.3/howto/static-files/#with-a-template-tag [2.1.0-notes]: https://groups.google.com/d/topic/django-rest-framework/Vv2M0CMY9bg/discussion [announcement]: rest-framework-2-announcement.md +[#582]: https://github.com/tomchristie/django-rest-framework/issues/582 diff --git a/docs/topics/rest-framework-2-announcement.md b/docs/topics/rest-framework-2-announcement.md index 885d1918..309548d0 100644 --- a/docs/topics/rest-framework-2-announcement.md +++ b/docs/topics/rest-framework-2-announcement.md @@ -62,19 +62,19 @@ REST framework 2 also allows you to work with both function-based and class-base Pretty much every aspect of REST framework has been reworked, with the aim of ironing out some of the design flaws of the previous versions. Each of the components of REST framework are cleanly decoupled, and can be used independantly of each-other, and there are no monolithic resource classes, overcomplicated mixin combinations, or opinionated serialization or URL routing decisions. -## The Browseable API +## The Browsable API Django REST framework's most unique feature is the way it is able to serve up both machine-readable representations, and a fully browsable HTML representation to the same endpoints. -Browseable Web APIs are easier to work with, visualize and debug, and generally makes it easier and more frictionless to inspect and work with. +Browsable Web APIs are easier to work with, visualize and debug, and generally makes it easier and more frictionless to inspect and work with. -With REST framework 2, the browseable API gets a snazzy new bootstrap-based theme that looks great and is even nicer to work with. +With REST framework 2, the browsable API gets a snazzy new bootstrap-based theme that looks great and is even nicer to work with. There are also some functionality improvments - actions such as as `POST` and `DELETE` will only display if the user has the appropriate permissions. -![Browseable API][image] +![Browsable API][image] -**Image above**: An example of the browseable API in REST framework 2 +**Image above**: An example of the browsable API in REST framework 2 ## Documentation diff --git a/docs/topics/rest-hypermedia-hateoas.md b/docs/topics/rest-hypermedia-hateoas.md index 10ab9dfe..43e5a8c6 100644 --- a/docs/topics/rest-hypermedia-hateoas.md +++ b/docs/topics/rest-hypermedia-hateoas.md @@ -26,7 +26,7 @@ REST framework is an agnostic Web API toolkit. It does help guide you towards b ## What REST framework provides. -It is self evident that REST framework makes it possible to build Hypermedia APIs. The browseable API that it offers is built on HTML - the hypermedia language of the web. +It is self evident that REST framework makes it possible to build Hypermedia APIs. The browsable API that it offers is built on HTML - the hypermedia language of the web. REST framework also includes [serialization] and [parser]/[renderer] components that make it easy to build appropriate media types, [hyperlinked relations][fields] for building well-connected systems, and great support for [content negotiation][conneg]. diff --git a/docs/tutorial/1-serialization.md b/docs/tutorial/1-serialization.md index e61fb946..ed54a876 100644 --- a/docs/tutorial/1-serialization.md +++ b/docs/tutorial/1-serialization.md @@ -4,11 +4,11 @@ This tutorial will cover creating a simple pastebin code highlighting Web API. Along the way it will introduce the various components that make up REST framework, and give you a comprehensive understanding of how everything fits together. -The tutorial is fairly in-depth, so you should probably get a cookie and a cup of your favorite brew before getting started.<!-- If you just want a quick overview, you should head over to the [quickstart] documentation instead. --> +The tutorial is fairly in-depth, so you should probably get a cookie and a cup of your favorite brew before getting started. If you just want a quick overview, you should head over to the [quickstart] documentation instead. --- -**Note**: The final code for this tutorial is available in the [tomchristie/rest-framework-tutorial][repo] repository on GitHub. There is also a sandbox version for testing, [available here][sandbox]. +**Note**: The code for this tutorial is available in the [tomchristie/rest-framework-tutorial][repo] repository on GitHub. The completed implementation is also online as a sandbox version for testing, [available here][sandbox]. --- @@ -60,7 +60,7 @@ We'll also need to add our new `snippets` app and the `rest_framework` app to `I INSTALLED_APPS = ( ... 'rest_framework', - 'snippets' + 'snippets', ) We also need to wire up the root urlconf, in the `tutorial/urls.py` file, to include our snippet app's URLs. @@ -73,19 +73,20 @@ Okay, we're ready to roll. ## Creating a model to work with -For the purposes of this tutorial we're going to start by creating a simple `Snippet` model that is used to store code snippets. Go ahead and edit the `snippets` app's `models.py` file. +For the purposes of this tutorial we're going to start by creating a simple `Snippet` model that is used to store code snippets. Go ahead and edit the `snippets` app's `models.py` file. Note: Good programming practices include comments. Although you will find them in our repository version of this tutorial code, we have omitted them here to focus on the code itself. from django.db import models from pygments.lexers import get_all_lexers from pygments.styles import get_all_styles - - LANGUAGE_CHOICES = sorted([(item[1][0], item[0]) for item in get_all_lexers()]) - STYLE_CHOICES = sorted((item, item) for item in list(get_all_styles())) + + LEXERS = [item for item in get_all_lexers() if item[1]] + LANGUAGE_CHOICES = sorted([(item[1][0], item[0]) for item in LEXERS]) + STYLE_CHOICES = sorted((item, item) for item in get_all_styles()) class Snippet(models.Model): created = models.DateTimeField(auto_now_add=True) - title = models.CharField(max_length=100, default='') + title = models.CharField(max_length=100, blank=True, default='') code = models.TextField() linenos = models.BooleanField(default=False) language = models.CharField(choices=LANGUAGE_CHOICES, @@ -108,7 +109,7 @@ The first thing we need to get started on our Web API is provide a way of serial from django.forms import widgets from rest_framework import serializers - from snippets import models + from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES class SnippetSerializer(serializers.Serializer): @@ -118,26 +119,30 @@ The first thing we need to get started on our Web API is provide a way of serial code = serializers.CharField(widget=widgets.Textarea, max_length=100000) linenos = serializers.BooleanField(required=False) - language = serializers.ChoiceField(choices=models.LANGUAGE_CHOICES, + language = serializers.ChoiceField(choices=LANGUAGE_CHOICES, default='python') - style = serializers.ChoiceField(choices=models.STYLE_CHOICES, + style = serializers.ChoiceField(choices=STYLE_CHOICES, default='friendly') def restore_object(self, attrs, instance=None): """ - Create or update a new snippet instance. + Create or update a new snippet instance, given a dictionary + of deserialized field values. + + Note that if we don't define this method, then deserializing + data will simply return a dictionary of items. """ if instance: # Update existing instance - instance.title = attrs['title'] - instance.code = attrs['code'] - instance.linenos = attrs['linenos'] - instance.language = attrs['language'] - instance.style = attrs['style'] + instance.title = attrs.get('title', instance.title) + instance.code = attrs.get('code', instance.code) + instance.linenos = attrs.get('linenos', instance.linenos) + instance.language = attrs.get('language', instance.language) + instance.style = attrs.get('style', instance.style) return instance # Create new instance - return models.Snippet(**attrs) + return Snippet(**attrs) The first part of serializer class defines the fields that get serialized/deserialized. The `restore_object` method defines how fully fledged instances get created when deserializing data. @@ -149,13 +154,16 @@ Before we go any further we'll familiarize ourselves with using our new Serializ python manage.py shell -Okay, once we've got a few imports out of the way, let's create a code snippet to work with. +Okay, once we've got a few imports out of the way, let's create a couple of code snippets to work with. from snippets.models import Snippet from snippets.serializers import SnippetSerializer from rest_framework.renderers import JSONRenderer from rest_framework.parsers import JSONParser + snippet = Snippet(code='foo = "bar"\n') + snippet.save() + snippet = Snippet(code='print "hello, world"\n') snippet.save() @@ -163,13 +171,13 @@ We've now got a few snippet instances to play with. Let's take a look at serial serializer = SnippetSerializer(snippet) serializer.data - # {'pk': 1, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'} + # {'pk': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'} At this point we've translated the model instance into python native datatypes. To finalize the serialization process we render the data into `json`. content = JSONRenderer().render(serializer.data) content - # '{"pk": 1, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}' + # '{"pk": 2, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}' Deserialization is similar. First we parse a stream into python native datatypes... @@ -188,9 +196,15 @@ Deserialization is similar. First we parse a stream into python native datatype Notice how similar the API is to working with forms. The similarity should become even more apparent when we start writing views that use our serializer. +We can also serialize querysets instead of model instances. To do so we simply add a `many=True` flag to the serializer arguments. + + serializer = SnippetSerializer(Snippet.objects.all(), many=True) + serializer.data + # [{'pk': 1, 'title': u'', 'code': u'foo = "bar"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'}, {'pk': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'}] + ## Using ModelSerializers -Our `SnippetSerializer` class is replicating a lot of information that's also contained in the `Snippet` model. It would be nice if we could keep out code a bit more concise. +Our `SnippetSerializer` class is replicating a lot of information that's also contained in the `Snippet` model. It would be nice if we could keep our code a bit more concise. In the same way that Django provides both `Form` classes and `ModelForm` classes, REST framework includes both `Serializer` classes, and `ModelSerializer` classes. @@ -202,8 +216,6 @@ Open the file `snippets/serializers.py` again, and edit the `SnippetSerializer` model = Snippet fields = ('id', 'title', 'code', 'linenos', 'language', 'style') - - ## Writing regular Django views using our Serializer Let's see how we can write some API views using our new Serializer class. @@ -229,7 +241,6 @@ Edit the `snippet/views.py` file, and add the following. kwargs['content_type'] = 'application/json' super(JSONResponse, self).__init__(content, **kwargs) - The root of our API is going to be a view that supports listing all the existing snippets, or creating a new snippet. @csrf_exempt @@ -239,7 +250,7 @@ The root of our API is going to be a view that supports listing all the existing """ if request.method == 'GET': snippets = Snippet.objects.all() - serializer = SnippetSerializer(snippets) + serializer = SnippetSerializer(snippets, many=True) return JSONResponse(serializer.data) elif request.method == 'POST': @@ -288,16 +299,45 @@ Finally we need to wire these views up. Create the `snippets/urls.py` file: urlpatterns = patterns('snippets.views', url(r'^snippets/$', 'snippet_list'), - url(r'^snippets/(?P<pk>[0-9]+)/$', 'snippet_detail') + url(r'^snippets/(?P<pk>[0-9]+)/$', 'snippet_detail'), ) It's worth noting that there are a couple of edge cases we're not dealing with properly at the moment. If we send malformed `json`, or if a request is made with a method that the view doesn't handle, then we'll end up with a 500 "server error" response. Still, this'll do for now. ## Testing our first attempt at a Web API -**TODO: Describe using runserver and making example requests from console** +Now we can start up a sample server that serves our snippets. + +Quit out of the shell... + + quit() + +...and start up Django's development server. + + python manage.py runserver + + Validating models... + + 0 errors found + Django version 1.4.3, using settings 'tutorial.settings' + Development server is running at http://127.0.0.1:8000/ + Quit the server with CONTROL-C. + +In another terminal window, we can test the server. + +We can get a list of all of the snippets. + + curl http://127.0.0.1:8000/snippets/ + + [{"id": 1, "title": "", "code": "foo = \"bar\"\n", "linenos": false, "language": "python", "style": "friendly"}, {"id": 2, "title": "", "code": "print \"hello, world\"\n", "linenos": false, "language": "python", "style": "friendly"}] + +Or we can get a particular snippet by referencing its id. + + curl http://127.0.0.1:8000/snippets/2/ + + {"id": 2, "title": "", "code": "print \"hello, world\"\n", "linenos": false, "language": "python", "style": "friendly"} -**TODO: Describe opening in a web browser and viewing json output** +Similarly, you can have the same json displayed by visiting these URLs in a web browser. ## Where are we now diff --git a/docs/tutorial/2-requests-and-responses.md b/docs/tutorial/2-requests-and-responses.md index 08cf91cd..260c4d83 100644 --- a/docs/tutorial/2-requests-and-responses.md +++ b/docs/tutorial/2-requests-and-responses.md @@ -8,7 +8,7 @@ Let's introduce a couple of essential building blocks. REST framework introduces a `Request` object that extends the regular `HttpRequest`, and provides more flexible request parsing. The core functionality of the `Request` object is the `request.DATA` attribute, which is similar to `request.POST`, but more useful for working with Web APIs. request.POST # Only handles form data. Only works for 'POST' method. - request.DATA # Handles arbitrary data. Works any HTTP request with content. + request.DATA # Handles arbitrary data. Works for 'POST', 'PUT' and 'PATCH' methods. ## Response objects @@ -31,7 +31,6 @@ These wrappers provide a few bits of functionality such as making sure you recei The wrappers also provide behaviour such as returning `405 Method Not Allowed` responses when appropriate, and handling any `ParseError` exception that occurs when accessing `request.DATA` with malformed input. - ## Pulling it all together Okay, let's go ahead and start using these new components to write a few views. @@ -52,7 +51,7 @@ We don't need our `JSONResponse` class anymore, so go ahead and delete that. On """ if request.method == 'GET': snippets = Snippet.objects.all() - serializer = SnippetSerializer(snippets) + serializer = SnippetSerializer(snippets, many=True) return Response(serializer.data) elif request.method == 'POST': @@ -63,7 +62,6 @@ We don't need our `JSONResponse` class anymore, so go ahead and delete that. On else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - Our instance view is an improvement over the previous example. It's a little more concise, and the code now feels very similar to if we were working with the Forms API. We're also using named status codes, which makes the response meanings more obvious. Here is the view for an individual snippet. @@ -77,11 +75,11 @@ Here is the view for an individual snippet. snippet = Snippet.objects.get(pk=pk) except Snippet.DoesNotExist: return Response(status=status.HTTP_404_NOT_FOUND) - + if request.method == 'GET': serializer = SnippetSerializer(snippet) return Response(serializer.data) - + elif request.method == 'PUT': serializer = SnippetSerializer(snippet, data=request.DATA) if serializer.is_valid(): @@ -117,7 +115,7 @@ Now update the `urls.py` file slightly, to append a set of `format_suffix_patter urlpatterns = patterns('snippets.views', url(r'^snippets/$', 'snippet_list'), - url(r'^snippets/(?P<pk>[0-9]+)$', 'snippet_detail') + url(r'^snippets/(?P<pk>[0-9]+)$', 'snippet_detail'), ) urlpatterns = format_suffix_patterns(urlpatterns) @@ -128,16 +126,43 @@ We don't necessarily need to add these extra url patterns in, but it gives us a Go ahead and test the API from the command line, as we did in [tutorial part 1][tut-1]. Everything is working pretty similarly, although we've got some nicer error handling if we send invalid requests. -**TODO: Describe using accept headers, content-type headers, and format suffixed URLs** +We can get a list of all of the snippets, as before. + + curl http://127.0.0.1:8000/snippets/ + + [{"id": 1, "title": "", "code": "foo = \"bar\"\n", "linenos": false, "language": "python", "style": "friendly"}, {"id": 2, "title": "", "code": "print \"hello, world\"\n", "linenos": false, "language": "python", "style": "friendly"}] + +We can control the format of the response that we get back, either by using the `Accept` header: + + curl http://127.0.0.1:8000/snippets/ -H 'Accept: application/json' # Request JSON + curl http://127.0.0.1:8000/snippets/ -H 'Accept: text/html' # Request HTML + +Or by appending a format suffix: + + curl http://127.0.0.1:8000/snippets/.json # JSON suffix + curl http://127.0.0.1:8000/snippets/.api # Browsable API suffix + +Similarly, we can control the format of the request that we send, using the `Content-Type` header. + + # POST using form data + curl -X POST http://127.0.0.1:8000/snippets/ -d "code=print 123" + + {"id": 3, "title": "", "code": "123", "linenos": false, "language": "python", "style": "friendly"} + + # POST using JSON + curl -X POST http://127.0.0.1:8000/snippets/ -d '{"code": "print 456"}' -H "Content-Type: application/json" + + {"id": 4, "title": "", "code": "print 456", "linenos": true, "language": "python", "style": "friendly"} Now go and open the API in a web browser, by visiting [http://127.0.0.1:8000/snippets/][devserver]. ### Browsability -Because the API chooses a return format based on what the client asks for, it will, by default, return an HTML-formatted representation of the resource when that resource is requested by a browser. This allows for the API to be easily browsable and usable by humans. +Because the API chooses the content type of the response based on the client request, it will, by default, return an HTML-formatted representation of the resource when that resource is requested by a web browser. This allows for the API to return a fully web-browsable HTML representation. -See the [browsable api][browseable-api] topic for more information about the browsable API feature and how to customize it. +Having a web-browsable API is a huge usability win, and makes developing and using your API much easier. It also dramatically lowers the barrier-to-entry for other developers wanting to inspect and work with your API. +See the [browsable api][browsable-api] topic for more information about the browsable API feature and how to customize it. ## What's next? @@ -145,6 +170,6 @@ In [tutorial part 3][tut-3], we'll start using class based views, and see how ge [json-url]: http://example.com/api/items/4.json [devserver]: http://127.0.0.1:8000/snippets/ -[browseable-api]: ../topics/browsable-api.md +[browsable-api]: ../topics/browsable-api.md [tut-1]: 1-serialization.md [tut-3]: 3-class-based-views.md diff --git a/docs/tutorial/3-class-based-views.md b/docs/tutorial/3-class-based-views.md index b115b022..70cf2c54 100644 --- a/docs/tutorial/3-class-based-views.md +++ b/docs/tutorial/3-class-based-views.md @@ -20,7 +20,7 @@ We'll start by rewriting the root view as a class based view. All this involves """ def get(self, request, format=None): snippets = Snippet.objects.all() - serializer = SnippetSerializer(snippets) + serializer = SnippetSerializer(snippets, many=True) return Response(serializer.data) def post(self, request, format=None): @@ -70,7 +70,7 @@ We'll also need to refactor our URLconf slightly now we're using class based vie urlpatterns = patterns('', url(r'^snippets/$', views.SnippetList.as_view()), - url(r'^snippets/(?P<pk>[0-9]+)/$', views.SnippetDetail.as_view()) + url(r'^snippets/(?P<pk>[0-9]+)/$', views.SnippetDetail.as_view()), ) urlpatterns = format_suffix_patterns(urlpatterns) @@ -92,8 +92,8 @@ Let's take a look at how we can compose our views by using the mixin classes. class SnippetList(mixins.ListModelMixin, mixins.CreateModelMixin, - generics.MultipleObjectAPIView): - model = Snippet + generics.GenericAPIView): + queryset = Snippet.objects.all() serializer_class = SnippetSerializer def get(self, request, *args, **kwargs): @@ -102,15 +102,15 @@ Let's take a look at how we can compose our views by using the mixin classes. def post(self, request, *args, **kwargs): return self.create(request, *args, **kwargs) -We'll take a moment to examine exactly what's happening here. We're building our view using `MultipleObjectAPIView`, and adding in `ListModelMixin` and `CreateModelMixin`. +We'll take a moment to examine exactly what's happening here. We're building our view using `GenericAPIView`, and adding in `ListModelMixin` and `CreateModelMixin`. The base class provides the core functionality, and the mixin classes provide the `.list()` and `.create()` actions. We're then explicitly binding the `get` and `post` methods to the appropriate actions. Simple enough stuff so far. class SnippetDetail(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - generics.SingleObjectAPIView): - model = Snippet + generics.GenericAPIView): + queryset = Snippet.objects.all() serializer_class = SnippetSerializer def get(self, request, *args, **kwargs): @@ -122,7 +122,7 @@ The base class provides the core functionality, and the mixin classes provide th def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) -Pretty similar. This time we're using the `SingleObjectAPIView` class to provide the core functionality, and adding in mixins to provide the `.retrieve()`, `.update()` and `.destroy()` actions. +Pretty similar. Again we're using the `GenericAPIView` class to provide the core functionality, and adding in mixins to provide the `.retrieve()`, `.update()` and `.destroy()` actions. ## Using generic class based views @@ -134,12 +134,12 @@ Using the mixin classes we've rewritten the views to use slightly less code than class SnippetList(generics.ListCreateAPIView): - model = Snippet + queryset = Snippet.objects.all() serializer_class = SnippetSerializer class SnippetDetail(generics.RetrieveUpdateDestroyAPIView): - model = Snippet + queryset = Snippet.objects.all() serializer_class = SnippetSerializer Wow, that's pretty concise. We've gotten a huge amount for free, and our code looks like good, clean, idiomatic Django. diff --git a/docs/tutorial/4-authentication-and-permissions.md b/docs/tutorial/4-authentication-and-permissions.md index 9576a7f0..f6c3efb0 100644 --- a/docs/tutorial/4-authentication-and-permissions.md +++ b/docs/tutorial/4-authentication-and-permissions.md @@ -22,14 +22,14 @@ We'd also need to make sure that when the model is saved, that we populate the h We'll need some extra imports: from pygments.lexers import get_lexer_by_name - from pygments.formatters import HtmlFormatter + from pygments.formatters.html import HtmlFormatter from pygments import highlight And now we can add a `.save()` method to our model class: def save(self, *args, **kwargs): """ - Use the `pygments` library to create an highlighted HTML + Use the `pygments` library to create a highlighted HTML representation of the code snippet. """ lexer = get_lexer_by_name(self.language) @@ -54,8 +54,10 @@ You might also want to create a few different users, to use for testing the API. Now that we've got some users to work with, we'd better add representations of those users to our API. Creating a new serializer is easy: + from django.contrib.auth.models import User + class UserSerializer(serializers.ModelSerializer): - snippets = serializers.ManyPrimaryKeyRelatedField() + snippets = serializers.PrimaryKeyRelatedField(many=True) class Meta: model = User @@ -66,18 +68,18 @@ Because `'snippets'` is a *reverse* relationship on the User model, it will not We'll also add a couple of views. We'd like to just use read-only views for the user representations, so we'll use the `ListAPIView` and `RetrieveAPIView` generic class based views. class UserList(generics.ListAPIView): - model = User + queryset = User.objects.all() serializer_class = UserSerializer - class UserInstance(generics.RetrieveAPIView): - model = User + class UserDetail(generics.RetrieveAPIView): + queryset = User.objects.all() serializer_class = UserSerializer Finally we need to add those views into the API, by referencing them from the URL conf. url(r'^users/$', views.UserList.as_view()), - url(r'^users/(?P<pk>[0-9]+)/$', views.UserInstance.as_view()) + url(r'^users/(?P<pk>[0-9]+)/$', views.UserDetail.as_view()), ## Associating Snippets with Users @@ -102,8 +104,6 @@ This field is doing something quite interesting. The `source` argument controls The field we've added is the untyped `Field` class, in contrast to the other typed fields, such as `CharField`, `BooleanField` etc... The untyped `Field` is always read-only, and will be used for serialized representations, but will not be used for updating model instances when they are deserialized. -**TODO: Explain the SessionAuthentication and BasicAuthentication classes, and demonstrate using HTTP basic authentication with curl requests** - ## Adding required permissions to views Now that code snippets are associated with users, we want to make sure that only authenticated users are able to create, update and delete code snippets. @@ -118,23 +118,21 @@ Then, add the following property to **both** the `SnippetList` and `SnippetDetai permission_classes = (permissions.IsAuthenticatedOrReadOnly,) -**TODO: Now that the permissions are restricted, demonstrate using HTTP basic authentication with curl requests** +## Adding login to the Browsable API -## Adding login to the Browseable API +If you open a browser and navigate to the browsable API at the moment, you'll find that you're no longer able to create new code snippets. In order to do so we'd need to be able to login as a user. -If you open a browser and navigate to the browseable API at the moment, you'll find that you're no longer able to create new code snippets. In order to do so we'd need to be able to login as a user. - -We can add a login view for use with the browseable API, by editing our URLconf once more. +We can add a login view for use with the browsable API, by editing our URLconf once more. Add the following import at the top of the file: from django.conf.urls import include -And, at the end of the file, add a pattern to include the login and logout views for the browseable API. +And, at the end of the file, add a pattern to include the login and logout views for the browsable API. urlpatterns += patterns('', url(r'^api-auth/', include('rest_framework.urls', - namespace='rest_framework')) + namespace='rest_framework')), ) The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use. The only restriction is that the included urls must use the `'rest_framework'` namespace. @@ -145,7 +143,7 @@ Once you've created a few code snippets, navigate to the '/users/' endpoint, and ## Object level permissions -Really we'd like all code snippets to be visible to anyone, but also make sure that only the user that created a code snippet is able update or delete it. +Really we'd like all code snippets to be visible to anyone, but also make sure that only the user that created a code snippet is able to update or delete it. To do that we're going to need to create a custom permission. @@ -159,12 +157,9 @@ In the snippets app, create a new file, `permissions.py` Custom permission to only allow owners of an object to edit it. """ - def has_permission(self, request, view, obj=None): - # Skip the check unless this is an object-level test - if obj is None: - return True - - # Read permissions are allowed to any request + def has_object_permission(self, request, view, obj): + # Read permissions are allowed to any request, + # so we'll always allow GET, HEAD or OPTIONS requests. if request.method in permissions.SAFE_METHODS: return True @@ -182,10 +177,31 @@ Make sure to also import the `IsOwnerOrReadOnly` class. Now, if you open a browser again, you find that the 'DELETE' and 'PUT' actions only appear on a snippet instance endpoint if you're logged in as the same user that created the code snippet. +## Authenticating with the API + +Because we now have a set of permissions on the API, we need to authenticate our requests to it if we want to edit any snippets. We haven't set up any [authentication classes][authentication], so the defaults are currently applied, which are `SessionAuthentication` and `BasicAuthentication`. + +When we interact with the API through the web browser, we can login, and the browser session will then provide the required authentication for the requests. + +If we're interacting with the API programmatically we need to explicitly provide the authentication credentials on each request. + +If we try to create a snippet without authenticating, we'll get an error: + + curl -i -X POST http://127.0.0.1:8000/snippets/ -d "code=print 123" + + {"detail": "Authentication credentials were not provided."} + +We can make a successful request by including the username and password of one of the users we created earlier. + + curl -X POST http://127.0.0.1:8000/snippets/ -d "code=print 789" -u tom:password + + {"id": 5, "owner": "tom", "title": "foo", "code": "print 789", "linenos": false, "language": "python", "style": "friendly"} + ## Summary We've now got a fairly fine-grained set of permissions on our Web API, and end points for users of the system and for the code snippets that they have created. -In [part 5][tut-5] of the tutorial we'll look at how we can tie everything together by creating an HTML endpoint for our hightlighted snippets, and improve the cohesion of our API by using hyperlinking for the relationships within the system. +In [part 5][tut-5] of the tutorial we'll look at how we can tie everything together by creating an HTML endpoint for our highlighted snippets, and improve the cohesion of our API by using hyperlinking for the relationships within the system. -[tut-5]: 5-relationships-and-hyperlinked-apis.md
\ No newline at end of file +[authentication]: ../api-guide/authentication.md +[tut-5]: 5-relationships-and-hyperlinked-apis.md diff --git a/docs/tutorial/5-relationships-and-hyperlinked-apis.md b/docs/tutorial/5-relationships-and-hyperlinked-apis.md index 216ca433..cb2e092c 100644 --- a/docs/tutorial/5-relationships-and-hyperlinked-apis.md +++ b/docs/tutorial/5-relationships-and-hyperlinked-apis.md @@ -1,4 +1,4 @@ -# Tutorial 5 - Relationships & Hyperlinked APIs +# Tutorial 5: Relationships & Hyperlinked APIs At the moment relationships within our API are represented by using primary keys. In this part of the tutorial we'll improve the cohesion and discoverability of our API, by instead using hyperlinking for relationships. @@ -15,8 +15,8 @@ Right now we have endpoints for 'snippets' and 'users', but we don't have a sing @api_view(('GET',)) def api_root(request, format=None): return Response({ - 'users': reverse('user-list', request=request), - 'snippets': reverse('snippet-list', request=request) + 'users': reverse('user-list', request=request, format=format), + 'snippets': reverse('snippet-list', request=request, format=format) }) Notice that we're using REST framework's `reverse` function in order to return fully-qualified URLs. @@ -34,8 +34,8 @@ Instead of using a concrete generic view, we'll use the base class for represent from rest_framework import renderers from rest_framework.response import Response - class SnippetHighlight(generics.SingleObjectAPIView): - model = Snippet + class SnippetHighlight(generics.GenericAPIView): + queryset = Snippet.objects.all() renderer_classes = (renderers.StaticHTMLRenderer,) def get(self, request, *args, **kwargs): @@ -70,8 +70,8 @@ The `HyperlinkedModelSerializer` has the following differences from `ModelSerial * It does not include the `pk` field by default. * It includes a `url` field, using `HyperlinkedIdentityField`. -* Relationships use `HyperlinkedRelatedField` and `ManyHyperlinkedRelatedField`, - instead of `PrimaryKeyRelatedField` and `ManyPrimaryKeyRelatedField`. +* Relationships use `HyperlinkedRelatedField`, + instead of `PrimaryKeyRelatedField`. We can easily re-write our existing serializers to use hyperlinking. @@ -86,7 +86,7 @@ We can easily re-write our existing serializers to use hyperlinking. class UserSerializer(serializers.HyperlinkedModelSerializer): - snippets = serializers.ManyHyperlinkedRelatedField(view_name='snippet-detail') + snippets = serializers.HyperlinkedRelatedField(many=True, view_name='snippet-detail') class Meta: model = User @@ -116,21 +116,21 @@ After adding all those names into our URLconf, our final `'urls.py'` file should url(r'^snippets/(?P<pk>[0-9]+)/$', views.SnippetDetail.as_view(), name='snippet-detail'), - url(r'^snippets/(?P<pk>[0-9]+)/highlight/$' + url(r'^snippets/(?P<pk>[0-9]+)/highlight/$', views.SnippetHighlight.as_view(), name='snippet-highlight'), url(r'^users/$', views.UserList.as_view(), name='user-list'), url(r'^users/(?P<pk>[0-9]+)/$', - views.UserInstance.as_view(), + views.UserDetail.as_view(), name='user-detail') )) # Login and logout views for the browsable API urlpatterns += patterns('', url(r'^api-auth/', include('rest_framework.urls', - namespace='rest_framework')) + namespace='rest_framework')), ) ## Adding pagination @@ -143,34 +143,16 @@ We can change the default list style to use pagination, by modifying our `settin 'PAGINATE_BY': 10 } -Note that settings in REST framework are all namespaced into a single dictionary setting, named 'REST_FRAMEWORK', which helps keep them well seperated from your other project settings. +Note that settings in REST framework are all namespaced into a single dictionary setting, named 'REST_FRAMEWORK', which helps keep them well separated from your other project settings. We could also customize the pagination style if we needed too, but in this case we'll just stick with the default. -## Reviewing our work +## Browsing the API -If we open a browser and navigate to the browseable API, you'll find that you can now work your way around the API simply by following links. +If we open a browser and navigate to the browsable API, you'll find that you can now work your way around the API simply by following links. You'll also be able to see the 'highlight' links on the snippet instances, that will take you to the highlighted code HTML representations. -We've now got a complete pastebin Web API, which is fully web browseable, and comes complete with authentication, per-object permissions, and multiple renderer formats. +In [part 6][tut-6] of the tutorial we'll look at how we can use ViewSets and Routers to reduce the amount of code we need to build our API. -We've walked through each step of the design process, and seen how if we need to customize anything we can gradually work our way down to simply using regular Django views. - -You can review the final [tutorial code][repo] on GitHub, or try out a live example in [the sandbox][sandbox]. - -## Onwards and upwards - -We've reached the end of our tutorial. If you want to get more involved in the REST framework project, here's a few places you can start: - -* Contribute on [GitHub][github] by reviewing and submitting issues, and making pull requests. -* Join the [REST framework discussion group][group], and help build the community. -* [Follow the author on Twitter][twitter] and say hi. - -**Now go build awesome things.** - -[repo]: https://github.com/tomchristie/rest-framework-tutorial -[sandbox]: http://restframework.herokuapp.com/ -[github]: https://github.com/tomchristie/django-rest-framework -[group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework -[twitter]: https://twitter.com/_tomchristie
\ No newline at end of file +[tut-6]: 6-viewsets-and-routers.md diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md new file mode 100644 index 00000000..277804e2 --- /dev/null +++ b/docs/tutorial/6-viewsets-and-routers.md @@ -0,0 +1,151 @@ +# Tutorial 6 - ViewSets & Routers + +REST framework includes an abstraction for dealing with `ViewSets`, that allows the developer to concentrate on modeling the state and interactions of the API, and leave the URL construction to be handled automatically, based on common conventions. + +`ViewSet` classes are almost the same thing as `View` classes, except that they provide operations such as `read`, or `update`, and not method handlers such as `get` or `put`. + +A `ViewSet` class is only bound to a set of method handlers at the last moment, when it is instantiated into a set of views, typically by using a `Router` class which handles the complexities of defining the URL conf for you. + +## Refactoring to use ViewSets + +Let's take our current set of views, and refactor them into view sets. + +First of all let's refactor our `UserListView` and `UserDetailView` views into a single `UserViewSet`. We can remove the two views, and replace then with a single class: + + class UserViewSet(viewsets.ReadOnlyModelViewSet): + """ + This viewset automatically provides `list` and `detail` actions. + """ + queryset = User.objects.all() + serializer_class = UserSerializer + +Here we've used `ReadOnlyModelViewSet` class to automatically provide the default 'read-only' operations. We're still setting the `queryset` and `serializer_class` attributes exactly as we did when we were using regular views, but we no longer need to provide the same information to two separate classes. + +Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. + + from rest_framework import viewsets + from rest_framework.decorators import link + + class SnippetViewSet(viewsets.ModelViewSet): + """ + This viewset automatically provides `list`, `create`, `retrieve`, + `update` and `destroy` actions. + + Additionally we also provide an extra `highlight` action. + """ + queryset = Snippet.objects.all() + serializer_class = SnippetSerializer + permission_classes = (permissions.IsAuthenticatedOrReadOnly, + IsOwnerOrReadOnly,) + + @link(renderer_classes=[renderers.StaticHTMLRenderer]) + def highlight(self, request, *args, **kwargs): + snippet = self.get_object() + return Response(snippet.highlighted) + + def pre_save(self, obj): + obj.owner = self.request.user + +This time we've used the `ModelViewSet` class in order to get the complete set of default read and write operations. + +Notice that we've also used the `@link` decorator to create a custom action, named `highlight`. This decorator can be used to add any custom endpoints that don't fit into the standard `create`/`update`/`delete` style. + +Custom actions which use the `@link` decorator will respond to `GET` requests. We could have instead used the `@action` decorator if we wanted an action that responded to `POST` requests. + +## Binding ViewSets to URLs explicitly + +The handler methods only get bound to the actions when we define the URLConf. +To see what's going on under the hood let's first explicitly create a set of views from our ViewSets. + +In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views. + + from snippets.resources import SnippetResource, UserResource + + snippet_list = SnippetViewSet.as_view({ + 'get': 'list', + 'post': 'create' + }) + snippet_detail = SnippetViewSet.as_view({ + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }) + snippet_highlight = SnippetViewSet.as_view({ + 'get': 'highlight' + }) + user_list = UserViewSet.as_view({ + 'get': 'list' + }) + user_detail = UserViewSet.as_view({ + 'get': 'retrieve' + }) + +Notice how we're creating multiple views from each `ViewSet` class, by binding the http methods to the required action for each view. + +Now that we've bound our resources into concrete views, that we can register the views with the URL conf as usual. + + urlpatterns = format_suffix_patterns(patterns('snippets.views', + url(r'^$', 'api_root'), + url(r'^snippets/$', snippet_list, name='snippet-list'), + url(r'^snippets/(?P<pk>[0-9]+)/$', snippet_detail, name='snippet-detail'), + url(r'^snippets/(?P<pk>[0-9]+)/highlight/$', snippet_highlight, name='snippet-highlight'), + url(r'^users/$', user_list, name='user-list'), + url(r'^users/(?P<pk>[0-9]+)/$', user_detail, name='user-detail') + )) + +## Using Routers + +Because we're using `ViewSet` classes rather than `View` classes, we actually don't need to design the URL conf ourselves. The conventions for wiring up resources into views and urls can be handled automatically, using a `Router` class. All we need to do is register the appropriate view sets with a router, and let it do the rest. + +Here's our re-wired `urls.py` file. + + from snippets import views + from rest_framework.routers import DefaultRouter + + # Create a router and register our viewsets with it. + router = DefaultRouter() + router.register(r'snippets', views.SnippetViewSet) + router.register(r'users', views.UserViewSet) + + # The API URLs are now determined automatically by the router. + # Additionally, we include the login URLs for the browseable API. + urlpatterns = patterns('', + url(r'^', include(router.urls)), + url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) + ) + +Registering the viewsets with the router is similar to providing a urlpattern. We include two arguments - the URL prefix for the views, and the viewset itself. + +The `DefaultRouter` class we're using also automatically creates the API root view for us, so we can now delete the `api_root` method from our `views` module. + +## Trade-offs between views vs viewsets + +Using viewsets can be a really useful abstraction. It helps ensure that URL conventions will be consistent across your API, minimizes the amount of code you need to write, and allows you to concentrate on the interactions and representations your API provides rather than the specifics of the URL conf. + +That doesn't mean it's always the right approach to take. There's a similar set of trade-offs to consider as when using class-based views instead of function based views. Using viewsets is less explicit than building your views individually. + +## Reviewing our work + +With an incredibly small amount of code, we've now got a complete pastebin Web API, which is fully web browseable, and comes complete with authentication, per-object permissions, and multiple renderer formats. + +We've walked through each step of the design process, and seen how if we need to customize anything we can gradually work our way down to simply using regular Django views. + +You can review the final [tutorial code][repo] on GitHub, or try out a live example in [the sandbox][sandbox]. + +## Onwards and upwards + +We've reached the end of our tutorial. If you want to get more involved in the REST framework project, here's a few places you can start: + +* Contribute on [GitHub][github] by reviewing and submitting issues, and making pull requests. +* Join the [REST framework discussion group][group], and help build the community. +* Follow [the author][twitter] on Twitter and say hi. + +**Now go build awesome things.** + + +[repo]: https://github.com/tomchristie/rest-framework-tutorial +[sandbox]: http://restframework.herokuapp.com/ +[github]: https://github.com/tomchristie/django-rest-framework +[group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework +[twitter]: https://twitter.com/_tomchristie diff --git a/docs/tutorial/quickstart.md b/docs/tutorial/quickstart.md index 74084541..52fe3acf 100644 --- a/docs/tutorial/quickstart.md +++ b/docs/tutorial/quickstart.md @@ -8,7 +8,7 @@ Create a new Django project, and start a new app called `quickstart`. Once you' First up we're going to define some serializers in `quickstart/serializers.py` that we'll use for our data representations. - from django.contrib.auth.models import User, Group, Permission + from django.contrib.auth.models import User, Group from rest_framework import serializers @@ -19,109 +19,64 @@ First up we're going to define some serializers in `quickstart/serializers.py` t class GroupSerializer(serializers.HyperlinkedModelSerializer): - permissions = serializers.ManySlugRelatedField( - slug_field='codename', - queryset=Permission.objects.all() - ) - class Meta: model = Group - fields = ('url', 'name', 'permissions') + fields = ('url', 'name') Notice that we're using hyperlinked relations in this case, with `HyperlinkedModelSerializer`. You can also use primary key and various other relationships, but hyperlinking is good RESTful design. -We've also overridden the `permission` field on the `GroupSerializer`. In this case we don't want to use a hyperlinked representation, but instead use the list of permission codenames associated with the group, so we've used a `ManySlugRelatedField`, using the `codename` field for the representation. - ## Views Right, we'd better write some views then. Open `quickstart/views.py` and get typing. from django.contrib.auth.models import User, Group - from rest_framework import generics - from rest_framework.decorators import api_view - from rest_framework.reverse import reverse - from rest_framework.response import Response + from rest_framework import viewsets from quickstart.serializers import UserSerializer, GroupSerializer - @api_view(['GET']) - def api_root(request, format=None): - """ - The entry endpoint of our API. - """ - return Response({ - 'users': reverse('user-list', request=request), - 'groups': reverse('group-list', request=request), - }) - - - class UserList(generics.ListCreateAPIView): - """ - API endpoint that represents a list of users. - """ - model = User - serializer_class = UserSerializer - - - class UserDetail(generics.RetrieveUpdateDestroyAPIView): + class UserViewSet(viewsets.ModelViewSet): """ - API endpoint that represents a single user. + API endpoint that allows users to be viewed or edited. """ - model = User + queryset = User.objects.all() serializer_class = UserSerializer - class GroupList(generics.ListCreateAPIView): + class GroupViewSet(viewsets.ModelViewSet): """ - API endpoint that represents a list of groups. + API endpoint that allows groups to be viewed or edited. """ - model = Group - serializer_class = GroupSerializer - - - class GroupDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint that represents a single group. - """ - model = Group + queryset = Group.objects.all() serializer_class = GroupSerializer -Let's take a moment to look at what we've done here before we move on. We have one function-based view representing the root of the API, and four class-based views which map to our database models, and specify which serializers should be used for representing that data. Pretty simple stuff. +Rather that write multiple views we're grouping together all the common behavior into classes called `ViewSets`. + +We can easily break these down into individual views if we need to, but using viewsets keeps the view logic nicely organized as well as being very concise. ## URLs -Okay, let's wire this baby up. On to `quickstart/urls.py`... +Okay, now let's wire up the API URLs. On to `quickstart/urls.py`... from django.conf.urls import patterns, url, include - from rest_framework.urlpatterns import format_suffix_patterns - from quickstart.views import UserList, UserDetail, GroupList, GroupDetail - - - urlpatterns = patterns('quickstart.views', - url(r'^$', 'api_root'), - url(r'^users/$', UserList.as_view(), name='user-list'), - url(r'^users/(?P<pk>\d+)/$', UserDetail.as_view(), name='user-detail'), - url(r'^groups/$', GroupList.as_view(), name='group-list'), - url(r'^groups/(?P<pk>\d+)/$', GroupDetail.as_view(), name='group-detail'), - ) - - - # Format suffixes - urlpatterns = format_suffix_patterns(urlpatterns, allowed=['json', 'api']) + from rest_framework import routers + from quickstart import views + router = routers.DefaultRouter() + router.register(r'users', views.UserViewSet) + router.register(r'groups', views.GroupViewSet) - # Default login/logout views - urlpatterns += patterns('', + # Wire up our API using automatic URL routing. + # Additionally, we include login URLs for the browseable API. + urlpatterns = patterns('', + url(r'^', include(router.urls)), url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')) ) -There's a few things worth noting here. - -Firstly the names `user-detail` and `group-detail` are important. We're using the default hyperlinked relationships without explicitly specifying the view names, so we need to use names of the style `{modelname}-detail` to represent the model instance views. +Because we're using viewsets instead of views, we can automatically generate the URL conf for our API, by simply registering the viewsets with a router class. -Secondly, we're modifying the urlpatterns using `format_suffix_patterns`, to append optional `.json` style suffixes to our URLs. +Again, if we need more control over the API URLs we can simply drop down to using regular class based views, and writing the URL conf explicitly. -Finally, we're including default login and logout views for use with the browsable API. That's optional, but useful if your API requires authentication and you want to use the browseable API. +Finally, we're including default login and logout views for use with the browsable API. That's optional, but useful if your API requires authentication and you want to use the browsable API. ## Settings @@ -37,6 +37,64 @@ page = open(os.path.join(docs_dir, 'template.html'), 'r').read() # shutil.rmtree(target) # shutil.copytree(source, target) + +# Hacky, but what the hell, it'll do the job +path_list = [ + 'index.md', + 'tutorial/quickstart.md', + 'tutorial/1-serialization.md', + 'tutorial/2-requests-and-responses.md', + 'tutorial/3-class-based-views.md', + 'tutorial/4-authentication-and-permissions.md', + 'tutorial/5-relationships-and-hyperlinked-apis.md', + 'tutorial/6-viewsets-and-routers.md', + 'api-guide/requests.md', + 'api-guide/responses.md', + 'api-guide/views.md', + 'api-guide/generic-views.md', + 'api-guide/viewsets.md', + 'api-guide/routers.md', + 'api-guide/parsers.md', + 'api-guide/renderers.md', + 'api-guide/serializers.md', + 'api-guide/fields.md', + 'api-guide/relations.md', + 'api-guide/authentication.md', + 'api-guide/permissions.md', + 'api-guide/throttling.md', + 'api-guide/filtering.md', + 'api-guide/pagination.md', + 'api-guide/content-negotiation.md', + 'api-guide/format-suffixes.md', + 'api-guide/reverse.md', + 'api-guide/exceptions.md', + 'api-guide/status-codes.md', + 'api-guide/settings.md', + 'topics/ajax-csrf-cors.md', + 'topics/browser-enhancements.md', + 'topics/browsable-api.md', + 'topics/rest-hypermedia-hateoas.md', + 'topics/contributing.md', + 'topics/rest-framework-2-announcement.md', + 'topics/2.2-announcement.md', + 'topics/2.3-announcement.md', + 'topics/release-notes.md', + 'topics/credits.md', +] + +prev_url_map = {} +next_url_map = {} +for idx in range(len(path_list)): + path = path_list[idx] + rel = '../' * path.count('/') + + if idx > 0: + prev_url_map[path] = rel + path_list[idx - 1][:-3] + suffix + + if idx < len(path_list) - 1: + next_url_map[path] = rel + path_list[idx + 1][:-3] + suffix + + for (dirpath, dirnames, filenames) in os.walk(docs_dir): relative_dir = dirpath.replace(docs_dir, '').lstrip(os.path.sep) build_dir = os.path.join(html_dir, relative_dir) @@ -46,6 +104,7 @@ for (dirpath, dirnames, filenames) in os.walk(docs_dir): for filename in filenames: path = os.path.join(dirpath, filename) + relative_path = os.path.join(relative_dir, filename) if not filename.endswith('.md'): if relative_dir: @@ -57,25 +116,55 @@ for (dirpath, dirnames, filenames) in os.walk(docs_dir): toc = '' text = open(path, 'r').read().decode('utf-8') + main_title = None + description = 'Django, API, REST' for line in text.splitlines(): if line.startswith('# '): title = line[2:].strip() template = main_header + description = description + ', ' + title elif line.startswith('## '): title = line[3:].strip() template = sub_header else: continue + if not main_title: + main_title = title anchor = title.lower().replace(' ', '-').replace(':-', '-').replace("'", '').replace('?', '').replace('.', '') template = template.replace('{{ title }}', title) template = template.replace('{{ anchor }}', anchor) toc += template + '\n' + if filename == 'index.md': + main_title = 'Django REST framework - APIs made easy' + else: + main_title = 'Django REST framework - ' + main_title + + prev_url = prev_url_map.get(relative_path) + next_url = next_url_map.get(relative_path) + content = markdown.markdown(text, ['headerid']) output = page.replace('{{ content }}', content).replace('{{ toc }}', toc).replace('{{ base_url }}', base_url).replace('{{ suffix }}', suffix).replace('{{ index }}', index) + output = output.replace('{{ title }}', main_title) + output = output.replace('{{ description }}', description) output = output.replace('{{ page_id }}', filename[:-3]) + + if prev_url: + output = output.replace('{{ prev_url }}', prev_url) + output = output.replace('{{ prev_url_disabled }}', '') + else: + output = output.replace('{{ prev_url }}', '#') + output = output.replace('{{ prev_url_disabled }}', 'disabled') + + if next_url: + output = output.replace('{{ next_url }}', next_url) + output = output.replace('{{ next_url_disabled }}', '') + else: + output = output.replace('{{ next_url }}', '#') + output = output.replace('{{ next_url_disabled }}', 'disabled') + output = re.sub(r'a href="([^"]*)\.md"', r'a href="\1%s"' % suffix, output) output = re.sub(r'<pre><code>:::bash', r'<pre class="prettyprint lang-bsh">', output) output = re.sub(r'<pre>', r'<pre class="prettyprint lang-py">', output) diff --git a/optionals.txt b/optionals.txt index 1d2358c6..1853f74b 100644 --- a/optionals.txt +++ b/optionals.txt @@ -1,3 +1,7 @@ markdown>=2.1.0 PyYAML>=3.10 +defusedxml>=0.3 django-filter>=0.5.4 +django-oauth-plus>=2.0 +oauth2>=1.5.211 +django-oauth2-provider>=0.2.3 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 02bc6fc1..0b1e67fb 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,9 @@ -__version__ = '2.1.12' +__version__ = '2.3.3' VERSION = __version__ # synonym + +# Header encoding (see RFC5987) +HTTP_HEADER_ENCODING = 'iso-8859-1' + +# Default datetime input and output formats +ISO_8601 = 'iso-8601' diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 30c78ebc..9caca788 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,13 +1,30 @@ """ -Provides a set of pluggable authentication policies. +Provides various authentication policies. """ +from __future__ import unicode_literals +import base64 +from datetime import datetime from django.contrib.auth import authenticate -from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError -from rest_framework import exceptions +from django.core.exceptions import ImproperlyConfigured +from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware +from rest_framework.compat import oauth, oauth_provider, oauth_provider_store +from rest_framework.compat import oauth2_provider from rest_framework.authtoken.models import Token -import base64 + + +def get_authorization_header(request): + """ + Return request's 'Authorization:' header, as a bytestring. + + Hide some test client ickyness where the header can be unicode. + """ + auth = request.META.get('HTTP_AUTHORIZATION', b'') + if type(auth) == type(''): + # Work around django test client oddness + auth = auth.encode(HTTP_HEADER_ENCODING) + return auth class BaseAuthentication(object): @@ -21,40 +38,58 @@ class BaseAuthentication(object): """ raise NotImplementedError(".authenticate() must be overridden.") + def authenticate_header(self, request): + """ + Return a string to be used as the value of the `WWW-Authenticate` + header in a `401 Unauthenticated` response, or `None` if the + authentication scheme should return `403 Permission Denied` responses. + """ + pass + class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ + www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ - if 'HTTP_AUTHORIZATION' in request.META: - auth = request.META['HTTP_AUTHORIZATION'].split() - if len(auth) == 2 and auth[0].lower() == "basic": - try: - auth_parts = base64.b64decode(auth[1]).partition(':') - except TypeError: - return None + auth = get_authorization_header(request).split() - try: - userid = smart_unicode(auth_parts[0]) - password = smart_unicode(auth_parts[2]) - except DjangoUnicodeDecodeError: - return None + if not auth or auth[0].lower() != b'basic': + return None - return self.authenticate_credentials(userid, password) + if len(auth) == 1: + msg = 'Invalid basic header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid basic header. Credentials string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) + + try: + auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':') + except (TypeError, UnicodeDecodeError): + msg = 'Invalid basic header. Credentials not correctly base64 encoded' + raise exceptions.AuthenticationFailed(msg) + + userid, password = auth_parts[0], auth_parts[2] + return self.authenticate_credentials(userid, password) def authenticate_credentials(self, userid, password): """ Authenticate the userid and password against username and password. """ user = authenticate(username=userid, password=password) - if user is not None and user.is_active: - return (user, None) + if user is None or not user.is_active: + raise exceptions.AuthenticationFailed('Invalid username/password') + return (user, None) + + def authenticate_header(self, request): + return 'Basic realm="%s"' % self.www_authenticate_realm class SessionAuthentication(BaseAuthentication): @@ -74,7 +109,7 @@ class SessionAuthentication(BaseAuthentication): # Unauthenticated, CSRF validation not required if not user or not user.is_active: - return + return None # Enforce CSRF validation for session based authentication. class CSRFCheck(CsrfViewMiddleware): @@ -85,7 +120,7 @@ class SessionAuthentication(BaseAuthentication): reason = CSRFCheck().process_view(http_request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) # CSRF passed with authenticated user return (user, None) @@ -110,16 +145,198 @@ class TokenAuthentication(BaseAuthentication): """ def authenticate(self, request): - auth = request.META.get('HTTP_AUTHORIZATION', '').split() + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != b'token': + return None + + if len(auth) == 1: + msg = 'Invalid token header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid token header. Token string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) + + return self.authenticate_credentials(auth[1]) + + def authenticate_credentials(self, key): + try: + token = self.model.objects.get(key=key) + except self.model.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + if not token.user.is_active: + raise exceptions.AuthenticationFailed('User inactive or deleted') + + return (token.user, token) + + def authenticate_header(self, request): + return 'Token' + + +class OAuthAuthentication(BaseAuthentication): + """ + OAuth 1.0a authentication backend using `django-oauth-plus` and `oauth2`. + + Note: The `oauth2` package actually provides oauth1.0a support. Urg. + We import it from the `compat` module as `oauth`. + """ + www_authenticate_realm = 'api' + + def __init__(self, *args, **kwargs): + super(OAuthAuthentication, self).__init__(*args, **kwargs) + + if oauth is None: + raise ImproperlyConfigured( + "The 'oauth2' package could not be imported." + "It is required for use with the 'OAuthAuthentication' class.") + + if oauth_provider is None: + raise ImproperlyConfigured( + "The 'django-oauth-plus' package could not be imported." + "It is required for use with the 'OAuthAuthentication' class.") + + def authenticate(self, request): + """ + Returns two-tuple of (user, token) if authentication succeeds, + or None otherwise. + """ + try: + oauth_request = oauth_provider.utils.get_oauth_request(request) + except oauth.Error as err: + raise exceptions.AuthenticationFailed(err.message) + + if not oauth_request: + return None + + oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES + + found = any(param for param in oauth_params if param in oauth_request) + missing = list(param for param in oauth_params if param not in oauth_request) + + if not found: + # OAuth authentication was not attempted. + return None + + if missing: + # OAuth was attempted but missing parameters. + msg = 'Missing parameters: %s' % (', '.join(missing)) + raise exceptions.AuthenticationFailed(msg) + + if not self.check_nonce(request, oauth_request): + msg = 'Nonce check failed' + raise exceptions.AuthenticationFailed(msg) - if len(auth) == 2 and auth[0].lower() == "token": - key = auth[1] - try: - token = self.model.objects.get(key=key) - except self.model.DoesNotExist: - return None + try: + consumer_key = oauth_request.get_parameter('oauth_consumer_key') + consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) + except oauth_provider.store.InvalidConsumerError as err: + raise exceptions.AuthenticationFailed(err) - if token.user.is_active: - return (token.user, token) + if consumer.status != oauth_provider.consts.ACCEPTED: + msg = 'Invalid consumer key status: %s' % consumer.get_status_display() + raise exceptions.AuthenticationFailed(msg) -# TODO: OAuthAuthentication + try: + token_param = oauth_request.get_parameter('oauth_token') + token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param) + except oauth_provider.store.InvalidTokenError: + msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token') + raise exceptions.AuthenticationFailed(msg) + + try: + self.validate_token(request, consumer, token) + except oauth.Error as err: + raise exceptions.AuthenticationFailed(err.message) + + user = token.user + + if not user.is_active: + msg = 'User inactive or deleted: %s' % user.username + raise exceptions.AuthenticationFailed(msg) + + return (token.user, token) + + def authenticate_header(self, request): + """ + If permission is denied, return a '401 Unauthorized' response, + with an appropraite 'WWW-Authenticate' header. + """ + return 'OAuth realm="%s"' % self.www_authenticate_realm + + def validate_token(self, request, consumer, token): + """ + Check the token and raise an `oauth.Error` exception if invalid. + """ + oauth_server, oauth_request = oauth_provider.utils.initialize_server_request(request) + oauth_server.verify_request(oauth_request, consumer, token) + + def check_nonce(self, request, oauth_request): + """ + Checks nonce of request, and return True if valid. + """ + return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce']) + + +class OAuth2Authentication(BaseAuthentication): + """ + OAuth 2 authentication backend using `django-oauth2-provider` + """ + www_authenticate_realm = 'api' + + def __init__(self, *args, **kwargs): + super(OAuth2Authentication, self).__init__(*args, **kwargs) + + if oauth2_provider is None: + raise ImproperlyConfigured( + "The 'django-oauth2-provider' package could not be imported. " + "It is required for use with the 'OAuth2Authentication' class.") + + def authenticate(self, request): + """ + Returns two-tuple of (user, token) if authentication succeeds, + or None otherwise. + """ + + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != b'bearer': + return None + + if len(auth) == 1: + msg = 'Invalid bearer header. No credentials provided.' + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = 'Invalid bearer header. Token string should not contain spaces.' + raise exceptions.AuthenticationFailed(msg) + + return self.authenticate_credentials(request, auth[1]) + + def authenticate_credentials(self, request, access_token): + """ + Authenticate the request, given the access token. + """ + + try: + token = oauth2_provider.models.AccessToken.objects.select_related('user') + # TODO: Change to timezone aware datetime when oauth2_provider add + # support to it. + token = token.get(token=access_token, expires__gt=datetime.now()) + except oauth2_provider.models.AccessToken.DoesNotExist: + raise exceptions.AuthenticationFailed('Invalid token') + + user = token.user + + if not user.is_active: + msg = 'User inactive or deleted: %s' % user.username + raise exceptions.AuthenticationFailed(msg) + + return (user, token) + + def authenticate_header(self, request): + """ + Bearer is the only finalized type currently + + Check details on the `OAuth2Authentication.authenticate` method + """ + return 'Bearer realm="%s"' % self.www_authenticate_realm diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index f4e052e4..d5965e40 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -4,6 +4,8 @@ from south.db import db from south.v2 import SchemaMigration from django.db import models +from rest_framework.settings import api_settings + try: from django.contrib.auth import get_user_model @@ -45,20 +47,7 @@ class Migration(SchemaMigration): 'name': ('django.db.models.fields.CharField', [], {'max_length': '50'}) }, "%s.%s" % (User._meta.app_label, User._meta.module_name): { - 'Meta': {'object_name': 'User'}, - 'date_joined': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), - 'email': ('django.db.models.fields.EmailField', [], {'max_length': '75', 'blank': 'True'}), - 'first_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), - 'groups': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Group']", 'symmetrical': 'False', 'blank': 'True'}), - 'id': ('django.db.models.fields.AutoField', [], {'primary_key': 'True'}), - 'is_active': ('django.db.models.fields.BooleanField', [], {'default': 'True'}), - 'is_staff': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), - 'is_superuser': ('django.db.models.fields.BooleanField', [], {'default': 'False'}), - 'last_login': ('django.db.models.fields.DateTimeField', [], {'default': 'datetime.datetime.now'}), - 'last_name': ('django.db.models.fields.CharField', [], {'max_length': '30', 'blank': 'True'}), - 'password': ('django.db.models.fields.CharField', [], {'max_length': '128'}), - 'user_permissions': ('django.db.models.fields.related.ManyToManyField', [], {'to': "orm['auth.Permission']", 'symmetrical': 'False', 'blank': 'True'}), - 'username': ('django.db.models.fields.CharField', [], {'unique': 'True', 'max_length': '30'}) + 'Meta': {'object_name': User._meta.module_name}, }, 'authtoken.token': { 'Meta': {'object_name': 'Token'}, diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 4da2aa62..52c45ad1 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -2,6 +2,7 @@ import uuid import hmac from hashlib import sha1 from rest_framework.compat import User +from django.conf import settings from django.db import models @@ -13,14 +14,22 @@ class Token(models.Model): user = models.OneToOneField(User, related_name='auth_token') created = models.DateTimeField(auto_now_add=True) + class Meta: + # Work around for a bug in Django: + # https://code.djangoproject.com/ticket/19422 + # + # Also see corresponding ticket: + # https://github.com/tomchristie/django-rest-framework/issues/705 + abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS + def save(self, *args, **kwargs): if not self.key: self.key = self.generate_key() return super(Token, self).save(*args, **kwargs) def generate_key(self): - unique = str(uuid.uuid4()) - return hmac.new(unique, digestmod=sha1).hexdigest() + unique = uuid.uuid4() + return hmac.new(unique.bytes, digestmod=sha1).hexdigest() def __unicode__(self): return self.key diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index d318c723..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -12,10 +12,11 @@ class ObtainAuthToken(APIView): permission_classes = () parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) + serializer_class = AuthTokenSerializer model = Token def post(self, request): - serializer = AuthTokenSerializer(data=request.DATA) + serializer = self.serializer_class(data=request.DATA) if serializer.is_valid(): token, created = Token.objects.get_or_create(user=serializer.object['user']) return Response({'token': token.key}) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 86952fb8..cd39f544 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -3,26 +3,57 @@ The `compat` module provides support for backwards compatibility with older versions of django/python, and compatibility wrappers around optional packages. """ # flake8: noqa +from __future__ import unicode_literals + import django +from django.core.exceptions import ImproperlyConfigured + +# Try to import six from Django, fallback to included `six`. +try: + from django.utils import six +except ImportError: + from rest_framework import six # location of patterns, url, include changes in 1.4 onwards try: from django.conf.urls import patterns, url, include -except: +except ImportError: from django.conf.urls.defaults import patterns, url, include +# Handle django.utils.encoding rename: +# smart_unicode -> smart_text +# force_unicode -> force_text +try: + from django.utils.encoding import smart_text +except ImportError: + from django.utils.encoding import smart_unicode as smart_text +try: + from django.utils.encoding import force_text +except ImportError: + from django.utils.encoding import force_unicode as force_text + + # django-filter is optional try: import django_filters -except: +except ImportError: django_filters = None # cStringIO only if it's available, otherwise StringIO try: - import cStringIO as StringIO + import cStringIO.StringIO as StringIO except ImportError: - import StringIO + StringIO = six.StringIO + +BytesIO = six.BytesIO + + +# urlparse compat import (Required because it changed in python 3.x) +try: + from urllib import parse as urlparse +except ImportError: + import urlparse # Try to import PIL in either of the two ways it can end up installed. @@ -54,12 +85,10 @@ else: try: from django.contrib.auth.models import User except ImportError: - raise ImportError(u"User model is not to be found.") + raise ImportError("User model is not to be found.") -# First implementation of Django class-based views did not include head method -# in base View class - https://code.djangoproject.com/ticket/15668 -if django.VERSION >= (1, 4): +if django.VERSION >= (1, 5): from django.views.generic import View else: from django.views.generic import View as _View @@ -67,6 +96,8 @@ else: from django.utils.functional import update_wrapper class View(_View): + # 1.3 does not include head method in base View class + # See: https://code.djangoproject.com/ticket/15668 @classonlymethod def as_view(cls, **initkwargs): """ @@ -75,11 +106,11 @@ else: # sanitize keyword arguments for key in initkwargs: if key in cls.http_method_names: - raise TypeError(u"You tried to pass in the %s method name as a " - u"keyword argument to %s(). Don't do that." + raise TypeError("You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." % (key, cls.__name__)) if not hasattr(cls, key): - raise TypeError(u"%s() received an invalid keyword %r" % ( + raise TypeError("%s() received an invalid keyword %r" % ( cls.__name__, key)) def view(request, *args, **kwargs): @@ -96,6 +127,16 @@ else: update_wrapper(view, cls.dispatch, assigned=()) return view + # _allowed_methods only present from 1.5 onwards + def _allowed_methods(self): + return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + +# PATCH method is not implemented by Django +if 'patch' not in View.http_method_names: + View.http_method_names = View.http_method_names + ['patch'] + + # PUT, DELETE do not require CSRF until 1.4. They should. Make it better. if django.VERSION >= (1, 4): from django.middleware.csrf import CsrfViewMiddleware @@ -104,7 +145,6 @@ else: import re import random import logging - import urlparse from django.conf import settings from django.core.urlresolvers import get_callable @@ -146,7 +186,8 @@ else: randrange = random.SystemRandom().randrange else: randrange = random.randrange - _MAX_CSRF_KEY = 18446744073709551616L # 2 << 63 + + _MAX_CSRF_KEY = 18446744073709551616 # 2 << 63 REASON_NO_REFERER = "Referer checking failed - no Referer." REASON_BAD_REFERER = "Referer checking failed - %s does not match %s." @@ -313,7 +354,7 @@ except ImportError: # dateparse is ALSO new in Django 1.4 try: - from django.utils.dateparse import parse_date, parse_datetime + from django.utils.dateparse import parse_date, parse_datetime, parse_time except ImportError: import datetime import re @@ -359,6 +400,41 @@ except ImportError: kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: + from django.utils.html import smart_urlquote +except ImportError: + import re + from django.utils.encoding import smart_str + try: + from urllib.parse import quote, urlsplit, urlunsplit + except ImportError: # Python 2 + from urllib import quote + from urlparse import urlsplit, urlunsplit + + unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})') + + def smart_urlquote(url): + "Quotes a URL if it isn't already quoted." + # Handle IDN before quoting. + scheme, netloc, path, query, fragment = urlsplit(url) + try: + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part + pass + else: + url = urlunsplit((scheme, netloc, path, query, fragment)) + + # An URL is considered unquoted if it contains no % characters or + # contains a % not followed by two hexadecimal digits. See #9655. + if '%' not in url or unquoted_percents_re.search(url): + # See http://bugs.python.org/issue2637 + url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~') + + return force_text(url) + + # Markdown is optional try: import markdown @@ -385,8 +461,37 @@ except ImportError: yaml = None -# xml.etree.parse only throws ParseError for python >= 2.7 +# XML is optional +try: + import defusedxml.ElementTree as etree +except ImportError: + etree = None + +# OAuth is optional try: - from xml.etree import ParseError as ETParseError -except ImportError: # python < 2.7 - ETParseError = None + # Note: The `oauth2` package actually provides oauth1.0a support. Urg. + import oauth2 as oauth +except ImportError: + oauth = None + +# OAuth is optional +try: + import oauth_provider + from oauth_provider.store import store as oauth_provider_store +except (ImportError, ImproperlyConfigured): + oauth_provider = None + oauth_provider_store = None + +# OAuth 2 support is optional +try: + import provider.oauth2 as oauth2_provider + from provider.oauth2 import models as oauth2_provider_models + from provider.oauth2 import forms as oauth2_provider_forms + from provider import scope as oauth2_provider_scope + from provider import constants as oauth2_constants +except ImportError: + oauth2_provider = None + oauth2_provider_models = None + oauth2_provider_forms = None + oauth2_provider_scope = None + oauth2_constants = None diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..81e585e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,15 @@ +""" +The most imporant decorator in this module is `@api_view`, which is used +for writing function-based views with REST framework. + +There are also various decorators for setting the API policies on function +based views, as well as the `@action` and `@link` decorators, which are +used to annotate methods on viewsets that should be included by routers. +""" +from __future__ import unicode_literals +from rest_framework.compat import six from rest_framework.views import APIView +import types def api_view(http_method_names): @@ -11,7 +22,7 @@ def api_view(http_method_names): def decorator(func): WrappedAPIView = type( - 'WrappedAPIView', + six.PY3 and 'WrappedAPIView' or b'WrappedAPIView', (APIView,), {'__doc__': func.__doc__} ) @@ -23,6 +34,14 @@ def api_view(http_method_names): # pass # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this + # api_view applied without (method_names) + assert not(isinstance(http_method_names, types.FunctionType)), \ + '@api_view missing list of allowed HTTP methods' + + # api_view applied with eg. string instead of list of strings + assert isinstance(http_method_names, (list, tuple)), \ + '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ + allowed_methods = set(http_method_names) | set(('options',)) WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] @@ -86,3 +105,25 @@ def permission_classes(permission_classes): func.permission_classes = permission_classes return func return decorator + + +def link(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for GET requests. + """ + def decorator(func): + func.bind_to_method = 'get' + func.kwargs = kwargs + return func + return decorator + + +def action(**kwargs): + """ + Used to mark a method on a ViewSet that should be routed for POST requests. + """ + def decorator(func): + func.bind_to_method = 'post' + func.kwargs = kwargs + return func + return decorator diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 89479deb..0c96ecdd 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -4,6 +4,7 @@ Handled exceptions raised by REST framework. In addition Django's built in 403 and 404 exceptions are handled. (`django.http.Http404` and `django.core.exceptions.PermissionDenied`) """ +from __future__ import unicode_literals from rest_framework import status @@ -23,6 +24,22 @@ class ParseError(APIException): self.detail = detail or self.default_detail +class AuthenticationFailed(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Incorrect authentication credentials.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + +class NotAuthenticated(APIException): + status_code = status.HTTP_401_UNAUTHORIZED + default_detail = 'Authentication credentials were not provided.' + + def __init__(self, detail=None): + self.detail = detail or self.default_detail + + class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = 'You do not have permission to perform this action.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index ca421ace..3c4e975a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,34 +1,102 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +""" +from __future__ import unicode_literals + import copy import datetime +from decimal import Decimal, DecimalException import inspect import re import warnings -from io import BytesIO - from django.core import validators -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.core.urlresolvers import resolve, get_script_prefix +from django.core.exceptions import ValidationError from django.conf import settings from django import forms from django.forms import widgets -from django.forms.models import ModelChoiceIterator -from django.utils.encoding import is_protected_type, smart_unicode +from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ -from rest_framework.reverse import reverse -from rest_framework.compat import parse_date, parse_datetime -from rest_framework.compat import timezone -from urlparse import urlparse + +from rest_framework import ISO_8601 +from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time +from rest_framework.compat import BytesIO +from rest_framework.compat import six +from rest_framework.compat import smart_text +from rest_framework.settings import api_settings def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ - return ( - (inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or - (inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1) - ) + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults + + +def get_component(obj, attr_name): + """ + Given an object, and an attribute name, + return that attribute on the object. + """ + if isinstance(obj, dict): + val = obj[attr_name] + else: + val = getattr(obj, attr_name) + + if is_simple_callable(val): + return val() + return val + + +def readable_datetime_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') + return humanize_strptime(format) + + +def readable_date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def readable_time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string class Field(object): @@ -36,7 +104,8 @@ class Field(object): creation_counter = 0 empty = '' type_name = None - _use_files = None + partial = False + use_files = False form_field_class = forms.CharField def __init__(self, source=None, label=None, help_text=None): @@ -63,7 +132,8 @@ class Field(object): self.parent = parent self.root = parent.root or parent self.context = self.root.context - if self.root.partial: + self.partial = self.root.partial + if self.partial: self.required = False def field_from_native(self, data, files, field_name, into): @@ -84,14 +154,14 @@ class Field(object): if self.source == '*': return self.to_native(obj) - if self.source: - value = obj - for component in self.source.split('.'): - value = getattr(value, component) - if is_simple_callable(value): - value = value() - else: - value = getattr(obj, field_name) + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break + return self.to_native(value) def to_native(self, value): @@ -103,11 +173,11 @@ class Field(object): if is_protected_type(value): return value - elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): + elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)): return [self.to_native(item) for item in value] elif isinstance(value, dict): return dict(map(self.to_native, (k, v)) for k, v in value.items()) - return smart_unicode(value) + return smart_text(value) def attributes(self): """ @@ -134,6 +204,14 @@ class WritableField(Field): read_only=False, required=None, validators=[], error_messages=None, widget=None, default=None, blank=None): + + # 'blank' is to be deprecated in favor of 'required' + if blank is not None: + warnings.warn('The `blank` keyword argument is deprecated. ' + 'Use the `required` keyword argument instead.', + DeprecationWarning, stacklevel=2) + required = not(blank) + super(WritableField, self).__init__(source=source, label=label, help_text=help_text) self.read_only = read_only @@ -151,7 +229,6 @@ class WritableField(Field): self.validators = self.default_validators + validators self.default = default if default is not None else self.default - self.blank = blank # Widgets are ony used for HTML forms. widget = widget or self.widget @@ -190,12 +267,14 @@ class WritableField(Field): return try: - if self._use_files: + if self.use_files: + files = files or {} native = files[field_name] else: native = data[field_name] except KeyError: - if self.default is not None: + if self.default is not None and not self.partial: + # Note: partial updates shouldn't set defaults native = self.default else: if self.required: @@ -225,7 +304,7 @@ class ModelField(WritableField): def __init__(self, *args, **kwargs): try: self.model_field = kwargs.pop('model_field') - except: + except KeyError: raise ValueError("ModelField requires 'model_field' kwarg") self.min_length = kwargs.pop('min_length', @@ -258,443 +337,6 @@ class ModelField(WritableField): "type": self.model_field.get_internal_type() } -##### Relational fields ##### - - -# Not actually Writable, but subclasses may need to be. -class RelatedField(WritableField): - """ - Base class for related model fields. - - If not overridden, this represents a to-one relationship, using the unicode - representation of the target. - """ - widget = widgets.Select - cache_choices = False - empty_label = None - default_read_only = True # TODO: Remove this - - def __init__(self, *args, **kwargs): - self.queryset = kwargs.pop('queryset', None) - self.null = kwargs.pop('null', False) - super(RelatedField, self).__init__(*args, **kwargs) - self.read_only = kwargs.pop('read_only', self.default_read_only) - - def initialize(self, parent, field_name): - super(RelatedField, self).initialize(parent, field_name) - if self.queryset is None and not self.read_only: - try: - manager = getattr(self.parent.opts.model, self.source or field_name) - if hasattr(manager, 'related'): # Forward - self.queryset = manager.related.model._default_manager.all() - else: # Reverse - self.queryset = manager.field.rel.to._default_manager.all() - except: - raise - msg = ('Serializer related fields must include a `queryset`' + - ' argument or set `read_only=True') - raise Exception(msg) - - ### We need this stuff to make form choices work... - - # def __deepcopy__(self, memo): - # result = super(RelatedField, self).__deepcopy__(memo) - # result.queryset = result.queryset - # return result - - def prepare_value(self, obj): - return self.to_native(obj) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_unicode(obj) - ident = smart_unicode(self.to_native(obj)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - def _get_queryset(self): - return self._queryset - - def _set_queryset(self, queryset): - self._queryset = queryset - self.widget.choices = self.choices - - queryset = property(_get_queryset, _set_queryset) - - def _get_choices(self): - # If self._choices is set, then somebody must have manually set - # the property self.choices. In this case, just return self._choices. - if hasattr(self, '_choices'): - return self._choices - - # Otherwise, execute the QuerySet in self.queryset to determine the - # choices dynamically. Return a fresh ModelChoiceIterator that has not been - # consumed. Note that we're instantiating a new ModelChoiceIterator *each* - # time _get_choices() is called (and, thus, each time self.choices is - # accessed) so that we can ensure the QuerySet has not been consumed. This - # construct might look complicated but it allows for lazy evaluation of - # the queryset. - return ModelChoiceIterator(self) - - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) - - choices = property(_get_choices, _set_choices) - - ### Regular serializer stuff... - - def field_to_native(self, obj, field_name): - value = getattr(obj, self.source or field_name) - return self.to_native(value) - - def field_from_native(self, data, files, field_name, into): - if self.read_only: - return - - try: - value = data[field_name] - except KeyError: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if value in (None, '') and not self.null: - raise ValidationError('Value may not be null') - elif value in (None, '') and self.null: - into[(self.source or field_name)] = None - else: - into[(self.source or field_name)] = self.from_native(value) - - -class ManyRelatedMixin(object): - """ - Mixin to convert a related field to a many related field. - """ - widget = widgets.SelectMultiple - - def field_to_native(self, obj, field_name): - value = getattr(obj, self.source or field_name) - return [self.to_native(item) for item in value.all()] - - def field_from_native(self, data, files, field_name, into): - if self.read_only: - return - - try: - # Form data - value = data.getlist(self.source or field_name) - except: - # Non-form data - value = data.get(self.source or field_name) - else: - if value == ['']: - value = [] - - into[field_name] = [self.from_native(item) for item in value] - - -class ManyRelatedField(ManyRelatedMixin, RelatedField): - """ - Base class for related model managers. - - If not overridden, this represents a to-many relationship, using the unicode - representations of the target, and is read-only. - """ - pass - - -### PrimaryKey relationships - -class PrimaryKeyRelatedField(RelatedField): - """ - Represents a to-one relationship as a pk value. - """ - default_read_only = False - form_field_class = forms.ChoiceField - - # TODO: Remove these field hacks... - def prepare_value(self, obj): - return self.to_native(obj.pk) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_unicode(obj) - ident = smart_unicode(self.to_native(obj.pk)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - # TODO: Possibly change this to just take `obj`, through prob less performant - def to_native(self, pk): - return pk - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(pk=data) - except ObjectDoesNotExist: - msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) - raise ValidationError(msg) - - def field_to_native(self, obj, field_name): - try: - # Prefer obj.serializable_value for performance reasons - pk = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedObject (reverse relationship) - obj = getattr(obj, self.source or field_name) - return self.to_native(obj.pk) - # Forward relationship - return self.to_native(pk) - - -class ManyPrimaryKeyRelatedField(ManyRelatedField): - """ - Represents a to-many relationship as a pk value. - """ - default_read_only = False - form_field_class = forms.MultipleChoiceField - - def prepare_value(self, obj): - return self.to_native(obj.pk) - - def label_from_instance(self, obj): - """ - Return a readable representation for use with eg. select widgets. - """ - desc = smart_unicode(obj) - ident = smart_unicode(self.to_native(obj.pk)) - if desc == ident: - return desc - return "%s - %s" % (desc, ident) - - def to_native(self, pk): - return pk - - def field_to_native(self, obj, field_name): - try: - # Prefer obj.serializable_value for performance reasons - queryset = obj.serializable_value(self.source or field_name) - except AttributeError: - # RelatedManager (reverse relationship) - queryset = getattr(obj, self.source or field_name) - return [self.to_native(item.pk) for item in queryset.all()] - # Forward relationship - return [self.to_native(item.pk) for item in queryset.all()] - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(pk=data) - except ObjectDoesNotExist: - msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) - raise ValidationError(msg) - -### Slug relationships - - -class SlugRelatedField(RelatedField): - default_read_only = False - form_field_class = forms.ChoiceField - - def __init__(self, *args, **kwargs): - self.slug_field = kwargs.pop('slug_field', None) - assert self.slug_field, 'slug_field is required' - super(SlugRelatedField, self).__init__(*args, **kwargs) - - def to_native(self, obj): - return getattr(obj, self.slug_field) - - def from_native(self, data): - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - try: - return self.queryset.get(**{self.slug_field: data}) - except ObjectDoesNotExist: - raise ValidationError('Object with %s=%s does not exist.' % - (self.slug_field, unicode(data))) - - -class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): - form_field_class = forms.MultipleChoiceField - - -### Hyperlinked relationships - -class HyperlinkedRelatedField(RelatedField): - """ - Represents a to-one relationship, using hyperlinking. - """ - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - default_read_only = False - form_field_class = forms.ChoiceField - - def __init__(self, *args, **kwargs): - try: - self.view_name = kwargs.pop('view_name') - except: - raise ValueError("Hyperlinked field requires 'view_name' kwarg") - - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - - self.format = kwargs.pop('format', None) - super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) - - def get_slug_field(self): - """ - Get the name of a slug field to be used to look up by slug. - """ - return self.slug_field - - def to_native(self, obj): - view_name = self.view_name - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - pk = getattr(obj, 'pk', None) - if pk is None: - return - kwargs = {self.pk_url_kwarg: pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - slug = getattr(obj, self.slug_field, None) - - if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) - - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) - - def from_native(self, value): - # Convert URL -> model instance pk - # TODO: Use values_list - if self.queryset is None: - raise Exception('Writable related fields must include a `queryset` argument') - - if value.startswith('http:') or value.startswith('https:'): - # If needed convert absolute URLs to relative path - value = urlparse(value).path - prefix = get_script_prefix() - if value.startswith(prefix): - value = '/' + value[len(prefix):] - - try: - match = resolve(value) - except: - raise ValidationError('Invalid hyperlink - No URL match') - - if match.url_name != self.view_name: - raise ValidationError('Invalid hyperlink - Incorrect URL match') - - pk = match.kwargs.get(self.pk_url_kwarg, None) - slug = match.kwargs.get(self.slug_url_kwarg, None) - - # Try explicit primary key. - if pk is not None: - queryset = self.queryset.filter(pk=pk) - # Next, try looking up by slug. - elif slug is not None: - slug_field = self.get_slug_field() - queryset = self.queryset.filter(**{slug_field: slug}) - # If none of those are defined, it's an error. - else: - raise ValidationError('Invalid hyperlink') - - try: - obj = queryset.get() - except ObjectDoesNotExist: - raise ValidationError('Invalid hyperlink - object does not exist.') - return obj - - -class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): - """ - Represents a to-many relationship, using hyperlinking. - """ - form_field_class = forms.MultipleChoiceField - - -class HyperlinkedIdentityField(Field): - """ - Represents the instance, or a property on the instance, using hyperlinking. - """ - pk_url_kwarg = 'pk' - slug_field = 'slug' - slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden - - def __init__(self, *args, **kwargs): - # TODO: Make view_name mandatory, and have the - # HyperlinkedModelSerializer set it on-the-fly - self.view_name = kwargs.pop('view_name', None) - self.format = kwargs.pop('format', None) - - self.slug_field = kwargs.pop('slug_field', self.slug_field) - default_slug_kwarg = self.slug_url_kwarg or self.slug_field - self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) - self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) - - super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) - - def field_to_native(self, obj, field_name): - request = self.context.get('request', None) - format = self.format or self.context.get('format', None) - view_name = self.view_name or self.parent.opts.view_name - kwargs = {self.pk_url_kwarg: obj.pk} - try: - return reverse(view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - slug = getattr(obj, self.slug_field, None) - - if not slug: - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) - - kwargs = {self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} - try: - return reverse(self.view_name, kwargs=kwargs, request=request, format=format) - except: - pass - - raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) - ##### Typed Fields ##### @@ -703,7 +345,7 @@ class BooleanField(WritableField): form_field_class = forms.BooleanField widget = widgets.CheckboxInput default_error_messages = { - 'invalid': _(u"'%s' value must be either True or False."), + 'invalid': _("'%s' value must be either True or False."), } empty = False @@ -732,20 +374,10 @@ class CharField(WritableField): if max_length is not None: self.validators.append(validators.MaxLengthValidator(max_length)) - def validate(self, value): - """ - Validates that the value is supplied (if required). - """ - # if empty string and allow blank - if self.blank and not value: - return - else: - super(CharField, self).validate(value) - def from_native(self, value): - if isinstance(value, basestring) or value is None: + if isinstance(value, six.string_types) or value is None: return value - return smart_unicode(value) + return smart_text(value) class URLField(CharField): @@ -770,7 +402,8 @@ class ChoiceField(WritableField): form_field_class = forms.ChoiceField widget = widgets.Select default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'), + 'invalid_choice': _('Select a valid choice. %(value)s is not one of ' + 'the available choices.'), } def __init__(self, choices=(), *args, **kwargs): @@ -804,10 +437,10 @@ class ChoiceField(WritableField): if isinstance(v, (list, tuple)): # This is an optgroup, so look inside the group for options for k2, v2 in v: - if value == smart_unicode(k2): + if value == smart_text(k2): return True else: - if value == smart_unicode(k) or value == k: + if value == smart_text(k) or value == k: return True return False @@ -847,7 +480,7 @@ class RegexField(CharField): return self._regex def _set_regex(self, regex): - if isinstance(regex, basestring): + if isinstance(regex, six.string_types): regex = re.compile(regex) self._regex = regex if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: @@ -870,12 +503,16 @@ class DateField(WritableField): form_field_class = forms.DateField default_error_messages = { - 'invalid': _(u"'%s' value has an invalid date format. It must be " - u"in YYYY-MM-DD format."), - 'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) " - u"but it is an invalid date."), + 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATE_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -891,17 +528,37 @@ class DateField(WritableField): if isinstance(value, datetime.date): return value - try: - parsed = parse_date(value) - if parsed is not None: - return parsed - except ValueError: - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_date(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.date() - msg = self.error_messages['invalid'] % value + msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.date() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + class DateTimeField(WritableField): type_name = 'DateTimeField' @@ -909,15 +566,16 @@ class DateTimeField(WritableField): form_field_class = forms.DateTimeField default_error_messages = { - 'invalid': _(u"'%s' value has an invalid format. It must be in " - u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), - 'invalid_date': _(u"'%s' value has the correct format " - u"(YYYY-MM-DD) but it is an invalid date."), - 'invalid_datetime': _(u"'%s' value has the correct format " - u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " - u"but it is an invalid date/time."), + 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATETIME_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateTimeField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -932,32 +590,100 @@ class DateTimeField(WritableField): # local time. This won't work during DST change, but we can't # do much about it, so we let the exceptions percolate up the # call stack. - warnings.warn(u"DateTimeField received a naive datetime (%s)" - u" while time zone support is active." % value, + warnings.warn("DateTimeField received a naive datetime (%s)" + " while time zone support is active." % value, RuntimeWarning) default_timezone = timezone.get_default_timezone() value = timezone.make_aware(value, default_timezone) return value - try: - parsed = parse_datetime(value) - if parsed is not None: - return parsed - except ValueError: - msg = self.error_messages['invalid_datetime'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_datetime(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed - try: - parsed = parse_date(value) - if parsed is not None: - return datetime.datetime(parsed.year, parsed.month, parsed.day) - except ValueError: - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) + raise ValidationError(msg) + + def to_native(self, value): + if value is None or self.format is None: + return value - msg = self.error_messages['invalid'] % value + if self.format.lower() == ISO_8601: + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret + return value.strftime(self.format) + + +class TimeField(WritableField): + type_name = 'TimeField' + widget = widgets.TimeInput + form_field_class = forms.TimeField + + default_error_messages = { + 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + } + empty = None + input_formats = api_settings.TIME_INPUT_FORMATS + format = api_settings.TIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(TimeField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.time): + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_time(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.time() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + class IntegerField(WritableField): type_name = 'IntegerField' @@ -1008,8 +734,77 @@ class FloatField(WritableField): raise ValidationError(msg) +class DecimalField(WritableField): + type_name = 'DecimalField' + form_field_class = forms.DecimalField + + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'max_digits': _('Ensure that there are no more than %s digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(*args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + value = smart_text(value).strip() + try: + value = Decimal(value) + except DecimalException: + raise ValidationError(self.error_messages['invalid']) + return value + + def validate(self, value): + super(DecimalField, self).validate(value) + if value in validators.EMPTY_VALUES: + return + # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, + # since it is never equal to itself. However, NaN is the only value that + # isn't equal to itself, so we can use this to identify NaN + if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): + raise ValidationError(self.error_messages['invalid']) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + raise ValidationError(self.error_messages['max_digits'] % self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) + return value + + class FileField(WritableField): - _use_files = True + use_files = True type_name = 'FileField' form_field_class = forms.FileField widget = widgets.FileInput @@ -1053,11 +848,12 @@ class FileField(WritableField): class ImageField(FileField): - _use_files = True + use_files = True form_field_class = forms.ImageField default_error_messages = { - 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."), + 'invalid_image': _("Upload a valid image. The file you uploaded was " + "either not an image or a corrupted image."), } def from_native(self, data): diff --git a/rest_framework/filters.py b/rest_framework/filters.py index bcc87660..c058bc71 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -1,4 +1,12 @@ -from rest_framework.compat import django_filters +""" +Provides generic filtering backends that can be used to filter the results +returned by list views. +""" +from __future__ import unicode_literals +from django.db import models +from rest_framework.compat import django_filters, six +from functools import reduce +import operator FilterSet = django_filters and django_filters.FilterSet or None @@ -24,36 +32,112 @@ class DjangoFilterBackend(BaseFilterBackend): def __init__(self): assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed' - def get_filter_class(self, view): + def get_filter_class(self, view, queryset=None): """ Return the django-filters `FilterSet` used to filter the queryset. """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - view_model = getattr(view, 'model', None) if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, view_model), \ - 'FilterSet model %s does not match view model %s' % \ - (filter_model, view_model) + assert issubclass(filter_model, queryset.model), \ + 'FilterSet model %s does not match queryset model %s' % \ + (filter_model, queryset.model) return filter_class if filter_fields: class AutoFilterSet(self.default_filter_set): class Meta: - model = view_model + model = queryset.model fields = filter_fields return AutoFilterSet return None def filter_queryset(self, request, queryset, view): - filter_class = self.get_filter_class(view) + filter_class = self.get_filter_class(view, queryset) if filter_class: - return filter_class(request.GET, queryset=queryset) + return filter_class(request.QUERY_PARAMS, queryset=queryset).qs + + return queryset + + +class SearchFilter(BaseFilterBackend): + search_param = 'search' # The URL query parameter used for the search. + + def get_search_terms(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.search_param, '') + return params.replace(',', ' ').split() + + def construct_search(self, field_name): + if field_name.startswith('^'): + return "%s__istartswith" % field_name[1:] + elif field_name.startswith('='): + return "%s__iexact" % field_name[1:] + elif field_name.startswith('@'): + return "%s__search" % field_name[1:] + else: + return "%s__icontains" % field_name + + def filter_queryset(self, request, queryset, view): + search_fields = getattr(view, 'search_fields', None) + + if not search_fields: + return queryset + + orm_lookups = [self.construct_search(str(search_field)) + for search_field in search_fields] + + for search_term in self.get_search_terms(request): + or_queries = [models.Q(**{orm_lookup: search_term}) + for orm_lookup in orm_lookups] + queryset = queryset.filter(reduce(operator.or_, or_queries)) + + return queryset + + +class OrderingFilter(BaseFilterBackend): + ordering_param = 'ordering' # The URL query parameter used for the ordering. + + def get_ordering(self, request): + """ + Search terms are set by a ?search=... query parameter, + and may be comma and/or whitespace delimited. + """ + params = request.QUERY_PARAMS.get(self.ordering_param) + if params: + return [param.strip() for param in params.split(',')] + + def get_default_ordering(self, view): + ordering = getattr(view, 'ordering', None) + if isinstance(ordering, six.string_types): + return (ordering,) + return ordering + + def remove_invalid_fields(self, queryset, ordering): + field_names = [field.name for field in queryset.model._meta.fields] + return [term for term in ordering if term.lstrip('-') in field_names] + + def filter_queryset(self, request, queryset, view): + ordering = self.get_ordering(request) + + if ordering: + # Skip any incorrect parameters + ordering = self.remove_invalid_fields(queryset, ordering) + + if not ordering: + # Use 'ordering' attribtue by default + ordering = self.get_default_ordering(view) + + if ordering: + return queryset.order_by(*ordering) return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dd8dfcf8..05ec93d3 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,23 +1,60 @@ """ Generic views that provide commonly needed behaviour. """ +from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured +from django.core.paginator import Paginator, InvalidPage +from django.http import Http404 +from django.shortcuts import get_object_or_404 +from django.utils.translation import ugettext as _ from rest_framework import views, mixins +from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings -from django.views.generic.detail import SingleObjectMixin -from django.views.generic.list import MultipleObjectMixin +import warnings -### Base classes for the generic views ### - class GenericAPIView(views.APIView): """ Base class for all other generic views. """ - model = None + # You'll need to either set these attributes, + # or override `get_queryset()`/`get_serializer_class()`. + queryset = None serializer_class = None + + # This shortcut may be used instead of setting either or both + # of the `queryset`/`serializer_class` attributes, although using + # the explicit style is generally preferred. + model = None + + # If you want to use object lookups other than pk, set this attribute. + # For more complex lookup requirements override `get_object()`. + lookup_field = 'pk' + + # Pagination settings + paginate_by = api_settings.PAGINATE_BY + paginate_by_param = api_settings.PAGINATE_BY_PARAM + pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS + page_kwarg = 'page' + + # The filter backend classes to use for queryset filtering + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS + + # The following attributes may be subject to change, + # and should be considered private API. model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + paginator_class = Paginator + + ###################################### + # These are pending deprecation... + + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' + slug_field = 'slug' + allow_empty = True + filter_backend = api_settings.FILTER_BACKEND def get_serializer_context(self): """ @@ -29,54 +66,18 @@ class GenericAPIView(views.APIView): 'view': self } - def get_serializer_class(self): - """ - Return the class to use for the serializer. - - Defaults to using `self.serializer_class`, falls back to constructing a - model serializer class using `self.model_serializer_class`, with - `self.model` as the model. - """ - serializer_class = self.serializer_class - - if serializer_class is None: - class DefaultSerializer(self.model_serializer_class): - class Meta: - model = self.model - serializer_class = DefaultSerializer - - return serializer_class - - def get_serializer(self, instance=None, data=None, files=None): + def get_serializer(self, instance=None, data=None, + files=None, many=False, partial=False): """ Return the serializer instance that should be used for validating and deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() context = self.get_serializer_context() - return serializer_class(instance, data=data, files=files, context=context) + return serializer_class(instance, data=data, files=files, + many=many, partial=partial, context=context) - -class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a queryset. - """ - - paginate_by = api_settings.PAGINATE_BY - paginate_by_param = api_settings.PAGINATE_BY_PARAM - pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - filter_backend = api_settings.FILTER_BACKEND - - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) - - def get_pagination_serializer(self, page=None): + def get_pagination_serializer(self, page): """ Return a serializer instance to use with paginated data. """ @@ -88,42 +89,233 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): context = self.get_serializer_context() return pagination_serializer_class(instance=page, context=context) - def get_paginate_by(self, queryset): + def paginate_queryset(self, queryset, page_size=None): + """ + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. + """ + deprecated_style = False + if page_size is not None: + warnings.warn('The `page_size` parameter to `paginate_queryset()` ' + 'is due to be deprecated. ' + 'Note that the return style of this method is also ' + 'changed, and will simply return a page object ' + 'when called without a `page_size` argument.', + PendingDeprecationWarning, stacklevel=2) + deprecated_style = True + else: + # Determine the required page size. + # If pagination is not configured, simply return None. + page_size = self.get_paginate_by() + if not page_size: + return None + + if not self.allow_empty: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning, stacklevel=2 + ) + + paginator = self.paginator_class(queryset, page_size, + allow_empty_first_page=self.allow_empty) + page_kwarg = self.kwargs.get(self.page_kwarg) + page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) + page = page_kwarg or page_query_param or 1 + try: + page_number = int(page) + except ValueError: + if page == 'last': + page_number = paginator.num_pages + else: + raise Http404(_("Page is not 'last', nor can it be converted to an int.")) + try: + page = paginator.page(page_number) + except InvalidPage as e: + raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { + 'page_number': page_number, + 'message': str(e) + }) + + if deprecated_style: + return (paginator, page, page.object_list, page.has_other_pages()) + return page + + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + + You are unlikely to want to override this method, although you may need + to call it either from a list view, or from a custom `get_object` + method if you want to apply the configured filtering backend to the + default queryset. + """ + filter_backends = self.filter_backends or [] + if not filter_backends and self.filter_backend: + warnings.warn( + 'The `filter_backend` attribute and `FILTER_BACKEND` setting ' + 'are due to be deprecated in favor of a `filter_backends` ' + 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' + 'a *list* of filter backend classes.', + PendingDeprecationWarning, stacklevel=2 + ) + filter_backends = [self.filter_backend] + + for backend in filter_backends: + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset + + ######################## + ### The following methods provide default implementations + ### that you may want to override for more complex cases. + + def get_paginate_by(self, queryset=None): """ Return the size of pages to use with pagination. + + If `PAGINATE_BY_PARAM` is set it will attempt to get the page size + from a named query parameter in the url, eg. ?page_size=100 + + Otherwise defaults to using `self.paginate_by`. """ + if queryset is not None: + warnings.warn('The `queryset` parameter to `get_paginate_by()` ' + 'is due to be deprecated.', + PendingDeprecationWarning, stacklevel=2) + if self.paginate_by_param: query_params = self.request.QUERY_PARAMS try: return int(query_params[self.paginate_by_param]) except (KeyError, ValueError): pass + return self.paginate_by + def get_serializer_class(self): + """ + Return the class to use for the serializer. + Defaults to using `self.serializer_class`. + + You may want to override this if you need to provide different + serializations depending on the incoming request. -class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): - """ - Base class for generic views onto a model instance. - """ + (Eg. admins get full serialization, others get basic serilization) + """ + serializer_class = self.serializer_class + if serializer_class is not None: + return serializer_class - pk_url_kwarg = 'pk' # Not provided in Django 1.3 - slug_url_kwarg = 'slug' # Not provided in Django 1.3 - slug_field = 'slug' + assert self.model is not None, \ + "'%s' should either include a 'serializer_class' attribute, " \ + "or use the 'model' attribute as a shortcut for " \ + "automatically generating a serializer class." \ + % self.__class__.__name__ + + class DefaultSerializer(self.model_serializer_class): + class Meta: + model = self.model + return DefaultSerializer + + def get_queryset(self): + """ + Get the list of items for this view. + This must be an iterable, and may be a queryset. + Defaults to using `self.queryset`. + + You may want to override this if you need to provide different + querysets depending on the incoming request. + + (Eg. return a list of items that is specific to the user) + """ + if self.queryset is not None: + return self.queryset._clone() + + if self.model is not None: + return self.model._default_manager.all() + + raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" + % self.__class__.__name__) def get_object(self, queryset=None): """ - Override default to add support for object-level permissions. + Returns the object the view is displaying. + + You may want to override this if you need to provide non-standard + queryset lookups. Eg if objects are referenced using multiple + keyword arguments in the url conf. """ - obj = super(SingleObjectAPIView, self).get_object(queryset) - if not self.has_permission(self.request, obj): - self.permission_denied(self.request) + # Determine the base queryset to use. + if queryset is None: + queryset = self.filter_queryset(self.get_queryset()) + else: + pass # Deprecation warning + + # Perform the lookup filtering. + pk = self.kwargs.get(self.pk_url_kwarg, None) + slug = self.kwargs.get(self.slug_url_kwarg, None) + lookup = self.kwargs.get(self.lookup_field, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `pk_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) + filter_kwargs = {'pk': pk} + elif slug is not None and self.lookup_field == 'pk': + warnings.warn( + 'The `slug_url_kwarg` attribute is due to be deprecated. ' + 'Use the `lookup_field` attribute instead', + PendingDeprecationWarning + ) + filter_kwargs = {self.slug_field: slug} + else: + raise ConfigurationError( + 'Expected view %s to be called with a URL keyword argument ' + 'named "%s". Fix your URL conf, or set the `.lookup_field` ' + 'attribute on the view correctly.' % + (self.__class__.__name__, self.lookup_field) + ) + + obj = get_object_or_404(queryset, **filter_kwargs) + + # May raise a permission denied + self.check_object_permissions(self.request, obj) + return obj + ######################## + ### The following are placeholder methods, + ### and are intended to be overridden. + ### + ### The are not called by GenericAPIView directly, + ### but are used by the mixin methods. -### Concrete view classes that provide method handlers ### -### by composing the mixin classes with a base view. ### + def pre_save(self, obj): + """ + Placeholder method for calling before saving an object. + + May be used to set attributes on the object that are implicit + in either the request, or the url. + """ + pass + + def post_save(self, obj, created=False): + """ + Placeholder method for calling after saving an object. + """ + pass +########################################################## +### Concrete view classes that provide method handlers ### +### by composing the mixin classes with the base view. ### +########################################################## + class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): @@ -135,7 +327,7 @@ class CreateAPIView(mixins.CreateModelMixin, class ListAPIView(mixins.ListModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset. """ @@ -144,7 +336,7 @@ class ListAPIView(mixins.ListModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving a model instance. """ @@ -153,7 +345,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for deleting a model instance. @@ -163,7 +355,7 @@ class DestroyAPIView(mixins.DestroyModelMixin, class UpdateAPIView(mixins.UpdateModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for updating a model instance. @@ -171,10 +363,13 @@ class UpdateAPIView(mixins.UpdateModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, - MultipleObjectAPIView): + GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ @@ -185,9 +380,25 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + GenericAPIView): + """ + Concrete view for retrieving, updating a model instance. + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def put(self, request, *args, **kwargs): + return self.update(request, *args, **kwargs) + + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + + class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving or deleting a model instance. """ @@ -201,7 +412,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, - SingleObjectAPIView): + GenericAPIView): """ Concrete view for retrieving, updating or deleting a model instance. """ @@ -211,5 +422,32 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) + def patch(self, request, *args, **kwargs): + return self.partial_update(request, *args, **kwargs) + def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +########################## +### Deprecated classes ### +########################## + +class MultipleObjectAPIView(GenericAPIView): + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(MultipleObjectAPIView, self).__init__(*args, **kwargs) + + +class SingleObjectAPIView(GenericAPIView): + def __init__(self, *args, **kwargs): + warnings.warn( + 'Subclassing `SingleObjectAPIView` is due to be deprecated. ' + 'You should simply subclass `GenericAPIView` instead.', + PendingDeprecationWarning, stacklevel=2 + ) + super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2700606d..f3cd5868 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -4,23 +4,57 @@ Basic building blocks for generic class based views. We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. """ +from __future__ import unicode_literals + from django.http import Http404 from rest_framework import status from rest_framework.response import Response +from rest_framework.request import clone_request +import warnings + + +def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None): + """ + Given a model instance, and an optional pk and slug field, + return the full list of all other field names on that model. + + For use when performing full_clean on a model instance, + so we only clean the required fields. + """ + include = [] + + if pk: + # Pending deprecation + pk_field = obj._meta.pk + while pk_field.rel: + pk_field = pk_field.rel.to._meta.pk + include.append(pk_field.name) + + if slug_field: + # Pending deprecation + include.append(slug_field) + + if lookup_field and lookup_field != 'pk': + include.append(lookup_field) + + return [field.name for field in obj._meta.fields if field.name not in include] class CreateModelMixin(object): """ Create a model instance. - Should be mixed in with any `BaseView`. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.DATA, files=request.FILES) + if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + self.object = serializer.save(force_insert=True) + self.post_save(self.object, created=True) headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response(serializer.data, status=status.HTTP_201_CREATED, + headers=headers) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) def get_success_headers(self, data): @@ -29,38 +63,35 @@ class CreateModelMixin(object): except (TypeError, KeyError): return {} - def pre_save(self, obj): - pass - class ListModelMixin(object): """ List a queryset. - Should be mixed in with `MultipleObjectAPIView`. """ - empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." + empty_error = "Empty list and '%(class_name)s.allow_empty' is False." def list(self, request, *args, **kwargs): - queryset = self.get_queryset() - self.object_list = self.filter_queryset(queryset) + self.object_list = self.filter_queryset(self.get_queryset()) # Default is to allow empty querysets. This can be altered by setting # `.allow_empty = False`, to raise 404 errors on empty querysets. - allow_empty = self.get_allow_empty() - if not allow_empty and not self.object_list: + if not self.allow_empty and not self.object_list: + warnings.warn( + 'The `allow_empty` parameter is due to be deprecated. ' + 'To use `allow_empty=False` style behavior, You should override ' + '`get_queryset()` and explicitly raise a 404 on empty querysets.', + PendingDeprecationWarning + ) class_name = self.__class__.__name__ error_msg = self.empty_error % {'class_name': class_name} raise Http404(error_msg) - # Pagination size is set by the `.paginate_by` attribute, - # which may be `None` to disable pagination. - page_size = self.get_paginate_by(self.object_list) - if page_size: - packed = self.paginate_queryset(self.object_list, page_size) - paginator, page, queryset, is_paginated = packed + # Switch between paginated or standard style responses + page = self.paginate_queryset(self.object_list) + if page is not None: serializer = self.get_pagination_serializer(page) else: - serializer = self.get_serializer(self.object_list) + serializer = self.get_serializer(self.object_list, many=True) return Response(serializer.data) @@ -68,7 +99,6 @@ class ListModelMixin(object): class RetrieveModelMixin(object): """ Retrieve a model instance. - Should be mixed in with `SingleObjectBaseView`. """ def retrieve(self, request, *args, **kwargs): self.object = self.get_object() @@ -79,49 +109,74 @@ class RetrieveModelMixin(object): class UpdateModelMixin(object): """ Update a model instance. - Should be mixed in with `SingleObjectBaseView`. """ - def update(self, request, *args, **kwargs): + def get_object_or_none(self): try: - self.object = self.get_object() - created = False + return self.get_object() except Http404: - self.object = None + # If this is a PUT-as-create operation, we need to ensure that + # we have relevant permissions, as if this was a POST request. + # This will either raise a PermissionDenied exception, + # or simply return None + self.check_permissions(clone_request(self.request, 'POST')) + + def update(self, request, *args, **kwargs): + partial = kwargs.pop('partial', False) + self.object = self.get_object_or_none() + + if self.object is None: created = True + save_kwargs = {'force_insert': True} + success_status_code = status.HTTP_201_CREATED + else: + created = False + save_kwargs = {'force_update': True} + success_status_code = status.HTTP_200_OK - serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES) + serializer = self.get_serializer(self.object, data=request.DATA, + files=request.FILES, partial=partial) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() - status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code) + self.object = serializer.save(**save_kwargs) + self.post_save(self.object, created=created) + return Response(serializer.data, status=success_status_code) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def partial_update(self, request, *args, **kwargs): + kwargs['partial'] = True + return self.update(request, *args, **kwargs) + def pre_save(self, obj): """ Set any attributes on the object that are implicit in the request. """ # pk and/or slug attributes are implicit in the URL. + lookup = self.kwargs.get(self.lookup_field, None) pk = self.kwargs.get(self.pk_url_kwarg, None) + slug = self.kwargs.get(self.slug_url_kwarg, None) + slug_field = slug and self.slug_field or None + + if lookup: + setattr(obj, self.lookup_field, lookup) + if pk: setattr(obj, 'pk', pk) - slug = self.kwargs.get(self.slug_url_kwarg, None) if slug: - slug_field = self.get_slug_field() setattr(obj, slug_field, slug) # Ensure we clean the attributes so that we don't eg return integer # pk using a string representation, as provided by the url conf kwarg. - obj.full_clean() + if hasattr(obj, 'full_clean'): + exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field) + obj.full_clean(exclude) class DestroyModelMixin(object): """ Destroy a model instance. - Should be mixed in with `SingleObjectBaseView`. """ def destroy(self, request, *args, **kwargs): obj = self.get_object() diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index ee2800a6..4d205c0e 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,3 +1,8 @@ +""" +Content negotiation deals with selecting an appropriate renderer given the +incoming request. Typically this will be based on the request's Accept header. +""" +from __future__ import unicode_literals from django.http import Http404 from rest_framework import exceptions from rest_framework.settings import api_settings @@ -33,7 +38,7 @@ class DefaultContentNegotiation(BaseContentNegotiation): """ # Allow URL style format override. eg. "?format=json format_query_param = self.settings.URL_FORMAT_OVERRIDE - format = format_suffix or request.GET.get(format_query_param) + format = format_suffix or request.QUERY_PARAMS.get(format_query_param) if format: renderers = self.filter_renderers(renderers, format) @@ -80,5 +85,5 @@ class DefaultContentNegotiation(BaseContentNegotiation): Allows URL style accept override. eg. "?accept=application/json" """ header = request.META.get('HTTP_ACCEPT', '*/*') - header = request.GET.get(self.settings.URL_ACCEPT_OVERRIDE, header) + header = request.QUERY_PARAMS.get(self.settings.URL_ACCEPT_OVERRIDE, header) return [token.strip() for token in header.split(',')] diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d241ade7..d51ea929 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,8 +1,11 @@ +""" +Pagination serializers determine the structure of the output that should +be used for paginated responses. +""" +from __future__ import unicode_literals from rest_framework import serializers from rest_framework.templatetags.rest_framework import replace_query_param -# TODO: Support URLconf kwarg-style paging - class NextPageField(serializers.Field): """ @@ -34,6 +37,17 @@ class PreviousPageField(serializers.Field): return replace_query_param(url, self.page_field, page) +class DefaultObjectSerializer(serializers.Field): + """ + If no object serializer is specified, then this serializer will be applied + as the default. + """ + + def __init__(self, source=None, context=None): + # Note: Swallow context kwarg - only required for eg. ModelSerializer. + super(DefaultObjectSerializer, self).__init__(source=source) + + class PaginationSerializerOptions(serializers.SerializerOptions): """ An object that stores the options that may be provided to a @@ -44,7 +58,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions): def __init__(self, meta): super(PaginationSerializerOptions, self).__init__(meta) self.object_serializer_class = getattr(meta, 'object_serializer_class', - serializers.Field) + DefaultObjectSerializer) class BasePaginationSerializer(serializers.Serializer): @@ -62,14 +76,13 @@ class BasePaginationSerializer(serializers.Serializer): super(BasePaginationSerializer, self).__init__(*args, **kwargs) results_field = self.results_field object_serializer = self.opts.object_serializer_class - self.fields[results_field] = object_serializer(source='object_list') - def to_native(self, obj): - """ - Prevent default behaviour of iterating over elements, and serializing - each in turn. - """ - return self.convert_object(obj) + if 'context' in kwargs: + context_kwarg = {'context': kwargs['context']} + else: + context_kwarg = {} + + self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 4841676c..25be2e6a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -4,15 +4,16 @@ Parsers are used to parse the content of incoming HTTP requests. They give us a generic way of being able to handle various media types on the request, such as form content or json encoded data. """ - +from __future__ import unicode_literals +from django.conf import settings +from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser -from django.http.multipartparser import MultiPartParserError -from django.utils import simplejson as json -from rest_framework.compat import yaml, ETParseError +from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter +from rest_framework.compat import yaml, etree from rest_framework.exceptions import ParseError -from xml.etree import ElementTree as ET -from xml.parsers.expat import ExpatError +from rest_framework.compat import six +import json import datetime import decimal @@ -54,10 +55,14 @@ class JSONParser(BaseParser): `data` will be an object which is the parsed content of the response. `files` will always be `None`. """ + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + try: - return json.load(stream) - except ValueError, exc: - raise ParseError('JSON parse error - %s' % unicode(exc)) + data = stream.read().decode(encoding) + return json.loads(data) + except ValueError as exc: + raise ParseError('JSON parse error - %s' % six.text_type(exc)) class YAMLParser(BaseParser): @@ -74,10 +79,16 @@ class YAMLParser(BaseParser): `data` will be an object which is the parsed content of the response. `files` will always be `None`. """ + assert yaml, 'YAMLParser requires pyyaml to be installed' + + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + try: - return yaml.safe_load(stream) - except (ValueError, yaml.parser.ParserError), exc: - raise ParseError('YAML parse error - %s' % unicode(exc)) + data = stream.read().decode(encoding) + return yaml.safe_load(data) + except (ValueError, yaml.parser.ParserError) as exc: + raise ParseError('YAML parse error - %s' % six.u(exc)) class FormParser(BaseParser): @@ -94,7 +105,9 @@ class FormParser(BaseParser): `data` will be a :class:`QueryDict` containing all the form parameters. `files` will always be :const:`None`. """ - data = QueryDict(stream.read()) + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + data = QueryDict(stream.read(), encoding=encoding) return data @@ -114,15 +127,16 @@ class MultiPartParser(BaseParser): """ parser_context = parser_context or {} request = parser_context['request'] + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) meta = request.META upload_handlers = request.upload_handlers try: - parser = DjangoMultiPartParser(meta, stream, upload_handlers) + parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding) data, files = parser.parse() return DataAndFiles(data, files) - except MultiPartParserError, exc: - raise ParseError('Multipart form parse error - %s' % unicode(exc)) + except MultiPartParserError as exc: + raise ParseError('Multipart form parse error - %s' % six.u(exc)) class XMLParser(BaseParser): @@ -133,10 +147,15 @@ class XMLParser(BaseParser): media_type = 'application/xml' def parse(self, stream, media_type=None, parser_context=None): + assert etree, 'XMLParser requires defusedxml to be installed' + + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + parser = etree.DefusedXMLParser(encoding=encoding) try: - tree = ET.parse(stream) - except (ExpatError, ETParseError, ValueError), exc: - raise ParseError('XML parse error - %s' % unicode(exc)) + tree = etree.parse(stream, parser=parser, forbid_dtd=True) + except (etree.ParseError, ValueError) as exc: + raise ParseError('XML parse error - %s' % six.u(exc)) data = self._xml_convert(tree.getroot()) return data @@ -146,7 +165,7 @@ class XMLParser(BaseParser): convert the xml `element` into the corresponding python object """ - children = element.getchildren() + children = list(element) if len(children) == 0: return self._type_convert(element.text) @@ -187,3 +206,90 @@ class XMLParser(BaseParser): pass return value + + +class FileUploadParser(BaseParser): + """ + Parser for file upload data. + """ + media_type = '*/*' + + def parse(self, stream, media_type=None, parser_context=None): + """ + Returns a DataAndFiles object. + + `.data` will be None (we expect request body to be a file content). + `.files` will be a `QueryDict` containing one 'file' element. + """ + + parser_context = parser_context or {} + request = parser_context['request'] + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + meta = request.META + upload_handlers = request.upload_handlers + filename = self.get_filename(stream, media_type, parser_context) + + # Note that this code is extracted from Django's handling of + # file uploads in MultiPartParser. + content_type = meta.get('HTTP_CONTENT_TYPE', + meta.get('CONTENT_TYPE', '')) + try: + content_length = int(meta.get('HTTP_CONTENT_LENGTH', + meta.get('CONTENT_LENGTH', 0))) + except (ValueError, TypeError): + content_length = None + + # See if the handler will want to take care of the parsing. + for handler in upload_handlers: + result = handler.handle_raw_input(None, + meta, + content_length, + None, + encoding) + if result is not None: + return DataAndFiles(None, {'file': result[1]}) + + # This is the standard case. + possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] + chunk_size = min([2 ** 31 - 4] + possible_sizes) + chunks = ChunkIter(stream, chunk_size) + counters = [0] * len(upload_handlers) + + for handler in upload_handlers: + try: + handler.new_file(None, filename, content_type, + content_length, encoding) + except StopFutureHandlers: + break + + for chunk in chunks: + for i, handler in enumerate(upload_handlers): + chunk_length = len(chunk) + chunk = handler.receive_data_chunk(chunk, counters[i]) + counters[i] += chunk_length + if chunk is None: + break + + for i, handler in enumerate(upload_handlers): + file_obj = handler.file_complete(counters[i]) + if file_obj: + return DataAndFiles(None, {'file': file_obj}) + raise ParseError("FileUpload parse error - " + "none of upload handlers can handle the stream") + + def get_filename(self, stream, media_type, parser_context): + """ + Detects the uploaded file name. First searches a 'filename' url kwarg. + Then tries to parse Content-Disposition header. + """ + try: + return parser_context['kwargs']['filename'] + except KeyError: + pass + + try: + meta = parser_context['request'].META + disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION']) + return disposition[1]['filename'] + except (AttributeError, KeyError): + pass diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 655b78a3..45fcfd66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -1,21 +1,38 @@ """ Provides a set of pluggable permission policies. """ - +from __future__ import unicode_literals +import inspect +import warnings SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] +from rest_framework.compat import oauth2_provider_scope, oauth2_constants + class BasePermission(object): """ A base class from which all permission classes should inherit. """ - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): + """ + Return `True` if permission is granted, `False` otherwise. + """ + return True + + def has_object_permission(self, request, view, obj): """ Return `True` if permission is granted, `False` otherwise. """ - raise NotImplementedError(".has_permission() must be overridden.") + if len(inspect.getargspec(self.has_permission).args) == 4: + warnings.warn( + 'The `obj` argument in `has_permission` is deprecated. ' + 'Use `has_object_permission()` instead for object permissions.', + DeprecationWarning, stacklevel=2 + ) + return self.has_permission(request, view, obj) + return True class AllowAny(BasePermission): @@ -25,7 +42,7 @@ class AllowAny(BasePermission): permission_classes list, but it's useful because it makes the intention more explicit. """ - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): return True @@ -34,7 +51,7 @@ class IsAuthenticated(BasePermission): Allows access only to authenticated users. """ - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): if request.user and request.user.is_authenticated(): return True return False @@ -45,7 +62,7 @@ class IsAdminUser(BasePermission): Allows access only to admin users. """ - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): if request.user and request.user.is_staff: return True return False @@ -56,7 +73,7 @@ class IsAuthenticatedOrReadOnly(BasePermission): The request is authenticated as a user, or is a read-only request. """ - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): if (request.method in SAFE_METHODS or request.user and request.user.is_authenticated()): @@ -72,8 +89,8 @@ class DjangoModelPermissions(BasePermission): It ensures that the user is authenticated, and has the appropriate `add`/`change`/`delete` permissions on the model. - This permission will only be applied against view classes that - provide a `.model` attribute, such as the generic class-based views. + This permission can only be applied against view classes that + provide a `.model` or `.queryset` attribute. """ # Map methods into required permission codes. @@ -89,6 +106,8 @@ class DjangoModelPermissions(BasePermission): 'DELETE': ['%(app_label)s.delete_%(model_name)s'], } + authenticated_users_only = True + def get_required_permissions(self, method, model_cls): """ Given a model and an HTTP method, return the list of permission @@ -100,15 +119,56 @@ class DjangoModelPermissions(BasePermission): } return [perm % kwargs for perm in self.perms_map[method]] - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): model_cls = getattr(view, 'model', None) - if not model_cls: + queryset = getattr(view, 'queryset', None) + + if model_cls is None and queryset is not None: + model_cls = queryset.model + + # Workaround to ensure DjangoModelPermissions are not applied + # to the root view when using DefaultRouter. + if model_cls is None and getattr(view, '_ignore_model_permissions'): return True + assert model_cls, ('Cannot apply DjangoModelPermissions on a view that' + ' does not have `.model` or `.queryset` property.') + perms = self.get_required_permissions(request.method, model_cls) if (request.user and - request.user.is_authenticated() and - request.user.has_perms(perms, obj)): + (request.user.is_authenticated() or not self.authenticated_users_only) and + request.user.has_perms(perms)): return True return False + + +class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): + """ + Similar to DjangoModelPermissions, except that anonymous users are + allowed read-only access. + """ + authenticated_users_only = False + + +class TokenHasReadWriteScope(BasePermission): + """ + The request is authenticated as a user and the token used has the right scope + """ + + def has_permission(self, request, view): + token = request.auth + read_only = request.method in SAFE_METHODS + + if not token: + return False + + if hasattr(token, 'resource'): # OAuth 1 + return read_only or not request.auth.resource.is_readonly + elif hasattr(token, 'scope'): # OAuth 2 + required = oauth2_constants.READ if read_only else oauth2_constants.WRITE + return oauth2_provider_scope.check(required, request.auth.scope) + + assert False, ('TokenHasReadWriteScope requires either the' + '`OAuthAuthentication` or `OAuth2Authentication` authentication ' + 'class to be used.') diff --git a/rest_framework/relations.py b/rest_framework/relations.py new file mode 100644 index 00000000..c4b790d4 --- /dev/null +++ b/rest_framework/relations.py @@ -0,0 +1,588 @@ +""" +Serializer fields that deal with relationships. + +These fields allow you to specify the style that should be used to represent +model relationships, including hyperlinks, primary keys, or slugs. +""" +from __future__ import unicode_literals +from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch +from django import forms +from django.forms import widgets +from django.forms.models import ModelChoiceIterator +from django.utils.translation import ugettext_lazy as _ +from rest_framework.fields import Field, WritableField, get_component +from rest_framework.reverse import reverse +from rest_framework.compat import urlparse +from rest_framework.compat import smart_text +import warnings + + +##### Relational fields ##### + + +# Not actually Writable, but subclasses may need to be. +class RelatedField(WritableField): + """ + Base class for related model fields. + + This represents a relationship using the unicode representation of the target. + """ + widget = widgets.Select + many_widget = widgets.SelectMultiple + form_field_class = forms.ChoiceField + many_form_field_class = forms.MultipleChoiceField + + cache_choices = False + empty_label = None + read_only = True + many = False + + def __init__(self, *args, **kwargs): + + # 'null' is to be deprecated in favor of 'required' + if 'null' in kwargs: + warnings.warn('The `null` keyword argument is deprecated. ' + 'Use the `required` keyword argument instead.', + DeprecationWarning, stacklevel=2) + kwargs['required'] = not kwargs.pop('null') + + self.queryset = kwargs.pop('queryset', None) + self.many = kwargs.pop('many', self.many) + if self.many: + self.widget = self.many_widget + self.form_field_class = self.many_form_field_class + + kwargs['read_only'] = kwargs.pop('read_only', self.read_only) + super(RelatedField, self).__init__(*args, **kwargs) + + def initialize(self, parent, field_name): + super(RelatedField, self).initialize(parent, field_name) + if self.queryset is None and not self.read_only: + try: + manager = getattr(self.parent.opts.model, self.source or field_name) + if hasattr(manager, 'related'): # Forward + self.queryset = manager.related.model._default_manager.all() + else: # Reverse + self.queryset = manager.field.rel.to._default_manager.all() + except Exception: + raise + msg = ('Serializer related fields must include a `queryset`' + + ' argument or set `read_only=True') + raise Exception(msg) + + ### We need this stuff to make form choices work... + + def prepare_value(self, obj): + return self.to_native(obj) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_text(obj) + ident = smart_text(self.to_native(obj)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + def _get_queryset(self): + return self._queryset + + def _set_queryset(self, queryset): + self._queryset = queryset + self.widget.choices = self.choices + + queryset = property(_get_queryset, _set_queryset) + + def _get_choices(self): + # If self._choices is set, then somebody must have manually set + # the property self.choices. In this case, just return self._choices. + if hasattr(self, '_choices'): + return self._choices + + # Otherwise, execute the QuerySet in self.queryset to determine the + # choices dynamically. Return a fresh ModelChoiceIterator that has not been + # consumed. Note that we're instantiating a new ModelChoiceIterator *each* + # time _get_choices() is called (and, thus, each time self.choices is + # accessed) so that we can ensure the QuerySet has not been consumed. This + # construct might look complicated but it allows for lazy evaluation of + # the queryset. + return ModelChoiceIterator(self) + + def _set_choices(self, value): + # Setting choices also sets the choices on the widget. + # choices can be any iterable, but we call list() on it because + # it will be consumed more than once. + self._choices = self.widget.choices = list(value) + + choices = property(_get_choices, _set_choices) + + ### Regular serializer stuff... + + def field_to_native(self, obj, field_name): + try: + if self.source == '*': + return self.to_native(obj) + + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break + except ObjectDoesNotExist: + return None + + if value is None: + return None + + if self.many: + return [self.to_native(item) for item in value.all()] + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + try: + if self.many: + try: + # Form data + value = data.getlist(field_name) + if value == [''] or value == []: + raise KeyError + except AttributeError: + # Non-form data + value = data[field_name] + else: + value = data[field_name] + except KeyError: + if self.partial: + return + value = [] if self.many else None + + if value in (None, '') and self.required: + raise ValidationError(self.error_messages['required']) + elif value in (None, ''): + into[(self.source or field_name)] = None + elif self.many: + into[(self.source or field_name)] = [self.from_native(item) for item in value] + else: + into[(self.source or field_name)] = self.from_native(value) + + +### PrimaryKey relationships + +class PrimaryKeyRelatedField(RelatedField): + """ + Represents a relationship as a pk value. + """ + read_only = False + + default_error_messages = { + 'does_not_exist': _("Invalid pk '%s' - object does not exist."), + 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'), + } + + # TODO: Remove these field hacks... + def prepare_value(self, obj): + return self.to_native(obj.pk) + + def label_from_instance(self, obj): + """ + Return a readable representation for use with eg. select widgets. + """ + desc = smart_text(obj) + ident = smart_text(self.to_native(obj.pk)) + if desc == ident: + return desc + return "%s - %s" % (desc, ident) + + # TODO: Possibly change this to just take `obj`, through prob less performant + def to_native(self, pk): + return pk + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(pk=data) + except ObjectDoesNotExist: + msg = self.error_messages['does_not_exist'] % smart_text(data) + raise ValidationError(msg) + except (TypeError, ValueError): + received = type(data).__name__ + msg = self.error_messages['incorrect_type'] % received + raise ValidationError(msg) + + def field_to_native(self, obj, field_name): + if self.many: + # To-many relationship + try: + # Prefer obj.serializable_value for performance reasons + queryset = obj.serializable_value(self.source or field_name) + except AttributeError: + # RelatedManager (reverse relationship) + queryset = getattr(obj, self.source or field_name) + + # Forward relationship + return [self.to_native(item.pk) for item in queryset.all()] + + # To-one relationship + try: + # Prefer obj.serializable_value for performance reasons + pk = obj.serializable_value(self.source or field_name) + except AttributeError: + # RelatedObject (reverse relationship) + try: + pk = getattr(obj, self.source or field_name).pk + except ObjectDoesNotExist: + return None + + # Forward relationship + return self.to_native(pk) + + +### Slug relationships + + +class SlugRelatedField(RelatedField): + """ + Represents a relationship using a unique field on the target. + """ + read_only = False + + default_error_messages = { + 'does_not_exist': _("Object with %s=%s does not exist."), + 'invalid': _('Invalid value.'), + } + + def __init__(self, *args, **kwargs): + self.slug_field = kwargs.pop('slug_field', None) + assert self.slug_field, 'slug_field is required' + super(SlugRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + return getattr(obj, self.slug_field) + + def from_native(self, data): + if self.queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + return self.queryset.get(**{self.slug_field: data}) + except ObjectDoesNotExist: + raise ValidationError(self.error_messages['does_not_exist'] % + (self.slug_field, smart_text(data))) + except (TypeError, ValueError): + msg = self.error_messages['invalid'] + raise ValidationError(msg) + + +### Hyperlinked relationships + +class HyperlinkedRelatedField(RelatedField): + """ + Represents a relationship using hyperlinking. + """ + read_only = False + lookup_field = 'pk' + + default_error_messages = { + 'no_match': _('Invalid hyperlink - No URL match'), + 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), + 'configuration_error': _('Invalid hyperlink due to configuration error'), + 'does_not_exist': _("Invalid hyperlink - object does not exist."), + 'incorrect_type': _('Incorrect type. Expected url string, received %s.'), + } + + # These are all pending deprecation + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + + def __init__(self, *args, **kwargs): + try: + self.view_name = kwargs.pop('view_name') + except KeyError: + raise ValueError("Hyperlinked field requires 'view_name' kwarg") + + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + self.format = kwargs.pop('format', None) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + + super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + + def get_url(self, obj, view_name, request, format): + """ + Given an object, return the URL that hyperlinks to the object. + + May raise a `NoReverseMatch` if the `view_name` and `lookup_field` + attributes are not configured to correctly match the URL conf. + """ + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + if self.pk_url_kwarg != 'pk': + # Only try pk if it has been explicitly set. + # Otherwise, the default `lookup_field = 'pk'` has us covered. + pk = obj.pk + kwargs = {self.pk_url_kwarg: pk} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + slug = getattr(obj, self.slug_field, None) + if slug is not None: + # Only try slug if it corresponds to an attribute on the object. + kwargs = {self.slug_url_kwarg: slug} + try: + ret = reverse(view_name, kwargs=kwargs, request=request, format=format) + if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': + # If the lookup succeeds using the default slug params, + # then `slug_field` is being used implicitly, and we + # we need to warn about the pending deprecation. + msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ + 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + return ret + except NoReverseMatch: + pass + + raise NoReverseMatch() + + def get_object(self, queryset, view_name, view_args, view_kwargs): + """ + Return the object corresponding to a matched URL. + + Takes the matched URL conf arguments, and the queryset, and should + return an object instance, or raise an `ObjectDoesNotExist` exception. + """ + lookup = view_kwargs.get(self.lookup_field, None) + pk = view_kwargs.get(self.pk_url_kwarg, None) + slug = view_kwargs.get(self.slug_url_kwarg, None) + + if lookup is not None: + filter_kwargs = {self.lookup_field: lookup} + elif pk is not None: + filter_kwargs = {'pk': pk} + elif slug is not None: + filter_kwargs = {self.slug_field: slug} + else: + raise ObjectDoesNotExist() + + return queryset.get(**filter_kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + format = self.format or self.context.get('format', None) + + if request is None: + msg = ( + "Using `HyperlinkedRelatedField` without including the request " + "in the serializer context is deprecated. " + "Add `context={'request': request}` when instantiating " + "the serializer." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=4) + + # If the object has not yet been saved then we cannot hyperlink to it. + if getattr(obj, 'pk', None) is None: + return + + # Return the hyperlink, or error if incorrectly configured. + try: + return self.get_url(obj, view_name, request, format) + except NoReverseMatch: + msg = ( + 'Could not resolve URL for hyperlinked relationship using ' + 'view name "%s". You may have failed to include the related ' + 'model in your API, or incorrectly configured the ' + '`lookup_field` attribute on this field.' + ) + raise Exception(msg % view_name) + + def from_native(self, value): + # Convert URL -> model instance pk + # TODO: Use values_list + queryset = self.queryset + if queryset is None: + raise Exception('Writable related fields must include a `queryset` argument') + + try: + http_prefix = value.startswith('http:') or value.startswith('https:') + except AttributeError: + msg = self.error_messages['incorrect_type'] + raise ValidationError(msg % type(value).__name__) + + if http_prefix: + # If needed convert absolute URLs to relative path + value = urlparse.urlparse(value).path + prefix = get_script_prefix() + if value.startswith(prefix): + value = '/' + value[len(prefix):] + + try: + match = resolve(value) + except Exception: + raise ValidationError(self.error_messages['no_match']) + + if match.view_name != self.view_name: + raise ValidationError(self.error_messages['incorrect_match']) + + try: + return self.get_object(queryset, match.view_name, + match.args, match.kwargs) + except (ObjectDoesNotExist, TypeError, ValueError): + raise ValidationError(self.error_messages['does_not_exist']) + + +class HyperlinkedIdentityField(Field): + """ + Represents the instance, or a property on the instance, using hyperlinking. + """ + lookup_field = 'pk' + read_only = True + + # These are all pending deprecation + pk_url_kwarg = 'pk' + slug_field = 'slug' + slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden + + def __init__(self, *args, **kwargs): + # TODO: Make view_name mandatory, and have the + # HyperlinkedModelSerializer set it on-the-fly + self.view_name = kwargs.pop('view_name', None) + # Optionally the format of the target hyperlink may be specified + self.format = kwargs.pop('format', None) + + self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) + + # These are pending deprecation + if 'pk_url_kwarg' in kwargs: + msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_url_kwarg' in kwargs: + msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + if 'slug_field' in kwargs: + msg = 'slug_field is pending deprecation. Use lookup_field instead.' + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + self.slug_field = kwargs.pop('slug_field', self.slug_field) + default_slug_kwarg = self.slug_url_kwarg or self.slug_field + self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg) + self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) + + super(HyperlinkedIdentityField, self).__init__(*args, **kwargs) + + def field_to_native(self, obj, field_name): + request = self.context.get('request', None) + format = self.context.get('format', None) + view_name = self.view_name or self.parent.opts.view_name + lookup_field = getattr(obj, self.lookup_field) + kwargs = {self.lookup_field: lookup_field} + + if request is None: + warnings.warn("Using `HyperlinkedIdentityField` without including the " + "request in the serializer context is deprecated. " + "Add `context={'request': request}` when instantiating the serializer.", + DeprecationWarning, stacklevel=4) + + # By default use whatever format is given for the current context + # unless the target is a different type to the source. + # + # Eg. Consider a HyperlinkedIdentityField pointing from a json + # representation to an html property of that representation... + # + # '/snippets/1/' should link to '/snippets/1/highlight/' + # ...but... + # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' + if format and self.format and self.format != format: + format = self.format + + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(view_name, kwargs=kwargs, request=request, format=format) + except NoReverseMatch: + pass + + raise Exception('Could not resolve URL for field using view name "%s"' % view_name) + + +### Old-style many classes for backwards compat + +class ManyRelatedField(RelatedField): + def __init__(self, *args, **kwargs): + warnings.warn('`ManyRelatedField()` is deprecated. ' + 'Use `RelatedField(many=True)` instead.', + DeprecationWarning, stacklevel=2) + kwargs['many'] = True + super(ManyRelatedField, self).__init__(*args, **kwargs) + + +class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField): + def __init__(self, *args, **kwargs): + warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. ' + 'Use `PrimaryKeyRelatedField(many=True)` instead.', + DeprecationWarning, stacklevel=2) + kwargs['many'] = True + super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs) + + +class ManySlugRelatedField(SlugRelatedField): + def __init__(self, *args, **kwargs): + warnings.warn('`ManySlugRelatedField()` is deprecated. ' + 'Use `SlugRelatedField(many=True)` instead.', + DeprecationWarning, stacklevel=2) + kwargs['many'] = True + super(ManySlugRelatedField, self).__init__(*args, **kwargs) + + +class ManyHyperlinkedRelatedField(HyperlinkedRelatedField): + def __init__(self, *args, **kwargs): + warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. ' + 'Use `HyperlinkedRelatedField(many=True)` instead.', + DeprecationWarning, stacklevel=2) + kwargs['many'] = True + super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a4ae717d..8361cd40 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -6,21 +6,26 @@ on the response, such as JSON encoded data or HTML output. REST framework also provides an HTML renderer the renders the browsable API. """ +from __future__ import unicode_literals + import copy import string +import json from django import forms from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template -from django.utils import simplejson as json +from django.utils.xmlutils import SimplerXMLGenerator +from rest_framework.compat import StringIO +from rest_framework.compat import six +from rest_framework.compat import smart_text from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings from rest_framework.request import clone_request -from rest_framework.utils import dict2xml from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import VERSION, status -from rest_framework import parsers +from rest_framework.utils.formatting import get_view_name, get_view_description +from rest_framework import exceptions, parsers, status, VERSION class BaseRenderer(object): @@ -53,14 +58,14 @@ class JSONRenderer(BaseRenderer): return '' # If 'indent' is provided in the context, then pretty print the result. - # E.g. If we're being called by the BrowseableAPIRenderer. + # E.g. If we're being called by the BrowsableAPIRenderer. renderer_context = renderer_context or {} indent = renderer_context.get('indent', None) if accepted_media_type: # If the media type looks like 'application/json; indent=4', # then pretty print the result. - base_media_type, params = parse_header(accepted_media_type) + base_media_type, params = parse_header(accepted_media_type.encode('ascii')) indent = params.get('indent', indent) try: indent = max(min(int(indent), 8), 0) @@ -86,7 +91,7 @@ class JSONPRenderer(JSONRenderer): Determine the name of the callback to wrap around the json output. """ request = renderer_context.get('request', None) - params = request and request.GET or {} + params = request and request.QUERY_PARAMS or {} return params.get(self.callback_parameter, self.default_callback) def render(self, data, accepted_media_type=None, renderer_context=None): @@ -100,7 +105,7 @@ class JSONPRenderer(JSONRenderer): callback = self.get_callback(renderer_context) json = super(JSONPRenderer, self).render(data, accepted_media_type, renderer_context) - return u"%s(%s);" % (callback, json) + return "%s(%s);" % (callback, json) class XMLRenderer(BaseRenderer): @@ -117,7 +122,38 @@ class XMLRenderer(BaseRenderer): """ if data is None: return '' - return dict2xml(data) + + stream = StringIO() + + xml = SimplerXMLGenerator(stream, "utf-8") + xml.startDocument() + xml.startElement("root", {}) + + self._to_xml(xml, data) + + xml.endElement("root") + xml.endDocument() + return stream.getvalue() + + def _to_xml(self, xml, data): + if isinstance(data, (list, tuple)): + for item in data: + xml.startElement("list-item", {}) + self._to_xml(xml, item) + xml.endElement("list-item") + + elif isinstance(data, dict): + for key, value in six.iteritems(data): + xml.startElement(key, {}) + self._to_xml(xml, value) + xml.endElement(key) + + elif data is None: + # Don't output any value + pass + + else: + xml.characters(smart_text(data)) class YAMLRenderer(BaseRenderer): @@ -133,6 +169,8 @@ class YAMLRenderer(BaseRenderer): """ Renders *obj* into serialized YAML. """ + assert yaml, 'YAMLRenderer requires pyyaml to be installed' + if data is None: return '' @@ -215,7 +253,7 @@ class TemplateHTMLRenderer(BaseRenderer): try: # Try to find an appropriate error template return self.resolve_template(template_names) - except: + except Exception: # Fall back to using eg '404 Not Found' return Template('%d %s' % (response.status_code, response.status_text.title())) @@ -297,12 +335,10 @@ class BrowsableAPIRenderer(BaseRenderer): if not api_settings.FORM_METHOD_OVERRIDE: return # Cannot use form overloading - request = clone_request(request, method) try: - if not view.has_permission(request, obj): - return # Don't have permission - except: - return # Don't have permission and exception explicitly raise + view.check_permissions(request) + except exceptions.APIException: + return False # Doesn't have permissions return True def serializer_to_form_fields(self, serializer): @@ -333,8 +369,33 @@ class BrowsableAPIRenderer(BaseRenderer): kwargs['label'] = k fields[k] = v.form_field_class(**kwargs) + return fields + def _get_form(self, view, method, request): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_form(view, method, request) + finally: + view.request = restore + + def _get_raw_data_form(self, view, method, request, media_types): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_raw_data_form(view, method, request, media_types) + finally: + view.request = restore + def get_form(self, view, method, request): """ Get a form, possibly bound to either the input or output data. @@ -345,24 +406,23 @@ class BrowsableAPIRenderer(BaseRenderer): if not self.show_form_for_method(view, method, request, obj): return - if method == 'DELETE' or method == 'OPTIONS': + if method in ('DELETE', 'OPTIONS'): return True # Don't actually need to return a form if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: - media_types = [parser.media_type for parser in view.parser_classes] - return self.get_generic_content_form(media_types) + return serializer = view.get_serializer(instance=obj) fields = self.serializer_to_form_fields(serializer) # Creating an on the fly form see: # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) + OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) data = (obj is not None) and serializer.data or None form_instance = OnTheFlyForm(data) return form_instance - def get_generic_content_form(self, media_types): + def get_raw_data_form(self, view, method, request, media_types): """ Returns a form that allows for arbitrary content types to be tunneled via standard HTML forms. @@ -375,6 +435,11 @@ class BrowsableAPIRenderer(BaseRenderer): and api_settings.FORM_CONTENTTYPE_OVERRIDE): return None + # Check permissions + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE content_field = api_settings.FORM_CONTENT_OVERRIDE choices = [(media_type, media_type) for media_type in media_types] @@ -386,7 +451,7 @@ class BrowsableAPIRenderer(BaseRenderer): super(GenericContentForm, self).__init__() self.fields[content_type_field] = forms.ChoiceField( - label='Content Type', + label='Media type', choices=choices, initial=initial ) @@ -398,16 +463,13 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - try: - return view.get_name() - except AttributeError: - return view.__doc__ + return get_view_name(view.__class__, getattr(view, 'suffix', None)) def get_description(self, view): - try: - return view.get_description(html=True) - except AttributeError: - return view.__doc__ + return get_view_description(view.__class__, html=True) + + def get_breadcrumbs(self, request): + return get_breadcrumbs(request.path) def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -422,18 +484,25 @@ class BrowsableAPIRenderer(BaseRenderer): view = renderer_context['view'] request = renderer_context['request'] response = renderer_context['response'] + media_types = [parser.media_type for parser in view.parser_classes] renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self.get_form(view, 'PUT', request) - post_form = self.get_form(view, 'POST', request) - delete_form = self.get_form(view, 'DELETE', request) - options_form = self.get_form(view, 'OPTIONS', request) + put_form = self._get_form(view, 'PUT', request) + post_form = self._get_form(view, 'POST', request) + patch_form = self._get_form(view, 'PATCH', request) + delete_form = self._get_form(view, 'DELETE', request) + options_form = self._get_form(view, 'OPTIONS', request) + + raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) + raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) + raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) description = self.get_description(view) - breadcrumb_list = get_breadcrumbs(request.path) + breadcrumb_list = self.get_breadcrumbs(request) template = loader.get_template(self.template) context = RequestContext(request, { @@ -447,10 +516,18 @@ class BrowsableAPIRenderer(BaseRenderer): 'breadcrumblist': breadcrumb_list, 'allowed_methods': view.allowed_methods, 'available_formats': [renderer.format for renderer in view.renderer_classes], + 'put_form': put_form, 'post_form': post_form, + 'patch_form': patch_form, 'delete_form': delete_form, 'options_form': options_form, + + 'raw_data_put_form': raw_data_put_form, + 'raw_data_post_form': raw_data_post_form, + 'raw_data_patch_form': raw_data_patch_form, + 'raw_data_put_or_patch_form': raw_data_put_or_patch_form, + 'api_settings': api_settings }) diff --git a/rest_framework/request.py b/rest_framework/request.py index b7133608..a434659c 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -1,18 +1,21 @@ """ -The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request` -object received in all the views. +The Request class is used as a wrapper around the standard request object. The wrapped request then offers a richer API, in particular : - content automatically parsed according to `Content-Type` header, - and available as :meth:`.DATA<Request.DATA>` + and available as `request.DATA` - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ -from StringIO import StringIO - +from __future__ import unicode_literals +from django.conf import settings +from django.http import QueryDict from django.http.multipartparser import parse_header +from django.utils.datastructures import MultiValueDict +from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions +from rest_framework.compat import BytesIO from rest_framework.settings import api_settings @@ -20,7 +23,7 @@ def is_form_media_type(media_type): """ Return True if the media type is a valid form media type. """ - base_media_type, params = parse_header(media_type) + base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING)) return (base_media_type == 'application/x-www-form-urlencoded' or base_media_type == 'multipart/form-data') @@ -42,10 +45,11 @@ def clone_request(request, method): Internal helper method to clone a request, replacing with a different HTTP method. Used for checking permissions against other methods. """ - ret = Request(request._request, - request.parsers, - request.authenticators, - request.parser_context) + ret = Request(request=request._request, + parsers=request.parsers, + authenticators=request.authenticators, + negotiator=request.negotiator, + parser_context=request.parser_context) ret._data = request._data ret._files = request._files ret._content_type = request._content_type @@ -55,6 +59,8 @@ def clone_request(request, method): ret._user = request._user if hasattr(request, '_auth'): ret._auth = request._auth + if hasattr(request, '_authenticator'): + ret._authenticator = request._authenticator return ret @@ -90,6 +96,7 @@ class Request(object): if self.parser_context is None: self.parser_context = {} self.parser_context['request'] = self + self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() @@ -166,17 +173,17 @@ class Request(object): by the authentication classes provided to the request. """ if not hasattr(self, '_user'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._user @user.setter def user(self, value): - """ - Sets the user on the current request. This is necessary to maintain - compatilbility with django.contrib.auth where the user proprety is - set in the login and logout functions. - """ - self._user = value + """ + Sets the user on the current request. This is necessary to maintain + compatilbility with django.contrib.auth where the user proprety is + set in the login and logout functions. + """ + self._user = value @property def auth(self): @@ -185,7 +192,7 @@ class Request(object): request, such as an authentication token. """ if not hasattr(self, '_auth'): - self._user, self._auth = self._authenticate() + self._authenticator, self._user, self._auth = self._authenticate() return self._auth @auth.setter @@ -196,6 +203,16 @@ class Request(object): """ self._auth = value + @property + def successful_authenticator(self): + """ + Return the instance of the authentication instance class that was used + to authenticate the request, or `None`. + """ + if not hasattr(self, '_authenticator'): + self._authenticator, self._user, self._auth = self._authenticate() + return self._authenticator + def _load_data_and_files(self): """ Parses the request content into self.DATA and self.FILES. @@ -213,11 +230,17 @@ class Request(object): """ self._content_type = self.META.get('HTTP_CONTENT_TYPE', self.META.get('CONTENT_TYPE', '')) + self._perform_form_overloading() - # if the HTTP method was not overloaded, we take the raw HTTP method + if not _hasattr(self, '_method'): self._method = self._request.method + if self._method == 'POST': + # Allow X-HTTP-METHOD-OVERRIDE header + self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', + self._method) + def _load_stream(self): """ Return the content body of the request, as a stream. @@ -233,7 +256,7 @@ class Request(object): elif hasattr(self._request, 'read'): self._stream = self._request else: - self._stream = StringIO(self.raw_post_data) + self._stream = BytesIO(self.raw_post_data) def _perform_form_overloading(self): """ @@ -268,7 +291,7 @@ class Request(object): self._CONTENT_PARAM in self._data and self._CONTENTTYPE_PARAM in self._data): self._content_type = self._data[self._CONTENTTYPE_PARAM] - self._stream = StringIO(self._data[self._CONTENT_PARAM]) + self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING)) self._data, self._files = (Empty, Empty) def _parse(self): @@ -281,7 +304,9 @@ class Request(object): media_type = self.content_type if stream is None or media_type is None: - return (None, None) + empty_data = QueryDict('', self._request._encoding) + empty_files = MultiValueDict() + return (empty_data, empty_files) parser = self.negotiator.select_parser(self, self.parsers) @@ -295,25 +320,28 @@ class Request(object): try: return (parsed.data, parsed.files) except AttributeError: - return (parsed, None) + empty_files = MultiValueDict() + return (parsed, empty_files) def _authenticate(self): """ - Attempt to authenticate the request using each authentication instance in turn. - Returns a two-tuple of (user, authtoken). + Attempt to authenticate the request using each authentication instance + in turn. + Returns a three-tuple of (authenticator, user, authtoken). """ for authenticator in self.authenticators: user_auth_tuple = authenticator.authenticate(self) if not user_auth_tuple is None: - return user_auth_tuple + user, auth = user_auth_tuple + return (authenticator, user, auth) return self._not_authenticated() def _not_authenticated(self): """ - Return a two-tuple of (user, authtoken), representing an - unauthenticated request. + Return a three-tuple of (authenticator, user, authtoken), representing + an unauthenticated request. - By default this will be (AnonymousUser, None). + By default this will be (None, AnonymousUser, None). """ if api_settings.UNAUTHENTICATED_USER: user = api_settings.UNAUTHENTICATED_USER() @@ -325,7 +353,7 @@ class Request(object): else: auth = None - return (user, auth) + return (None, user, auth) def __getattr__(self, attr): """ diff --git a/rest_framework/response.py b/rest_framework/response.py index be78c43a..26e4ab37 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,5 +1,13 @@ +""" +The Response class in REST framework is similiar to HTTPResponse, except that +it is initialized with unrendered data, instead of a pre-rendered string. + +The appropriate renderer is called during Django's template response rendering. +""" +from __future__ import unicode_literals from django.core.handlers.wsgi import STATUS_CODE_TEXT from django.template.response import SimpleTemplateResponse +from rest_framework.compat import six class Response(SimpleTemplateResponse): @@ -22,9 +30,9 @@ class Response(SimpleTemplateResponse): self.data = data self.template_name = template_name self.exception = exception - + if headers: - for name,value in headers.iteritems(): + for name, value in six.iteritems(headers): self[name] = value @property diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index c9db02f0..a51b07f5 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -1,6 +1,7 @@ """ Provide reverse functions that return fully qualified URLs """ +from __future__ import unicode_literals from django.core.urlresolvers import reverse as django_reverse from django.utils.functional import lazy diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..dba104c3 --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,249 @@ +""" +Routers provide a convenient and consistent way of automatically +determining the URL conf for your API. + +They are used by simply instantiating a Router class, and then registering +all the required ViewSets with that router. + +For example, you might have a `urls.py` that looks something like this: + + router = routers.DefaultRouter() + router.register('users', UserViewSet, 'user') + router.register('accounts', AccountViewSet, 'account') + + urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from collections import namedtuple +from rest_framework import views +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.urlpatterns import format_suffix_patterns + + +Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) + + +def replace_methodname(format_string, methodname): + """ + Partially format a format_string, swapping out any + '{methodname}' or '{methodnamehyphen}' components. + """ + methodnamehyphen = methodname.replace('_', '-') + ret = format_string + ret = ret.replace('{methodname}', methodname) + ret = ret.replace('{methodnamehyphen}', methodnamehyphen) + return ret + + +class BaseRouter(object): + def __init__(self): + self.registry = [] + + def register(self, prefix, viewset, base_name=None): + if base_name is None: + base_name = self.get_default_base_name(viewset) + self.registry.append((prefix, viewset, base_name)) + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + raise NotImplemented('get_default_base_name must be overridden') + + def get_urls(self): + """ + Return a list of URL patterns, given the registered viewsets. + """ + raise NotImplemented('get_urls must be overridden') + + @property + def urls(self): + if not hasattr(self, '_urls'): + self._urls = patterns('', *self.get_urls()) + return self._urls + + +class SimpleRouter(BaseRouter): + routes = [ + # List route. + Route( + url=r'^{prefix}/$', + mapping={ + 'get': 'list', + 'post': 'create' + }, + name='{basename}-list', + initkwargs={'suffix': 'List'} + ), + # Detail route. + Route( + url=r'^{prefix}/{lookup}/$', + mapping={ + 'get': 'retrieve', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy' + }, + name='{basename}-detail', + initkwargs={'suffix': 'Instance'} + ), + # Dynamically generated routes. + # Generated using @action or @link decorators on methods of the viewset. + Route( + url=r'^{prefix}/{lookup}/{methodname}/$', + mapping={ + '{httpmethod}': '{methodname}', + }, + name='{basename}-{methodnamehyphen}', + initkwargs={} + ), + ] + + def get_default_base_name(self, viewset): + """ + If `base_name` is not specified, attempt to automatically determine + it from the viewset. + """ + model_cls = getattr(viewset, 'model', None) + queryset = getattr(viewset, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model + + assert model_cls, '`name` not argument not specified, and could ' \ + 'not automatically determine the name from the viewset, as ' \ + 'it does not have a `.model` or `.queryset` attribute.' + + return model_cls._meta.object_name.lower() + + def get_routes(self, viewset): + """ + Augment `self.routes` with any dynamically generated routes. + + Returns a list of the Route namedtuple. + """ + + # Determine any `@action` or `@link` decorated methods on the viewset + dynamic_routes = [] + for methodname in dir(viewset): + attr = getattr(viewset, methodname) + httpmethod = getattr(attr, 'bind_to_method', None) + if httpmethod: + dynamic_routes.append((httpmethod, methodname)) + + ret = [] + for route in self.routes: + if route.mapping == {'{httpmethod}': '{methodname}'}: + # Dynamic routes (@link or @action decorator) + for httpmethod, methodname in dynamic_routes: + initkwargs = route.initkwargs.copy() + initkwargs.update(getattr(viewset, methodname).kwargs) + ret.append(Route( + url=replace_methodname(route.url, methodname), + mapping={httpmethod: methodname}, + name=replace_methodname(route.name, methodname), + initkwargs=initkwargs, + )) + else: + # Standard route + ret.append(route) + + return ret + + def get_method_map(self, viewset, method_map): + """ + Given a viewset, and a mapping of http methods to actions, + return a new mapping which only includes any mappings that + are actually implemented by the viewset. + """ + bound_methods = {} + for method, action in method_map.items(): + if hasattr(viewset, action): + bound_methods[method] = action + return bound_methods + + def get_lookup_regex(self, viewset): + """ + Given a viewset, return the portion of URL regex that is used + to match against a single instance. + """ + base_regex = '(?P<{lookup_field}>[^/]+)' + lookup_field = getattr(viewset, 'lookup_field', 'pk') + return base_regex.format(lookup_field=lookup_field) + + def get_urls(self): + """ + Use the registered viewsets to generate a list of URL patterns. + """ + ret = [] + + for prefix, viewset, basename in self.registry: + lookup = self.get_lookup_regex(viewset) + routes = self.get_routes(viewset) + + for route in routes: + + # Only actions which actually exist on the viewset will be bound + mapping = self.get_method_map(viewset, route.mapping) + if not mapping: + continue + + # Build the url pattern + regex = route.url.format(prefix=prefix, lookup=lookup) + view = viewset.as_view(mapping, **route.initkwargs) + name = route.name.format(basename=basename) + ret.append(url(regex, view, name=name)) + + return ret + + +class DefaultRouter(SimpleRouter): + """ + The default router extends the SimpleRouter, but also adds in a default + API root view, and adds format suffix patterns to the URLs. + """ + include_root_view = True + include_format_suffixes = True + + def get_api_root_view(self): + """ + Return a view to use as the API root. + """ + api_root_dict = {} + list_name = self.routes[0].name + for prefix, viewset, basename in self.registry: + api_root_dict[prefix] = list_name.format(basename=basename) + + class APIRoot(views.APIView): + _ignore_model_permissions = True + + def get(self, request, format=None): + ret = {} + for key, url_name in api_root_dict.items(): + ret[key] = reverse(url_name, request=request, format=format) + return Response(ret) + + return APIRoot.as_view() + + def get_urls(self): + """ + Generate the list of URL patterns, including a default root view + for the API, and appending `.json` style format suffixes. + """ + urls = [] + + if self.include_root_view: + root_url = url(r'^$', self.get_api_root_view(), name='api-root') + urls.append(root_url) + + default_urls = super(DefaultRouter, self).get_urls() + urls.extend(default_urls) + + if self.include_format_suffixes: + urls = format_suffix_patterns(urls) + + return urls diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py index 0ce379eb..ce11b213 100755 --- a/rest_framework/runtests/runcoverage.py +++ b/rest_framework/runtests/runcoverage.py @@ -8,6 +8,9 @@ Useful tool to run the test suite for rest_framework and generate a coverage rep # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' from coverage import coverage @@ -49,12 +52,21 @@ def main(): if os.path.basename(path) in ['tests', 'runtests', 'migrations']: continue - # Drop the compat module from coverage, since we're not interested in the coverage - # of a module which is specifically for resolving environment dependant imports. + # Drop the compat and six modules from coverage, since we're not interested in the coverage + # of modules which are specifically for resolving environment dependant imports. # (Because we'll end up getting different coverage reports for it for each environment) if 'compat.py' in files: files.remove('compat.py') + if 'six.py' in files: + files.remove('six.py') + + # Same applies to template tags module. + # This module has to include branching on Django versions, + # so it's never possible for it to have full coverage. + if 'rest_framework.py' in files: + files.remove('rest_framework.py') + cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')]) cov.report(cov_files) diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index 729ef26a..4a333fb3 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -5,11 +5,9 @@ # http://code.djangoproject.com/svn/django/trunk/tests/runtests.py import os import sys -""" -Need to fix sys path so following works without specifically messing with PYTHONPATH -python ./rest_framework/runtests/runtests.py -""" -sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) + +# fix sys path so we don't need to setup PYTHONPATH +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings' from django.conf import settings @@ -35,7 +33,7 @@ def main(): elif len(sys.argv) == 1: test_case = '' else: - print usage() + print(usage()) sys.exit(1) failures = test_runner.run_tests(['tests' + test_case]) diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index dd5d9dc3..9b519f27 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -97,11 +97,41 @@ INSTALLED_APPS = ( # 'django.contrib.admindocs', 'rest_framework', 'rest_framework.authtoken', - 'rest_framework.tests' + 'rest_framework.tests', ) +# OAuth is optional and won't work if there is no oauth_provider & oauth2 +try: + import oauth_provider + import oauth2 +except ImportError: + pass +else: + INSTALLED_APPS += ( + 'oauth_provider', + ) + +try: + import provider +except ImportError: + pass +else: + INSTALLED_APPS += ( + 'provider', + 'provider.oauth2', + ) + STATIC_URL = '/static/' +PASSWORD_HASHERS = ( + 'django.contrib.auth.hashers.SHA1PasswordHasher', + 'django.contrib.auth.hashers.PBKDF2PasswordHasher', + 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', + 'django.contrib.auth.hashers.BCryptPasswordHasher', + 'django.contrib.auth.hashers.MD5PasswordHasher', + 'django.contrib.auth.hashers.CryptPasswordHasher', +) + import django if django.VERSION < (1, 3): diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py index 4b7da787..ed5baeae 100644 --- a/rest_framework/runtests/urls.py +++ b/rest_framework/runtests/urls.py @@ -1,7 +1,7 @@ """ Blank URLConf just to keep runtests.py happy. """ -from django.conf.urls.defaults import * +from rest_framework.compat import patterns urlpatterns = patterns('', ) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e9bc25e4..942ab399 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1,11 +1,25 @@ +""" +Serializers and ModelSerializers are similar to Forms and ModelForms. +Unlike forms, they are not constrained to dealing with HTML output, and +form encoded input. + +Serialization in REST framework is a two-phase process: + +1. Serializers marshal between complex types like model instances, and +python primatives. +2. The process of marshalling between python primatives and request and +response content is handled by parsers and renderers. +""" +from __future__ import unicode_literals import copy import datetime import types from decimal import Decimal +from django.core.paginator import Page from django.db import models from django.forms import widgets from django.utils.datastructures import SortedDict -from rest_framework.compat import get_concrete_model +from rest_framework.compat import get_concrete_model, six # Note: We do the following so that users of the framework can use this style: # @@ -14,10 +28,29 @@ from rest_framework.compat import get_concrete_model # This helps keep the seperation between model fields, form fields, and # serializer fields more explicit. - +from rest_framework.relations import * from rest_framework.fields import * +class NestedValidationError(ValidationError): + """ + The default ValidationError behavior is to stringify each item in the list + if the messages are a list of error messages. + + In the case of nested serializers, where the parent has many children, + then the child's `serializer.errors` will be a list of dicts. In the case + of a single child, the `serializer.errors` will be a dict. + + We need to override the default behavior to get properly nested error dicts. + """ + + def __init__(self, message): + if isinstance(message, dict): + self.messages = [message] + else: + self.messages = message + + class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. @@ -25,20 +58,23 @@ class DictWithMetadata(dict): def __getstate__(self): """ Used by pickle (e.g., caching). - Overriden to remove metadata from the dict, since it shouldn't be pickled - and may in some instances be unpickleable. + Overriden to remove the metadata from the dict, since it shouldn't be + pickled and may in some instances be unpickleable. """ - # return an instance of the first dict in MRO that isn't a DictWithMetadata - for base in self.__class__.__mro__: - if not isinstance(base, DictWithMetadata) and isinstance(base, dict): - return base(self) + return dict(self) -class SortedDictWithMetadata(SortedDict, DictWithMetadata): +class SortedDictWithMetadata(SortedDict): """ A sorted dict-like object, that can have additional properties attached. """ - pass + def __getstate__(self): + """ + Used by pickle (e.g., caching). + Overriden to remove the metadata from the dict, since it shouldn't be + pickle and may in some instances be unpickleable. + """ + return SortedDict(self).__dict__ def _is_protected_type(obj): @@ -63,7 +99,7 @@ def _get_declared_fields(bases, attrs): Note that all fields from the base classes are used. """ fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in attrs.items() + for field_name, obj in list(six.iteritems(attrs)) if isinstance(obj, Field)] fields.sort(key=lambda x: x[1].creation_counter) @@ -72,7 +108,7 @@ def _get_declared_fields(bases, attrs): # in order to maintain the correct order of fields. for base in bases[::-1]: if hasattr(base, 'base_fields'): - fields = base.base_fields.items() + fields + fields = list(base.base_fields.items()) + fields return SortedDict(fields) @@ -93,20 +129,27 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): + """ + This is the Serializer implementation. + We need to implement it as `BaseSerializer` due to metaclass magicks. + """ class Meta(object): pass _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. + _dict_class = SortedDictWithMetadata def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, **kwargs): + context=None, partial=False, many=None, + allow_add_remove=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial + self.many = many + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -118,6 +161,13 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._deleted = None + + if many and instance is not None and not hasattr(instance, '__iter__'): + raise ValueError('instance should be a queryset or other iterable with many=True') + + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -150,6 +200,7 @@ class BaseSerializer(Field): # If 'fields' is specified, use those fields, in that order. if self.opts.fields: + assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' new = SortedDict() for key in self.opts.fields: new[key] = ret[key] @@ -157,6 +208,7 @@ class BaseSerializer(Field): # Remove anything in 'exclude' if self.opts.exclude: + assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' for key in self.opts.exclude: ret.pop(key, None) @@ -166,18 +218,6 @@ class BaseSerializer(Field): return ret ##### - # Field methods - used when the serializer class is itself used as a field. - - def initialize(self, parent, field_name): - """ - Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth. - """ - super(BaseSerializer, self).initialize(parent, field_name) - if parent.opts.depth: - self.opts.depth = parent.opts.depth - 1 - - ##### # Methods to convert or revert from objects <--> primitive representations. def get_field_key(self, field_name): @@ -186,28 +226,17 @@ class BaseSerializer(Field): """ return field_name - def convert_object(self, obj): - """ - Core of serialization. - Convert an object into a dictionary of serialized field values. - """ - ret = self._dict_class() - ret.fields = {} - - for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - ret[key] = value - ret.fields[key] = field - return ret - def restore_fields(self, data, files): """ Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ reverted_data = {} + + if data is not None and not isinstance(data, dict): + self._errors['non_field_errors'] = ['Invalid data'] + return None + for field_name, field in self.fields.items(): field.initialize(parent=self, field_name=field_name) try: @@ -222,6 +251,8 @@ class BaseSerializer(Field): Run `validate_<fieldname>()` and `validate()` methods on the serializer """ for field_name, field in self.fields.items(): + if field_name in self._errors: + continue try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -266,18 +297,21 @@ class BaseSerializer(Field): """ Serialize objects -> primitives. """ - if hasattr(obj, '__iter__'): - return [self.convert_object(item) for item in obj] - return self.convert_object(obj) + ret = self._dict_class() + ret.fields = {} + + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) + key = self.get_field_key(field_name) + value = field.field_to_native(obj, field_name) + ret[key] = value + ret.fields[key] = field + return ret def from_native(self, data, files): """ Deserialize primitives -> objects. """ - if hasattr(data, '__iter__') and not isinstance(data, dict): - # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) - self._errors = {} if data is not None or files is not None: attrs = self.restore_fields(data, files) @@ -290,24 +324,91 @@ class BaseSerializer(Field): def field_to_native(self, obj, field_name): """ - Override default so that we can apply ModelSerializer as a nested - field to relationships. + Override default so that the serializer can be used as a nested field + across relationships. + """ + if self.source == '*': + return self.to_native(obj) + + try: + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break + except ObjectDoesNotExist: + return None + + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] + + if value is None: + return None + + if self.many is not None: + many = self.many + else: + many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type)) + + if many: + return [self.to_native(item) for item in value] + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + """ + Override default so that the serializer can be used as a writable + nested field across relationships. """ - if self.source: - for component in self.source.split('.'): - obj = getattr(obj, component) - if is_simple_callable(obj): - obj = obj() + if self.read_only: + return + + try: + value = data[field_name] + except KeyError: + if self.default is not None and not self.partial: + # Note: partial updates shouldn't set defaults + value = copy.deepcopy(self.default) + else: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + # Set the serializer object if it exists + obj = getattr(self.parent.object, field_name) if self.parent.object else None + + if value in (None, ''): + into[(self.source or field_name)] = None else: - obj = getattr(obj, field_name) - if is_simple_callable(obj): - obj = value() + kwargs = { + 'instance': obj, + 'data': value, + 'context': self.context, + 'partial': self.partial, + 'many': self.many + } + serializer = self.__class__(**kwargs) + + if serializer.is_valid(): + into[self.source or field_name] = serializer.object + else: + # Propagate errors up to our parent + raise NestedValidationError(serializer.errors) - # If the object has an "all" method, assume it's a relationship - if is_simple_callable(getattr(obj, 'all', None)): - return [self.to_native(item) for item in obj.all()] + def get_identity(self, data): + """ + This hook is required for bulk update. + It is used to determine the canonical identity of a given object. - return self.to_native(obj) + Note that the data has not been validated at this point, so we need + to make sure that we catch any cases of incorrect datatypes being + passed to this method. + """ + try: + return data.get('id', None) + except AttributeError: + return None @property def errors(self): @@ -316,9 +417,57 @@ class BaseSerializer(Field): setting self.object if no errors occurred. """ if self._errors is None: - obj = self.from_native(self.init_data, self.init_files) + data, files = self.init_data, self.init_files + + if self.many is not None: + many = self.many + else: + many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) + if many: + warnings.warn('Implict list/queryset serialization is deprecated. ' + 'Use the `many=True` flag when instantiating the serializer.', + DeprecationWarning, stacklevel=3) + + if many: + ret = [] + errors = [] + update = self.object is not None + + if update: + # If this is a bulk update we need to map all the objects + # to a canonical identity so we can determine which + # individual object is being updated for each item in the + # incoming data + objects = self.object + identities = [self.get_identity(self.to_native(obj)) for obj in objects] + identity_to_objects = dict(zip(identities, objects)) + + if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): + for item in data: + if update: + # Determine which object we're updating + identity = self.get_identity(item) + self.object = identity_to_objects.pop(identity, None) + if self.object is None and not self.allow_add_remove: + ret.append(None) + errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) + continue + + ret.append(self.from_native(item, None)) + errors.append(self._errors) + + if update: + self._deleted = identity_to_objects.values() + + self._errors = any(errors) and errors or [] + else: + self._errors = {'non_field_errors': ['Expected a list of items.']} + else: + ret = self.from_native(data, files) + if not self._errors: - self.object = obj + self.object = ret + return self._errors def is_valid(self): @@ -326,20 +475,51 @@ class BaseSerializer(Field): @property def data(self): + """ + Returns the serialized data on the serializer. + """ if self._data is None: - self._data = self.to_native(self.object) + obj = self.object + + if self.many is not None: + many = self.many + else: + many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) + if many: + warnings.warn('Implict list/queryset serialization is deprecated. ' + 'Use the `many=True` flag when instantiating the serializer.', + DeprecationWarning, stacklevel=2) + + if many: + self._data = [self.to_native(item) for item in obj] + else: + self._data = self.to_native(obj) + return self._data - def save(self): + def save_object(self, obj, **kwargs): + obj.save(**kwargs) + + def delete_object(self, obj): + obj.delete() + + def save(self, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + if isinstance(self.object, list): + [self.save_object(item, **kwargs) for item in self.object] + else: + self.save_object(self.object, **kwargs) + + if self.allow_add_remove and self._deleted: + [self.delete_object(item) for item in self._deleted] + return self.object -class Serializer(BaseSerializer): - __metaclass__ = SerializerMetaclass +class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): + pass class ModelSerializerOptions(SerializerOptions): @@ -358,43 +538,125 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions + field_mapping = { + models.AutoField: IntegerField, + models.FloatField: FloatField, + models.IntegerField: IntegerField, + models.PositiveIntegerField: IntegerField, + models.SmallIntegerField: IntegerField, + models.PositiveSmallIntegerField: IntegerField, + models.DateTimeField: DateTimeField, + models.DateField: DateField, + models.TimeField: TimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.CharField: CharField, + models.URLField: URLField, + models.SlugField: SlugField, + models.TextField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.BooleanField: BooleanField, + models.FileField: FileField, + models.ImageField: ImageField, + } + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ cls = self.opts.model + assert cls is not None, \ + "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ opts = get_concrete_model(cls)._meta + ret = SortedDict() + nested = bool(self.opts.depth) + + # Deal with adding the primary key field pk_field = opts.pk - while pk_field.rel: + while pk_field.rel and pk_field.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk pk_field = pk_field.rel.to._meta.pk - fields = [pk_field] - fields += [field for field in opts.fields if field.serialize] - fields += [field for field in opts.many_to_many if field.serialize] - ret = SortedDict() - nested = bool(self.opts.depth) - is_pk = True # First field in the list is the pk - - for model_field in fields: - if is_pk: - field = self.get_pk_field(model_field) - is_pk = False - elif model_field.rel and nested: - field = self.get_nested_field(model_field) - elif model_field.rel: + field = self.get_pk_field(pk_field) + if field: + ret[pk_field.name] = field + + # Deal with forward relationships + forward_rels = [field for field in opts.fields if field.serialize] + forward_rels += [field for field in opts.many_to_many if field.serialize] + + for model_field in forward_rels: + if model_field.rel: to_many = isinstance(model_field, models.fields.related.ManyToManyField) - field = self.get_related_field(model_field, to_many=to_many) + related_model = model_field.rel.to + + if model_field.rel and nested: + if len(inspect.getargspec(self.get_nested_field).args) == 2: + warnings.warn( + 'The `get_nested_field(model_field)` call signature ' + 'is due to be deprecated. ' + 'Use `get_nested_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) + field = self.get_nested_field(model_field) + else: + field = self.get_nested_field(model_field, related_model, to_many) + elif model_field.rel: + if len(inspect.getargspec(self.get_nested_field).args) == 3: + warnings.warn( + 'The `get_related_field(model_field, to_many)` call ' + 'signature is due to be deprecated. ' + 'Use `get_related_field(model_field, related_model, ' + 'to_many) instead', + PendingDeprecationWarning + ) + field = self.get_related_field(model_field, to_many=to_many) + else: + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) if field: ret[model_field.name] = field + # Deal with reverse relationships + if not self.opts.fields: + reverse_rels = [] + else: + # Reverse relationships are only included if they are explicitly + # present in the `fields` option on the serializer + reverse_rels = opts.get_all_related_objects() + reverse_rels += opts.get_all_related_many_to_many_objects() + + for relation in reverse_rels: + accessor_name = relation.get_accessor_name() + if not self.opts.fields or accessor_name not in self.opts.fields: + continue + related_model = relation.model + to_many = relation.field.rel.multiple + + if nested: + field = self.get_nested_field(None, related_model, to_many) + else: + field = self.get_related_field(None, related_model, to_many) + + if field: + ret[accessor_name] = field + + # Add the `read_only` flag to any fields that have bee specified + # in the `read_only_fields` option for field_name in self.opts.read_only_fields: + assert field_name not in self.base_fields.keys(), \ + "field '%s' on serializer '%s' specfied in " \ + "`read_only_fields`, but also added " \ + "as an explict field. Remove it from `read_only_fields`." % \ + (field_name, self.__class__.__name__) assert field_name in ret, \ - "read_only_fields on '%s' included invalid item '%s'" % \ + "Noexistant field '%s' specified in `read_only_fields` " \ + "on serializer '%s'." % \ (self.__class__.__name__, field_name) ret[field_name].read_only = True @@ -404,30 +666,38 @@ class ModelSerializer(Serializer): """ Returns a default instance of the pk field. """ - return Field() + return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, related_model, to_many): """ Creates a default instance of a nested relational field. + + Note that model_field will be `None` for reverse relationships. """ class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to - return NestedModelSerializer() + model = related_model + depth = self.opts.depth - 1 - def get_related_field(self, model_field, to_many=False): + return NestedModelSerializer(many=to_many) + + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. + + Note that model_field will be `None` for reverse relationships. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + kwargs = { - 'null': model_field.null, - 'queryset': model_field.rel.to._default_manager + 'queryset': related_model._default_manager, + 'many': to_many } - if to_many: - return ManyPrimaryKeyRelatedField(**kwargs) + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): @@ -435,17 +705,18 @@ class ModelSerializer(Serializer): Creates a default instance of a basic non-relational field. """ kwargs = {} + has_default = model_field.has_default() - kwargs['blank'] = model_field.blank - - if model_field.null: + if model_field.null or model_field.blank or has_default: kwargs['required'] = False - if model_field.has_default(): - kwargs['required'] = False + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + + if has_default: kwargs['default'] = model_field.get_default() - if model_field.__class__ == models.TextField: + if issubclass(model_field.__class__, models.TextField): kwargs['widget'] = widgets.Textarea # TODO: TypedChoiceField? @@ -459,26 +730,8 @@ class ModelSerializer(Serializer): if model_field.help_text is not None: kwargs['help_text'] = model_field.help_text - field_mapping = { - models.FloatField: FloatField, - models.IntegerField: IntegerField, - models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, - models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.EmailField: EmailField, - models.CharField: CharField, - models.URLField: URLField, - models.SlugField: SlugField, - models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, - } try: - return field_mapping[model_field.__class__](**kwargs) + return self.field_mapping[model_field.__class__](**kwargs) except KeyError: return ModelField(model_field=model_field, **kwargs) @@ -490,54 +743,93 @@ class ModelSerializer(Serializer): opts = get_concrete_model(cls)._meta exclusions = [field.name for field in opts.fields + opts.many_to_many] for field_name, field in self.fields.items(): + field_name = field.source or field_name if field_name in exclusions and not field.read_only: exclusions.remove(field_name) return exclusions + def full_clean(self, instance): + """ + Perform Django's full_clean, and populate the `errors` dictionary + if any validation errors occur. + + Note that we don't perform this inside the `.restore_object()` method, + so that subclasses can override `.restore_object()`, and still get + the full_clean validation checking. + """ + try: + instance.full_clean(exclude=self.get_validation_exclusions()) + except ValidationError as err: + self._errors = err.message_dict + return None + return instance + def restore_object(self, attrs, instance=None): """ Restore the model instance. """ - self.m2m_data = {} + m2m_data = {} + related_data = {} + meta = self.opts.model._meta + + # Reverse fk or one-to-one relations + for (obj, model) in meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + related_data[field_name] = attrs.pop(field_name) + # Reverse m2m relations + for (obj, model) in meta.get_all_related_m2m_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + m2m_data[field_name] = attrs.pop(field_name) + + # Forward m2m relations + for field in meta.many_to_many: + if field.name in attrs: + m2m_data[field.name] = attrs.pop(field.name) + + # Update an existing instance... if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) + # ...or create a new instance else: - # Reverse relations - for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): - field_name = obj.field.related_query_name() - if field_name in attrs: - self.m2m_data[field_name] = attrs.pop(field_name) - - # Forward relations - for field in self.opts.model._meta.many_to_many: - if field.name in attrs: - self.m2m_data[field.name] = attrs.pop(field.name) - instance = self.opts.model(**attrs) - try: - instance.full_clean(exclude=self.get_validation_exclusions()) - except ValidationError, err: - self._errors = err.message_dict - return None + # Any relations that cannot be set until we've + # saved the model get hidden away on these + # private attributes, so we can deal with them + # at the point of save. + instance._related_data = related_data + instance._m2m_data = m2m_data return instance - def save(self, save_m2m=True): + def from_native(self, data, files): + """ + Override the default method to also include model field validation. + """ + instance = super(ModelSerializer, self).from_native(data, files) + if not self._errors: + return self.full_clean(instance) + + def save_object(self, obj, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + obj.save(**kwargs) - if getattr(self, 'm2m_data', None) and save_m2m: - for accessor_name, object_list in self.m2m_data.items(): - setattr(self.object, accessor_name, object_list) - self.m2m_data = {} + if getattr(obj, '_m2m_data', None): + for accessor_name, object_list in obj._m2m_data.items(): + setattr(obj, accessor_name, object_list) + del(obj._m2m_data) - return self.object + if getattr(obj, '_related_data', None): + for accessor_name, related in obj._related_data.items(): + setattr(obj, accessor_name, related) + del(obj._related_data) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -547,13 +839,17 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): def __init__(self, meta): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) + self.lookup_field = getattr(meta, 'lookup_field', None) class HyperlinkedModelSerializer(ModelSerializer): """ + A subclass of ModelSerializer that uses hyperlinked relationships, + instead of primary key relationships. """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' + _hyperlink_field_class = HyperlinkedRelatedField url = HyperlinkedIdentityField() @@ -574,20 +870,35 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) - def get_related_field(self, model_field, to_many): + def get_related_field(self, model_field, related_model, to_many): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - rel = model_field.rel.to kwargs = { - 'null': model_field.null, - 'queryset': rel._default_manager, - 'view_name': self._get_default_view_name(rel) + 'queryset': related_model._default_manager, + 'view_name': self._get_default_view_name(related_model), + 'many': to_many } - if to_many: - return ManyHyperlinkedRelatedField(**kwargs) - return HyperlinkedRelatedField(**kwargs) + + if model_field: + kwargs['required'] = not(model_field.null or model_field.blank) + + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + + return self._hyperlink_field_class(**kwargs) + + def get_identity(self, data): + """ + This hook is required for bulk update. + We need to override the default, to use the url as the identity. + """ + try: + return data.get('url', None) + except AttributeError: + return None diff --git a/rest_framework/settings.py b/rest_framework/settings.py index ee24a4ad..beb511ac 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -17,13 +17,19 @@ This module provides the `api_setting` object, that is used to access REST framework settings, checking for user settings first, then falling back to the defaults. """ +from __future__ import unicode_literals + from django.conf import settings from django.utils import importlib +from rest_framework import ISO_8601 +from rest_framework.compat import six + USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) DEFAULTS = { + # Base API policies 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'rest_framework.renderers.BrowsableAPIRenderer', @@ -45,11 +51,15 @@ DEFAULTS = { 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + + # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer', + 'DEFAULT_FILTER_BACKENDS': (), + # Throttling 'DEFAULT_THROTTLE_RATES': { 'user': None, 'anon': None, @@ -59,9 +69,6 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - # Filtering - 'FILTER_BACKEND': None, - # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -74,6 +81,25 @@ DEFAULTS = { 'URL_FORMAT_OVERRIDE': 'format', 'FORMAT_SUFFIX_KWARG': 'format', + + # Input and output formats + 'DATE_INPUT_FORMATS': ( + ISO_8601, + ), + 'DATE_FORMAT': None, + + 'DATETIME_INPUT_FORMATS': ( + ISO_8601, + ), + 'DATETIME_FORMAT': None, + + 'TIME_INPUT_FORMATS': ( + ISO_8601, + ), + 'TIME_FORMAT': None, + + # Pending deprecation + 'FILTER_BACKEND': None, } @@ -87,6 +113,7 @@ IMPORT_STRINGS = ( 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', + 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', @@ -98,7 +125,7 @@ def perform_import(val, setting_name): If the given setting is a string import notation, then perform the necessary import or imports. """ - if isinstance(val, basestring): + if isinstance(val, six.string_types): return import_from_string(val, setting_name) elif isinstance(val, (list, tuple)): return [import_from_string(item, setting_name) for item in val] @@ -115,8 +142,8 @@ def import_from_string(val, setting_name): module_path, class_name = '.'.join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) - except: - msg = "Could not import '%s' for API setting '%s'" % (val, setting_name) + except ImportError as e: + msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) diff --git a/rest_framework/six.py b/rest_framework/six.py new file mode 100644 index 00000000..9e382312 --- /dev/null +++ b/rest_framework/six.py @@ -0,0 +1,389 @@ +"""Utilities for writing code that runs on Python 2 and 3""" + +import operator +import sys +import types + +__author__ = "Benjamin Peterson <benjamin@python.org>" +__version__ = "1.2.0" + + +# True if we are running on Python 3. +PY3 = sys.version_info[0] == 3 + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform == "java": + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) + # This is a bit ugly, but it avoids running this again. + delattr(tp, self.name) + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + + +class _MovedItems(types.ModuleType): + """Lazy loading of moved objects""" + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("reload_module", "__builtin__", "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("copyreg", "copy_reg"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("winreg", "_winreg"), +] +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) +del attr + +moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_code = "__code__" + _func_defaults = "__defaults__" + + _iterkeys = "keys" + _itervalues = "values" + _iteritems = "items" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_code = "func_code" + _func_defaults = "func_defaults" + + _iterkeys = "iterkeys" + _itervalues = "itervalues" + _iteritems = "iteritems" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +if PY3: + def get_unbound_function(unbound): + return unbound + + Iterator = object + + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) +else: + def get_unbound_function(unbound): + return unbound.im_func + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) + + +def iterkeys(d): + """Return an iterator over the keys of a dictionary.""" + return iter(getattr(d, _iterkeys)()) + +def itervalues(d): + """Return an iterator over the values of a dictionary.""" + return iter(getattr(d, _itervalues)()) + +def iteritems(d): + """Return an iterator over the (key, value) pairs of a dictionary.""" + return iter(getattr(d, _iteritems)()) + + +if PY3: + def b(s): + return s.encode("latin-1") + def u(s): + return s + if sys.version_info[1] <= 1: + def int2byte(i): + return bytes((i,)) + else: + # This is about 2x faster than the implementation above on 3.2+ + int2byte = operator.methodcaller("to_bytes", 1, "big") + import io + StringIO = io.StringIO + BytesIO = io.BytesIO +else: + def b(s): + return s + def u(s): + return unicode(s, "unicode_escape") + int2byte = chr + import StringIO + StringIO = BytesIO = StringIO.StringIO +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +if PY3: + import builtins + exec_ = getattr(builtins, "exec") + + + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + + print_ = getattr(builtins, "print") + del builtins + +else: + def exec_(code, globs=None, locs=None): + """Execute code in a namespace.""" + if globs is None: + frame = sys._getframe(1) + globs = frame.f_globals + if locs is None: + locs = frame.f_locals + del frame + elif locs is None: + locs = globs + exec("""exec code in globs, locs""") + + + exec_("""def reraise(tp, value, tb=None): + raise tp, value, tb +""") + + + def print_(*args, **kwargs): + """The new-style print function.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + def write(data): + if not isinstance(data, basestring): + data = str(data) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) + +_add_doc(reraise, """Reraise an exception.""") + + +def with_metaclass(meta, base=object): + """Create a base class with a metaclass.""" + return meta("NewBase", (base,), {}) + + +### Additional customizations for Django ### + +if PY3: + _iterlists = "lists" + _assertRaisesRegex = "assertRaisesRegex" +else: + _iterlists = "iterlists" + _assertRaisesRegex = "assertRaisesRegexp" + + +def iterlists(d): + """Return an iterator over the values of a MultiValueDict.""" + return getattr(d, _iterlists)() + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +add_move(MovedModule("_dummy_thread", "dummy_thread")) +add_move(MovedModule("_thread", "thread")) diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css index b2e41b99..d806267b 100644 --- a/rest_framework/static/rest_framework/css/default.css +++ b/rest_framework/static/rest_framework/css/default.css @@ -150,6 +150,49 @@ html, body { margin: 0 auto -60px; } +.form-switcher { + margin-bottom: 0; +} + +.well { + -webkit-box-shadow: none; + -moz-box-shadow: none; + box-shadow: none; +} + +.well .form-actions { + padding-bottom: 0; + margin-bottom: 0; +} + +.well form { + margin-bottom: 0; +} + +.nav-tabs { + border: 0; +} + +.nav-tabs > li { + float: right; +} + +.nav-tabs li a { + margin-right: 0; +} + +.nav-tabs > .active > a { + background: #f5f5f5; +} + +.nav-tabs > .active > a:hover { + background: #f5f5f5; +} + +.tabbable.first-tab-active .tab-content +{ + border-top-right-radius: 0; +} #footer, #push { height: 60px; /* .push must be the same height as .footer */ diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js index ecaccc0f..c74829d7 100644 --- a/rest_framework/static/rest_framework/js/default.js +++ b/rest_framework/static/rest_framework/js/default.js @@ -3,3 +3,11 @@ prettyPrint(); $('.js-tooltip').tooltip({ delay: 1000 }); + +$('a[data-toggle="tab"]:first').on('shown', function (e) { + $(e.target).parents('.tabbable').addClass('first-tab-active'); +}); +$('a[data-toggle="tab"]:not(:first)').on('shown', function (e) { + $(e.target).parents('.tabbable').removeClass('first-tab-active'); +}); +$('.form-switcher a:first').tab('show'); diff --git a/rest_framework/status.py b/rest_framework/status.py index a1eb48da..b9f249f9 100644 --- a/rest_framework/status.py +++ b/rest_framework/status.py @@ -4,6 +4,7 @@ Descriptive HTTP status codes, for code readability. See RFC 2616 - http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html And RFC 6585 - http://tools.ietf.org/html/rfc6585 """ +from __future__ import unicode_literals HTTP_100_CONTINUE = 100 HTTP_101_SWITCHING_PROTOCOLS = 101 diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index fb0e19f0..4410f285 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -1,6 +1,5 @@ {% load url from future %} {% load rest_framework %} -{% load static %} <!DOCTYPE html> <html> <head> @@ -14,10 +13,10 @@ <title>{% block title %}Django REST framework{% endblock %}</title> {% block style %} - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/prettify.css'/> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> + {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %} + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/prettify.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> {% endblock %} {% endblock %} @@ -113,10 +112,10 @@ <div class="request-info"> <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> - <div> + </div> <div class="response-info"> <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> +{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> {% endfor %} </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} </div> @@ -124,56 +123,88 @@ {% if response.status_code != 403 %} - {% if post_form %} - <div class="well"> - <form action="{{ request.get_full_path }}" method="POST" {% if post_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> - <fieldset> - {% csrf_token %} - {{ post_form.non_field_errors }} - {% for field in post_form %} - <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> - {{ field.label_tag|add_class:"control-label" }} - <div class="controls"> - {{ field }} - <span class="help-inline">{{ field.help_text }}</span> - <!--{{ field.errors|add_class:"help-block" }}--> + {% if post_form or raw_data_post_form %} + <div {% if post_form %}class="tabbable"{% endif %}> + {% if post_form %} + <ul class="nav nav-tabs form-switcher"> + <li><a href="#object-form" data-toggle="tab">HTML form</a></li> + <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> + </ul> + {% endif %} + <div class="well tab-content"> + {% if post_form %} + <div class="tab-pane" id="object-form"> + {% with form=post_form %} + <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> + <fieldset> + {% include "rest_framework/form.html" %} + <div class="form-actions"> + <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button> </div> - </div> - {% endfor %} - <div class="form-actions"> - <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button> - </div> - </fieldset> - </form> + </fieldset> + </form> + {% endwith %} + </div> + {% endif %} + <div {% if post_form %}class="tab-pane"{% endif %} id="generic-content-form"> + {% with form=raw_data_post_form %} + <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal"> + <fieldset> + {% include "rest_framework/form.html" %} + <div class="form-actions"> + <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button> + </div> + </fieldset> + </form> + {% endwith %} + </div> + </div> </div> {% endif %} - {% if put_form %} - <div class="well"> - <form action="{{ request.get_full_path }}" method="POST" {% if put_form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> - <fieldset> - <input type="hidden" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" /> - {% csrf_token %} - {{ put_form.non_field_errors }} - {% for field in put_form %} - <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> - {{ field.label_tag|add_class:"control-label" }} - <div class="controls"> - {{ field }} - <span class='help-inline'>{{ field.help_text }}</span> - <!--{{ field.errors|add_class:"help-block" }}--> + {% if put_form or raw_data_put_form or raw_data_patch_form %} + <div {% if put_form %}class="tabbable"{% endif %}> + {% if put_form %} + <ul class="nav nav-tabs form-switcher"> + <li><a href="#object-form" data-toggle="tab">HTML form</a></li> + <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> + </ul> + {% endif %} + <div class="well tab-content"> + {% if put_form %} + <div class="tab-pane" id="object-form"> + {% with form=put_form %} + <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> + <fieldset> + {% include "rest_framework/form.html" %} + <div class="form-actions"> + <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> </div> - </div> - {% endfor %} - <div class="form-actions"> - <button class="btn btn-primary js-tooltip" title="Make a PUT request on the {{ name }} resource">PUT</button> - </div> - - </fieldset> - </form> + </fieldset> + </form> + {% endwith %} + </div> + {% endif %} + <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form"> + {% with form=raw_data_put_or_patch_form %} + <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal"> + <fieldset> + {% include "rest_framework/form.html" %} + <div class="form-actions"> + {% if raw_data_put_form %} + <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> + {% endif %} + {% if raw_data_patch_form %} + <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PATCH" title="Make a PUT request on the {{ name }} resource">PATCH</button> + {% endif %} + </div> + </fieldset> + </form> + {% endwith %} + </div> + </div> </div> {% endif %} - {% endif %} </div> @@ -195,10 +226,10 @@ {% endblock %} {% block script %} - <script src="{% get_static_prefix %}rest_framework/js/jquery-1.8.1-min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/bootstrap.min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/prettify-min.js"></script> - <script src="{% get_static_prefix %}rest_framework/js/default.js"></script> + <script src="{% static "rest_framework/js/jquery-1.8.1-min.js" %}"></script> + <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script> + <script src="{% static "rest_framework/js/prettify-min.js" %}"></script> + <script src="{% static "rest_framework/js/default.js" %}"></script> {% endblock %} </body> </html> diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html new file mode 100644 index 00000000..dc7acc70 --- /dev/null +++ b/rest_framework/templates/rest_framework/form.html @@ -0,0 +1,13 @@ +{% load rest_framework %} +{% csrf_token %} +{{ form.non_field_errors }} +{% for field in form %} + <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> + {{ field.label_tag|add_class:"control-label" }} + <div class="controls"> + {{ field }} + <span class="help-inline">{{ field.help_text }}</span> + <!--{{ field.errors|add_class:"help-block" }}--> + </div> + </div> +{% endfor %} diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index c1271399..b7629327 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -1,53 +1,3 @@ -{% load url from future %} -{% load static %} -<html> +{% extends "rest_framework/login_base.html" %} - <head> - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/> - <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/> - <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/> - </head> - - <body class="container"> - -<div class="container-fluid" style="margin-top: 30px"> - <div class="row-fluid"> - - <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> - <div class="row-fluid"> - <div> - <h3 style="margin: 0 0 20px;">Django REST framework</h3> - </div> - </div><!-- /row fluid --> - - <div class="row-fluid"> - <div> - <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> - {% csrf_token %} - <div id="div_id_username" class="clearfix control-group"> - <div class="controls" style="height: 30px"> - <Label class="span4" style="margin-top: 3px">Username:</label> - <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> - </div> - </div> - <div id="div_id_password" class="clearfix control-group"> - <div class="controls" style="height: 30px"> - <Label class="span4" style="margin-top: 3px">Password:</label> - <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> - </div> - </div> - <input type="hidden" name="next" value="{{ next }}" /> - <div class="form-actions-no-box"> - <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> - </div> - </form> - </div> - </div><!-- /row fluid --> - </div><!--/span--> - - </div><!-- /.row-fluid --> - </div> - - </div> - </body> -</html> +{# Override this template in your own templates directory to customize #} diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html new file mode 100644 index 00000000..a3e73b6b --- /dev/null +++ b/rest_framework/templates/rest_framework/login_base.html @@ -0,0 +1,51 @@ +{% load url from future %} +{% load rest_framework %} +<html> + + <head> + {% block style %} + {% block bootstrap_theme %}<link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap.min.css" %}"/>{% endblock %} + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/bootstrap-tweaks.css" %}"/> + <link rel="stylesheet" type="text/css" href="{% static "rest_framework/css/default.css" %}"/> + {% endblock %} + </head> + + <body class="container"> + + <div class="container-fluid" style="margin-top: 30px"> + <div class="row-fluid"> + <div class="well" style="width: 320px; margin-left: auto; margin-right: auto"> + <div class="row-fluid"> + <div> + {% block branding %}<h3 style="margin: 0 0 20px;">Django REST framework</h3>{% endblock %} + </div> + </div><!-- /row fluid --> + + <div class="row-fluid"> + <div> + <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post"> + {% csrf_token %} + <div id="div_id_username" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Username:</label> + <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username"> + </div> + </div> + <div id="div_id_password" class="clearfix control-group"> + <div class="controls"> + <Label class="span4">Password:</label> + <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password"> + </div> + </div> + <input type="hidden" name="next" value="{{ next }}" /> + <div class="form-actions-no-box"> + <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit"> + </div> + </form> + </div> + </div><!-- /.row-fluid --> + </div><!--/.well--> + </div><!-- /.row-fluid --> + </div><!-- /.container-fluid --> + </body> +</html> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 4e0181ee..c86b6456 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,48 +1,114 @@ +from __future__ import unicode_literals, absolute_import from django import template -from django.core.urlresolvers import reverse +from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict -from django.utils.encoding import force_unicode from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe -from urlparse import urlsplit, urlunsplit -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string register = template.Library() +# Note we don't use 'load staticfiles', because we need a 1.3 compatible +# version, so instead we include the `static` template tag ourselves. + +# When 1.3 becomes unsupported by REST framework, we can instead start to +# use the {% load staticfiles %} tag, remove the following code, +# and add a dependancy that `django.contrib.staticfiles` must be installed. + +# Note: We can't put this into the `compat` module because the compat import +# from rest_framework.compat import ... +# conflicts with this rest_framework template tag module. + +try: # Django 1.5+ + from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode + + @register.tag('static') + def do_static(parser, token): + return StaticFilesNode.handle_token(parser, token) + +except ImportError: + try: # Django 1.4 + from django.contrib.staticfiles.storage import staticfiles_storage + + @register.simple_tag + def static(path): + """ + A template tag that returns the URL to a file + using staticfiles' storage backend + """ + return staticfiles_storage.url(path) + + except ImportError: # Django 1.3 + from urlparse import urljoin + from django import template + from django.templatetags.static import PrefixNode + + class StaticNode(template.Node): + def __init__(self, varname=None, path=None): + if path is None: + raise template.TemplateSyntaxError( + "Static template nodes must be given a path to return.") + self.path = path + self.varname = varname + + def url(self, context): + path = self.path.resolve(context) + return self.handle_simple(path) + + def render(self, context): + url = self.url(context) + if self.varname is None: + return url + context[self.varname] = url + return '' + + @classmethod + def handle_simple(cls, path): + return urljoin(PrefixNode.handle_simple("STATIC_URL"), path) + + @classmethod + def handle_token(cls, parser, token): + """ + Class method to parse prefix node and return a Node. + """ + bits = token.split_contents() + + if len(bits) < 2: + raise template.TemplateSyntaxError( + "'%s' takes at least one argument (path to file)" % bits[0]) + + path = parser.compile_filter(bits[1]) + + if len(bits) >= 2 and bits[-2] == 'as': + varname = bits[3] + else: + varname = None + + return cls(varname, path) + + @register.tag('static') + def do_static_13(parser, token): + return StaticNode.handle_token(parser, token) + + def replace_query_param(url, key, val): """ Given a URL and a key/val pair, set or replace an item in the query parameters of the URL, and return the new URL. """ - (scheme, netloc, path, query, fragment) = urlsplit(url) + (scheme, netloc, path, query, fragment) = urlparse.urlsplit(url) query_dict = QueryDict(query).copy() query_dict[key] = val query = query_dict.urlencode() - return urlunsplit((scheme, netloc, path, query, fragment)) + return urlparse.urlunsplit((scheme, netloc, path, query, fragment)) # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ - ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), - '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') - - # And the template tags themselves... @register.simple_tag @@ -52,7 +118,7 @@ def optional_login(request): """ try: login_url = reverse('rest_framework:login') - except: + except NoReverseMatch: return '' snippet = "<a href='%s?next=%s'>Log in</a>" % (login_url, request.path) @@ -66,7 +132,7 @@ def optional_logout(request): """ try: logout_url = reverse('rest_framework:logout') - except: + except NoReverseMatch: return '' snippet = "<a href='%s?next=%s'>Log out</a>" % (logout_url, request.path) @@ -96,7 +162,7 @@ def add_class(value, css_class): In the case of REST Framework, the filter is used to add Bootstrap-specific classes to the forms. """ - html = unicode(value) + html = six.text_type(value) match = class_re.search(html) if match: m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class, @@ -110,15 +176,25 @@ def add_class(value, css_class): return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), + ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ Converts any URLs in text into clickable links. - Works on http://, https://, www. links and links ending in .org, .net or - .com. Links can have trailing punctuation (periods, commas, close-parens) - and leading punctuation (opening parens) and it'll still do the right - thing. + Works on http://, https://, www. links, and also on links ending in one of + the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). + Links can have trailing punctuation (periods, commas, close-parens) and + leading punctuation (opening parens) and it'll still do the right thing. If trim_url_limit is not None, the URLs in link text longer than this limit will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -130,25 +206,42 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru """ trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x safe_input = isinstance(text, SafeData) - words = word_split_re.split(force_unicode(text)) - nofollow_attr = nofollow and ' rel="nofollow"' or '' + words = word_split_re.split(force_text(text)) for i, word in enumerate(words): match = None if '.' in word or '@' in word or ':' in word: - match = punctuation_re.match(word) - if match: - lead, middle, trail = match.groups() + # Deal with punctuation. + lead, middle, trail = '', word, '' + for punctuation in TRAILING_PUNCTUATION: + if middle.endswith(punctuation): + middle = middle[:-len(punctuation)] + trail = punctuation + trail + for opening, closing in WRAPPING_PUNCTUATION: + if middle.startswith(opening): + middle = middle[len(opening):] + lead = lead + opening + # Keep parentheses at the end only if they're balanced. + if (middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1): + middle = middle[:-len(closing)] + trail = closing + trail + # Make URL we want to point to. url = None - if middle.startswith('http://') or middle.startswith('https://'): - url = middle - elif middle.startswith('www.') or ('@' not in middle and \ - middle and middle[0] in string.ascii_letters + string.digits and \ - (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): - url = 'http://%s' % middle - elif '@' in middle and not ':' in middle and simple_email_re.match(middle): - url = 'mailto:%s' % middle + nofollow_attr = ' rel="nofollow"' if nofollow else '' + if simple_url_re.match(middle): + url = smart_urlquote(middle) + elif simple_url_2_re.match(middle): + url = smart_urlquote('http://%s' % middle) + elif not ':' in middle and simple_email_re.match(middle): + local, domain = middle.rsplit('@', 1) + try: + domain = domain.encode('idna').decode('ascii') + except UnicodeError: + continue + url = 'mailto:%s@%s' % (local, domain) nofollow_attr = '' + # Make link. if url: trimmed = trim_url(middle) @@ -166,4 +259,15 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words[i] = mark_safe(word) elif autoescape: words[i] = escape(word) - return mark_safe(u''.join(words)) + return ''.join(words) + + +@register.filter +def break_long_headers(header): + """ + Breaks headers longer than 160 characters (~page length) + when possible (are comma separated) + """ + if len(header) > 160 and ',' in header: + header = mark_safe('<br> ' + ', <br>'.join(header.split(','))) + return header diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 838e081b..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,33 +1,65 @@ +from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase -from django.utils import simplejson as json - +from django.utils import unittest +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions from rest_framework import permissions +from rest_framework import status +from rest_framework.authentication import ( + BaseAuthentication, + TokenAuthentication, + BasicAuthentication, + SessionAuthentication, + OAuthAuthentication, + OAuth2Authentication +) from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication -from rest_framework.compat import patterns +from rest_framework.compat import patterns, url, include +from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth, oauth_provider +from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView - +import json import base64 +import time +import datetime + +factory = RequestFactory() class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) + def get(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def post(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) urlpatterns = patterns('', - (r'^$', MockView.as_view()), + (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), + (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), + (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), + (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + permission_classes=[permissions.TokenHasReadWriteScope])) ) +if oauth2_provider is not None: + urlpatterns += patterns('', + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + permission_classes=[permissions.TokenHasReadWriteScope])), + ) + class BasicAuthTests(TestCase): """Basic authentication""" @@ -42,25 +74,30 @@ class BasicAuthTests(TestCase): def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + credentials = ('%s:%s' % (self.username, self.password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + auth = 'Basic %s' % base64_credentials + response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + credentials = ('%s:%s' % (self.username, self.password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + auth = 'Basic %s' % base64_credentials + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') class SessionAuthTests(TestCase): @@ -83,31 +120,31 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_post_form_session_auth_passing(self): """ Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + response = self.non_csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_put_form_session_auth_passing(self): """ Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.put('/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + response = self.non_csrf_client.put('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) class TokenAuthTests(TestCase): @@ -126,25 +163,25 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" - auth = "Token " + self.key - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + auth = 'Token ' + self.key + response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" @@ -157,8 +194,8 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', json.dumps({'username': self.username, 'password': self.password}), 'application/json') - self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.content)['token'], self.key) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" @@ -179,5 +216,340 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + + +class IncorrectCredentialsTests(TestCase): + def test_incorrect_credentials(self): + """ + If a request contains bad authentication credentials, then + authentication should run and error, even if no permissions + are set on the view. + """ + class IncorrectCredentialsAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('Bad credentials') + + request = factory.get('/') + view = MockView.as_view( + authentication_classes=(IncorrectCredentialsAuth,), + permission_classes=() + ) + response = view(request) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class OAuthTests(TestCase): + """OAuth 1.0a authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + # these imports are here because oauth is optional and hiding them in try..except block or compat + # could obscure problems if something breaks + from oauth_provider.models import Consumer, Resource + from oauth_provider.models import Token as OAuthToken + from oauth_provider import consts + + self.consts = consts + + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CONSUMER_KEY = 'consumer_key' + self.CONSUMER_SECRET = 'consumer_secret' + self.TOKEN_KEY = "token_key" + self.TOKEN_SECRET = "token_secret" + + self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, + name='example', user=self.user, status=self.consts.ACCEPTED) + + self.resource = Resource.objects.create(name="resource name", url="api/") + self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource, + token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True + ) + + def _create_authorization_header(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + + return req.to_header()["Authorization"] + + def _create_authorization_url_parameters(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + return dict(req) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_passing_oauth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_repeated_nonce_failing_oauth(self): + """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + # simulate reply attack auth header containes already used (nonce, timestamp) pair + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_token_removed_failing_oauth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_consumer_status_not_accepted_failing_oauth(self): + """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" + for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): + self.consumer.status = consumer_status + self.consumer.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_request_token_failing_oauth(self): + """Ensure POSTing with unauthorized request token instead of access token fails""" + self.token.token_type = self.token.REQUEST + self.token.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_urlencoded_parameters(self): + """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_url_parameters(self): + """Ensure GETing with auth in url parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_hmac_sha1_signature_passes(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_readonly_resource_passing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_readonly_resource_failing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_write_resource_passing_auth(self): + """Ensure POSTing with a write resource succeed""" + read_write_access_token = self.token + read_write_access_token.resource.is_readonly = False + read_write_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + +class OAuth2Tests(TestCase): + """OAuth 2.0 authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CLIENT_ID = 'client_key' + self.CLIENT_SECRET = 'client_secret' + self.ACCESS_TOKEN = "access_token" + self.REFRESH_TOKEN = "refresh_token" + + self.oauth2_client = oauth2_provider_models.Client.objects.create( + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + redirect_uri='', + client_type=0, + name='example', + user=None, + ) + + self.access_token = oauth2_provider_models.AccessToken.objects.create( + token=self.ACCESS_TOKEN, + client=self.oauth2_client, + user=self.user, + ) + self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + user=self.user, + access_token=self.access_token, + client=self.oauth2_client + ) + + def _create_authorization_header(self, token=None): + return "Bearer {0}".format(token or self.access_token.token) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_type_failing(self): + """Ensure that a wrong token type lead to the correct HTTP error status code""" + auth = "Wrong token-type-obsviously" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_format_failing(self): + """Ensure that a wrong token format lead to the correct HTTP error status code""" + auth = "Bearer wrong token format" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_failing(self): + """Ensure that a wrong token lead to the correct HTTP error status code""" + auth = "Bearer wrong-token" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth(self): + """Ensure GETing form over OAuth with correct client credentials succeed""" + auth = self._create_authorization_header() + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_token_removed_failing_auth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.access_token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_refresh_token_failing_auth(self): + """Ensure POSTing with refresh token instead of access token fails""" + auth = self._create_authorization_header(token=self.refresh_token.token) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_expired_access_token_failing_auth(self): + """Ensure POSTing with expired access token fails with an 'Invalid token' error""" + self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late + self.access_token.save() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + self.assertIn('Invalid token', response.content) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_invalid_scope_failing_auth(self): + """Ensure POSTing with a readonly scope instead of a write scope fails""" + read_only_access_token = self.access_token + read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] + read_only_access_token.save() + auth = self._create_authorization_header(token=read_only_access_token.token) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_valid_scope_passing_auth(self): + """Ensure POSTing with a write scope succeed""" + read_write_access_token = self.access_token + read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] + read_write_access_token.save() + auth = self._create_authorization_header(token=read_write_access_token.token) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.content)['token'], self.key) diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py index df891683..d9ed647e 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/breadcrumbs.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.utils.breadcrumbs import get_breadcrumbs diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 8079c8cb..1016fed3 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,5 +1,5 @@ +from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import status from rest_framework.response import Response from rest_framework.renderers import JSONRenderer @@ -17,6 +17,8 @@ from rest_framework.decorators import ( permission_classes, ) +from rest_framework.tests.utils import RequestFactory + class DecoratorTestCase(TestCase): @@ -27,13 +29,27 @@ class DecoratorTestCase(TestCase): response.request = request return APIView.finalize_response(self, request, response, *args, **kwargs) - def test_wrap_view(self): + def test_api_view_incorrect(self): + """ + If @api_view is not applied correct, we should raise an assertion. + """ - @api_view(['GET']) + @api_view def view(request): - return Response({}) + return Response() + + request = self.factory.get('/') + self.assertRaises(AssertionError, view, request) + + def test_api_view_incorrect_arguments(self): + """ + If @api_view is missing arguments, we should raise an assertion. + """ - self.assertTrue(isinstance(view.cls_instance, APIView)) + with self.assertRaises(AssertionError): + @api_view('GET') + def view(request): + return Response() def test_calling_method(self): @@ -43,11 +59,11 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) request = self.factory.post('/') response = view(request) - self.assertEqual(response.status_code, 405) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def test_calling_put_method(self): @@ -57,11 +73,25 @@ class DecoratorTestCase(TestCase): request = self.factory.put('/') response = view(request) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + request = self.factory.post('/') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_calling_patch_method(self): + + @api_view(['GET', 'PATCH']) + def view(request): + return Response({}) + + request = self.factory.patch('/') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_200_OK) request = self.factory.post('/') response = view(request) - self.assertEqual(response.status_code, 405) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def test_renderer_classes(self): @@ -109,7 +139,7 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_throttle_classes(self): class OncePerDayUserThrottle(UserRateThrottle): @@ -122,7 +152,7 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = view(request) - self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index d958b840..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -1,6 +1,10 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals from django.test import TestCase from rest_framework.views import APIView from rest_framework.compat import apply_markdown +from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -46,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> class TestViewNamesAndDescriptions(TestCase): - def test_resource_name_uses_classname_by_default(self): - """Ensure Resource names are based on the classname by default.""" + def test_view_name_uses_class_name(self): + """ + Ensure view names are based on the class name. + """ class MockView(APIView): pass - self.assertEquals(MockView().get_name(), 'Mock') - - def test_resource_name_can_be_set_explicitly(self): - """Ensure Resource names can be set using the 'get_name' method.""" - example = 'Some Other Name' - class MockView(APIView): - def get_name(self): - return example - self.assertEquals(MockView().get_name(), example) + self.assertEqual(get_view_name(MockView), 'Mock') - def test_resource_description_uses_docstring_by_default(self): - """Ensure Resource names are based on the docstring by default.""" + def test_view_description_uses_docstring(self): + """Ensure view descriptions are based on the docstring.""" class MockView(APIView): """an example docstring ==================== @@ -78,35 +76,32 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEquals(MockView().get_description(), DESCRIPTION) + self.assertEqual(get_view_description(MockView), DESCRIPTION) - def test_resource_description_can_be_set_explicitly(self): - """Ensure Resource descriptions can be set using the 'get_description' method.""" - example = 'Some other description' + def test_view_description_supports_unicode(self): + """ + Unicode in docstrings should be respected. + """ class MockView(APIView): - """docstring""" - def get_description(self): - return example - self.assertEquals(MockView().get_description(), example) - - def test_resource_description_does_not_require_docstring(self): - """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" - example = 'Some other description' + """Проверка""" + pass - class MockView(APIView): - def get_description(self): - return example - self.assertEquals(MockView().get_description(), example) + self.assertEqual(get_view_description(MockView), "Проверка") - def test_resource_description_can_be_empty(self): - """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" + def test_view_description_can_be_empty(self): + """ + Ensure that if a view has no docstring, + then it's description is the empty string. + """ class MockView(APIView): pass - self.assertEquals(MockView().get_description(), '') + self.assertEqual(get_view_description(MockView), '') def test_markdown(self): - """Ensure markdown to HTML works as expected""" + """ + Ensure markdown to HTML works as expected. + """ if apply_markdown: gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/extras/__init__.py diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/rest_framework/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py new file mode 100644 index 00000000..3cdfa0f6 --- /dev/null +++ b/rest_framework/tests/fields.py @@ -0,0 +1,648 @@ +""" +General serializer field tests. +""" +from __future__ import unicode_literals +import datetime +from decimal import Decimal + +from django.db import models +from django.test import TestCase +from django.core import validators + +from rest_framework import serializers +from rest_framework.serializers import Serializer + + +class TimestampedModel(models.Model): + added = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): + id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): + class Meta: + model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): + class Meta: + model = CharPrimaryKeyModel + + +class TimeFieldModel(models.Model): + clock = models.TimeField() + + +class TimeFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = TimeFieldModel + + +class BasicFieldTests(TestCase): + def test_auto_now_fields_read_only(self): + """ + auto_now and auto_now_add fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEqual(serializer.fields['added'].read_only, True) + + def test_auto_pk_fields_read_only(self): + """ + AutoField fields should be read_only by default. + """ + serializer = TimestampedModelSerializer() + self.assertEqual(serializer.fields['id'].read_only, True) + + def test_non_auto_pk_fields_not_read_only(self): + """ + PK fields other than AutoField fields should not be read_only by default. + """ + serializer = CharPrimaryKeyModelSerializer() + self.assertEqual(serializer.fields['id'].read_only, False) + + +class DateFieldTest(TestCase): + """ + Tests for the DateFieldTest from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateField() + result_1 = f.from_native('1984-07-31') + + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_from_native_datetime_date(self): + """ + Make sure from_native() accepts a datetime.date instance. + """ + f = serializers.DateField() + result_1 = f.from_native(datetime.date(1984, 7, 31)) + + self.assertEqual(result_1, datetime.date(1984, 7, 31)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + result = f.from_native('1984 -- 31') + + self.assertEqual(datetime.date(1984, 1, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + + try: + f.from_native('1984-07-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_date(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid date. + """ + f = serializers.DateField() + + try: + f.from_native('1984-13-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateField() + + try: + f.from_native('1984 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns datetime as default. + """ + f = serializers.DateField() + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_to_native_iso(self): + """ + Make sure to_native() with 'iso-8601' returns iso formated date. + """ + f = serializers.DateField(format='iso-8601') + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984-07-31', result_1) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateField(format="%Y - %m.%d") + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984 - 07.31', result_1) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class DateTimeFieldTest(TestCase): + """ + Tests for the DateTimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateTimeField() + result_1 = f.from_native('1984-07-31 04:31') + result_2 = f.from_native('1984-07-31 04:31:59') + result_3 = f.from_native('1984-07-31 04:31:59.000200') + + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) + + def test_from_native_datetime_datetime(self): + """ + Make sure from_native() accepts a datetime.datetime instance. + """ + f = serializers.DateTimeField() + result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) + self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) + self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + result = f.from_native('1984 -- 04:59') + + self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + + try: + f.from_native('1984-07-31 04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateTimeField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_datetime(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid datetime. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns isoformat as default. + """ + f = serializers.DateTimeField() + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual(datetime.datetime(1984, 7, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + + def test_to_native_iso(self): + """ + Make sure to_native() with format=iso-8601 returns iso formatted datetime. + """ + f = serializers.DateTimeField(format='iso-8601') + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984-07-31T00:00:00', result_1) + self.assertEqual('1984-07-31T04:31:00', result_2) + self.assertEqual('1984-07-31T04:31:59', result_3) + self.assertEqual('1984-07-31T04:31:59.000200', result_4) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateTimeField(format="%Y - %H:%M") + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984 - 00:00', result_1) + self.assertEqual('1984 - 04:31', result_2) + self.assertEqual('1984 - 04:31', result_3) + self.assertEqual('1984 - 04:31', result_4) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class TimeFieldTest(TestCase): + """ + Tests for the TimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.TimeField() + result_1 = f.from_native('04:31') + result_2 = f.from_native('04:31:59') + result_3 = f.from_native('04:31:59.000200') + + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + + def test_from_native_datetime_time(self): + """ + Make sure from_native() accepts a datetime.time instance. + """ + f = serializers.TimeField() + result_1 = f.from_native(datetime.time(4, 31)) + result_2 = f.from_native(datetime.time(4, 31, 59)) + result_3 = f.from_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.time(4, 31)) + self.assertEqual(result_2, datetime.time(4, 31, 59)) + self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + result = f.from_native('04 -- 31') + + self.assertEqual(datetime.time(4, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + + try: + f.from_native('04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.TimeField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.TimeField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_time(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid time. + """ + f = serializers.TimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.TimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns time object as default. + """ + f = serializers.TimeField() + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + + def test_to_native_iso(self): + """ + Make sure to_native() with format='iso-8601' returns iso formatted time. + """ + f = serializers.TimeField(format='iso-8601') + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04:31:00', result_1) + self.assertEqual('04:31:59', result_2) + self.assertEqual('04:31:59.000200', result_3) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.TimeField(format="%H - %S [%f]") + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04 - 00 [000000]', result_1) + self.assertEqual('04 - 59 [000000]', result_2) + self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): + """ + Tests for the DecimalField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts string values + """ + f = serializers.DecimalField() + result_1 = f.from_native('9000') + result_2 = f.from_native('1.00000001') + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_from_native_invalid_string(self): + """ + Make sure from_native() raises ValidationError on passing invalid string + """ + f = serializers.DecimalField() + + try: + f.from_native('123.45.6') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Enter a number."]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_integer(self): + """ + Make sure from_native() accepts integer values + """ + f = serializers.DecimalField() + result = f.from_native(9000) + + self.assertEqual(Decimal('9000'), result) + + def test_from_native_float(self): + """ + Make sure from_native() accepts float values + """ + f = serializers.DecimalField() + result = f.from_native(1.00000001) + + self.assertEqual(Decimal('1.00000001'), result) + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DecimalField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_to_native(self): + """ + Make sure to_native() returns Decimal as string. + """ + f = serializers.DecimalField() + + result_1 = f.to_native(Decimal('9000')) + result_2 = f.to_native(Decimal('1.00000001')) + + self.assertEqual(Decimal('9000'), result_1) + self.assertEqual(Decimal('1.00000001'), result_2) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DecimalField(required=False) + self.assertEqual(None, f.to_native(None)) + + def test_valid_serialization(self): + """ + Make sure the serializer works correctly + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=9010, + min_value=9000, + max_digits=6, + decimal_places=2) + + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) + self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + + self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) + self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + + def test_raise_max_value(self): + """ + Make sure max_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_value=100) + + s = DecimalSerializer(data={'decimal_field': '123'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) + + def test_raise_min_value(self): + """ + Make sure min_value violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(min_value=100) + + s = DecimalSerializer(data={'decimal_field': '99'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) + + def test_raise_max_digits(self): + """ + Make sure max_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=5) + + s = DecimalSerializer(data={'decimal_field': '123.456'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) + + def test_raise_max_decimal_places(self): + """ + Make sure max_decimal_places violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '123.4567'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) + + def test_raise_max_whole_digits(self): + """ + Make sure max_whole_digits violations raises ValidationError + """ + class DecimalSerializer(Serializer): + decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + + s = DecimalSerializer(data={'decimal_field': '12345.6'}) + + self.assertFalse(s.is_valid()) + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index 5dd57b7c..487046ac 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -1,9 +1,9 @@ -import StringIO -import datetime - +from __future__ import unicode_literals from django.test import TestCase - from rest_framework import serializers +from rest_framework.compat import BytesIO +from rest_framework.compat import six +import datetime class UploadedFile(object): @@ -25,15 +25,27 @@ class UploadedFileSerializer(serializers.Serializer): class FileSerializerTests(TestCase): - def test_create(self): now = datetime.datetime.now() - file = StringIO.StringIO('stuff') + file = BytesIO(six.b('stuff')) file.name = 'stuff.txt' - file.size = file.len + file.size = len(file.getvalue()) serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) uploaded_file = UploadedFile(file=file, created=now) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.object.created, uploaded_file.created) - self.assertEquals(serializer.object.file, uploaded_file.file) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertEqual(serializer.object.file, uploaded_file.file) self.assertFalse(serializer.object is uploaded_file) + + def test_creation_failure(self): + """ + Passing files=None should result in an ValidationError + + Regression test for: + https://github.com/tomchristie/django-rest-framework/issues/542 + """ + now = datetime.datetime.now() + + serializer = UploadedFileSerializer(data={'created': now}) + self.assertFalse(serializer.is_valid()) + self.assertIn('file', serializer.errors) diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py new file mode 100644 index 00000000..8ae6d530 --- /dev/null +++ b/rest_framework/tests/filters.py @@ -0,0 +1,474 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.urlresolvers import reverse +from django.test import TestCase +from django.test.client import RequestFactory +from django.utils import unittest +from rest_framework import generics, serializers, status, filters +from rest_framework.compat import django_filters, patterns, url +from rest_framework.tests.models import BasicModel + +factory = RequestFactory() + + +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + +if django_filters: + # Basic filter on a list view. + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_fields = ['decimal', 'date'] + filter_backends = (filters.DjangoFilterBackend,) + + # These class are used to test a filter class. + class SeveralFieldsFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + decimal = django_filters.NumberFilter(lookup_type='lt') + date = django_filters.DateFilter(lookup_type='gt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterClassRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + # These classes are used to test a misconfigured filter class. + class MisconfiguredFilter(django_filters.FilterSet): + text = django_filters.CharFilter(lookup_type='icontains') + + class Meta: + model = BasicModel + fields = ['text'] + + class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = MisconfiguredFilter + filter_backends = (filters.DjangoFilterBackend,) + + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backends = (filters.DjangoFilterBackend,) + + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + def get_queryset(self): + return FilterableItem.objects.all() + + urlpatterns = patterns('', + url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), + ) + + +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + + def setUp(self): + """ + Create 10 FilterableItem instances. + """ + base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) + for i in range(10): + text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. + decimal = base_data[1] + i + date = base_data[2] - datetime.timedelta(days=i * 2) + FilterableItem(text=text, decimal=decimal, date=date).save() + + self.objects = FilterableItem.objects + self.data = [ + self._serialize_object(obj) + for obj in self.objects.all() + ] + + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_fields_root_view(self): + """ + GET requests to paginated ListCreateAPIView should return paginated results. + """ + view = FilterFieldsRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter works. + search_date = datetime.date(2012, 9, 22) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] == search_date] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_class_root_view(self): + """ + GET requests to filtered ListCreateAPIView that have a filter_class set + should return filtered results. + """ + view = FilterClassRootView.as_view() + + # Basic test with no filter. + request = factory.get('/') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) + + # Tests that the decimal filter set with 'lt' in the filter class works. + search_decimal = Decimal('4.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + # Tests that the date filter set with 'gt' in the filter class works. + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date] + self.assertEqual(response.data, expected_data) + + # Tests that the text filter set with 'icontains' in the filter class works. + search_text = 'ff' + request = factory.get('/?text=%s' % search_text) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if search_text in f['text'].lower()] + self.assertEqual(response.data, expected_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['date'] > search_date and + f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_incorrectly_configured_filter(self): + """ + An error should be displayed when the filter class is misconfigured. + """ + view = IncorrectlyConfiguredRootView.as_view() + + request = factory.get('/') + self.assertRaises(AssertionError, view, request) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_unknown_filter(self): + """ + GET requests with filters that aren't configured should return 200. + """ + view = FilterFieldsRootView.as_view() + + search_integer = 10 + request = factory.get('/?integer=%s' % search_integer) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filters' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py deleted file mode 100644 index af2e6c2e..00000000 --- a/rest_framework/tests/filterset.py +++ /dev/null @@ -1,168 +0,0 @@ -import datetime -from decimal import Decimal -from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import unittest -from rest_framework import generics, status, filters -from rest_framework.compat import django_filters -from rest_framework.tests.models import FilterableItem, BasicModel - -factory = RequestFactory() - - -if django_filters: - # Basic filter on a list view. - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend - - # These class are used to test a filter class. - class SeveralFieldsFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - decimal = django_filters.NumberFilter(lookup_type='lt') - date = django_filters.DateFilter(lookup_type='gt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend - - # These classes are used to test a misconfigured filter class. - class MisconfiguredFilter(django_filters.FilterSet): - text = django_filters.CharFilter(lookup_type='icontains') - - class Meta: - model = BasicModel - fields = ['text'] - - class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem - filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend - - -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ - - def setUp(self): - """ - Create 10 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(10): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_fields_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - view = FilterFieldsRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) - - # Tests that the decimal filter works. - search_decimal = Decimal('2.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] - self.assertEquals(response.data, expected_data) - - # Tests that the date filter works. - search_date = datetime.date(2012, 9, 22) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] == search_date] - self.assertEquals(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_filtered_class_root_view(self): - """ - GET requests to filtered ListCreateAPIView that have a filter_class set - should return filtered results. - """ - view = FilterClassRootView.as_view() - - # Basic test with no filter. - request = factory.get('/') - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) - - # Tests that the decimal filter set with 'lt' in the filter class works. - search_decimal = Decimal('4.25') - request = factory.get('/?decimal=%s' % search_decimal) - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] < search_decimal] - self.assertEquals(response.data, expected_data) - - # Tests that the date filter set with 'gt' in the filter class works. - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date] - self.assertEquals(response.data, expected_data) - - # Tests that the text filter set with 'icontains' in the filter class works. - search_text = 'ff' - request = factory.get('/?text=%s' % search_text) - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if search_text in f['text'].lower()] - self.assertEquals(response.data, expected_data) - - # Tests that multiple filters works. - search_decimal = Decimal('5.25') - search_date = datetime.date(2012, 10, 2) - request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] - self.assertEquals(response.data, expected_data) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_incorrectly_configured_filter(self): - """ - An error should be displayed when the filter class is misconfigured. - """ - view = IncorrectlyConfiguredRootView.as_view() - - request = factory.get('/') - self.assertRaises(AssertionError, view, request) - - @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_unknown_filter(self): - """ - GET requests with filters that aren't configured should return 200. - """ - view = FilterFieldsRootView.as_view() - - search_integer = 10 - request = factory.get('/?integer=%s' % search_integer) - response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index bc7378e1..c38bfb9f 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -1,25 +1,62 @@ +from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * + + +class Tag(models.Model): + """ + Tags have a descriptive slug, and are attached to an arbitrary object. + """ + tag = models.SlugField() + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + tagged_item = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag + + +class Bookmark(models.Model): + """ + A URL bookmark that may have multiple tags attached. + """ + url = models.URLField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Bookmark: %s' % self.url + + +class Note(models.Model): + """ + A textual note that may have multiple tags attached. + """ + text = models.TextField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Note: %s' % self.text class TestGenericRelations(TestCase): def setUp(self): - bookmark = Bookmark(url='https://www.djangoproject.com/') - bookmark.save() - django = Tag(tag_name='django') - django.save() - python = Tag(tag_name='python') - python.save() - t1 = TaggedItem(content_object=bookmark, tag=django) - t1.save() - t2 = TaggedItem(content_object=bookmark, tag=python) - t2.save() - self.bookmark = bookmark - - def test_reverse_generic_relation(self): + self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') + Tag.objects.create(tagged_item=self.bookmark, tag='django') + Tag.objects.create(tagged_item=self.bookmark, tag='python') + self.note = Note.objects.create(text='Remember the milk') + Tag.objects.create(tagged_item=self.note, tag='reminder') + + def test_generic_relation(self): + """ + Test a relationship that spans a GenericRelation field. + IE. A reverse generic relationship. + """ + class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.ManyRelatedField(source='tags') + tags = serializers.RelatedField(many=True) class Meta: model = Bookmark @@ -27,7 +64,37 @@ class TestGenericRelations(TestCase): serializer = BookmarkSerializer(self.bookmark) expected = { - 'tags': [u'django', u'python'], - 'url': u'https://www.djangoproject.com/' + 'tags': ['django', 'python'], + 'url': 'https://www.djangoproject.com/' + } + self.assertEqual(serializer.data, expected) + + def test_generic_fk(self): + """ + Test a relationship that spans a GenericForeignKey field. + IE. A forward generic relationship. + """ + + class TagSerializer(serializers.ModelSerializer): + tagged_item = serializers.RelatedField() + + class Meta: + model = Tag + exclude = ('id', 'content_type', 'object_id') + + serializer = TagSerializer(Tag.objects.all(), many=True) + expected = [ + { + 'tag': 'django', + 'tagged_item': 'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': 'python', + 'tagged_item': 'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': 'reminder', + 'tagged_item': 'Note: Remember the milk' } - self.assertEquals(serializer.data, expected) + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 7c24d84e..15d87e86 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,10 +1,12 @@ +from __future__ import unicode_literals from django.db import models +from django.shortcuts import get_object_or_404 from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import simplejson as json -from rest_framework import generics, serializers, status +from rest_framework import generics, renderers, serializers, status +from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel - +from rest_framework.compat import six +import json factory = RequestFactory() @@ -37,12 +39,13 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] for item in items: @@ -59,9 +62,10 @@ class TestRootView(TestCase): GET requests to ListCreateAPIView should return list of objects. """ request = factory.get('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_post_root_view(self): """ @@ -70,11 +74,12 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) - self.assertEquals(created.text, 'foobar') + self.assertEqual(created.text, 'foobar') def test_put_root_view(self): """ @@ -83,25 +88,28 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.put('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) def test_delete_root_view(self): """ DELETE requests to ListCreateAPIView should not be allowed """ request = factory.delete('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) def test_options_root_view(self): """ OPTIONS requests to ListCreateAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -115,8 +123,8 @@ class TestRootView(TestCase): 'name': 'Root', 'description': 'Example description for OPTIONS.' } - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, expected) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) def test_post_cannot_set_id(self): """ @@ -125,11 +133,12 @@ class TestRootView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) - self.assertEquals(created.text, 'foobar') + self.assertEqual(created.text, 'foobar') class TestInstanceView(TestCase): @@ -153,9 +162,10 @@ class TestInstanceView(TestCase): GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ request = factory.get('/1') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) def test_post_instance_view(self): """ @@ -164,9 +174,10 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) def test_put_instance_view(self): """ @@ -175,29 +186,47 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk='1').render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.view(request, pk='1').render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + updated = self.objects.get(id=1) + self.assertEqual(updated.text, 'foobar') + + def test_patch_instance_view(self): + """ + PATCH requests to RetrieveUpdateDestroyAPIView should update an object. + """ + content = {'text': 'foobar'} + request = factory.patch('/1', json.dumps(content), + content_type='application/json') + + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_delete_instance_view(self): """ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. """ request = factory.delete('/1') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertEquals(response.content, '') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.content, six.b('')) ids = [obj.id for obj in self.objects.all()] - self.assertEquals(ids, [2, 3]) + self.assertEqual(ids, [2, 3]) def test_options_instance_view(self): """ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -211,8 +240,8 @@ class TestInstanceView(TestCase): 'name': 'Instance', 'description': 'Example description for OPTIONS.' } - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, expected) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) def test_put_cannot_set_id(self): """ @@ -221,11 +250,12 @@ class TestInstanceView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_put_to_deleted_instance(self): """ @@ -236,11 +266,12 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(3): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_put_as_create_on_id_based_url(self): """ @@ -248,13 +279,14 @@ class TestInstanceView(TestCase): at the requested url if it doesn't exist. """ content = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set th pk for a new object + # pk fields can not be created on demand, only the database can set the pk for a new object request = factory.put('/5', json.dumps(content), content_type='application/json') - response = self.view(request, pk=5).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) + with self.assertNumQueries(3): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) new_obj = self.objects.get(pk=5) - self.assertEquals(new_obj.text, 'foobar') + self.assertEqual(new_obj.text, 'foobar') def test_put_as_create_on_slug_based_url(self): """ @@ -264,11 +296,53 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/test_slug', json.dumps(content), content_type='application/json') - response = self.slug_based_view(request, slug='test_slug').render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.slug_based_view(request, slug='test_slug').render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) new_obj = SlugBasedModel.objects.get(slug='test_slug') - self.assertEquals(new_obj.text, 'foobar') + self.assertEqual(new_obj.text, 'foobar') + + +class TestOverriddenGetObject(TestCase): + """ + Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the + queryset/model mechanism but instead overrides get_object() + """ + def setUp(self): + """ + Create 3 BasicModel intances. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + + class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): + """ + Example detail view for override of get_object(). + """ + model = BasicModel + + def get_object(self): + pk = int(self.kwargs['pk']) + return get_object_or_404(BasicModel.objects.all(), id=pk) + + self.view = OverriddenGetObjectView.as_view() + + def test_overridden_get_object_view(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object. + """ + request = factory.get('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) # Regression test for #285 @@ -299,12 +373,12 @@ class TestCreateModelWithAutoNowAddField(TestCase): request = factory.post('/', json.dumps(content), content_type='application/json') response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) - self.assertEquals(created.content, 'foobar') + self.assertEqual(created.content, 'foobar') -# Test for particularly ugly reression with m2m in browseable API +# Test for particularly ugly regression with m2m in browsable API class ClassB(models.Model): name = models.CharField(max_length=255) @@ -315,7 +389,7 @@ class ClassA(models.Model): class ClassASerializer(serializers.ModelSerializer): - childs = serializers.ManyPrimaryKeyRelatedField(source='childs') + childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') class Meta: model = ClassA @@ -329,9 +403,108 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowseableAPI(TestCase): def test_m2m_in_browseable_api(self): """ - Test for particularly ugly reression with m2m in browseable API + Test for particularly ugly regression with m2m in browsable API """ request = factory.get('/', HTTP_ACCEPT='text/html') view = ExampleView().as_view() response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='other') + + +class TestFilterBackendAppliedToViews(TestCase): + + def setUp(self): + """ + Create 3 BasicModel instances to filter on. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + + def test_get_root_view_filters_by_name_with_filter_backend(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + + def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): + """ + GET requests to ListCreateAPIView should return empty list when all models are filtered out. + """ + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/') + response = root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, []) + + def test_get_instance_view_filters_out_name_with_filter_backend(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. + """ + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.data, {'detail': 'Not found'}) + + def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded + """ + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) + request = factory.get('/1') + response = instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + + +class TwoFieldModel(models.Model): + field_a = models.CharField(max_length=100) + field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): + model = TwoFieldModel + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_serializer_class(self): + if self.request.method == 'POST': + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + fields = ('field_b',) + return DynamicSerializer + return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + + def test_dynamic_serializer_form_in_browsable_api(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + self.assertContains(response, 'field_b') + self.assertNotContains(response, 'field_a') diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index 54096206..8f2e2b5a 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -1,12 +1,15 @@ +from __future__ import unicode_literals from django.core.exceptions import PermissionDenied from django.http import Http404 from django.test import TestCase from django.template import TemplateDoesNotExist, Template import django.template.loader +from rest_framework import status from rest_framework.compat import patterns, url from rest_framework.decorators import api_view, renderer_classes from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.response import Response +from rest_framework.compat import six @api_view(('GET',)) @@ -63,19 +66,19 @@ class TemplateHTMLRendererTests(TestCase): def test_simple_html_view(self): response = self.client.get('/') self.assertContains(response, "example: foobar") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html') def test_not_found_html_view(self): response = self.client.get('/not_found') - self.assertEquals(response.status_code, 404) - self.assertEquals(response.content, "404 Not Found") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.content, six.b("404 Not Found")) + self.assertEqual(response['Content-Type'], 'text/html') def test_permission_denied_html_view(self): response = self.client.get('/permission_denied') - self.assertEquals(response.status_code, 403) - self.assertEquals(response.content, "403 Forbidden") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.content, six.b("403 Forbidden")) + self.assertEqual(response['Content-Type'], 'text/html') class TemplateHTMLRendererExceptionTests(TestCase): @@ -104,12 +107,12 @@ class TemplateHTMLRendererExceptionTests(TestCase): def test_not_found_html_view_with_template(self): response = self.client.get('/not_found') - self.assertEquals(response.status_code, 404) - self.assertEquals(response.content, "404: Not found") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.content, six.b("404: Not found")) + self.assertEqual(response['Content-Type'], 'text/html') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') - self.assertEquals(response.status_code, 403) - self.assertEquals(response.content, "403: Permission denied") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.content, six.b("403: Permission denied")) + self.assertEqual(response['Content-Type'], 'text/html') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index ee4d8e57..8fc6ba77 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,6 +1,7 @@ +from __future__ import unicode_literals +import json from django.test import TestCase from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel @@ -26,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -72,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -99,7 +110,7 @@ class TestBasicHyperlinkedView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] for item in items: @@ -118,8 +129,8 @@ class TestBasicHyperlinkedView(TestCase): """ request = factory.get('/basic/') response = self.list_view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_get_detail_view(self): """ @@ -127,8 +138,8 @@ class TestBasicHyperlinkedView(TestCase): """ request = factory.get('/basic/1') response = self.detail_view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) class TestManyToManyHyperlinkedView(TestCase): @@ -136,7 +147,7 @@ class TestManyToManyHyperlinkedView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] anchors = [] @@ -166,8 +177,8 @@ class TestManyToManyHyperlinkedView(TestCase): """ request = factory.get('/manytomany/') response = self.list_view(request) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_get_detail_view(self): """ @@ -175,8 +186,38 @@ class TestManyToManyHyperlinkedView(TestCase): """ request = factory.get('/manytomany/1/') response = self.detail_view(request, pk=1) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) + + +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) class TestCreateWithForeignKeys(TestCase): @@ -234,7 +275,7 @@ class TestOptionalRelationHyperlinkedView(TestCase): def setUp(self): """ - Create 1 OptionalRelationModel intances. + Create 1 OptionalRelationModel instances. """ OptionalRelationModel().save() self.objects = OptionalRelationModel.objects @@ -248,8 +289,8 @@ class TestOptionalRelationHyperlinkedView(TestCase): """ request = factory.get('/optionalrelationmodel-detail/1') response = self.detail_view(request, pk=1) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_put_detail_view(self): """ diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 69fd0b30..3465268b 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -1,36 +1,7 @@ +from __future__ import unicode_literals from django.db import models -from django.contrib.contenttypes.models import ContentType -from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation from django.utils.translation import ugettext_lazy as _ -# from django.contrib.auth.models import Group - - -# class CustomUser(models.Model): -# """ -# A custom user model, which uses a 'through' table for the foreign key -# """ -# username = models.CharField(max_length=255, unique=True) -# groups = models.ManyToManyField( -# to=Group, blank=True, null=True, through='UserGroupMap' -# ) - -# @models.permalink -# def get_absolute_url(self): -# return ('custom_user', (), { -# 'pk': self.id -# }) - - -# class UserGroupMap(models.Model): -# user = models.ForeignKey(to=CustomUser) -# group = models.ForeignKey(to=Group) - -# @models.permalink -# def get_absolute_url(self): -# return ('user_group_map', (), { -# 'pk': self.id -# }) def foobar(): return 'foobar' @@ -72,6 +43,7 @@ class SlugBasedModel(RESTFrameworkModel): class DefaultValueModel(RESTFrameworkModel): text = models.CharField(default='foobar', max_length=100) + extra = models.CharField(blank=True, null=True, max_length=100) class CallableDefaultValueModel(RESTFrameworkModel): @@ -86,34 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') rel = models.ManyToManyField(Anchor) -# Models to test generic relations - - -class Tag(RESTFrameworkModel): - tag_name = models.SlugField() - - -class TaggedItem(RESTFrameworkModel): - tag = models.ForeignKey(Tag, related_name='items') - content_type = models.ForeignKey(ContentType) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') - - def __unicode__(self): - return self.tag.tag_name - - -class Bookmark(RESTFrameworkModel): - url = models.URLField() - tags = GenericRelation(TaggedItem) - - -# Model to test filtering. -class FilterableItem(RESTFrameworkModel): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - # Model for regression test for #285 @@ -177,3 +121,42 @@ class OptionalRelationModel(RESTFrameworkModel): # Model for RegexField class Book(RESTFrameworkModel): isbn = models.CharField(max_length=13) + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, + related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): + name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='nullable_source') diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py deleted file mode 100644 index 1f8468e8..00000000 --- a/rest_framework/tests/modelviews.py +++ /dev/null @@ -1,90 +0,0 @@ -# from django.conf.urls.defaults import patterns, url -# from django.forms import ModelForm -# from django.contrib.auth.models import Group, User -# from rest_framework.resources import ModelResource -# from rest_framework.views import ListOrCreateModelView, InstanceModelView -# from rest_framework.tests.models import CustomUser -# from rest_framework.tests.testcases import TestModelsTestCase - - -# class GroupResource(ModelResource): -# model = Group - - -# class UserForm(ModelForm): -# class Meta: -# model = User -# exclude = ('last_login', 'date_joined') - - -# class UserResource(ModelResource): -# model = User -# form = UserForm - - -# class CustomUserResource(ModelResource): -# model = CustomUser - -# urlpatterns = patterns('', -# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), -# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), -# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), -# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), -# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), -# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), -# ) - - -# class ModelViewTests(TestModelsTestCase): -# """Test the model views rest_framework provides""" -# urls = 'rest_framework.tests.modelviews' - -# def test_creation(self): -# """Ensure that a model object can be created""" -# self.assertEqual(0, Group.objects.count()) - -# response = self.client.post('/groups/', {'name': 'foo'}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, Group.objects.count()) -# self.assertEqual('foo', Group.objects.all()[0].name) - -# def test_creation_with_m2m_relation(self): -# """Ensure that a model object with a m2m relation can be created""" -# group = Group(name='foo') -# group.save() -# self.assertEqual(0, User.objects.count()) - -# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, User.objects.count()) - -# user = User.objects.all()[0] -# self.assertEqual('bar', user.username) -# self.assertEqual('baz', user.password) -# self.assertEqual(1, user.groups.count()) - -# group = user.groups.all()[0] -# self.assertEqual('foo', group.name) - -# def test_creation_with_m2m_relation_through(self): -# """ -# Ensure that a model object with a m2m relation can be created where that -# relation uses a through table -# """ -# group = Group(name='foo') -# group.save() -# self.assertEqual(0, User.objects.count()) - -# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, CustomUser.objects.count()) - -# user = CustomUser.objects.all()[0] -# self.assertEqual('bar', user.username) -# self.assertEqual(1, user.groups.count()) - -# group = user.groups.all()[0] -# self.assertEqual('foo', group.name) diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py new file mode 100644 index 00000000..00c15327 --- /dev/null +++ b/rest_framework/tests/multitable_inheritance.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import RESTFrameworkModel + + +# Models +class ParentModel(RESTFrameworkModel): + name1 = models.CharField(max_length=100) + + +class ChildModel(ParentModel): + name2 = models.CharField(max_length=100) + + +class AssociatedModel(RESTFrameworkModel): + ref = models.OneToOneField(ParentModel, primary_key=True) + name = models.CharField(max_length=100) + + +# Serializers +class DerivedModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + + +class AssociatedModelSerializer(serializers.ModelSerializer): + class Meta: + model = AssociatedModel + + +# Tests +class IneritedModelSerializationTests(TestCase): + + def test_multitable_inherited_model_fields_as_expected(self): + """ + Assert that the parent pointer field is not included in the fields + serialized fields + """ + child = ChildModel(name1='parent name', name2='child name') + serializer = DerivedModelSerializer(child) + self.assertEqual(set(serializer.data.keys()), + set(['name1', 'name2', 'id'])) + + def test_onetoone_primary_key_model_fields_as_expected(self): + """ + Assert that a model with a onetoone field that is the primary key is + not treated like a derived model + """ + parent = ParentModel(name1='parent name') + associate = AssociatedModel(name='hello', ref=parent) + serializer = AssociatedModelSerializer(associate) + self.assertEqual(set(serializer.data.keys()), + set(['name', 'ref'])) + + def test_data_is_valid_without_parent_ptr(self): + """ + Assert that the pointer to the parent table is not a required field + for input data + """ + data = { + 'name1': 'parent name', + 'name2': 'child name', + } + serializer = DerivedModelSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py index e06354ea..43721b84 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/negotiation.py @@ -1,6 +1,9 @@ +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation +from rest_framework.request import Request + factory = RequestFactory() @@ -22,16 +25,16 @@ class TestAcceptedMediaType(TestCase): return self.negotiator.select_renderer(request, self.renderers) def test_client_without_accept_use_renderer(self): - request = factory.get('/') + request = Request(factory.get('/')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json') + self.assertEqual(accepted_media_type, 'application/json') def test_client_underspecifies_accept_use_renderer(self): - request = factory.get('/', HTTP_ACCEPT='*/*') + request = Request(factory.get('/', HTTP_ACCEPT='*/*')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json') + self.assertEqual(accepted_media_type, 'application/json') def test_client_overspecifies_accept_use_client(self): - request = factory.get('/', HTTP_ACCEPT='application/json; indent=8') + request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json; indent=8') + self.assertEqual(accepted_media_type, 'application/json; indent=8') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 81d297a1..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,16 +1,24 @@ +from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters -from rest_framework.tests.models import BasicModel, FilterableItem +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. @@ -19,21 +27,6 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -if django_filters: - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend - - class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage @@ -72,28 +65,32 @@ class IntegrationTestPagination(TestCase): GET requests to paginated ListCreateAPIView should return paginated results. """ request = factory.get('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[10:20]) - self.assertNotEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[10:20]) + self.assertNotEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[20:]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[20:]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) class IntegrationTestPaginationAndFiltering(TestCase): @@ -111,41 +108,105 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + for obj in self.objects.all() ] - self.view = FilterFieldsRootView.as_view() @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_paginated_filtered_root_view(self): + def test_get_django_filter_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return paginated results. The next and previous links should preserve the filtered parameters. """ + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backends = (filters.DjangoFilterBackend,) + + view = FilterFieldsRootView.as_view() + + EXPECTED_NUM_QUERIES = 2 + request = factory.get('/?decimal=15.20') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[10:15]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['previous']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + def test_get_basic_paginated_filtered_root_view(self): + """ + Same as `test_get_django_filter_paginated_filtered_root_view`, + except using a custom filter backend instead of the django-filter + backend, + """ + + class DecimalFilterBackend(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + + class BasicFilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_backends = (DecimalFilterBackend,) + + view = BasicFilterFieldsRootView.as_view() + + request = factory.get('/?decimal=15.20') + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + request = factory.get(response.data['next']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) + + request = factory.get(response.data['previous']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) class PassOnContextPaginationSerializer(pagination.PaginationSerializer): @@ -166,25 +227,25 @@ class UnitTestPagination(TestCase): def test_native_pagination(self): serializer = pagination.PaginationSerializer(self.first_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], '?page=2') - self.assertEquals(serializer.data['previous'], None) - self.assertEquals(serializer.data['results'], self.objects[:10]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], '?page=2') + self.assertEqual(serializer.data['previous'], None) + self.assertEqual(serializer.data['results'], self.objects[:10]) serializer = pagination.PaginationSerializer(self.last_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], None) - self.assertEquals(serializer.data['previous'], '?page=2') - self.assertEquals(serializer.data['results'], self.objects[20:]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], None) + self.assertEqual(serializer.data['previous'], '?page=2') + self.assertEqual(serializer.data['results'], self.objects[20:]) def test_context_available_in_result(self): """ Ensure context gets passed through to the object serializer. """ - serializer = PassOnContextPaginationSerializer(self.first_page) + serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer.data results = serializer.fields[serializer.results_field] - self.assertTrue(serializer.context is results.context) + self.assertEqual(serializer.context, results.context) class TestUnpaginated(TestCase): @@ -212,7 +273,7 @@ class TestUnpaginated(TestCase): """ request = factory.get('/') response = self.view(request) - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) class TestCustomPaginateByParam(TestCase): @@ -240,7 +301,7 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) def test_paginate_by_param(self): """ @@ -248,10 +309,12 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/?page_size=5') response = self.view(request).render() - self.assertEquals(response.data['count'], 13) - self.assertEquals(response.data['results'], self.data[:5]) + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:5]) +### Tests for context in pagination serializers + class CustomField(serializers.Field): def to_native(self, value): if not 'view' in self.context: @@ -262,6 +325,11 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() + def __init__(self, *args, **kwargs): + super(BasicModelSerializer, self).__init__(*args, **kwargs) + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into serializer init") + class TestContextPassedToCustomField(TestCase): def setUp(self): @@ -277,5 +345,41 @@ class TestContextPassedToCustomField(TestCase): request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): + links = LinksSerializer(source='*') # Takes the page object as the source + total_results = serializers.Field(source='paginator.count') + + results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): + def setUp(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = Paginator(objects, 2) + self.page = paginator.page(1) + + def test_custom_pagination_serializer(self): + request = RequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=self.page, + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page=2', + 'prev': None + }, + 'total_results': 4, + 'objects': ['john', 'paul'] + } + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index 8ab8a52f..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -1,140 +1,11 @@ -# """ -# .. -# >>> from rest_framework.parsers import FormParser -# >>> from django.test.client import RequestFactory -# >>> from rest_framework.views import View -# >>> from StringIO import StringIO -# >>> from urllib import urlencode -# >>> req = RequestFactory().get('/') -# >>> some_view = View() -# >>> some_view.request = req # Make as if this request had been dispatched -# -# FormParser -# ============ -# -# Data flatening -# ---------------- -# -# Here is some example data, which would eventually be sent along with a post request : -# -# >>> inpt = urlencode([ -# ... ('key1', 'bla1'), -# ... ('key2', 'blo1'), ('key2', 'blo2'), -# ... ]) -# -# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter : -# -# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'bla1', 'key2': 'blo1'} -# True -# -# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` : -# -# >>> class MyFormParser(FormParser): -# ... -# ... def is_a_list(self, key, val_list): -# ... return len(val_list) > 1 -# -# This new parser only flattens the lists of parameters that contain a single value. -# -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']} -# True -# -# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`. -# -# Submitting an empty list -# -------------------------- -# -# When submitting an empty select multiple, like this one :: -# -# <select multiple="multiple" name="key2"></select> -# -# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty :: -# -# <select multiple="multiple" name="key2"><option value="_empty"></select> -# -# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data : -# -# >>> inpt = urlencode([ -# ... ('key1', 'blo1'), ('key1', '_empty'), -# ... ('key2', '_empty'), -# ... ]) -# -# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists. -# -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'blo1'} -# True -# -# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it. -# -# >>> class MyFormParser(FormParser): -# ... -# ... def is_a_list(self, key, val_list): -# ... return key == 'key2' -# ... -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'blo1', 'key2': []} -# True -# -# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`. -# """ -# import httplib, mimetypes -# from tempfile import TemporaryFile -# from django.test import TestCase -# from django.test.client import RequestFactory -# from rest_framework.parsers import MultiPartParser -# from rest_framework.views import View -# from StringIO import StringIO -# -# def encode_multipart_formdata(fields, files): -# """For testing multipart parser. -# fields is a sequence of (name, value) elements for regular form fields. -# files is a sequence of (name, filename, value) elements for data to be uploaded as files -# Return (content_type, body).""" -# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$' -# CRLF = '\r\n' -# L = [] -# for (key, value) in fields: -# L.append('--' + BOUNDARY) -# L.append('Content-Disposition: form-data; name="%s"' % key) -# L.append('') -# L.append(value) -# for (key, filename, value) in files: -# L.append('--' + BOUNDARY) -# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)) -# L.append('Content-Type: %s' % get_content_type(filename)) -# L.append('') -# L.append(value) -# L.append('--' + BOUNDARY + '--') -# L.append('') -# body = CRLF.join(L) -# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY -# return content_type, body -# -# def get_content_type(filename): -# return mimetypes.guess_type(filename)[0] or 'application/octet-stream' -# -#class TestMultiPartParser(TestCase): -# def setUp(self): -# self.req = RequestFactory() -# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')], -# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')]) -# -# def test_multipartparser(self): -# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters.""" -# post_req = RequestFactory().post('/', self.body, content_type=self.content_type) -# view = View() -# view.request = post_req -# (data, files) = MultiPartParser(view).parse(StringIO(self.body)) -# self.assertEqual(data['key1'], 'val1') -# self.assertEqual(files['file1'].read(), 'blablabla') - -from StringIO import StringIO +from __future__ import unicode_literals +from rest_framework.compat import StringIO from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase -from rest_framework.parsers import FormParser +from django.utils import unittest +from rest_framework.compat import etree +from rest_framework.parsers import FormParser, FileUploadParser from rest_framework.parsers import XMLParser import datetime @@ -201,12 +72,44 @@ class TestXMLParser(TestCase): ] } + @unittest.skipUnless(etree, 'defusedxml not installed') def test_parse(self): parser = XMLParser() data = parser.parse(self._input) self.assertEqual(data, self._data) + @unittest.skipUnless(etree, 'defusedxml not installed') def test_complex_data_parse(self): parser = XMLParser() data = parser.parse(self._complex_data_input) self.assertEqual(data, self._complex_data) + + +class TestFileUploadParser(TestCase): + def setUp(self): + class MockRequest(object): + pass + from io import BytesIO + self.stream = BytesIO( + "Test text file".encode('utf-8') + ) + request = MockRequest() + request.upload_handlers = (MemoryFileUploadHandler(),) + request.META = { + 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'), + 'HTTP_CONTENT_LENGTH': 14, + } + self.parser_context = {'request': request, 'kwargs': {}} + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FileUploadParser() + self.stream.seek(0) + data_and_files = parser.parse(self.stream, None, self.parser_context) + file_obj = data_and_files.files['file'] + self.assertEqual(file_obj._size, 14) + + def test_get_filename(self): + parser = FileUploadParser() + filename = parser.get_filename(self.stream, None, self.parser_context) + self.assertEqual(filename, 'file.txt'.encode('utf-8')) diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py new file mode 100644 index 00000000..b3993be5 --- /dev/null +++ b/rest_framework/tests/permissions.py @@ -0,0 +1,153 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User, Permission +from django.db import models +from django.test import TestCase +from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework.tests.utils import RequestFactory +import base64 +import json + +factory = RequestFactory() + + +class BasicModel(models.Model): + text = models.CharField(max_length=100) + + +class RootView(generics.ListCreateAPIView): + model = BasicModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [permissions.DjangoModelPermissions] + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [permissions.DjangoModelPermissions] + +root_view = RootView.as_view() +instance_view = InstanceView.as_view() + + +def basic_auth_header(username, password): + credentials = ('%s:%s' % (username, password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + return 'Basic %s' % base64_credentials + + +class ModelPermissionsIntegrationTests(TestCase): + def setUp(self): + User.objects.create_user('disallowed', 'disallowed@example.com', 'password') + user = User.objects.create_user('permitted', 'permitted@example.com', 'password') + user.user_permissions = [ + Permission.objects.get(codename='add_basicmodel'), + Permission.objects.get(codename='change_basicmodel'), + Permission.objects.get(codename='delete_basicmodel') + ] + user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') + user.user_permissions = [ + Permission.objects.get(codename='change_basicmodel'), + ] + + self.permitted_credentials = basic_auth_header('permitted', 'password') + self.disallowed_credentials = basic_auth_header('disallowed', 'password') + self.updateonly_credentials = basic_auth_header('updateonly', 'password') + + BasicModel(text='foo').save() + + def test_has_create_permissions(self): + request = factory.post('/', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_has_put_permissions(self): + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_has_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + def test_does_not_have_create_permissions(self): + request = factory.post('/', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_does_not_have_put_permissions(self): + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_does_not_have_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_has_put_as_create_permissions(self): + # User only has update permissions - should be able to update an entity. + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # But if PUTing to a new entity, permission should be denied. + request = factory.put('/2', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='2') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +class OwnerModel(models.Model): + text = models.CharField(max_length=100) + owner = models.ForeignKey(User) + + +class IsOwnerPermission(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return request.user == obj.owner + + +class OwnerInstanceView(generics.RetrieveUpdateDestroyAPIView): + model = OwnerModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [IsOwnerPermission] + + +owner_instance_view = OwnerInstanceView.as_view() + + +class ObjectPermissionsIntegrationTests(TestCase): + """ + Integration tests for the object level permissions API. + """ + + def setUp(self): + User.objects.create_user('not_owner', 'not_owner@example.com', 'password') + user = User.objects.create_user('owner', 'owner@example.com', 'password') + + self.not_owner_credentials = basic_auth_header('not_owner', 'password') + self.owner_credentials = basic_auth_header('owner', 'password') + + OwnerModel(text='foo', owner=user).save() + + def test_owner_has_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.owner_credentials) + response = owner_instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + def test_non_owner_does_not_have_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.not_owner_credentials) + response = owner_instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py new file mode 100644 index 00000000..cbf93c65 --- /dev/null +++ b/rest_framework/tests/relations.py @@ -0,0 +1,47 @@ +""" +General tests for relational fields. +""" +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class NullModel(models.Model): + pass + + +class FieldTests(TestCase): + def test_pk_related_field_with_empty_string(self): + """ + Regression test for #446 + + https://github.com/tomchristie/django-rest-framework/issues/446 + """ + field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_hyperlinked_related_field_with_empty_string(self): + field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_slug_related_field_with_empty_string(self): + field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelateMixin(TestCase): + def test_missing_many_to_many_related_field(self): + ''' + Regression test for #632 + + https://github.com/tomchristie/django-rest-framework/pull/632 + ''' + field = serializers.RelatedField(many=True, read_only=False) + + into = {} + field.field_from_native({}, None, 'field_name', into) + self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 53ce0074..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -1,7 +1,15 @@ -from django.db import models +from __future__ import unicode_literals from django.test import TestCase +from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.tests.models import ( + ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +) + +factory = RequestFactory() +request = factory.get('/') # Just to ensure we have a request in the serializer context def dummy_view(request, pk): @@ -13,66 +21,49 @@ urlpatterns = patterns('', url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), + url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), + url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), ) # ManyToMany - -class ManyToManyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): - name = models.CharField(max_length=100) - targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') - - class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') - class Meta: model = ManyToManyTarget + fields = ('url', 'name', 'sources') class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ManyToManySource + fields = ('url', 'name', 'targets') # ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - - class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail', read_only=True) - class Meta: model = ForeignKeyTarget + fields = ('url', 'name', 'sources') class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = ForeignKeySource + fields = ('url', 'name', 'target') # Nullable ForeignKey - -class NullableForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, - related_name='nullable_sources') - - class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('url', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('url', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -91,98 +82,98 @@ class HyperlinkedManyToManyTests(TestCase): def test_many_to_many_retrieve(self): queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): - data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} instance = ManyToManySource.objects.get(pk=1) - serializer = ManyToManySourceSerializer(instance, data=data) + serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_update(self): - data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} + data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} instance = ManyToManyTarget.objects.get(pk=1) - serializer = ManyToManyTargetSerializer(instance, data=data) + serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_create(self): - data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} - serializer = ManyToManySourceSerializer(data=data) + data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} + serializer = ManyToManySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, - {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, + {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_create(self): - data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} - serializer = ManyToManyTargetSerializer(data=data) + data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} + serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') # Ensure target 4 is added, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, - {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class HyperlinkedForeignKeyTests(TestCase): @@ -199,47 +190,118 @@ class HyperlinkedForeignKeyTests(TestCase): def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, - {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update(self): - data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']}) + + def test_reverse_foreign_key_update(self): + data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) + serializer.save() + self.assertEqual(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} + serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} + serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') + + # Ensure target 4 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, - {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, + {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_invalid_null(self): - data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) class HyperlinkedNullableForeignKeyTests(TestCase): @@ -249,110 +311,143 @@ class HyperlinkedNullableForeignKeyTests(TestCase): target = ForeignKeyTarget(name='target-1') target.save() for idx in range(1, 4): + if idx == 3: + target = None source = NullableForeignKeySource(name='source-%d' % idx, target=target) source.save() + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, + ] + self.assertEqual(serializer.data, expected) + def test_foreign_key_create_with_valid_null(self): - data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} - expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, expected_data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_null(self): - data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} - expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, expected_data) + self.assertEqual(serializer.data, expected_data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) # reverse foreign keys MUST be read_only # In the general case they do not provide .remove() or .clear() # and cannot be arbitrarily set. # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # data = {'id': 1, 'name': 'target-1', 'sources': [1]} # instance = ForeignKeyTarget.objects.get(pk=1) # serializer = ForeignKeyTargetSerializer(instance, data=data) # self.assertTrue(serializer.is_valid()) - # self.assertEquals(serializer.data, data) + # self.assertEqual(serializer.data, data) # serializer.save() # # Ensure target 1 is updated, and everything else is as expected # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset) + # serializer = ForeignKeyTargetSerializer(queryset, many=True) # expected = [ - # {'id': 1, 'name': u'target-1', 'sources': [1]}, - # {'id': 2, 'name': u'target-2', 'sources': []}, + # {'id': 1, 'name': 'target-1', 'sources': [1]}, + # {'id': 2, 'name': 'target-2', 'sources': []}, # ] - # self.assertEquals(serializer.data, expected) + # self.assertEqual(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): + urls = 'rest_framework.tests.relations_hyperlink' + + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) + expected = [ + {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, + {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 3482c252..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,29 +1,35 @@ -from django.db import models +from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers - - -# ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') +from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = ForeignKeySourceSerializer() - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') + depth = 1 + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + class Meta: + model = NullableForeignKeySource + fields = ('id', 'name', 'target') + depth = 1 + + +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') + depth = 1 class ReverseForeignKeyTests(TestCase): @@ -36,16 +42,66 @@ class ReverseForeignKeyTests(TestCase): source = ForeignKeySource(name='source-%d' % idx, target=target) source.save() + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, + ] + self.assertEqual(serializer.data, expected) + def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1}, + {'id': 1, 'name': 'target-1', 'sources': [ + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1}, ]}, - {'id': 2, 'name': u'target-2', 'sources': [ + {'id': 2, 'name': 'target-2', 'sources': [ ]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) + + +class NestedNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 3, 'name': 'source-3', 'target': None}, + ] + self.assertEqual(serializer.data, expected) + + +class NestedNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, + {'id': 2, 'name': 'target-2', 'nullable_source': None}, + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index e3360939..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -1,65 +1,48 @@ -from django.db import models +from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.compat import six # ManyToMany - -class ManyToManyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): - name = models.CharField(max_length=100) - targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') - - class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.ManyPrimaryKeyRelatedField() - class Meta: model = ManyToManyTarget + fields = ('id', 'name', 'sources') class ManyToManySourceSerializer(serializers.ModelSerializer): class Meta: model = ManyToManySource + fields = ('id', 'name', 'targets') # ForeignKey - -class ForeignKeyTarget(models.Model): - name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - - class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.ManyPrimaryKeyRelatedField(read_only=True) - class Meta: model = ForeignKeyTarget + fields = ('id', 'name', 'sources') class ForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeySource + fields = ('id', 'name', 'target') # Nullable ForeignKey - -class NullableForeignKeySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, - related_name='nullable_sources') - - class NullableForeignKeySourceSerializer(serializers.ModelSerializer): class Meta: model = NullableForeignKeySource + fields = ('id', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'nullable_source') # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -76,97 +59,97 @@ class PKManyToManyTests(TestCase): def test_many_to_many_retrieve(self): queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + {'id': 1, 'name': 'source-1', 'targets': [1]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]} + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): - data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} + data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} instance = ManyToManySource.objects.get(pk=1) serializer = ManyToManySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_update(self): - data = {'id': 1, 'name': u'target-1', 'sources': [1]} + data = {'id': 1, 'name': 'target-1', 'sources': [1]} instance = ManyToManyTarget.objects.get(pk=1) serializer = ManyToManyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]} + {'id': 1, 'name': 'target-1', 'sources': [1]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_create(self): - data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]} + data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} serializer = ManyToManySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}, - {'id': 4, 'name': u'source-4', 'targets': [1, 3]}, + {'id': 1, 'name': 'source-1', 'targets': [1]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, + {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_create(self): - data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} serializer = ManyToManyTargetSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') # Ensure target 4 is added, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]}, - {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]}, + {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class PKForeignKeyTests(TestCase): @@ -181,47 +164,118 @@ class PKForeignKeyTests(TestCase): def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': []}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update(self): - data = {'id': 1, 'name': u'source-1', 'target': 2} + data = {'id': 1, 'name': 'source-1', 'target': 2} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 2}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': 'source-1', 'target': 'foo'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) + + serializer.save() + self.assertEqual(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [2]}, + {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': 'source-4', 'target': 2} + serializer = ForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1}, + {'id': 4, 'name': 'source-4', 'target': 2}, + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 2}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'target-1', 'sources': [2]}, + {'id': 2, 'name': 'target-2', 'sources': []}, + {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_invalid_null(self): - data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) class PKNullableForeignKeyTests(TestCase): @@ -229,110 +283,141 @@ class PKNullableForeignKeyTests(TestCase): target = ForeignKeyTarget(name='target-1') target.save() for idx in range(1, 4): + if idx == 3: + target = None source = NullableForeignKeySource(name='source-%d' % idx, target=target) source.save() + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, + ] + self.assertEqual(serializer.data, expected) + def test_foreign_key_create_with_valid_null(self): - data = {'id': 4, 'name': u'source-4', 'target': None} + data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1}, - {'id': 4, 'name': u'source-4', 'target': None} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 4, 'name': u'source-4', 'target': ''} - expected_data = {'id': 4, 'name': u'source-4', 'target': None} + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, expected_data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1}, - {'id': 4, 'name': u'source-4', 'target': None} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_null(self): - data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': None}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 1, 'name': u'source-1', 'target': ''} - expected_data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, expected_data) + self.assertEqual(serializer.data, expected_data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': None}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) # reverse foreign keys MUST be read_only # In the general case they do not provide .remove() or .clear() # and cannot be arbitrarily set. # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # data = {'id': 1, 'name': 'target-1', 'sources': [1]} # instance = ForeignKeyTarget.objects.get(pk=1) # serializer = ForeignKeyTargetSerializer(instance, data=data) # self.assertTrue(serializer.is_valid()) - # self.assertEquals(serializer.data, data) + # self.assertEqual(serializer.data, data) # serializer.save() # # Ensure target 1 is updated, and everything else is as expected # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset) + # serializer = ForeignKeyTargetSerializer(queryset, many=True) # expected = [ - # {'id': 1, 'name': u'target-1', 'sources': [1]}, - # {'id': 2, 'name': u'target-2', 'sources': []}, + # {'id': 1, 'name': 'target-1', 'sources': [1]}, + # {'id': 2, 'name': 'target-2', 'sources': []}, # ] - # self.assertEquals(serializer.data, expected) + # self.assertEqual(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): + def setUp(self): + target = OneToOneTarget(name='target-1') + target.save() + new_target = OneToOneTarget(name='target-2') + new_target.save() + source = NullableOneToOneSource(name='source-1', target=new_target) + source.save() + + def test_reverse_foreign_key_retrieve_with_null(self): + queryset = OneToOneTarget.objects.all() + serializer = NullableOneToOneTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'nullable_source': None}, + {'id': 2, 'name': 'target-2', 'nullable_source': 1}, + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py new file mode 100644 index 00000000..435c821c --- /dev/null +++ b/rest_framework/tests/relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.SlugRelatedField(many=True, slug_field='name') + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name', required=False) + + class Meta: + model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nullable), One2One +class SlugForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'} + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-2'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': 'source-1', 'target': 123} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) + + serializer.save() + self.assertEqual(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} + serializer = ForeignKeySourceSerializer(data=data) + serializer.is_valid() + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'}, + {'id': 4, 'name': 'source-4', 'target': 'target-2'}, + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + + +class SlugNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create_with_valid_null(self): + data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_valid_null(self): + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, expected_data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None} + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index c1b4e624..40bac9cb 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -1,29 +1,28 @@ -import pickle -import re - +from decimal import Decimal from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory - +from django.utils import unittest from rest_framework import status, permissions -from rest_framework.compat import yaml, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings - -from StringIO import StringIO +from rest_framework.compat import StringIO +from rest_framework.compat import six import datetime -from decimal import Decimal +import pickle +import re DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' -RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x -RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ @@ -35,7 +34,7 @@ class BasicRendererTests(TestCase): def test_expected_results(self): for value, renderer_cls, expected in expected_results: output = renderer_cls().render(value) - self.assertEquals(output, expected) + self.assertEqual(output, expected) class RendererA(BaseRenderer): @@ -94,7 +93,7 @@ urlpatterns = patterns('', class POSTDeniedPermission(permissions.BasePermission): - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): return request.method != 'POST' @@ -111,6 +110,9 @@ class POSTDeniedView(APIView): def put(self, request): return Response() + def patch(self, request): + return Response() + class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): @@ -119,6 +121,7 @@ class DocumentingRendererTests(TestCase): response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') + self.assertContains(response, '>PATCH<') class RendererEndToEndTests(TestCase): @@ -131,39 +134,39 @@ class RendererEndToEndTests(TestCase): def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_head_method_serializes_no_content(self): """No response must be included in HEAD requests.""" resp = self.client.head('/') - self.assertEquals(resp.status_code, DUMMYSTATUS) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, '') + self.assertEqual(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_non_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_accept_query(self): """The '_accept' query string should behave in the same way as the Accept header.""" @@ -172,14 +175,14 @@ class RendererEndToEndTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_unsatisfiable_accept_header_on_request_returns_406_status(self): """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching @@ -189,17 +192,17 @@ class RendererEndToEndTests(TestCase): RendererB.format ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_kwargs(self): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): """If both a 'format' query and a matching Accept header specified, @@ -210,9 +213,9 @@ class RendererEndToEndTests(TestCase): ) resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) _flat_repr = '{"foo": ["bar", "baz"]}' @@ -240,7 +243,7 @@ class JSONRendererTests(TestCase): renderer = JSONRenderer() content = renderer.render(obj, 'application/json') # Fix failing test case which depends on version of JSON library. - self.assertEquals(content, _flat_repr) + self.assertEqual(content, _flat_repr) def test_with_content_type_args(self): """ @@ -249,7 +252,7 @@ class JSONRendererTests(TestCase): obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json; indent=2') - self.assertEquals(strip_trailing_whitespace(content), _indented_repr) + self.assertEqual(strip_trailing_whitespace(content), _indented_repr) class JSONPRendererTests(TestCase): @@ -265,9 +268,10 @@ class JSONPRendererTests(TestCase): """ resp = self.client.get('/jsonp/jsonrenderer', HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('callback(%s);' % _flat_repr).encode('ascii')) def test_without_callback_without_json_renderer(self): """ @@ -275,9 +279,10 @@ class JSONPRendererTests(TestCase): """ resp = self.client.get('/jsonp/nojsonrenderer', HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('callback(%s);' % _flat_repr).encode('ascii')) def test_with_callback(self): """ @@ -286,9 +291,10 @@ class JSONPRendererTests(TestCase): callback_func = 'myjsonpcallback' resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) if yaml: @@ -306,7 +312,7 @@ if yaml: obj = {'foo': ['bar', 'baz']} renderer = YAMLRenderer() content = renderer.render(obj, 'application/yaml') - self.assertEquals(content, _yaml_repr) + self.assertEqual(content, _yaml_repr) def test_render_and_parse(self): """ @@ -320,7 +326,7 @@ if yaml: content = renderer.render(obj, 'application/yaml') data = parser.parse(StringIO(content)) - self.assertEquals(obj, data) + self.assertEqual(obj, data) class XMLRendererTestCase(TestCase): @@ -402,6 +408,7 @@ class XMLRendererTestCase(TestCase): self.assertXMLContains(content, '<sub_name>first</sub_name>') self.assertXMLContains(content, '<sub_name>second</sub_name>') + @unittest.skipUnless(etree, 'defusedxml not installed') def test_render_and_parse_complex_data(self): """ Test XML rendering. diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 1f05ff8f..97e5af20 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,12 +1,12 @@ """ Tests for content parsing, and form-overloaded content parsing. """ +from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client from django.test.client import RequestFactory -from django.utils import simplejson as json from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -20,6 +20,8 @@ from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +from rest_framework.compat import six +import json factory = RequestFactory() @@ -56,21 +58,29 @@ class TestMethodOverloading(TestCase): request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) self.assertEqual(request.method, 'DELETE') + def test_x_http_method_override_header(self): + """ + POST requests can also be overloaded to another method by setting + the X-HTTP-Method-Override header. + """ + request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) + self.assertEqual(request.method, 'DELETE') + class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self): """ - Ensure request.DATA returns None for GET request with no content. + Ensure request.DATA returns empty QueryDict for GET request. """ request = Request(factory.get('/')) - self.assertEqual(request.DATA, None) + self.assertEqual(request.DATA, {}) def test_standard_behaviour_determines_no_content_HEAD(self): """ - Ensure request.DATA returns None for HEAD request. + Ensure request.DATA returns empty QueryDict for HEAD request. """ request = Request(factory.head('/')) - self.assertEqual(request.DATA, None) + self.assertEqual(request.DATA, {}) def test_request_DATA_with_form_content(self): """ @@ -79,14 +89,14 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.DATA.items(), data.items()) + self.assertEqual(list(request.DATA.items()), list(data.items())) def test_request_DATA_with_text_content(self): """ Ensure request.DATA returns content for POST request with non-form content. """ - content = 'qwerty' + content = six.b('qwerty') content_type = 'text/plain' request = Request(factory.post('/', content, content_type=content_type)) request.parsers = (PlainTextParser(),) @@ -99,7 +109,7 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.POST.items(), data.items()) + self.assertEqual(list(request.POST.items()), list(data.items())) def test_standard_behaviour_determines_form_content_PUT(self): """ @@ -117,14 +127,14 @@ class TestContentParsing(TestCase): request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.DATA.items(), data.items()) + self.assertEqual(list(request.DATA.items()), list(data.items())) def test_standard_behaviour_determines_non_form_content_PUT(self): """ Ensure request.DATA returns content for PUT request with non-form content. """ - content = 'qwerty' + content = six.b('qwerty') content_type = 'text/plain' request = Request(factory.put('/', content, content_type=content_type)) request.parsers = (PlainTextParser(), ) diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 875f4d42..aecf83f4 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework.compat import patterns, url, include from rest_framework.response import Response @@ -9,6 +10,7 @@ from rest_framework.renderers import ( BrowsableAPIRenderer ) from rest_framework.settings import api_settings +from rest_framework.compat import six class MockPickleRenderer(BaseRenderer): @@ -22,8 +24,8 @@ class MockJsonRenderer(BaseRenderer): DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' -RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x -RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') class RendererA(BaseRenderer): @@ -83,39 +85,39 @@ class RendererIntegrationTests(TestCase): def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_head_method_serializes_no_content(self): """No response must be included in HEAD requests.""" resp = self.client.head('/') - self.assertEquals(resp.status_code, DUMMYSTATUS) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, '') + self.assertEqual(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_non_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_accept_query(self): """The '_accept' query string should behave in the same way as the Accept header.""" @@ -124,34 +126,34 @@ class RendererIntegrationTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_kwargs(self): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): """If both a 'format' query and a matching Accept header specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) class Issue122Tests(TestCase): diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py index 8c86e1fb..cb8d8132 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/reverse.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework.compat import patterns, url @@ -16,7 +17,7 @@ urlpatterns = patterns('', class ReverseTests(TestCase): """ - Tests for fully qualifed URLs when using `reverse`. + Tests for fully qualified URLs when using `reverse`. """ urls = 'rest_framework.tests.reverse' diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py new file mode 100644 index 00000000..4e4765cb --- /dev/null +++ b/rest_framework/tests/routers.py @@ -0,0 +1,55 @@ +from __future__ import unicode_literals +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import status +from rest_framework.response import Response +from rest_framework import viewsets +from rest_framework.decorators import link, action +from rest_framework.routers import SimpleRouter +import copy + +factory = RequestFactory() + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint)) + # check method to function mapping + if endpoint.startswith('action'): + method_map = 'post' + else: + method_map = 'get' + self.assertEqual(route.mapping[method_map], endpoint) + + diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 1c7283ae..b0c7e568 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,10 +1,12 @@ -import datetime -import pickle +from __future__ import unicode_literals +from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers, fields from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, - BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, + BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) +import datetime +import pickle class SubComment(object): @@ -54,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer): model = ActionItem +class ActionItemSerializerCustomRestore(serializers.ModelSerializer): + + class Meta: + model = ActionItem + + def restore_object(self, data, instance=None): + if instance is None: + return ActionItem(**data) + for key, val in data.items(): + setattr(instance, key, val) + return instance + + class PersonSerializer(serializers.ModelSerializer): info = serializers.Field(source='info') @@ -63,18 +78,36 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: model = Album fields = ['title'] # lists are also valid options + class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): class Meta: model = HasPositiveIntegerAsChoice fields = ['some_integer'] +class BrokenModelSerializer(serializers.ModelSerializer): + class Meta: + fields = ['some_field'] + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -106,39 +139,39 @@ class BasicTests(TestCase): 'created': None, 'sub_comment': '' } - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_retrieve(self): serializer = CommentSerializer(self.comment) - self.assertEquals(serializer.data, self.expected) + self.assertEqual(serializer.data, self.expected) def test_create(self): serializer = CommentSerializer(data=self.data) expected = self.comment - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected) self.assertFalse(serializer.object is expected) - self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected) self.assertTrue(serializer.object is expected) - self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') def test_partial_update(self): msg = 'Merry New Year!' partial_data = {'content': msg} serializer = CommentSerializer(self.comment, data=partial_data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) serializer = CommentSerializer(self.comment, data=partial_data, partial=True) expected = self.comment self.assertEqual(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.object, expected) self.assertTrue(serializer.object is expected) - self.assertEquals(serializer.data['content'], msg) + self.assertEqual(serializer.data['content'], msg) def test_model_fields_as_expected(self): """ @@ -146,7 +179,7 @@ class BasicTests(TestCase): in the Meta data """ serializer = PersonSerializer(self.person) - self.assertEquals(set(serializer.data.keys()), + self.assertEqual(set(serializer.data.keys()), set(['name', 'age', 'info'])) def test_field_with_dictionary(self): @@ -155,19 +188,51 @@ class BasicTests(TestCase): """ serializer = PersonSerializer(self.person) expected = self.person_data - self.assertEquals(serializer.data['info'], expected) + self.assertEqual(serializer.data['info'], expected) def test_read_only_fields(self): """ Attempting to update fields set as read_only should have no effect. """ - serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(serializer.errors, {}) + self.assertEqual(serializer.errors, {}) # Assert age is unchanged (35) - self.assertEquals(instance.age, self.person_data['age']) + self.assertEqual(instance.age, self.person_data['age']) + + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + + +class DictStyleSerializer(serializers.Serializer): + """ + Note that we don't have any `restore_object` method, so the default + case of simply returning a dict will apply. + """ + email = serializers.EmailField() + + +class DictStyleSerializerTests(TestCase): + def test_dict_style_deserialize(self): + """ + Ensure serializers can deserialize into a dict. + """ + data = {'email': 'foo@example.com'} + serializer = DictStyleSerializer(data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + + def test_dict_style_serialize(self): + """ + Ensure serializers can serialize dict objects. + """ + data = {'email': 'foo@example.com'} + serializer = DictStyleSerializer(data) + self.assertEqual(serializer.data, data) class ValidationTests(TestCase): @@ -182,18 +247,17 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } - self.actionitem = ActionItem(title='Some to do item', - ) + self.actionitem = ActionItem(title='Some to do item',) def test_create(self): serializer = CommentSerializer(data=self.data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) def test_update_missing_field(self): data = { @@ -201,8 +265,8 @@ class ValidationTests(TestCase): 'created': datetime.datetime(2012, 1, 1) } serializer = CommentSerializer(self.comment, data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'email': ['This field is required.']}) def test_missing_bool_with_default(self): """Make sure that a boolean value with a 'False' value is not @@ -212,33 +276,8 @@ class ValidationTests(TestCase): #No 'done' value. } serializer = ActionItemSerializer(self.actionitem, data=data) - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.errors, {}) - - def test_field_validation(self): - - class CommentSerializerWithFieldValidator(CommentSerializer): - - def validate_content(self, attrs, source): - value = attrs[source] - if "test" not in value: - raise serializers.ValidationError("Test not in value") - return attrs - - data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) - - data['content'] = 'This should not validate' - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.errors, {}) def test_cross_field_validation(self): @@ -262,23 +301,37 @@ class ValidationTests(TestCase): serializer = CommentSerializerWithCrossFieldValidator(data=data) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) + self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']}) def test_null_is_true_fields(self): """ Omitting a value for null-field should validate. """ serializer = PersonSerializer(data={'name': 'marko'}) - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.errors, {}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.errors, {}) def test_modelserializer_max_length_exceeded(self): data = { 'title': 'x' * 201, } serializer = ActionItemSerializer(data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) + + def test_modelserializer_max_length_exceeded_with_custom_restore(self): + """ + When overriding ModelSerializer.restore_object, validation tests should still apply. + Regression test for #623. + + https://github.com/tomchristie/django-rest-framework/pull/623 + """ + data = { + 'title': 'x' * 201, + } + serializer = ActionItemSerializerCustomRestore(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) def test_default_modelfield_max_length_exceeded(self): data = { @@ -286,15 +339,98 @@ class ValidationTests(TestCase): 'info': 'x' * 13, } serializer = ActionItemSerializer(data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) + + def test_datetime_validation_failure(self): + """ + Test DateTimeField validation errors on non-str values. + Regression test for #669. + + https://github.com/tomchristie/django-rest-framework/issues/669 + """ + data = self.data + data['created'] = 0 + + serializer = CommentSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + + self.assertIn('created', serializer.errors) + + def test_missing_model_field_exception_msg(self): + """ + Assert that a meaningful exception message is outputted when the model + field is missing (e.g. when mistyping ``model``). + """ + try: + serializer = BrokenModelSerializer() + except AssertionError as e: + self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") + except: + self.fail('Wrong exception type thrown.') + + +class CustomValidationTests(TestCase): + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_email(self, attrs, source): + value = attrs[source] + return attrs + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + def test_field_validation(self): + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'content': ['Test not in value']}) + + def test_missing_data(self): + """ + Make sure that validate_content isn't called if the field is missing + """ + incomplete_data = { + 'email': 'tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'content': ['This field is required.']}) + + def test_wrong_data(self): + """ + Make sure that validate_content isn't called if the field input is wrong + """ + wrong_data = { + 'email': 'not an email', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']}) class PositiveIntegerAsChoiceTests(TestCase): def test_positive_integer_in_json_is_correctly_parsed(self): - data = {'some_integer':1} + data = {'some_integer': 1} serializer = PositiveIntegerAsChoiceSerializer(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) + class ModelValidationTests(TestCase): def test_validate_unique(self): @@ -306,7 +442,7 @@ class ModelValidationTests(TestCase): serializer.save() second_serializer = AlbumsSerializer(data={'title': 'a'}) self.assertFalse(second_serializer.is_valid()) - self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) + self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']}) def test_foreign_key_with_partial(self): """ @@ -340,20 +476,19 @@ class ModelValidationTests(TestCase): self.assertTrue(photo_serializer.save()) - class RegexValidationTest(TestCase): def test_create_failed(self): serializer = BookSerializer(data={'isbn': '1234567890'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) serializer = BookSerializer(data={'isbn': '12345678901234'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) def test_create_success(self): serializer = BookSerializer(data={'isbn': '1234567890123'}) @@ -398,7 +533,7 @@ class ManyToManyTests(TestCase): """ serializer = self.serializer_class(instance=self.instance) expected = self.data - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_create(self): """ @@ -406,11 +541,11 @@ class ManyToManyTests(TestCase): """ data = {'rel': [self.anchor.id]} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), [self.anchor]) def test_update(self): """ @@ -420,11 +555,11 @@ class ManyToManyTests(TestCase): new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor]) + self.assertEqual(len(ManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor]) def test_create_empty_relationship(self): """ @@ -433,11 +568,11 @@ class ManyToManyTests(TestCase): """ data = {'rel': []} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), []) def test_update_empty_relationship(self): """ @@ -448,11 +583,11 @@ class ManyToManyTests(TestCase): new_anchor.save() data = {'rel': []} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(list(instance.rel.all()), []) def test_create_empty_relationship_flat_data(self): """ @@ -460,19 +595,20 @@ class ManyToManyTests(TestCase): containing no items, using a representation that does not support lists (eg form data). """ - data = {'rel': ''} + data = MultiValueDict() + data.setlist('rel', ['']) serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), []) class ReadOnlyManyToManyTests(TestCase): def setUp(self): class ReadOnlyManyToManySerializer(serializers.ModelSerializer): - rel = serializers.ManyRelatedField(read_only=True) + rel = serializers.RelatedField(many=True, read_only=True) class Meta: model = ReadOnlyManyToManyModel @@ -500,12 +636,12 @@ class ReadOnlyManyToManyTests(TestCase): new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) + self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) # rel is still as original (1 entry) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(list(instance.rel.all()), [self.anchor]) def test_update_without_relationship(self): """ @@ -516,12 +652,12 @@ class ReadOnlyManyToManyTests(TestCase): new_anchor.save() data = {} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) + self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) # rel is still as original (1 entry) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(list(instance.rel.all()), [self.anchor]) class DefaultValueTests(TestCase): @@ -536,20 +672,35 @@ class DefaultValueTests(TestCase): def test_create_using_default(self): data = {} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'foobar') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'foobar') def test_create_overriding_default(self): data = {'text': 'overridden'} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) + instance = serializer.save() + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'overridden') + + def test_partial_update_default(self): + """ Regression test for issue #532 """ + data = {'text': 'overridden'} + serializer = self.serializer_class(data=data, partial=True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'overridden') + + data = {'extra': 'extra_value'} + serializer = self.serializer_class(instance=instance, data=data, partial=True) + self.assertEqual(serializer.is_valid(), True) + instance = serializer.save() + + self.assertEqual(instance.extra, 'extra_value') + self.assertEqual(instance.text, 'overridden') class CallableDefaultValueTests(TestCase): @@ -564,20 +715,20 @@ class CallableDefaultValueTests(TestCase): def test_create_using_default(self): data = {} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'foobar') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'foobar') def test_create_overriding_default(self): data = {'text': 'overridden'} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'overridden') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'overridden') class ManyRelatedTests(TestCase): @@ -604,6 +755,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post") @@ -626,12 +814,13 @@ class ManyRelatedTests(TestCase): class RelatedTraversalTest(TestCase): def test_nested_traversal(self): + """ + Source argument should support dotted.source notation. + """ user = Person.objects.create(name="django") post = BlogPost.objects.create(title="Test blog post", writer=user) post.blogpostcomment_set.create(text="I love this blog post") - from rest_framework.tests.models import BlogPostComment - class PersonSerializer(serializers.ModelSerializer): class Meta: model = Person @@ -652,11 +841,11 @@ class RelatedTraversalTest(TestCase): serializer = BlogPostSerializer(instance=post) expected = { - 'title': u'Test blog post', + 'title': 'Test blog post', 'comments': [{ - 'text': u'I love this blog post', + 'text': 'I love this blog post', 'post_owner': { - "name": u"django", + "name": "django", "age": None } }] @@ -664,6 +853,41 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) + def test_nested_traversal_with_none(self): + """ + If a component of the dotted.source is None, return None for the field. + """ + from rest_framework.tests.models import NullableForeignKeySource + instance = NullableForeignKeySource.objects.create(name='Source with null FK') + + class NullableSourceSerializer(serializers.Serializer): + target_name = serializers.Field(source='target.name') + + serializer = NullableSourceSerializer(instance=instance) + + expected = { + 'target_name': None, + } + + self.assertEqual(serializer.data, expected) + + def test_queryset_nested_traversal(self): + """ + Relational fields should be able to use methods as their source. + """ + BlogPost.objects.create(title='blah') + + class QuerysetMethodSerializer(serializers.Serializer): + blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') + + class ClassWithQuerysetMethod(object): + def get_all_blogposts(self): + return BlogPost.objects + + obj = ClassWithQuerysetMethod() + serializer = QuerysetMethodSerializer(obj) + self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) + class SerializerMethodFieldTests(TestCase): def setUp(self): @@ -691,8 +915,8 @@ class SerializerMethodFieldTests(TestCase): serializer = self.serializer_class(source_data) expected = { - 'beep': u'hello!', - 'boop': [u'a', u'b', u'c'], + 'beep': 'hello!', + 'boop': ['a', 'b', 'c'], 'boop_count': 3, } @@ -708,7 +932,7 @@ class BlankFieldTests(TestCase): model = BlankFieldModel class BlankFieldSerializer(serializers.Serializer): - title = serializers.CharField(blank=True) + title = serializers.CharField(required=False) class NotBlankFieldModelSerializer(serializers.ModelSerializer): class Meta: @@ -725,15 +949,15 @@ class BlankFieldTests(TestCase): def test_create_blank_field(self): serializer = self.serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_model_blank_field(self): serializer = self.model_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_model_null_field(self): serializer = self.model_serializer_class(data={'title': None}) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_not_blank_field(self): """ @@ -741,7 +965,7 @@ class BlankFieldTests(TestCase): is considered invalid in a non-model serializer """ serializer = self.not_blank_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) def test_create_model_not_blank_field(self): """ @@ -749,7 +973,11 @@ class BlankFieldTests(TestCase): is considered invalid in a model serializer """ serializer = self.not_blank_model_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) + + def test_create_model_empty_field(self): + serializer = self.model_serializer_class(data={}) + self.assertEqual(serializer.is_valid(), True) #test for issue #460 @@ -773,28 +1001,45 @@ class SerializerPickleTests(TestCase): class Meta: model = Person fields = ('name', 'age') - pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data) + pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0) + + def test_getstate_method_should_not_return_none(self): + """ + Regression test for #645. + """ + data = serializers.DictWithMetadata({1: 1}) + self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1})) + + def test_serializer_data_is_pickleable(self): + """ + Another regression test for #645. + """ + data = serializers.SortedDictWithMetadata({1: 1}) + repr(pickle.loads(pickle.dumps(data, 0))) class DepthTest(TestCase): def test_implicit_nesting(self): + writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - class BlogPostSerializer(serializers.ModelSerializer): + class BlogPostCommentSerializer(serializers.ModelSerializer): class Meta: - model = BlogPost - depth = 1 + model = BlogPostComment + depth = 2 - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': u'Test blog post', - 'writer': {'id': 1, 'name': u'django', 'age': 1}} + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) def test_explicit_nesting(self): writer = Person.objects.create(name="django", age=1) post = BlogPost.objects.create(title="Test blog post", writer=writer) + comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) class PersonSerializer(serializers.ModelSerializer): class Meta: @@ -806,9 +1051,15 @@ class DepthTest(TestCase): class Meta: model = BlogPost - serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': u'Test blog post', - 'writer': {'id': 1, 'name': u'django', 'age': 1}} + class BlogPostCommentSerializer(serializers.ModelSerializer): + blog_post = BlogPostSerializer() + + class Meta: + model = BlogPostComment + + serializer = BlogPostCommentSerializer(instance=comment) + expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}}} self.assertEqual(serializer.data, expected) @@ -865,6 +1116,35 @@ class NestedSerializerContextTests(TestCase): AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data +class DeserializeListTestCase(TestCase): + + def setUp(self): + self.data = { + 'email': 'nobody@nowhere.com', + 'content': 'This is some test content', + 'created': datetime.datetime(2013, 3, 7), + } + + def test_no_errors(self): + data = [self.data.copy() for x in range(0, 3)] + serializer = CommentSerializer(data=data, many=True) + self.assertTrue(serializer.is_valid()) + self.assertTrue(isinstance(serializer.object, list)) + self.assertTrue( + all((isinstance(item, Comment) for item in serializer.object)) + ) + + def test_errors_return_as_list(self): + invalid_item = self.data.copy() + invalid_item['email'] = '' + data = [self.data.copy(), invalid_item, self.data.copy()] + + serializer = CommentSerializer(data=data, many=True) + self.assertFalse(serializer.is_valid()) + expected = [{}, {'email': ['This field is required.']}, {}] + self.assertEqual(serializer.errors, expected) + + # Test for issue #467 class FieldLabelTest(TestCase): def setUp(self): @@ -891,3 +1171,4 @@ class FieldLabelTest(TestCase): self.assertEquals(u'Label', fields.Field(label='Label', help_text='Help').label) self.assertEquals(u'Help', fields.CharField(label='Label', help_text='Help').help_text) self.assertEquals(u'Label', fields.ManyHyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help').label) + diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py new file mode 100644 index 00000000..8b0ded1a --- /dev/null +++ b/rest_framework/tests/serializer_bulk_update.py @@ -0,0 +1,278 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): + """ + Creating multiple instances using serializers. + """ + + def setUp(self): + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + self.BookSerializer = BookSerializer + + def test_bulk_create_success(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_bulk_create_errors(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 'foo', + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {}, + {'id': ['Enter a whole number.']} + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_list_datatype(self): + """ + Data containing list of incorrect data type should return errors. + """ + data = ['foo', 'bar', 'baz'] + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = [ + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']} + ] + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_datatype(self): + """ + Data containing a single incorrect data type should return errors. + """ + data = 123 + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_object(self): + """ + Data containing only a single object, instead of a list of objects + should return errors. + """ + data = { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + } + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): + """ + Updating multiple instances using serializers. + """ + + def setUp(self): + class Book(object): + """ + A data type that can be persisted to a mock storage backend + with `.save()` and `.delete()`. + """ + object_map = {} + + def __init__(self, id, title, author): + self.id = id + self.title = title + self.author = author + + def save(self): + Book.object_map[self.id] = self + + def delete(self): + del Book.object_map[self.id] + + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + def restore_object(self, attrs, instance=None): + if instance: + instance.id = attrs['id'] + instance.title = attrs['title'] + instance.author = attrs['author'] + return instance + return Book(**attrs) + + self.Book = Book + self.BookSerializer = BookSerializer + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + for item in data: + book = Book(item['id'], item['title'], item['author']) + book.save() + + def books(self): + """ + Return all the objects in the mock storage backend. + """ + return self.Book.object_map.values() + + def test_bulk_update_success(self): + """ + Correct bulk update serialization should return the input data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 2, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + + self.assertEqual(data, new_data) + + def test_bulk_update_and_create(self): + """ + Bulk update serialization may also include created items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + self.assertEqual(data, new_data) + + def test_bulk_update_invalid_create(self): + """ + Bulk update serialization without allow_add_remove may not create items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_bulk_update_error(self): + """ + Incorrect bulk update serialization should return error data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 'foo', + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'id': ['Enter a whole number.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py new file mode 100644 index 00000000..71d0e24b --- /dev/null +++ b/rest_framework/tests/serializer_nested.py @@ -0,0 +1,246 @@ +""" +Tests to cover nested serializers. + +Doesn't cover model serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class WritableNestedSerializerBasicTests(TestCase): + """ + Tests for deserializing nested entities. + Basic tests that use serializers that simply restore to dicts. + """ + + def setUp(self): + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + expected_errors = { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_many_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data + when multiple entities are being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + ] + expected_errors = [ + {}, + { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + ] + + serializer = self.AlbumSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + +class WritableNestedSerializerObjectTests(TestCase): + """ + Tests for deserializing nested entities. + These tests use serializers that restore to concrete objects. + """ + + def setUp(self): + # Couple of concrete objects that we're going to deserialize into + class Track(object): + def __init__(self, order, title, duration): + self.order, self.title, self.duration = order, title, duration + + def __eq__(self, other): + return ( + self.order == other.order and + self.title == other.title and + self.duration == other.duration + ) + + class Album(object): + def __init__(self, album_name, artist, tracks): + self.album_name, self.artist, self.tracks = album_name, artist, tracks + + def __eq__(self, other): + return ( + self.album_name == other.album_name and + self.artist == other.artist and + self.tracks == other.tracks + ) + + # And their corresponding serializers + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + def restore_object(self, attrs, instance=None): + return Track(attrs['order'], attrs['title'], attrs['duration']) + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + def restore_object(self, attrs, instance=None): + return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) + + self.Album, self.Track = Album, Track + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return a restored object + that corresponds to the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + expected_object = self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) + + def test_many_nested_validation_success(self): + """ + Correct nested serialization should return multiple restored objects + that corresponds to the input data when multiple objects are + being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + ] + expected_object = [ + self.Album( + album_name='Russian Red', + artist='I Love Your Glasses', + tracks=[ + self.Track(order=1, title='Cigarettes', duration=121), + self.Track(order=2, title='No Past Land', duration=198), + self.Track(order=3, title='They Don\'t Believe', duration=191), + ] + ), + self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + ] + + serializer = self.AlbumSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py new file mode 100644 index 00000000..857375c2 --- /dev/null +++ b/rest_framework/tests/settings.py @@ -0,0 +1,22 @@ +"""Tests for the settings module""" +from __future__ import unicode_literals +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): + """Tests relating to the api settings""" + + def test_non_import_errors(self): + """Make sure other errors aren't suppressed.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ValueError): + settings.DEFAULT_MODEL_SERIALIZER_CLASS + + def test_import_error_message_maintained(self): + """Make sure real import errors are captured and raised sensibly.""" + settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) + with self.assertRaises(ImportError) as cm: + settings.DEFAULT_MODEL_SERIALIZER_CLASS + self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py deleted file mode 100644 index 30df5cef..00000000 --- a/rest_framework/tests/status.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Tests for the status module""" -from django.test import TestCase -from rest_framework import status - - -class TestStatus(TestCase): - """Simple sanity test to check the status module""" - - def test_status(self): - """Ensure the status module is present and correct.""" - self.assertEquals(200, status.HTTP_200_OK) - self.assertEquals(404, status.HTTP_404_NOT_FOUND) diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py index 97f492ff..f8c2579e 100644 --- a/rest_framework/tests/testcases.py +++ b/rest_framework/tests/testcases.py @@ -1,4 +1,5 @@ # http://djangosnippets.org/snippets/1011/ +from __future__ import unicode_literals from django.conf import settings from django.core.management import call_command from django.db.models import loading diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py index adeaf6da..08f88e11 100644 --- a/rest_framework/tests/tests.py +++ b/rest_framework/tests/tests.py @@ -2,6 +2,7 @@ Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers. """ +from __future__ import unicode_literals import os modules = [filename.rsplit('.', 1)[0] diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py index 4b98b941..11cbd8eb 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/throttling.py @@ -1,11 +1,10 @@ """ Tests for the throttling implementations in the permissions module. """ - +from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache - from django.test.client import RequestFactory from rest_framework.views import APIView from rest_framework.throttling import UserRateThrottle @@ -104,7 +103,7 @@ class ThrottlingTests(TestCase): self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: - self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) + self.assertEqual(response['X-Throttle-Wait-Seconds'], expect) else: self.assertFalse('X-Throttle-Wait-Seconds' in response) diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..29ed4a96 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,76 @@ +from __future__ import unicode_literals +from collections import namedtuple +from django.core import urlresolvers +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): + pass + + +class FormatSuffixTests(TestCase): + """ + Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. + """ + def _resolve_urlpatterns(self, urlpatterns, test_paths): + factory = RequestFactory() + try: + urlpatterns = format_suffix_patterns(urlpatterns) + except Exception: + self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") + resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + for test_path in test_paths: + request = factory.get(test_path.path) + try: + callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) + except Exception: + self.fail("Failed to resolve URL: %s" % request.path_info) + self.assertEqual(callback_args, test_path.args) + self.assertEqual(callback_kwargs, test_path.kwargs) + + def test_format_suffix(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view), + ) + test_paths = [ + URLTestPath('/test', (), {}), + URLTestPath('/test.api', (), {'format': 'api'}), + URLTestPath('/test.asdf', (), {'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_default_args(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view, {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test', (), {'foo': 'bar', }), + URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_included_urls(self): + nested_patterns = patterns( + '', + url(r'^path$', dummy_view) + ) + urlpatterns = patterns( + '', + url(r'^test/', include(nested_patterns), {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test/path', (), {'foo': 'bar', }), + URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..8c87917d --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals +from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory +from django.test.client import MULTIPART_CONTENT +from rest_framework.compat import urlparse + + +class RequestFactory(_RequestFactory): + + def __init__(self, **defaults): + super(RequestFactory, self).__init__(**defaults) + + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + **extra): + "Construct a PATCH request." + + patch_data = self._encode_data(data, content_type) + + parsed = urlparse.urlparse(path) + r = { + 'CONTENT_LENGTH': len(patch_data), + 'CONTENT_TYPE': content_type, + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': parsed[4], + 'REQUEST_METHOD': 'PATCH', + 'wsgi.input': FakePayload(patch_data), + } + r.update(extra) + return self.request(**r) + + +class Client(_Client, RequestFactory): + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + follow=False, **extra): + """ + Send a resource to the server using PATCH. + """ + response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/validation.py new file mode 100644 index 00000000..cbdd6515 --- /dev/null +++ b/rest_framework/tests/validation.py @@ -0,0 +1,65 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import generics, serializers, status +from rest_framework.tests.utils import RequestFactory +import json + +factory = RequestFactory() + + +# Regression for #666 + +class ValidationModel(models.Model): + blank_validated_field = models.CharField(max_length=255) + + +class ValidationModelSerializer(serializers.ModelSerializer): + class Meta: + model = ValidationModel + fields = ('blank_validated_field',) + read_only_fields = ('blank_validated_field',) + + +class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): + model = ValidationModel + serializer_class = ValidationModelSerializer + + +class TestPreSaveValidationExclusions(TestCase): + def test_pre_save_validation_exclusions(self): + """ + Somewhat weird test case to ensure that we don't perform model + validation on read only fields. + """ + obj = ValidationModel.objects.create(blank_validated_field='') + request = factory.put('/', json.dumps({}), + content_type='application/json') + view = UpdateValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +# Regression for #653 + +class ShouldValidateModel(models.Model): + should_validate_field = models.CharField(max_length=255) + + +class ShouldValidateModelSerializer(serializers.ModelSerializer): + renamed = serializers.CharField(source='should_validate_field', required=False) + + class Meta: + model = ShouldValidateModel + fields = ('renamed',) + + +class TestPreSaveValidationExclusions(TestCase): + def test_renamed_fields_are_model_validated(self): + """ + Ensure fields with 'source' applied do get still get model validation. + """ + # We've set `required=False` on the serializer, but the model + # does not have `blank=True`, so this serializer should not validate. + serializer = ShouldValidateModelSerializer(data={'renamed': ''}) + self.assertEqual(serializer.is_valid(), False) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py deleted file mode 100644 index c032985e..00000000 --- a/rest_framework/tests/validators.py +++ /dev/null @@ -1,329 +0,0 @@ -# from django import forms -# from django.db import models -# from django.test import TestCase -# from rest_framework.response import ImmediateResponse -# from rest_framework.views import View - - -# class TestDisabledValidations(TestCase): -# """Tests on FormValidator with validation disabled by setting form to None""" - -# def test_disabled_form_validator_returns_content_unchanged(self): -# """If the view's form attribute is None then FormValidator(view).validate_request(content, None) -# should just return the content unmodified.""" -# class DisabledFormResource(FormResource): -# form = None - -# class MockView(View): -# resource = DisabledFormResource - -# view = MockView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(FormResource(view).validate_request(content, None), content) - -# def test_disabled_form_validator_get_bound_form_returns_none(self): -# """If the view's form attribute is None on then -# FormValidator(view).get_bound_form(content) should just return None.""" -# class DisabledFormResource(FormResource): -# form = None - -# class MockView(View): -# resource = DisabledFormResource - -# view = MockView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(FormResource(view).get_bound_form(content), None) - -# def test_disabled_model_form_validator_returns_content_unchanged(self): -# """If the view's form is None and does not have a Resource with a model set then -# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified.""" - -# class DisabledModelFormView(View): -# resource = ModelResource - -# view = DisabledModelFormView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(ModelResource(view).get_bound_form(content), None) - -# def test_disabled_model_form_validator_get_bound_form_returns_none(self): -# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None.""" -# class DisabledModelFormView(View): -# resource = ModelResource - -# view = DisabledModelFormView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(ModelResource(view).get_bound_form(content), None) - - -# class TestNonFieldErrors(TestCase): -# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)""" - -# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self): -# """If validation fails with a non-field error, ensure the response a non-field error""" -# class MockForm(forms.Form): -# field1 = forms.CharField(required=False) -# field2 = forms.CharField(required=False) -# ERROR_TEXT = 'You may not supply both field1 and field2' - -# def clean(self): -# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data: -# raise forms.ValidationError(self.ERROR_TEXT) -# return self.cleaned_data - -# class MockResource(FormResource): -# form = MockForm - -# class MockView(View): -# pass - -# view = MockView() -# content = {'field1': 'example1', 'field2': 'example2'} -# try: -# MockResource(view).validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]}) -# else: -# self.fail('ImmediateResponse was not raised') - - -# class TestFormValidation(TestCase): -# """Tests which check basic form validation. -# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set. -# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)""" -# def setUp(self): -# class MockForm(forms.Form): -# qwerty = forms.CharField(required=True) - -# class MockFormResource(FormResource): -# form = MockForm - -# class MockModelResource(ModelResource): -# form = MockForm - -# class MockFormView(View): -# resource = MockFormResource - -# class MockModelFormView(View): -# resource = MockModelResource - -# self.MockFormResource = MockFormResource -# self.MockModelResource = MockModelResource -# self.MockFormView = MockFormView -# self.MockModelFormView = MockModelFormView - -# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator): -# """If the content is already valid and clean then validate(content) should just return the content unmodified.""" -# content = {'qwerty': 'uiop'} -# self.assertEqual(validator.validate_request(content, None), content) - -# def validation_failure_raises_response_exception(self, validator): -# """If form validation fails a ResourceException 400 (Bad Request) should be raised.""" -# content = {} -# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) - -# def validation_does_not_allow_extra_fields_by_default(self, validator): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) - -# def validation_allows_extra_fields_if_explicitly_set(self, validator): -# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names.""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# validator._validate(content, None, allowed_extra_fields=('extra',)) - -# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator): -# """If we set ``unknown_form_fields`` on the form resource, then don't -# raise errors on unexpected request data""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# validator.allow_unknown_form_fields = True -# self.assertEqual({'qwerty': u'uiop'}, -# validator.validate_request(content, None), -# "Resource didn't accept unknown fields.") -# validator.allow_unknown_form_fields = False - -# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator): -# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names.""" -# content = {'qwerty': 'uiop'} -# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content) - -# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator): -# """If validation fails due to no content, ensure the response contains a single non-field error""" -# content = {} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator): -# """If validation fails due to a field error, ensure the response contains a single field error""" -# content = {'qwerty': ''} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator): -# """If validation fails due to an invalid field, ensure the response contains a single field error""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator): -# """If validation for multiple reasons, ensure the response contains each error""" -# content = {'qwerty': '', 'extra': 'extra'} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'], -# 'extra': ['This field does not exist.']}}) -# else: -# self.fail('ResourceException was not raised') - -# # Tests on FormResource - -# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) - -# def test_form_validation_failure_raises_response_exception(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failure_raises_response_exception(validator) - -# def test_validation_does_not_allow_extra_fields_by_default(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_does_not_allow_extra_fields_by_default(validator) - -# def test_validation_allows_extra_fields_if_explicitly_set(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_allows_extra_fields_if_explicitly_set(validator) - -# def test_validation_allows_unknown_fields_if_explicitly_allowed(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_allows_unknown_fields_if_explicitly_allowed(validator) - -# def test_validation_does_not_require_extra_fields_if_explicitly_set(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) - -# def test_validation_failed_due_to_no_content_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_field_error_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) - -# # Same tests on ModelResource - -# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) - -# def test_modelform_validation_failure_raises_response_exception(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failure_raises_response_exception(validator) - -# def test_modelform_validation_does_not_allow_extra_fields_by_default(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_does_not_allow_extra_fields_by_default(validator) - -# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_allows_extra_fields_if_explicitly_set(validator) - -# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) - -# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) - - -# class TestModelFormValidator(TestCase): -# """Tests specific to ModelFormValidatorMixin""" - -# def setUp(self): -# """Create a validator for a model with two fields and a property.""" -# class MockModel(models.Model): -# qwerty = models.CharField(max_length=256) -# uiop = models.CharField(max_length=256, blank=True) - -# @property -# def read_only(self): -# return 'read only' - -# class MockResource(ModelResource): -# model = MockModel - -# class MockView(View): -# resource = MockResource - -# self.validator = MockResource(MockView) - -# def test_property_fields_are_allowed_on_model_forms(self): -# """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} -# self.assertEqual(self.validator.validate_request(content, None), content) - -# def test_property_fields_are_not_required_on_model_forms(self): -# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example'} -# self.assertEqual(self.validator.validate_request(content, None), content) - -# def test_extra_fields_not_allowed_on_model_forms(self): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} -# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) - -# def test_validate_requires_fields_on_model_forms(self): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'read_only': 'read only'} -# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) - -# def test_validate_does_not_require_blankable_fields_on_model_forms(self): -# """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" -# content = {'qwerty': 'example', 'read_only': 'read only'} -# self.validator.validate_request(content, None) - -# def test_model_form_validator_uses_model_forms(self): -# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm)) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index 43365e07..994cf6dc 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -1,4 +1,4 @@ -import copy +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework import status @@ -6,6 +6,7 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +import copy factory = RequestFactory() @@ -18,7 +19,7 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.DATA}) -@api_view(['GET', 'POST', 'PUT']) +@api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': return {'method': 'GET'} @@ -26,6 +27,8 @@ def basic_view(request): return {'method': 'POST', 'data': request.DATA} elif request.method == 'PUT': return {'method': 'PUT', 'data': request.DATA} + elif request.method == 'PATCH': + return {'method': 'PATCH', 'data': request.DATA} def sanitise_json_error(error_dict): @@ -47,10 +50,10 @@ class ClassBasedViewIntegrationTests(TestCase): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) def test_400_parse_error_tunneled_content(self): content = 'f00bar' @@ -62,10 +65,10 @@ class ClassBasedViewIntegrationTests(TestCase): request = factory.post('/', form_data) response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) class FunctionBasedViewIntegrationTests(TestCase): @@ -76,10 +79,10 @@ class FunctionBasedViewIntegrationTests(TestCase): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) def test_400_parse_error_tunneled_content(self): content = 'f00bar' @@ -91,7 +94,7 @@ class FunctionBasedViewIntegrationTests(TestCase): request = factory.post('/', form_data) response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 8fe64248..93ea9816 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,7 +1,11 @@ -import time +""" +Provides various throttling policies. +""" +from __future__ import unicode_literals from django.core.cache import cache from rest_framework import exceptions from rest_framework.settings import api_settings +import time class BaseThrottle(object): @@ -27,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle): A simple cache implementation, that only requires `.get_cache_key()` to be overridden. - The rate (requests / seconds) is set by a :attr:`throttle` attribute - on the :class:`.View` class. The attribute is a string of the form 'number of - requests/period'. + The rate (requests / seconds) is set by a `throttle` attribute on the View + class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 0ad926fa..d9143bb4 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,7 +1,38 @@ -from django.conf.urls.defaults import url +from __future__ import unicode_literals +from django.core.urlresolvers import RegexURLResolver +from rest_framework.compat import url, include from rest_framework.settings import api_settings +def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): + ret = [] + for urlpattern in urlpatterns: + if isinstance(urlpattern, RegexURLResolver): + # Set of included URL patterns + regex = urlpattern.regex.pattern + namespace = urlpattern.namespace + app_name = urlpattern.app_name + kwargs = urlpattern.default_kwargs + # Add in the included patterns, after applying the suffixes + patterns = apply_suffix_patterns(urlpattern.url_patterns, + suffix_pattern, + suffix_required) + ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) + + else: + # Regular URL pattern + regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern + view = urlpattern._callback or urlpattern._callback_str + kwargs = urlpattern.default_args + name = urlpattern.name + # Add in both the existing and the new urlpattern + if not suffix_required: + ret.append(urlpattern) + ret.append(url(regex, view, kwargs, name)) + + return ret + + def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): """ Supplement existing urlpatterns with corresponding patterns that also @@ -28,15 +59,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): else: suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg - ret = [] - for urlpattern in urlpatterns: - # Form our complementing '.format' urlpattern - regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern - view = urlpattern._callback or urlpattern._callback_str - kwargs = urlpattern.default_args - name = urlpattern.name - # Add in both the existing and the new urlpattern - if not suffix_required: - ret.append(urlpattern) - ret.append(url(regex, view, kwargs, name)) - return ret + return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/urls.py b/rest_framework/urls.py index fbe4bc07..9c4719f1 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -12,6 +12,7 @@ your authentication settings include `SessionAuthentication`. url(r'^auth', include('rest_framework.urls', namespace='rest_framework')) ) """ +from __future__ import unicode_literals from rest_framework.compat import patterns, url diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py index 84fcb5db..e69de29b 100644 --- a/rest_framework/utils/__init__.py +++ b/rest_framework/utils/__init__.py @@ -1,100 +0,0 @@ -from django.utils.encoding import smart_unicode -from django.utils.xmlutils import SimplerXMLGenerator -from rest_framework.compat import StringIO -import re -import xml.etree.ElementTree as ET - - -# From xml2dict -class XML2Dict(object): - - def __init__(self): - pass - - def _parse_node(self, node): - node_tree = {} - # Save attrs and text, hope there will not be a child with same name - if node.text: - node_tree = node.text - for (k, v) in node.attrib.items(): - k, v = self._namespace_split(k, v) - node_tree[k] = v - #Save childrens - for child in node.getchildren(): - tag, tree = self._namespace_split(child.tag, self._parse_node(child)) - if tag not in node_tree: # the first time, so store it in dict - node_tree[tag] = tree - continue - old = node_tree[tag] - if not isinstance(old, list): - node_tree.pop(tag) - node_tree[tag] = [old] # multi times, so change old dict to a list - node_tree[tag].append(tree) # add the new one - - return node_tree - - def _namespace_split(self, tag, value): - """ - Split the tag '{http://cs.sfsu.edu/csc867/myscheduler}patients' - ns = http://cs.sfsu.edu/csc867/myscheduler - name = patients - """ - result = re.compile("\{(.*)\}(.*)").search(tag) - if result: - value.namespace, tag = result.groups() - return (tag, value) - - def parse(self, file): - """parse a xml file to a dict""" - f = open(file, 'r') - return self.fromstring(f.read()) - - def fromstring(self, s): - """parse a string""" - t = ET.fromstring(s) - unused_root_tag, root_tree = self._namespace_split(t.tag, self._parse_node(t)) - return root_tree - - -def xml2dict(input): - return XML2Dict().fromstring(input) - - -# Piston: -class XMLRenderer(): - def _to_xml(self, xml, data): - if isinstance(data, (list, tuple)): - for item in data: - xml.startElement("list-item", {}) - self._to_xml(xml, item) - xml.endElement("list-item") - - elif isinstance(data, dict): - for key, value in data.iteritems(): - xml.startElement(key, {}) - self._to_xml(xml, value) - xml.endElement(key) - - elif data is None: - # Don't output any value - pass - - else: - xml.characters(smart_unicode(data)) - - def dict2xml(self, data): - stream = StringIO.StringIO() - - xml = SimplerXMLGenerator(stream, "utf-8") - xml.startDocument() - xml.startElement("root", {}) - - self._to_xml(xml, data) - - xml.endElement("root") - xml.endDocument() - return stream.getvalue() - - -def dict2xml(input): - return XMLRenderer().dict2xml(input) diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 80e39d46..d51374b0 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,25 +1,37 @@ +from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): - """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url).""" + """ + Given a url returns a list of breadcrumbs, which are each a + tuple of (name, url). + """ from rest_framework.views import APIView def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): - """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" + """ + Add tuples of (name, url) to the breadcrumbs list, + progressively chomping off parts of the url. + """ try: (view, unused_args, unused_kwargs) = resolve(url) except Exception: pass else: - # Check if this is a REST framework view, and if so add it to the breadcrumbs - if isinstance(getattr(view, 'cls_instance', None), APIView): + # Check if this is a REST framework view, + # and if so add it to the breadcrumbs + cls = getattr(view, 'cls', None) + if cls is not None and issubclass(cls, APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) + suffix = getattr(view, 'suffix', None) + name = get_view_name(view.cls, suffix) + breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) if url == '': @@ -27,11 +39,15 @@ def get_breadcrumbs(url): return breadcrumbs_list elif url.endswith('/'): - # Drop trailing slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) - - # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs - return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) + # Drop trailing slash off the end and continue to try to + # resolve more breadcrumbs + url = url.rstrip('/') + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) + + # Drop trailing non-slash off the end and continue to try to + # resolve more breadcrumbs + url = url[:url.rfind('/') + 1] + return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) prefix = get_script_prefix().rstrip('/') url = url[len(prefix):] diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 2d1fb353..b6de18a8 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -1,18 +1,19 @@ """ Helper classes for parsers. """ -import datetime -import decimal -import types -from django.utils import simplejson as json +from __future__ import unicode_literals from django.utils.datastructures import SortedDict from rest_framework.compat import timezone from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata +import datetime +import decimal +import types +import json class JSONEncoder(json.JSONEncoder): """ - JSONEncoder subclass that knows how to encode date/time, + JSONEncoder subclass that knows how to encode date/time/timedelta, decimal types, and generators. """ def default(self, o): @@ -34,6 +35,8 @@ class JSONEncoder(json.JSONEncoder): if o.microsecond: r = r[:12] return r + elif isinstance(o, datetime.timedelta): + return str(o.total_seconds()) elif isinstance(o, decimal.Decimal): return str(o) elif hasattr(o, '__iter__'): diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py new file mode 100644 index 00000000..ebadb3a6 --- /dev/null +++ b/rest_framework/utils/formatting.py @@ -0,0 +1,80 @@ +""" +Utility functions to return a formatted name and description for a given view. +""" +from __future__ import unicode_literals + +from django.utils.html import escape +from django.utils.safestring import mark_safe +from rest_framework.compat import apply_markdown +import re + + +def _remove_trailing_string(content, trailing): + """ + Strip trailing component `trailing` from `content` if it exists. + Used when generating names from view classes. + """ + if content.endswith(trailing) and content != trailing: + return content[:-len(trailing)] + return content + + +def _remove_leading_indent(content): + """ + Remove leading indent from a block of text. + Used when generating descriptions from docstrings. + """ + whitespace_counts = [len(line) - len(line.lstrip(' ')) + for line in content.splitlines()[1:] if line.lstrip()] + + # unindent the content if needed + if whitespace_counts: + whitespace_pattern = '^' + (' ' * min(whitespace_counts)) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + content = content.strip('\n') + return content + + +def _camelcase_to_spaces(content): + """ + Translate 'CamelCaseNames' to 'Camel Case Names'. + Used when generating names from view classes. + """ + camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' + content = re.sub(camelcase_boundry, ' \\1', content).strip() + return ' '.join(content.split('_')).title() + + +def get_view_name(cls, suffix=None): + """ + Return a formatted name for an `APIView` class or `@api_view` function. + """ + name = cls.__name__ + name = _remove_trailing_string(name, 'View') + name = _remove_trailing_string(name, 'ViewSet') + name = _camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + return name + + +def get_view_description(cls, html=False): + """ + Return a description for an `APIView` class or `@api_view` function. + """ + description = cls.__doc__ or '' + description = _remove_leading_indent(description) + if html: + return markup_description(description) + return description + + +def markup_description(description): + """ + Apply HTML markup to the given description. + """ + if apply_markdown: + description = apply_markdown(description) + else: + description = escape(description).replace('\n', '<br />') + return mark_safe(description) diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index ee7f3a54..c09c2933 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -3,8 +3,9 @@ Handling of media types, as found in HTTP Content-Type and Accept headers. See http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7 """ - +from __future__ import unicode_literals from django.http.multipartparser import parse_header +from rest_framework import HTTP_HEADER_ENCODING def media_type_matches(lhs, rhs): @@ -47,7 +48,7 @@ class _MediaType(object): if media_type_str is None: media_type_str = '' self.orig = media_type_str - self.full_type, self.params = parse_header(media_type_str) + self.full_type, self.params = parse_header(media_type_str.encode(HTTP_HEADER_ENCODING)) self.main_type, sep, self.sub_type = self.full_type.partition('/') def match(self, other): diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a5..555fa2f4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,54 +1,16 @@ """ -Provides an APIView class that is used as the base of all class-based views. +Provides an APIView class that is the base of all views in REST framework. """ - -import re +from __future__ import unicode_literals from django.core.exceptions import PermissionDenied -from django.http import Http404 -from django.utils.html import escape -from django.utils.safestring import mark_safe +from django.http import Http404, HttpResponse from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown +from rest_framework.compat import View from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings - - -def _remove_trailing_string(content, trailing): - """ - Strip trailing component `trailing` from `content` if it exists. - Used when generating names from view classes. - """ - if content.endswith(trailing) and content != trailing: - return content[:-len(trailing)] - return content - - -def _remove_leading_indent(content): - """ - Remove leading indent from a block of text. - Used when generating descriptions from docstrings. - """ - whitespace_counts = [len(line) - len(line.lstrip(' ')) - for line in content.splitlines()[1:] if line.lstrip()] - - # unindent the content if needed - if whitespace_counts: - whitespace_pattern = '^' + (' ' * min(whitespace_counts)) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content - - -def _camelcase_to_spaces(content): - """ - Translate 'CamelCaseNames' to 'Camel Case Names'. - Used when generating names from view classes. - """ - camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' - content = re.sub(camelcase_boundry, ' \\1', content).strip() - return ' '.join(content.split('_')).title() +from rest_framework.utils.formatting import get_view_name, get_view_description class APIView(View): @@ -64,22 +26,21 @@ class APIView(View): @classmethod def as_view(cls, **initkwargs): """ - Override the default :meth:`as_view` to store an instance of the view - as an attribute on the callable function. This allows us to discover - information about the view when we do URL reverse lookups. + Store the original class on the view function. + + This allows us to discover information about the view when we do URL + reverse lookups. Used for breadcrumb generation. """ - # TODO: deprecate? view = super(APIView, cls).as_view(**initkwargs) - view.cls_instance = cls(**initkwargs) + view.cls = cls return view @property def allowed_methods(self): """ - Return the list of allowed HTTP methods, uppercased. + Wrap Django's private `_allowed_methods` interface in a public property. """ - return [method.upper() for method in self.http_method_names - if hasattr(self, method)] + return self._allowed_methods() @property def default_response_headers(self): @@ -90,43 +51,10 @@ class APIView(View): 'Vary': 'Accept' } - def get_name(self): - """ - Return the resource or view class name for use as this view's name. - Override to customize. - """ - # TODO: deprecate? - name = self.__class__.__name__ - name = _remove_trailing_string(name, 'View') - return _camelcase_to_spaces(name) - - def get_description(self, html=False): - """ - Return the resource or view docstring for use as this view's description. - Override to customize. - """ - # TODO: deprecate? - description = self.__doc__ or '' - description = _remove_leading_indent(description) - if html: - return self.markup_description(description) - return description - - def markup_description(self, description): - """ - Apply HTML markup to the description of this view. - """ - # TODO: deprecate? - if apply_markdown: - description = apply_markdown(description) - else: - description = escape(description).replace('\n', '<br />') - return mark_safe(description) - def metadata(self, request): return { - 'name': self.get_name(), - 'description': self.get_description(), + 'name': get_view_name(self.__class__), + 'description': get_view_description(self.__class__), 'renders': [renderer.media_type for renderer in self.renderer_classes], 'parses': [parser.media_type for parser in self.parser_classes], } @@ -140,7 +68,8 @@ class APIView(View): def http_method_not_allowed(self, request, *args, **kwargs): """ - Called if `request.method` does not correspond to a handler method. + If `request.method` does not correspond to a handler method, + determine what kind of exception to raise. """ raise exceptions.MethodNotAllowed(request.method) @@ -148,6 +77,8 @@ class APIView(View): """ If request is not permitted, determine what kind of exception to raise. """ + if not self.request.successful_authenticator: + raise exceptions.NotAuthenticated() raise exceptions.PermissionDenied() def throttled(self, request, wait): @@ -156,6 +87,15 @@ class APIView(View): """ raise exceptions.Throttled(wait) + def get_authenticate_header(self, request): + """ + If a request is unauthenticated, determine the WWW-Authenticate + header to use for 401 responses, if any. + """ + authenticators = self.get_authenticators() + if authenticators: + return authenticators[0].authenticate_header(request) + def get_parser_context(self, http_request): """ Returns a dict that is passed through to Parser.parse(), @@ -200,13 +140,13 @@ class APIView(View): def get_parsers(self): """ - Instantiates and returns the list of renderers that this view can use. + Instantiates and returns the list of parsers that this view can use. """ return [parser() for parser in self.parser_classes] def get_authenticators(self): """ - Instantiates and returns the list of renderers that this view can use. + Instantiates and returns the list of authenticators that this view can use. """ return [auth() for auth in self.authentication_classes] @@ -241,23 +181,43 @@ class APIView(View): try: return conneg.select_renderer(request, renderers, self.format_kwarg) - except: + except Exception: if force: return (renderers[0], renderers[0].media_type) raise - def has_permission(self, request, obj=None): + def perform_authentication(self, request): """ - Return `True` if the request should be permitted. + Perform authentication on the incoming request. + + Note that if you override this and simply 'pass', then authentication + will instead be performed lazily, the first time either + `request.user` or `request.auth` is accessed. + """ + request.user + + def check_permissions(self, request): + """ + Check if the request should be permitted. + Raises an appropriate exception if the request is not permitted. + """ + for permission in self.get_permissions(): + if not permission.has_permission(request, self): + self.permission_denied(request) + + def check_object_permissions(self, request, obj): + """ + Check if the request should be permitted for a given object. + Raises an appropriate exception if the request is not permitted. """ for permission in self.get_permissions(): - if not permission.has_permission(request, self, obj): - return False - return True + if not permission.has_object_permission(request, self, obj): + self.permission_denied(request) def check_throttles(self, request): """ Check if request should be throttled. + Raises an appropriate exception if the request is throttled. """ for throttle in self.get_throttles(): if not throttle.allow_request(request, self): @@ -284,8 +244,8 @@ class APIView(View): self.format_kwarg = self.get_format_suffix(**kwargs) # Ensure that the incoming request is permitted - if not self.has_permission(request): - self.permission_denied(request) + self.perform_authentication(request) + self.check_permissions(request) self.check_throttles(request) # Perform content negotiation and store the accepted info on the request @@ -296,6 +256,12 @@ class APIView(View): """ Returns the final response object. """ + # Make the error obvious if a proper response is not returned + assert isinstance(response, HttpResponse), ( + 'Expected a `Response` to be returned from the view, ' + 'but received a `%s`' % type(response) + ) + if isinstance(response, Response): if not getattr(request, 'accepted_renderer', None): neg = self.perform_content_negotiation(request, force=True) @@ -319,6 +285,16 @@ class APIView(View): # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + if isinstance(exc, (exceptions.NotAuthenticated, + exceptions.AuthenticationFailed)): + # WWW-Authenticate header for 401 responses, else coerce to 403 + auth_header = self.get_authenticate_header(self.request) + + if auth_header: + self.headers['WWW-Authenticate'] = auth_header + else: + exc.status_code = status.HTTP_403_FORBIDDEN + if isinstance(exc, exceptions.APIException): return Response({'detail': exc.detail}, status=exc.status_code, diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py new file mode 100644 index 00000000..d91323f2 --- /dev/null +++ b/rest_framework/viewsets.py @@ -0,0 +1,139 @@ +""" +ViewSets are essentially just a type of class based view, that doesn't provide +any method handlers, such as `get()`, `post()`, etc... but instead has actions, +such as `list()`, `retrieve()`, `create()`, etc... + +Actions are only bound to methods at the point of instantiating the views. + + user_list = UserViewSet.as_view({'get': 'list'}) + user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically, rather than instantiate views from viewsets directly, you'll +regsiter the viewset with a router and let the URL conf be determined +automatically. + + router = DefaultRouter() + router.register(r'users', UserViewSet, 'user') + urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from functools import update_wrapper +from django.utils.decorators import classonlymethod +from rest_framework import views, generics, mixins + + +class ViewSetMixin(object): + """ + This is the magic. + + Overrides `.as_view()` so that it takes an `actions` keyword that performs + the binding of HTTP methods to actions on the Resource. + + For example, to create a concrete view binding the 'GET' and 'POST' methods + to the 'list' and 'create' actions... + + view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) + """ + + @classonlymethod + def as_view(cls, actions=None, **initkwargs): + """ + Because of the way class based views create a closure around the + instantiated view, we need to totally reimplement `.as_view`, + and slightly modify the view function that is created and returned. + """ + # The suffix initkwarg is reserved for identifing the viewset type + # eg. 'List' or 'Instance'. + cls.suffix = None + + # sanitize keyword arguments + for key in initkwargs: + if key in cls.http_method_names: + raise TypeError("You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." + % (key, cls.__name__)) + if not hasattr(cls, key): + raise TypeError("%s() received an invalid keyword %r" % ( + cls.__name__, key)) + + def view(request, *args, **kwargs): + self = cls(**initkwargs) + # We also store the mapping of request methods to actions, + # so that we can later set the action attribute. + # eg. `self.action = 'list'` on an incoming GET request. + self.action_map = actions + + # Bind methods to actions + # This is the bit that's different to a standard view + for method, action in actions.items(): + handler = getattr(self, action) + setattr(self, method, handler) + + # Patch this in as it's otherwise only present from 1.5 onwards + if hasattr(self, 'get') and not hasattr(self, 'head'): + self.head = self.get + + # And continue as usual + return self.dispatch(request, *args, **kwargs) + + # take name and docstring from class + update_wrapper(view, cls, updated=()) + + # and possible attributes set by decorators + # like csrf_exempt from dispatch + update_wrapper(view, cls.dispatch, assigned=()) + + # We need to set these on the view function, so that breadcrumb + # generation can pick out these bits of information from a + # resolved URL. + view.cls = cls + view.suffix = initkwargs.get('suffix', None) + return view + + def initialize_request(self, request, *args, **kargs): + """ + Set the `.action` attribute on the view, + depending on the request method. + """ + request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs) + self.action = self.action_map.get(request.method.lower()) + return request + + +class ViewSet(ViewSetMixin, views.APIView): + """ + The base ViewSet class does not provide any actions by default. + """ + pass + + +class GenericViewSet(ViewSetMixin, generics.GenericAPIView): + """ + The GenericViewSet class does not provide any actions by default, + but does include the base set of generic view behavior, such as + the `get_object` and `get_queryset` methods. + """ + pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, + mixins.ListModelMixin, + GenericViewSet): + """ + A viewset that provides default `list()` and `retrieve()` actions. + """ + pass + + +class ModelViewSet(mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + GenericViewSet): + """ + A viewset that provides default `create()`, `retrieve()`, `update()`, + `partial_update()`, `destroy()` and `list()` actions. + """ + pass @@ -45,9 +45,9 @@ version = get_version('rest_framework') if sys.argv[-1] == 'publish': os.system("python setup.py sdist upload") - print "You probably want to also tag the version now:" - print " git tag -a %s -m 'version %s'" % (version, version) - print " git push --tags" + print("You probably want to also tag the version now:") + print(" git tag -a %s -m 'version %s'" % (version, version)) + print(" git push --tags") sys.exit() @@ -57,21 +57,28 @@ setup( url='http://django-rest-framework.org', download_url='http://pypi.python.org/pypi/rest_framework/', license='BSD', - description='A lightweight REST framework for Django.', + description='Web APIs for Django, made easy.', author='Tom Christie', - author_email='tom@tomchristie.com', + author_email='tom@tomchristie.com', # SEE NOTE BELOW (*) packages=get_packages('rest_framework'), package_data=get_package_data('rest_framework'), test_suite='rest_framework.runtests.runtests.main', install_requires=[], classifiers=[ - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Environment :: Web Environment', 'Framework :: Django', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', 'Programming Language :: Python', + 'Programming Language :: Python :: 3', 'Topic :: Internet :: WWW/HTTP', ] ) + +# (*) Please direct queries to the discussion group, rather than to me directly +# Doing so helps ensure your question is helpful to other users. +# Queries directly to my email are likely to receive a canned response. +# +# Many thanks for your understanding. @@ -1,36 +1,72 @@ [tox] downloadcache = {toxworkdir}/cache/ -envlist = py2.7-django1.5,py2.7-django1.4,py2.7-django1.3,py2.6-django1.5,py2.6-django1.4,py2.6-django1.3 +envlist = py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,py2.7-django1.4,py2.6-django1.4,py2.7-django1.3,py2.6-django1.3 [testenv] commands = {envpython} rest_framework/runtests/runtests.py -[testenv:py2.7-django1.5] -basepython = python2.7 -deps = https://github.com/django/django/zipball/master - django-filter==0.5.4 +[testenv:py3.3-django1.5] +basepython = python3.3 +deps = django==1.5 + django-filter==0.6a1 + defusedxml==0.3 -[testenv:py2.7-django1.4] -basepython = python2.7 -deps = django==1.4.3 - django-filter==0.5.4 +[testenv:py3.2-django1.5] +basepython = python3.2 +deps = django==1.5 + django-filter==0.6a1 + defusedxml==0.3 -[testenv:py2.7-django1.3] +[testenv:py2.7-django1.5] basepython = python2.7 -deps = django==1.3.5 - django-filter==0.5.4 +deps = django==1.5 + django-filter==0.6a1 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 [testenv:py2.6-django1.5] basepython = python2.6 -deps = https://github.com/django/django/zipball/master - django-filter==0.5.4 +deps = django==1.5 + django-filter==0.6a1 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 + +[testenv:py2.7-django1.4] +basepython = python2.7 +deps = django==1.4.3 + django-filter==0.6a1 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 [testenv:py2.6-django1.4] basepython = python2.6 deps = django==1.4.3 + django-filter==0.6a1 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 + +[testenv:py2.7-django1.3] +basepython = python2.7 +deps = django==1.3.5 django-filter==0.5.4 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 [testenv:py2.6-django1.3] basepython = python2.6 deps = django==1.3.5 django-filter==0.5.4 + defusedxml==0.3 + django-oauth-plus==2.0 + oauth2==1.5.211 + django-oauth2-provider==0.2.3 |
