aboutsummaryrefslogtreecommitdiffstats
path: root/djangorestframework/permissions.py
blob: bce03cabc8aecd9c6d66c65d12d03764e2140771 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
The :mod:`permissions` module bundles a set of  permission classes that are used
for checking if a request passes a certain set of constraints. You can assign a permission
class to your view by setting your View's :attr:`permissions` class attribute.
"""

from django.core.cache import cache
from djangorestframework import status
from djangorestframework.response import ErrorResponse
import time

__all__ = (
    'BasePermission',
    'FullAnonAccess',
    'IsAuthenticated',
    'IsAdminUser',
    'IsUserOrIsAnonReadOnly',
    'PerUserThrottling',
    'PerViewThrottling',
    'PerResourceThrottling'
)


_403_FORBIDDEN_RESPONSE = ErrorResponse(
    content={'detail': 'You do not have permission to access this resource. ' +
               'You may need to login or otherwise authenticate the request.'},
    status=status.HTTP_403_FORBIDDEN)

_503_SERVICE_UNAVAILABLE = ErrorResponse(
    content={'detail': 'request was throttled'},
    status=status.HTTP_503_SERVICE_UNAVAILABLE)


class BasePermission(object):
    """
    A base class from which all permission classes should inherit.
    """
    def __init__(self, view):
        """
        Permission classes are always passed the current view on creation.
        """
        self.view = view

    def check_permission(self, auth):
        """
        Should simply return, or raise an :exc:`response.ErrorResponse`.
        """
        pass


class FullAnonAccess(BasePermission):
    """
    Allows full access.
    """

    def check_permission(self, user):
        pass


class IsAuthenticated(BasePermission):
    """
    Allows access only to authenticated users.
    """

    def check_permission(self, user):
        if not user.is_authenticated():
            raise _403_FORBIDDEN_RESPONSE


class IsAdminUser(BasePermission):
    """
    Allows access only to admin users.
    """

    def check_permission(self, user):
        if not user.is_staff:
            raise _403_FORBIDDEN_RESPONSE


class IsUserOrIsAnonReadOnly(BasePermission):
    """
    The request is authenticated as a user, or is a read-only request.
    """

    def check_permission(self, user):
        if (not user.is_authenticated() and
            self.view.method != 'GET' and
            self.view.method != 'HEAD'):
            raise _403_FORBIDDEN_RESPONSE


class BaseThrottle(BasePermission):
    """
    Rate throttling of requests.

    The rate (requests / seconds) is set by a :attr:`throttle` attribute
    on the :class:`.View` class.  The attribute is a string of the form 'number of
    requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """

    attr_name = 'throttle'
    default = '0/sec'
    timer = time.time

    def get_cache_key(self):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.
        """
        pass

    def check_permission(self, auth):
        """
        Check the throttling.
        Return `None` or raise an :exc:`.ErrorResponse`.
        """
        num, period = getattr(self.view, self.attr_name, self.default).split('/')
        self.num_requests = int(num)
        self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        self.auth = auth
        self.check_throttle()

    def check_throttle(self):
        """
        Implement the check to see if the request should be throttled.

        On success calls :meth:`throttle_success`.
        On failure calls :meth:`throttle_failure`.
        """
        self.key = self.get_cache_key()
        self.history = cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            self.throttle_failure()
        else:
            self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        cache.set(self.key, self.history, self.duration)
        header = 'status=SUCCESS; next=%s sec' % self.next()
        self.view.headers['X-Throttle'] = header

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        Raises a '503 service unavailable' response.
        """
        header = 'status=FAILURE; next=%s sec' % self.next()
        self.view.headers['X-Throttle'] = header
        raise _503_SERVICE_UNAVAILABLE

    def next(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1

        return '%.2f' % (remaining_duration / float(available_requests))


class PerUserThrottling(BaseThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique identifier if the user is
    authenticated. For anonymous requests, the IP address of the client will
    be used.
    """

    def get_cache_key(self):
        if self.auth.is_authenticated():
            ident = self.auth.id
        else:
            ident = self.view.request.META.get('REMOTE_ADDR', None)
        return 'throttle_user_%s' % ident


class PerViewThrottling(BaseThrottle):
    """
    Limits the rate of API calls that may be used on a given view.

    The class name of the view is used as a unique identifier to
    throttle against.
    """

    def get_cache_key(self):
        return 'throttle_view_%s' % self.view.__class__.__name__


class PerResourceThrottling(BaseThrottle):
    """
    Limits the rate of API calls that may be used against all views on
    a given resource.

    The class name of the resource is used as a unique identifier to
    throttle against.
    """

    def get_cache_key(self):
        return 'throttle_resource_%s' % self.view.resource.__class__.__name__