diff options
Diffstat (limited to 'rest_framework/authentication.py')
| -rw-r--r-- | rest_framework/authentication.py | 80 | 
1 files changed, 55 insertions, 25 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 30c78ebc..fc169189 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -21,32 +21,46 @@ class BaseAuthentication(object):          """          raise NotImplementedError(".authenticate() must be overridden.") +    def authenticate_header(self, request): +        """ +        Return a string to be used as the value of the `WWW-Authenticate` +        header in a `401 Unauthenticated` response, or `None` if the +        authentication scheme should return `403 Permission Denied` responses. +        """ +        pass +  class BasicAuthentication(BaseAuthentication):      """      HTTP Basic authentication against username/password.      """ +    www_authenticate_realm = 'api'      def authenticate(self, request):          """          Returns a `User` if a correct username and password have been supplied          using HTTP Basic authentication.  Otherwise returns `None`.          """ -        if 'HTTP_AUTHORIZATION' in request.META: -            auth = request.META['HTTP_AUTHORIZATION'].split() -            if len(auth) == 2 and auth[0].lower() == "basic": -                try: -                    auth_parts = base64.b64decode(auth[1]).partition(':') -                except TypeError: -                    return None - -                try: -                    userid = smart_unicode(auth_parts[0]) -                    password = smart_unicode(auth_parts[2]) -                except DjangoUnicodeDecodeError: -                    return None - -                return self.authenticate_credentials(userid, password) +        auth = request.META.get('HTTP_AUTHORIZATION', '').split() + +        if not auth or auth[0].lower() != "basic": +            return None + +        if len(auth) != 2: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        try: +            auth_parts = base64.b64decode(auth[1]).partition(':') +        except TypeError: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        try: +            userid = smart_unicode(auth_parts[0]) +            password = smart_unicode(auth_parts[2]) +        except DjangoUnicodeDecodeError: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        return self.authenticate_credentials(userid, password)      def authenticate_credentials(self, userid, password):          """ @@ -55,6 +69,10 @@ class BasicAuthentication(BaseAuthentication):          user = authenticate(username=userid, password=password)          if user is not None and user.is_active:              return (user, None) +        raise exceptions.AuthenticationFailed('Invalid username/password') + +    def authenticate_header(self, request): +        return 'Basic realm="%s"' % self.www_authenticate_realm  class SessionAuthentication(BaseAuthentication): @@ -74,7 +92,7 @@ class SessionAuthentication(BaseAuthentication):          # Unauthenticated, CSRF validation not required          if not user or not user.is_active: -            return +            return None          # Enforce CSRF validation for session based authentication.          class CSRFCheck(CsrfViewMiddleware): @@ -85,7 +103,7 @@ class SessionAuthentication(BaseAuthentication):          reason = CSRFCheck().process_view(http_request, None, (), {})          if reason:              # CSRF failed, bail with explicit error message -            raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) +            raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)          # CSRF passed with authenticated user          return (user, None) @@ -112,14 +130,26 @@ class TokenAuthentication(BaseAuthentication):      def authenticate(self, request):          auth = request.META.get('HTTP_AUTHORIZATION', '').split() -        if len(auth) == 2 and auth[0].lower() == "token": -            key = auth[1] -            try: -                token = self.model.objects.get(key=key) -            except self.model.DoesNotExist: -                return None +        if not auth or auth[0].lower() != "token": +            return None + +        if len(auth) != 2: +            raise exceptions.AuthenticationFailed('Invalid token header') + +        return self.authenticate_credentials(auth[1]) + +    def authenticate_credentials(self, key): +        try: +            token = self.model.objects.get(key=key) +        except self.model.DoesNotExist: +            raise exceptions.AuthenticationFailed('Invalid token') + +        if token.user.is_active: +            return (token.user, token) +        raise exceptions.AuthenticationFailed('User inactive or deleted') + +    def authenticate_header(self, request): +        return 'Token' -            if token.user.is_active: -                return (token.user, token)  # TODO: OAuthAuthentication  | 
