base.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """
  2. oauthlib.oauth2.rfc6749
  3. ~~~~~~~~~~~~~~~~~~~~~~~
  4. This module is an implementation of various logic needed
  5. for consuming and providing OAuth 2.0 RFC6749.
  6. """
  7. import functools
  8. import logging
  9. from ..errors import (
  10. FatalClientError, InvalidClientError, InvalidRequestError, OAuth2Error,
  11. ServerError, TemporarilyUnavailableError, UnsupportedTokenTypeError,
  12. )
  13. log = logging.getLogger(__name__)
  14. class BaseEndpoint:
  15. def __init__(self):
  16. self._available = True
  17. self._catch_errors = False
  18. self._valid_request_methods = None
  19. @property
  20. def valid_request_methods(self):
  21. return self._valid_request_methods
  22. @valid_request_methods.setter
  23. def valid_request_methods(self, valid_request_methods):
  24. if valid_request_methods is not None:
  25. valid_request_methods = [x.upper() for x in valid_request_methods]
  26. self._valid_request_methods = valid_request_methods
  27. @property
  28. def available(self):
  29. return self._available
  30. @available.setter
  31. def available(self, available):
  32. self._available = available
  33. @property
  34. def catch_errors(self):
  35. return self._catch_errors
  36. @catch_errors.setter
  37. def catch_errors(self, catch_errors):
  38. self._catch_errors = catch_errors
  39. def _raise_on_missing_token(self, request):
  40. """Raise error on missing token."""
  41. if not request.token:
  42. raise InvalidRequestError(request=request,
  43. description='Missing token parameter.')
  44. def _raise_on_invalid_client(self, request):
  45. """Raise on failed client authentication."""
  46. if self.request_validator.client_authentication_required(request):
  47. if not self.request_validator.authenticate_client(request):
  48. log.debug('Client authentication failed, %r.', request)
  49. raise InvalidClientError(request=request)
  50. elif not self.request_validator.authenticate_client_id(request.client_id, request):
  51. log.debug('Client authentication failed, %r.', request)
  52. raise InvalidClientError(request=request)
  53. def _raise_on_unsupported_token(self, request):
  54. """Raise on unsupported tokens."""
  55. if (request.token_type_hint and
  56. request.token_type_hint in self.valid_token_types and
  57. request.token_type_hint not in self.supported_token_types):
  58. raise UnsupportedTokenTypeError(request=request)
  59. def _raise_on_bad_method(self, request):
  60. if self.valid_request_methods is None:
  61. raise ValueError('Configure "valid_request_methods" property first')
  62. if request.http_method.upper() not in self.valid_request_methods:
  63. raise InvalidRequestError(request=request,
  64. description=('Unsupported request method %s' % request.http_method.upper()))
  65. def _raise_on_bad_post_request(self, request):
  66. """Raise if invalid POST request received
  67. """
  68. if request.http_method.upper() == 'POST':
  69. query_params = request.uri_query or ""
  70. if query_params:
  71. raise InvalidRequestError(request=request,
  72. description=('URL query parameters are not allowed'))
  73. def catch_errors_and_unavailability(f):
  74. @functools.wraps(f)
  75. def wrapper(endpoint, uri, *args, **kwargs):
  76. if not endpoint.available:
  77. e = TemporarilyUnavailableError()
  78. log.info('Endpoint unavailable, ignoring request %s.' % uri)
  79. return {}, e.json, 503
  80. if endpoint.catch_errors:
  81. try:
  82. return f(endpoint, uri, *args, **kwargs)
  83. except OAuth2Error:
  84. raise
  85. except FatalClientError:
  86. raise
  87. except Exception as e:
  88. error = ServerError()
  89. log.warning(
  90. 'Exception caught while processing request, %s.' % e)
  91. return {}, error.json, 500
  92. else:
  93. return f(endpoint, uri, *args, **kwargs)
  94. return wrapper