From e94a1897f61720253a4dd5454eb00d74af6fb4a2 Mon Sep 17 00:00:00 2001 From: Tomas Pazderka Date: Fri, 12 Apr 2019 13:17:07 +0200 Subject: [PATCH] Use black for code formatting --- .gitignore | 1 + Makefile | 8 + pylama.ini | 3 +- setup.py | 2 +- src/oic/__init__.py | 17 +- src/oic/exception.py | 2 +- src/oic/extension/__init__.py | 2 +- src/oic/extension/client.py | 513 ++++++++----- src/oic/extension/device_flow.py | 76 +- src/oic/extension/heart.py | 46 +- src/oic/extension/message.py | 201 +++--- src/oic/extension/pop.py | 48 +- src/oic/extension/popjwt.py | 34 +- src/oic/extension/proof_of_possesion.py | 64 +- src/oic/extension/provider.py | 499 +++++++------ src/oic/extension/signed_http_req.py | 49 +- src/oic/extension/sts.py | 32 +- src/oic/extension/token.py | 87 ++- src/oic/oauth2/__init__.py | 647 +++++++++++------ src/oic/oauth2/base.py | 19 +- src/oic/oauth2/consumer.py | 58 +- src/oic/oauth2/exception.py | 2 +- src/oic/oauth2/grant.py | 2 +- src/oic/oauth2/message.py | 243 ++++--- src/oic/oauth2/provider.py | 293 +++++--- src/oic/oauth2/util.py | 90 +-- src/oic/oic/__init__.py | 916 +++++++++++++++--------- src/oic/oic/claims_provider.py | 125 ++-- src/oic/oic/consumer.py | 73 +- src/oic/oic/message.py | 354 +++++---- src/oic/oic/provider.py | 757 ++++++++++++-------- src/oic/utils/__init__.py | 7 +- src/oic/utils/aes.py | 29 +- src/oic/utils/authn/__init__.py | 2 +- src/oic/utils/authn/authn_context.py | 26 +- src/oic/utils/authn/client.py | 120 ++-- src/oic/utils/authn/client_saml.py | 6 +- src/oic/utils/authn/javascript_login.py | 6 +- src/oic/utils/authn/ldap_member.py | 5 +- src/oic/utils/authn/ldapc.py | 32 +- src/oic/utils/authn/multi_auth.py | 23 +- src/oic/utils/authn/saml.py | 130 ++-- src/oic/utils/authn/user.py | 60 +- src/oic/utils/authn/user_cas.py | 56 +- src/oic/utils/authz.py | 10 +- src/oic/utils/claims.py | 2 +- src/oic/utils/client_management.py | 94 ++- src/oic/utils/clientdb.py | 26 +- src/oic/utils/http_util.py | 176 +++-- src/oic/utils/jwt.py | 63 +- src/oic/utils/keyio.py | 211 +++--- src/oic/utils/restrict.py | 28 +- src/oic/utils/rp/__init__.py | 210 +++--- src/oic/utils/rp/oauth2.py | 168 +++-- src/oic/utils/sanitize.py | 29 +- src/oic/utils/sdb.py | 211 +++--- src/oic/utils/shelve_wrapper.py | 2 +- src/oic/utils/stateless.py | 40 +- src/oic/utils/template_render.py | 24 +- src/oic/utils/time_util.py | 92 ++- src/oic/utils/token_handler.py | 63 +- src/oic/utils/userinfo/__init__.py | 2 +- src/oic/utils/userinfo/aa_info.py | 16 +- src/oic/utils/userinfo/distaggr.py | 14 +- src/oic/utils/userinfo/ldap_info.py | 26 +- src/oic/utils/webfinger.py | 52 +- tox.ini | 5 +- 67 files changed, 4432 insertions(+), 2897 deletions(-) diff --git a/.gitignore b/.gitignore index f61373fec..a1c9f166a 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ foo.* .coverage .cache/ .pytest_cache +.mypy_cache # Dynamically created doc folders doc/_build diff --git a/Makefile b/Makefile index f2fd36c2b..623a38e03 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,7 @@ help: @echo " install to install the python dependencies for development" @echo " test to run the tests" @echo " isort to sort imports" + @echo " blacken to format the code" .PHONY: help clean: @@ -47,6 +48,13 @@ check-isort: @pipenv run isort --recursive --diff --check-only $(OICDIR) $(TESTDIR) .PHONY: isort check-isort +blacken: + @pipenv run black src/ + +check-black: + @pipenv run black src/ --check +.PHONY: blacken check-black + check-pylama: @pipenv run pylama $(OICDIR) $(TESTDIR) .PHONY: check-pylama diff --git a/pylama.ini b/pylama.ini index 9b5c7aa96..4c5db4379 100644 --- a/pylama.ini +++ b/pylama.ini @@ -2,7 +2,8 @@ linters = pyflakes,eradicate,pycodestyle,mccabe,pep257 # D10X - Ignore complains about missing docstrings - we want to enforce style but do not want to add all docstrings # D203/D204 and D212/D213 are mutually exclusive, pick one -ignore = D100,D101,D102,D103,D104,D105,D106,D107,D203,D212 +# E203 is not PEP8 compliant in pycodestyle +ignore = D100,D101,D102,D103,D104,D105,D106,D107,D203,D212,E203 [pylama:pycodestyle] max_line_length = 120 diff --git a/setup.py b/setup.py index d137c7abd..402ffdfb3 100755 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ def run_tests(self): 'develop': ["cherrypy==3.2.4", "pyOpenSSL"], 'testing': tests_requires, 'docs': ['Sphinx', 'sphinx-autobuild', 'alabaster'], - 'quality': ['pylama', 'isort', 'eradicate', 'mypy'], + 'quality': ['pylama', 'isort', 'eradicate', 'mypy', 'black'], 'ldap_authn': ['pyldap'], }, install_requires=[ diff --git a/src/oic/__init__.py b/src/oic/__init__.py index f3883784f..78535c454 100644 --- a/src/oic/__init__.py +++ b/src/oic/__init__.py @@ -7,6 +7,7 @@ from secrets import choice except ImportError: import random + try: # Python 2.4+ if available on the platform _sysrand = random.SystemRandom() @@ -14,22 +15,20 @@ except AttributeError: # Fallback, really bad import warnings + choice = random.choice warnings.warn( "No good random number generator available on this platform. " "Security tokens will be weak and guessable.", - RuntimeWarning) + RuntimeWarning, + ) -__author__ = 'Roland Hedberg' -__version__ = '0.15.1' +__author__ = "Roland Hedberg" +__version__ = "0.15.1" OIDCONF_PATTERN = "%s/.well-known/openid-configuration" -CC_METHOD = { - 'S256': hashlib.sha256, - 'S384': hashlib.sha384, - 'S512': hashlib.sha512, -} +CC_METHOD = {"S256": hashlib.sha256, "S384": hashlib.sha384, "S512": hashlib.sha512} def rndstr(size=16): @@ -43,7 +42,7 @@ def rndstr(size=16): return "".join([choice(_basech) for _ in range(size)]) -BASECH = string.ascii_letters + string.digits + '-._~' +BASECH = string.ascii_letters + string.digits + "-._~" def unreserved(size=64): diff --git a/src/oic/exception.py b/src/oic/exception.py index 8fdef9be6..92dab0e6c 100644 --- a/src/oic/exception.py +++ b/src/oic/exception.py @@ -1,4 +1,4 @@ -__author__ = 'rohe0002' +__author__ = "rohe0002" class PyoidcError(Exception): diff --git a/src/oic/extension/__init__.py b/src/oic/extension/__init__.py index 16f3a7968..c09b22136 100644 --- a/src/oic/extension/__init__.py +++ b/src/oic/extension/__init__.py @@ -1 +1 @@ -__author__ = 'roland' +__author__ = "roland" diff --git a/src/oic/extension/client.py b/src/oic/extension/client.py index 89a729ce1..528a0c393 100644 --- a/src/oic/extension/client.py +++ b/src/oic/extension/client.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -__author__ = 'roland' +__author__ = "roland" # ----------------------------------------------------------------------------- @@ -26,10 +26,10 @@ RESPONSE2ERROR = { "ClientInfoResponse": [ClientRegistrationError], - "ClientUpdateRequest": [ClientRegistrationError] + "ClientUpdateRequest": [ClientRegistrationError], } -BASECH = string.ascii_letters + string.digits + '-._~' +BASECH = string.ascii_letters + string.digits + "-._~" def unreserved(size=64): @@ -42,59 +42,70 @@ def unreserved(size=64): return "".join([random.choice(BASECH) for _ in range(size)]) -CC_METHOD = { - 'S256': hashlib.sha256, - 'S384': hashlib.sha384, - 'S512': hashlib.sha512, -} +CC_METHOD = {"S256": hashlib.sha256, "S384": hashlib.sha384, "S512": hashlib.sha512} class Client(oauth2.Client): - def __init__(self, client_id=None, - client_authn_method=None, keyjar=None, verify_ssl=True, - config=None, message_factory=ExtensionMessageFactory): - super().__init__(client_id=client_id, client_authn_method=client_authn_method, - keyjar=keyjar, verify_ssl=verify_ssl, - config=config, message_factory=message_factory) + def __init__( + self, + client_id=None, + client_authn_method=None, + keyjar=None, + verify_ssl=True, + config=None, + message_factory=ExtensionMessageFactory, + ): + super().__init__( + client_id=client_id, + client_authn_method=client_authn_method, + keyjar=keyjar, + verify_ssl=verify_ssl, + config=config, + message_factory=message_factory, + ) self.allow = {} - self.request2endpoint.update({ - "RegistrationRequest": "registration_endpoint", - "ClientUpdateRequest": "clientinfo_endpoint", - 'TokenIntrospectionRequest': 'introspection_endpoint', - 'TokenRevocationRequest': 'revocation_endpoint' - }) + self.request2endpoint.update( + { + "RegistrationRequest": "registration_endpoint", + "ClientUpdateRequest": "clientinfo_endpoint", + "TokenIntrospectionRequest": "introspection_endpoint", + "TokenRevocationRequest": "revocation_endpoint", + } + ) self.registration_response = None - def construct_RegistrationRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_RegistrationRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('registration_endpoint') + request = self.message_factory.get_request_type("registration_endpoint") if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) - def construct_ClientUpdateRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_ClientUpdateRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('update_endpoint') + request = self.message_factory.get_request_type("update_endpoint") if request_args is None: request_args = {} return self.construct_request(request, request_args, extra_args) def _token_interaction_setup(self, request_args=None, **kwargs): - if request_args is None or 'token' not in request_args: + if request_args is None or "token" not in request_args: token = self.get_token(**kwargs) try: - _token_type_hint = kwargs['token_type_hint'] + _token_type_hint = kwargs["token_type_hint"] except KeyError: - _token_type_hint = 'access_token' + _token_type_hint = "access_token" - request_args = {'token_type_hint': _token_type_hint, - 'token': getattr(token, _token_type_hint)} + request_args = { + "token_type_hint": _token_type_hint, + "token": getattr(token, _token_type_hint), + } if "client_id" not in request_args: request_args["client_id"] = self.client_id @@ -103,217 +114,355 @@ def _token_interaction_setup(self, request_args=None, **kwargs): return request_args - def construct_TokenIntrospectionRequest(self, - request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_TokenIntrospectionRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('introspection_endpoint') + request = self.message_factory.get_request_type("introspection_endpoint") request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) - def construct_TokenRevocationRequest(self, - request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_TokenRevocationRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('revocation_endpoint') + request = self.message_factory.get_request_type("revocation_endpoint") request_args = self._token_interaction_setup(request_args, **kwargs) return self.construct_request(request, request_args, extra_args) - def do_op(self, request, body_type='', method='GET', request_args=None, - extra_args=None, http_args=None, response_cls=None, **kwargs): - - url, body, ht_args, _ = self.request_info(request, method, - request_args, extra_args, - **kwargs) + def do_op( + self, + request, + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): + + url, body, ht_args, _ = self.request_info( + request, method, request_args, extra_args, **kwargs + ) if http_args is None: http_args = ht_args else: http_args.update(http_args) - resp = self.request_and_return(url, response_cls, method, body, - body_type, http_args=http_args) + resp = self.request_and_return( + url, response_cls, method, body, body_type, http_args=http_args + ) return resp - def do_client_registration(self, request=None, - body_type="", method="GET", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - **kwargs): + def do_client_registration( + self, + request=None, + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('registration_endpoint') + request = self.message_factory.get_request_type("registration_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('registration_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) - - def do_client_read_request(self, request=None, - body_type="", method="GET", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - **kwargs): + response_cls = self.message_factory.get_response_type( + "registration_endpoint" + ) + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) + + def do_client_read_request( + self, + request=None, + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('update_endpoint') + request = self.message_factory.get_request_type("update_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('update_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) - - def do_client_update_request(self, request=None, - body_type="", method="PUT", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - **kwargs): + response_cls = self.message_factory.get_response_type("update_endpoint") + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) + + def do_client_update_request( + self, + request=None, + body_type="", + method="PUT", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('update_endpoint') + request = self.message_factory.get_request_type("update_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('update_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) - - def do_client_delete_request(self, request=None, - body_type="", method="DELETE", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - **kwargs): + response_cls = self.message_factory.get_response_type("update_endpoint") + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) + + def do_client_delete_request( + self, + request=None, + body_type="", + method="DELETE", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('delete_endpoint') + request = self.message_factory.get_request_type("delete_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('delete_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) + response_cls = self.message_factory.get_response_type("delete_endpoint") + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) def do_token_introspection( - self, request=None, body_type="json", - method="POST", request_args=None, extra_args=None, - http_args=None, response_cls=None, **kwargs): + self, + request=None, + body_type="json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('introspection_endpoint') + request = self.message_factory.get_request_type("introspection_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('introspection_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) + response_cls = self.message_factory.get_response_type( + "introspection_endpoint" + ) + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) def do_token_revocation( - self, request=None, body_type="", - method="POST", request_args=None, extra_args=None, - http_args=None, response_cls=None, **kwargs): + self, + request=None, + body_type="", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('revocation_endpoint') + request = self.message_factory.get_request_type("revocation_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('revocation_endpoint') - return self.do_op(request=request, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, - **kwargs) + response_cls = self.message_factory.get_response_type("revocation_endpoint") + return self.do_op( + request=request, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) def add_code_challenge(self): try: - cv_len = self.config['code_challenge']['length'] + cv_len = self.config["code_challenge"]["length"] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) - _cv = code_verifier.encode('ascii') + _cv = code_verifier.encode("ascii") try: - _method = self.config['code_challenge']['method'] + _method = self.config["code_challenge"]["method"] except KeyError: - _method = 'S256' + _method = "S256" try: _h = CC_METHOD[_method](_cv).digest() - code_challenge = b64e(_h).decode('ascii') + code_challenge = b64e(_h).decode("ascii") except KeyError: - raise Unsupported( - 'PKCE Transformation method:{}'.format(_method)) + raise Unsupported("PKCE Transformation method:{}".format(_method)) # TODO store code_verifier - return {"code_challenge": code_challenge, - "code_challenge_method": _method}, code_verifier + return ( + {"code_challenge": code_challenge, "code_challenge_method": _method}, + code_verifier, + ) def do_authorization_request( - self, request=None, state="", body_type="", - method="GET", request_args=None, extra_args=None, http_args=None, - response_cls=None, **kwargs): + self, + request=None, + state="", + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('authorization_endpoint') + request = self.message_factory.get_request_type("authorization_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated, please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated, please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('authorization_endpoint') - if 'code_challenge' in self.config and self.config['code_challenge']: + response_cls = self.message_factory.get_response_type( + "authorization_endpoint" + ) + if "code_challenge" in self.config and self.config["code_challenge"]: _args, code_verifier = self.add_code_challenge() request_args.update(_args) - oauth2.Client.do_authorization_request(self, - request=request, state=state, - body_type=body_type, - method=method, - request_args=request_args, - extra_args=extra_args, - http_args=http_args, - response_cls=response_cls, - **kwargs) + oauth2.Client.do_authorization_request( + self, + request=request, + state=state, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + **kwargs + ) def store_registration_info(self, reginfo): self.registration_response = reginfo @@ -323,7 +472,9 @@ def store_registration_info(self, reginfo): def handle_registration_info(self, response): if response.status_code in SUCCESSFUL: - resp = self.message_factory.get_response_type('registration_endpoint')().deserialize(response.text, "json") + resp = self.message_factory.get_response_type( + "registration_endpoint" + )().deserialize(response.text, "json") self.store_response(resp, response.text) self.store_registration_info(resp) else: @@ -332,8 +483,7 @@ def handle_registration_info(self, response): resp.verify() self.store_response(resp, response.text) except Exception: - raise PyoidcError( - 'Registration failed: {}'.format(response.text)) + raise PyoidcError("Registration failed: {}".format(response.text)) return resp @@ -349,16 +499,17 @@ def register(self, url, **kwargs): headers = {"content-type": "application/json"} - rsp = self.http_request(url, "POST", data=req.to_json(), - headers=headers) + rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) def parse_authz_response(self, query): - aresp = self.parse_response(self.message_factory.get_response_type('authorization_endpoint'), - info=query, - sformat="urlencoded", - keyjar=self.keyjar) + aresp = self.parse_response( + self.message_factory.get_response_type("authorization_endpoint"), + info=query, + sformat="urlencoded", + keyjar=self.keyjar, + ) if aresp.type() == "ErrorResponse": logger.info("ErrorResponse: %s" % sanitize(aresp)) raise AuthzError(aresp.error) diff --git a/src/oic/extension/device_flow.py b/src/oic/extension/device_flow.py index a62eaa32a..87d1d44c2 100644 --- a/src/oic/extension/device_flow.py +++ b/src/oic/extension/device_flow.py @@ -11,27 +11,27 @@ class AuthorizationRequest(Message): c_param = { - 'response_type': SINGLE_REQUIRED_STRING, - 'client_id': SINGLE_REQUIRED_STRING, - 'scope': SINGLE_OPTIONAL_STRING, + "response_type": SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, + "scope": SINGLE_OPTIONAL_STRING, } class AuthorizationResponse(Message): c_param = { - 'device_code': SINGLE_REQUIRED_STRING, - 'user_code': SINGLE_REQUIRED_STRING, - 'verification_uri': SINGLE_REQUIRED_STRING, - 'expires_in': SINGLE_OPTIONAL_INT, - 'interval': SINGLE_OPTIONAL_INT, + "device_code": SINGLE_REQUIRED_STRING, + "user_code": SINGLE_REQUIRED_STRING, + "verification_uri": SINGLE_REQUIRED_STRING, + "expires_in": SINGLE_OPTIONAL_INT, + "interval": SINGLE_OPTIONAL_INT, } class TokenRequest(Message): c_param = { - 'grant_type': SINGLE_REQUIRED_STRING, - 'device_code': SINGLE_REQUIRED_STRING, - 'client_id': SINGLE_REQUIRED_STRING, + "grant_type": SINGLE_REQUIRED_STRING, + "device_code": SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, } @@ -54,23 +54,24 @@ def device_endpoint(self, request, authn=None): self.device2user[device_code] = user_code self.user_auth[user_code] = False - self.client_id2device[_req['client_id']] = device_code - self.device_code_expire_at[ - device_code] = time_sans_frac() + self.device_code_life_time + self.client_id2device[_req["client_id"]] = device_code + self.device_code_expire_at[device_code] = ( + time_sans_frac() + self.device_code_life_time + ) def token_endpoint(self, request, authn=None): _req = TokenRequest(**request) - _dc = _req['device_code'] + _dc = _req["device_code"] if time_sans_frac() > self.device_code_expire_at[_dc]: - return self.host.error_code(error='expired_token') + return self.host.error_code(error="expired_token") _uc = self.device2user[_dc] if self.user_auth[_uc]: # User is authenticated pass else: - return self.host.error_code(error='authorization_pending') + return self.host.error_code(error="authorization_pending") def device_auth(self, user_code): self.user_auth[user_code] = True @@ -79,34 +80,41 @@ def device_auth(self, user_code): class DeviceFlowClient(SingleClient): def __init__(self, host): SingleClient.__init__(self, host) - self.requests = {'authorization': self.authorization_request, - 'token': self.authorization_request} - - def authorization_request(self, scope=''): - req = AuthorizationRequest(client_id=self.host.client_id, - response_type='device_code') + self.requests = { + "authorization": self.authorization_request, + "token": self.authorization_request, + } + + def authorization_request(self, scope=""): + req = AuthorizationRequest( + client_id=self.host.client_id, response_type="device_code" + ) if scope: - req['scope'] = scope + req["scope"] = scope http_response = self.host.http_request( - self.host.provider_info['device_endpoint'], 'POST', - req.to_urlencoded()) + self.host.provider_info["device_endpoint"], "POST", req.to_urlencoded() + ) - response = self.host.parse_request_response(AuthorizationResponse, - http_response, 'json') + response = self.host.parse_request_response( + AuthorizationResponse, http_response, "json" + ) return response - def token_request(self, device_code=''): + def token_request(self, device_code=""): req = TokenRequest( grant_type="urn:ietf:params:oauth:grant-type:device_code", - device_code=device_code, client_id=self.host.client_id) + device_code=device_code, + client_id=self.host.client_id, + ) http_response = self.host.http_request( - self.host.provider_info['token_endpoint'], 'POST', - req.to_urlencoded()) + self.host.provider_info["token_endpoint"], "POST", req.to_urlencoded() + ) - response = self.host.parse_request_response(AccessTokenResponse, - http_response, 'json') + response = self.host.parse_request_response( + AccessTokenResponse, http_response, "json" + ) return response diff --git a/src/oic/extension/heart.py b/src/oic/extension/heart.py index 4a16a34cd..cd2483d0a 100644 --- a/src/oic/extension/heart.py +++ b/src/oic/extension/heart.py @@ -6,20 +6,22 @@ from oic.oic.message import JasonWebToken from oic.utils.keyio import KeyBundle -__author__ = 'roland' +__author__ = "roland" class PrivateKeyJWT(JasonWebToken): c_param = JasonWebToken.c_param.copy() - c_param.update({ - 'aud': SINGLE_REQUIRED_STRING, - "iss": SINGLE_REQUIRED_STRING, - "sub": SINGLE_REQUIRED_STRING, - "aud": SINGLE_REQUIRED_STRING, - "exp": SINGLE_REQUIRED_INT, - "iat": SINGLE_REQUIRED_INT, - "jti": SINGLE_REQUIRED_STRING, - }) + c_param.update( + { + "aud": SINGLE_REQUIRED_STRING, + "iss": SINGLE_REQUIRED_STRING, + "sub": SINGLE_REQUIRED_STRING, + "aud": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "iat": SINGLE_REQUIRED_INT, + "jti": SINGLE_REQUIRED_STRING, + } + ) def verify_url(url): @@ -34,11 +36,11 @@ def verify_url(url): :param url: :return: """ - if url.startswith('http://localhost'): + if url.startswith("http://localhost"): return True else: p = urlparse(url) - if p.scheme == 'http': + if p.scheme == "http": return False return True @@ -46,14 +48,16 @@ def verify_url(url): class HeartSoftwareStatement(JasonWebToken): c_param = JasonWebToken.c_param.copy() - c_param.update({ - 'redirect_uris': REQUIRED_LIST_OF_STRINGS, - 'grant_types': SINGLE_REQUIRED_STRING, - 'jwks_uri': SINGLE_REQUIRED_STRING, - 'jwks': SINGLE_REQUIRED_STRING, - 'client_name': SINGLE_REQUIRED_STRING, - 'client_uri': SINGLE_REQUIRED_STRING - }) + c_param.update( + { + "redirect_uris": REQUIRED_LIST_OF_STRINGS, + "grant_types": SINGLE_REQUIRED_STRING, + "jwks_uri": SINGLE_REQUIRED_STRING, + "jwks": SINGLE_REQUIRED_STRING, + "client_name": SINGLE_REQUIRED_STRING, + "client_uri": SINGLE_REQUIRED_STRING, + } + ) c_allowed_values = {"grant_types": ["authorization_code", "implicit"]} def verify(self, **kwargs): @@ -65,7 +69,7 @@ def verify(self, **kwargs): else: # will raise an exception if syntax error KeyBundle(_keys) - for param in ['jwks_uri', 'client_uri']: + for param in ["jwks_uri", "client_uri"]: verify_url(self[param]) JasonWebToken.verify(self, **kwargs) diff --git a/src/oic/extension/message.py b/src/oic/extension/message.py index 29337af16..c350d1d77 100644 --- a/src/oic/extension/message.py +++ b/src/oic/extension/message.py @@ -22,18 +22,18 @@ from oic.utils.http_util import SUCCESSFUL from oic.utils.jwt import JWT -__author__ = 'roland' +__author__ = "roland" # RFC 7662 class TokenIntrospectionRequest(Message): c_param = { - 'token': SINGLE_REQUIRED_STRING, - 'token_type_hint': SINGLE_OPTIONAL_STRING, + "token": SINGLE_REQUIRED_STRING, + "token_type_hint": SINGLE_OPTIONAL_STRING, # The ones below are part of authentication information - 'client_id': SINGLE_OPTIONAL_STRING, - 'client_assertion_type': SINGLE_OPTIONAL_STRING, - 'client_assertion': SINGLE_OPTIONAL_STRING + "client_id": SINGLE_OPTIONAL_STRING, + "client_assertion_type": SINGLE_OPTIONAL_STRING, + "client_assertion": SINGLE_OPTIONAL_STRING, } @@ -42,82 +42,82 @@ class TokenIntrospectionRequest(Message): class TokenIntrospectionResponse(Message): c_param = { - 'active': SINGLE_REQUIRED_BOOLEAN, - 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, - 'client_id': SINGLE_OPTIONAL_STRING, - 'username': SINGLE_OPTIONAL_STRING, - 'token_type': SINGLE_OPTIONAL_STRING, - 'exp': SINGLE_OPTIONAL_INT, - 'iat': SINGLE_OPTIONAL_INT, - 'nbf': SINGLE_OPTIONAL_INT, - 'sub': SINGLE_OPTIONAL_STRING, - 'aud': OPTIONAL_LIST_OF_STRINGS, - 'iss': SINGLE_OPTIONAL_STRING, - 'jti': SINGLE_OPTIONAL_STRING + "active": SINGLE_REQUIRED_BOOLEAN, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "client_id": SINGLE_OPTIONAL_STRING, + "username": SINGLE_OPTIONAL_STRING, + "token_type": SINGLE_OPTIONAL_STRING, + "exp": SINGLE_OPTIONAL_INT, + "iat": SINGLE_OPTIONAL_INT, + "nbf": SINGLE_OPTIONAL_INT, + "sub": SINGLE_OPTIONAL_STRING, + "aud": OPTIONAL_LIST_OF_STRINGS, + "iss": SINGLE_OPTIONAL_STRING, + "jti": SINGLE_OPTIONAL_STRING, } # RFC 7009 class TokenRevocationRequest(Message): c_param = { - 'token': SINGLE_REQUIRED_STRING, - 'token_type_hint': SINGLE_OPTIONAL_STRING, - 'client_id': SINGLE_OPTIONAL_STRING, - 'client_assertion_type': SINGLE_OPTIONAL_STRING, - 'client_assertion': SINGLE_OPTIONAL_STRING + "token": SINGLE_REQUIRED_STRING, + "token_type_hint": SINGLE_OPTIONAL_STRING, + "client_id": SINGLE_OPTIONAL_STRING, + "client_assertion_type": SINGLE_OPTIONAL_STRING, + "client_assertion": SINGLE_OPTIONAL_STRING, } class SoftwareStatement(JasonWebToken): c_param = JasonWebToken.c_param.copy() - c_param.update({ - "software_id": SINGLE_OPTIONAL_STRING, - 'client_name': SINGLE_OPTIONAL_STRING, - 'client_uri': SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "software_id": SINGLE_OPTIONAL_STRING, + "client_name": SINGLE_OPTIONAL_STRING, + "client_uri": SINGLE_OPTIONAL_STRING, + } + ) class StateJWT(JasonWebToken): c_param = JasonWebToken.c_param.copy() - c_param.update({ - 'aud': SINGLE_OPTIONAL_STRING, - 'rfp': SINGLE_REQUIRED_STRING, - 'kid': SINGLE_OPTIONAL_STRING, - 'target_link__uri': SINGLE_OPTIONAL_STRING, - 'as': SINGLE_OPTIONAL_STRING, - 'at_hash': SINGLE_OPTIONAL_STRING, - 'c_hash': SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "aud": SINGLE_OPTIONAL_STRING, + "rfp": SINGLE_REQUIRED_STRING, + "kid": SINGLE_OPTIONAL_STRING, + "target_link__uri": SINGLE_OPTIONAL_STRING, + "as": SINGLE_OPTIONAL_STRING, + "at_hash": SINGLE_OPTIONAL_STRING, + "c_hash": SINGLE_OPTIONAL_STRING, + } + ) class ServerMetadata(Message): c_param = { - 'issuer': SINGLE_REQUIRED_STRING, - 'authorization_endpoint': SINGLE_OPTIONAL_STRING, - 'token_endpoint': SINGLE_OPTIONAL_STRING, - 'jwks_uri': SINGLE_REQUIRED_STRING, - 'registration_endpoint': SINGLE_OPTIONAL_STRING, - 'scopes_supported': OPTIONAL_LIST_OF_STRINGS, - 'response_types_supported': REQUIRED_LIST_OF_STRINGS, - 'response_modes_supported': OPTIONAL_LIST_OF_STRINGS, - 'grant_types_supported': OPTIONAL_LIST_OF_STRINGS, - 'token_endpoint_auth_methods_supported': OPTIONAL_LIST_OF_STRINGS, - 'token_endpoint_auth_signing_alg_values_supported': - OPTIONAL_LIST_OF_STRINGS, - 'service_documentation': SINGLE_OPTIONAL_STRING, - 'ui_locales_supported': OPTIONAL_LIST_OF_STRINGS, - 'op_policy_uri': SINGLE_OPTIONAL_STRING, - 'op_tos_uri': SINGLE_OPTIONAL_STRING, - 'revocation_endpoint': SINGLE_OPTIONAL_STRING, - 'revocation_endpoint_auth_methods_supported': OPTIONAL_LIST_OF_STRINGS, - 'revocation_endpoint_auth_signing_alg_values_supported': - OPTIONAL_LIST_OF_STRINGS, - 'introspection_endpoint': SINGLE_OPTIONAL_STRING, - 'introspection_endpoint_auth_methods_supported': - OPTIONAL_LIST_OF_STRINGS, - 'introspection_endpoint_auth_signing_alg_values_supported': - OPTIONAL_LIST_OF_STRINGS, - 'code_challenge_methods_supported': OPTIONAL_LIST_OF_STRINGS + "issuer": SINGLE_REQUIRED_STRING, + "authorization_endpoint": SINGLE_OPTIONAL_STRING, + "token_endpoint": SINGLE_OPTIONAL_STRING, + "jwks_uri": SINGLE_REQUIRED_STRING, + "registration_endpoint": SINGLE_OPTIONAL_STRING, + "scopes_supported": OPTIONAL_LIST_OF_STRINGS, + "response_types_supported": REQUIRED_LIST_OF_STRINGS, + "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, + "grant_types_supported": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "service_documentation": SINGLE_OPTIONAL_STRING, + "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, + "op_policy_uri": SINGLE_OPTIONAL_STRING, + "op_tos_uri": SINGLE_OPTIONAL_STRING, + "revocation_endpoint": SINGLE_OPTIONAL_STRING, + "revocation_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, + "revocation_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "introspection_endpoint": SINGLE_OPTIONAL_STRING, + "introspection_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, + "introspection_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "code_challenge_methods_supported": OPTIONAL_LIST_OF_STRINGS, } @@ -137,25 +137,26 @@ class RegistrationRequest(Message): "jwks_uri": SINGLE_OPTIONAL_STRING, "software_id": SINGLE_OPTIONAL_STRING, "software_version": SINGLE_OPTIONAL_STRING, - 'software_statement': OPTIONAL_LIST_OF_STRINGS + "software_statement": OPTIONAL_LIST_OF_STRINGS, } def verify(self, **kwargs): - if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith("https:"): - raise RegistrationError('initiate_login_uri is not https') + if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith( + "https:" + ): + raise RegistrationError("initiate_login_uri is not https") if "redirect_uris" in self: for uri in self["redirect_uris"]: if urlparse(uri).fragment: - raise InvalidRedirectUri( - "redirect_uri contains fragment: %s" % uri) + raise InvalidRedirectUri("redirect_uri contains fragment: %s" % uri) for uri in ["client_uri", "logo_uri", "tos_uri", "policy_uri"]: if uri in self: try: - resp = requests.request("GET", str(self[uri]), - allow_redirects=True, - verify=False) + resp = requests.request( + "GET", str(self[uri]), allow_redirects=True, verify=False + ) except requests.ConnectionError: raise MissingPage(self[uri]) @@ -163,47 +164,57 @@ def verify(self, **kwargs): raise MissingPage(self[uri]) try: - ss = self['software_statement'] + ss = self["software_statement"] except KeyError: pass else: _ss = [] for _s in ss: - _ss.append(unpack_software_statement(_s, '', kwargs['keyjar'])) - self['__software_statement'] = _ss + _ss.append(unpack_software_statement(_s, "", kwargs["keyjar"])) + self["__software_statement"] = _ss return super(RegistrationRequest, self).verify(**kwargs) class ClientInfoResponse(RegistrationRequest): c_param = RegistrationRequest.c_param.copy() - c_param.update({ - "client_id": SINGLE_REQUIRED_STRING, - "client_secret": SINGLE_OPTIONAL_STRING, - "client_id_issued_at": SINGLE_OPTIONAL_INT, - "client_secret_expires_at": SINGLE_OPTIONAL_INT, - "registration_access_token": SINGLE_REQUIRED_STRING, - "registration_client_uri": SINGLE_REQUIRED_STRING - }) + c_param.update( + { + "client_id": SINGLE_REQUIRED_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + "client_id_issued_at": SINGLE_OPTIONAL_INT, + "client_secret_expires_at": SINGLE_OPTIONAL_INT, + "registration_access_token": SINGLE_REQUIRED_STRING, + "registration_client_uri": SINGLE_REQUIRED_STRING, + } + ) class ClientRegistrationError(ErrorResponse): c_param = ErrorResponse.c_param.copy() c_param.update({"state": SINGLE_OPTIONAL_STRING}) c_allowed_values = ErrorResponse.c_allowed_values.copy() - c_allowed_values.update({"error": ["invalid_redirect_uri", - "invalid_client_metadata", - "invalid_client_id"]}) + c_allowed_values.update( + { + "error": [ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_client_id", + ] + } + ) class ClientUpdateRequest(RegistrationRequest): c_param = RegistrationRequest.c_param.copy() - c_param.update({ - "client_id": SINGLE_REQUIRED_STRING, - "client_secret": SINGLE_OPTIONAL_STRING, - 'client_assertion_type': SINGLE_OPTIONAL_STRING, - 'client_assertion': SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "client_id": SINGLE_REQUIRED_STRING, + "client_secret": SINGLE_OPTIONAL_STRING, + "client_assertion_type": SINGLE_OPTIONAL_STRING, + "client_assertion": SINGLE_OPTIONAL_STRING, + } + ) MSG = { @@ -215,7 +226,7 @@ class ClientUpdateRequest(RegistrationRequest): "TokenIntrospectionRequest": TokenIntrospectionRequest, "TokenIntrospectionResponse": TokenIntrospectionResponse, "SoftwareStatement": SoftwareStatement, - 'StateJWT': StateJWT + "StateJWT": StateJWT, } @@ -237,7 +248,7 @@ def factory(msgtype): def make_software_statement(keyjar, iss, **kwargs): params = list(inspect.signature(JWT.__init__).parameters.keys()) - params.remove('self') + params.remove("self") args = {} for param in params: @@ -260,7 +271,9 @@ def unpack_software_statement(software_statement, iss, keyjar): class ExtensionMessageFactory(OauthMessageFactory): """Message factory for Extension code.""" - introspection_endpoint = MessageTuple(TokenIntrospectionRequest, TokenIntrospectionResponse) + introspection_endpoint = MessageTuple( + TokenIntrospectionRequest, TokenIntrospectionResponse + ) revocation_endpoint = MessageTuple(TokenRevocationRequest, Message) registration_endpoint = MessageTuple(RegistrationRequest, ClientInfoResponse) update_endpoint = MessageTuple(ClientUpdateRequest, ClientInfoResponse) diff --git a/src/oic/extension/pop.py b/src/oic/extension/pop.py index 790f1653f..ada95a31b 100644 --- a/src/oic/extension/pop.py +++ b/src/oic/extension/pop.py @@ -13,20 +13,19 @@ from oic.utils.jwt import JWT from oic.utils.keyio import KeyBundle -__author__ = 'roland' +__author__ = "roland" -def sign_http_args(method, url, headers, body=''): +def sign_http_args(method, url, headers, body=""): p = urlparse(url) - kwargs = {'path': p.path, 'host': p.netloc, 'headers': headers, - 'method': method} + kwargs = {"path": p.path, "host": p.netloc, "headers": headers, "method": method} if body: - kwargs['body'] = body + kwargs["body"] = body query_params = compact(parse_qs(p.query)) - kwargs['query_params'] = query_params + kwargs["query_params"] = query_params return kwargs @@ -37,23 +36,22 @@ def __init__(self, key, alg): def __call__(self, method, url, **kwargs): try: - body = kwargs['body'] + body = kwargs["body"] except KeyError: body = None try: - headers = kwargs['headers'] + headers = kwargs["headers"] except KeyError: headers = {} _kwargs = sign_http_args(method, url, headers, body) shr = SignedHttpRequest(self.key) - kwargs['Authorization'] = 'pop {}'.format(shr.sign(alg=self.alg, - **_kwargs)) + kwargs["Authorization"] = "pop {}".format(shr.sign(alg=self.alg, **_kwargs)) return kwargs class PoPClient(object): - def __init__(self, key_size=2048, sign_alg='RS256'): + def __init__(self, key_size=2048, sign_alg="RS256"): self.key_size = key_size self.state2key = {} self.token2key = {} @@ -73,7 +71,7 @@ def update(self, msg, state, key_size=0): key = RSAKey(key=RSA.generate(key_size)) self.state2key[state] = key - msg['key'] = json.dumps(key.serialize()) + msg["key"] = json.dumps(key.serialize()) return msg def handle_access_token_response(self, resp): @@ -82,7 +80,7 @@ def handle_access_token_response(self, resp): :param resp: AccessTokenResponse instance """ - self.token2key[resp['access_token']] = self.state2key[resp['state']] + self.token2key[resp["access_token"]] = self.state2key[resp["state"]] class PoPAS(object): @@ -96,8 +94,7 @@ def store_key(self, key): kb.do_keys([key]) # Store key with thumbprint as key - key_thumbprint = b64e(kb.keys()[0].thumbprint('SHA-256')).decode( - 'utf8') + key_thumbprint = b64e(kb.keys()[0].thumbprint("SHA-256")).decode("utf8") self.thumbprint2key[key_thumbprint] = key return key_thumbprint @@ -105,15 +102,14 @@ def create_access_token(self, key_thumbprint): # creating the access_token jwt_constructor = JWT(self.keyjar, iss=self.me) # Audience is myself - return jwt_constructor.pack( - kid='abc', cnf={'kid': key_thumbprint}, aud=self.me) + return jwt_constructor.pack(kid="abc", cnf={"kid": key_thumbprint}, aud=self.me) def token_introspection(self, token): jwt_constructor = JWT(self.keyjar, iss=self.me) res = jwt_constructor.unpack(token) tir = TokenIntrospectionResponse(active=True) - tir['key'] = json.dumps(self.thumbprint2key[res['cnf']['kid']]) + tir["key"] = json.dumps(self.thumbprint2key[res["cnf"]["kid"]]) return tir @@ -129,14 +125,18 @@ def store_key(self, access_token, tir): :param access_token: The token that was introspected :param tir: TokenIntrospectionResponse instance """ - key = load_jwks(json.dumps({'keys': [json.loads(tir['key'])]})) + key = load_jwks(json.dumps({"keys": [json.loads(tir["key"])]})) self.token2key[access_token] = key - def eval_signed_http_request(self, pop_token, access_token, method, url, - headers, body=''): + def eval_signed_http_request( + self, pop_token, access_token, method, url, headers, body="" + ): kwargs = sign_http_args(method, url, headers, body) shr = SignedHttpRequest(self.token2key[access_token][0]) - return shr.verify(signature=pop_token, - strict_query_params_verification=True, - strict_headers_verification=True, **kwargs) + return shr.verify( + signature=pop_token, + strict_query_params_verification=True, + strict_headers_verification=True, + **kwargs + ) diff --git a/src/oic/extension/popjwt.py b/src/oic/extension/popjwt.py index 072e50340..6863eaa2f 100644 --- a/src/oic/extension/popjwt.py +++ b/src/oic/extension/popjwt.py @@ -4,19 +4,18 @@ from oic.oic.message import JasonWebToken from oic.utils.time_util import utc_time_sans_frac -__author__ = 'roland' +__author__ = "roland" class PJWT(JasonWebToken): c_param = JasonWebToken.c_param.copy() - c_param.update({ - 'cnf': REQUIRED_MESSAGE - }) + c_param.update({"cnf": REQUIRED_MESSAGE}) class PopJWT(object): - def __init__(self, iss='', aud='', lifetime=3600, in_a_while=0, sub='', - jwe=None, keys=None): + def __init__( + self, iss="", aud="", lifetime=3600, in_a_while=0, sub="", jwe=None, keys=None + ): """ Initialize the class. @@ -38,17 +37,17 @@ def __init__(self, iss='', aud='', lifetime=3600, in_a_while=0, sub='', def _init_jwt(self): kwargs = {} - for para in ['iss', 'aud', 'sub']: + for para in ["iss", "aud", "sub"]: _val = getattr(self, para) if _val: kwargs[para] = _val _iat = utc_time_sans_frac() - kwargs['iat'] = _iat + kwargs["iat"] = _iat if self.lifetime: - kwargs['exp'] = _iat + self.lifetime + kwargs["exp"] = _iat + self.lifetime if self.in_a_while: - kwargs['nbf'] = _iat + self.in_a_while + kwargs["nbf"] = _iat + self.in_a_while return PJWT(**kwargs) @@ -60,10 +59,10 @@ def pack_jwk(self, jwk): :return: """ pjwt = self._init_jwt() - pjwt['cnf'] = {'jwk': jwk} + pjwt["cnf"] = {"jwk": jwk} return pjwt - def pack_jwe(self, jwe=None, jwk=None, kid=''): + def pack_jwe(self, jwe=None, jwk=None, kid=""): """ Pack JWE. @@ -75,16 +74,15 @@ def pack_jwe(self, jwe=None, jwk=None, kid=''): pjwt = self._init_jwt() if jwe: - pjwt['cnf'] = {'jwe': jwe} + pjwt["cnf"] = {"jwe": jwe} elif jwk: self.jwe.msg = json.dumps(jwk) - pjwt['cnf'] = {'jwe': self.jwe.encrypt(keys=self.keys.keys(), - kid=kid)} + pjwt["cnf"] = {"jwe": self.jwe.encrypt(keys=self.keys.keys(), kid=kid)} return pjwt def pack_kid(self, kid): pjwt = self._init_jwt() - pjwt['cnf'] = {'kid': kid} + pjwt["cnf"] = {"kid": kid} return pjwt def unpack(self, jwt, jwe=None): @@ -98,7 +96,7 @@ def unpack(self, jwt, jwe=None): _pjwt = PJWT().from_json(jwt) try: - _jwe = _pjwt['cnf']['jwe'] + _jwe = _pjwt["cnf"]["jwe"] except KeyError: pass else: @@ -106,6 +104,6 @@ def unpack(self, jwt, jwe=None): jwe = self.jwe msg = jwe.decrypt(_jwe, self.keys.keys()) - _pjwt['cnf']['jwk'] = json.loads(msg.decode('utf8')) + _pjwt["cnf"]["jwk"] = json.loads(msg.decode("utf8")) return _pjwt diff --git a/src/oic/extension/proof_of_possesion.py b/src/oic/extension/proof_of_possesion.py index 0475ce9e2..fb488c33b 100644 --- a/src/oic/extension/proof_of_possesion.py +++ b/src/oic/extension/proof_of_possesion.py @@ -17,7 +17,7 @@ from oic.utils.http_util import Response from oic.utils.http_util import get_post -__author__ = 'regu0004' +__author__ = "regu0004" class NonPoPTokenError(Exception): @@ -31,26 +31,28 @@ def __init__(self, *args, **kwargs): # mapping from signed pop token to access token in db self.access_tokens = {} - def token_endpoint(self, dtype='urlencoded', **kwargs): + def token_endpoint(self, dtype="urlencoded", **kwargs): atr = AccessTokenRequest().deserialize(kwargs["request"], dtype) resp = super(PoPProvider, self).token_endpoint(**kwargs) if "token_type" not in atr or atr["token_type"] != "pop": return resp - client_public_key = base64.urlsafe_b64decode( - atr["key"].encode("utf-8")).decode("utf-8") + client_public_key = base64.urlsafe_b64decode(atr["key"].encode("utf-8")).decode( + "utf-8" + ) pop_key = json.loads(client_public_key) atr = AccessTokenResponse().deserialize(resp.message, method="json") data = self.sdb.read(atr["access_token"]) - jwt = {"iss": self.baseurl, - "aud": self.baseurl, - "exp": data["token_expires_at"], - "nbf": int(time.time()), - "cnf": {"jwk": pop_key}} - _jws = JWS(jwt, alg="RS256").sign_compact( - self.keyjar.get_signing_key(owner="")) + jwt = { + "iss": self.baseurl, + "aud": self.baseurl, + "exp": data["token_expires_at"], + "nbf": int(time.time()), + "cnf": {"jwk": pop_key}, + } + _jws = JWS(jwt, alg="RS256").sign_compact(self.keyjar.get_signing_key(owner="")) self.access_tokens[_jws] = data["access_token"] atr["access_token"] = _jws @@ -62,31 +64,36 @@ def userinfo_endpoint(self, request, **kwargs): shr = SignedHttpRequest(self._get_client_public_key(access_token)) http_signature = self._parse_signature(request) try: - shr.verify(http_signature, - method=request["method"], - host=request["host"], path=request["path"], - query_params=request["query"], - headers=request["headers"], - body=request["body"], - strict_query_param_verification=True, - strict_headers_verification=False) + shr.verify( + http_signature, + method=request["method"], + host=request["host"], + path=request["path"], + query_params=request["query"], + headers=request["headers"], + body=request["body"], + strict_query_param_verification=True, + strict_headers_verification=False, + ) except ValidationError: - return self._error_response("access_denied", - descr="Could not verify proof of " - "possession") + return self._error_response( + "access_denied", descr="Could not verify proof of " "possession" + ) return self._do_user_info(self.access_tokens[access_token], **kwargs) def _get_client_public_key(self, access_token): _jws = jws.factory(access_token) if _jws: - data = _jws.verify_compact(access_token, - self.keyjar.get_verify_key(owner="")) + data = _jws.verify_compact( + access_token, self.keyjar.get_verify_key(owner="") + ) try: return keyrep(data["cnf"]["jwk"]) except KeyError: raise NonPoPTokenError( - "Could not extract public key as JWK from access token") + "Could not extract public key as JWK from access token" + ) raise NonPoPTokenError("Unsigned access token, maybe not PoP?") @@ -125,10 +132,13 @@ def _parse_access_token(self, request, **kwargs): return request["query"]["access_token"] elif "access_token" in request["body"]: return parse_qs(request["body"])["access_token"][0] - elif "Authorization" in request["headers"] and request["headers"]["Authorization"]: + elif ( + "Authorization" in request["headers"] + and request["headers"]["Authorization"] + ): auth_header = request["headers"]["Authorization"] if auth_header.startswith("pop "): - return auth_header[len("pop "):] + return auth_header[len("pop ") :] return None diff --git a/src/oic/extension/provider.py b/src/oic/extension/provider.py index 485b16216..a5a1a8c2d 100644 --- a/src/oic/extension/provider.py +++ b/src/oic/extension/provider.py @@ -53,62 +53,109 @@ from oic.utils.token_handler import NotAllowed from oic.utils.token_handler import TokenHandler -__author__ = 'roland' +__author__ = "roland" logger = logging.getLogger(__name__) CAPABILITIES = { "response_types_supported": ["code", "token"], - "response_modes_supported": ['query', 'fragment', 'form_post'], + "response_modes_supported": ["query", "fragment", "form_post"], "grant_types_supported": [ - "authorization_code", "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer"], + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], } -AUTH_METHODS_SUPPORTED = ["client_secret_post", "client_secret_basic", - "client_secret_jwt", "private_key_jwt"] +AUTH_METHODS_SUPPORTED = [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", +] class ClientInfoEndpoint(Endpoint): etype = "clientinfo" - url = 'clientinfo' + url = "clientinfo" class RevocationEndpoint(Endpoint): etype = "revocation" - url = 'revocation' + url = "revocation" class IntrospectionEndpoint(Endpoint): etype = "introspection" - url = 'introspection' + url = "introspection" class Provider(provider.Provider): """A OAuth2 RP that knows all the OAuth2 extensions I've implemented.""" - def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, - symkey=None, urlmap=None, iv=0, default_scope="", - ca_bundle=None, seed=b"", client_authn_methods=None, - authn_at_registration="", client_info_url="", - secret_lifetime=86400, jwks_uri='', keyjar=None, - capabilities=None, verify_ssl=True, baseurl='', hostname='', - config=None, behavior=None, lifetime_policy=None, message_factory=ExtensionMessageFactory, **kwargs): + def __init__( + self, + name, + sdb, + cdb, + authn_broker, + authz, + client_authn, + symkey=None, + urlmap=None, + iv=0, + default_scope="", + ca_bundle=None, + seed=b"", + client_authn_methods=None, + authn_at_registration="", + client_info_url="", + secret_lifetime=86400, + jwks_uri="", + keyjar=None, + capabilities=None, + verify_ssl=True, + baseurl="", + hostname="", + config=None, + behavior=None, + lifetime_policy=None, + message_factory=ExtensionMessageFactory, + **kwargs + ): if not name.endswith("/"): name += "/" try: - args = {'server_cls': kwargs['server_cls']} + args = {"server_cls": kwargs["server_cls"]} except KeyError: args = {} - super().__init__(name, sdb, cdb, authn_broker, authz, - client_authn, symkey, urlmap, iv, - default_scope, ca_bundle, message_factory=message_factory, **args) - - self.endp.extend([RegistrationEndpoint, ClientInfoEndpoint, - RevocationEndpoint, IntrospectionEndpoint]) + super().__init__( + name, + sdb, + cdb, + authn_broker, + authz, + client_authn, + symkey, + urlmap, + iv, + default_scope, + ca_bundle, + message_factory=message_factory, + **args + ) + + self.endp.extend( + [ + RegistrationEndpoint, + ClientInfoEndpoint, + RevocationEndpoint, + IntrospectionEndpoint, + ] + ) # dictionary of client authentication methods self.client_authn_methods = client_authn_methods @@ -123,16 +170,15 @@ def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, self.jwks_uri = jwks_uri self.verify_ssl = verify_ssl try: - self.scopes = kwargs['scopes'] + self.scopes = kwargs["scopes"] except KeyError: - self.scopes = ['offline_access'] + self.scopes = ["offline_access"] self.keyjar = keyjar if self.keyjar is None: self.keyjar = KeyJar(verify_ssl=self.verify_ssl) if capabilities: - self.capabilities = self.provider_features( - provider_config=capabilities) + self.capabilities = self.provider_features(provider_config=capabilities) else: self.capabilities = self.provider_features() self.baseurl = baseurl or name @@ -140,31 +186,32 @@ def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, self.kid = {"sig": {}, "enc": {}} self.config = config or {} self.behavior = behavior or {} - self.token_policy = {'access_token': {}, 'refresh_token': {}} + self.token_policy = {"access_token": {}, "refresh_token": {}} if lifetime_policy is None: self.lifetime_policy = { - 'access_token': { - 'code': 600, - 'token': 120, - 'implicit': 120, - 'authorization_code': 600, - 'client_credentials': 600, - 'password': 600 + "access_token": { + "code": 600, + "token": 120, + "implicit": 120, + "authorization_code": 600, + "client_credentials": 600, + "password": 600, + }, + "refresh_token": { + "code": 3600, + "token": 3600, + "implicit": 3600, + "authorization_code": 3600, + "client_credentials": 3600, + "password": 3600, }, - 'refresh_token': { - 'code': 3600, - 'token': 3600, - 'implicit': 3600, - 'authorization_code': 3600, - 'client_credentials': 3600, - 'password': 3600 - } } else: self.lifetime_policy = lifetime_policy - self.token_handler = TokenHandler(self.baseurl, self.token_policy, - keyjar=self.keyjar) + self.token_handler = TokenHandler( + self.baseurl, self.token_policy, keyjar=self.keyjar + ) @staticmethod def _uris_to_tuples(uris): @@ -201,16 +248,19 @@ def load_keys(self, request, client_id, client_secret): logger.error(msg.format(sanitize(request.to_dict()))) logger.error("%s", err) err = ClientRegistrationError( - error="invalid_configuration_parameter", - error_description="%s" % err) - return Response(err.to_json(), content="application/json", - status_code="400 Bad Request") + error="invalid_configuration_parameter", error_description="%s" % err + ) + return Response( + err.to_json(), content="application/json", status_code="400 Bad Request" + ) # Add the client_secret as a symmetric key to the keyjar - _kc = KeyBundle([{"kty": "oct", "key": client_secret, - "use": "ver"}, - {"kty": "oct", "key": client_secret, - "use": "sig"}]) + _kc = KeyBundle( + [ + {"kty": "oct", "key": client_secret, "use": "ver"}, + {"kty": "oct", "key": client_secret, "use": "sig"}, + ] + ) try: self.keyjar[client_id].append(_kc) except KeyError: @@ -225,9 +275,9 @@ def verify_correct(cinfo, restrictions): raise RestrictionError(res) def set_token_policy(self, cid, cinfo): - for ttyp in ['access_token', 'refresh_token']: + for ttyp in ["access_token", "refresh_token"]: pol = {} - for rgtyp in ['response_type', 'grant_type']: + for rgtyp in ["response_type", "grant_type"]: try: rtyp = cinfo[rgtyp] except KeyError: @@ -267,17 +317,19 @@ def create_new_client(self, request, restrictions): if ClientInfoEndpoint in self.endp: _cinfo["registration_access_token"] = rndstr(32) _cinfo["registration_client_uri"] = "%s%s%s?client_id=%s" % ( - self.name, self.client_info_url, ClientInfoEndpoint.etype, - _id) + self.name, + self.client_info_url, + ClientInfoEndpoint.etype, + _id, + ) if "redirect_uris" in request: - _cinfo["redirect_uris"] = self._uris_to_tuples( - request["redirect_uris"]) + _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.load_keys(request, _id, _cinfo["client_secret"]) try: - _behav = self.behavior['client_registration'] + _behav = self.behavior["client_registration"] except KeyError: pass else: @@ -296,45 +348,41 @@ def match_client_request(self, request): match = False p = set(val.split(" ")) for cv in self.capabilities[_prov]: - if p == set(cv.split(' ')): + if p == set(cv.split(" ")): match = True break if not match: - raise CapabilitiesMisMatch( - 'Not allowed {}'.format(_pref)) + raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) else: if isinstance(request[_pref], str): if request[_pref] not in self.capabilities[_prov]: - raise CapabilitiesMisMatch( - 'Not allowed {}'.format(_pref)) + raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) else: if not set(request[_pref]).issubset( - set(self.capabilities[_prov])): - raise CapabilitiesMisMatch( - 'Not allowed {}'.format(_pref)) + set(self.capabilities[_prov]) + ): + raise CapabilitiesMisMatch("Not allowed {}".format(_pref)) def client_info(self, client_id): _cinfo = self.cdb[client_id].copy() if not valid_client_info(_cinfo): err = ErrorResponse( - error="invalid_client", - error_description="Invalid client secret") + error="invalid_client", error_description="Invalid client secret" + ) return BadRequest(err.to_json(), content="application/json") try: - _cinfo["redirect_uris"] = self._tuples_to_uris( - _cinfo["redirect_uris"]) + _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass - msg = self.server.message_factory.get_response_type('update_endpoint')(**_cinfo) + msg = self.server.message_factory.get_response_type("update_endpoint")(**_cinfo) return Response(msg.to_json(), content="application/json") def client_info_update(self, client_id, request): _cinfo = self.cdb[client_id].copy() try: - _cinfo["redirect_uris"] = self._tuples_to_uris( - _cinfo["redirect_uris"]) + _cinfo["redirect_uris"] = self._tuples_to_uris(_cinfo["redirect_uris"]) except KeyError: pass @@ -347,15 +395,18 @@ def client_info_update(self, client_id, request): _cinfo[key] = value for key in list(_cinfo.keys()): - if key in ["client_id_issued_at", "client_secret_expires_at", - "registration_access_token", "registration_client_uri"]: + if key in [ + "client_id_issued_at", + "client_secret_expires_at", + "registration_access_token", + "registration_client_uri", + ]: continue if key not in request: del _cinfo[key] if "redirect_uris" in request: - _cinfo["redirect_uris"] = self._uris_to_tuples( - request["redirect_uris"]) + _cinfo["redirect_uris"] = self._uris_to_tuples(request["redirect_uris"]) self.cdb[client_id] = _cinfo @@ -369,8 +420,7 @@ def verify_client(self, environ, areq, authn_method, client_id=""): :return: """ if not client_id: - client_id = get_client_id(self.cdb, areq, - environ["HTTP_AUTHORIZATION"]) + client_id = get_client_id(self.cdb, areq, environ["HTTP_AUTHORIZATION"]) try: method = self.client_authn_methods[authn_method] @@ -390,43 +440,49 @@ def registration_endpoint(self, **kwargs): :param kwargs: extra keyword arguments :return: A Response instance """ - _request = self.server.message_factory.get_request_type('registration_endpoint')().deserialize( - kwargs['request'], "json") + _request = self.server.message_factory.get_request_type( + "registration_endpoint" + )().deserialize(kwargs["request"], "json") try: _request.verify(keyjar=self.keyjar) except InvalidRedirectUri as err: - msg = ClientRegistrationError(error="invalid_redirect_uri", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_redirect_uri", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: - msg = ClientRegistrationError(error="invalid_client_metadata", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_client_metadata", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") # If authentication is necessary at registration if self.authn_at_registration: try: - self.verify_client(kwargs['environ'], _request, - self.authn_at_registration) + self.verify_client( + kwargs["environ"], _request, self.authn_at_registration + ) except (AuthnFailure, UnknownAssertionType): return Unauthorized() client_restrictions = {} - if 'parsed_software_statement' in _request: - for ss in _request['parsed_software_statement']: + if "parsed_software_statement" in _request: + for ss in _request["parsed_software_statement"]: client_restrictions.update(self.consume_software_statement(ss)) - del _request['software_statement'] - del _request['parsed_software_statement'] + del _request["software_statement"] + del _request["parsed_software_statement"] try: client_id = self.create_new_client(_request, client_restrictions) except CapabilitiesMisMatch as err: - msg = ClientRegistrationError(error="invalid_client_metadata", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_client_metadata", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") except RestrictionError as err: - msg = ClientRegistrationError(error="invalid_client_metadata", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_client_metadata", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") return self.client_info(client_id) @@ -439,7 +495,7 @@ def client_info_endpoint(self, method="GET", **kwargs): :param kwargs: keyword arguments :return: A Response instance """ - _query = compact(parse_qs(kwargs['query'])) + _query = compact(parse_qs(kwargs["query"])) try: _id = _query["client_id"] except KeyError: @@ -450,8 +506,9 @@ def client_info_endpoint(self, method="GET", **kwargs): # authenticated client try: - self.verify_client(kwargs['environ'], kwargs['request'], - "bearer_header", client_id=_id) + self.verify_client( + kwargs["environ"], kwargs["request"], "bearer_header", client_id=_id + ) except (AuthnFailure, UnknownAssertionType): return Unauthorized() @@ -459,20 +516,23 @@ def client_info_endpoint(self, method="GET", **kwargs): return self.client_info(_id) elif method == "PUT": try: - _request = self.server.message_factory.get_request_type('update_endpoint')().from_json( - kwargs['request']) + _request = self.server.message_factory.get_request_type( + "update_endpoint" + )().from_json(kwargs["request"]) except ValueError as err: return BadRequest(str(err)) try: _request.verify() except InvalidRedirectUri as err: - msg = ClientRegistrationError(error="invalid_redirect_uri", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_redirect_uri", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") except (MissingPage, VerificationError) as err: - msg = ClientRegistrationError(error="invalid_client_metadata", - error_description="%s" % err) + msg = ClientRegistrationError( + error="invalid_client_metadata", error_description="%s" % err + ) return BadRequest(msg.to_json(), content="application/json") try: @@ -499,7 +559,7 @@ def provider_features(self, pcr_class=ServerMetadata, provider_config=None): _provider_info["scopes_supported"] = self.scopes sign_algs = list(jws.SIGNER_ALGS.keys()) - sign_algs.remove('none') + sign_algs.remove("none") sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg)) _pat1 = "{}_endpoint_auth_signing_alg_values_supported" @@ -513,8 +573,7 @@ def provider_features(self, pcr_class=ServerMetadata, provider_config=None): return _provider_info - def create_providerinfo(self, pcr_class=ASConfigurationResponse, - setup=None): + def create_providerinfo(self, pcr_class=ASConfigurationResponse, setup=None): """ Dynamically create the provider info response. @@ -525,8 +584,9 @@ def create_providerinfo(self, pcr_class=ASConfigurationResponse, return super().create_providerinfo(pcr_class=pcr_class, setup=setup) @staticmethod - def verify_code_challenge(code_verifier, code_challenge, - code_challenge_method='S256'): + def verify_code_challenge( + code_verifier, code_challenge, code_challenge_method="S256" + ): """ Verify a PKCE (RFC7636) code challenge. @@ -534,30 +594,32 @@ def verify_code_challenge(code_verifier, code_challenge, :param code_challenge: The transformed verifier used as challenge :return: """ - _h = CC_METHOD[code_challenge_method]( - code_verifier.encode('ascii')).digest() + _h = CC_METHOD[code_challenge_method](code_verifier.encode("ascii")).digest() _cc = b64e(_h) - if _cc.decode('ascii') != code_challenge: - logger.error('PCKE Code Challenge check failed') - err = TokenErrorResponse(error="invalid_request", - error_description="PCKE check failed") - return Response(err.to_json(), content="application/json", - status_code=401) + if _cc.decode("ascii") != code_challenge: + logger.error("PCKE Code Challenge check failed") + err = TokenErrorResponse( + error="invalid_request", error_description="PCKE check failed" + ) + return Response(err.to_json(), content="application/json", status_code=401) return True - def do_access_token_response(self, access_token, atinfo, state, - refresh_token=None): - _tinfo = {'access_token': access_token, 'expires_in': atinfo['exp'], - 'token_type': 'bearer', 'state': state} + def do_access_token_response(self, access_token, atinfo, state, refresh_token=None): + _tinfo = { + "access_token": access_token, + "expires_in": atinfo["exp"], + "token_type": "bearer", + "state": state, + } try: - _tinfo['scope'] = atinfo['scope'] + _tinfo["scope"] = atinfo["scope"] except KeyError: pass if refresh_token: - _tinfo['refresh_token'] = refresh_token + _tinfo["refresh_token"] = refresh_token - atr_class = self.server.message_factory.get_response_type('token_endpoint') + atr_class = self.server.message_factory.get_response_type("token_endpoint") return atr_class(**by_schema(atr_class, **_tinfo)) def code_grant_type(self, areq): @@ -565,27 +627,29 @@ def code_grant_type(self, areq): try: _info = self.sdb[areq["code"]] except KeyError: - err = TokenErrorResponse(error="invalid_grant", - error_description="Unknown access grant") - return Response(err.to_json(), content="application/json", - status="401 Unauthorized") - - authzreq = json.loads(_info['authzreq']) - if 'code_verifier' in areq: + err = TokenErrorResponse( + error="invalid_grant", error_description="Unknown access grant" + ) + return Response( + err.to_json(), content="application/json", status="401 Unauthorized" + ) + + authzreq = json.loads(_info["authzreq"]) + if "code_verifier" in areq: try: - _method = authzreq['code_challenge_method'] + _method = authzreq["code_challenge_method"] except KeyError: - _method = 'S256' + _method = "S256" - resp = self.verify_code_challenge(areq['code_verifier'], - authzreq['code_challenge'], - _method) + resp = self.verify_code_challenge( + areq["code_verifier"], authzreq["code_challenge"], _method + ) if resp: return resp - if 'state' in areq: - if self.sdb[areq['code']]['state'] != areq['state']: - logger.error('State value mismatch') + if "state" in areq: + if self.sdb[areq["code"]]["state"] != areq["state"]: + logger.error("State value mismatch") err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") @@ -596,45 +660,51 @@ def code_grant_type(self, areq): # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: - logger.error('Redirect_uri mismatch') + logger.error("Redirect_uri mismatch") err = TokenErrorResponse(error="unauthorized_client") return Unauthorized(err.to_json(), content="application/json") issue_refresh = False - if 'scope' in authzreq and 'offline_access' in authzreq['scope']: - if authzreq['response_type'] == 'code': + if "scope" in authzreq and "offline_access" in authzreq["scope"]: + if authzreq["response_type"] == "code": issue_refresh = True try: - _tinfo = self.sdb.upgrade_to_token(areq["code"], - issue_refresh=issue_refresh) + _tinfo = self.sdb.upgrade_to_token( + areq["code"], issue_refresh=issue_refresh + ) except AccessCodeUsed: - err = TokenErrorResponse(error="invalid_grant", - error_description="Access grant used") - return Response(err.to_json(), content="application/json", - status="401 Unauthorized") + err = TokenErrorResponse( + error="invalid_grant", error_description="Access grant used" + ) + return Response( + err.to_json(), content="application/json", status="401 Unauthorized" + ) logger.debug("_tinfo: %s" % _tinfo) - atr_class = self.server.message_factory.get_response_type('token_endpoint') + atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **_tinfo)) logger.debug("AccessTokenResponse: %s" % atr) - return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + return Response( + atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS + ) def client_credentials_grant_type(self, areq): - _at = self.token_handler.get_access_token(areq['client_id'], - scope=areq['scope'], - grant_type='client_credentials') + _at = self.token_handler.get_access_token( + areq["client_id"], scope=areq["scope"], grant_type="client_credentials" + ) _info = self.token_handler.token_factory.get_info(_at) try: _rt = self.token_handler.get_refresh_token( - self.baseurl, _info['access_token'], 'client_credentials') + self.baseurl, _info["access_token"], "client_credentials" + ) except NotAllowed: - atr = self.do_access_token_response(_at, _info, areq['state']) + atr = self.do_access_token_response(_at, _info, areq["state"]) else: - atr = self.do_access_token_response(_at, _info, areq['state'], _rt) + atr = self.do_access_token_response(_at, _info, areq["state"], _rt) return Response(atr.to_json(), content="application/json") @@ -648,28 +718,36 @@ def password_grant_type(self, areq): try: authn, authn_class_ref = self.pick_auth(areq, "any") except IndexError: - err = TokenErrorResponse(error='invalid_grant') - return Unauthorized(err.to_json(), content='application/json') - identity, _ts = authn.authenticated_as(username=areq['username'], password=areq['password']) + err = TokenErrorResponse(error="invalid_grant") + return Unauthorized(err.to_json(), content="application/json") + identity, _ts = authn.authenticated_as( + username=areq["username"], password=areq["password"] + ) if identity is None: - err = TokenErrorResponse(error='invalid_grant') - return Unauthorized(err.to_json(), content='application/json') + err = TokenErrorResponse(error="invalid_grant") + return Unauthorized(err.to_json(), content="application/json") # We are returning a token - areq['response_type'] = ['token'] - authn_event = AuthnEvent(identity["uid"], identity.get('salt', ''), - authn_info=authn_class_ref, - time_stamp=_ts) - sid = self.setup_session(areq, authn_event, self.cdb[areq['client_id']]) - _at = self.sdb.upgrade_to_token(self.sdb[sid]['code'], issue_refresh=True) - atr_class = self.server.message_factory.get_response_type('token_endpoint') + areq["response_type"] = ["token"] + authn_event = AuthnEvent( + identity["uid"], + identity.get("salt", ""), + authn_info=authn_class_ref, + time_stamp=_ts, + ) + sid = self.setup_session(areq, authn_event, self.cdb[areq["client_id"]]) + _at = self.sdb.upgrade_to_token(self.sdb[sid]["code"], issue_refresh=True) + atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **_at)) - return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + return Response( + atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS + ) def refresh_token_grant_type(self, areq): at = self.token_handler.refresh_access_token( - self.baseurl, areq['access_token'], 'refresh_token') + self.baseurl, areq["access_token"], "refresh_token" + ) - atr_class = self.server.message_factory.get_response_type('token_endpoint') + atr_class = self.server.message_factory.get_response_type("token_endpoint") atr = atr_class(**by_schema(atr_class, **at)) return Response(atr.to_json(), content="application/json") @@ -678,16 +756,16 @@ def token_access(endpoint, client_id, token_info): # simple rules: if client_id in azp or aud it's allow to introspect # to revoke it has to be in azr allow = False - if endpoint == 'revocation_endpoint': - if 'azr' in token_info and client_id == token_info['azr']: + if endpoint == "revocation_endpoint": + if "azr" in token_info and client_id == token_info["azr"]: allow = True - elif len(token_info['aud']) == 1 and token_info['aud'] == [client_id]: + elif len(token_info["aud"]) == 1 and token_info["aud"] == [client_id]: allow = True else: # has to be introspection endpoint - if 'azr' in token_info and client_id == token_info['azr']: + if "azr" in token_info and client_id == token_info["azr"]: allow = True - elif 'aud' in token_info: - if client_id in token_info['aud']: + elif "aud" in token_info: + if client_id in token_info["aud"]: allow = True return allow @@ -703,31 +781,34 @@ def get_token_info(self, authn, req, endpoint): client_id = self.client_authn(self, req, authn) except FailedAuthentication as err: logger.error(err) - err = TokenErrorResponse(error="unauthorized_client", - error_description="%s" % err) - return Response(err.to_json(), content="application/json", - status="401 Unauthorized") + err = TokenErrorResponse( + error="unauthorized_client", error_description="%s" % err + ) + return Response( + err.to_json(), content="application/json", status="401 Unauthorized" + ) - logger.debug('{}: {} requesting {}'.format(endpoint, client_id, - req.to_dict())) + logger.debug("{}: {} requesting {}".format(endpoint, client_id, req.to_dict())) try: - token_type = req['token_type_hint'] + token_type = req["token_type_hint"] except KeyError: try: - _info = self.sdb.token_factory['access_token'].get_info(req['token']) + _info = self.sdb.token_factory["access_token"].get_info(req["token"]) except Exception: try: - _info = self.sdb.token_factory['refresh_token'].get_info(req['token']) + _info = self.sdb.token_factory["refresh_token"].get_info( + req["token"] + ) except Exception: return self._return_inactive() else: - token_type = 'refresh_token' + token_type = "refresh_token" else: - token_type = 'access_token' + token_type = "access_token" else: try: - _info = self.sdb.token_factory[token_type].get_info(req['token']) + _info = self.sdb.token_factory[token_type].get_info(req["token"]) except Exception: return self._return_inactive() @@ -737,10 +818,12 @@ def get_token_info(self, authn, req, endpoint): return client_id, token_type, _info def _return_inactive(self): - ir = self.server.message_factory.get_response_type('introspection_endpoint')(active=False) + ir = self.server.message_factory.get_response_type("introspection_endpoint")( + active=False + ) return Response(ir.to_json(), content="application/json") - def revocation_endpoint(self, authn='', request=None, **kwargs): + def revocation_endpoint(self, authn="", request=None, **kwargs): """ Implement RFC7009 allows a client to invalidate an access or refresh token. @@ -749,25 +832,27 @@ def revocation_endpoint(self, authn='', request=None, **kwargs): :param kwargs: :return: """ - trr = self.server.message_factory.get_request_type('revocation_endpoint')().deserialize(request, "urlencoded") + trr = self.server.message_factory.get_request_type( + "revocation_endpoint" + )().deserialize(request, "urlencoded") - resp = self.get_token_info(authn, trr, 'revocation_endpoint') + resp = self.get_token_info(authn, trr, "revocation_endpoint") if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp - logger.info('{} token revocation: {}'.format(client_id, trr.to_dict())) + logger.info("{} token revocation: {}".format(client_id, trr.to_dict())) try: - self.sdb.token_factory[token_type].invalidate(trr['token']) + self.sdb.token_factory[token_type].invalidate(trr["token"]) except KeyError: return BadRequest() else: - return Response('OK') + return Response("OK") - def introspection_endpoint(self, authn='', request=None, **kwargs): + def introspection_endpoint(self, authn="", request=None, **kwargs): """ Implement RFC7662. @@ -776,22 +861,22 @@ def introspection_endpoint(self, authn='', request=None, **kwargs): :param kwargs: :return: """ - tir = self.server.message_factory.get_request_type('introspection_endpoint')().deserialize(request, - "urlencoded") + tir = self.server.message_factory.get_request_type( + "introspection_endpoint" + )().deserialize(request, "urlencoded") - resp = self.get_token_info(authn, tir, 'introspection_endpoint') + resp = self.get_token_info(authn, tir, "introspection_endpoint") if isinstance(resp, Response): return resp else: client_id, token_type, _info = resp - logger.info('{} token introspection: {}'.format(client_id, - tir.to_dict())) + logger.info("{} token introspection: {}".format(client_id, tir.to_dict())) - ir = self.server.message_factory.get_response_type('introspection_endpoint')( - active=self.sdb.token_factory[token_type].is_valid(_info), - **_info.to_dict()) + ir = self.server.message_factory.get_response_type("introspection_endpoint")( + active=self.sdb.token_factory[token_type].is_valid(_info), **_info.to_dict() + ) ir.weed() diff --git a/src/oic/extension/signed_http_req.py b/src/oic/extension/signed_http_req.py index e1ce889d8..b1e4ecc6f 100644 --- a/src/oic/extension/signed_http_req.py +++ b/src/oic/extension/signed_http_req.py @@ -52,8 +52,7 @@ def _serialize_params(params, str_format, hash_size): return [_keys, _hash] -def _verify_params(params, req, str_format, hash_size, strict_verification, - key): +def _verify_params(params, req, str_format, hash_size, strict_verification, key): key_order, req_hash = req if strict_verification and len(key_order) != len(params): @@ -73,18 +72,18 @@ def _upper(s): SIMPLE_OPER = { - "method": ('m', _upper), - "host": ('u', None), - "path": ('p', None), - "time_stamp": ('ts', int), + "method": ("m", _upper), + "host": ("u", None), + "path": ("p", None), + "time_stamp": ("ts", int), } QUERY_PARAM_FORMAT = "{}={}" REQUEST_HEADER_FORMAT = "{}: {}" PARAM_ARGS = { - 'query_params': ('q', QUERY_PARAM_FORMAT), - 'headers': ('h', REQUEST_HEADER_FORMAT) + "query_params": ("q", QUERY_PARAM_FORMAT), + "headers": ("h", REQUEST_HEADER_FORMAT), } @@ -107,13 +106,12 @@ def sign(self, alg, **kwargs): for arg, (key, format) in PARAM_ARGS.items(): try: - http_json[key] = _serialize_params(kwargs[arg], format, - hash_size) + http_json[key] = _serialize_params(kwargs[arg], format, hash_size) except KeyError: pass try: - http_json['b'] = b64_hash(kwargs['body'], hash_size) + http_json["b"] = b64_hash(kwargs["body"], hash_size) except KeyError: pass @@ -141,7 +139,7 @@ def verify(self, signature, **kwargs): hash_size = get_hash_size(_header["alg"]) for arg, (key, func) in SIMPLE_OPER.items(): - if arg == 'time_stamp': + if arg == "time_stamp": continue try: if func is None: @@ -154,26 +152,33 @@ def verify(self, signature, **kwargs): for arg, (key, format) in PARAM_ARGS.items(): try: - _attr = 'strict_{}_verification'.format(arg) + _attr = "strict_{}_verification".format(arg) _strict_verify = kwargs[_attr] except KeyError: _strict_verify = False try: - _verify_params(kwargs[arg], unpacked_req[key], format, - hash_size, _strict_verify, key) + _verify_params( + kwargs[arg], + unpacked_req[key], + format, + hash_size, + _strict_verify, + key, + ) except KeyError: pass - if 'b' not in unpacked_req and 'body' not in kwargs: + if "b" not in unpacked_req and "body" not in kwargs: pass - elif 'b' in unpacked_req and 'body' in kwargs: - _equals(b64_hash(kwargs.get("body", ""), hash_size), - unpacked_req.get("b", "")) + elif "b" in unpacked_req and "body" in kwargs: + _equals( + b64_hash(kwargs.get("body", ""), hash_size), unpacked_req.get("b", "") + ) else: - if 'b' in unpacked_req: - raise ValidationError('Body sent but not received!!') + if "b" in unpacked_req: + raise ValidationError("Body sent but not received!!") else: - raise ValidationError('Body received but not sent!!') + raise ValidationError("Body received but not sent!!") return unpacked_req diff --git a/src/oic/extension/sts.py b/src/oic/extension/sts.py index 6833c4c10..160cccdb7 100644 --- a/src/oic/extension/sts.py +++ b/src/oic/extension/sts.py @@ -18,26 +18,26 @@ from oic.oic.message import SINGLE_REQUIRED_INT from oic.oic.message import msg_ser -__author__ = 'roland' +__author__ = "roland" class TokenExchangeRequest(Message): c_param = { - 'grant_type': SINGLE_REQUIRED_STRING, - 'resource': SINGLE_OPTIONAL_STRING, - 'audience': SINGLE_OPTIONAL_STRING, - 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, - 'requested_token_type': SINGLE_OPTIONAL_STRING, - 'subject_token': SINGLE_REQUIRED_STRING, - 'subject_token_type': SINGLE_REQUIRED_STRING, - 'actor_token': SINGLE_OPTIONAL_STRING, - 'actor_token_type': SINGLE_OPTIONAL_STRING, - 'want_composite': SINGLE_OPTIONAL_STRING + "grant_type": SINGLE_REQUIRED_STRING, + "resource": SINGLE_OPTIONAL_STRING, + "audience": SINGLE_OPTIONAL_STRING, + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, + "requested_token_type": SINGLE_OPTIONAL_STRING, + "subject_token": SINGLE_REQUIRED_STRING, + "subject_token_type": SINGLE_REQUIRED_STRING, + "actor_token": SINGLE_OPTIONAL_STRING, + "actor_token_type": SINGLE_OPTIONAL_STRING, + "want_composite": SINGLE_OPTIONAL_STRING, } def verify(self, **kwargs): - if 'actor_token' in self: - if not 'actor_token_type': + if "actor_token" in self: + if not "actor_token_type": return False @@ -48,7 +48,7 @@ class TokenExchangeResponse(Message): "token_type": SINGLE_REQUIRED_STRING, "expires_in": SINGLE_OPTIONAL_INT, "refresh_token": SINGLE_OPTIONAL_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } @@ -72,6 +72,6 @@ class STS(Message): "exp": SINGLE_REQUIRED_INT, "nbf": SINGLE_REQUIRED_INT, "sub": SINGLE_REQUIRED_STRING, - 'act': SINGLE_OPTIONAL_STS, - 'scp': OPTIONAL_LIST_OF_STRINGS + "act": SINGLE_OPTIONAL_STS, + "scp": OPTIONAL_LIST_OF_STRINGS, } diff --git a/src/oic/extension/token.py b/src/oic/extension/token.py index 93084dd13..4b22b34ca 100644 --- a/src/oic/extension/token.py +++ b/src/oic/extension/token.py @@ -9,7 +9,7 @@ from oic.utils.sdb import Token from oic.utils.time_util import utc_time_sans_frac -__author__ = 'roland' +__author__ = "roland" class TokenAssertion(Message): @@ -17,25 +17,24 @@ class TokenAssertion(Message): "iss": SINGLE_REQUIRED_STRING, "azp": SINGLE_REQUIRED_STRING, "sub": SINGLE_REQUIRED_STRING, - 'kid': SINGLE_REQUIRED_STRING, + "kid": SINGLE_REQUIRED_STRING, "exp": SINGLE_REQUIRED_INT, - 'jti': SINGLE_REQUIRED_STRING, + "jti": SINGLE_REQUIRED_STRING, "aud": OPTIONAL_LIST_OF_STRINGS, # Array of strings or string } class JWTToken(Token, JWT): - usage = 'authorization_grant' + usage = "authorization_grant" - def __init__(self, typ, keyjar, lt_pattern=None, extra_claims=None, - **kwargs): + def __init__(self, typ, keyjar, lt_pattern=None, extra_claims=None, **kwargs): self.type = typ JWT.__init__(self, keyjar, msgtype=TokenAssertion, **kwargs) Token.__init__(self, typ, **kwargs) self.lt_pattern = lt_pattern or {} self.db = {} - self.session_info = {'': 600} - self.exp_args = ['sinfo'] + self.session_info = {"": 600} + self.exp_args = ["sinfo"] self.extra_claims = extra_claims or {} def __call__(self, sid, *args, **kwargs): @@ -45,80 +44,80 @@ def __call__(self, sid, *args, **kwargs): :return: """ try: - _sinfo = kwargs['sinfo'] + _sinfo = kwargs["sinfo"] except KeyError: exp = self.do_exp(**kwargs) - _tid = kwargs['target_id'] + _tid = kwargs["target_id"] else: - if 'lifetime' in kwargs: - _sinfo['lifetime'] = kwargs['lifetime'] + if "lifetime" in kwargs: + _sinfo["lifetime"] = kwargs["lifetime"] exp = self.do_exp(**_sinfo) - _tid = _sinfo['client_id'] - if 'scope' not in kwargs: + _tid = _sinfo["client_id"] + if "scope" not in kwargs: _scope = None try: - _scope = _sinfo['scope'] + _scope = _sinfo["scope"] except KeyError: - ar = json.loads(_sinfo['authzreq']) + ar = json.loads(_sinfo["authzreq"]) try: - _scope = ar['scope'] + _scope = ar["scope"] except KeyError: pass if _scope: - kwargs['scope'] = ' ' .join(_scope) + kwargs["scope"] = " ".join(_scope) - if self.usage == 'authorization_grant': + if self.usage == "authorization_grant": try: - kwargs['sub'] = _sinfo['sub'] + kwargs["sub"] = _sinfo["sub"] except KeyError: pass - del kwargs['sinfo'] + del kwargs["sinfo"] - if 'aud' in kwargs: - if _tid not in kwargs['aud']: - kwargs['aud'].append(_tid) + if "aud" in kwargs: + if _tid not in kwargs["aud"]: + kwargs["aud"].append(_tid) else: - kwargs['aud'] = [_tid] + kwargs["aud"] = [_tid] - if self.usage == 'client_authentication': + if self.usage == "client_authentication": try: - kwargs['sub'] = _tid + kwargs["sub"] = _tid except KeyError: pass else: - if 'azp' not in kwargs: - kwargs['azp'] = _tid + if "azp" not in kwargs: + kwargs["azp"] = _tid - for param in ['lifetime', 'grant_type', 'response_type', 'target_id']: + for param in ["lifetime", "grant_type", "response_type", "target_id"]: try: del kwargs[param] except KeyError: pass try: - kwargs['kid'] = self.extra_claims['kid'] + kwargs["kid"] = self.extra_claims["kid"] except KeyError: pass - _jti = '{}-{}'.format(self.type, uuid.uuid4().hex) + _jti = "{}-{}".format(self.type, uuid.uuid4().hex) _jwt = self.pack(jti=_jti, exp=exp, **kwargs) self.db[_jti] = sid return _jwt def do_exp(self, **kwargs): try: - lifetime = kwargs['lifetime'] + lifetime = kwargs["lifetime"] except KeyError: try: - rt = ' '.join(kwargs['response_type']) + rt = " ".join(kwargs["response_type"]) except KeyError: - rt = ' '.join(kwargs['grant_type']) + rt = " ".join(kwargs["grant_type"]) try: lifetime = self.lt_pattern[rt] except KeyError: - lifetime = self.lt_pattern[''] + lifetime = self.lt_pattern[""] return utc_time_sans_frac() + lifetime @@ -130,7 +129,7 @@ def type_and_key(self, token): :return: tuple of token type and session id """ msg = self.unpack(token) - return self.type, self.db[msg['jti']] + return self.type, self.db[msg["jti"]] def get_key(self, token): """ @@ -140,7 +139,7 @@ def get_key(self, token): :return: The session id """ msg = self.unpack(token) - return self.db[msg['jti']] + return self.db[msg["jti"]] def get_type(self, token): """ @@ -155,22 +154,22 @@ def get_type(self, token): def invalidate(self, token): info = self.unpack(token) try: - del self.db[info['jti']] + del self.db[info["jti"]] except KeyError: return False return True def is_valid(self, info): - if info['jti'] in self.db: - if info['exp'] >= utc_time_sans_frac(): + if info["jti"] in self.db: + if info["exp"] >= utc_time_sans_frac(): return True return False def expires_at(self, token): info = self.unpack(token) - return info['exp'] + return info["exp"] def valid(self, token): info = self.unpack(token) @@ -181,8 +180,8 @@ def get_info(self, token): class Authorization_Grant(JWTToken): - usage = 'authorization_grant' + usage = "authorization_grant" class Client_Authentication(JWTToken): - usage = 'client_authentication' + usage = "client_authentication" diff --git a/src/oic/oauth2/__init__.py b/src/oic/oauth2/__init__.py index be06c6ea2..27fe711d6 100644 --- a/src/oic/oauth2/__init__.py +++ b/src/oic/oauth2/__init__.py @@ -52,7 +52,7 @@ from oic.utils.keyio import KeyJar from oic.utils.time_util import utc_time_sans_frac -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -69,17 +69,17 @@ # ROPCAccessTokenRequest: "authorization_endpoint", # CCAccessTokenRequest: "authorization_endpoint", "RefreshAccessTokenRequest": "token_endpoint", - "TokenRevocationRequest": "token_endpoint"} + "TokenRevocationRequest": "token_endpoint", +} RESPONSE2ERROR = { "AuthorizationResponse": [AuthorizationErrorResponse, TokenErrorResponse], - "AccessTokenResponse": [TokenErrorResponse] + "AccessTokenResponse": [TokenErrorResponse], } # type: Dict[str, List] -ENDPOINTS = ["authorization_endpoint", "token_endpoint", - "token_revocation_endpoint"] +ENDPOINTS = ["authorization_endpoint", "token_endpoint", "token_revocation_endpoint"] -ENCODINGS = Literal['json', 'urlencoded'] +ENCODINGS = Literal["json", "urlencoded"] class ExpiredToken(PyoidcError): @@ -88,10 +88,13 @@ class ExpiredToken(PyoidcError): # ============================================================================= + def error_response(error, descr=None, status_code=400): logger.error("%s" % sanitize(error)) response = ErrorResponse(error=error, error_description=descr) - return Response(response.to_json(), content="application/json", status_code=status_code) + return Response( + response.to_json(), content="application/json", status_code=status_code + ) def none_response(**kwargs): @@ -111,8 +114,7 @@ def authz_error(error, descr=None): return Response(response.to_json(), content="application/json", status_code=400) -def redirect_authz_error(error, redirect_uri, descr=None, state="", - return_type=None): +def redirect_authz_error(error, redirect_uri, descr=None, state="", return_type=None): err = AuthorizationErrorResponse(error=error) if descr: err["error_description"] = descr @@ -135,10 +137,11 @@ def exception_to_error_mesg(excep): else: resp = BadRequest() else: - err = ErrorResponse(error='service_error', - error_description='{}:{}'.format( - excep.__class__.__name__, excep.args)) - resp = BadRequest(err.to_json(), content='application/json') + err = ErrorResponse( + error="service_error", + error_description="{}:{}".format(excep.__class__.__name__, excep.args), + ) + resp = BadRequest(err.to_json(), content="application/json") return resp @@ -151,15 +154,24 @@ def compact(qsdict): res[key] = val return res + # ============================================================================= class Client(PBase): _endpoints = ENDPOINTS - def __init__(self, client_id=None, client_authn_method=None, - keyjar=None, verify_ssl=True, config=None, client_cert=None, - timeout=5, message_factory: Type[MessageFactory] = OauthMessageFactory): + def __init__( + self, + client_id=None, + client_authn_method=None, + keyjar=None, + verify_ssl=True, + config=None, + client_cert=None, + timeout=5, + message_factory: Type[MessageFactory] = OauthMessageFactory, + ): """ Initialize the instance. @@ -176,8 +188,13 @@ def __init__(self, client_id=None, client_authn_method=None, :param: message_factory: Factory for message classes, should inherit from OauthMessageFactory :return: Client instance """ - PBase.__init__(self, verify_ssl=verify_ssl, keyjar=keyjar, - client_cert=client_cert, timeout=timeout) + PBase.__init__( + self, + verify_ssl=verify_ssl, + keyjar=keyjar, + client_cert=client_cert, + timeout=timeout, + ) self.client_id = client_id self.client_authn_method = client_authn_method @@ -201,7 +218,7 @@ def __init__(self, client_id=None, client_authn_method=None, self.token_class = Token self.provider_info = ASConfigurationResponse() # type: Message - self._c_secret = '' # type: str + self._c_secret = "" # type: str self.kid = {"sig": {}, "enc": {}} # type: Dict[str, Dict] self.authz_req = {} # type: Dict[str, Message] @@ -209,9 +226,9 @@ def __init__(self, client_id=None, client_authn_method=None, # configuration information location self.config = config or {} try: - self.issuer = self.config['issuer'] + self.issuer = self.config["issuer"] except KeyError: - self.issuer = '' + self.issuer = "" self.allow = {} # type: Dict[str, Any] def store_response(self, clinst, text): @@ -327,7 +344,9 @@ def clean_tokens(self) -> None: if token.replaced or not token.is_valid(): grant.delete_token(token) - def construct_request(self, request: Type[Message], request_args=None, extra_args=None): + def construct_request( + self, request: Type[Message], request_args=None, extra_args=None + ): if request_args is None: request_args = {} @@ -338,17 +357,26 @@ def construct_request(self, request: Type[Message], request_args=None, extra_arg logger.debug("request: %s" % sanitize(request)) return request(**kwargs) - def construct_Message(self, request: Type[Message] = Message, request_args=None, - extra_args=None, **kwargs) -> Message: + def construct_Message( + self, + request: Type[Message] = Message, + request_args=None, + extra_args=None, + **kwargs + ) -> Message: return self.construct_request(request, request_args, extra_args) - def construct_AuthorizationRequest(self, request: Type[AuthorizationRequest] = None, - request_args=None, extra_args=None, - **kwargs) -> AuthorizationRequest: + def construct_AuthorizationRequest( + self, + request: Type[AuthorizationRequest] = None, + request_args=None, + extra_args=None, + **kwargs + ) -> AuthorizationRequest: if request is None: - request = self.message_factory.get_request_type('authorization_endpoint') + request = self.message_factory.get_request_type("authorization_endpoint") if request_args is not None: try: # change default new = request_args["redirect_uri"] @@ -366,27 +394,31 @@ def construct_AuthorizationRequest(self, request: Type[AuthorizationRequest] = N return self.construct_request(request, request_args, extra_args) - def construct_AccessTokenRequest(self, - request: Type[AccessTokenRequest] = None, - request_args=None, extra_args=None, - **kwargs) -> AccessTokenRequest: + def construct_AccessTokenRequest( + self, + request: Type[AccessTokenRequest] = None, + request_args=None, + extra_args=None, + **kwargs + ) -> AccessTokenRequest: if request is None: - request = self.message_factory.get_request_type('token_endpoint') + request = self.message_factory.get_request_type("token_endpoint") if request_args is None: request_args = {} if request is not ROPCAccessTokenRequest: grant = self.get_grant(**kwargs) if not grant.is_valid(): - raise GrantExpired("Authorization Code to old %s > %s" % ( - utc_time_sans_frac(), - grant.grant_expiration_time)) + raise GrantExpired( + "Authorization Code to old %s > %s" + % (utc_time_sans_frac(), grant.grant_expiration_time) + ) request_args["code"] = grant.code try: - request_args['state'] = kwargs['state'] + request_args["state"] = kwargs["state"] except KeyError: pass @@ -399,13 +431,16 @@ def construct_AccessTokenRequest(self, request_args["client_id"] = self.client_id return self.construct_request(request, request_args, extra_args) - def construct_RefreshAccessTokenRequest(self, - request: Type[RefreshAccessTokenRequest] = None, - request_args=None, extra_args=None, - **kwargs) -> RefreshAccessTokenRequest: + def construct_RefreshAccessTokenRequest( + self, + request: Type[RefreshAccessTokenRequest] = None, + request_args=None, + extra_args=None, + **kwargs + ) -> RefreshAccessTokenRequest: if request is None: - request = self.message_factory.get_request_type('refresh_endpoint') + request = self.message_factory.get_request_type("refresh_endpoint") if request_args is None: request_args = {} @@ -420,12 +455,16 @@ def construct_RefreshAccessTokenRequest(self, return self.construct_request(request, request_args, extra_args) - def construct_ResourceRequest(self, request: Type[ResourceRequest] = None, - request_args=None, extra_args=None, - **kwargs) -> ResourceRequest: + def construct_ResourceRequest( + self, + request: Type[ResourceRequest] = None, + request_args=None, + extra_args=None, + **kwargs + ) -> ResourceRequest: if request is None: - request = self.message_factory.get_request_type('resource_endpoint') + request = self.message_factory.get_request_type("resource_endpoint") if request_args is None: request_args = {} @@ -434,8 +473,14 @@ def construct_ResourceRequest(self, request: Type[ResourceRequest] = None, request_args["access_token"] = token.access_token return self.construct_request(request, request_args, extra_args) - def uri_and_body(self, reqmsg: Type[Message], cis: Message, method="POST", request_args=None, - **kwargs) -> Tuple[str, str, Dict, Message]: + def uri_and_body( + self, + reqmsg: Type[Message], + cis: Message, + method="POST", + request_args=None, + **kwargs + ) -> Tuple[str, str, Dict, Message]: if "endpoint" in kwargs and kwargs["endpoint"]: uri = kwargs["endpoint"] else: @@ -449,31 +494,37 @@ def uri_and_body(self, reqmsg: Type[Message], cis: Message, method="POST", reque return uri, body, h_args, cis - def request_info(self, request: Type[Message], method="POST", request_args=None, - extra_args=None, lax=False, **kwargs) -> Tuple[str, str, Dict, Message]: + def request_info( + self, + request: Type[Message], + method="POST", + request_args=None, + extra_args=None, + lax=False, + **kwargs + ) -> Tuple[str, str, Dict, Message]: if request_args is None: request_args = {} try: cls = getattr(self, "construct_%s" % request.__name__) - cis = cls(request_args=request_args, extra_args=extra_args, - **kwargs) + cis = cls(request_args=request_args, extra_args=extra_args, **kwargs) except AttributeError: cis = self.construct_request(request, request_args, extra_args) if self.events: - self.events.store('Protocol request', cis) + self.events.store("Protocol request", cis) - if 'nonce' in cis and 'state' in cis: - self.state2nonce[cis['state']] = cis['nonce'] + if "nonce" in cis and "state" in cis: + self.state2nonce[cis["state"]] = cis["nonce"] cis.lax = lax if "authn_method" in kwargs: - h_arg = self.init_authentication_method(cis, - request_args=request_args, - **kwargs) + h_arg = self.init_authentication_method( + cis, request_args=request_args, **kwargs + ) else: h_arg = None @@ -483,17 +534,20 @@ def request_info(self, request: Type[Message], method="POST", request_args=None, else: kwargs["headers"] = h_arg["headers"] - return self.uri_and_body(request, cis, method, request_args, - **kwargs) + return self.uri_and_body(request, cis, method, request_args, **kwargs) - def authorization_request_info(self, request_args=None, extra_args=None, - **kwargs): - return self.request_info(self.message_factory.get_request_type('authorization_endpoint'), "GET", - request_args, extra_args, **kwargs) + def authorization_request_info(self, request_args=None, extra_args=None, **kwargs): + return self.request_info( + self.message_factory.get_request_type("authorization_endpoint"), + "GET", + request_args, + extra_args, + **kwargs + ) @staticmethod def get_urlinfo(info: str) -> str: - if '?' in info or '#' in info: + if "?" in info or "#" in info: parts = urlparse(info) scheme, netloc, path, params, query, fragment = parts[:6] # either query of fragment @@ -503,8 +557,14 @@ def get_urlinfo(info: str) -> str: info = fragment return info - def parse_response(self, response: Type[Message], info: str = "", sformat: ENCODINGS = "json", state: str = "", - **kwargs) -> Message: + def parse_response( + self, + response: Type[Message], + info: str = "", + sformat: ENCODINGS = "json", + state: str = "", + **kwargs + ) -> Message: """ Parse a response. @@ -525,7 +585,7 @@ def parse_response(self, response: Type[Message], info: str = "", sformat: ENCOD msg = 'Initial response parsing => "{}"' logger.debug(msg.format(sanitize(resp.to_dict()))) if self.events: - self.events.store('Response', resp.to_dict()) + self.events.store("Response", resp.to_dict()) if "error" in resp and not isinstance(resp, ErrorResponse): resp = None @@ -550,10 +610,10 @@ def parse_response(self, response: Type[Message], info: str = "", sformat: ENCOD else: kwargs["client_id"] = self.client_id try: - kwargs['iss'] = self.provider_info['issuer'] + kwargs["iss"] = self.provider_info["issuer"] except (KeyError, AttributeError): if self.issuer: - kwargs['iss'] = self.issuer + kwargs["iss"] = self.issuer if "key" not in kwargs and "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar @@ -562,7 +622,7 @@ def parse_response(self, response: Type[Message], info: str = "", sformat: ENCOD verf = resp.verify(**kwargs) if not verf: - logger.error('Verification of the response failed') + logger.error("Verification of the response failed") raise PyoidcError("Verification of the response failed") if resp.type() == "AuthorizationResponse" and "scope" not in resp: try: @@ -571,7 +631,7 @@ def parse_response(self, response: Type[Message], info: str = "", sformat: ENCOD pass if not resp: - logger.error('Missing or faulty response') + logger.error("Missing or faulty response") raise ResponseError("Missing or faulty response") self.store_response(resp, info) @@ -592,8 +652,9 @@ def parse_response(self, response: Type[Message], info: str = "", sformat: ENCOD return resp - def init_authentication_method(self, cis, authn_method, request_args=None, - http_args=None, **kwargs): + def init_authentication_method( + self, cis, authn_method, request_args=None, http_args=None, **kwargs + ): if http_args is None: http_args = {} @@ -602,45 +663,45 @@ def init_authentication_method(self, cis, authn_method, request_args=None, if authn_method: return self.client_authn_method[authn_method](self).construct( - cis, request_args, http_args, **kwargs) + cis, request_args, http_args, **kwargs + ) else: return http_args - def parse_request_response(self, reqresp, response, body_type, state="", - **kwargs): + def parse_request_response(self, reqresp, response, body_type, state="", **kwargs): if reqresp.status_code in SUCCESSFUL: body_type = verify_header(reqresp, body_type) elif reqresp.status_code in [302, 303]: # redirect return reqresp elif reqresp.status_code == 500: - logger.error("(%d) %s" % (reqresp.status_code, - sanitize(reqresp.text))) + logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) raise ParseError("ERROR: Something went wrong: %s" % reqresp.text) elif reqresp.status_code in [400, 401]: # expecting an error response if issubclass(response, ErrorResponse): pass else: - logger.error("(%d) %s" % (reqresp.status_code, - sanitize(reqresp.text))) - raise HttpError("HTTP ERROR: %s [%s] on %s" % ( - reqresp.text, reqresp.status_code, reqresp.url)) + logger.error("(%d) %s" % (reqresp.status_code, sanitize(reqresp.text))) + raise HttpError( + "HTTP ERROR: %s [%s] on %s" + % (reqresp.text, reqresp.status_code, reqresp.url) + ) if response: - if body_type == 'txt': + if body_type == "txt": # no meaning trying to parse unstructured text return reqresp.text - return self.parse_response(response, reqresp.text, body_type, - state, **kwargs) + return self.parse_response( + response, reqresp.text, body_type, state, **kwargs + ) # could be an error response if reqresp.status_code in [200, 400, 401]: - if body_type == 'txt': - body_type = 'urlencoded' + if body_type == "txt": + body_type = "urlencoded" try: - err = ErrorResponse().deserialize(reqresp.message, - method=body_type) + err = ErrorResponse().deserialize(reqresp.message, method=body_type) try: err.verify() except PyoidcError: @@ -652,9 +713,17 @@ def parse_request_response(self, reqresp, response, body_type, state="", return reqresp - def request_and_return(self, url: str, response: Type[Message] = None, method="GET", body=None, - body_type: ENCODINGS = "json", state: str = "", http_args=None, - **kwargs): + def request_and_return( + self, + url: str, + response: Type[Message] = None, + method="GET", + body=None, + body_type: ENCODINGS = "json", + state: str = "", + http_args=None, + **kwargs + ): """ Perform a request and return the response. @@ -678,27 +747,40 @@ def request_and_return(self, url: str, response: Type[Message] = None, method="G if "keyjar" not in kwargs: kwargs["keyjar"] = self.keyjar - return self.parse_request_response(resp, response, body_type, state, - **kwargs) - - def do_authorization_request(self, request=None, - state="", body_type="", method="GET", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - **kwargs) -> AuthorizationResponse: + return self.parse_request_response(resp, response, body_type, state, **kwargs) + + def do_authorization_request( + self, + request=None, + state="", + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ) -> AuthorizationResponse: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - request = self.message_factory.get_request_type('authorization_endpoint') + request = self.message_factory.get_request_type("authorization_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - response_cls = self.message_factory.get_response_type('authorization_endpoint') + response_cls = self.message_factory.get_response_type( + "authorization_endpoint" + ) if state: try: @@ -706,10 +788,10 @@ def do_authorization_request(self, request=None, except TypeError: request_args = {"state": state} - kwargs['authn_endpoint'] = 'authorization' - url, body, ht_args, csi = self.request_info(request, method, - request_args, extra_args, - **kwargs) + kwargs["authn_endpoint"] = "authorization" + url, body, ht_args, csi = self.request_info( + request, method, request_args, extra_args, **kwargs + ) try: self.authz_req[request_args["state"]] = csi @@ -726,9 +808,16 @@ def do_authorization_request(self, request=None, except KeyError: algs = {} - resp = self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args, algs=algs) + resp = self.request_and_return( + url, + response_cls, + method, + body, + body_type, + state=state, + http_args=http_args, + algs=algs, + ) if isinstance(resp, Message): # FIXME: The Message classes do not have classical attrs @@ -737,33 +826,51 @@ def do_authorization_request(self, request=None, return resp - def do_access_token_request(self, request=None, - scope: str = "", state: str = "", body_type: ENCODINGS = "json", - method="POST", request_args=None, - extra_args=None, http_args=None, - response_cls=None, - authn_method="", **kwargs) -> AccessTokenResponse: + def do_access_token_request( + self, + request=None, + scope: str = "", + state: str = "", + body_type: ENCODINGS = "json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + authn_method="", + **kwargs + ) -> AccessTokenResponse: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - request = self.message_factory.get_request_type('token_endpoint') + request = self.message_factory.get_request_type("token_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - response_cls = self.message_factory.get_response_type('token_endpoint') + response_cls = self.message_factory.get_response_type("token_endpoint") - kwargs['authn_endpoint'] = 'token' + kwargs["authn_endpoint"] = "token" # method is default POST - url, body, ht_args, csi = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state, - authn_method=authn_method, - **kwargs) + url, body, ht_args, csi = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + authn_method=authn_method, + **kwargs + ) if http_args is None: http_args = ht_args @@ -771,79 +878,118 @@ def do_access_token_request(self, request=None, http_args.update(ht_args) if self.events is not None: - self.events.store('request_url', url) - self.events.store('request_http_args', http_args) - self.events.store('Request', body) + self.events.store("request_url", url) + self.events.store("request_http_args", http_args) + self.events.store("Request", body) - logger.debug(" URL: %s, Body: %s" % (url, - sanitize(body))) + logger.debug(" URL: %s, Body: %s" % (url, sanitize(body))) logger.debug(" response_cls: %s" % response_cls) - return self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args, **kwargs) - - def do_access_token_refresh(self, request=None, - state: str = "", body_type: ENCODINGS = "json", method="POST", - request_args=None, extra_args=None, - http_args=None, - response_cls=None, - authn_method="", **kwargs) -> AccessTokenResponse: + return self.request_and_return( + url, + response_cls, + method, + body, + body_type, + state=state, + http_args=http_args, + **kwargs + ) + + def do_access_token_refresh( + self, + request=None, + state: str = "", + body_type: ENCODINGS = "json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + authn_method="", + **kwargs + ) -> AccessTokenResponse: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - request = self.message_factory.get_request_type('refresh_endpoint') + request = self.message_factory.get_request_type("refresh_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: # TODO: This can be moved to the call once we remove the kwarg - response_cls = self.message_factory.get_response_type('refresh_endpoint') + response_cls = self.message_factory.get_response_type("refresh_endpoint") token = self.get_token(also_expired=True, state=state, **kwargs) - kwargs['authn_endpoint'] = 'refresh' - url, body, ht_args, csi = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - token=token, - authn_method=authn_method) + kwargs["authn_endpoint"] = "refresh" + url, body, ht_args, csi = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + token=token, + authn_method=authn_method, + ) if http_args is None: http_args = ht_args else: http_args.update(ht_args) - response = self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args) + response = self.request_and_return( + url, response_cls, method, body, body_type, state=state, http_args=http_args + ) if token.replaced: grant = self.get_grant(state) grant.delete_token(token) return response - def do_any(self, request: Type[Message], endpoint="", scope="", state="", body_type="json", - method="POST", request_args=None, extra_args=None, - http_args=None, response: Type[Message] = None, authn_method="") -> Message: - - url, body, ht_args, _ = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state, - authn_method=authn_method, - endpoint=endpoint) + def do_any( + self, + request: Type[Message], + endpoint="", + scope="", + state="", + body_type="json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response: Type[Message] = None, + authn_method="", + ) -> Message: + + url, body, ht_args, _ = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + authn_method=authn_method, + endpoint=endpoint, + ) if http_args is None: http_args = ht_args else: http_args.update(ht_args) - return self.request_and_return(url, response, method, body, body_type, - state=state, http_args=http_args) + return self.request_and_return( + url, response, method, body, body_type, state=state, http_args=http_args + ) - def fetch_protected_resource(self, uri, method="GET", headers=None, - state="", **kwargs): + def fetch_protected_resource( + self, uri, method="GET", headers=None, state="", **kwargs + ): if "token" in kwargs and kwargs["token"]: token = kwargs["token"] @@ -862,11 +1008,13 @@ def fetch_protected_resource(self, uri, method="GET", headers=None, if "authn_method" in kwargs: http_args = self.init_authentication_method( - request_args=request_args, **kwargs) + request_args=request_args, **kwargs + ) else: # If nothing defined this is the default - http_args = self.client_authn_method[ - "bearer_header"](self).construct(request_args=request_args) + http_args = self.client_authn_method["bearer_header"](self).construct( + request_args=request_args + ) headers.update(http_args["headers"]) @@ -880,32 +1028,38 @@ def add_code_challenge(self): :return: """ try: - cv_len = self.config['code_challenge']['length'] + cv_len = self.config["code_challenge"]["length"] except KeyError: cv_len = 64 # Use default code_verifier = unreserved(cv_len) - _cv = code_verifier.encode('ascii') + _cv = code_verifier.encode("ascii") try: - _method = self.config['code_challenge']['method'] + _method = self.config["code_challenge"]["method"] except KeyError: - _method = 'S256' + _method = "S256" try: _h = CC_METHOD[_method](_cv).digest() - code_challenge = b64e(_h).decode('ascii') + code_challenge = b64e(_h).decode("ascii") except KeyError: - raise Unsupported( - 'PKCE Transformation method:{}'.format(_method)) + raise Unsupported("PKCE Transformation method:{}".format(_method)) # TODO store code_verifier - return {"code_challenge": code_challenge, - "code_challenge_method": _method}, code_verifier - - def handle_provider_config(self, pcr: ASConfigurationResponse, issuer: str, keys: bool = True, - endpoints: bool = True) -> None: + return ( + {"code_challenge": code_challenge, "code_challenge_method": _method}, + code_verifier, + ) + + def handle_provider_config( + self, + pcr: ASConfigurationResponse, + issuer: str, + keys: bool = True, + endpoints: bool = True, + ) -> None: """ Deal with Provider Config Response. @@ -928,7 +1082,10 @@ def handle_provider_config(self, pcr: ASConfigurationResponse, issuer: str, keys _issuer = issuer if not self.allow.get("issuer_mismatch", False) and _issuer != _pcr_issuer: - raise PyoidcError("provider info issuer mismatch '%s' != '%s'" % (_issuer, _pcr_issuer)) + raise PyoidcError( + "provider info issuer mismatch '%s' != '%s'" + % (_issuer, _pcr_issuer) + ) self.provider_info = pcr else: @@ -947,14 +1104,24 @@ def handle_provider_config(self, pcr: ASConfigurationResponse, issuer: str, keys self.keyjar.load_keys(pcr, _pcr_issuer) - def provider_config(self, issuer: str, keys: bool = True, endpoints: bool = True, - response_cls: Type[ASConfigurationResponse] = None, - serv_pattern: str = OIDCONF_PATTERN) -> ASConfigurationResponse: + def provider_config( + self, + issuer: str, + keys: bool = True, + endpoints: bool = True, + response_cls: Type[ASConfigurationResponse] = None, + serv_pattern: str = OIDCONF_PATTERN, + ) -> ASConfigurationResponse: if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('configuration_endpoint') + response_cls = self.message_factory.get_response_type( + "configuration_endpoint" + ) if issuer.endswith("/"): _issuer = issuer[:-1] else: @@ -983,10 +1150,21 @@ def provider_config(self, issuer: str, keys: bool = True, endpoints: bool = True class Server(PBase): """OAuth Server class.""" - def __init__(self, verify_ssl: bool = True, keyjar: KeyJar = None, client_cert: Union[str, Tuple[str, str]] = None, - timeout: int = 5, message_factory: Type[MessageFactory] = OauthMessageFactory): + def __init__( + self, + verify_ssl: bool = True, + keyjar: KeyJar = None, + client_cert: Union[str, Tuple[str, str]] = None, + timeout: int = 5, + message_factory: Type[MessageFactory] = OauthMessageFactory, + ): """Initialize the server.""" - super().__init__(verify_ssl=verify_ssl, keyjar=keyjar, client_cert=client_cert, timeout=timeout) + super().__init__( + verify_ssl=verify_ssl, + keyjar=keyjar, + client_cert=client_cert, + timeout=timeout, + ) self.message_factory = message_factory @staticmethod @@ -1002,47 +1180,68 @@ def parse_url_request(request, url=None, query=None): req.verify() return req - def parse_authorization_request(self, request: Type[AuthorizationRequest] = None, - url: str = None, query: dict = None) -> AuthorizationRequest: + def parse_authorization_request( + self, + request: Type[AuthorizationRequest] = None, + url: str = None, + query: dict = None, + ) -> AuthorizationRequest: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('authorization_endpoint') + request = self.message_factory.get_request_type("authorization_endpoint") return self.parse_url_request(request, url, query) - def parse_jwt_request(self, request: Type[Message] = AuthorizationRequest, txt: str = "", - keyjar: KeyJar = None, verify: bool = True, **kwargs) -> Message: + def parse_jwt_request( + self, + request: Type[Message] = AuthorizationRequest, + txt: str = "", + keyjar: KeyJar = None, + verify: bool = True, + **kwargs + ) -> Message: if not keyjar: keyjar = self.keyjar - areq = request().deserialize(txt, "jwt", keyjar=keyjar, - verify=verify, **kwargs) + areq = request().deserialize(txt, "jwt", keyjar=keyjar, verify=verify, **kwargs) if verify: areq.verify() return areq - def parse_body_request(self, request: Type[Message] = AccessTokenRequest, - body: str = None): + def parse_body_request( + self, request: Type[Message] = AccessTokenRequest, body: str = None + ): req = request().deserialize(body, "urlencoded") req.verify() return req - def parse_token_request(self, request: Type[AccessTokenRequest] = None, - body: str = None) -> AccessTokenRequest: + def parse_token_request( + self, request: Type[AccessTokenRequest] = None, body: str = None + ) -> AccessTokenRequest: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('token_endpoint') + request = self.message_factory.get_request_type("token_endpoint") return self.parse_body_request(request, body) - def parse_refresh_token_request(self, request: Type[RefreshAccessTokenRequest] = None, - body: str = None) -> RefreshAccessTokenRequest: + def parse_refresh_token_request( + self, request: Type[RefreshAccessTokenRequest] = None, body: str = None + ) -> RefreshAccessTokenRequest: if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory`.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory`.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('refresh_endpoint') + request = self.message_factory.get_request_type("refresh_endpoint") return self.parse_body_request(request, body) diff --git a/src/oic/oauth2/base.py b/src/oic/oauth2/base.py index d91034b41..fb43ef2ec 100644 --- a/src/oic/oauth2/base.py +++ b/src/oic/oauth2/base.py @@ -11,7 +11,7 @@ from oic.utils.keyio import KeyJar from oic.utils.sanitize import sanitize -__author__ = 'roland' +__author__ = "roland" logger = logging.getLogger(__name__) @@ -93,12 +93,13 @@ def http_request(self, url, method="GET", **kwargs): r = requests.request(method, url, **_kwargs) except Exception as err: logger.error( - "http_request failed: %s, url: %s, htargs: %s, method: %s" % ( - err, url, sanitize(_kwargs), method)) + "http_request failed: %s, url: %s, htargs: %s, method: %s" + % (err, url, sanitize(_kwargs), method) + ) raise if self.events is not None: - self.events.store('HTTP response', r, ref=url) + self.events.store("HTTP response", r, ref=url) try: _cookie = r.headers["set-cookie"] @@ -116,11 +117,13 @@ def http_request(self, url, method="GET", **kwargs): def send(self, url, method="GET", **kwargs): return self.http_request(url, method, **kwargs) - def load_cookies_from_file(self, filename, ignore_discard=False, - ignore_expires=False): + def load_cookies_from_file( + self, filename, ignore_discard=False, ignore_expires=False + ): self.cookiejar.load(filename, ignore_discard, ignore_expires) - def save_cookies_to_file(self, filename, ignore_discard=False, - ignore_expires=False): + def save_cookies_to_file( + self, filename, ignore_discard=False, ignore_expires=False + ): self.cookiejar.save(filename, ignore_discard, ignore_expires) diff --git a/src/oic/oauth2/consumer.py b/src/oic/oauth2/consumer.py index d2411d37b..ede461f10 100644 --- a/src/oic/oauth2/consumer.py +++ b/src/oic/oauth2/consumer.py @@ -16,11 +16,16 @@ from oic.utils import http_util from oic.utils.sanitize import sanitize -__author__ = 'rohe0002' +__author__ = "rohe0002" -ENDPOINTS = ["authorization_endpoint", "token_endpoint", "userinfo_endpoint", - "check_id_endpoint", "registration_endpoint", - "token_revokation_endpoint"] +ENDPOINTS = [ + "authorization_endpoint", + "token_endpoint", + "userinfo_endpoint", + "check_id_endpoint", + "registration_endpoint", + "token_revokation_endpoint", +] logger = logging.getLogger(__name__) @@ -92,9 +97,17 @@ class MissingAuthenticationInfo(PyoidcError): class Consumer(Client): """An OAuth2 consumer implementation.""" - def __init__(self, session_db, client_config=None, - server_info=None, authz_page="", response_type="", - scope="", flow_type="", password=None): + def __init__( + self, + session_db, + client_config=None, + server_info=None, + authz_page="", + response_type="", + scope="", + flow_type="", + password=None, + ): """ Initialize a Consumer instance. @@ -215,8 +228,11 @@ def begin(self, baseurl, request, response_type="", **kwargs): self.response_type = response_type = "code" location = self.request_info( - AuthorizationRequest, method="GET", scope=self.scope, - request_args={"state": sid, "response_type": response_type})[0] + AuthorizationRequest, + method="GET", + scope=self.scope, + request_args={"state": sid, "response_type": response_type}, + )[0] logger.debug("Redirecting to: %s" % (sanitize(location),)) @@ -235,8 +251,9 @@ def handle_authorization_response(self, query="", **kwargs): if "code" in self.response_type: # Might be an error response try: - aresp = self.parse_response(AuthorizationResponse, - info=query, sformat="urlencoded") + aresp = self.parse_response( + AuthorizationResponse, info=query, sformat="urlencoded" + ) except Exception as err: logger.error("%s" % err) raise @@ -254,9 +271,9 @@ def handle_authorization_response(self, query="", **kwargs): return aresp else: # implicit flow - atr = self.parse_response(AccessTokenResponse, - info=query, sformat="urlencoded", - extended=True) + atr = self.parse_response( + AccessTokenResponse, info=query, sformat="urlencoded", extended=True + ) if isinstance(atr, Message): if atr.type().endswith("ErrorResponse"): @@ -293,8 +310,10 @@ def client_auth_info(self): extra_args = {} elif self.client_secret: http_args = {} - request_args = {"client_secret": self.client_secret, - "client_id": self.client_id} + request_args = { + "client_secret": self.client_secret, + "client_id": self.client_id, + } extra_args = {"auth_method": "bearer_body"} else: raise MissingAuthenticationInfo("Nothing to authenticate with") @@ -304,10 +323,9 @@ def client_auth_info(self): def get_access_token_request(self, state, **kwargs): request_args, http_args, extra_args = self.client_auth_info() - url, body, ht_args, _ = self.request_info(AccessTokenRequest, - request_args=request_args, - state=state, - **extra_args) + url, body, ht_args, _ = self.request_info( + AccessTokenRequest, request_args=request_args, state=state, **extra_args + ) if not http_args: http_args = ht_args diff --git a/src/oic/oauth2/exception.py b/src/oic/oauth2/exception.py index 1c28aa941..46fb6050e 100644 --- a/src/oic/oauth2/exception.py +++ b/src/oic/oauth2/exception.py @@ -1,6 +1,6 @@ from oic.exception import PyoidcError -__author__ = 'roland' +__author__ = "roland" class HttpError(PyoidcError): diff --git a/src/oic/oauth2/grant.py b/src/oic/oauth2/grant.py index 360b67ee7..38767566b 100644 --- a/src/oic/oauth2/grant.py +++ b/src/oic/oauth2/grant.py @@ -4,7 +4,7 @@ from oic.oauth2.message import AuthorizationResponse from oic.utils.time_util import utc_time_sans_frac -__author__ = 'roland' +__author__ = "roland" class Token(object): diff --git a/src/oic/oauth2/message.py b/src/oic/oauth2/message.py index 1b959054d..1d3dba796 100644 --- a/src/oic/oauth2/message.py +++ b/src/oic/oauth2/message.py @@ -163,7 +163,7 @@ def _extract_cparam(key, _spec): The key can be direct attribute or lang typed attribute. If ParamDefinition is not found, tries to return "*" attribute, if it exists, otherwise returns None. """ - for _key in (key, key.split('#')[0], '*'): + for _key in (key, key.split("#")[0], "*"): if _key in _spec: return _spec[_key] return None @@ -178,8 +178,7 @@ def to_urlencoded(self, lev=0): if not self.lax: for attribute, cparam in _spec.items(): if cparam.required and attribute not in self._dict: - raise MissingRequiredAttribute("%s" % attribute, - "%s" % self) + raise MissingRequiredAttribute("%s" % attribute, "%s" % self) params = [] @@ -199,11 +198,10 @@ def to_urlencoded(self, lev=0): params.append((key, val.encode("utf-8"))) elif isinstance(val, list): if _ser: - params.append((key, str(_ser(val, sformat="urlencoded", - lev=lev)))) + params.append((key, str(_ser(val, sformat="urlencoded", lev=lev)))) else: for item in val: - params.append((key, str(item).encode('utf-8'))) + params.append((key, str(item).encode("utf-8"))) elif isinstance(val, Message): try: _val = json.dumps(_ser(val, sformat="dict", lev=lev + 1)) @@ -284,7 +282,7 @@ def from_urlencoded(self, urlencoded, **kwargs): except KeyError: raise ParameterError(key) else: - raise TooManyValues('{}'.format(key)) + raise TooManyValues("{}".format(key)) return self @@ -309,7 +307,9 @@ def to_dict(self, lev=0): if isinstance(val, Message): _res[key] = val.to_dict(lev + 1) - elif isinstance(val, list) and isinstance(next(iter(val or []), None), Message): + elif isinstance(val, list) and isinstance( + next(iter(val or []), None), Message + ): _res[key] = [v.to_dict(lev) for v in val] else: _res[key] = val @@ -326,11 +326,13 @@ def from_dict(self, dictionary, **kwargs): _spec = self.c_param for key, val in dictionary.items(): - if val in ('', ['']): + if val in ("", [""]): continue cparam = self._extract_cparam(key, _spec) if cparam is not None: - self._add_value(key, cparam.type, key, val, cparam.deserializer, cparam.null_allowed) + self._add_value( + key, cparam.type, key, val, cparam.deserializer, cparam.null_allowed + ) else: self._dict[key] = val return self @@ -349,7 +351,9 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): if vtyp is bool: self._dict[skey] = val else: - raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) + raise ValueError( + '"{}", wrong type of value for "{}"'.format(val, skey) + ) elif isinstance(val, vtyp): # Not necessary to do anything self._dict[skey] = val else: @@ -362,11 +366,15 @@ def _add_value(self, skey, vtyp, key, val, _deser, null_allowed): try: self._dict[skey] = int(val) except (ValueError, TypeError): - raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) + raise ValueError( + '"{}", wrong type of value for "{}"'.format(val, skey) + ) else: return elif vtyp is bool: - raise ValueError('"{}", wrong type of value for "{}"'.format(val, skey)) + raise ValueError( + '"{}", wrong type of value for "{}"'.format(val, skey) + ) if isinstance(val, str): self._dict[skey] = val @@ -420,7 +428,9 @@ def _add_value_list(self, skey, vtype, key, val, _deser, null_allowed): else: for v in val: if not isinstance(v, vtype): - raise DecodeError(ERRTXT % (key, "type != %s (%s)" % (vtype, type(v)))) + raise DecodeError( + ERRTXT % (key, "type != %s (%s)" % (vtype, type(v))) + ) self._dict[skey] = val return if isinstance(val, dict): @@ -454,15 +464,15 @@ def to_jwt(self, key=None, algorithm="", lev=0): _jws = JWS(self.to_json(lev), alg=algorithm) return _jws.sign_compact(key) - def _add_key(self, keyjar, issuer, key, key_type='', kid='', - no_kid_issuer=None): + def _add_key(self, keyjar, issuer, key, key_type="", kid="", no_kid_issuer=None): if issuer not in keyjar: logger.error('Issuer "{}" not in keyjar'.format(issuer)) return - logger.debug('Key set summary for {}: {}'.format( - issuer, key_summary(keyjar, issuer))) + logger.debug( + "Key set summary for {}: {}".format(issuer, key_summary(keyjar, issuer)) + ) if kid: _key = keyjar.get_key_by_kid(kid, issuer) @@ -502,9 +512,9 @@ def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs): :return: list of usable keys """ try: - _kid = header['kid'] + _kid = header["kid"] except KeyError: - _kid = '' + _kid = "" try: _iss = jso["iss"] @@ -517,8 +527,7 @@ def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs): # This is really questionable try: if kwargs["trusting"]: - keyjar.add(jso["iss"], - header["jku"]) + keyjar.add(jso["iss"], header["jku"]) except KeyError: pass @@ -535,18 +544,17 @@ def get_verify_keys(self, keyjar, key, jso, header, jwt, **kwargs): pass try: - nki = kwargs['no_kid_issuer'] + nki = kwargs["no_kid_issuer"] except KeyError: nki = {} try: - _key_type = alg2keytype(header['alg']) + _key_type = alg2keytype(header["alg"]) except KeyError: - _key_type = '' + _key_type = "" try: - self._add_key(keyjar, kwargs["opponent_id"], key, _key_type, _kid, - nki) + self._add_key(keyjar, kwargs["opponent_id"], key, _key_type, _kid, nki) except KeyError: pass @@ -583,9 +591,13 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): if "algs" in kwargs and "encalg" in kwargs["algs"]: if kwargs["algs"]["encalg"] != _jw["alg"]: - raise WrongEncryptionAlgorithm("%s != %s" % (_jw["alg"], kwargs["algs"]["encalg"])) + raise WrongEncryptionAlgorithm( + "%s != %s" % (_jw["alg"], kwargs["algs"]["encalg"]) + ) if kwargs["algs"]["encenc"] != _jw["enc"]: - raise WrongEncryptionAlgorithm("%s != %s" % (_jw["enc"], kwargs["algs"]["encenc"])) + raise WrongEncryptionAlgorithm( + "%s != %s" % (_jw["enc"], kwargs["algs"]["encenc"]) + ) if keyjar: dkeys = keyjar.get_decrypt_key(owner="") if "sender" in kwargs: @@ -595,9 +607,9 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): else: dkeys = [] - logger.debug('Decrypt class: {}'.format(_jw.__class__)) + logger.debug("Decrypt class: {}".format(_jw.__class__)) _res = _jw.decrypt(txt, dkeys) - logger.debug('decrypted message:{}'.format(_res)) + logger.debug("decrypted message:{}".format(_res)) if isinstance(_res, tuple): txt = as_unicode(_res[0]) elif isinstance(_res, list) and len(_res) == 2: @@ -611,7 +623,9 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): if "algs" in kwargs and "sign" in kwargs["algs"]: _alg = _jw.jwt.headers["alg"] if kwargs["algs"]["sign"] != _alg: - raise WrongSigningAlgorithm("%s != %s" % (_alg, kwargs["algs"]["sign"])) + raise WrongSigningAlgorithm( + "%s != %s" % (_alg, kwargs["algs"]["sign"]) + ) try: _jwt = JWT().unpack(txt) jso = _jwt.payload() @@ -631,13 +645,13 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): pass elif verify: if keyjar: - key = self.get_verify_keys(keyjar, key, jso, _header, - _jw, **kwargs) + key = self.get_verify_keys( + keyjar, key, jso, _header, _jw, **kwargs + ) if "alg" in _header and _header["alg"] != "none": if not key: - raise MissingSigningKey( - "alg=%s" % _header["alg"]) + raise MissingSigningKey("alg=%s" % _header["alg"]) logger.debug("Found signing key.") try: @@ -645,8 +659,9 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): except NoSuitableSigningKeys: if keyjar: update_keyjar(keyjar) - key = self.get_verify_keys(keyjar, key, jso, - _header, _jw, **kwargs) + key = self.get_verify_keys( + keyjar, key, jso, _header, _jw, **kwargs + ) _jw.verify_compact(txt, key) except Exception: raise @@ -659,7 +674,7 @@ def from_jwt(self, txt, key=None, verify=True, keyjar=None, **kwargs): return self.from_dict(jso) def __str__(self): - return '{}'.format(self.to_dict()) + return "{}".format(self.to_dict()) def _type_check(self, typ, _allowed, val, na=False): if typ is str: @@ -711,7 +726,9 @@ def verify(self, **kwargs): else: raise NotAllowedValue(val) else: - self._type_check(cparam.type, _allowed[attribute], val, cparam.null_allowed) + self._type_check( + cparam.type, _allowed[attribute], val, cparam.null_allowed + ) return True @@ -755,7 +772,14 @@ def request(self, location, fragment_enc=False): def __setitem__(self, key, value): try: cparam = self.c_param[key] - self._add_value(str(key), cparam.type, key, value, cparam.deserializer, cparam.null_allowed) + self._add_value( + str(key), + cparam.type, + key, + value, + cparam.deserializer, + cparam.null_allowed, + ) except KeyError: self._dict[key] = value @@ -777,8 +801,9 @@ def __len__(self): return len(self._dict) def extra(self): - return dict([(key, val) for key, val in - self._dict.items() if key not in self.c_param]) + return dict( + [(key, val) for key, val in self._dict.items() if key not in self.c_param] + ) def only_extras(self): extras = [key for key in self._dict.keys() if key in self.c_param] @@ -862,6 +887,7 @@ def add_non_standard(msg1, msg2): # ============================================================================= + def list_serializer(vals, sformat="urlencoded", lev=0): if isinstance(vals, str) or not isinstance(vals, list): raise ValueError("Expected list: %s" % vals) @@ -911,18 +937,34 @@ def json_deserializer(txt, sformat="urlencoded"): VDESER = 3 VNULLALLOWED = 4 -ParamDefinition = namedtuple('ParamDefinition', ['type', 'required', 'serializer', 'deserializer', 'null_allowed']) +ParamDefinition = namedtuple( + "ParamDefinition", + ["type", "required", "serializer", "deserializer", "null_allowed"], +) SINGLE_REQUIRED_STRING = ParamDefinition(str, True, None, None, False) SINGLE_OPTIONAL_STRING = ParamDefinition(str, False, None, None, False) SINGLE_OPTIONAL_INT = ParamDefinition(int, False, None, None, False) -OPTIONAL_LIST_OF_STRINGS = ParamDefinition([str], False, list_serializer, list_deserializer, False) -REQUIRED_LIST_OF_STRINGS = ParamDefinition([str], True, list_serializer, list_deserializer, False) -OPTIONAL_LIST_OF_SP_SEP_STRINGS = ParamDefinition([str], False, sp_sep_list_serializer, sp_sep_list_deserializer, False) -REQUIRED_LIST_OF_SP_SEP_STRINGS = ParamDefinition([str], True, sp_sep_list_serializer, sp_sep_list_deserializer, False) -SINGLE_OPTIONAL_JSON = ParamDefinition(str, False, json_serializer, json_deserializer, False) - -REQUIRED = [SINGLE_REQUIRED_STRING, REQUIRED_LIST_OF_STRINGS, - REQUIRED_LIST_OF_SP_SEP_STRINGS] +OPTIONAL_LIST_OF_STRINGS = ParamDefinition( + [str], False, list_serializer, list_deserializer, False +) +REQUIRED_LIST_OF_STRINGS = ParamDefinition( + [str], True, list_serializer, list_deserializer, False +) +OPTIONAL_LIST_OF_SP_SEP_STRINGS = ParamDefinition( + [str], False, sp_sep_list_serializer, sp_sep_list_deserializer, False +) +REQUIRED_LIST_OF_SP_SEP_STRINGS = ParamDefinition( + [str], True, sp_sep_list_serializer, sp_sep_list_deserializer, False +) +SINGLE_OPTIONAL_JSON = ParamDefinition( + str, False, json_serializer, json_deserializer, False +) + +REQUIRED = [ + SINGLE_REQUIRED_STRING, + REQUIRED_LIST_OF_STRINGS, + REQUIRED_LIST_OF_SP_SEP_STRINGS, +] # @@ -931,28 +973,43 @@ def json_deserializer(txt, sformat="urlencoded"): class ErrorResponse(Message): - c_param = {"error": SINGLE_REQUIRED_STRING, - "error_description": SINGLE_OPTIONAL_STRING, - "error_uri": SINGLE_OPTIONAL_STRING} + c_param = { + "error": SINGLE_REQUIRED_STRING, + "error_description": SINGLE_OPTIONAL_STRING, + "error_uri": SINGLE_OPTIONAL_STRING, + } class AuthorizationErrorResponse(ErrorResponse): c_param = ErrorResponse.c_param.copy() c_param.update({"state": SINGLE_OPTIONAL_STRING}) c_allowed_values = ErrorResponse.c_allowed_values.copy() - c_allowed_values.update({"error": ["invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", "server_error", - "temporarily_unavailable"]}) + c_allowed_values.update( + { + "error": [ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", + ] + } + ) class TokenErrorResponse(ErrorResponse): - c_allowed_values = {"error": ["invalid_request", "invalid_client", - "invalid_grant", "unauthorized_client", - "unsupported_grant_type", - "invalid_scope"]} + c_allowed_values = { + "error": [ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", + ] + } class AccessTokenRequest(Message): @@ -962,7 +1019,7 @@ class AccessTokenRequest(Message): "redirect_uri": SINGLE_REQUIRED_STRING, "client_id": SINGLE_OPTIONAL_STRING, "client_secret": SINGLE_OPTIONAL_STRING, - 'state': SINGLE_OPTIONAL_STRING + "state": SINGLE_OPTIONAL_STRING, } c_default = {"grant_type": "authorization_code"} @@ -981,27 +1038,27 @@ class AuthorizationResponse(Message): c_param = { "code": SINGLE_REQUIRED_STRING, "state": SINGLE_OPTIONAL_STRING, - 'iss': SINGLE_OPTIONAL_STRING, - 'client_id': SINGLE_OPTIONAL_STRING + "iss": SINGLE_OPTIONAL_STRING, + "client_id": SINGLE_OPTIONAL_STRING, } def verify(self, **kwargs): super(AuthorizationResponse, self).verify(**kwargs) - if 'client_id' in self: + if "client_id" in self: try: - if self['client_id'] != kwargs['client_id']: - raise VerificationError('client_id mismatch') + if self["client_id"] != kwargs["client_id"]: + raise VerificationError("client_id mismatch") except KeyError: - logger.info('No client_id to verify against') + logger.info("No client_id to verify against") pass - if 'iss' in self: + if "iss" in self: try: # Issuer URL for the authorization server issuing the response. - if self['iss'] != kwargs['iss']: - raise VerificationError('Issuer mismatch') + if self["iss"] != kwargs["iss"]: + raise VerificationError("Issuer mismatch") except KeyError: - logger.info('No issuer set in the Client config') + logger.info("No issuer set in the Client config") pass return True @@ -1014,14 +1071,12 @@ class AccessTokenResponse(Message): "expires_in": SINGLE_OPTIONAL_INT, "refresh_token": SINGLE_OPTIONAL_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, - "state": SINGLE_OPTIONAL_STRING + "state": SINGLE_OPTIONAL_STRING, } class NoneResponse(Message): - c_param = { - "state": SINGLE_OPTIONAL_STRING - } + c_param = {"state": SINGLE_OPTIONAL_STRING} class ROPCAccessTokenRequest(Message): @@ -1029,14 +1084,14 @@ class ROPCAccessTokenRequest(Message): "grant_type": SINGLE_REQUIRED_STRING, "username": SINGLE_OPTIONAL_STRING, "password": SINGLE_OPTIONAL_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } class CCAccessTokenRequest(Message): c_param = { "grant_type": SINGLE_REQUIRED_STRING, - "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, } c_default = {"grant_type": "client_credentials"} c_allowed_values = {"grant_type": ["client_credentials"]} @@ -1048,7 +1103,7 @@ class RefreshAccessTokenRequest(Message): "refresh_token": SINGLE_REQUIRED_STRING, "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, "client_id": SINGLE_OPTIONAL_STRING, - "client_secret": SINGLE_OPTIONAL_STRING + "client_secret": SINGLE_OPTIONAL_STRING, } c_default = {"grant_type": "refresh_token"} c_allowed_values = {"grant_type": ["refresh_token"]} @@ -1070,14 +1125,13 @@ class ASConfigurationResponse(Message): "response_modes_supported": OPTIONAL_LIST_OF_STRINGS, "grant_types_supported": REQUIRED_LIST_OF_STRINGS, "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, - "token_endpoint_auth_signing_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, "service_documentation": SINGLE_OPTIONAL_STRING, "ui_locales_supported": OPTIONAL_LIST_OF_STRINGS, "op_policy_uri": SINGLE_OPTIONAL_STRING, "op_tos_uri": SINGLE_OPTIONAL_STRING, - 'revocation_endpoint': SINGLE_OPTIONAL_STRING, - 'introspection_endpoint': SINGLE_OPTIONAL_STRING, + "revocation_endpoint": SINGLE_OPTIONAL_STRING, + "introspection_endpoint": SINGLE_OPTIONAL_STRING, } c_default = {"version": "3.0"} @@ -1096,22 +1150,25 @@ class ASConfigurationResponse(Message): "CCAccessTokenRequest": CCAccessTokenRequest, "RefreshAccessTokenRequest": RefreshAccessTokenRequest, "ResourceRequest": ResourceRequest, - 'ASConfigurationResponse': ASConfigurationResponse + "ASConfigurationResponse": ASConfigurationResponse, } def factory(msgtype): - warnings.warn('`factory` is deprecated. Use `OauthMessageFactory` instead.', DeprecationWarning) + warnings.warn( + "`factory` is deprecated. Use `OauthMessageFactory` instead.", + DeprecationWarning, + ) try: return MSG[msgtype] except KeyError: raise FormatError("Unknown message type: %s" % msgtype) -MessageTuple = namedtuple('MessageTuple', ['request_cls', 'response_cls']) +MessageTuple = namedtuple("MessageTuple", ["request_cls", "response_cls"]) -class MessageFactory(): +class MessageFactory: """Factory for holding message types.""" @classmethod @@ -1120,7 +1177,7 @@ def get_request_type(cls, endpoint: str): try: return getattr(cls, endpoint).request_cls except AttributeError: - raise MessageException('Unknown endpoint.') + raise MessageException("Unknown endpoint.") @classmethod def get_response_type(cls, endpoint: str): @@ -1128,7 +1185,7 @@ def get_response_type(cls, endpoint: str): try: return getattr(cls, endpoint).response_cls except AttributeError: - raise MessageException('Unknown endpoint.') + raise MessageException("Unknown endpoint.") class OauthMessageFactory(MessageFactory): diff --git a/src/oic/oauth2/provider.py b/src/oic/oauth2/provider.py index 92dabdc80..57ecf7eca 100644 --- a/src/oic/oauth2/provider.py +++ b/src/oic/oauth2/provider.py @@ -51,7 +51,7 @@ from oic.utils.sdb import AccessCodeUsed from oic.utils.sdb import AuthnEvent -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -85,7 +85,7 @@ class TokenEndpoint(Endpoint): def endpoint_ava(endp, baseurl): - key = '{}_endpoint'.format(endp.etype) + key = "{}_endpoint".format(endp.etype) val = urljoin(baseurl, endp.url) return {key: val} @@ -99,9 +99,9 @@ def code_response(**kwargs): pass aresp["code"] = kwargs["scode"] # TODO Add 'iss' and 'client_id' - if kwargs['myself']: - aresp['iss'] = kwargs['myself'] - aresp['client_id'] = _areq['client_id'] + if kwargs["myself"]: + aresp["iss"] = kwargs["myself"] + aresp["client_id"] = _areq["client_id"] add_non_standard(_areq, aresp) return aresp @@ -154,18 +154,40 @@ def re_authenticate(areq, authn): class Provider(object): endp = [AuthorizationEndpoint, TokenEndpoint] - def __init__(self, name, sdb, cdb, authn_broker, authz, client_authn, - symkey=None, urlmap=None, iv=0, default_scope="", - verify_ssl=True, default_acr="", keyjar=None, - baseurl='', server_cls=Server, client_cert=None, message_factory=OauthMessageFactory): + def __init__( + self, + name, + sdb, + cdb, + authn_broker, + authz, + client_authn, + symkey=None, + urlmap=None, + iv=0, + default_scope="", + verify_ssl=True, + default_acr="", + keyjar=None, + baseurl="", + server_cls=Server, + client_cert=None, + message_factory=OauthMessageFactory, + ): self.name = name self.sdb = sdb if not isinstance(cdb, BaseClientDatabase): - warnings.warn('ClientDatabase should be an instance of ' - 'oic.utils.clientdb.BaseClientDatabase to ensure proper API.') + warnings.warn( + "ClientDatabase should be an instance of " + "oic.utils.clientdb.BaseClientDatabase to ensure proper API." + ) self.cdb = cdb - self.server = server_cls(verify_ssl=verify_ssl, client_cert=client_cert, keyjar=keyjar, - message_factory=message_factory) + self.server = server_cls( + verify_ssl=verify_ssl, + client_cert=client_cert, + keyjar=keyjar, + message_factory=message_factory, + ) self.authn_broker = authn_broker if authn_broker is None: @@ -286,15 +308,14 @@ def _verify_redirect_uri(self, areq): try: cid = areq["client_id"] except KeyError: - logger.error('No client id found') - raise UnknownClient('No client_id provided') + logger.error("No client id found") + raise UnknownClient("No client_id provided") else: logger.info("Unknown client: %s" % cid) raise UnknownClient(areq["client_id"]) else: logger.info("Registered redirect_uris: %s" % sanitize(_cinfo)) - raise RedirectURIError( - "Faulty redirect_uri: %s" % areq["redirect_uri"]) + raise RedirectURIError("Faulty redirect_uri: %s" % areq["redirect_uri"]) def get_redirect_uri(self, areq): """ @@ -304,12 +325,13 @@ def get_redirect_uri(self, areq): :return: Tuple of (redirect_uri, Response instance) Response instance is not None of matching redirect_uri failed """ - if 'redirect_uri' in areq: + if "redirect_uri" in areq: self._verify_redirect_uri(areq) uri = areq["redirect_uri"] else: raise ParameterError( - "Missing redirect_uri and more than one or none registered") + "Missing redirect_uri and more than one or none registered" + ) return uri @@ -336,8 +358,9 @@ def pick_auth(self, areq, comparision_type=""): for acr in areq["acr_values"]: res = self.authn_broker.pick(acr, comparision_type) - logger.debug("Picked AuthN broker for ACR %s: %s" % ( - str(acr), str(res))) + logger.debug( + "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) + ) if res: # Return the best guess by pick. return res[0] @@ -349,16 +372,17 @@ def pick_auth(self, areq, comparision_type=""): else: for acr in acrs: res = self.authn_broker.pick(acr, comparision_type) - logger.debug("Picked AuthN broker for ACR %s: %s" % ( - str(acr), str(res))) + logger.debug( + "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) + ) if res: # Return the best guess by pick. return res[0] except KeyError as exc: logger.debug( - "An error occured while picking the authN broker: %s" % str( - exc)) + "An error occured while picking the authN broker: %s" % str(exc) + ) # return the best I have return None, None @@ -374,18 +398,23 @@ def auth_init(self, request, request_class=None): :return: """ if request_class is not None: - warnings.warn('Passing `request_class` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request_class` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - request_class = self.server.message_factory.get_request_type('authorization_endpoint') + request_class = self.server.message_factory.get_request_type( + "authorization_endpoint" + ) logger.debug("Request: '%s'" % sanitize(request)) # Same serialization used for GET and POST try: areq = self.server.parse_authorization_request( - request=request_class, query=request) - except (MissingRequiredValue, MissingRequiredAttribute, - AuthzError) as err: + request=request_class, query=request + ) + except (MissingRequiredValue, MissingRequiredAttribute, AuthzError) as err: logger.debug("%s" % err) areq = request_class() areq.lax = True @@ -404,10 +433,11 @@ def auth_init(self, request, request_class=None): try: _state = areq["state"] except KeyError: - _state = '' + _state = "" - return redirect_authz_error("invalid_request", redirect_uri, - "%s" % err, _state, _rtype) + return redirect_authz_error( + "invalid_request", redirect_uri, "%s" % err, _state, _rtype + ) except KeyError: areq = request_class().deserialize(request, "urlencoded") # verify the redirect_uri @@ -419,9 +449,8 @@ def auth_init(self, request, request_class=None): message = traceback.format_exception(*sys.exc_info()) logger.error(message) logger.debug("Bad request: %s (%s)" % (err, err.__class__.__name__)) - err = ErrorResponse(error='invalid_request', - error_description=str(err)) - return BadRequest(err.to_json(), content='application/json') + err = ErrorResponse(error="invalid_request", error_description=str(err)) + return BadRequest(err.to_json(), content="application/json") if not areq: logger.debug("No AuthzRequest") @@ -433,33 +462,36 @@ def auth_init(self, request, request_class=None): areq = self.filter_request(areq) if self.events: - self.events.store('Protocol request', areq) + self.events.store("Protocol request", areq) try: - _cinfo = self.cdb[areq['client_id']] + _cinfo = self.cdb[areq["client_id"]] except KeyError: logger.error( - 'Client ID ({}) not in client database'.format( - areq['client_id'])) - return error_response('unauthorized_client', 'unknown client') + "Client ID ({}) not in client database".format(areq["client_id"]) + ) + return error_response("unauthorized_client", "unknown client") else: try: - _registered = [set(rt.split(' ')) for rt in - _cinfo['response_types']] + _registered = [set(rt.split(" ")) for rt in _cinfo["response_types"]] except KeyError: # If no response_type is registered by the client then we'll # code which it the default according to the OIDC spec. - _registered = [{'code'}] + _registered = [{"code"}] _wanted = set(areq["response_type"]) if _wanted not in _registered: - return error_response("invalid_request", "Trying to use unregistered response_typ") + return error_response( + "invalid_request", "Trying to use unregistered response_typ" + ) logger.debug("AuthzRequest: %s" % (sanitize(areq.to_dict()),)) try: redirect_uri = self.get_redirect_uri(areq) except (RedirectURIError, ParameterError, UnknownClient) as err: - return error_response("invalid_request", "{}:{}".format(err.__class__.__name__, err)) + return error_response( + "invalid_request", "{}:{}".format(err.__class__.__name__, err) + ) try: keyjar = self.keyjar @@ -469,10 +501,8 @@ def auth_init(self, request, request_class=None): try: # verify that the request message is correct areq.verify(keyjar=keyjar, opponent_id=areq["client_id"]) - except (MissingRequiredAttribute, ValueError, - MissingRequiredValue) as err: - return redirect_authz_error("invalid_request", redirect_uri, - "%s" % err) + except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err: + return redirect_authz_error("invalid_request", redirect_uri, "%s" % err) return {"areq": areq, "redirect_uri": redirect_uri} @@ -514,8 +544,9 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): tup = (None, None) for acr in acrs: res = self.authn_broker.pick(acr, "exact") - logger.debug("Picked AuthN broker for ACR %s: %s" % ( - str(acr), str(res))) + logger.debug( + "Picked AuthN broker for ACR %s: %s" % (str(acr), str(res)) + ) if res: # Return the best guess by pick. tup = res[0] break @@ -528,8 +559,9 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): authn, authn_class_ref = self.pick_auth(areq, "any") if authn is None: - return redirect_authz_error("access_denied", redirect_uri, - return_type=areq["response_type"]) + return redirect_authz_error( + "access_denied", redirect_uri, return_type=areq["response_type"] + ) try: try: @@ -543,7 +575,8 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): _max_age = max_age(areq) identity, _ts = authn.authenticated_as( - cookie, authorization=_auth_info, max_age=_max_age) + cookie, authorization=_auth_info, max_age=_max_age + ) except (NoSuchAuthentication, TamperAllert): identity = None _ts = 0 @@ -565,7 +598,7 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): authn_args["query"] = request if "req_user" in kwargs: - authn_args["as_user"] = kwargs["req_user"], + authn_args["as_user"] = (kwargs["req_user"],) for attr in ["policy_uri", "logo_uri", "tos_uri"]: try: @@ -584,8 +617,8 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): if "prompt" in areq and "none" in areq["prompt"]: # Need to authenticate but not allowed return redirect_authz_error( - "login_required", redirect_uri, - return_type=areq["response_type"]) + "login_required", redirect_uri, return_type=areq["response_type"] + ) else: return authn(**authn_args) else: @@ -597,26 +630,30 @@ def do_auth(self, areq, redirect_uri, cinfo, request, cookie, **kwargs): user = identity["uid"] if "req_user" in kwargs: sids_for_sub = self.sdb.get_sids_by_sub(kwargs["req_user"]) - if sids_for_sub and user != \ - self.sdb.get_authentication_event( - sids_for_sub[-1]).uid: + if ( + sids_for_sub + and user + != self.sdb.get_authentication_event(sids_for_sub[-1]).uid + ): logger.debug("Wanted to be someone else!") if "prompt" in areq and "none" in areq["prompt"]: # Need to authenticate but not allowed - return redirect_authz_error("login_required", - redirect_uri) + return redirect_authz_error("login_required", redirect_uri) else: return authn(**authn_args) - authn_event = AuthnEvent(identity["uid"], identity.get('salt', ''), - authn_info=authn_class_ref, - time_stamp=_ts) + authn_event = AuthnEvent( + identity["uid"], + identity.get("salt", ""), + authn_info=authn_class_ref, + time_stamp=_ts, + ) return {"authn_event": authn_event, "identity": identity, "user": user} def setup_session(self, areq, authn_event, cinfo): sid = self.sdb.create_authz_session(authn_event, areq) - self.sdb.do_sub(sid, '') + self.sdb.do_sub(sid, "") return sid def authorization_endpoint(self, request="", cookie="", **kwargs): @@ -632,8 +669,9 @@ def authorization_endpoint(self, request="", cookie="", **kwargs): _cid = info["areq"]["client_id"] cinfo = self.cdb[_cid] - authnres = self.do_auth(info["areq"], info["redirect_uri"], - cinfo, request, cookie, **kwargs) + authnres = self.do_auth( + info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs + ) if isinstance(authnres, Response): return authnres @@ -641,11 +679,9 @@ def authorization_endpoint(self, request="", cookie="", **kwargs): logger.debug("- authenticated -") logger.debug("AREQ keys: %s" % info["areq"].keys()) - sid = self.setup_session(info["areq"], authnres["authn_event"], - cinfo) + sid = self.setup_session(info["areq"], authnres["authn_event"], cinfo) - return self.authz_part2(authnres["user"], info["areq"], sid, - cookie=cookie) + return self.authz_part2(authnres["user"], info["areq"], sid, cookie=cookie) def aresp_check(self, aresp, areq): return "" @@ -653,8 +689,9 @@ def aresp_check(self, aresp, areq): def create_authn_response(self, areq, sid): rtype = areq["response_type"][0] _func = self.response_type_map[rtype] - aresp = _func(areq=areq, scode=self.sdb[sid]["code"], sdb=self.sdb, - myself=self.baseurl) + aresp = _func( + areq=areq, scode=self.sdb[sid]["code"], sdb=self.sdb, myself=self.baseurl + ) if rtype == "code": fragment_enc = False @@ -666,10 +703,10 @@ def create_authn_response(self, areq, sid): def response_mode(self, areq, fragment_enc, **kwargs): resp_mode = areq["response_mode"] - if resp_mode == 'fragment' and not fragment_enc: + if resp_mode == "fragment" and not fragment_enc: # Can't be done raise InvalidRequest("wrong response_mode") - elif resp_mode == 'query' and fragment_enc: + elif resp_mode == "query" and fragment_enc: # Can't be done raise InvalidRequest("wrong response_mode") return None @@ -691,8 +728,8 @@ def authz_part2(self, user, areq, sid, **kwargs): aresp, headers, redirect_uri, fragment_enc = result # Mix-Up mitigation - aresp['iss'] = self.baseurl - aresp['client_id'] = areq['client_id'] + aresp["iss"] = self.baseurl + aresp["client_id"] = areq["client_id"] # Just do whatever is the default location = aresp.request(redirect_uri, fragment_enc) @@ -705,7 +742,7 @@ def _complete_authz(self, user, areq, sid, **kwargs): # Do the authorization try: - permission = self.authz(user, client_id=areq['client_id']) + permission = self.authz(user, client_id=areq["client_id"]) self.sdb.update(sid, "permission", permission) except Exception: raise @@ -742,14 +779,18 @@ def _complete_authz(self, user, areq, sid, **kwargs): except KeyError: _kaka = None - c_val = "{}{}{}".format(user, DELIM, areq['client_id']) + c_val = "{}{}{}".format(user, DELIM, areq["client_id"]) cookie_header = None if _kaka is not None: if self.cookie_name not in _kaka: # Don't overwrite - cookie_header = self.cookie_func(c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl) + cookie_header = self.cookie_func( + c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl + ) else: - cookie_header = self.cookie_func(c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl) + cookie_header = self.cookie_func( + c_val, typ="sso", cookie_name=self.sso_cookie_name, ttl=self.sso_ttl + ) if cookie_header is not None: headers.append(cookie_header) @@ -759,9 +800,13 @@ def _complete_authz(self, user, areq, sid, **kwargs): if "response_mode" in areq: try: - resp = self.response_mode(areq, fragment_enc, aresp=aresp, - redirect_uri=redirect_uri, - headers=headers) + resp = self.response_mode( + areq, + fragment_enc, + aresp=aresp, + redirect_uri=redirect_uri, + headers=headers, + ) except InvalidRequest as err: return error_response("invalid_request", str(err)) else: @@ -774,7 +819,7 @@ def token_scope_check(self, areq, info): """Not implemented here.""" return None - def token_endpoint(self, request='', authn='', dtype='urlencoded', **kwargs): + def token_endpoint(self, request="", authn="", dtype="urlencoded", **kwargs): """ Provide clients with access tokens. @@ -785,57 +830,69 @@ def token_endpoint(self, request='', authn='', dtype='urlencoded', **kwargs): logger.debug("- token -") logger.debug("token_request: %s" % sanitize(request)) - areq = self.server.message_factory.get_request_type('token_endpoint')().deserialize(request, dtype) + areq = self.server.message_factory.get_request_type( + "token_endpoint" + )().deserialize(request, dtype) # Verify client authentication try: client_id = self.client_authn(self, areq, authn) except (FailedAuthentication, AuthnFailure) as err: logger.error(err) - err = TokenErrorResponse(error="unauthorized_client", error_description="%s" % err) + err = TokenErrorResponse( + error="unauthorized_client", error_description="%s" % err + ) return Unauthorized(err.to_json(), content="application/json") logger.debug("AccessTokenRequest: %s" % sanitize(areq)) # `code` is not mandatory for all requests - if 'code' in areq: + if "code" in areq: try: _info = self.sdb[areq["code"]] except KeyError: - logger.error('Code not present in SessionDB') - error = TokenErrorResponse(error="unauthorized_client", - error_description='Invalid code.') + logger.error("Code not present in SessionDB") + error = TokenErrorResponse( + error="unauthorized_client", error_description="Invalid code." + ) return Unauthorized(error.to_json(), content="application/json") resp = self.token_scope_check(areq, _info) if resp: return resp # If redirect_uri was in the initial authorization request verify that they match - if "redirect_uri" in _info and areq["redirect_uri"] != _info["redirect_uri"]: - logger.error('Redirect_uri mismatch') - error = TokenErrorResponse(error="unauthorized_client", - error_description='Redirect_uris do not match.') + if ( + "redirect_uri" in _info + and areq["redirect_uri"] != _info["redirect_uri"] + ): + logger.error("Redirect_uri mismatch") + error = TokenErrorResponse( + error="unauthorized_client", + error_description="Redirect_uris do not match.", + ) return Unauthorized(error.to_json(), content="application/json") - if 'state' in areq: - if _info['state'] != areq['state']: - logger.error('State value mismatch') - error = TokenErrorResponse(error="unauthorized_client", - error_description='State values do not match.') + if "state" in areq: + if _info["state"] != areq["state"]: + logger.error("State value mismatch") + error = TokenErrorResponse( + error="unauthorized_client", + error_description="State values do not match.", + ) return Unauthorized(error.to_json(), content="application/json") # Propagate the client_id further - areq.setdefault('client_id', client_id) + areq.setdefault("client_id", client_id) grant_type = areq["grant_type"] if grant_type == "authorization_code": return self.code_grant_type(areq) elif grant_type == "refresh_token": return self.refresh_token_grant_type(areq) - elif grant_type == 'client_credentials': + elif grant_type == "client_credentials": return self.client_credentials_grant_type(areq) - elif grant_type == 'password': + elif grant_type == "password": return self.password_grant_type(areq) else: - raise UnSupported('grant_type: {}'.format(grant_type)) + raise UnSupported("grant_type: {}".format(grant_type)) def code_grant_type(self, areq): """ @@ -846,7 +903,9 @@ def code_grant_type(self, areq): try: _tinfo = self.sdb.upgrade_to_token(areq["code"], issue_refresh=True) except AccessCodeUsed: - error = TokenErrorResponse(error="invalid_grant", error_description="Access grant used") + error = TokenErrorResponse( + error="invalid_grant", error_description="Access grant used" + ) return Unauthorized(error.to_json(), content="application/json") logger.debug("_tinfo: %s" % sanitize(_tinfo)) @@ -855,7 +914,9 @@ def code_grant_type(self, areq): logger.debug("AccessTokenResponse: %s" % sanitize(atr)) - return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + return Response( + atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS + ) def refresh_token_grant_type(self, areq): """ @@ -864,7 +925,7 @@ def refresh_token_grant_type(self, areq): RFC6749 section 6 """ # This is not implemented here, please see oic.extension.provider. - return error_response('unsupported_grant_type', descr='Unsupported grant_type') + return error_response("unsupported_grant_type", descr="Unsupported grant_type") def client_credentials_grant_type(self, areq): """ @@ -873,7 +934,7 @@ def client_credentials_grant_type(self, areq): RFC6749 section 4.4 """ # This is not implemented here, please see oic.extension.provider. - return error_response('unsupported_grant_type', descr='Unsupported grant_type') + return error_response("unsupported_grant_type", descr="Unsupported grant_type") def password_grant_type(self, areq): """ @@ -882,14 +943,14 @@ def password_grant_type(self, areq): RFC6749 section 4.3 """ # This is not implemented here, please see oic.extension.provider. - return error_response('unsupported_grant_type', descr='Unsupported grant_type') + return error_response("unsupported_grant_type", descr="Unsupported grant_type") def verify_endpoint(self, request="", cookie=None, **kwargs): _req = parse_qs(request) try: areq = parse_qs(_req["query"][0]) except KeyError: - return BadRequest('Could not verify endpoint') + return BadRequest("Could not verify endpoint") authn, acr = self.pick_auth(areq=areq) kwargs["cookie"] = cookie @@ -899,12 +960,10 @@ def write_session_cookie(self, value): return make_cookie(self.session_cookie_name, value, self.seed, path="/") def delete_session_cookie(self): - return make_cookie(self.session_cookie_name, "", b"", path="/", - expire=-1) + return make_cookie(self.session_cookie_name, "", b"", path="/", expire=-1) def _compute_session_state(self, state, salt, client_id, redirect_uri): parsed_uri = urlparse(redirect_uri) rp_origin_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_uri) session_str = client_id + " " + rp_origin_url + " " + state + " " + salt - return hashlib.sha256( - session_str.encode("utf-8")).hexdigest() + "." + salt + return hashlib.sha256(session_str.encode("utf-8")).hexdigest() + "." + salt diff --git a/src/oic/oauth2/util.py b/src/oic/oauth2/util.py index c27cc8fbd..1a9412c5f 100644 --- a/src/oic/oauth2/util.py +++ b/src/oic/oauth2/util.py @@ -11,9 +11,9 @@ logger = logging.getLogger(__name__) -__author__ = 'roland' +__author__ = "roland" -URL_ENCODED = 'application/x-www-form-urlencoded' +URL_ENCODED = "application/x-www-form-urlencoded" JSON_ENCODED = "application/json" DEFAULT_POST_CONTENT_TYPE = URL_ENCODED @@ -21,30 +21,33 @@ PAIRS = { "port": "port_specified", "domain": "domain_specified", - "path": "path_specified" + "path": "path_specified", } -ATTRS = {"version": None, - "name": "", - "value": None, - "port": None, - "port_specified": False, - "domain": "", - "domain_specified": False, - "domain_initial_dot": False, - "path": "", - "path_specified": False, - "secure": False, - "expires": None, - "discard": True, - "comment": None, - "comment_url": None, - "rest": "", - "rfc2109": True} - - -def get_or_post(uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, - accept=None, **kwargs): +ATTRS = { + "version": None, + "name": "", + "value": None, + "port": None, + "port_specified": False, + "domain": "", + "domain_specified": False, + "domain_initial_dot": False, + "path": "", + "path_specified": False, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": "", + "rfc2109": True, +} + + +def get_or_post( + uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, accept=None, **kwargs +): """ Construct HTTP request. @@ -64,8 +67,9 @@ def get_or_post(uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, _req.update(parse_qs(comp.query)) _query = str(_req.to_urlencoded()) - path = urlunsplit((comp.scheme, comp.netloc, comp.path, - _query, comp.fragment)) + path = urlunsplit( + (comp.scheme, comp.netloc, comp.path, _query, comp.fragment) + ) else: path = uri body = None @@ -76,8 +80,7 @@ def get_or_post(uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE, elif content_type == JSON_ENCODED: body = req.to_json() else: - raise UnSupported( - "Unsupported content type: '%s'" % content_type) + raise UnSupported("Unsupported content type: '%s'" % content_type) header_ext = {"Content-Type": content_type} if accept: @@ -130,8 +133,9 @@ def set_cookie(cookiejar, kaka): except TimeFormatError: # Ignore cookie logger.info( - "Time format error on %s parameter in received cookie" % ( - sanitize(attr),)) + "Time format error on %s parameter in received cookie" + % (sanitize(attr),) + ) continue for att, spec in PAIRS.items(): @@ -143,9 +147,11 @@ def set_cookie(cookiejar, kaka): if morsel["max-age"] == 0: try: - cookiejar.clear(domain=std_attr["domain"], - path=std_attr["path"], - name=std_attr["name"]) + cookiejar.clear( + domain=std_attr["domain"], + path=std_attr["path"], + name=std_attr["name"], + ) except ValueError: pass else: @@ -179,27 +185,31 @@ def verify_header(reqresp, body_type): if body_type == "": _ctype = reqresp.headers["content-type"] if match_to_("application/json", _ctype): - body_type = 'json' + body_type = "json" elif match_to_("application/jwt", _ctype): body_type = "jwt" elif match_to_(URL_ENCODED, _ctype): - body_type = 'urlencoded' + body_type = "urlencoded" else: - body_type = 'txt' # reasonable default ?? + body_type = "txt" # reasonable default ?? elif body_type == "json": if not match_to_("application/json", reqresp.headers["content-type"]): if match_to_("application/jwt", reqresp.headers["content-type"]): body_type = "jwt" else: - raise ValueError("content-type: %s" % (reqresp.headers["content-type"],)) + raise ValueError( + "content-type: %s" % (reqresp.headers["content-type"],) + ) elif body_type == "jwt": if not match_to_("application/jwt", reqresp.headers["content-type"]): - raise ValueError("Wrong content-type in header, got: {} expected " - "'application/jwt'".format(reqresp.headers["content-type"])) + raise ValueError( + "Wrong content-type in header, got: {} expected " + "'application/jwt'".format(reqresp.headers["content-type"]) + ) elif body_type == "urlencoded": if not match_to_(DEFAULT_POST_CONTENT_TYPE, reqresp.headers["content-type"]): if not match_to_("text/plain", reqresp.headers["content-type"]): - raise ValueError('Wrong content-type') + raise ValueError("Wrong content-type") else: raise ValueError("Unknown return format: %s" % body_type) diff --git a/src/oic/oic/__init__.py b/src/oic/oic/__init__.py index bf3241e05..481572cd1 100644 --- a/src/oic/oic/__init__.py +++ b/src/oic/oic/__init__.py @@ -69,22 +69,26 @@ from oic.utils.webfinger import OIC_ISSUER from oic.utils.webfinger import WebFinger -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) -ENDPOINTS = ["authorization_endpoint", "token_endpoint", - "userinfo_endpoint", "refresh_session_endpoint", - "end_session_endpoint", "registration_endpoint", - "check_id_endpoint"] +ENDPOINTS = [ + "authorization_endpoint", + "token_endpoint", + "userinfo_endpoint", + "refresh_session_endpoint", + "end_session_endpoint", + "registration_endpoint", + "check_id_endpoint", +] RESPONSE2ERROR = { - "AuthorizationResponse": [AuthorizationErrorResponse, - TokenErrorResponse], + "AuthorizationResponse": [AuthorizationErrorResponse, TokenErrorResponse], "AccessTokenResponse": [TokenErrorResponse], "IdToken": [ErrorResponse], "RegistrationResponse": [ClientRegistrationErrorResponse], - "OpenIDSchema": [UserInfoErrorResponse] + "OpenIDSchema": [UserInfoErrorResponse], } # type: Dict[str, List] REQUEST2ENDPOINT = { @@ -101,8 +105,8 @@ "RotateSecret": "registration_endpoint", # --- "ResourceRequest": "resource_endpoint", - 'TokenIntrospectionRequest': 'introspection_endpoint', - 'TokenRevocationRequest': 'revocation_endpoint', + "TokenIntrospectionRequest": "introspection_endpoint", + "TokenRevocationRequest": "revocation_endpoint", "ROPCAccessTokenRequest": "token_endpoint", } @@ -113,15 +117,15 @@ # This should probably be part of the configuration MAX_AUTHENTICATION_AGE = 86400 -DEF_SIGN_ALG = {"id_token": "RS256", - "openid_request_object": "RS256", - "client_secret_jwt": "HS256", - "private_key_jwt": "RS256"} +DEF_SIGN_ALG = { + "id_token": "RS256", + "openid_request_object": "RS256", + "client_secret_jwt": "HS256", + "private_key_jwt": "RS256", +} # ----------------------------------------------------------------------------- -ACR_LISTS = [ - ["0", "1", "2", "3", "4"], -] +ACR_LISTS = [["0", "1", "2", "3", "4"]] def verify_acr_level(req, level): @@ -149,9 +153,14 @@ def deser_id_token(inst, txt=""): # ----------------------------------------------------------------------------- -def make_openid_request(arq, keys=None, userinfo_claims=None, - idtoken_claims=None, request_object_signing_alg=None, - **kwargs): +def make_openid_request( + arq, + keys=None, + userinfo_claims=None, + idtoken_claims=None, + request_object_signing_alg=None, + **kwargs +): """ Construct the specification of what I want returned. @@ -213,27 +222,20 @@ def add_token(self, resp): PREFERENCE2PROVIDER = { "request_object_signing_alg": "request_object_signing_alg_values_supported", - "request_object_encryption_alg": - "request_object_encryption_alg_values_supported", - "request_object_encryption_enc": - "request_object_encryption_enc_values_supported", + "request_object_encryption_alg": "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": "request_object_encryption_enc_values_supported", "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", - "userinfo_encrypted_response_alg": - "userinfo_encryption_alg_values_supported", - "userinfo_encrypted_response_enc": - "userinfo_encryption_enc_values_supported", + "userinfo_encrypted_response_alg": "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": "userinfo_encryption_enc_values_supported", "id_token_signed_response_alg": "id_token_signing_alg_values_supported", - "id_token_encrypted_response_alg": - "id_token_encryption_alg_values_supported", - "id_token_encrypted_response_enc": - "id_token_encryption_enc_values_supported", + "id_token_encrypted_response_alg": "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": "id_token_encryption_enc_values_supported", "default_acr_values": "acr_values_supported", "subject_type": "subject_types_supported", "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", - "token_endpoint_auth_signing_alg": - "token_endpoint_auth_signing_alg_values_supported", + "token_endpoint_auth_signing_alg": "token_endpoint_auth_signing_alg_values_supported", "response_types": "response_types_supported", - 'grant_types': 'grant_types_supported' + "grant_types": "grant_types_supported", } PROVIDER2PREFERENCE = dict([(v, k) for k, v in PREFERENCE2PROVIDER.items()]) @@ -250,29 +252,28 @@ def add_token(self, resp): } rt2gt = { - 'code': ['authorization_code'], - 'id_token': ['implicit'], - 'id_token token': ['implicit'], - 'code id_token': ['authorization_code', 'implicit'], - 'code token': ['authorization_code', 'implicit'], - 'code id_token token': ['authorization_code', 'implicit'] + "code": ["authorization_code"], + "id_token": ["implicit"], + "id_token token": ["implicit"], + "code id_token": ["authorization_code", "implicit"], + "code token": ["authorization_code", "implicit"], + "code id_token token": ["authorization_code", "implicit"], } def response_types_to_grant_types(resp_types, **kwargs): _res = set() - if 'grant_types' in kwargs: - _res.update(set(kwargs['grant_types'])) + if "grant_types" in kwargs: + _res.update(set(kwargs["grant_types"])) for response_type in resp_types: - _rt = response_type.split(' ') + _rt = response_type.split(" ") _rt.sort() try: _gt = rt2gt[" ".join(_rt)] except KeyError: - raise ValueError( - 'No such response type combination: {}'.format(resp_types)) + raise ValueError("No such response type combination: {}".format(resp_types)) else: _res.update(set(_gt)) @@ -302,7 +303,7 @@ def claims_match(value, claimspec): elif key == "values": if value in val: matched = True - elif key == 'essential': + elif key == "essential": # Whether it's essential or not doesn't change anything here continue @@ -310,7 +311,7 @@ def claims_match(value, claimspec): break if matched is False: - if list(claimspec.keys()) == ['essential']: + if list(claimspec.keys()) == ["essential"]: return True return matched @@ -319,19 +320,33 @@ def claims_match(value, claimspec): class Client(oauth2.Client): _endpoints = ENDPOINTS - def __init__(self, client_id=None, - client_prefs=None, client_authn_method=None, keyjar=None, - verify_ssl=True, config=None, client_cert=None, - requests_dir='requests', message_factory: Type[MessageFactory] = OIDCMessageFactory): - - oauth2.Client.__init__(self, client_id, - client_authn_method=client_authn_method, - keyjar=keyjar, verify_ssl=verify_ssl, - config=config, client_cert=client_cert, message_factory=message_factory) + def __init__( + self, + client_id=None, + client_prefs=None, + client_authn_method=None, + keyjar=None, + verify_ssl=True, + config=None, + client_cert=None, + requests_dir="requests", + message_factory: Type[MessageFactory] = OIDCMessageFactory, + ): + + oauth2.Client.__init__( + self, + client_id, + client_authn_method=client_authn_method, + keyjar=keyjar, + verify_ssl=verify_ssl, + config=config, + client_cert=client_cert, + message_factory=message_factory, + ) self.file_store = "./file/" self.file_uri = "http://localhost/" - self.base_url = '' + self.base_url = "" # OpenID connect specific endpoints for endpoint in ENDPOINTS: @@ -346,7 +361,9 @@ def __init__(self, client_id=None, self.grant_class = Grant self.token_class = Token self.provider_info = Message() - self.registration_response = RegistrationResponse() # type: RegistrationResponse + self.registration_response = ( + RegistrationResponse() + ) # type: RegistrationResponse self.client_prefs = client_prefs or {} self.behaviour = {} # type: Dict[str, Any] @@ -406,7 +423,8 @@ def request_object_encryption(self, msg, **kwargs): encenc = self.behaviour["request_object_encryption_enc"] except KeyError: raise MissingRequiredAttribute( - "No request_object_encryption_enc specified") + "No request_object_encryption_enc specified" + ) _jwe = JWE(msg, alg=encalg, enc=encenc) _kty = jwe.alg2keytype(encalg) @@ -420,8 +438,7 @@ def request_object_encryption(self, msg, **kwargs): raise MissingRequiredAttribute("No target specified") if _kid: - _keys = self.keyjar.get_encrypt_key(_kty, owner=kwargs["target"], - kid=_kid) + _keys = self.keyjar.get_encrypt_key(_kty, owner=kwargs["target"], kid=_kid) _jwe["kid"] = _kid else: _keys = self.keyjar.get_encrypt_key(_kty, owner=kwargs["target"]) @@ -448,11 +465,11 @@ def filename_from_webname(self, webname): os.makedirs(_filedir) assert webname.startswith(self.base_url) - return webname[len(self.base_url):] + return webname[len(self.base_url) :] - def construct_AuthorizationRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_AuthorizationRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request_args is not None: if "nonce" not in request_args: @@ -465,7 +482,7 @@ def construct_AuthorizationRequest(self, request=None, else: # Never wrong to specify a nonce request_args = {"nonce": rndstr(32)} - request_param = kwargs.get('request_param') + request_param = kwargs.get("request_param") if "request_method" in kwargs: if kwargs["request_method"] == "file": request_param = "request_uri" @@ -473,8 +490,9 @@ def construct_AuthorizationRequest(self, request=None, request_param = "request" del kwargs["request_method"] - areq = super().construct_AuthorizationRequest(request=request, request_args=request_args, extra_args=extra_args, - **kwargs) + areq = super().construct_AuthorizationRequest( + request=request, request_args=request_args, extra_args=extra_args, **kwargs + ) if request_param: alg = None @@ -512,7 +530,7 @@ def construct_AuthorizationRequest(self, request=None, areq["request"] = _req else: try: - _webname = self.registration_response['request_uris'][0] + _webname = self.registration_response["request_uris"][0] filename = self.filename_from_webname(_webname) except KeyError: filename, _webname = self.construct_redirect_uri(**kwargs) @@ -523,12 +541,12 @@ def construct_AuthorizationRequest(self, request=None, return areq - def construct_UserInfoRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_UserInfoRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('userinfo_endpoint') + request = self.message_factory.get_request_type("userinfo_endpoint") if request_args is None: request_args = {} @@ -545,23 +563,21 @@ def construct_UserInfoRequest(self, request=None, return self.construct_request(request, request_args, extra_args) - def construct_RegistrationRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_RegistrationRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('registration_endpoint') + request = self.message_factory.get_request_type("registration_endpoint") return self.construct_request(request, request_args, extra_args) - def construct_RefreshSessionRequest(self, - request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_RefreshSessionRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('refreshsession_endpoint') + request = self.message_factory.get_request_type("refreshsession_endpoint") return self.construct_request(request, request_args, extra_args) - def _id_token_based(self, request, request_args=None, extra_args=None, - **kwargs): + def _id_token_based(self, request, request_args=None, extra_args=None, **kwargs): if request_args is None: request_args = {} @@ -582,29 +598,30 @@ def _id_token_based(self, request, request_args=None, extra_args=None, return self.construct_request(request, request_args, extra_args) - def construct_CheckSessionRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_CheckSessionRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('checksession_endpoint') + request = self.message_factory.get_request_type("checksession_endpoint") return self._id_token_based(request, request_args, extra_args, **kwargs) - def construct_CheckIDRequest(self, request=None, - request_args=None, - extra_args=None, **kwargs): + def construct_CheckIDRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('checkid_endpoint') + request = self.message_factory.get_request_type("checkid_endpoint") # access_token is where the id_token will be placed - return self._id_token_based(request, request_args, extra_args, - prop="access_token", **kwargs) + return self._id_token_based( + request, request_args, extra_args, prop="access_token", **kwargs + ) - def construct_EndSessionRequest(self, request=None, - request_args=None, extra_args=None, - **kwargs): + def construct_EndSessionRequest( + self, request=None, request_args=None, extra_args=None, **kwargs + ): if request is None: - request = self.message_factory.get_request_type('endsession_endpoint') + request = self.message_factory.get_request_type("endsession_endpoint") if request_args is None: request_args = {} @@ -613,73 +630,131 @@ def construct_EndSessionRequest(self, request=None, elif "state" in request_args: kwargs["state"] = request_args["state"] - return self._id_token_based(request, request_args, extra_args, - **kwargs) + return self._id_token_based(request, request_args, extra_args, **kwargs) - def do_authorization_request(self, request=None, - state="", body_type="", method="GET", - request_args=None, extra_args=None, - http_args=None, response_cls=None, **kwargs): + def do_authorization_request( + self, + request=None, + state="", + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) algs = self.sign_enc_algs("id_token") - if 'code_challenge' in self.config: + if "code_challenge" in self.config: _args, code_verifier = self.add_code_challenge() request_args.update(_args) - return super().do_authorization_request(request=request, state=state, body_type=body_type, method=method, - request_args=request_args, extra_args=extra_args, http_args=http_args, - response_cls=response_cls, algs=algs) - - def do_access_token_request(self, request=None, - scope="", state="", body_type="json", - method="POST", request_args=None, - extra_args=None, http_args=None, - response_cls=None, - authn_method="client_secret_basic", **kwargs): + return super().do_authorization_request( + request=request, + state=state, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + algs=algs, + ) + + def do_access_token_request( + self, + request=None, + scope="", + state="", + body_type="json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + authn_method="client_secret_basic", + **kwargs + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) - atr = super().do_access_token_request(request=request, scope=scope, state=state, body_type=body_type, - method=method, request_args=request_args, extra_args=extra_args, - http_args=http_args, response_cls=response_cls, authn_method=authn_method, - **kwargs) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) + atr = super().do_access_token_request( + request=request, + scope=scope, + state=state, + body_type=body_type, + method=method, + request_args=request_args, + extra_args=extra_args, + http_args=http_args, + response_cls=response_cls, + authn_method=authn_method, + **kwargs + ) try: - _idt = atr['id_token'] + _idt = atr["id_token"] except KeyError: pass else: try: - if self.state2nonce[state] != _idt['nonce']: + if self.state2nonce[state] != _idt["nonce"]: raise ParameterError('Someone has messed with "nonce"') except KeyError: pass return atr - def do_registration_request(self, request=None, - scope="", state="", body_type="json", - method="POST", request_args=None, - extra_args=None, http_args=None, - response_cls=None): + def do_registration_request( + self, + request=None, + scope="", + state="", + body_type="json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', DeprecationWarning, - stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('registration_endpoint') - url, body, ht_args, csi = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state) + request = self.message_factory.get_request_type("registration_endpoint") + url, body, ht_args, csi = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + ) if http_args is None: http_args = ht_args @@ -687,106 +762,164 @@ def do_registration_request(self, request=None, http_args.update(http_args) if response_cls is None: - response_cls = self.message_factory.get_response_type('registration_endpoint') + response_cls = self.message_factory.get_response_type( + "registration_endpoint" + ) else: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) - response = self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) + response = self.request_and_return( + url, response_cls, method, body, body_type, state=state, http_args=http_args + ) return response - def do_check_session_request(self, request=None, - scope="", - state="", body_type="json", method="GET", - request_args=None, extra_args=None, - http_args=None, - response_cls=None): + def do_check_session_request( + self, + request=None, + scope="", + state="", + body_type="json", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('checksession_endpoint') + request = self.message_factory.get_request_type("checksession_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('checksession_endpoint') - - url, body, ht_args, csi = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state) + response_cls = self.message_factory.get_response_type( + "checksession_endpoint" + ) + + url, body, ht_args, csi = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + ) if http_args is None: http_args = ht_args else: http_args.update(http_args) - return self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args) - - def do_check_id_request(self, request=None, scope="", - state="", body_type="json", method="GET", - request_args=None, extra_args=None, - http_args=None, - response_cls=None): + return self.request_and_return( + url, response_cls, method, body, body_type, state=state, http_args=http_args + ) + + def do_check_id_request( + self, + request=None, + scope="", + state="", + body_type="json", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('checkid_endpoint') + request = self.message_factory.get_request_type("checkid_endpoint") if response_cls is not None: - warnings.warn('Passing `response_cls` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `response_cls` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('checkid_endpoint') + response_cls = self.message_factory.get_response_type("checkid_endpoint") - url, body, ht_args, csi = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state) + url, body, ht_args, csi = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + ) if http_args is None: http_args = ht_args else: http_args.update(http_args) - return self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args) - - def do_end_session_request(self, request=None, scope="", - state="", body_type="", method="GET", - request_args=None, extra_args=None, - http_args=None, response_cls=None): + return self.request_and_return( + url, response_cls, method, body, body_type, state=state, http_args=http_args + ) + + def do_end_session_request( + self, + request=None, + scope="", + state="", + body_type="", + method="GET", + request_args=None, + extra_args=None, + http_args=None, + response_cls=None, + ): if request is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - request = self.message_factory.get_request_type('endsession_endpoint') + request = self.message_factory.get_request_type("endsession_endpoint") if response_cls is not None: - warnings.warn('Passing `request` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) else: - response_cls = self.message_factory.get_response_type('endsession_endpoint') - url, body, ht_args, _ = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args, - scope=scope, state=state) + response_cls = self.message_factory.get_response_type("endsession_endpoint") + url, body, ht_args, _ = self.request_info( + request, + method=method, + request_args=request_args, + extra_args=extra_args, + scope=scope, + state=state, + ) if http_args is None: http_args = ht_args else: http_args.update(http_args) - return self.request_and_return(url, response_cls, method, body, - body_type, state=state, - http_args=http_args) + return self.request_and_return( + url, response_cls, method, body, body_type, state=state, http_args=http_args + ) def user_info_request(self, method="GET", state="", scope="", **kwargs): - uir = self.message_factory.get_request_type('userinfo_endpoint')() + uir = self.message_factory.get_request_type("userinfo_endpoint")() logger.debug("[user_info_request]: kwargs:%s" % (sanitize(kwargs),)) token = None if "token" in kwargs: @@ -807,8 +940,11 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): if token.is_valid(): uir["access_token"] = token.access_token - if token.token_type and token.token_type.lower() == "bearer" \ - and method == "GET": + if ( + token.token_type + and token.token_type.lower() == "bearer" + and method == "GET" + ): kwargs["behavior"] = "use_authorization_header" else: # raise oauth2.OldAccessToken @@ -832,7 +968,7 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): if "behavior" in kwargs: _behav = kwargs["behavior"] _token = uir["access_token"] - _ttype = '' + _ttype = "" try: _ttype = kwargs["token_type"] except KeyError: @@ -842,17 +978,16 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): except AttributeError: raise MissingParameter("Unspecified token type") - if 'as_query_parameter' == _behav: - method = 'GET' + if "as_query_parameter" == _behav: + method = "GET" elif token: # use_authorization_header, token_in_message_body if "use_authorization_header" in _behav: token_header = "{type} {token}".format( - type=_ttype.capitalize(), - token=_token) + type=_ttype.capitalize(), token=_token + ) if "headers" in kwargs: - kwargs["headers"].update( - {"Authorization": token_header}) + kwargs["headers"].update({"Authorization": token_header}) else: kwargs["headers"] = {"Authorization": token_header} @@ -866,20 +1001,24 @@ def user_info_request(self, method="GET", state="", scope="", **kwargs): return path, body, method, h_args - def do_user_info_request(self, method="POST", state="", scope="openid", - request="openid", **kwargs): + def do_user_info_request( + self, method="POST", state="", scope="openid", request="openid", **kwargs + ): kwargs["request"] = request - path, body, method, h_args = self.user_info_request(method, state, - scope, **kwargs) + path, body, method, h_args = self.user_info_request( + method, state, scope, **kwargs + ) - logger.debug("[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" % ( - sanitize(path), sanitize(body), sanitize(h_args))) + logger.debug( + "[do_user_info_request] PATH:%s BODY:%s H_ARGS: %s" + % (sanitize(path), sanitize(body), sanitize(h_args)) + ) if self.events: - self.events.store('Request', {'body': body}) - self.events.store('request_url', path) - self.events.store('request_http_args', h_args) + self.events.store("Request", {"body": body}) + self.events.store("request_url", path) + self.events.store("request_http_args", h_args) try: resp = self.http_request(path, method, data=body, **h_args) @@ -904,8 +1043,9 @@ def do_user_info_request(self, method="POST", state="", scope="openid", self.store_response(res, resp.text) return res else: - raise PyoidcError("ERROR: Something went wrong [%s]: %s" % ( - resp.status_code, resp.text)) + raise PyoidcError( + "ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text) + ) try: _schema = kwargs["user_info_schema"] @@ -918,26 +1058,33 @@ def do_user_info_request(self, method="POST", state="", scope="openid", if sformat == "json": res = _schema().from_json(txt=_txt) else: - verify = kwargs.get('verify', True) - res = _schema().from_jwt(_txt, keyjar=self.keyjar, sender=self.provider_info["issuer"], verify=verify) - - if 'error' in res: # Error response + verify = kwargs.get("verify", True) + res = _schema().from_jwt( + _txt, + keyjar=self.keyjar, + sender=self.provider_info["issuer"], + verify=verify, + ) + + if "error" in res: # Error response res = UserInfoErrorResponse(**res.to_dict()) if state: # Verify userinfo sub claim against what's returned in the ID Token idt = self.grant[state].get_id_token() if idt: - if idt['sub'] != res['sub']: + if idt["sub"] != res["sub"]: raise SubMismatch( - 'Sub identifier not the same in userinfo and Id Token') + "Sub identifier not the same in userinfo and Id Token" + ) self.store_response(res, _txt) return res - def get_userinfo_claims(self, access_token, endpoint, method="POST", - schema_class=OpenIDSchema, **kwargs): + def get_userinfo_claims( + self, access_token, endpoint, method="POST", schema_class=OpenIDSchema, **kwargs + ): uir = UserInfoRequest(access_token=access_token) @@ -947,8 +1094,7 @@ def get_userinfo_claims(self, access_token, endpoint, method="POST", http_args = self.init_authentication_method(**kwargs) else: # If nothing defined this is the default - http_args = self.init_authentication_method(uir, "bearer_header", - **kwargs) + http_args = self.init_authentication_method(uir, "bearer_header", **kwargs) h_args.update(http_args) path, body, kwargs = get_or_post(endpoint, method, uir, **kwargs) @@ -964,8 +1110,8 @@ def get_userinfo_claims(self, access_token, endpoint, method="POST", raise PyoidcError("ERROR: Something went wrong: %s" % resp.text) else: raise PyoidcError( - "ERROR: Something went wrong [%s]: %s" % (resp.status_code, - resp.text)) + "ERROR: Something went wrong [%s]: %s" % (resp.status_code, resp.text) + ) res = schema_class().from_json(txt=resp.text) self.store_response(res, resp.text) @@ -976,15 +1122,19 @@ def unpack_aggregated_claims(self, userinfo): for csrc, spec in userinfo["_claim_sources"].items(): if "JWT" in spec: aggregated_claims = Message().from_jwt( - spec["JWT"].encode("utf-8"), - keyjar=self.keyjar, sender=csrc) - claims = [value for value, src in - userinfo["_claim_names"].items() if src == csrc] + spec["JWT"].encode("utf-8"), keyjar=self.keyjar, sender=csrc + ) + claims = [ + value + for value, src in userinfo["_claim_names"].items() + if src == csrc + ] if set(claims) != set(list(aggregated_claims.keys())): logger.warning( "Claims from claim source doesn't match what's in " - "the userinfo") + "the userinfo" + ) for key, vals in aggregated_claims.items(): userinfo[key] = vals @@ -995,25 +1145,43 @@ def fetch_distributed_claims(self, userinfo, callback=None): for csrc, spec in userinfo["_claim_sources"].items(): if "endpoint" in spec: if not spec["endpoint"].startswith("https://"): - logger.warning("Fetching distributed claims from an untrusted source: %s", spec["endpoint"]) + logger.warning( + "Fetching distributed claims from an untrusted source: %s", + spec["endpoint"], + ) if "access_token" in spec: - _uinfo = self.do_user_info_request(method='GET', token=spec["access_token"], - userinfo_endpoint=spec["endpoint"], verify=False) + _uinfo = self.do_user_info_request( + method="GET", + token=spec["access_token"], + userinfo_endpoint=spec["endpoint"], + verify=False, + ) else: if callback: - _uinfo = self.do_user_info_request(method='GET', token=callback(spec['endpoint']), - userinfo_endpoint=spec["endpoint"], verify=False) + _uinfo = self.do_user_info_request( + method="GET", + token=callback(spec["endpoint"]), + userinfo_endpoint=spec["endpoint"], + verify=False, + ) else: - _uinfo = self.do_user_info_request(method='GET', userinfo_endpoint=spec["endpoint"], - verify=False) - - claims = [value for value, src in - userinfo["_claim_names"].items() if src == csrc] + _uinfo = self.do_user_info_request( + method="GET", + userinfo_endpoint=spec["endpoint"], + verify=False, + ) + + claims = [ + value + for value, src in userinfo["_claim_names"].items() + if src == csrc + ] if set(claims) != set(list(_uinfo.keys())): logger.warning( "Claims from claim source doesn't match what's in " - "the userinfo") + "the userinfo" + ) for key, vals in _uinfo.items(): userinfo[key] = vals @@ -1066,7 +1234,7 @@ def match_preferences(self, pcr=None, issuer=None): if not pcr: pcr = self.provider_info - regreq = self.message_factory.get_request_type('registration_endpoint') + regreq = self.message_factory.get_request_type("registration_endpoint") for _pref, _prov in PREFERENCE2PROVIDER.items(): try: @@ -1104,8 +1272,7 @@ def match_preferences(self, pcr=None, issuer=None): break if _pref not in self.behaviour: - raise ConfigurationError( - "OP couldn't match preference:%s" % _pref, pcr) + raise ConfigurationError("OP couldn't match preference:%s" % _pref, pcr) for key, val in self.client_prefs.items(): if key in self.behaviour: @@ -1126,7 +1293,8 @@ def store_registration_info(self, reginfo): self.registration_response = reginfo if "token_endpoint_auth_method" not in self.registration_response: self.registration_response[ - "token_endpoint_auth_method"] = "client_secret_basic" + "token_endpoint_auth_method" + ] = "client_secret_basic" self.client_id = reginfo["client_id"] try: self.client_secret = reginfo["client_secret"] @@ -1138,16 +1306,17 @@ def store_registration_info(self, reginfo): except KeyError: pass try: - self.registration_access_token = reginfo[ - "registration_access_token"] + self.registration_access_token = reginfo["registration_access_token"] except KeyError: pass def handle_registration_info(self, response): - err_msg = 'Got error response: {}' - unk_msg = 'Unknown response: {}' + err_msg = "Got error response: {}" + unk_msg = "Unknown response: {}" if response.status_code in [200, 201]: - resp = self.message_factory.get_response_type('registration_endpoint')().deserialize(response.text, "json") + resp = self.message_factory.get_response_type( + "registration_endpoint" + )().deserialize(response.text, "json") # Some implementations sends back a 200 with an error message inside try: resp.verify() @@ -1159,7 +1328,7 @@ def handle_registration_info(self, response): if resp.verify(): logger.error(err_msg.format(sanitize(resp.to_json()))) if self.events: - self.events.store('protocol response', resp) + self.events.store("protocol response", resp) raise RegistrationError(resp.to_dict()) else: # Something else logger.error(unk_msg.format(sanitize(response.text))) @@ -1178,7 +1347,7 @@ def handle_registration_info(self, response): if resp.verify(): logger.error(err_msg.format(sanitize(resp.to_json()))) if self.events: - self.events.store('protocol response', resp) + self.events.store("protocol response", resp) raise RegistrationError(resp.to_dict()) else: # Something else logger.error(unk_msg.format(sanitize(response.text))) @@ -1213,9 +1382,9 @@ def generate_request_uris(self, request_dir): :return: A list of uris """ m = hashlib.sha256() - m.update(as_bytes(self.provider_info['issuer'])) + m.update(as_bytes(self.provider_info["issuer"])) m.update(as_bytes(self.base_url)) - return '{}{}/{}'.format(self.base_url, request_dir, m.hexdigest()) + return "{}{}/{}".format(self.base_url, request_dir, m.hexdigest()) def create_registration_request(self, **kwargs): """ @@ -1224,7 +1393,7 @@ def create_registration_request(self, **kwargs): :param kwargs: parameters to the registration request :return: """ - req = self.message_factory.get_request_type('registration_endpoint')() + req = self.message_factory.get_request_type("registration_endpoint")() for prop in req.parameters(): try: @@ -1237,9 +1406,7 @@ def create_registration_request(self, **kwargs): if "post_logout_redirect_uris" not in req: try: - req[ - "post_logout_redirect_uris"] = \ - self.post_logout_redirect_uris + req["post_logout_redirect_uris"] = self.post_logout_redirect_uris except AttributeError: pass @@ -1250,15 +1417,15 @@ def create_registration_request(self, **kwargs): raise MissingRequiredAttribute("redirect_uris", req) try: - if self.provider_info['require_request_uri_registration'] is True: - req['request_uris'] = self.generate_request_uris( - self.requests_dir) + if self.provider_info["require_request_uri_registration"] is True: + req["request_uris"] = self.generate_request_uris(self.requests_dir) except KeyError: pass - if 'response_types' in req: - req['grant_types'] = response_types_to_grant_types( - req['response_types'], **kwargs) + if "response_types" in req: + req["grant_types"] = response_types_to_grant_types( + req["response_types"], **kwargs + ) return req @@ -1276,14 +1443,15 @@ def register(self, url, registration_token=None, **kwargs): logger.debug("[registration_request]: kwargs:%s" % (sanitize(kwargs),)) if self.events: - self.events.store('Protocol request', req) + self.events.store("Protocol request", req) headers = {"content-type": "application/json"} if registration_token is not None: - headers["Authorization"] = b"Bearer " + b64encode(registration_token.encode()) + headers["Authorization"] = b"Bearer " + b64encode( + registration_token.encode() + ) - rsp = self.http_request(url, "POST", data=req.to_json(), - headers=headers) + rsp = self.http_request(url, "POST", data=req.to_json(), headers=headers) return self.handle_registration_info(rsp) @@ -1314,8 +1482,9 @@ def sign_enc_algs(self, typ): resp[key] = DEF_SIGN_ALG["id_token"] return resp - def _verify_id_token(self, id_token, nonce="", acr_values=None, auth_time=0, - max_age=0): + def _verify_id_token( + self, id_token, nonce="", acr_values=None, auth_time=0, max_age=0 + ): """ Verify IdToken. @@ -1347,16 +1516,19 @@ def _verify_id_token(self, id_token, nonce="", acr_values=None, auth_time=0, if _now > id_token["exp"]: raise OtherError("Passed best before date") - if self.id_token_max_age and _now > int(id_token["iat"]) + self.id_token_max_age: + if ( + self.id_token_max_age + and _now > int(id_token["iat"]) + self.id_token_max_age + ): raise OtherError("I think this ID token is to old") - if nonce and nonce != id_token['nonce']: + if nonce and nonce != id_token["nonce"]: raise OtherError("nonce mismatch") - if acr_values and id_token['acr'] not in acr_values: + if acr_values and id_token["acr"] not in acr_values: raise OtherError("acr mismatch") - if max_age and _now > int(id_token['auth_time'] + max_age): + if max_age and _now > int(id_token["auth_time"] + max_age): raise AuthnToOld("To old authentication") if auth_time: @@ -1382,11 +1554,22 @@ def verify_id_token(self, id_token, authn_req): class Server(oauth2.Server): """OIC Server class.""" - def __init__(self, verify_ssl: bool = True, keyjar: KeyJar = None, client_cert: Union[str, Tuple[str, str]] = None, - timeout: int = 5, message_factory: Type[MessageFactory] = OIDCMessageFactory): + def __init__( + self, + verify_ssl: bool = True, + keyjar: KeyJar = None, + client_cert: Union[str, Tuple[str, str]] = None, + timeout: int = 5, + message_factory: Type[MessageFactory] = OIDCMessageFactory, + ): """Initialize the server.""" - super().__init__(verify_ssl=verify_ssl, keyjar=keyjar, client_cert=client_cert, timeout=timeout, - message_factory=message_factory) + super().__init__( + verify_ssl=verify_ssl, + keyjar=keyjar, + client_cert=client_cert, + timeout=timeout, + message_factory=message_factory, + ) @staticmethod def _parse_urlencoded(url=None, query=None): @@ -1400,7 +1583,7 @@ def parse_token_request(self, request=AccessTokenRequest, body=None): """Overridden to use OIC Message type.""" return super().parse_token_request(request=request, body=body) - def handle_request_uri(self, request_uri, verify=True, sender=''): + def handle_request_uri(self, request_uri, verify=True, sender=""): """ Handle request URI. @@ -1411,37 +1594,40 @@ def handle_request_uri(self, request_uri, verify=True, sender=''): :return: """ # Do a HTTP get - logger.debug('Get request from request_uri: {}'.format(request_uri)) + logger.debug("Get request from request_uri: {}".format(request_uri)) try: http_req = self.http_request(request_uri) except ConnectionError: - logger.error('Connection Error') + logger.error("Connection Error") return authz_error("invalid_request_uri") if not http_req: - logger.error('Nothing returned') + logger.error("Nothing returned") return authz_error("invalid_request_uri") elif http_req.status_code >= 400: - logger.error('HTTP error {}:{}'.format(http_req.status_code, - http_req.text)) - raise AuthzError('invalid_request') + logger.error("HTTP error {}:{}".format(http_req.status_code, http_req.text)) + raise AuthzError("invalid_request") # http_req.text is a signed JWT try: - logger.debug('request txt: {}'.format(http_req.text)) - req = self.parse_jwt_request(txt=http_req.text, verify=verify, - sender=sender) + logger.debug("request txt: {}".format(http_req.text)) + req = self.parse_jwt_request( + txt=http_req.text, verify=verify, sender=sender + ) except Exception as err: logger.error( - '{}:{} encountered while parsing fetched request'.format( - err.__class__, err)) + "{}:{} encountered while parsing fetched request".format( + err.__class__, err + ) + ) raise AuthzError("invalid_openid_request_object") - logger.debug('Fetched request: {}'.format(req)) + logger.debug("Fetched request: {}".format(req)) return req - def parse_authorization_request(self, request=AuthorizationRequest, - url=None, query=None, keys=None): + def parse_authorization_request( + self, request=AuthorizationRequest, url=None, query=None, keys=None + ): if url: parts = urlparse(url) scheme, netloc, path, params, query, fragment = parts[:6] @@ -1449,34 +1635,37 @@ def parse_authorization_request(self, request=AuthorizationRequest, if isinstance(query, dict): sformat = "dict" else: - sformat = 'urlencoded' + sformat = "urlencoded" _req = self._parse_request(request, query, sformat, verify=False) if self.events: - self.events.store('Request', _req) + self.events.store("Request", _req) _req_req = {} try: - _request = _req['request'] + _request = _req["request"] except KeyError: try: - _url = _req['request_uri'] + _url = _req["request_uri"] except KeyError: pass else: - _req_req = self.handle_request_uri(_url, verify=False, - sender=_req['client_id']) + _req_req = self.handle_request_uri( + _url, verify=False, sender=_req["client_id"] + ) else: if isinstance(_request, Message): _req_req = _request else: try: - _req_req = self.parse_jwt_request(request, txt=_request, - verify=False) + _req_req = self.parse_jwt_request( + request, txt=_request, verify=False + ) except Exception: - _req_req = self._parse_request(request, _request, - 'urlencoded', verify=False) + _req_req = self._parse_request( + request, _request, "urlencoded", verify=False + ) else: # remove JWT attributes for attr in JasonWebToken.c_param: try: @@ -1489,35 +1678,48 @@ def parse_authorization_request(self, request=AuthorizationRequest, if _req_req: if self.events: - self.events.store('Signed Request', _req_req) + self.events.store("Signed Request", _req_req) for key, val in _req.items(): - if key in ['request', 'request_uri']: + if key in ["request", "request_uri"]: continue if key not in _req_req: _req_req[key] = val _req = _req_req if self.events: - self.events.store('Combined Request', _req) + self.events.store("Combined Request", _req) try: _req.verify(keyjar=self.keyjar) except Exception as err: if self.events: - self.events.store('Exception', err) + self.events.store("Exception", err) logger.error(err) raise return _req - def parse_jwt_request(self, request=AuthorizationRequest, txt="", - keyjar=None, verify=True, sender='', **kwargs): + def parse_jwt_request( + self, + request=AuthorizationRequest, + txt="", + keyjar=None, + verify=True, + sender="", + **kwargs + ): """Overridden to use OIC Message type.""" - if 'keys' in kwargs: - keyjar = kwargs['keys'] - warnings.warn('`keys` was renamed to `keyjar`, please update your code.', DeprecationWarning, stacklevel=2) - return super().parse_jwt_request(request=request, txt=txt, keyjar=keyjar, verify=verify, sender=sender) + if "keys" in kwargs: + keyjar = kwargs["keys"] + warnings.warn( + "`keys` was renamed to `keyjar`, please update your code.", + DeprecationWarning, + stacklevel=2, + ) + return super().parse_jwt_request( + request=request, txt=txt, keyjar=keyjar, verify=verify, sender=sender + ) def parse_refresh_token_request(self, request=RefreshAccessTokenRequest, body=None): """Overridden to use OIC Message type.""" @@ -1533,25 +1735,24 @@ def parse_check_id_request(self, url=None, query=None): assert "access_token" in param # ignore the rest return deser_id_token(self, param["access_token"][0]) - def _parse_request(self, request_cls, data, sformat, client_id=None, - verify=True): + def _parse_request(self, request_cls, data, sformat, client_id=None, verify=True): if sformat == "json": request = request_cls().from_json(data) elif sformat == "jwt": - request = request_cls().from_jwt(data, keyjar=self.keyjar, - sender=client_id) + request = request_cls().from_jwt(data, keyjar=self.keyjar, sender=client_id) elif sformat == "urlencoded": - if '?' in data: + if "?" in data: parts = urlparse(data) scheme, netloc, path, params, query, fragment = parts[:6] else: query = data request = request_cls().from_urlencoded(query) - elif sformat == 'dict': + elif sformat == "dict": request = request_cls(**data) else: - raise ParseError("Unknown package format: '{}'".format(sformat), - request_cls) + raise ParseError( + "Unknown package format: '{}'".format(sformat), request_cls + ) # get the verification keys if client_id: @@ -1560,12 +1761,12 @@ def _parse_request(self, request_cls, data, sformat, client_id=None, else: try: keys = self.keyjar.verify_keys(request["client_id"]) - sender = request['client_id'] + sender = request["client_id"] except KeyError: keys = None - sender = '' + sender = "" - logger.debug("Found {} verify keys".format(len(keys or ''))) + logger.debug("Found {} verify keys".format(len(keys or ""))) if verify: request.verify(key=keys, keyjar=self.keyjar, sender=sender) return request @@ -1589,8 +1790,7 @@ def parse_registration_request(self, data, sformat="urlencoded"): return self._parse_request(RegistrationRequest, data, sformat) def parse_end_session_request(self, query, sformat="urlencoded"): - esr = self._parse_request(EndSessionRequest, query, - sformat) + esr = self._parse_request(EndSessionRequest, query, sformat) # if there is a id_token in there it is as a string esr["id_token"] = deser_id_token(self, esr["id_token"]) return esr @@ -1648,9 +1848,19 @@ def id_token_claims(self, session): itc = self.update_claims(session, "oidreq", "id_token", itc) return itc - def make_id_token(self, session, loa="2", issuer="", - alg="RS256", code=None, access_token=None, - user_info=None, auth_time=0, exp=None, extra_claims=None): + def make_id_token( + self, + session, + loa="2", + issuer="", + alg="RS256", + code=None, + access_token=None, + user_info=None, + auth_time=0, + exp=None, + extra_claims=None, + ): """ Create ID Token. @@ -1696,8 +1906,7 @@ def make_id_token(self, session, loa="2", issuer="", _args = user_info # Make sure that there are no name clashes - for key in ["iss", "sub", "aud", "exp", "acr", "nonce", - "auth_time"]: + for key in ["iss", "sub", "aud", "exp", "acr", "nonce", "auth_time"]: try: del _args[key] except KeyError: @@ -1710,14 +1919,17 @@ def make_id_token(self, session, loa="2", issuer="", if code: _args["c_hash"] = jws.left_hash(code.encode("utf-8"), halg) if access_token: - _args["at_hash"] = jws.left_hash(access_token.encode("utf-8"), - halg) - - idt = IdToken(iss=issuer, sub=session["sub"], - aud=session["client_id"], - exp=time_util.epoch_in_a_while(**inawhile), acr=loa, - iat=time_util.utc_time_sans_frac(), - **_args) + _args["at_hash"] = jws.left_hash(access_token.encode("utf-8"), halg) + + idt = IdToken( + iss=issuer, + sub=session["sub"], + aud=session["client_id"], + exp=time_util.epoch_in_a_while(**inawhile), + acr=loa, + iat=time_util.utc_time_sans_frac(), + **_args + ) for key, val in extra.items(): idt[key] = val diff --git a/src/oic/oic/claims_provider.py b/src/oic/oic/claims_provider.py index 0db5c07ae..086c74af0 100644 --- a/src/oic/oic/claims_provider.py +++ b/src/oic/oic/claims_provider.py @@ -17,23 +17,27 @@ from oic.utils.keyio import KeyJar from oic.utils.sanitize import sanitize -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) class UserClaimsRequest(Message): - c_param = {"sub": SINGLE_REQUIRED_STRING, - "client_id": SINGLE_REQUIRED_STRING, - "client_secret": SINGLE_REQUIRED_STRING, - "claims_names": REQUIRED_LIST_OF_STRINGS} + c_param = { + "sub": SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, + "client_secret": SINGLE_REQUIRED_STRING, + "claims_names": REQUIRED_LIST_OF_STRINGS, + } class UserClaimsResponse(Message): - c_param = {"claims_names": REQUIRED_LIST_OF_STRINGS, - "jwt": SINGLE_OPTIONAL_STRING, - "endpoint": SINGLE_OPTIONAL_STRING, - "access_token": SINGLE_OPTIONAL_STRING} + c_param = { + "claims_names": REQUIRED_LIST_OF_STRINGS, + "jwt": SINGLE_OPTIONAL_STRING, + "endpoint": SINGLE_OPTIONAL_STRING, + "access_token": SINGLE_OPTIONAL_STRING, + } class UserInfoClaimsRequest(Message): @@ -49,12 +53,34 @@ def parse_userinfo_claims_request(self, info, sformat="urlencoded"): class ClaimsServer(Provider): - def __init__(self, name, sdb, cdb, userinfo, client_authn, urlmap=None, - keyjar=None, hostname="", dist_claims_mode=None, - verify_ssl=True): - Provider.__init__(self, name, sdb, cdb, None, userinfo, None, - client_authn, None, urlmap, keyjar, hostname, - verify_ssl=verify_ssl) + def __init__( + self, + name, + sdb, + cdb, + userinfo, + client_authn, + urlmap=None, + keyjar=None, + hostname="", + dist_claims_mode=None, + verify_ssl=True, + ): + Provider.__init__( + self, + name, + sdb, + cdb, + None, + userinfo, + None, + client_authn, + None, + urlmap, + keyjar, + hostname, + verify_ssl=verify_ssl, + ) if keyjar is None: keyjar = KeyJar(verify_ssl=verify_ssl) @@ -73,9 +99,10 @@ def __init__(self, name, sdb, cdb, userinfo, client_authn, urlmap=None, def _aggregation(self, info): jwt_key = self.keyjar.get_signing_key() - cresp = UserClaimsResponse(jwt=info.to_jwt(key=jwt_key, - algorithm="RS256"), - claims_names=list(info.keys())) + cresp = UserClaimsResponse( + jwt=info.to_jwt(key=jwt_key, algorithm="RS256"), + claims_names=list(info.keys()), + ) logger.info("RESPONSE: %s" % (sanitize(cresp.to_dict()),)) return cresp @@ -84,9 +111,11 @@ def _distributed(self, info): # store the user info so it can be accessed later access_token = rndstr() self.info_store[access_token] = info - return UserClaimsResponse(endpoint=self.claims_userinfo_endpoint, - access_token=access_token, - claims_names=info.keys()) + return UserClaimsResponse( + endpoint=self.claims_userinfo_endpoint, + access_token=access_token, + claims_names=info.keys(), + ) def do_aggregation(self, info, uid): return self.dist_claims_mode.aggregate(uid, info) @@ -104,8 +133,7 @@ def claims_endpoint(self, request, http_authz, *args): _log_info("Failed to verify client due to: %s" % err) if "claims_names" in ucreq: - args = dict([(n, {"optional": True}) for n in - ucreq["claims_names"]]) + args = dict([(n, {"optional": True}) for n in ucreq["claims_names"]]) uic = Claims(**args) else: uic = None @@ -113,8 +141,9 @@ def claims_endpoint(self, request, http_authz, *args): _log_info("User info claims: %s" % sanitize(uic)) # oicsrv, userdb, subject, client_id="", user_info_claims=None - info = self.userinfo(ucreq["sub"], user_info_claims=uic, - client_id=ucreq["client_id"]) + info = self.userinfo( + ucreq["sub"], user_info_claims=uic, client_id=ucreq["client_id"] + ) _log_info("User info: %s" % sanitize(info)) @@ -137,7 +166,7 @@ def claims_info_endpoint(self, request, authn): ucreq = self.srvmethod.parse_userinfo_claims_request(request) # Access_token is mandatory in UserInfoClaimsRequest - uiresp = OpenIDSchema(**self.info_store[ucreq['access_token']]) + uiresp = OpenIDSchema(**self.info_store[ucreq["access_token"]]) _log_info("returning: %s" % sanitize(uiresp.to_dict())) return Response(uiresp.to_json(), content="application/json") @@ -153,21 +182,26 @@ def __init__(self, client_id=None, verify_ssl=True): self.response2error = RESPONSE2ERROR.copy() self.response2error["UserClaimsResponse"] = ["ErrorResponse"] - def construct_UserClaimsRequest(self, request=UserClaimsRequest, - request_args=None, extra_args=None, - **kwargs): + def construct_UserClaimsRequest( + self, request=UserClaimsRequest, request_args=None, extra_args=None, **kwargs + ): return self.construct_request(request, request_args, extra_args) - def do_claims_request(self, request=UserClaimsRequest, - request_resp=UserClaimsResponse, - body_type="json", - method="POST", request_args=None, extra_args=None, - http_args=None): - - url, body, ht_args, _ = self.request_info(request, method=method, - request_args=request_args, - extra_args=extra_args) + def do_claims_request( + self, + request=UserClaimsRequest, + request_resp=UserClaimsResponse, + body_type="json", + method="POST", + request_args=None, + extra_args=None, + http_args=None, + ): + + url, body, ht_args, _ = self.request_info( + request, method=method, request_args=request_args, extra_args=extra_args + ) if http_args is None: http_args = ht_args @@ -177,11 +211,16 @@ def do_claims_request(self, request=UserClaimsRequest, # http_args = self.init_authentication_method(csi, "bearer_header", # request_args) - return self.request_and_return(url, request_resp, method, body, - body_type, extended=False, - http_args=http_args, - key=self.keyjar.verify_keys( - self.keyjar.match_owner(url))) + return self.request_and_return( + url, + request_resp, + method, + body, + body_type, + extended=False, + http_args=http_args, + key=self.keyjar.verify_keys(self.keyjar.match_owner(url)), + ) class UserClaimsEndpoint(Endpoint): diff --git a/src/oic/oic/consumer.py b/src/oic/oic/consumer.py index eae999b62..a3298bfc0 100644 --- a/src/oic/oic/consumer.py +++ b/src/oic/oic/consumer.py @@ -20,7 +20,7 @@ from oic.utils import http_util from oic.utils.sanitize import sanitize -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -97,15 +97,22 @@ def clean_response(aresp): "request_object_encryption_enc", "default_max_age", "require_auth_time", - "default_acr_values" + "default_acr_values", ] class Consumer(Client): """An OpenID Connect consumer implementation.""" - def __init__(self, session_db, consumer_config, client_config=None, - server_info=None, debug=False, client_prefs=None): + def __init__( + self, + session_db, + consumer_config, + client_config=None, + server_info=None, + debug=False, + client_prefs=None, + ): """ Initialize a Consumer instance. @@ -172,8 +179,7 @@ def restore(self, sid): setattr(self, key, val) def dictionary(self): - return dict([(k, v) for k, v in - self.__dict__.items() if k not in IGNORE]) + return dict([(k, v) for k, v in self.__dict__.items() if k not in IGNORE]) def _backup(self, sid): """ @@ -183,8 +189,7 @@ def _backup(self, sid): """ self.sdb[sid] = self.dictionary() - def begin(self, scope="", response_type="", use_nonce=False, path="", - **kwargs): + def begin(self, scope="", response_type="", use_nonce=False, path="", **kwargs): """ Begin the OIDC flow. @@ -237,7 +242,7 @@ def begin(self, scope="", response_type="", use_nonce=False, path="", # OPTIONAL on code flow. if "token" in response_type or use_nonce: args["nonce"] = rndstr(12) - self.state2nonce[sid] = args['nonce'] + self.state2nonce[sid] = args["nonce"] if "max_age" in self.consumer_config: args["max_age"] = self.consumer_config["max_age"] @@ -245,21 +250,23 @@ def begin(self, scope="", response_type="", use_nonce=False, path="", _claims = None if "user_info" in self.consumer_config: _claims = ClaimsRequest( - userinfo=Claims(**self.consumer_config["user_info"])) + userinfo=Claims(**self.consumer_config["user_info"]) + ) if "id_token" in self.consumer_config: if _claims: _claims["id_token"] = Claims(**self.consumer_config["id_token"]) else: _claims = ClaimsRequest( - id_token=Claims(**self.consumer_config["id_token"])) + id_token=Claims(**self.consumer_config["id_token"]) + ) if _claims: args["claims"] = _claims if "request_method" in self.consumer_config: areq = self.construct_AuthorizationRequest( - request_args=args, extra_args=None, - request_param="request") + request_args=args, extra_args=None, request_param="request" + ) if self.consumer_config["request_method"] == "file": id_request = areq["request"] @@ -282,8 +289,9 @@ def begin(self, scope="", response_type="", use_nonce=False, path="", if "userinfo_claims" in args: # can only be carried in an IDRequest raise PyoidcError("Need a request method") - areq = self.construct_AuthorizationRequest(AuthorizationRequest, - request_args=args) + areq = self.construct_AuthorizationRequest( + AuthorizationRequest, request_args=args + ) location = areq.request(self.authorization_endpoint) @@ -296,13 +304,12 @@ def _parse_authz(self, query="", **kwargs): _log_info = logger.info # Might be an error response _log_info("Expect Authorization Response") - aresp = self.parse_response(AuthorizationResponse, - info=query, - sformat="urlencoded", - keyjar=self.keyjar) + aresp = self.parse_response( + AuthorizationResponse, info=query, sformat="urlencoded", keyjar=self.keyjar + ) if isinstance(aresp, ErrorResponse): _log_info("ErrorResponse: %s" % sanitize(aresp)) - raise AuthzError(aresp.get('error'), aresp) + raise AuthzError(aresp.get("error"), aresp) _log_info("Aresp: %s" % sanitize(aresp)) @@ -359,9 +366,13 @@ def parse_authz(self, query="", **kwargs): return aresp, atr, idt elif "token" in self.consumer_config["response_type"]: # implicit flow _log_info("Expect Access Token Response") - atr = self.parse_response(AccessTokenResponse, info=query, - sformat="urlencoded", - keyjar=self.keyjar, **kwargs) + atr = self.parse_response( + AccessTokenResponse, + info=query, + sformat="urlencoded", + keyjar=self.keyjar, + **kwargs + ) if isinstance(atr, ErrorResponse): raise TokenError(atr.get("error"), atr) @@ -389,15 +400,19 @@ def complete(self, state): elif self.client_secret: logger.info("request_body auth") http_args = {} - args.update({"client_secret": self.client_secret, - "client_id": self.client_id, - "secret_type": self.secret_type}) + args.update( + { + "client_secret": self.client_secret, + "client_id": self.client_id, + "secret_type": self.secret_type, + } + ) else: raise PyoidcError("Nothing to authenticate with") - resp = self.do_access_token_request(state=state, - request_args=args, - http_args=http_args) + resp = self.do_access_token_request( + state=state, request_args=args, http_args=http_args + ) logger.info("Access Token Response: %s" % sanitize(resp)) diff --git a/src/oic/oic/message.py b/src/oic/oic/message.py index 6b7fe3ae2..637a8b73a 100644 --- a/src/oic/oic/message.py +++ b/src/oic/oic/message.py @@ -37,7 +37,7 @@ from oic.oauth2.message import SchemeError from oic.utils import time_util -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -230,12 +230,18 @@ def claims_request_deser(val, sformat="json"): OPTIONAL_ADDRESS = ParamDefinition(Message, False, msg_ser, address_deser, False) OPTIONAL_LOGICAL = ParamDefinition(bool, False, None, None, False) -OPTIONAL_MULTIPLE_Claims = ParamDefinition(Message, False, claims_ser, claims_deser, False) +OPTIONAL_MULTIPLE_Claims = ParamDefinition( + Message, False, claims_ser, claims_deser, False +) SINGLE_OPTIONAL_IDTOKEN = ParamDefinition(str, False, msg_ser, None, False) -SINGLE_OPTIONAL_REGISTRATION_REQUEST = ParamDefinition(Message, False, msg_ser, registration_request_deser, False) -SINGLE_OPTIONAL_CLAIMSREQ = ParamDefinition(Message, False, msg_ser_json, claims_request_deser, False) +SINGLE_OPTIONAL_REGISTRATION_REQUEST = ParamDefinition( + Message, False, msg_ser, registration_request_deser, False +) +SINGLE_OPTIONAL_CLAIMSREQ = ParamDefinition( + Message, False, msg_ser_json, claims_request_deser, False +) OPTIONAL_MESSAGE = ParamDefinition(Message, False, msg_ser, message_deser, False) REQUIRED_MESSAGE = ParamDefinition(Message, True, msg_ser, message_deser, False) @@ -244,7 +250,7 @@ def claims_request_deser(val, sformat="json"): SCOPE_CHARSET = [] -for char in ['\x21', ('\x23', '\x5b'), ('\x5d', '\x7E')]: +for char in ["\x21", ("\x23", "\x5b"), ("\x5d", "\x7E")]: if isinstance(char, tuple): c = char[0] while c <= char[1]: @@ -275,22 +281,22 @@ def verify_id_token(instance, check_hash=False, **kwargs): _jwe = JWE_factory(_jws) if _jwe is not None: try: - _jws = _jwe.decrypt(keys=kwargs['keyjar'].get_decrypt_key()) + _jws = _jwe.decrypt(keys=kwargs["keyjar"].get_decrypt_key()) except JWEException as err: raise VerificationError("Could not decrypt id_token", err) _packer = JWT() _body = _packer.unpack(_jws).payload() - if 'keyjar' in kwargs: + if "keyjar" in kwargs: try: - if _body['iss'] not in kwargs['keyjar']: - raise ValueError('Unknown issuer') + if _body["iss"] not in kwargs["keyjar"]: + raise ValueError("Unknown issuer") except KeyError: - raise MissingRequiredAttribute('iss') + raise MissingRequiredAttribute("iss") if _jwe is not None: # Use the original encrypted token to set correct headers - idt = IdToken().from_jwt(str(instance['id_token']), **args) + idt = IdToken().from_jwt(str(instance["id_token"]), **args) else: idt = IdToken().from_jwt(_jws, **args) if not idt.verify(**kwargs): @@ -319,6 +325,7 @@ def verify_id_token(instance, check_hash=False, **kwargs): # ----------------------------------------------------------------------------- + class RefreshAccessTokenRequest(message.RefreshAccessTokenRequest): pass @@ -341,21 +348,20 @@ def verify(self, **kwargs): class UserInfoRequest(Message): - c_param = { - "access_token": SINGLE_OPTIONAL_STRING, - } + c_param = {"access_token": SINGLE_OPTIONAL_STRING} -class AuthorizationResponse(message.AuthorizationResponse, - message.AccessTokenResponse): +class AuthorizationResponse(message.AuthorizationResponse, message.AccessTokenResponse): c_param = message.AuthorizationResponse.c_param.copy() c_param.update(message.AccessTokenResponse.c_param) - c_param.update({ - "code": SINGLE_OPTIONAL_STRING, - "access_token": SINGLE_OPTIONAL_STRING, - "token_type": SINGLE_OPTIONAL_STRING, - "id_token": SINGLE_OPTIONAL_IDTOKEN, - }) + c_param.update( + { + "code": SINGLE_OPTIONAL_STRING, + "access_token": SINGLE_OPTIONAL_STRING, + "token_type": SINGLE_OPTIONAL_STRING, + "id_token": SINGLE_OPTIONAL_IDTOKEN, + } + ) def verify(self, **kwargs): super(AuthorizationResponse, self).verify(**kwargs) @@ -377,17 +383,20 @@ def verify(self, **kwargs): class AuthorizationErrorResponse(message.AuthorizationErrorResponse): - c_allowed_values = message.AuthorizationErrorResponse.c_allowed_values \ - .copy() - c_allowed_values["error"].extend(["interaction_required", - "login_required", - "session_selection_required", - "consent_required", - "invalid_request_uri", - "invalid_request_object", - "registration_not_supported", - "request_not_supported", - "request_uri_not_supported"]) + c_allowed_values = message.AuthorizationErrorResponse.c_allowed_values.copy() + c_allowed_values["error"].extend( + [ + "interaction_required", + "login_required", + "session_selection_required", + "consent_required", + "invalid_request_uri", + "invalid_request_object", + "registration_not_supported", + "request_not_supported", + "request_uri_not_supported", + ] + ) class AuthorizationRequest(message.AuthorizationRequest): @@ -413,10 +422,12 @@ class AuthorizationRequest(message.AuthorizationRequest): } ) c_allowed_values = message.AuthorizationRequest.c_allowed_values.copy() - c_allowed_values.update({ - "display": ["page", "popup", "touch", "wap"], - "prompt": ["none", "login", "consent", "select_account"] - }) + c_allowed_values.update( + { + "display": ["page", "popup", "touch", "wap"], + "prompt": ["none", "login", "consent", "select_account"], + } + ) def verify(self, **kwargs): """ @@ -477,55 +488,63 @@ def verify(self, **kwargs): if "prompt" in self: if "none" in self["prompt"] and len(self["prompt"]) > 1: - raise InvalidRequest("prompt none combined with other value", - self) + raise InvalidRequest("prompt none combined with other value", self) return True class AccessTokenRequest(message.AccessTokenRequest): c_param = message.AccessTokenRequest.c_param.copy() - c_param.update({"client_assertion_type": SINGLE_OPTIONAL_STRING, - "client_assertion": SINGLE_OPTIONAL_STRING}) + c_param.update( + { + "client_assertion_type": SINGLE_OPTIONAL_STRING, + "client_assertion": SINGLE_OPTIONAL_STRING, + } + ) c_default = {"grant_type": "authorization_code"} c_allowed_values = { "client_assertion_type": [ - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"], + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + ] } class AddressClaim(Message): - c_param = {"formatted": SINGLE_OPTIONAL_STRING, - "street_address": SINGLE_OPTIONAL_STRING, - "locality": SINGLE_OPTIONAL_STRING, - "region": SINGLE_OPTIONAL_STRING, - "postal_code": SINGLE_OPTIONAL_STRING, - "country": SINGLE_OPTIONAL_STRING} + c_param = { + "formatted": SINGLE_OPTIONAL_STRING, + "street_address": SINGLE_OPTIONAL_STRING, + "locality": SINGLE_OPTIONAL_STRING, + "region": SINGLE_OPTIONAL_STRING, + "postal_code": SINGLE_OPTIONAL_STRING, + "country": SINGLE_OPTIONAL_STRING, + } class OpenIDSchema(Message): - c_param = {"sub": SINGLE_REQUIRED_STRING, - "name": SINGLE_OPTIONAL_STRING, - "given_name": SINGLE_OPTIONAL_STRING, - "family_name": SINGLE_OPTIONAL_STRING, - "middle_name": SINGLE_OPTIONAL_STRING, - "nickname": SINGLE_OPTIONAL_STRING, - "preferred_username": SINGLE_OPTIONAL_STRING, - "profile": SINGLE_OPTIONAL_STRING, - "picture": SINGLE_OPTIONAL_STRING, - "website": SINGLE_OPTIONAL_STRING, - "email": SINGLE_OPTIONAL_STRING, - "email_verified": SINGLE_OPTIONAL_BOOLEAN, - "gender": SINGLE_OPTIONAL_STRING, - "birthdate": SINGLE_OPTIONAL_STRING, - "zoneinfo": SINGLE_OPTIONAL_STRING, - "locale": SINGLE_OPTIONAL_STRING, - "phone_number": SINGLE_OPTIONAL_STRING, - "phone_number_verified": SINGLE_OPTIONAL_BOOLEAN, - "address": OPTIONAL_ADDRESS, - "updated_at": SINGLE_OPTIONAL_INT, - "_claim_names": OPTIONAL_MESSAGE, - "_claim_sources": OPTIONAL_MESSAGE} + c_param = { + "sub": SINGLE_REQUIRED_STRING, + "name": SINGLE_OPTIONAL_STRING, + "given_name": SINGLE_OPTIONAL_STRING, + "family_name": SINGLE_OPTIONAL_STRING, + "middle_name": SINGLE_OPTIONAL_STRING, + "nickname": SINGLE_OPTIONAL_STRING, + "preferred_username": SINGLE_OPTIONAL_STRING, + "profile": SINGLE_OPTIONAL_STRING, + "picture": SINGLE_OPTIONAL_STRING, + "website": SINGLE_OPTIONAL_STRING, + "email": SINGLE_OPTIONAL_STRING, + "email_verified": SINGLE_OPTIONAL_BOOLEAN, + "gender": SINGLE_OPTIONAL_STRING, + "birthdate": SINGLE_OPTIONAL_STRING, + "zoneinfo": SINGLE_OPTIONAL_STRING, + "locale": SINGLE_OPTIONAL_STRING, + "phone_number": SINGLE_OPTIONAL_STRING, + "phone_number_verified": SINGLE_OPTIONAL_BOOLEAN, + "address": OPTIONAL_ADDRESS, + "updated_at": SINGLE_OPTIONAL_INT, + "_claim_names": OPTIONAL_MESSAGE, + "_claim_sources": OPTIONAL_MESSAGE, + } def verify(self, **kwargs): super(OpenIDSchema, self).verify(**kwargs) @@ -584,18 +603,24 @@ class RegistrationRequest(Message): "post_logout_redirect_uris": OPTIONAL_LIST_OF_STRINGS, } c_default = {"application_type": "web", "response_types": ["code"]} - c_allowed_values = {"application_type": ["native", "web"], - "subject_type": ["public", "pairwise"]} + c_allowed_values = { + "application_type": ["native", "web"], + "subject_type": ["public", "pairwise"], + } def verify(self, **kwargs): super(RegistrationRequest, self).verify(**kwargs) - if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith("https:"): + if "initiate_login_uri" in self and not self["initiate_login_uri"].startswith( + "https:" + ): raise AssertionError() - for param in ["request_object_encryption", - "id_token_encrypted_response", - "userinfo_encrypted_response"]: + for param in [ + "request_object_encryption", + "id_token_encrypted_response", + "userinfo_encrypted_response", + ]: alg_param = "%s_alg" % param enc_param = "%s_enc" % param if alg_param in self: @@ -606,7 +631,10 @@ def verify(self, **kwargs): if enc_param in self and alg_param not in self: raise AssertionError() - if "token_endpoint_auth_signing_alg" in self and self["token_endpoint_auth_signing_alg"] == "none": + if ( + "token_endpoint_auth_signing_alg" in self + and self["token_endpoint_auth_signing_alg"] == "none" + ): raise AssertionError() return True @@ -639,44 +667,53 @@ def verify(self, **kwargs): has_reg_uri = "registration_client_uri" in self has_reg_at = "registration_access_token" in self if has_reg_uri != has_reg_at: - raise VerificationError(( - "Only one of registration_client_uri" - " and registration_access_token present"), self) + raise VerificationError( + ( + "Only one of registration_client_uri" + " and registration_access_token present" + ), + self, + ) return True class ClientRegistrationErrorResponse(message.ErrorResponse): - c_allowed_values = {"error": ["invalid_redirect_uri", - "invalid_client_metadata", - "invalid_configuration_parameter"]} + c_allowed_values = { + "error": [ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_configuration_parameter", + ] + } class IdToken(OpenIDSchema): c_param = OpenIDSchema.c_param.copy() - c_param.update({ - "iss": SINGLE_REQUIRED_STRING, - "sub": SINGLE_REQUIRED_STRING, - "aud": REQUIRED_LIST_OF_STRINGS, # Array of strings or string - "exp": SINGLE_REQUIRED_INT, - "iat": SINGLE_REQUIRED_INT, - "auth_time": SINGLE_OPTIONAL_INT, - "nonce": SINGLE_OPTIONAL_STRING, - "at_hash": SINGLE_OPTIONAL_STRING, - "c_hash": SINGLE_OPTIONAL_STRING, - "acr": SINGLE_OPTIONAL_STRING, - "amr": OPTIONAL_LIST_OF_STRINGS, - "azp": SINGLE_OPTIONAL_STRING, - "sub_jwk": SINGLE_OPTIONAL_STRING - }) + c_param.update( + { + "iss": SINGLE_REQUIRED_STRING, + "sub": SINGLE_REQUIRED_STRING, + "aud": REQUIRED_LIST_OF_STRINGS, # Array of strings or string + "exp": SINGLE_REQUIRED_INT, + "iat": SINGLE_REQUIRED_INT, + "auth_time": SINGLE_OPTIONAL_INT, + "nonce": SINGLE_OPTIONAL_STRING, + "at_hash": SINGLE_OPTIONAL_STRING, + "c_hash": SINGLE_OPTIONAL_STRING, + "acr": SINGLE_OPTIONAL_STRING, + "amr": OPTIONAL_LIST_OF_STRINGS, + "azp": SINGLE_OPTIONAL_STRING, + "sub_jwk": SINGLE_OPTIONAL_STRING, + } + ) def verify(self, **kwargs): super(IdToken, self).verify(**kwargs) try: - if kwargs['iss'] != self['iss']: - raise IssuerMismatch( - '{} != {}'.format(kwargs['iss'], self['iss'])) + if kwargs["iss"] != self["iss"]: + raise IssuerMismatch("{} != {}".format(kwargs["iss"], self["iss"])) except KeyError: pass @@ -685,8 +722,9 @@ def verify(self, **kwargs): # check that I'm among the recipients if kwargs["client_id"] not in self["aud"]: raise NotForMe( - "{} not in aud:{}".format(kwargs["client_id"], - self["aud"]), self) + "{} not in aud:{}".format(kwargs["client_id"], self["aud"]), + self, + ) # Then azp has to be present and be one of the aud values if len(self["aud"]) > 1: @@ -699,49 +737,50 @@ def verify(self, **kwargs): if "client_id" in kwargs: if kwargs["client_id"] != self["azp"]: raise NotForMe( - "{} != azp:{}".format(kwargs["client_id"], - self["azp"]), self) + "{} != azp:{}".format(kwargs["client_id"], self["azp"]), self + ) _now = time_util.utc_time_sans_frac() try: - _skew = kwargs['skew'] + _skew = kwargs["skew"] except KeyError: _skew = 0 try: - _exp = self['exp'] + _exp = self["exp"] except KeyError: - raise MissingRequiredAttribute('exp') + raise MissingRequiredAttribute("exp") else: if (_now - _skew) > _exp: - raise EXPError('Invalid expiration time') + raise EXPError("Invalid expiration time") try: - _storage_time = kwargs['nonce_storage_time'] + _storage_time = kwargs["nonce_storage_time"] except KeyError: _storage_time = NONCE_STORAGE_TIME try: - _iat = self['iat'] + _iat = self["iat"] except KeyError: - raise MissingRequiredAttribute('iat') + raise MissingRequiredAttribute("iat") else: if (_iat + _storage_time) < (_now - _skew): - raise IATError('Issued too long ago') + raise IATError("Issued too long ago") return True class RefreshSessionRequest(Message): - c_param = {"id_token": SINGLE_REQUIRED_STRING, - "redirect_url": SINGLE_REQUIRED_STRING, - "state": SINGLE_REQUIRED_STRING} + c_param = { + "id_token": SINGLE_REQUIRED_STRING, + "redirect_url": SINGLE_REQUIRED_STRING, + "state": SINGLE_REQUIRED_STRING, + } class RefreshSessionResponse(Message): - c_param = {"id_token": SINGLE_REQUIRED_STRING, - "state": SINGLE_REQUIRED_STRING} + c_param = {"id_token": SINGLE_REQUIRED_STRING, "state": SINGLE_REQUIRED_STRING} class CheckSessionRequest(Message): @@ -756,7 +795,7 @@ class EndSessionRequest(Message): c_param = { "id_token_hint": SINGLE_OPTIONAL_STRING, "post_logout_redirect_uri": SINGLE_OPTIONAL_STRING, - "state": SINGLE_OPTIONAL_STRING + "state": SINGLE_OPTIONAL_STRING, } @@ -771,7 +810,7 @@ class Claims(Message): class ClaimsRequest(Message): c_param = { "userinfo": OPTIONAL_MULTIPLE_Claims, - "id_token": OPTIONAL_MULTIPLE_Claims + "id_token": OPTIONAL_MULTIPLE_Claims, } @@ -800,13 +839,10 @@ class ProviderConfigurationResponse(Message): "userinfo_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, "userinfo_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, "request_object_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, - "request_object_encryption_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, - "request_object_encryption_enc_values_supported": - OPTIONAL_LIST_OF_STRINGS, + "request_object_encryption_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, + "request_object_encryption_enc_values_supported": OPTIONAL_LIST_OF_STRINGS, "token_endpoint_auth_methods_supported": OPTIONAL_LIST_OF_STRINGS, - "token_endpoint_auth_signing_alg_values_supported": - OPTIONAL_LIST_OF_STRINGS, + "token_endpoint_auth_signing_alg_values_supported": OPTIONAL_LIST_OF_STRINGS, "display_values_supported": OPTIONAL_LIST_OF_STRINGS, "claim_types_supported": OPTIONAL_LIST_OF_STRINGS, "claims_supported": OPTIONAL_LIST_OF_STRINGS, @@ -822,14 +858,15 @@ class ProviderConfigurationResponse(Message): "check_session_iframe": SINGLE_OPTIONAL_STRING, "end_session_endpoint": SINGLE_OPTIONAL_STRING, } - c_default = {"version": "3.0", - "token_endpoint_auth_methods_supported": [ - "client_secret_basic"], - "claims_parameter_supported": False, - "request_parameter_supported": False, - "request_uri_parameter_supported": True, - "require_request_uri_registration": False, - "grant_types_supported": ["authorization_code", "implicit"]} + c_default = { + "version": "3.0", + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "claims_parameter_supported": False, + "request_parameter_supported": False, + "request_uri_parameter_supported": True, + "require_request_uri_registration": False, + "grant_types_supported": ["authorization_code", "implicit"], + } def verify(self, **kwargs): super(ProviderConfigurationResponse, self).verify(**kwargs) @@ -847,8 +884,10 @@ def verify(self, **kwargs): if parts.query or parts.fragment: raise AssertionError() - if any("code" in rt for rt in self[ - "response_types_supported"]) and "token_endpoint" not in self: + if ( + any("code" in rt for rt in self["response_types_supported"]) + and "token_endpoint" not in self + ): raise MissingRequiredAttribute("token_endpoint") return True @@ -892,13 +931,18 @@ def jwt_deser(val, sformat="json"): class UserInfoErrorResponse(message.ErrorResponse): - c_allowed_values = {"error": ["invalid_schema", "invalid_request", - "invalid_token", "insufficient_scope"]} + c_allowed_values = { + "error": [ + "invalid_schema", + "invalid_request", + "invalid_token", + "insufficient_scope", + ] + } class DiscoveryRequest(Message): - c_param = {"principal": SINGLE_REQUIRED_STRING, - "service": SINGLE_REQUIRED_STRING} + c_param = {"principal": SINGLE_REQUIRED_STRING, "service": SINGLE_REQUIRED_STRING} class DiscoveryResponse(Message): @@ -911,14 +955,26 @@ class ResourceRequest(Message): SCOPE2CLAIMS = { "openid": ["sub"], - "profile": ["name", "given_name", "family_name", "middle_name", - "nickname", "profile", "picture", "website", "gender", - "birthdate", "zoneinfo", "locale", "updated_at", - "preferred_username"], + "profile": [ + "name", + "given_name", + "family_name", + "middle_name", + "nickname", + "profile", + "picture", + "website", + "gender", + "birthdate", + "zoneinfo", + "locale", + "updated_at", + "preferred_username", + ], "email": ["email", "email_verified"], "address": ["address"], "phone": ["phone_number", "phone_number_verified"], - "offline_access": [] + "offline_access": [], } MSG = { @@ -954,7 +1010,9 @@ class ResourceRequest(Message): def factory(msgtype): - warnings.warn('`factory` is deprecated. Use `OIDCMessageFactory` instead.', DeprecationWarning) + warnings.warn( + "`factory` is deprecated. Use `OIDCMessageFactory` instead.", DeprecationWarning + ) for _, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and issubclass(obj, Message): try: @@ -982,5 +1040,7 @@ class OIDCMessageFactory(MessageFactory): checkid_endpoint = MessageTuple(CheckIDRequest, IdToken) checksession_endpoint = MessageTuple(CheckSessionRequest, IdToken) endsession_endpoint = MessageTuple(EndSessionRequest, EndSessionResponse) - refreshsession_endpoint = MessageTuple(RefreshSessionRequest, RefreshSessionResponse) + refreshsession_endpoint = MessageTuple( + RefreshSessionRequest, RefreshSessionResponse + ) discovery_endpoint = MessageTuple(DiscoveryRequest, DiscoveryResponse) diff --git a/src/oic/oic/provider.py b/src/oic/oic/provider.py index 56e6a3045..c8301c46d 100644 --- a/src/oic/oic/provider.py +++ b/src/oic/oic/provider.py @@ -82,7 +82,7 @@ from oic.utils.template_render import render_template from oic.utils.time_util import utc_time_sans_frac -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -154,43 +154,56 @@ def construct_uri(item): class AuthorizationEndpoint(Endpoint): etype = "authorization" - url = 'authorization' + url = "authorization" class TokenEndpoint(Endpoint): etype = "token" - url = 'token' + url = "token" class UserinfoEndpoint(Endpoint): etype = "userinfo" - url = 'userinfo' + url = "userinfo" class RegistrationEndpoint(Endpoint): etype = "registration" - url = 'registration' + url = "registration" class EndSessionEndpoint(Endpoint): etype = "end_session" - url = 'end_session' + url = "end_session" RESPONSE_TYPES_SUPPORTED = [ - ["code"], ["token"], ["id_token"], ["code", "token"], ["code", "id_token"], - ["id_token", "token"], ["code", "token", "id_token"], ['none']] + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] CAPABILITIES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ - "client_secret_post", "client_secret_basic", - "client_secret_jwt", "private_key_jwt"], - "response_modes_supported": ['query', 'fragment', 'form_post'], + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ - "authorization_code", "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", "refresh_token"], + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -199,26 +212,57 @@ class EndSessionEndpoint(Endpoint): class Provider(AProvider): - - def __init__(self, name, sdb, cdb, authn_broker, userinfo, authz, - client_authn, symkey=None, urlmap=None, keyjar=None, - hostname="", template_lookup=None, template=None, - verify_ssl=True, capabilities=None, schema=OpenIDSchema, - jwks_uri='', jwks_name='', baseurl=None, client_cert=None, - extra_claims=None, template_renderer=render_template, extra_scope_dict=None, - message_factory=OIDCMessageFactory): - - AProvider.__init__(self, name, sdb, cdb, authn_broker, authz, - client_authn, symkey, urlmap, - verify_ssl=verify_ssl, client_cert=client_cert, message_factory=message_factory) + def __init__( + self, + name, + sdb, + cdb, + authn_broker, + userinfo, + authz, + client_authn, + symkey=None, + urlmap=None, + keyjar=None, + hostname="", + template_lookup=None, + template=None, + verify_ssl=True, + capabilities=None, + schema=OpenIDSchema, + jwks_uri="", + jwks_name="", + baseurl=None, + client_cert=None, + extra_claims=None, + template_renderer=render_template, + extra_scope_dict=None, + message_factory=OIDCMessageFactory, + ): + + AProvider.__init__( + self, + name, + sdb, + cdb, + authn_broker, + authz, + client_authn, + symkey, + urlmap, + verify_ssl=verify_ssl, + client_cert=client_cert, + message_factory=message_factory, + ) # Should be a OIC Server not an OAuth2 server - self.server = Server(keyjar=keyjar, verify_ssl=verify_ssl, message_factory=message_factory) + self.server = Server( + keyjar=keyjar, verify_ssl=verify_ssl, message_factory=message_factory + ) # Same keyjar self.keyjar = self.server.keyjar - self.endp.extend([UserinfoEndpoint, RegistrationEndpoint, - EndSessionEndpoint]) + self.endp.extend([UserinfoEndpoint, RegistrationEndpoint, EndSessionEndpoint]) self.userinfo = userinfo self.template_renderer = template_renderer @@ -245,10 +289,12 @@ def __init__(self, name, sdb, cdb, authn_broker, userinfo, authz, self.register_endpoint = None for endp in self.endp: - if endp.etype == 'registration': + if endp.etype == "registration": endpoint = urljoin(self.baseurl, endp.url) - warnings.warn("Using `register_endpoint` is deprecated, please use `registration_endpoint` instead.", - DeprecationWarning) + warnings.warn( + "Using `register_endpoint` is deprecated, please use `registration_endpoint` instead.", + DeprecationWarning, + ) self.register_endpoint = endpoint break @@ -279,7 +325,7 @@ def build_jwx_def(self): for _typ in ["signing_alg", "encryption_alg", "encryption_enc"]: self.jwx_def[_typ] = {} for item in ["id_token", "userinfo"]: - cap_param = '{}_{}_values_supported'.format(item, _typ) + cap_param = "{}_{}_values_supported".format(item, _typ) try: self.jwx_def[_typ][item] = self.capabilities[cap_param][0] except KeyError: @@ -329,9 +375,19 @@ def set_mode(self, mode): if val.endswith("encryption_enc_values_supported"): self.capabilities[val] = [_enc_enc] - def id_token_as_signed_jwt(self, session, loa="2", alg="", code=None, - access_token=None, user_info=None, auth_time=0, - exp=None, extra_claims=None, **kwargs): + def id_token_as_signed_jwt( + self, + session, + loa="2", + alg="", + code=None, + access_token=None, + user_info=None, + auth_time=0, + exp=None, + extra_claims=None, + **kwargs + ): if alg == "": alg = self.jwx_def["signing_alg"]["id_token"] @@ -341,15 +397,24 @@ def id_token_as_signed_jwt(self, session, loa="2", alg="", code=None, else: alg = "none" - _idt = self.server.make_id_token(session, loa, self.name, alg, code, - access_token, user_info, auth_time, - exp, extra_claims) + _idt = self.server.make_id_token( + session, + loa, + self.name, + alg, + code, + access_token, + user_info, + auth_time, + exp, + extra_claims, + ) try: - ckey = kwargs['keys'] + ckey = kwargs["keys"] except KeyError: try: - _keyjar = kwargs['keyjar'] + _keyjar = kwargs["keyjar"] except KeyError: _keyjar = self.keyjar @@ -357,15 +422,13 @@ def id_token_as_signed_jwt(self, session, loa="2", alg="", code=None, # My signing key if its RS*, can use client secret if HS* if alg.startswith("HS"): logger.debug("client_id: %s" % session["client_id"]) - ckey = _keyjar.get_signing_key(alg2keytype(alg), - session["client_id"]) + ckey = _keyjar.get_signing_key(alg2keytype(alg), session["client_id"]) if not ckey: # create a new key _secret = self.cdb[session["client_id"]]["client_secret"] ckey = [SYMKey(key=_secret)] else: if "" in self.keyjar: - ckey = _keyjar.get_signing_key(alg2keytype(alg), "", - alg=alg) + ckey = _keyjar.get_signing_key(alg2keytype(alg), "", alg=alg) else: ckey = None @@ -374,8 +437,7 @@ def id_token_as_signed_jwt(self, session, loa="2", alg="", code=None, return _signed_jwt def _parse_openid_request(self, request, **kwargs): - return OpenIDRequest().from_jwt(request, keyjar=self.keyjar, - **kwargs) + return OpenIDRequest().from_jwt(request, keyjar=self.keyjar, **kwargs) def _parse_id_token(self, id_token, redirect_uri): try: @@ -385,8 +447,7 @@ def _parse_id_token(self, id_token, redirect_uri): logger.error("Exception: %s" % (err.__class__.__name__,)) id_token = IdToken().from_jwt(id_token, verify=False) logger.error("IdToken: %s" % id_token.to_dict()) - return redirect_authz_error("invalid_id_token_object", - redirect_uri) + return redirect_authz_error("invalid_id_token_object", redirect_uri) @staticmethod def get_sector_id(redirect_uri, client_info): @@ -499,9 +560,11 @@ def let_user_verify_logout(self, uid, esr, cookie, redirect_uri): "post_logout_redirect_uri": esr["post_logout_redirect_uri"], "key": self.sdb.get_verify_logout(uid), "redirect": redirect, - "action": "/" + EndSessionEndpoint("").etype + "action": "/" + EndSessionEndpoint("").etype, } - return Response(self.template_renderer('verify_logout', context), headers=headers) + return Response( + self.template_renderer("verify_logout", context), headers=headers + ) def _get_sids_from_cookie(self, cookie): """Get cookie_dealer, client_id and sids from cookie.""" @@ -514,16 +577,18 @@ def _get_sids_from_cookie(self, cookie): _cval = cookie_dealer.get_cookie_value(cookie, self.sso_cookie_name) if _cval: (value, _ts, typ) = _cval - if typ == 'sso': + if typ == "sso": uid, client_id = value.split(DELIM) try: sids = self.sdb.uid2sid[uid] except (KeyError, IndexError): - raise SubMismatch('Mismatch uid') + raise SubMismatch("Mismatch uid") return cookie_dealer, client_id, sids def end_session_endpoint(self, request="", cookie=None, **kwargs): - esr = self.server.message_factory.get_request_type('endsession_endpoint')().from_urlencoded(request) + esr = self.server.message_factory.get_request_type( + "endsession_endpoint" + )().from_urlencoded(request) logger.debug("End session request: {}", sanitize(esr.to_dict())) @@ -532,26 +597,26 @@ def end_session_endpoint(self, request="", cookie=None, **kwargs): try: cookie_dealer, client_id, sids = self._get_sids_from_cookie(cookie) except SubMismatch as error: - return error_response('invalid_request', '%s' % error) + return error_response("invalid_request", "%s" % error) if "id_token_hint" in esr: - id_token_hint = IdToken().from_jwt(esr["id_token_hint"], - keyjar=self.keyjar, - verify=True) - far_away = 86400*30 # 30 days + id_token_hint = IdToken().from_jwt( + esr["id_token_hint"], keyjar=self.keyjar, verify=True + ) + far_away = 86400 * 30 # 30 days if client_id: - args = {'client_id': client_id} + args = {"client_id": client_id} else: args = {} try: - id_token_hint.verify(iss=self.baseurl, skew=far_away, - nonce_storage_time=far_away, **args) + id_token_hint.verify( + iss=self.baseurl, skew=far_away, nonce_storage_time=far_away, **args + ) except (VerificationError, NotForMe) as err: - logger.warning( - 'Verification error on id_token_hint: {}'.format(err)) - return error_response('invalid_request', "Bad Id Token hint") + logger.warning("Verification error on id_token_hint: {}".format(err)) + return error_response("invalid_request", "Bad Id Token hint") sub = id_token_hint["sub"] @@ -559,11 +624,11 @@ def end_session_endpoint(self, request="", cookie=None, **kwargs): match = False # verify that 'sub' are bound to 'user' for sid in sids: - if self.sdb[sid]['sub'] == sub: + if self.sdb[sid]["sub"] == sub: match = True break if not match: - return error_response('invalid_request', "Wrong user") + return error_response("invalid_request", "Wrong user") else: try: sids = self.sdb.get_sids_by_sub(sub) @@ -571,36 +636,37 @@ def end_session_endpoint(self, request="", cookie=None, **kwargs): pass if not client_id: - if len(id_token_hint['aud']) == 1: - client_id = id_token_hint['aud'][0] + if len(id_token_hint["aud"]) == 1: + client_id = id_token_hint["aud"][0] else: - client_id = id_token_hint['azp'] + client_id = id_token_hint["azp"] if not client_id: - return error_response('invalid_request', "Could not find client ID") + return error_response("invalid_request", "Could not find client ID") if client_id not in self.cdb: - return error_response('invalid_request', "Unknown client") + return error_response("invalid_request", "Unknown client") match = False for sid in sids: - if self.sdb[sid]['client_id'] == client_id: + if self.sdb[sid]["client_id"] == client_id: match = True break if not match: - return error_response('invalid_request', "Unmatched client") + return error_response("invalid_request", "Unmatched client") redirect_uri = None if "post_logout_redirect_uri" in esr: redirect_uri = self.verify_post_logout_redirect_uri(esr, client_id) if not redirect_uri: msg = "Post logout redirect URI verification failed!" - return error_response('invalid_request', msg) + return error_response("invalid_request", msg) else: # If only one registered use that one if len(self.cdb[client_id]["post_logout_redirect_uris"]) == 1: _base, _query = self.cdb[client_id]["post_logout_redirect_uris"][0] if _query: query_string = urlencode( - [(key, v) for key in _query for v in _query[key]]) + [(key, v) for key in _query for v in _query[key]] + ) redirect_uri = "%s?%s" % (_base, query_string) else: redirect_uri = _base @@ -616,16 +682,16 @@ def end_session_endpoint(self, request="", cookie=None, **kwargs): if redirect_uri is not None: try: - _state = esr['state'] + _state = esr["state"] except KeyError: redirect_uri = str(redirect_uri) else: - if '?' in redirect_uri: + if "?" in redirect_uri: redirect_uri += "&" else: redirect_uri += "?" - redirect_uri += urlencode({'state': _state}) + redirect_uri += urlencode({"state": _state}) return SeeOther(redirect_uri, headers=headers) @@ -648,7 +714,7 @@ def verify_endpoint(self, request="", cookie=None, **kwargs): _req = compact(parse_qs(request)) try: - areq = Message().from_urlencoded(_req['query']) + areq = Message().from_urlencoded(_req["query"]) except KeyError: areq = _req @@ -676,7 +742,7 @@ def setup_session(self, areq, authn_event, cinfo): except KeyError: pass - self.sdb.do_sub(sid, cinfo['client_salt'], **kwargs) + self.sdb.do_sub(sid, cinfo["client_salt"], **kwargs) return sid def match_sp_sep(self, first, second): @@ -691,38 +757,39 @@ def filter_request(self, req): before = req.to_dict() - if 'claims' in req: - if _cap['claims_parameter_supported']: - if _cap['claims_supported']: - for part in ['userinfo', 'id_token']: - if part in req['claims']: - _keys = list(req['claims'][part].keys()) + if "claims" in req: + if _cap["claims_parameter_supported"]: + if _cap["claims_supported"]: + for part in ["userinfo", "id_token"]: + if part in req["claims"]: + _keys = list(req["claims"][part].keys()) for c in _keys: - if c not in _cap['claims_supported']: - del req['claims'][part][c] + if c not in _cap["claims_supported"]: + del req["claims"][part][c] else: - del req['claims'] + del req["claims"] - if 'scope' in req: - _scopes = [s for s in req['scope'] if s in _cap['scopes_supported']] - req['scope'] = _scopes + if "scope" in req: + _scopes = [s for s in req["scope"] if s in _cap["scopes_supported"]] + req["scope"] = _scopes - if 'request' in req: - if _cap['request_parameter_supported'] is False: - raise InvalidRequest('Contains unsupported request parameter') + if "request" in req: + if _cap["request_parameter_supported"] is False: + raise InvalidRequest("Contains unsupported request parameter") - if 'request_uri' in req: - if _cap['request_uri_parameter_supported'] is False: - raise InvalidRequest('Contains unsupported request parameter') + if "request_uri" in req: + if _cap["request_uri_parameter_supported"] is False: + raise InvalidRequest("Contains unsupported request parameter") - if 'response_mode' in req: - if req['response_mode'] not in _cap['response_modes_supported']: - raise InvalidRequest('Contains unsupported response mode') + if "response_mode" in req: + if req["response_mode"] not in _cap["response_modes_supported"]: + raise InvalidRequest("Contains unsupported response mode") - if 'response_type' in req: - if not self.match_sp_sep([" ".join(req['response_type'])], - _cap['response_types_supported']): - raise InvalidRequest('Contains unsupported response type') + if "response_type" in req: + if not self.match_sp_sep( + [" ".join(req["response_type"])], _cap["response_types_supported"] + ): + raise InvalidRequest("Contains unsupported response type") if before != req.to_dict(): msg = "Request modified from %s to %s" @@ -733,12 +800,15 @@ def filter_request(self, req): def auth_init(self, request, request_class=None): """Overriden since the filter_request can throw an InvalidRequest.""" if request_class is not None: - warnings.warn('Passing `request_class` is deprecated. Please use `message_factory` instead.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `request_class` is deprecated. Please use `message_factory` instead.", + DeprecationWarning, + stacklevel=2, + ) try: return super().auth_init(request, request_class) except InvalidRequest as err: - return error_response('invalid_request', '%s' % err) + return error_response("invalid_request", "%s" % err) def authorization_endpoint(self, request="", cookie=None, **kwargs): """ @@ -750,7 +820,7 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs): if isinstance(info, Response): return info - areq = info['areq'] + areq = info["areq"] logger.info("authorization_request: %s" % (sanitize(areq.to_dict()),)) _cid = areq["client_id"] @@ -767,13 +837,13 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs): # Is the authentication event to be regarded as valid ? if authn_event.valid(): sid = self.setup_session(areq, authn_event, cinfo) - return self.authz_part2(authn_event.uid, areq, sid, - cookie=cookie) + return self.authz_part2(authn_event.uid, areq, sid, cookie=cookie) kwargs["req_user"] = req_user - authnres = self.do_auth(info["areq"], info["redirect_uri"], - cinfo, request, cookie, **kwargs) + authnres = self.do_auth( + info["areq"], info["redirect_uri"], cinfo, request, cookie, **kwargs + ) if isinstance(authnres, Response): return authnres @@ -802,13 +872,13 @@ def authz_part2(self, user, areq, sid, **kwargs): # as per the mix-up draft don't add iss and client_id if they are # already in the id_token. - if 'id_token' not in aresp: - aresp['iss'] = self.name + if "id_token" not in aresp: + aresp["iss"] = self.name - aresp['client_id'] = areq['client_id'] + aresp["client_id"] = areq["client_id"] if self.events: - self.events.store('Protocol response', aresp) + self.events.store("Protocol response", aresp) response = sanitize(aresp.to_dict()) logger.info("authorization response: %s", response) @@ -844,14 +914,16 @@ def recuperate_keys(self, cid: str, client_info: Dict[str, str]) -> None: self.keyjar.issuer_keys[cid] = [] # Add client secret as a symmetric key - self.keyjar.add_symmetric(cid, client_info['client_secret'], usage=['enc', 'sig']) + self.keyjar.add_symmetric( + cid, client_info["client_secret"], usage=["enc", "sig"] + ) # Try to renew from jwks or jwks_uri - if client_info.get('jwks_uri') is not None: + if client_info.get("jwks_uri") is not None: self.keyjar.add(cid, client_info["jwks_uri"]) - elif client_info.get('jwks') is not None: - self.keyjar.import_jwks(client_info['jwks'], cid) + elif client_info.get("jwks") is not None: + self.keyjar.import_jwks(client_info["jwks"], cid) else: - logger.warning('No keys to recover.') + logger.warning("No keys to recover.") def encrypt(self, payload, client_info, cid, val_type="id_token", cty=""): """ @@ -867,7 +939,7 @@ def encrypt(self, payload, client_info, cid, val_type="id_token", cty=""): try: alg = client_info["%s_encrypted_response_alg" % val_type] except KeyError: - logger.warning('%s NOT defined means no encryption', val_type) + logger.warning("%s NOT defined means no encryption", val_type) return payload else: try: @@ -875,7 +947,7 @@ def encrypt(self, payload, client_info, cid, val_type="id_token", cty=""): except KeyError as err: # if not defined-> A128CBC-HS256 (default) logger.warning("undefined parameter: %s", err) logger.info("using default") - enc = 'A128CBC-HS256' + enc = "A128CBC-HS256" logger.debug("alg=%s, enc=%s, val_type=%s" % (alg, enc, val_type)) if cid not in self.keyjar: @@ -889,8 +961,9 @@ def encrypt(self, payload, client_info, cid, val_type="id_token", cty=""): _jwe = JWE(payload, **kwargs) return _jwe.encrypt(keys, context="public") - def sign_encrypt_id_token(self, sinfo, client_info, areq, code=None, - access_token=None, user_info=None): + def sign_encrypt_id_token( + self, sinfo, client_info, areq, code=None, access_token=None, user_info=None + ): """ Sign and or encrypt a IDToken. @@ -915,14 +988,20 @@ def sign_encrypt_id_token(self, sinfo, client_info, areq, code=None, _authn_event = AuthnEvent.from_json(sinfo["authn_event"]) id_token = self.id_token_as_signed_jwt( - sinfo, loa=_authn_event.authn_info, alg=alg, code=code, - access_token=access_token, user_info=user_info, - auth_time=_authn_event.authn_time) + sinfo, + loa=_authn_event.authn_info, + alg=alg, + code=code, + access_token=access_token, + user_info=user_info, + auth_time=_authn_event.authn_time, + ) # Then encrypt if "id_token_encrypted_response_alg" in client_info: - id_token = self.encrypt(id_token, client_info, areq["client_id"], - "id_token", "JWT") + id_token = self.encrypt( + id_token, client_info, areq["client_id"], "id_token", "JWT" + ) return id_token @@ -938,9 +1017,9 @@ def code_grant_type(self, areq): client_info = self.cdb[str(areq["client_id"])] try: - _access_code = areq["code"].replace(' ', '+') + _access_code = areq["code"].replace(" ", "+") except KeyError: # Missing code parameter - absolutely fatal - return error_response('invalid_request', descr='Missing code') + return error_response("invalid_request", descr="Missing code") # assert that the code is valid if self.sdb.is_revoked(_access_code): @@ -954,14 +1033,14 @@ def code_grant_type(self, areq): # If redirect_uri was in the initial authorization request verify that it is here as well # Mismatch would raise in oic.oauth2.provider.Provider.token_endpoint - if "redirect_uri" in _info and 'redirect_uri' not in areq: - return error_response('invalid_request', descr='Missing redirect_uri') + if "redirect_uri" in _info and "redirect_uri" not in areq: + return error_response("invalid_request", descr="Missing redirect_uri") _log_debug("All checks OK") issue_refresh = False - permissions = _info.get('permission', ['offline_access']) or ['offline_access'] - if 'offline_access' in _info['scope'] and 'offline_access' in permissions: + permissions = _info.get("permission", ["offline_access"]) or ["offline_access"] + if "offline_access" in _info["scope"] and "offline_access" in permissions: issue_refresh = True try: @@ -975,10 +1054,14 @@ def code_grant_type(self, areq): if "openid" in _info["scope"]: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) + _idtoken = self.sign_encrypt_id_token( + _info, client_info, areq, user_info=userinfo + ) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) - return error_response("invalid_request", descr="Could not sign/encrypt id_token") + return error_response( + "invalid_request", descr="Could not sign/encrypt id_token" + ) _sdb.update_by_token(_access_code, "id_token", _idtoken) @@ -987,12 +1070,14 @@ def code_grant_type(self, areq): _log_debug("_tinfo: %s" % sanitize(_tinfo)) - response_cls = self.server.message_factory.get_response_type('token_endpoint') + response_cls = self.server.message_factory.get_response_type("token_endpoint") atr = response_cls(**by_schema(response_cls, **_tinfo)) logger.info("access_token_response: %s" % sanitize(atr.to_dict())) - return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + return Response( + atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS + ) def refresh_token_grant_type(self, areq): """ @@ -1003,7 +1088,7 @@ def refresh_token_grant_type(self, areq): _sdb = self.sdb _log_debug = logger.debug - client_id = str(areq['client_id']) + client_id = str(areq["client_id"]) client_info = self.cdb[client_id] rtoken = areq["refresh_token"] @@ -1017,22 +1102,28 @@ def refresh_token_grant_type(self, areq): if "openid" in _info["scope"] and "authn_event" in _info: userinfo = self.userinfo_in_id_token_claims(_info) try: - _idtoken = self.sign_encrypt_id_token(_info, client_info, areq, user_info=userinfo) + _idtoken = self.sign_encrypt_id_token( + _info, client_info, areq, user_info=userinfo + ) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) - return error_response("invalid_request", descr="Could not sign/encrypt id_token") + return error_response( + "invalid_request", descr="Could not sign/encrypt id_token" + ) - sid = _sdb.access_token.get_key(_info['access_token']) + sid = _sdb.access_token.get_key(_info["access_token"]) _sdb.update(sid, "id_token", _idtoken) _log_debug("_info: %s" % sanitize(_info)) - response_cls = self.server.message_factory.get_response_type('token_endpoint') + response_cls = self.server.message_factory.get_response_type("token_endpoint") atr = response_cls(**by_schema(response_cls, **_info)) logger.info("access_token_response: %s" % sanitize(atr.to_dict())) - return Response(atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS) + return Response( + atr.to_json(), content="application/json", headers=OAUTH2_NOCACHE_HEADERS + ) def client_credentials_grant_type(self, areq): """ @@ -1041,7 +1132,7 @@ def client_credentials_grant_type(self, areq): RFC6749 section 4.4 """ # Not supported in OpenID Connect - return error_response('unsupported_grant_type', descr='Unsupported grant_type') + return error_response("unsupported_grant_type", descr="Unsupported grant_type") def password_grant_type(self, areq): """ @@ -1050,7 +1141,7 @@ def password_grant_type(self, areq): RFC6749 section 4.3 """ # Not supported in OpenID Connect - return error_response('unsupported_grant_type', descr='Unsupported grant_type') + return error_response("unsupported_grant_type", descr="Unsupported grant_type") def _collect_user_info(self, session, userinfo_claims=None): """ @@ -1068,32 +1159,29 @@ def _collect_user_info(self, session, userinfo_claims=None): # Get only keys allowed by user and update the dict if such info # is stored in session - perm_set = session.get('permission') + perm_set = session.get("permission") if perm_set: uic = {key: uic[key] for key in uic if key in perm_set} if "oidreq" in session: - uic = self.server.update_claims(session, "oidreq", "userinfo", - uic) + uic = self.server.update_claims(session, "oidreq", "userinfo", uic) else: - uic = self.server.update_claims(session, "authzreq", "userinfo", - uic) + uic = self.server.update_claims(session, "authzreq", "userinfo", uic) if uic: userinfo_claims = Claims(**uic) else: userinfo_claims = None - logger.debug( - "userinfo_claim: %s" % sanitize(userinfo_claims.to_dict())) + logger.debug("userinfo_claim: %s" % sanitize(userinfo_claims.to_dict())) logger.debug("Session info: %s" % sanitize(session)) if "authn_event" in session: uid = AuthnEvent.from_json(session["authn_event"]).uid else: - uid = session['uid'] + uid = session["uid"] - info = self.userinfo(uid, session['client_id'], userinfo_claims) + info = self.userinfo(uid, session["client_id"], userinfo_claims) if "sub" in userinfo_claims: if not claims_match(session["sub"], userinfo_claims["sub"]): @@ -1104,8 +1192,7 @@ def _collect_user_info(self, session, userinfo_claims=None): logger.debug("user_info_response: {}".format(info)) except UnicodeEncodeError: try: - logger.debug( - "user_info_response: {}".format(info.encode('utf-8'))) + logger.debug("user_info_response: {}".format(info.encode("utf-8"))) except Exception: pass @@ -1129,21 +1216,21 @@ def signed_userinfo(self, client_info, userinfo, session): key = [] else: if algo.startswith("HS"): - key = self.keyjar.get_signing_key(alg2keytype(algo), - client_info["client_id"], - alg=algo) + key = self.keyjar.get_signing_key( + alg2keytype(algo), client_info["client_id"], alg=algo + ) else: # Use my key for signing - key = self.keyjar.get_signing_key(alg2keytype(algo), "", - alg=algo) + key = self.keyjar.get_signing_key(alg2keytype(algo), "", alg=algo) if not key: return error_response("invalid_request", descr="Missing signing key") jinfo = userinfo.to_jwt(key, algo) if "userinfo_encrypted_response_alg" in client_info: # encrypt with clients public key - jinfo = self.encrypt(jinfo, client_info, session["client_id"], - "userinfo", "JWT") + jinfo = self.encrypt( + jinfo, client_info, session["client_id"], "userinfo", "JWT" + ) return jinfo def userinfo_endpoint(self, request="", **kwargs): @@ -1152,29 +1239,28 @@ def userinfo_endpoint(self, request="", **kwargs): :param request: The request in a string format or as a dictionary """ - logger.debug('userinfo_endpoint: request={}, kwargs={}'.format( - request, kwargs)) + logger.debug("userinfo_endpoint: request={}, kwargs={}".format(request, kwargs)) try: _token = self._parse_access_token(request, **kwargs) except ParameterError: - return error_response('invalid_request', descr='Token is malformed') + return error_response("invalid_request", descr="Token is malformed") return self._do_user_info(_token, **kwargs) def _parse_access_token(self, request, **kwargs): if not request or "access_token" not in request: - _token = kwargs.get("authn", '') or '' + _token = kwargs.get("authn", "") or "" if not _token.startswith("Bearer "): raise ParameterError("Token is missing or malformed") - _token = _token[len("Bearer "):] + _token = _token[len("Bearer ") :] logger.debug("Bearer token {} chars".format(len(_token))) else: - args = {'data': request} + args = {"data": request} if isinstance(request, dict): - args['sformat'] = 'dict' + args["sformat"] = "dict" uireq = self.server.parse_user_info_request(**args) logger.debug("user_info_request: %s" % sanitize(uireq)) - _token = uireq["access_token"].replace(' ', '+') + _token = uireq["access_token"].replace(" ", "+") return _token @@ -1189,19 +1275,25 @@ def _do_user_info(self, token, **kwargs): try: typ, key = _sdb.access_token.type_and_key(token) except Exception: - return error_response("invalid_token", descr="Invalid Token", status_code=401) + return error_response( + "invalid_token", descr="Invalid Token", status_code=401 + ) _log_debug("access_token type: '%s'" % (typ,)) if typ != "T": - logger.error('Wrong token type: {}'.format(typ)) + logger.error("Wrong token type: {}".format(typ)) raise FailedAuthentication("Wrong type of token") if _sdb.access_token.is_expired(token): - return error_response('invalid_token', descr='Token is expired', status_code=401) + return error_response( + "invalid_token", descr="Token is expired", status_code=401 + ) if _sdb.is_revoked(key): - return error_response("invalid_token", descr="Token is revoked", status_code=401) + return error_response( + "invalid_token", descr="Token is revoked", status_code=401 + ) session = _sdb[key] # Scope can translate to userinfo_claims @@ -1219,14 +1311,18 @@ def _do_user_info(self, token, **kwargs): content_type = "application/jwt" elif "userinfo_encrypted_response_alg" in _cinfo: jinfo = info.to_json() - jinfo = self.encrypt(jinfo, _cinfo, session["client_id"], - "userinfo", "") + jinfo = self.encrypt( + jinfo, _cinfo, session["client_id"], "userinfo", "" + ) content_type = "application/jwt" else: jinfo = info.to_json() content_type = "application/json" except NotSupportedAlgorithm as err: - return error_response("invalid_request", descr="Not supported algorithm: {}".format(err.args[0])) + return error_response( + "invalid_request", + descr="Not supported algorithm: {}".format(err.args[0]), + ) except JWEException: return error_response("invalid_request", descr="Could not encrypt") @@ -1266,8 +1362,7 @@ def match_client_request(self, request): for _pref, _prov in PREFERENCE2PROVIDER.items(): if _pref in request: if _pref == "response_types": - if not self.match_sp_sep(request[_pref], - self.capabilities[_prov]): + if not self.match_sp_sep(request[_pref], self.capabilities[_prov]): raise CapabilitiesMisMatch(_pref) else: if isinstance(request[_pref], str): @@ -1275,7 +1370,8 @@ def match_client_request(self, request): raise CapabilitiesMisMatch(_pref) else: if not set(request[_pref]).issubset( - set(self.capabilities[_prov])): + set(self.capabilities[_prov]) + ): raise CapabilitiesMisMatch(_pref) def do_client_registration(self, request, client_id, ignore=None): @@ -1296,11 +1392,14 @@ def do_client_registration(self, request, client_id, ignore=None): err = ClientRegistrationErrorResponse( error="invalid_configuration_parameter", error_description="post_logout_redirect_uris " - "contains " - "fragment") - return Response(err.to_json(), - content="application/json", - status="400 Bad Request") + "contains " + "fragment", + ) + return Response( + err.to_json(), + content="application/json", + status="400 Bad Request", + ) base, query = splitquery(uri) if query: plruri.append((base, parse_qs(query))) @@ -1314,13 +1413,17 @@ def do_client_registration(self, request, client_id, ignore=None): _cinfo["redirect_uris"] = ruri except InvalidRedirectURIError as e: err = ClientRegistrationErrorResponse( - error="invalid_redirect_uri", - error_description=str(e)) - return Response(err.to_json(), content="application/json", status_code=400) + error="invalid_redirect_uri", error_description=str(e) + ) + return Response( + err.to_json(), content="application/json", status_code=400 + ) if "sector_identifier_uri" in request: try: - _cinfo["si_redirects"], _cinfo["sector_id"] = self._verify_sector_identifier(request) + _cinfo["si_redirects"], _cinfo[ + "sector_id" + ] = self._verify_sector_identifier(request) except InvalidSectorIdentifier as err: return error_response("invalid_configuration_parameter", descr=str(err)) elif "redirect_uris" in request: @@ -1334,8 +1437,10 @@ def do_client_registration(self, request, client_id, ignore=None): host = _host else: if host != _host: - return error_response("invalid_configuration_parameter", - descr="'sector_identifier_uri' must be registered") + return error_response( + "invalid_configuration_parameter", + descr="'sector_identifier_uri' must be registered", + ) for item in ["policy_uri", "logo_uri", "tos_uri"]: if item in request: @@ -1344,11 +1449,11 @@ def do_client_registration(self, request, client_id, ignore=None): else: return error_response( "invalid_configuration_parameter", - descr="%s pointed to illegal URL" % item) + descr="%s pointed to illegal URL" % item, + ) # Do I have the necessary keys - for item in ["id_token_signed_response_alg", - "userinfo_signed_response_alg"]: + for item in ["id_token_signed_response_alg", "userinfo_signed_response_alg"]: if item in request: if request[item] in self.capabilities[PREFERENCE2PROVIDER[item]]: ktyp = jws.alg2keytype(request[item]) @@ -1367,15 +1472,15 @@ def do_client_registration(self, request, client_id, ignore=None): except KeyError: pass except Exception as err: - logger.error( - "Failed to load client keys: %s" % sanitize(request.to_dict())) + logger.error("Failed to load client keys: %s" % sanitize(request.to_dict())) logger.error("%s", err) - logger.debug('Verify SSL: {}'.format(self.keyjar.verify_ssl)) + logger.debug("Verify SSL: {}".format(self.keyjar.verify_ssl)) err = ClientRegistrationErrorResponse( - error="invalid_configuration_parameter", - error_description="%s" % err) - return Response(err.to_json(), content="application/json", - status="400 Bad Request") + error="invalid_configuration_parameter", error_description="%s" % err + ) + return Response( + err.to_json(), content="application/json", status="400 Bad Request" + ) return _cinfo @@ -1401,14 +1506,19 @@ def verify_redirect_uris(registration_request): for uri in registration_request["redirect_uris"]: p = urlparse(uri) if client_type == "native": - if p.scheme not in ['http', 'https']: # Custom scheme + if p.scheme not in ["http", "https"]: # Custom scheme pass - elif p.scheme == "http" and p.hostname in ["localhost", - "127.0.0.1"]: + elif p.scheme == "http" and p.hostname in ["localhost", "127.0.0.1"]: pass else: - logger.error("InvalidRedirectURI: scheme:%s, hostname:%s", p.scheme, p.hostname) - raise InvalidRedirectURIError("Redirect_uri must use custom scheme or http and localhost") + logger.error( + "InvalidRedirectURI: scheme:%s, hostname:%s", + p.scheme, + p.hostname, + ) + raise InvalidRedirectURIError( + "Redirect_uri must use custom scheme or http and localhost" + ) elif must_https and p.scheme != "https": raise InvalidRedirectURIError("None https redirect_uri not allowed") elif p.fragment: @@ -1445,13 +1555,17 @@ def _verify_sector_identifier(self, request): try: si_redirects = json.loads(res.text) except ValueError: - raise InvalidSectorIdentifier("Error deserializing sector_identifier_uri content") + raise InvalidSectorIdentifier( + "Error deserializing sector_identifier_uri content" + ) if "redirect_uris" in request: logger.debug("redirect_uris: %s", request["redirect_uris"]) for uri in request["redirect_uris"]: if uri not in si_redirects: - raise InvalidSectorIdentifier("redirect_uri missing from sector_identifiers") + raise InvalidSectorIdentifier( + "redirect_uri missing from sector_identifiers" + ) return si_redirects, si_url @@ -1465,8 +1579,8 @@ def comb_uri(args): for base, query_dict in args[param]: if query_dict: query_string = urlencode( - [(key, v) for key in query_dict for v in - query_dict[key]]) + [(key, v) for key in query_dict for v in query_dict[key]] + ) val.append("%s?%s" % (base, query_string)) else: val.append(base) @@ -1476,7 +1590,9 @@ def comb_uri(args): def create_registration(self, authn=None, request=None, **kwargs): logger.debug("@registration_endpoint: <<%s>>" % sanitize(request)) - request_cls = self.server.message_factory.get_request_type('registration_endpoint') + request_cls = self.server.message_factory.get_request_type( + "registration_endpoint" + ) try: request = request_cls().deserialize(request, "json") except ValueError: @@ -1488,8 +1604,11 @@ def create_registration(self, authn=None, request=None, **kwargs): if isinstance(result, Response): return result - return Created(result.to_json(), content="application/json", - headers=[("Cache-Control", "no-store")]) + return Created( + result.to_json(), + content="application/json", + headers=[("Cache-Control", "no-store")], + ) @staticmethod def client_secret_expiration_time(): @@ -1507,13 +1626,17 @@ def client_registration_setup(self, request): if "type" not in request: return error_response("invalid_type", descr="%s" % err) else: - return error_response("invalid_configuration_parameter", descr="%s" % err) + return error_response( + "invalid_configuration_parameter", descr="%s" % err + ) request.rm_blanks() try: self.match_client_request(request) except CapabilitiesMisMatch as err: - return error_response("invalid_request", descr="Don't support proposed %s" % err) + return error_response( + "invalid_request", descr="Don't support proposed %s" % err + ) # create new id och secret client_id = rndstr(12) @@ -1525,7 +1648,7 @@ def client_registration_setup(self, request): _rat = rndstr(32) reg_enp = "" for endp in self.endp: - if endp.etype == 'registration': + if endp.etype == "registration": reg_enp = urljoin(self.baseurl, endp.url) break @@ -1536,19 +1659,21 @@ def client_registration_setup(self, request): "registration_client_uri": "%s?client_id=%s" % (reg_enp, client_id), "client_secret_expires_at": self.client_secret_expiration_time(), "client_id_issued_at": utc_time_sans_frac(), - "client_salt": rndstr(8) + "client_salt": rndstr(8), } - _cinfo = self.do_client_registration(request, client_id, - ignore=["redirect_uris", - "policy_uri", "logo_uri", - "tos_uri"]) + _cinfo = self.do_client_registration( + request, + client_id, + ignore=["redirect_uris", "policy_uri", "logo_uri", "tos_uri"], + ) if isinstance(_cinfo, Response): return _cinfo - response_cls = self.server.message_factory.get_response_type('registration_endpoint') - args = dict([(k, v) for k, v in _cinfo.items() - if k in response_cls.c_param]) + response_cls = self.server.message_factory.get_response_type( + "registration_endpoint" + ) + args = dict([(k, v) for k, v in _cinfo.items() if k in response_cls.c_param]) self.comb_uri(args) response = response_cls(**args) @@ -1568,16 +1693,16 @@ def client_registration_setup(self, request): return response - def registration_endpoint(self, request, authn=None, method='POST', **kwargs): - if method.lower() == 'post': + def registration_endpoint(self, request, authn=None, method="POST", **kwargs): + if method.lower() == "post": return self.create_registration(authn, request, **kwargs) - elif method.lower() == 'get': + elif method.lower() == "get": return self.read_registration(authn, request, **kwargs) - elif method.lower() == 'put': + elif method.lower() == "put": return self.alter_registration(authn, request, **kwargs) - elif method.lower() == 'delete': + elif method.lower() == "delete": return self.delete_registration(authn, request, **kwargs) - return error_response('Unsupported method', descr='Unsupported HTTP method') + return error_response("Unsupported method", descr="Unsupported HTTP method") def read_registration(self, authn, request, **kwargs): """ @@ -1591,15 +1716,13 @@ def read_registration(self, authn, request, **kwargs): :param kwargs: Any other arguments :return: """ - logger.debug("authn: %s, request: %s" % (sanitize(authn), - sanitize(request)) - ) + logger.debug("authn: %s, request: %s" % (sanitize(authn), sanitize(request))) # verify the access token, has to be key into the client information # database. if not authn.startswith("Bearer "): - return error_response('invalid_request') - token = authn[len("Bearer "):] + return error_response("invalid_request") + token = authn[len("Bearer ") :] # Get client_id from request _info = parse_qs(request) @@ -1608,20 +1731,30 @@ def read_registration(self, authn, request, **kwargs): cdb_entry = self.cdb.get(client_id) if cdb_entry is None: return Unauthorized() - reg_token = cdb_entry.get('registration_access_token', '') + reg_token = cdb_entry.get("registration_access_token", "") if not safe_str_cmp(reg_token, token): return Unauthorized() logger.debug("Client '%s' reads client info" % client_id) - response_cls = self.server.message_factory.get_response_type('registration_endpoint') - args = dict([(k, v) for k, v in self.cdb[client_id].items() - if k in response_cls.c_param]) + response_cls = self.server.message_factory.get_response_type( + "registration_endpoint" + ) + args = dict( + [ + (k, v) + for k, v in self.cdb[client_id].items() + if k in response_cls.c_param + ] + ) self.comb_uri(args) response = response_cls(**args) - return Response(response.to_json(), content="application/json", - headers=[("Cache-Control", "no-store")]) + return Response( + response.to_json(), + content="application/json", + headers=[("Cache-Control", "no-store")], + ) def alter_registration(self, authn, request, **kwargs): """ @@ -1631,8 +1764,11 @@ def alter_registration(self, authn, request, **kwargs): :param request: Query part of the request :return: Response with updated client info """ - return error_response('Unsupported operation', descr='Altering of the registration is not supported', - status_code=403) + return error_response( + "Unsupported operation", + descr="Altering of the registration is not supported", + status_code=403, + ) def delete_registration(self, authn, request, **kwargs): """ @@ -1642,11 +1778,13 @@ def delete_registration(self, authn, request, **kwargs): :param request: Query part of the request :return: Response with updated client info """ - return error_response('Unsupported operation', descr='Deletion of the registration is not supported', - status_code=403) + return error_response( + "Unsupported operation", + descr="Deletion of the registration is not supported", + status_code=403, + ) - def create_providerinfo(self, pcr_class=None, - setup=None): + def create_providerinfo(self, pcr_class=None, setup=None): """ Dynamically create the provider info response. @@ -1655,21 +1793,28 @@ def create_providerinfo(self, pcr_class=None, :return: """ if pcr_class is not None: - warnings.warn('Passing `pcr_class` is deprecated. Please use `message_factory.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `pcr_class` is deprecated. Please use `message_factory.", + DeprecationWarning, + stacklevel=2, + ) else: - pcr_class = self.server.message_factory.get_response_type('configuration_endpoint') + pcr_class = self.server.message_factory.get_response_type( + "configuration_endpoint" + ) _provider_info = copy.deepcopy(self.capabilities.to_dict()) if self.jwks_uri and self.keyjar: _provider_info["jwks_uri"] = self.jwks_uri for endp in self.endp: - if not self.baseurl.endswith('/'): - baseurl = self.baseurl + '/' + if not self.baseurl.endswith("/"): + baseurl = self.baseurl + "/" else: baseurl = self.baseurl - _provider_info['{}_endpoint'.format(endp.etype)] = urljoin(baseurl, endp.url) + _provider_info["{}_endpoint".format(endp.etype)] = urljoin( + baseurl, endp.url + ) if setup and isinstance(setup, dict): for key in pcr_class.c_param.keys(): @@ -1689,10 +1834,15 @@ def provider_features(self, pcr_class=None): :return: ProviderConfigurationResponse instance """ if pcr_class is not None: - warnings.warn('Passing `pcr_class` is deprecated. Please use `message_factory.', - DeprecationWarning, stacklevel=2) + warnings.warn( + "Passing `pcr_class` is deprecated. Please use `message_factory.", + DeprecationWarning, + stacklevel=2, + ) else: - pcr_class = self.server.message_factory.get_response_type('configuration_endpoint') + pcr_class = self.server.message_factory.get_response_type( + "configuration_endpoint" + ) _provider_info = pcr_class(**CAPABILITIES) # Parse scopes @@ -1723,9 +1873,8 @@ def provider_features(self, pcr_class=None): # Remove 'none' for token_endpoint_auth_signing_alg_values_supported # since it is not allowed sign_algs = sign_algs[:] - sign_algs.remove('none') - _provider_info[ - "token_endpoint_auth_signing_alg_values_supported"] = sign_algs + sign_algs.remove("none") + _provider_info["token_endpoint_auth_signing_alg_values_supported"] = sign_algs algs = jwe.SUPPORTED["alg"] for typ in ["userinfo", "id_token", "request_object"]: @@ -1762,10 +1911,10 @@ def verify_capabilities(self, capabilities): else: not_supported[key] = val except KeyError: - not_supported[key] = '' + not_supported[key] = "" elif isinstance(val, bool): if not _pinfo[key] and val: - not_supported[key] = '' + not_supported[key] = "" elif isinstance(val, list): for v in val: try: @@ -1777,12 +1926,14 @@ def verify_capabilities(self, capabilities): except KeyError: not_supported[key] = [v] except KeyError: - not_supported[key] = '' + not_supported[key] = "" if not_supported: logger.error( "Server doesn't support the following features: {}".format( - not_supported)) + not_supported + ) + ) return False return True @@ -1796,22 +1947,24 @@ def providerinfo_endpoint(self, handle="", **kwargs): msg = "provider_info_response: {}" _log_info(msg.format(sanitize(_response.to_dict()))) if self.events: - self.events.store('Protocol response', _response) + self.events.store("Protocol response", _response) headers = [("Cache-Control", "no-store"), ("x-ffo", "bar")] if handle: (key, timestamp) = handle if key.startswith(STR) and key.endswith(STR): - cookie = self.cookie_func(key, self.cookie_name, "pinfo", - self.sso_ttl) + cookie = self.cookie_func( + key, self.cookie_name, "pinfo", self.sso_ttl + ) headers.append(cookie) - resp = Response(_response.to_json(), content="application/json", - headers=headers) + resp = Response( + _response.to_json(), content="application/json", headers=headers + ) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) - resp = error_response('service_error', message) + resp = error_response("service_error", message) return resp @@ -1820,8 +1973,9 @@ def discovery_endpoint(self, request, handle=None, **kwargs): _log_debug("@discovery_endpoint") - request = self.server.message_factory.get_request_type('discovery_endpoint')().deserialize(request, - "urlencoded") + request = self.server.message_factory.get_request_type( + "discovery_endpoint" + )().deserialize(request, "urlencoded") _log_debug("discovery_request:%s" % (sanitize(request.to_dict()),)) if request["service"] != SWD_ISSUER: @@ -1829,19 +1983,21 @@ def discovery_endpoint(self, request, handle=None, **kwargs): # verify that the principal is one of mine - _response = self.server.message_factory.get_response_type('discovery_endpoint')(locations=[self.baseurl]) + _response = self.server.message_factory.get_response_type("discovery_endpoint")( + locations=[self.baseurl] + ) _log_debug("discovery_response:%s" % (sanitize(_response.to_dict()),)) headers = [("Cache-Control", "no-store")] (key, timestamp) = handle if key.startswith(STR) and key.endswith(STR): - cookie = self.cookie_func(key, self.cookie_name, "disc", - self.sso_ttl) + cookie = self.cookie_func(key, self.cookie_name, "disc", self.sso_ttl) headers.append(cookie) - return Response(_response.to_json(), content="application/json", - headers=headers) + return Response( + _response.to_json(), content="application/json", headers=headers + ) def aresp_check(self, aresp, areq): # Use of the nonce is REQUIRED for all requests where an ID Token is @@ -1855,15 +2011,19 @@ def response_mode(self, areq, fragment_enc, **kwargs): if resp is None and areq["response_mode"] == "form_post": context = { - 'action': kwargs['redirect_uri'], - 'inputs': kwargs['aresp'].to_dict(), + "action": kwargs["redirect_uri"], + "inputs": kwargs["aresp"].to_dict(), } - return Response(self.template_renderer('form_post', context), headers=kwargs["headers"]) + return Response( + self.template_renderer("form_post", context), headers=kwargs["headers"] + ) return None def create_authn_response(self, areq, sid): # create the response - aresp = self.server.message_factory.get_response_type('authorization_endpoint')() + aresp = self.server.message_factory.get_response_type( + "authorization_endpoint" + )() try: aresp["state"] = areq["state"] except KeyError: @@ -1889,7 +2049,7 @@ def create_authn_response(self, areq, sid): _code = aresp["code"] = self.sdb[sid]["code"] rtype.remove("code") else: - self.sdb.update(sid, 'code', None) + self.sdb.update(sid, "code", None) _code = None if "token" in rtype: @@ -1921,21 +2081,23 @@ def create_authn_response(self, areq, sid): hargs = {} rt_set = set(areq["response_type"]) - if {'code', 'id_token', 'token'}.issubset(rt_set): + if {"code", "id_token", "token"}.issubset(rt_set): hargs = {"code": _code, "access_token": _access_token} - elif {'code', 'id_token'}.issubset(rt_set): + elif {"code", "id_token"}.issubset(rt_set): hargs = {"code": _code} - elif {'id_token', 'token'}.issubset(rt_set): + elif {"id_token", "token"}.issubset(rt_set): hargs = {"access_token": _access_token} # or 'code id_token' try: id_token = self.sign_encrypt_id_token( - _sinfo, client_info, areq, user_info=user_info, - **hargs) + _sinfo, client_info, areq, user_info=user_info, **hargs + ) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) - return error_response("invalid_request", descr="Could not sign/encrypt id_token") + return error_response( + "invalid_request", descr="Could not sign/encrypt id_token" + ) aresp["id_token"] = id_token _sinfo["id_token"] = id_token @@ -1956,8 +2118,15 @@ def key_setup(self, local_path, vault="keys", sig=None, enc=None): :param enc: Key for encryption :return: A URL the RP can use to download the key. """ - self.jwks_uri = key_export(self.baseurl, local_path, vault, self.keyjar, - fqdn=self.hostname, sig=sig, enc=enc) + self.jwks_uri = key_export( + self.baseurl, + local_path, + vault, + self.keyjar, + fqdn=self.hostname, + sig=sig, + enc=enc, + ) def endsession_endpoint(self, request="", **kwargs): """ diff --git a/src/oic/utils/__init__.py b/src/oic/utils/__init__.py index 97377e664..364fcb331 100644 --- a/src/oic/utils/__init__.py +++ b/src/oic/utils/__init__.py @@ -1,7 +1,7 @@ import sys import traceback -__author__ = 'rohe0002' +__author__ = "rohe0002" def tobytes(value): @@ -21,12 +21,11 @@ def exception_trace(tag, exc, log=None): log.error("[%s] ExcList: %s", tag, "".join(message)) log.error("[%s] Exception: %s", tag, exc) else: - print("[{0}] ExcList: {1}".format(tag, "".join(message)), - file=sys.stderr) + print("[{0}] ExcList: {1}".format(tag, "".join(message)), file=sys.stderr) print("[{0}] Exception: {1}".format(tag, exc), file=sys.stderr) -SORT_ORDER = {'RS': 0, 'ES': 1, 'HS': 2, 'PS': 3, 'no': 4} +SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4} def sort_sign_alg(alg1, alg2): diff --git a/src/oic/utils/aes.py b/src/oic/utils/aes.py index 93ce2a9e5..04f84f100 100644 --- a/src/oic/utils/aes.py +++ b/src/oic/utils/aes.py @@ -8,13 +8,9 @@ from oic.utils import tobytes -__author__ = 'rolandh' +__author__ = "rolandh" -POSTFIX_MODE = { - "cbc": AES.MODE_CBC, - "cfb": AES.MODE_CFB, - "ecb": AES.MODE_CFB, -} +POSTFIX_MODE = {"cbc": AES.MODE_CBC, "cfb": AES.MODE_CFB, "ecb": AES.MODE_CFB} BLOCK_SIZE = 16 @@ -50,8 +46,15 @@ def build_cipher(key, iv, alg="aes_128_cbc"): raise AESError("Unsupported chaining mode") -def encrypt(key, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7", - b64enc=True, block_size=BLOCK_SIZE): +def encrypt( + key, + msg, + iv=None, + alg="aes_128_cbc", + padding="PKCS#7", + b64enc=True, + block_size=BLOCK_SIZE, +): """ Encrypt message. @@ -73,7 +76,7 @@ def encrypt(key, msg, iv=None, alg="aes_128_cbc", padding="PKCS#7", if _block_size: plen = _block_size - (len(msg) % _block_size) c = chr(plen) - msg += (c * plen) + msg += c * plen cipher, iv = build_cipher(tobytes(key), iv, alg) cmsg = iv + cipher.encrypt(tobytes(msg)) @@ -97,13 +100,13 @@ def decrypt(key, msg, iv=None, padding="PKCS#7", b64dec=True): else: data = msg - _iv = data[:AES.block_size] + _iv = data[: AES.block_size] if iv: assert iv == _iv cipher, iv = build_cipher(key, iv) - res = cipher.decrypt(data)[AES.block_size:] + res = cipher.decrypt(data)[AES.block_size :] if padding in ["PKCS#5", "PKCS#7"]: - res = res[:-res[-1]] + res = res[: -res[-1]] return res.decode("utf-8") @@ -150,7 +153,7 @@ def add_associated_data(self, data): :type data: bytes """ if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") self.kernel.update(data) def encrypt_and_tag(self, cleardata): diff --git a/src/oic/utils/authn/__init__.py b/src/oic/utils/authn/__init__.py index 3b031d2bf..169408711 100644 --- a/src/oic/utils/authn/__init__.py +++ b/src/oic/utils/authn/__init__.py @@ -1 +1 @@ -__author__ = 'rolandh' +__author__ = "rolandh" diff --git a/src/oic/utils/authn/authn_context.py b/src/oic/utils/authn/authn_context.py index a23936ef5..da5ed3a31 100644 --- a/src/oic/utils/authn/authn_context.py +++ b/src/oic/utils/authn/authn_context.py @@ -2,20 +2,23 @@ from oic.utils.http_util import extract_from_request -__author__ = 'rolandh' +__author__ = "rolandh" UNSPECIFIED = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified" -INTERNETPROTOCOLPASSWORD = \ - 'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword' -MOBILETWOFACTORCONTRACT = \ - 'urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract' -PASSWORDPROTECTEDTRANSPORT = \ - 'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport' -PASSWORD = 'urn:oasis:names:tc:SAML:2.0:ac:classes:Password' -TLSCLIENT = 'urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient' +INTERNETPROTOCOLPASSWORD = ( + "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword" +) +MOBILETWOFACTORCONTRACT = ( + "urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract" +) +PASSWORDPROTECTEDTRANSPORT = ( + "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport" +) +PASSWORD = "urn:oasis:names:tc:SAML:2.0:ac:classes:Password" +TLSCLIENT = "urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient" TIMESYNCTOKEN = "urn:oasis:names:tc:SAML:2.0:ac:classes:TimeSyncToken" -CMP_TYPE = ['exact', 'minimum', 'maximum', 'better'] +CMP_TYPE = ["exact", "minimum", "maximum", "better"] class AuthnBroker(object): @@ -54,7 +57,7 @@ def add(self, acr, method, level=0, authn_authority=""): "ref": acr, "method": method, "level": level, - "authn_auth": authn_authority + "authn_auth": authn_authority, } self.next += 1 @@ -203,6 +206,7 @@ def make_auth_verify(callback, next_module_instance=None): setup_multi_auth (in multi_auth.py) :return: function encapsulating the specified callback which properly handles a multi auth chain. """ + # This has to be here ... def auth_verify(environ, start_response, logger=None): kwargs = extract_from_request(environ) diff --git a/src/oic/utils/authn/client.py b/src/oic/utils/authn/client.py index 12fe06d42..4ca65fae2 100644 --- a/src/oic/utils/authn/client.py +++ b/src/oic/utils/authn/client.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -__author__ = 'rolandh' +__author__ = "rolandh" class AuthnFailure(Exception): @@ -43,10 +43,15 @@ class UnknownAuthnMethod(Exception): def assertion_jwt(cli, keys, audience, algorithm, lifetime=600): _now = utc_time_sans_frac() - at = AuthnToken(iss=cli.client_id, sub=cli.client_id, - aud=audience, jti=rndstr(32), - exp=_now + lifetime, iat=_now) - logger.debug('AuthnToken: {}'.format(at.to_dict())) + at = AuthnToken( + iss=cli.client_id, + sub=cli.client_id, + aud=audience, + jti=rndstr(32), + exp=_now + lifetime, + iat=_now, + ) + logger.debug("AuthnToken: {}".format(at.to_dict())) return at.to_jwt(key=keys, algorithm=algorithm) @@ -126,8 +131,10 @@ def construct(self, cis, request_args=None, http_args=None, **kwargs): except KeyError: pass - if (("client_id" not in cis.c_param.keys()) or - cis.c_param["client_id"][VREQUIRED]) is False: + if ( + ("client_id" not in cis.c_param.keys()) + or cis.c_param["client_id"][VREQUIRED] + ) is False: try: del cis["client_id"] except KeyError: @@ -168,8 +175,7 @@ def construct(self, cis, request_args=None, http_args=None, **kwargs): class BearerHeader(ClientAuthnMethod): - def construct(self, cis=None, request_args=None, http_args=None, - **kwargs): + def construct(self, cis=None, request_args=None, http_args=None, **kwargs): """ More complicated logic then I would have liked it to be. @@ -277,8 +283,7 @@ def choose_algorithm(self, entity, **kwargs): return algorithm def get_signing_key(self, algorithm): - return self.cli.keyjar.get_signing_key(alg2keytype(algorithm), - alg=algorithm) + return self.cli.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) def get_key_by_kid(self, kid, algorithm): _key = self.cli.keyjar.get_key_by_kid(kid) @@ -305,27 +310,29 @@ def construct(self, cis, request_args=None, http_args=None, **kwargs): # audience is the OP endpoint # OR OP identifier algorithm = None - if kwargs['authn_endpoint'] in ['token', 'refresh']: + if kwargs["authn_endpoint"] in ["token", "refresh"]: try: algorithm = self.cli.registration_info[ - 'token_endpoint_auth_signing_alg'] + "token_endpoint_auth_signing_alg" + ] except (KeyError, AttributeError): pass - audience = self.cli.provider_info['token_endpoint'] + audience = self.cli.provider_info["token_endpoint"] else: - audience = self.cli.provider_info['issuer'] + audience = self.cli.provider_info["issuer"] if not algorithm: algorithm = self.choose_algorithm(**kwargs) ktype = alg2keytype(algorithm) try: - if 'kid' in kwargs: + if "kid" in kwargs: signing_key = [self.get_key_by_kid(kwargs["kid"], algorithm)] elif ktype in self.cli.kid["sig"]: try: - signing_key = [self.get_key_by_kid( - self.cli.kid["sig"][ktype], algorithm)] + signing_key = [ + self.get_key_by_kid(self.cli.kid["sig"][ktype], algorithm) + ] except KeyError: signing_key = self.get_signing_key(algorithm) else: @@ -334,24 +341,24 @@ def construct(self, cis, request_args=None, http_args=None, **kwargs): logger.error("%s" % sanitize(err)) raise - if 'client_assertion' in kwargs: - cis["client_assertion"] = kwargs['client_assertion'] - if 'client_assertion_type' in kwargs: - cis['client_assertion_type'] = kwargs['client_assertion_type'] + if "client_assertion" in kwargs: + cis["client_assertion"] = kwargs["client_assertion"] + if "client_assertion_type" in kwargs: + cis["client_assertion_type"] = kwargs["client_assertion_type"] else: cis["client_assertion_type"] = JWT_BEARER - elif 'client_assertion' in cis: - if 'client_assertion_type' not in cis: + elif "client_assertion" in cis: + if "client_assertion_type" not in cis: cis["client_assertion_type"] = JWT_BEARER else: try: - _args = {'lifetime': kwargs['lifetime']} + _args = {"lifetime": kwargs["lifetime"]} except KeyError: _args = {} - cis["client_assertion"] = assertion_jwt(self.cli, signing_key, - audience, algorithm, - **_args) + cis["client_assertion"] = assertion_jwt( + self.cli, signing_key, audience, algorithm, **_args + ) cis["client_assertion_type"] = JWT_BEARER @@ -371,18 +378,18 @@ def construct(self, cis, request_args=None, http_args=None, **kwargs): def verify(self, areq, **kwargs): try: try: - argv = {'sender': areq['client_id']} + argv = {"sender": areq["client_id"]} except KeyError: argv = {} - bjwt = AuthnToken().from_jwt(areq["client_assertion"], - keyjar=self.cli.keyjar, - **argv) + bjwt = AuthnToken().from_jwt( + areq["client_assertion"], keyjar=self.cli.keyjar, **argv + ) except (Invalid, MissingKey) as err: logger.info("%s" % sanitize(err)) raise AuthnFailure("Could not verify client_assertion.") logger.debug("authntoken: %s" % sanitize(bjwt.to_dict())) - areq['parsed_client_assertion'] = bjwt + areq["parsed_client_assertion"] = bjwt try: cid = kwargs["client_id"] @@ -400,10 +407,10 @@ def verify(self, areq, **kwargs): logger.debug("audience: %s, baseurl: %s" % (_aud, self.cli.baseurl)) # figure out authn method - if alg2keytype(bjwt.jws_header['alg']) == 'oct': # Symmetric key - authn_method = 'client_secret_jwt' + if alg2keytype(bjwt.jws_header["alg"]) == "oct": # Symmetric key + authn_method = "client_secret_jwt" else: - authn_method = 'private_key_jwt' + authn_method = "private_key_jwt" if isinstance(_aud, str): if not str(_aud).startswith(self.cli.baseurl): @@ -431,8 +438,7 @@ def choose_algorithm(self, entity="client_secret_jwt", **kwargs): return JWSAuthnMethod.choose_algorithm(self, entity, **kwargs) def get_signing_key(self, algorithm): - return self.cli.keyjar.get_signing_key(alg2keytype(algorithm), - alg=algorithm) + return self.cli.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) class PrivateKeyJWT(JWSAuthnMethod): @@ -442,8 +448,9 @@ def choose_algorithm(self, entity="private_key_jwt", **kwargs): return JWSAuthnMethod.choose_algorithm(self, entity, **kwargs) def get_signing_key(self, algorithm): - return self.cli.keyjar.get_signing_key(alg2keytype(algorithm), "", - alg=algorithm) + return self.cli.keyjar.get_signing_key( + alg2keytype(algorithm), "", alg=algorithm + ) CLIENT_AUTHN_METHOD = { @@ -459,7 +466,7 @@ def get_signing_key(self, algorithm): def valid_client_info(cinfo): - eta = cinfo.get('client_secret_expires_at', 0) + eta = cinfo.get("client_secret_expires_at", 0) if eta != 0 and eta < utc_time_sans_frac(): return False return True @@ -482,7 +489,9 @@ def get_client_id(cdb, req, authn): raise FailedAuthentication("Missing client_id") elif authn.startswith("Basic "): logger.debug("Basic auth") - (_id, _secret) = base64.b64decode(authn[6:].encode("utf-8")).decode("utf-8").split(":") + (_id, _secret) = ( + base64.b64decode(authn[6:].encode("utf-8")).decode("utf-8").split(":") + ) # Either as string or encoded if _id not in cdb: _bid = as_bytes(_id) @@ -523,22 +532,23 @@ def verify_client(inst, areq, authn, type_method=TYPE_METHOD): """ if authn: # HTTP Basic auth (client_secret_basic) cid = get_client_id(inst.cdb, areq, authn) - auth_method = 'client_secret_basic' + auth_method = "client_secret_basic" elif "client_secret" in areq: # client_secret_post client_id = get_client_id(inst.cdb, areq, authn) logger.debug("Verified Client ID: %s" % client_id) cid = ClientSecretBasic(inst).verify(areq, client_id) - auth_method = 'client_secret_post' + auth_method = "client_secret_post" elif "client_assertion" in areq: # client_secret_jwt or private_key_jwt - check_key_availability(inst, areq['client_assertion']) + check_key_availability(inst, areq["client_assertion"]) for typ, method in type_method: if areq["client_assertion_type"] == typ: cid, auth_method = method(inst).verify(areq) break else: - logger.error('UnknownAssertionType: {}'.format( - areq["client_assertion_type"])) + logger.error( + "UnknownAssertionType: {}".format(areq["client_assertion_type"]) + ) raise UnknownAssertionType(areq["client_assertion_type"], areq) else: logger.error("Missing client authentication.") @@ -546,22 +556,24 @@ def verify_client(inst, areq, authn, type_method=TYPE_METHOD): if isinstance(areq, AccessTokenRequest): try: - _method = inst.cdb[cid]['token_endpoint_auth_method'] + _method = inst.cdb[cid]["token_endpoint_auth_method"] except KeyError: - _method = 'client_secret_basic' + _method = "client_secret_basic" if _method != auth_method: - logger.error("Wrong authentication method used: {} != {}".format( - auth_method, _method)) + logger.error( + "Wrong authentication method used: {} != {}".format( + auth_method, _method + ) + ) raise FailedAuthentication("Wrong authentication method used") # store which authn method was used where try: - inst.cdb[cid]['auth_method'][areq.__class__.__name__] = auth_method + inst.cdb[cid]["auth_method"][areq.__class__.__name__] = auth_method except KeyError: try: - inst.cdb[cid]['auth_method'] = { - areq.__class__.__name__: auth_method} + inst.cdb[cid]["auth_method"] = {areq.__class__.__name__: auth_method} except KeyError: pass diff --git a/src/oic/utils/authn/client_saml.py b/src/oic/utils/authn/client_saml.py index 6a82704a4..5776ff62f 100644 --- a/src/oic/utils/authn/client_saml.py +++ b/src/oic/utils/authn/client_saml.py @@ -3,16 +3,16 @@ from oic.utils.authn.client import CLIENT_AUTHN_METHOD from oic.utils.authn.client import ClientAuthnMethod -__author__ = 'rolandh' +__author__ = "rolandh" -SAML2_BEARER_ASSERTION_TYPE = \ - "urn:ietf:params:oauth:client-assertion-type:saml2-bearer" +SAML2_BEARER_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:saml2-bearer" try: from saml2.saml import assertion_from_string except ImportError: pass else: + class SAML2AuthnMethod(ClientAuthnMethod): """Authenticating clients using the SAML2 assertion profile.""" diff --git a/src/oic/utils/authn/javascript_login.py b/src/oic/utils/authn/javascript_login.py index 3b2407d19..18c5755db 100644 --- a/src/oic/utils/authn/javascript_login.py +++ b/src/oic/utils/authn/javascript_login.py @@ -5,7 +5,7 @@ from oic.utils.http_util import SeeOther from oic.utils.http_util import Unauthorized -__author__ = 'danielevertsson' +__author__ = "danielevertsson" class JavascriptFormMako(UsernamePasswordMako): @@ -35,7 +35,7 @@ def verify(self, request, **kwargs): logger.debug("passwd: %s" % self.passwd) # verify username and password try: - assert _dict['login_parameter'][0] == 'logged_in' + assert _dict["login_parameter"][0] == "logged_in" except (AssertionError, KeyError): resp = Unauthorized("You are not authorized. Javascript not executed") return resp, False @@ -44,7 +44,7 @@ def verify(self, request, **kwargs): try: _qp = _dict["query"][0] except KeyError: - _qp = self.get_multi_auth_cookie(kwargs['cookie']) + _qp = self.get_multi_auth_cookie(kwargs["cookie"]) try: return_to = self.generate_return_url(kwargs["return_to"], _qp) except KeyError: diff --git a/src/oic/utils/authn/ldap_member.py b/src/oic/utils/authn/ldap_member.py index 72bc1a802..a9dce3105 100644 --- a/src/oic/utils/authn/ldap_member.py +++ b/src/oic/utils/authn/ldap_member.py @@ -2,7 +2,7 @@ from oic.utils.userinfo.ldap_info import UserInfoLDAP -__author__ = 'haho0032' +__author__ = "haho0032" logger = logging.getLogger(__name__) @@ -19,6 +19,5 @@ def __call__(self, userid, **kwargs): for field in result[self.verify_attr]: if field in self.verify_attr_valid: return True - logger.warning(userid + "tries to use the service with the values " + - result) + logger.warning(userid + "tries to use the service with the values " + result) return False diff --git a/src/oic/utils/authn/ldapc.py b/src/oic/utils/authn/ldapc.py index a6f718298..54a34fbda 100644 --- a/src/oic/utils/authn/ldapc.py +++ b/src/oic/utils/authn/ldapc.py @@ -1,7 +1,7 @@ try: import ldap except ImportError: - raise ImportError('This module can be used only with pyldap installed.') + raise ImportError("This module can be used only with pyldap installed.") from oic.exception import PyoidcError from oic.utils.authn.user import UsernamePasswordMako @@ -9,7 +9,7 @@ SCOPE_MAP = { "base": ldap.SCOPE_BASE, "onelevel": ldap.SCOPE_ONELEVEL, - "subtree": ldap.SCOPE_SUBTREE + "subtree": ldap.SCOPE_SUBTREE, } @@ -18,9 +18,18 @@ class LDAPCError(PyoidcError): class LDAPAuthn(UsernamePasswordMako): - def __init__(self, srv, ldapsrv, return_to, pattern, mako_template, - template_lookup, ldap_user="", ldap_pwd="", - verification_endpoints=["verify"]): + def __init__( + self, + srv, + ldapsrv, + return_to, + pattern, + mako_template, + template_lookup, + ldap_user="", + ldap_pwd="", + verification_endpoints=["verify"], + ): """ Authenticate user against LDAP. @@ -38,8 +47,14 @@ def __init__(self, srv, ldapsrv, return_to, pattern, mako_template, :param ldap_pwd: The password for the ldap_user """ UsernamePasswordMako.__init__( - self, srv, mako_template, template_lookup, None, return_to, - verification_endpoints=verification_endpoints) + self, + srv, + mako_template, + template_lookup, + None, + return_to, + verification_endpoints=verification_endpoints, + ) self.ldap = ldap.initialize(ldapsrv) self.ldap.protocol_version = 3 @@ -64,7 +79,8 @@ def _verify(self, pwd, user): else: args = { "filterstr": self.pattern["filterstr"] % user, - "base": self.pattern["base"]} + "base": self.pattern["base"], + } if "scope" not in args: args["scope"] = ldap.SCOPE_SUBTREE else: diff --git a/src/oic/utils/authn/multi_auth.py b/src/oic/utils/authn/multi_auth.py index 8b92c2afd..8f60f4a6f 100644 --- a/src/oic/utils/authn/multi_auth.py +++ b/src/oic/utils/authn/multi_auth.py @@ -1,7 +1,7 @@ from oic.utils.authn.authn_context import make_auth_verify from oic.utils.authn.user import UserAuthnMethod -__author__ = 'danielevertsson' +__author__ = "danielevertsson" class MultiAuthnMethod(UserAuthnMethod): @@ -16,8 +16,9 @@ def __init__(self, auth_module): self.auth_module = auth_module def __call__(self, **kwargs): - cookie = self.create_cookie(kwargs['query'], "query", - UserAuthnMethod.MULTI_AUTH_COOKIE) + cookie = self.create_cookie( + kwargs["query"], "query", UserAuthnMethod.MULTI_AUTH_COOKIE + ) resp = self.auth_module(**kwargs) resp.headers.append(cookie) return resp @@ -43,8 +44,12 @@ def setup_multi_auth(auth_broker, urls, auth_modules): if i < len(auth_modules) - 1: next_module_instance = auth_modules[i + 1][0] - urls.append((callback_regexp, make_auth_verify(module_instance.verify, - next_module_instance))) + urls.append( + ( + callback_regexp, + make_auth_verify(module_instance.verify, next_module_instance), + ) + ) return multi_auth @@ -63,12 +68,12 @@ def __init__(self, authn_instance, end_point_index): self.end_point_index = end_point_index def __call__(self, **kwargs): - return self.authn_instance(end_point_index=self.end_point_index, - **kwargs) + return self.authn_instance(end_point_index=self.end_point_index, **kwargs) def verify(self, **kwargs): - return self.authn_instance.verify(end_point_index=self.end_point_index, - **kwargs) + return self.authn_instance.verify( + end_point_index=self.end_point_index, **kwargs + ) @property def srv(self): diff --git a/src/oic/utils/authn/saml.py b/src/oic/utils/authn/saml.py index 10f43ffb1..b880d8668 100644 --- a/src/oic/utils/authn/saml.py +++ b/src/oic/utils/authn/saml.py @@ -1,7 +1,7 @@ try: import saml2 except ImportError: - raise ImportError('This module can be used only with saml2 installed.') + raise ImportError("This module can be used only with saml2 installed.") import base64 @@ -41,9 +41,19 @@ class SAMLAuthnMethod(UserAuthnMethod): CONST_SAML_COOKIE = "samlauthc" CONST_HASIDP = "hasidp" - def __init__(self, srv, lookup, userdb, spconf, url, return_to, - cache=None, - bindings=None, userinfo=None, samlcache=None): + def __init__( + self, + srv, + lookup, + userdb, + spconf, + url, + return_to, + cache=None, + bindings=None, + userinfo=None, + samlcache=None, + ): """ Construct the class. @@ -64,8 +74,11 @@ def __init__(self, srv, lookup, userdb, spconf, url, return_to, if bindings: self.bindings = bindings else: - self.bindings = [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST, - BINDING_HTTP_ARTIFACT] + self.bindings = [ + BINDING_HTTP_REDIRECT, + BINDING_HTTP_POST, + BINDING_HTTP_ARTIFACT, + ] # TODO Why does this exist? self.verification_endpoint = "" # Configurations for the SP handler. @@ -73,9 +86,7 @@ def __init__(self, srv, lookup, userdb, spconf, url, return_to, config = SPConfig().load(self.sp_conf.CONFIG) self.sp = Saml2Client(config=config) mte = lookup.get_template("unauthorized.mako") - argv = { - "message": "You are not authorized!", - } + argv = {"message": "You are not authorized!"} self.not_authorized = mte.render(**argv) self.samlcache = self.sp_conf.SAML_CACHE @@ -117,8 +128,7 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs): binding = endp[1] break - saml_cookie, _ts, _typ = self.getCookieValue(cookie, - self.CONST_SAML_COOKIE) + saml_cookie, _ts, _typ = self.getCookieValue(cookie, self.CONST_SAML_COOKIE) data = json.loads(saml_cookie) rp_query_cookie = self.get_multi_auth_cookie(cookie) @@ -128,12 +138,14 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs): if not query: query = base64.b64decode(data[self.CONST_QUERY]).decode("ascii") - if data[self.CONST_HASIDP] == 'False': + if data[self.CONST_HASIDP] == "False": (done, response) = self._pick_idp(request, end_point_index) if done == 0: entity_id = response # Do the AuthnRequest - resp = self._redirect_to_auth(self.sp, entity_id, query, end_point_index) + resp = self._redirect_to_auth( + self.sp, entity_id, query, end_point_index + ) return resp, False return response, False @@ -143,8 +155,8 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs): try: response = self.sp.parse_authn_request_response( - request["SAMLResponse"][0], binding, - self.cache_outstanding_queries) + request["SAMLResponse"][0], binding, self.cache_outstanding_queries + ) except UnknownPrincipal as excp: logger.error("UnknownPrincipal: %s" % (excp,)) return Unauthorized(self.not_authorized), False @@ -183,9 +195,8 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs): self.samlcache["AA_ENTITYID"] = response.entity_id self.setup_userdb(uid, response.ava) - return_to = create_return_url(self.return_to, uid, - **{self.query_param: "true"}) - if '?' in return_to: + return_to = create_return_url(self.return_to, uid, **{self.query_param: "true"}) + if "?" in return_to: return_to += "&" else: return_to += "?" @@ -256,9 +267,16 @@ def _pick_idp(self, query, end_point_index): if not idp_entity_id: cookie = self.create_cookie( - '{"' + self.CONST_QUERY + '": "' + base64.b64encode(query) + - '" , "' + self.CONST_HASIDP + '": "False" }', - self.CONST_SAML_COOKIE, self.CONST_SAML_COOKIE) + '{"' + + self.CONST_QUERY + + '": "' + + base64.b64encode(query) + + '" , "' + + self.CONST_HASIDP + + '": "False" }', + self.CONST_SAML_COOKIE, + self.CONST_SAML_COOKIE, + ) if self.sp_conf.WAYF: if query: try: @@ -277,15 +295,18 @@ def _pick_idp(self, query, end_point_index): eid = _cli.config.entityid disco_end_point_index = end_point_index["disco_end_point_index"] - ret = _cli.config.getattr("endpoints", "sp")[ - "discovery_response"][disco_end_point_index][0] + ret = _cli.config.getattr("endpoints", "sp")["discovery_response"][ + disco_end_point_index + ][0] ret += "?sid=%s" % sid_ loc = _cli.create_discovery_service_request( - self.sp_conf.DISCOSRV, eid, **{"return": ret}) + self.sp_conf.DISCOSRV, eid, **{"return": ret} + ) return -1, SeeOther(loc, headers=[cookie]) elif not len(idps): raise ServiceErrorException( - 'Misconfiguration for the SAML Service Provider!') + "Misconfiguration for the SAML Service Provider!" + ) else: return -1, NotImplemented("No WAYF or DS present!") return 0, idp_entity_id @@ -293,46 +314,57 @@ def _pick_idp(self, query, end_point_index): def _wayf_redirect(self, cookie): sid_ = sid() self.cache_outstanding_queries[sid_] = self.verification_endpoint - return -1, SeeOther(headers=[ - ('Location', "%s?%s" % (self.sp_conf.WAYF, sid_)), cookie]) + return ( + -1, + SeeOther( + headers=[("Location", "%s?%s" % (self.sp_conf.WAYF, sid_)), cookie] + ), + ) def _redirect_to_auth(self, _cli, entity_id, query, end_point_index, vorg_name=""): try: binding, destination = _cli.pick_binding( - "single_sign_on_service", self.bindings, "idpsso", - entity_id=entity_id) - logger.debug("binding: %s, destination: %s" % (binding, - destination)) + "single_sign_on_service", self.bindings, "idpsso", entity_id=entity_id + ) + logger.debug("binding: %s, destination: %s" % (binding, destination)) extensions = None kwargs = {} if end_point_index: - kwargs["assertion_consumer_service_index"] = str(end_point_index[binding]) + kwargs["assertion_consumer_service_index"] = str( + end_point_index[binding] + ) if _cli.authn_requests_signed: _sid = saml2.s_utils.sid(_cli.seed) req_id, msg_str = _cli.create_authn_request( - destination, vorg=vorg_name, - sign=_cli.authn_requests_signed, message_id=_sid, - extensions=extensions, **kwargs) + destination, + vorg=vorg_name, + sign=_cli.authn_requests_signed, + message_id=_sid, + extensions=extensions, + **kwargs + ) _sid = req_id else: - req_id, req = _cli.create_authn_request(destination, - vorg=vorg_name, - sign=False, **kwargs) + req_id, req = _cli.create_authn_request( + destination, vorg=vorg_name, sign=False, **kwargs + ) msg_str = "%s" % req _sid = req_id _rstate = rndstr() - ht_args = _cli.apply_binding(binding, msg_str, destination, - relay_state=_rstate) + ht_args = _cli.apply_binding( + binding, msg_str, destination, relay_state=_rstate + ) logger.debug("ht_args: %s" % ht_args) except Exception as exc: logger.exception(exc) raise ServiceErrorException( - "Failed to construct the AuthnRequest: %s" % exc) + "Failed to construct the AuthnRequest: %s" % exc + ) # remember the request self.cache_outstanding_queries[_sid] = self.return_to @@ -340,9 +372,16 @@ def _redirect_to_auth(self, _cli, entity_id, query, end_point_index, vorg_name=" def response(self, binding, http_args, query): cookie = self.create_cookie( - '{"' + self.CONST_QUERY + '": "' + base64.b64encode(query.encode("ascii")).decode("ascii") + - '" , "' + self.CONST_HASIDP + '": "True" }', - self.CONST_SAML_COOKIE, self.CONST_SAML_COOKIE) + '{"' + + self.CONST_QUERY + + '": "' + + base64.b64encode(query.encode("ascii")).decode("ascii") + + '" , "' + + self.CONST_HASIDP + + '": "True" }', + self.CONST_SAML_COOKIE, + self.CONST_SAML_COOKIE, + ) if binding == BINDING_HTTP_ARTIFACT: resp = SeeOther() elif binding == BINDING_HTTP_REDIRECT: @@ -354,7 +393,6 @@ def response(self, binding, http_args, query): raise ServiceErrorException("Parameter error") else: http_args["headers"].append(cookie) - resp = Response(http_args["data"], - headers=http_args["headers"]) + resp = Response(http_args["data"], headers=http_args["headers"]) return resp diff --git a/src/oic/utils/authn/user.py b/src/oic/utils/authn/user.py index b31d87815..ca2d67815 100644 --- a/src/oic/utils/authn/user.py +++ b/src/oic/utils/authn/user.py @@ -21,7 +21,7 @@ from oic.utils.http_util import Unauthorized from oic.utils.sanitize import sanitize -__author__ = 'rolandh' +__author__ = "rolandh" logger = logging.getLogger(__name__) @@ -32,14 +32,15 @@ "login_title": "Username", "passwd_title": "Password", "submit_text": "Submit", - "client_policy_title": "Client Policy"}, + "client_policy_title": "Client Policy", + }, "se": { "title": "Logga in", "login_title": u"Användarnamn", "passwd_title": u"Lösenord", "submit_text": u"Sänd", - "client_policy_title": "Klientens sekretesspolicy" - } + "client_policy_title": "Klientens sekretesspolicy", + }, } @@ -93,15 +94,17 @@ def authenticated_as(self, cookie=None, **kwargs): _now = int(time.time()) if _now > (int(_ts) + int(self.cookie_ttl * 60)): logger.debug("Authentication timed out") - raise ToOld("%d > (%d + %d)" % (_now, int(_ts), - int(self.cookie_ttl * 60))) + raise ToOld( + "%d > (%d + %d)" % (_now, int(_ts), int(self.cookie_ttl * 60)) + ) else: if "max_age" in kwargs and kwargs["max_age"]: _now = int(time.time()) if _now > (int(_ts) + int(kwargs["max_age"])): logger.debug("Authentication too old") - raise ToOld("%d > (%d + %d)" % ( - _now, int(_ts), int(kwargs["max_age"]))) + raise ToOld( + "%d > (%d + %d)" % (_now, int(_ts), int(kwargs["max_age"])) + ) return {"uid": uid}, _ts @@ -135,8 +138,7 @@ def verify(self, **kwargs): raise NotImplementedError def get_multi_auth_cookie(self, cookie): - rp_query_cookie = self.getCookieValue(cookie, - UserAuthnMethod.MULTI_AUTH_COOKIE) + rp_query_cookie = self.getCookieValue(cookie, UserAuthnMethod.MULTI_AUTH_COOKIE) if rp_query_cookie: return rp_query_cookie[0] @@ -203,12 +205,25 @@ class UsernamePasswordMako(UserAuthnMethod): Works in a WSGI environment using Mako as template system. """ - param_map = {"as_user": "login", "acr_values": "acr", - "policy_uri": "policy_uri", "logo_uri": "logo_uri", - "tos_uri": "tos_uri", "query": "query"} + param_map = { + "as_user": "login", + "acr_values": "acr", + "policy_uri": "policy_uri", + "logo_uri": "logo_uri", + "tos_uri": "tos_uri", + "query": "query", + } - def __init__(self, srv, mako_template, template_lookup, pwd, return_to="", - templ_arg_func=None, verification_endpoints=None): + def __init__( + self, + srv, + mako_template, + template_lookup, + pwd, + return_to="", + templ_arg_func=None, + verification_endpoints=None, + ): """ Initialize the class. @@ -318,7 +333,7 @@ def verify(self, request, **kwargs): try: _qp = _dict["query"] except KeyError: - _qp = self.get_multi_auth_cookie(kwargs['cookie']) + _qp = self.get_multi_auth_cookie(kwargs["cookie"]) except (AssertionError, KeyError) as err: logger.debug("Password verification failed: {}".format(err)) resp = Unauthorized("Unknown user or wrong password") @@ -327,7 +342,7 @@ def verify(self, request, **kwargs): try: _qp = _dict["query"] except KeyError: - _qp = self.get_multi_auth_cookie(kwargs['cookie']) + _qp = self.get_multi_auth_cookie(kwargs["cookie"]) logger.debug("Password verification succeeded.") headers = [self.create_cookie(_dict["login"], "upm")] @@ -335,8 +350,9 @@ def verify(self, request, **kwargs): return_to = self.generate_return_url(kwargs["return_to"], _qp) except KeyError: try: - return_to = self.generate_return_url(self.return_to, _qp, - kwargs["path"]) + return_to = self.generate_return_url( + self.return_to, _qp, kwargs["path"] + ) except KeyError: return_to = self.generate_return_url(self.return_to, _qp) @@ -351,7 +367,6 @@ def done(self, areq): class BasicAuthn(UserAuthnMethod): - def __init__(self, srv, pwd, ttl=5): UserAuthnMethod.__init__(self, srv, ttl) self.passwd = pwd @@ -360,9 +375,9 @@ def verify_password(self, user, password): if user in self.passwd: _pwd = self.passwd[user] if not hmac.compare_digest(_pwd.encode(), password.encode()): - raise FailedAuthentication('Wrong user/password combination') + raise FailedAuthentication("Wrong user/password combination") else: - raise FailedAuthentication('Wrong user/password combination') + raise FailedAuthentication("Wrong user/password combination") def authenticated_as(self, cookie=None, authorization="", **kwargs): """ @@ -386,7 +401,6 @@ def authenticated_as(self, cookie=None, authorization="", **kwargs): class SymKeyAuthn(UserAuthnMethod): - def __init__(self, srv, ttl, symkey): UserAuthnMethod.__init__(self, srv, ttl) diff --git a/src/oic/utils/authn/user_cas.py b/src/oic/utils/authn/user_cas.py index 0dc787e2e..9fcf098cf 100644 --- a/src/oic/utils/authn/user_cas.py +++ b/src/oic/utils/authn/user_cas.py @@ -40,8 +40,7 @@ class CasAuthnMethod(UserAuthnMethod): # The name for the CAS cookie, containing query parameters and nonce. CONST_CAS_COOKIE = "cascookie" - def __init__(self, srv, cas_server, service_url, return_to, - extra_validation=None): + def __init__(self, srv, cas_server, service_url, return_to, extra_validation=None): """ Construct the class. @@ -69,20 +68,26 @@ def create_redirect(self, query): """ try: req = parse_qs(query) - acr = req['acr_values'][0] + acr = req["acr_values"][0] except KeyError: acr = None nonce = uuid.uuid4().get_urn() - service_url = urlencode( - {self.CONST_SERVICE: self.get_service_url(nonce, acr)}) + service_url = urlencode({self.CONST_SERVICE: self.get_service_url(nonce, acr)}) cas_url = self.cas_server + self.CONST_CASLOGIN + service_url cookie = self.create_cookie( - '{"' + self.CONST_NONCE + '": "' + base64.b64encode( - nonce) + '", "' + - self.CONST_QUERY + '": "' + base64.b64encode(query) + '"}', + '{"' + + self.CONST_NONCE + + '": "' + + base64.b64encode(nonce) + + '", "' + + self.CONST_QUERY + + '": "' + + base64.b64encode(query) + + '"}', self.CONST_CAS_COOKIE, - self.CONST_CAS_COOKIE) + self.CONST_CAS_COOKIE, + ) return SeeOther(cas_url, headers=[cookie]) def handle_callback(self, ticket, service_url): @@ -95,8 +100,7 @@ def handle_callback(self, ticket, service_url): :return: Uid if the login was successful otherwise None. """ data = {self.CONST_TICKET: ticket, self.CONST_SERVICE: service_url} - resp = requests.get(self.cas_server + self.CONST_CAS_VERIFY_TICKET, - params=data) + resp = requests.get(self.cas_server + self.CONST_CAS_VERIFY_TICKET, params=data) root = ET.fromstring(resp.content) for l1 in root: if self.CONST_AUTHSUCCESS in l1.tag: @@ -123,7 +127,15 @@ def get_service_url(self, nonce, acr): """ if acr is None: acr = "" - return self.service_url + "?" + self.CONST_NONCE + "=" + nonce + "&acr_values=" + acr + return ( + self.service_url + + "?" + + self.CONST_NONCE + + "=" + + nonce + + "&acr_values=" + + acr + ) def verify(self, request, cookie, **kwargs): """ @@ -145,27 +157,26 @@ def verify(self, request, cookie, **kwargs): else: raise ValueError("Wrong type of input") try: - cas_cookie, _ts, _typ = self.getCookieValue(cookie, - self.CONST_CAS_COOKIE) + cas_cookie, _ts, _typ = self.getCookieValue(cookie, self.CONST_CAS_COOKIE) data = json.loads(cas_cookie) nonce = base64.b64decode(data[self.CONST_NONCE]) if nonce != _dict[self.CONST_NONCE][0]: - logger.warning( - 'Someone tried to login without a correct nonce!') + logger.warning("Someone tried to login without a correct nonce!") return Unauthorized("You are not authorized!") acr = None try: acr = _dict["acr_values"][0] except KeyError: pass - uid = self.handle_callback(_dict[self.CONST_TICKET], - self.get_service_url(nonce, acr)) + uid = self.handle_callback( + _dict[self.CONST_TICKET], self.get_service_url(nonce, acr) + ) if uid is None or uid == "": - logger.info('Someone tried to login, but was denied by CAS!') + logger.info("Someone tried to login, but was denied by CAS!") return Unauthorized("You are not authorized!") cookie = self.create_cookie(uid, "casm") return_to = self.generate_return_url(self.return_to, uid) - if '?' in return_to: + if "?" in return_to: return_to += "&" else: return_to += "?" @@ -173,6 +184,7 @@ def verify(self, request, cookie, **kwargs): return SeeOther(return_to, headers=[cookie]) except Exception: # FIXME: This should catch specific exception thrown from methods in the block - logger.fatal('Metod verify in user_cas.py had a fatal exception.', - exc_info=True) + logger.fatal( + "Metod verify in user_cas.py had a fatal exception.", exc_info=True + ) return Unauthorized("You are not authorized!") diff --git a/src/oic/utils/authz.py b/src/oic/utils/authz.py index 90815717e..48400d77f 100644 --- a/src/oic/utils/authz.py +++ b/src/oic/utils/authz.py @@ -33,15 +33,17 @@ def permissions(self, cookie=None, **kwargs): _now = int(time.time()) if _now > (int(_ts) + int(self.cookie_ttl * 60)): logger.debug("Authentication timed out") - raise ToOld("%d > (%d + %d)" % (_now, int(_ts), - int(self.cookie_ttl * 60))) + raise ToOld( + "%d > (%d + %d)" % (_now, int(_ts), int(self.cookie_ttl * 60)) + ) else: if "max_age" in kwargs and kwargs["max_age"]: _now = int(time.time()) if _now > (int(_ts) + int(kwargs["max_age"])): logger.debug("Authentication too old") - raise ToOld("%d > (%d + %d)" % ( - _now, int(_ts), int(kwargs["max_age"]))) + raise ToOld( + "%d > (%d + %d)" % (_now, int(_ts), int(kwargs["max_age"])) + ) return self.permdb[uid] diff --git a/src/oic/utils/claims.py b/src/oic/utils/claims.py index 24e1faa9e..25e4400f8 100644 --- a/src/oic/utils/claims.py +++ b/src/oic/utils/claims.py @@ -1,4 +1,4 @@ -__author__ = 'rolandh' +__author__ = "rolandh" class ClaimsMode(object): diff --git a/src/oic/utils/client_management.py b/src/oic/utils/client_management.py index 70ee1fd16..f27c6d22b 100755 --- a/src/oic/utils/client_management.py +++ b/src/oic/utils/client_management.py @@ -14,7 +14,7 @@ from oic.oic.provider import secret from oic.utils.clientdb import BaseClientDatabase -__author__ = 'rolandh' +__author__ = "rolandh" def unpack_redirect_uri(redirect_uris): @@ -58,14 +58,12 @@ def keys(self): def items(self): return self.cdb.items() - def create(self, redirect_uris=None, policy_uri="", logo_uri="", - jwks_uri=""): + def create(self, redirect_uris=None, policy_uri="", logo_uri="", jwks_uri=""): if redirect_uris is None: - print( - 'Enter redirect_uris one at the time, end with a blank line: ') + print("Enter redirect_uris one at the time, end with a blank line: ") redirect_uris = [] while True: - redirect_uri = input('?: ') + redirect_uri = input("?: ") if redirect_uri: redirect_uris.append(redirect_uri) else: @@ -93,7 +91,7 @@ def create(self, redirect_uris=None, policy_uri="", logo_uri="", if logo_uri: info["logo_uri"] = logo_uri if jwks_uri: - info['jwks_uri'] = jwks_uri + info["jwks_uri"] = jwks_uri self.cdb[client_id] = info @@ -132,34 +130,70 @@ def dump(self, filename): else: res.append([key, val]) - fp = open(filename, 'w') + fp = open(filename, "w") json.dump(res, fp) fp.close() def run(): parser = argparse.ArgumentParser() - parser.add_argument('-l', '--list', dest='list', action='store_true', - help="List all client_ids") - parser.add_argument('-d', '--delete', dest='delete', action='store_true', - help="Delete the entity with the given client_id") - parser.add_argument('-c', '--create', dest='create', action='store_true', - help=("Create a new client, returns the stored " - "information")) - parser.add_argument('-s', '--show', dest='show', action='store_true', - help=("Show information connected to a specific" - "client_id")) - parser.add_argument('-i', '--client-id', dest='client_id', - help="A client_id on which to do an action") - parser.add_argument('-r', '--replace', dest='replace', - help=("Information that should replace what's there" - "about a specific client_id")) - parser.add_argument('-I', '--input-file', dest='input_file', - help="Import client information from a file") - parser.add_argument('-D', '--output-file', dest='output_file', - help="Dump client information to a file") - parser.add_argument('-R', '--reset', dest="reset", action='store_true', - help="Reset the database == removing all registrations") + parser.add_argument( + "-l", "--list", dest="list", action="store_true", help="List all client_ids" + ) + parser.add_argument( + "-d", + "--delete", + dest="delete", + action="store_true", + help="Delete the entity with the given client_id", + ) + parser.add_argument( + "-c", + "--create", + dest="create", + action="store_true", + help=("Create a new client, returns the stored " "information"), + ) + parser.add_argument( + "-s", + "--show", + dest="show", + action="store_true", + help=("Show information connected to a specific" "client_id"), + ) + parser.add_argument( + "-i", + "--client-id", + dest="client_id", + help="A client_id on which to do an action", + ) + parser.add_argument( + "-r", + "--replace", + dest="replace", + help=( + "Information that should replace what's there" "about a specific client_id" + ), + ) + parser.add_argument( + "-I", + "--input-file", + dest="input_file", + help="Import client information from a file", + ) + parser.add_argument( + "-D", + "--output-file", + dest="output_file", + help="Dump client information to a file", + ) + parser.add_argument( + "-R", + "--reset", + dest="reset", + action="store_true", + help="Reset the database == removing all registrations", + ) parser.add_argument(dest="filename") args = parser.parse_args() @@ -188,5 +222,5 @@ def run(): cdb.dump(args.output_file) -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/src/oic/utils/clientdb.py b/src/oic/utils/clientdb.py index 02cf1123f..634474415 100644 --- a/src/oic/utils/clientdb.py +++ b/src/oic/utils/clientdb.py @@ -77,42 +77,48 @@ class MDQClient(BaseClientDatabase): def __init__(self, url): """Set the remote storage url.""" self.url = url - self.headers = {'Accept': 'application/json', 'Accept-Encoding': 'gzip'} + self.headers = {"Accept": "application/json", "Accept-Encoding": "gzip"} def __getitem__(self, item): """Retrieve a single entity.""" - mdx_url = urljoin(self.url, 'entities/{}'.format(quote(item, safe=''))) + mdx_url = urljoin(self.url, "entities/{}".format(quote(item, safe=""))) response = requests.get(mdx_url, headers=self.headers) if response.status_code == 200: return response.json() else: - raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason)) + raise NoClientInfoReceivedError( + "{} {}".format(response.status_code, response.reason) + ) def __setitem__(self, item, value): """Remote management is readonly.""" - raise RuntimeError('MDQClient is readonly.') + raise RuntimeError("MDQClient is readonly.") def __delitem__(self, item): """Remote management is readonly.""" - raise RuntimeError('MDQClient is readonly.') + raise RuntimeError("MDQClient is readonly.") def keys(self): """Get all registered entitites.""" - mdx_url = urljoin(self.url, 'entities') + mdx_url = urljoin(self.url, "entities") response = requests.get(mdx_url, headers=self.headers) if response.status_code == 200: - return [item['client_id'] for item in response.json()] + return [item["client_id"] for item in response.json()] else: - raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason)) + raise NoClientInfoReceivedError( + "{} {}".format(response.status_code, response.reason) + ) def items(self): """Geting all registered entities.""" - mdx_url = urljoin(self.url, 'entities') + mdx_url = urljoin(self.url, "entities") response = requests.get(mdx_url, headers=self.headers) if response.status_code == 200: return response.json() else: - raise NoClientInfoReceivedError("{} {}".format(response.status_code, response.reason)) + raise NoClientInfoReceivedError( + "{} {}".format(response.status_code, response.reason) + ) # Dictionary can be used as a ClientDatabase diff --git a/src/oic/utils/http_util.py b/src/oic/utils/http_util.py index 7864b5cad..c2b22f342 100644 --- a/src/oic/utils/http_util.py +++ b/src/oic/utils/http_util.py @@ -18,7 +18,7 @@ from oic.utils.aes import AEAD from oic.utils.aes import AESError -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -27,24 +27,21 @@ CORS_HEADERS = [ ("Access-Control-Allow-Origin", "*"), ("Access-Control-Allow-Methods", "GET"), - ("Access-Control-Allow-Headers", "Authorization") + ("Access-Control-Allow-Headers", "Authorization"), ] -OAUTH2_NOCACHE_HEADERS = [ - ('Pragma', 'no-cache'), - ('Cache-Control', 'no-store'), -] +OAUTH2_NOCACHE_HEADERS = [("Pragma", "no-cache"), ("Cache-Control", "no-store")] class Response(object): - _template = '' + _template = "" _status_code = 200 - _content_type = 'text/html' + _content_type = "text/html" _mako_template = None _mako_lookup = None def __init__(self, message=None, **kwargs): - self.status_code = kwargs.get('status_code', self._status_code) + self.status_code = kwargs.get("status_code", self._status_code) self.response = kwargs.get("response", self._response) self.template = kwargs.get("template", self._template) self.mako_template = kwargs.get("mako_template", self._mako_template) @@ -59,7 +56,7 @@ def __init__(self, message=None, **kwargs): self.headers.append(("Content-type", _content_type)) def _start_response(self, start_response): - name = client.responses.get(self.status_code, 'UNKNOWN') + name = client.responses.get(self.status_code, "UNKNOWN") start_response("{} {}".format(self.status_code, name), self.headers) def __call__(self, environ, start_response, **kwargs): @@ -70,15 +67,15 @@ def _response(self, message="", **argv): # Have to be more specific, this might be a bit to much. if message: try: - if '', '</script>') + if "", "</script>" + ) except TypeError: - if b'', b'</script>') + if b"", b"</script>" + ) if self.template: if ("Content-type", "application/json") in self.headers: @@ -90,9 +87,9 @@ def _response(self, message="", **argv): mte = self.mako_lookup.get_template(self.mako_template) return [mte.render(**argv)] else: - if [x for x in self._c_types() if x.startswith('image/')]: + if [x for x in self._c_types() if x.startswith("image/")]: return [message] - elif [x for x in self._c_types() if x == 'application/x-gzip']: + elif [x for x in self._c_types() if x == "application/x-gzip"]: return [message] try: @@ -101,7 +98,11 @@ def _response(self, message="", **argv): return [message] def info(self): - return {'status_code': self.status_code, 'headers': self.headers, 'message': self.message} + return { + "status_code": self.status_code, + "headers": self.headers, + "message": self.message, + } def add_header(self, ava): self.headers.append(ava) @@ -130,27 +131,31 @@ class NoContent(Response): class Redirect(Response): - _template = '\nRedirecting to %s\n' \ - '\nYou are being redirected to %s\n' \ - '\n' + _template = ( + "\nRedirecting to %s\n" + '\nYou are being redirected to %s\n' + "\n" + ) _status_code = 302 def __call__(self, environ, start_response, **kwargs): location = self.message - self.headers.append(('location', location)) + self.headers.append(("location", location)) self._start_response(start_response) return self.response((location, location, location)) class SeeOther(Response): - _template = '\nRedirecting to %s\n' \ - '\nYou are being redirected to %s\n' \ - '\n' + _template = ( + "\nRedirecting to %s\n" + '\nYou are being redirected to %s\n' + "\n" + ) _status_code = 303 def __call__(self, environ, start_response, **kwargs): location = self.message - self.headers.append(('location', location)) + self.headers.append(("location", location)) self._start_response(start_response) return self.response((location, location, location)) @@ -220,7 +225,7 @@ def extract(environ, empty=False, err=False): :param empty: Stops on empty fields (default: Fault) :param err: Stops on errors in fields (default: Fault) """ - formdata = cgi.parse(environ['wsgi.input'], environ, empty, err) + formdata = cgi.parse(environ["wsgi.input"], environ, empty, err) # Remove single entries from lists for key, value in formdata.iteritems(): if len(value) == 1: @@ -235,28 +240,29 @@ def geturl(environ, query=True, path=True): :param query: Is QUERY_STRING included in URI (default: True) :param path: Is path included in URI (default: True) """ - url = [environ['wsgi.url_scheme'] + '://'] - if environ.get('HTTP_HOST'): - url.append(environ['HTTP_HOST']) + url = [environ["wsgi.url_scheme"] + "://"] + if environ.get("HTTP_HOST"): + url.append(environ["HTTP_HOST"]) else: - url.append(environ['SERVER_NAME']) - if environ['wsgi.url_scheme'] == 'https': - if environ['SERVER_PORT'] != '443': - url.append(':' + environ['SERVER_PORT']) + url.append(environ["SERVER_NAME"]) + if environ["wsgi.url_scheme"] == "https": + if environ["SERVER_PORT"] != "443": + url.append(":" + environ["SERVER_PORT"]) else: - if environ['SERVER_PORT'] != '80': - url.append(':' + environ['SERVER_PORT']) + if environ["SERVER_PORT"] != "80": + url.append(":" + environ["SERVER_PORT"]) if path: url.append(getpath(environ)) - if query and environ.get('QUERY_STRING'): - url.append('?' + environ['QUERY_STRING']) - return ''.join(url) + if query and environ.get("QUERY_STRING"): + url.append("?" + environ["QUERY_STRING"]) + return "".join(url) def getpath(environ): """Build a path.""" - return ''.join([quote(environ.get('SCRIPT_NAME', '')), - quote(environ.get('PATH_INFO', ''))]) + return "".join( + [quote(environ.get("SCRIPT_NAME", "")), quote(environ.get("PATH_INFO", ""))] + ) def _expiration(timeout, time_format=None): @@ -282,7 +288,7 @@ def cookie_signature(key, *parts): for part in parts: if part: if isinstance(part, str): - sha1.update(part.encode('utf-8')) + sha1.update(part.encode("utf-8")) else: sha1.update(part) return str(sha1.hexdigest()) @@ -304,7 +310,7 @@ def verify_cookie_signature(sig, key, *parts): return hmac.compare_digest(sig, cookie_signature(key, *parts)) -def _make_hashed_key(parts, hashfunc='sha256'): +def _make_hashed_key(parts, hashfunc="sha256"): """ Construct a key via hashing the parts. @@ -314,14 +320,24 @@ def _make_hashed_key(parts, hashfunc='sha256'): h = hashlib.new(hashfunc) for part in parts: if isinstance(part, str): - part = part.encode('utf-8') + part = part.encode("utf-8") if part: h.update(part) return h.digest() -def make_cookie(name, load, seed, expire=0, domain="", path="", timestamp="", - enc_key=None, secure=True, httponly=True): +def make_cookie( + name, + load, + seed, + expire=0, + domain="", + path="", + timestamp="", + enc_key=None, + secure=True, + httponly=True, +): """ Create and return a cookie. @@ -377,16 +393,20 @@ def make_cookie(name, load, seed, expire=0, domain="", path="", timestamp="", crypt.add_associated_data(bytes_timestamp) ciphertext, tag = crypt.encrypt_and_tag(bytes_load) - cookie_payload = [bytes_timestamp, - base64.b64encode(iv), - base64.b64encode(ciphertext), - base64.b64encode(tag)] + cookie_payload = [ + bytes_timestamp, + base64.b64encode(iv), + base64.b64encode(ciphertext), + base64.b64encode(tag), + ] else: cookie_payload = [ - bytes_load, bytes_timestamp, - cookie_signature(seed, load, timestamp).encode('utf-8')] + bytes_load, + bytes_timestamp, + cookie_signature(seed, load, timestamp).encode("utf-8"), + ] - cookie[name] = (b"|".join(cookie_payload)).decode('utf-8') + cookie[name] = (b"|".join(cookie_payload)).decode("utf-8") if path: cookie[name]["path"] = path if domain: @@ -394,9 +414,9 @@ def make_cookie(name, load, seed, expire=0, domain="", path="", timestamp="", if expire: cookie[name]["expires"] = _expiration(expire, "%a, %d-%b-%Y %H:%M:%S GMT") if secure: - cookie[name]['secure'] = secure + cookie[name]["secure"] = secure if httponly: - cookie[name]['httponly'] = httponly + cookie[name]["httponly"] = httponly return tuple(cookie.output().split(": ", 1)) @@ -423,7 +443,7 @@ def parse_cookie(name, seed, kaka, enc_key=None): return None if isinstance(seed, str): - seed = seed.encode('utf-8') + seed = seed.encode("utf-8") parts = cookie_parts(name, kaka) if parts is None: @@ -447,12 +467,12 @@ def parse_cookie(name, seed, kaka, enc_key=None): crypt = AEAD(key, iv) # timestamp does not need to be encrypted, just MAC'ed, # so we add it to 'Associated Data' only. - crypt.add_associated_data(timestamp.encode('utf-8')) + crypt.add_associated_data(timestamp.encode("utf-8")) try: cleartext = crypt.decrypt_and_verify(ciphertext, tag) except AESError: raise InvalidCookieSign() - return cleartext.decode('utf-8'), timestamp + return cleartext.decode("utf-8"), timestamp return None @@ -471,14 +491,14 @@ def cookie_parts(name, kaka): def get_post(environ): # the environment variable CONTENT_LENGTH may be empty or missing try: - request_body_size = int(environ.get('CONTENT_LENGTH', 0)) + request_body_size = int(environ.get("CONTENT_LENGTH", 0)) except ValueError: request_body_size = 0 # When the method is POST the query string will be sent # in the HTTP request body which is passed by the WSGI server # in the file like wsgi.input environment variable. - text = environ['wsgi.input'].read(request_body_size) + text = environ["wsgi.input"].read(request_body_size) try: text = text.decode("utf-8") except AttributeError: @@ -548,7 +568,6 @@ def wsgi_wrapper(environ, start_response, func, **kwargs): class CookieDealer(object): - @property def srv(self): return self._srv @@ -569,17 +588,16 @@ def init_srv(self, srv): return self.srv = srv - symkey = getattr(self.srv, 'symkey', None) + symkey = getattr(self.srv, "symkey", None) if symkey is not None and symkey == "": msg = "CookieDealer.srv.symkey cannot be an empty value" raise ImproperlyConfigured(msg) - if not getattr(srv, 'seed', None): - setattr(srv, 'seed', rndstr().encode("utf-8")) + if not getattr(srv, "seed", None): + setattr(srv, "seed", rndstr().encode("utf-8")) def delete_cookie(self, cookie_name=None): - return self.create_cookie("", "", cookie_name=cookie_name, ttl=-1, - kill=True) + return self.create_cookie("", "", cookie_name=cookie_name, ttl=-1, kill=True) def create_cookie(self, value, typ, cookie_name=None, ttl=-1, kill=False): if kill: @@ -607,8 +625,18 @@ def create_cookie(self, value, typ, cookie_name=None, ttl=-1, kill=False): except TypeError: _msg = "::".join([value[0], timestamp, typ]) - cookie = make_cookie(cookie_name, _msg, self.srv.seed, expire=ttl, domain=cookie_domain, path=cookie_path, - timestamp=timestamp, enc_key=self.srv.symkey, secure=self.secure, httponly=self.httponly) + cookie = make_cookie( + cookie_name, + _msg, + self.srv.seed, + expire=ttl, + domain=cookie_domain, + path=cookie_path, + timestamp=timestamp, + enc_key=self.srv.symkey, + secure=self.secure, + httponly=self.httponly, + ) return cookie def getCookieValue(self, cookie=None, cookie_name=None): @@ -626,9 +654,9 @@ def get_cookie_value(self, cookie=None, cookie_name=None): return None else: try: - info, timestamp = parse_cookie(cookie_name, - self.srv.seed, cookie, - self.srv.symkey) + info, timestamp = parse_cookie( + cookie_name, self.srv.seed, cookie, self.srv.symkey + ) except (TypeError, AssertionError): return None else: diff --git a/src/oic/utils/jwt.py b/src/oic/utils/jwt.py index bc7ae10e2..a95983e0b 100644 --- a/src/oic/utils/jwt.py +++ b/src/oic/utils/jwt.py @@ -9,13 +9,21 @@ from oic.oic.message import JasonWebToken from oic.utils.time_util import utc_time_sans_frac -__author__ = 'roland' +__author__ = "roland" class JWT(object): - def __init__(self, keyjar, iss='', lifetime=0, sign_alg='RS256', - msgtype=JasonWebToken, encrypt=False, enc_enc="A128CBC-HS256", - enc_alg="RSA1_5"): + def __init__( + self, + keyjar, + iss="", + lifetime=0, + sign_alg="RS256", + msgtype=JasonWebToken, + encrypt=False, + enc_enc="A128CBC-HS256", + enc_alg="RSA1_5", + ): self.iss = iss self.lifetime = lifetime self.sign_alg = sign_alg @@ -25,8 +33,8 @@ def __init__(self, keyjar, iss='', lifetime=0, sign_alg='RS256', self.enc_alg = enc_alg self.enc_enc = enc_enc - def _encrypt(self, payload, cty='JWT'): - keys = self.keyjar.get_encrypt_key(owner='') + def _encrypt(self, payload, cty="JWT"): + keys = self.keyjar.get_encrypt_key(owner="") kwargs = {"alg": self.enc_alg, "enc": self.enc_enc} if cty: @@ -37,33 +45,34 @@ def _encrypt(self, payload, cty='JWT'): return _jwe.encrypt(keys, context="public") def pack_init(self): - argv = {'iss': self.iss, 'iat': utc_time_sans_frac()} - argv['exp'] = argv['iat'] + self.lifetime + argv = {"iss": self.iss, "iat": utc_time_sans_frac()} + argv["exp"] = argv["iat"] + self.lifetime return argv - def pack_key(self, owner='', kid=''): - keys = self.keyjar.get_signing_key(jws.alg2keytype(self.sign_alg), - owner=owner, kid=kid) + def pack_key(self, owner="", kid=""): + keys = self.keyjar.get_signing_key( + jws.alg2keytype(self.sign_alg), owner=owner, kid=kid + ) if not keys: - raise NoSuitableSigningKeys('kid={}'.format(kid)) + raise NoSuitableSigningKeys("kid={}".format(kid)) return keys[0] # Might be more then one if kid == '' - def pack(self, kid='', owner='', cls_instance=None, **kwargs): + def pack(self, kid="", owner="", cls_instance=None, **kwargs): _args = self.pack_init() - if self.sign_alg != 'none': + if self.sign_alg != "none": _key = self.pack_key(owner, kid) - _args['kid'] = _key.kid + _args["kid"] = _key.kid else: _key = None try: - _encrypt = kwargs['encrypt'] + _encrypt = kwargs["encrypt"] except KeyError: _encrypt = self.encrypt else: - del kwargs['encrypt'] + del kwargs["encrypt"] _args.update(kwargs) @@ -73,13 +82,13 @@ def pack(self, kid='', owner='', cls_instance=None, **kwargs): else: _jwt = self.message_type(**_args) - if 'jti' in self.message_type.c_param: + if "jti" in self.message_type.c_param: try: - _jti = kwargs['jti'] + _jti = kwargs["jti"] except KeyError: _jti = uuid.uuid4().hex - _jwt['jti'] = _jti + _jwt["jti"] = _jti _jws = _jwt.to_jwt([_key], self.sign_alg) if _encrypt: @@ -88,17 +97,19 @@ def pack(self, kid='', owner='', cls_instance=None, **kwargs): return _jws def _verify(self, rj, token): - _msg = json.loads(rj.jwt.part[1].decode('utf8')) - if _msg['iss'] == self.iss: - owner = '' + _msg = json.loads(rj.jwt.part[1].decode("utf8")) + if _msg["iss"] == self.iss: + owner = "" else: - owner = _msg['iss'] + owner = _msg["iss"] - keys = self.keyjar.get_verify_key(jws.alg2keytype(rj.jwt.headers['alg']), owner=owner) + keys = self.keyjar.get_verify_key( + jws.alg2keytype(rj.jwt.headers["alg"]), owner=owner + ) return rj.verify_compact(token, keys) def _decrypt(self, rj, token): - keys = self.keyjar.get_verify_key(owner='') + keys = self.keyjar.get_verify_key(owner="") msg = rj.decrypt(token, keys) _rj = jws.factory(msg) if not _rj: diff --git a/src/oic/utils/keyio.py b/src/oic/utils/keyio.py index fa8267bed..a778b6fe9 100644 --- a/src/oic/utils/keyio.py +++ b/src/oic/utils/keyio.py @@ -24,7 +24,7 @@ from oic.exception import MessageException from oic.exception import PyoidcError -__author__ = 'rohe0002' +__author__ = "rohe0002" KEYLOADERR = "Failed to load %s key from '%s' (%s)" REMOTE_FAILED = "Remote key update from '{}' failed, HTTP status {}" @@ -33,9 +33,9 @@ logger = logging.getLogger(__name__) -def raise_exception(excep, descr, error='service_error'): - _err = json.dumps({'error': error, 'error_description': descr}) - raise excep(_err, 'application/json') +def raise_exception(excep, descr, error="service_error"): + _err = json.dumps({"error": error, "error_description": descr}) + raise excep(_err, "application/json") class KeyIOError(PyoidcError): @@ -54,16 +54,21 @@ class JWKSError(KeyIOError): pass -K2C = { - "RSA": RSAKey, - "EC": ECKey, - "oct": SYMKey, -} +K2C = {"RSA": RSAKey, "EC": ECKey, "oct": SYMKey} class KeyBundle(object): - def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, - fileformat="jwk", keytype="RSA", keyusage=None, timeout=5): + def __init__( + self, + keys=None, + source="", + cache_time=300, + verify_ssl=True, + fileformat="jwk", + keytype="RSA", + keyusage=None, + timeout=5, + ): """ Initialize the KeyBundle. @@ -123,7 +128,7 @@ def do_keys(self, keys): """ for inst in keys: if not isinstance(inst, dict): - raise JWKSError('Illegal JWK') + raise JWKSError("Illegal JWK") typ = inst["kty"] flag = 0 @@ -133,16 +138,16 @@ def do_keys(self, keys): except KeyError: continue except TypeError: - raise JWKSError('Inappropriate JWKS argument type') + raise JWKSError("Inappropriate JWKS argument type") except JWKException as err: - logger.warning('Loading a key failed: %s', err) + logger.warning("Loading a key failed: %s", err) else: if _key not in self._keys: self._keys.append(_key) flag = 1 break if not flag: - logger.warning('Unknown key type: %s', typ) + logger.warning("Unknown key type: %s", typ) def do_local_jwk(self, filename): try: @@ -151,8 +156,8 @@ def do_local_jwk(self, filename): except KeyError: logger.error("Now 'keys' keyword in JWKS") raise_exception( - UpdateFailed, - "Local key update from '{}' failed.".format(filename)) + UpdateFailed, "Local key update from '{}' failed.".format(filename) + ) else: self.last_updated = time.time() @@ -176,12 +181,11 @@ def do_remote(self): args["headers"] = {"If-None-Match": self.etag} try: - logging.debug('KeyBundle fetch keys from: {}'.format(self.source)) + logging.debug("KeyBundle fetch keys from: {}".format(self.source)) r = requests.get(self.source, **args) except Exception as err: logger.error(err) - raise_exception(UpdateFailed, - REMOTE_FAILED.format(self.source, str(err))) + raise_exception(UpdateFailed, REMOTE_FAILED.format(self.source, str(err))) if r.status_code == 304: # file has not changed self.time_out = time.time() + self.cache_time @@ -197,8 +201,7 @@ def do_remote(self): self.time_out = time.time() + self.cache_time self.imp_jwks = self._parse_remote_response(r) - if not isinstance(self.imp_jwks, - dict) or "keys" not in self.imp_jwks: + if not isinstance(self.imp_jwks, dict) or "keys" not in self.imp_jwks: raise_exception(UpdateFailed, MALFORMED.format(self.source)) logger.debug("Loaded JWKS: %s from %s" % (r.text, self.source)) @@ -213,8 +216,9 @@ def do_remote(self): except KeyError: pass else: - raise_exception(UpdateFailed, - REMOTE_FAILED.format(self.source, r.status_code)) + raise_exception( + UpdateFailed, REMOTE_FAILED.format(self.source, r.status_code) + ) self.last_updated = time.time() return True @@ -229,8 +233,12 @@ def _parse_remote_response(self, response): """ # Check if the content type is the right one. try: - if not response.headers["Content-Type"].lower().startswith('application/json'): - logger.warning('Wrong Content_type') + if ( + not response.headers["Content-Type"] + .lower() + .startswith("application/json") + ): + logger.warning("Wrong Content_type") except KeyError: pass @@ -303,8 +311,9 @@ def remove_key(self, typ, val=None): :param val: The key itself """ if val: - self._keys = [k for k in self._keys if - not (k.kty == typ and k.key == val.key)] + self._keys = [ + k for k in self._keys if not (k.kty == typ and k.key == val.key) + ] else: self._keys = [k for k in self._keys if not k.kty == typ] @@ -406,16 +415,21 @@ def dump_jwks(kbl, target, private=False): """ keys = [] for kb in kbl: - keys.extend([k.serialize(private) for k in kb.keys() if - k.kty != 'oct' and not k.inactive_since]) + keys.extend( + [ + k.serialize(private) + for k in kb.keys() + if k.kty != "oct" and not k.inactive_since + ] + ) res = {"keys": keys} try: - f = open(target, 'w') + f = open(target, "w") except IOError: (head, tail) = os.path.split(target) os.makedirs(head) - f = open(target, 'w') + f = open(target, "w") _txt = json.dumps(res) f.write(_txt) @@ -425,8 +439,9 @@ def dump_jwks(kbl, target, private=False): class KeyJar(object): """A keyjar contains a number of KeyBundles.""" - def __init__(self, verify_ssl=True, keybundle_cls=KeyBundle, - remove_after=3600, timeout=5): + def __init__( + self, verify_ssl=True, keybundle_cls=KeyBundle, remove_after=3600, timeout=5 + ): """ Initialize the class. @@ -445,7 +460,7 @@ def __init__(self, verify_ssl=True, keybundle_cls=KeyBundle, def __repr__(self): issuers = list(self.issuer_keys.keys()) - return ''.format(issuers) + return "".format(issuers) def add_if_unique(self, issuer, use, keys): if use in self.issuer_keys[issuer] and self.issuer_keys[issuer][use]: @@ -472,10 +487,13 @@ def add(self, issuer, url, **kwargs): raise KeyError("No jwks_uri") if "/localhost:" in url or "/localhost/" in url: - kc = self.keybundle_cls(source=url, verify_ssl=False, timeout=self.timeout, **kwargs) + kc = self.keybundle_cls( + source=url, verify_ssl=False, timeout=self.timeout, **kwargs + ) else: - kc = self.keybundle_cls(source=url, verify_ssl=self.verify_ssl, timeout=self.timeout, - **kwargs) + kc = self.keybundle_cls( + source=url, verify_ssl=self.verify_ssl, timeout=self.timeout, **kwargs + ) try: self.issuer_keys[issuer].append(kc) @@ -491,13 +509,13 @@ def add_symmetric(self, issuer, key, usage=None): _key = b64e(as_bytes(key)) if usage is None: self.issuer_keys[issuer].append( - self.keybundle_cls([{"kty": "oct", "k": _key}])) + self.keybundle_cls([{"kty": "oct", "k": _key}]) + ) else: for use in usage: self.issuer_keys[issuer].append( - self.keybundle_cls([{"kty": "oct", - "k": _key, - "use": use}])) + self.keybundle_cls([{"kty": "oct", "k": _key, "use": use}]) + ) def add_kb(self, issuer, kb): try: @@ -569,7 +587,7 @@ def get(self, key_use, key_type="", issuer="", kid=None, **kwargs): lst.append(key) continue # Verification can be performed by both `sig` and `ver` keys - if key_use == 'ver' and key.use in ('sig', 'ver'): + if key_use == "ver" and key.use in ("sig", "ver"): lst.append(key) # if elliptic curve have to check I have a key of the right curve @@ -581,9 +599,9 @@ def get(self, key_use, key_type="", issuer="", kid=None, **kwargs): _lst.append(key) lst = _lst - if use == 'enc' and key_type == 'oct' and issuer != '': + if use == "enc" and key_type == "oct" and issuer != "": # Add my symmetric keys - for kb in self.issuer_keys['']: + for kb in self.issuer_keys[""]: for key in kb.get(key_type): if key.inactive_since: continue @@ -655,7 +673,9 @@ def __getitem__(self, issuer): except KeyError: logger.debug( "Issuer '{}' not found, available key issuers: {}".format( - issuer, list(self.issuer_keys.keys()))) + issuer, list(self.issuer_keys.keys()) + ) + ) raise def remove_key(self, issuer, key_type, key): @@ -726,7 +746,10 @@ def load_keys(self, pcr, issuer, replace=False): try: _keys = pcr["jwks"]["keys"] self.issuer_keys[issuer].append( - self.keybundle_cls(_keys, verify_ssl=self.verify_ssl, timeout=self.timeout)) + self.keybundle_cls( + _keys, verify_ssl=self.verify_ssl, timeout=self.timeout + ) + ) except KeyError: pass @@ -757,8 +780,9 @@ def dump_issuer_keys(self, issuer): def export_jwks(self, private=False, issuer=""): keys = [] for kb in self.issuer_keys[issuer]: - keys.extend([k.serialize(private) for k in kb.keys() if - k.inactive_since == 0]) + keys.extend( + [k.serialize(private) for k in kb.keys() if k.inactive_since == 0] + ) return {"keys": keys} def import_jwks(self, jwks, issuer): @@ -771,14 +795,20 @@ def import_jwks(self, jwks, issuer): try: _keys = jwks["keys"] except KeyError: - raise ValueError('Not a proper JWKS') + raise ValueError("Not a proper JWKS") else: try: self.issuer_keys[issuer].append( - self.keybundle_cls(_keys, verify_ssl=self.verify_ssl, timeout=self.timeout)) + self.keybundle_cls( + _keys, verify_ssl=self.verify_ssl, timeout=self.timeout + ) + ) except KeyError: - self.issuer_keys[issuer] = [self.keybundle_cls( - _keys, verify_ssl=self.verify_ssl, timeout=self.timeout)] + self.issuer_keys[issuer] = [ + self.keybundle_cls( + _keys, verify_ssl=self.verify_ssl, timeout=self.timeout + ) + ] def add_keyjar(self, keyjar): for iss, kblist in keyjar.items(): @@ -795,8 +825,11 @@ def dump(self): def restore(self, info): for issuer, keys in info.items(): - self.issuer_keys[issuer] = [self.keybundle_cls( - keys, verify_ssl=self.verify_ssl, timeout=self.timeout)] + self.issuer_keys[issuer] = [ + self.keybundle_cls( + keys, verify_ssl=self.verify_ssl, timeout=self.timeout + ) + ] def copy(self): copy_keyjar = KeyJar(verify_ssl=self.verify_ssl, timeout=self.timeout) @@ -905,12 +938,11 @@ def key_setup(vault, **kwargs): _args = kwargs[usage] if _args["alg"].upper() == "RSA": try: - _key = rsa_load('%s%s' % (vault_path, "pyoidc")) + _key = rsa_load("%s%s" % (vault_path, "pyoidc")) except Exception: - with open(os.devnull, 'w') as devnull: + with open(os.devnull, "w") as devnull: with RedirectStdStreams(stdout=devnull, stderr=devnull): - _key = create_and_store_rsa_key_pair( - path=vault_path) + _key = create_and_store_rsa_key_pair(path=vault_path) k = RSAKey(key=_key, use=usage) k.add_kid() @@ -955,8 +987,7 @@ def key_export(baseurl, local_path, vault, keyjar, **kwargs): with open(_export_filename, "w") as f: f.write(str(kb)) - _url = "%s://%s%s" % (part.scheme, part.netloc, - _export_filename[1:]) + _url = "%s://%s%s" % (part.scheme, part.netloc, _export_filename[1:]) return _url @@ -978,12 +1009,12 @@ def create_and_store_rsa_key_pair(name="pyoidc", path=".", size=2048): os.makedirs(path, exist_ok=True) if name: - with open(os.path.join(path, name), 'wb') as f: - f.write(key.exportKey('PEM')) + with open(os.path.join(path, name), "wb") as f: + f.write(key.exportKey("PEM")) _pub_key = key.publickey() - with open(os.path.join(path, '{}.pub'.format(name)), 'wb') as f: - f.write(_pub_key.exportKey('PEM')) + with open(os.path.join(path, "{}.pub".format(name)), "wb") as f: + f.write(_pub_key.exportKey("PEM")) return key @@ -1068,8 +1099,9 @@ def keyjar_init(instance, key_conf, kid_template=""): :param kid_template: A template by which to build the kids :return: a JWKS """ - jwks, keyjar, kdd = build_keyjar(key_conf, kid_template, instance.keyjar, - instance.kid) + jwks, keyjar, kdd = build_keyjar( + key_conf, kid_template, instance.keyjar, instance.kid + ) instance.keyjar = keyjar instance.kid = kdd @@ -1077,13 +1109,13 @@ def keyjar_init(instance, key_conf, kid_template=""): def _new_rsa_key(spec): - if 'name' not in spec: - if '/' in spec['key']: - (head, tail) = os.path.split(spec['key']) - spec['path'] = head - spec['name'] = tail + if "name" not in spec: + if "/" in spec["key"]: + (head, tail) = os.path.split(spec["key"]) + spec["path"] = head + spec["name"] = tail else: - spec['name'] = spec['key'] + spec["name"] = spec["key"] return rsa_init(spec) @@ -1118,12 +1150,16 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, kidd=None): if typ == "RSA": if "key" in spec: - error_to_catch = getattr(builtins, 'FileNotFoundError', - getattr(builtins, 'IOError')) + error_to_catch = getattr( + builtins, "FileNotFoundError", getattr(builtins, "IOError") + ) try: - kb = KeyBundle(source="file://%s" % spec["key"], - fileformat="der", - keytype=typ, keyusage=spec["use"]) + kb = KeyBundle( + source="file://%s" % spec["key"], + fileformat="der", + keytype=typ, + keyusage=spec["use"], + ) except error_to_catch: kb = _new_rsa_key(spec) except Exception: @@ -1141,8 +1177,7 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, kidd=None): k.add_kid() kidd[k.use][k.kty] = k.kid - jwks["keys"].extend( - [k.serialize() for k in kb.keys() if k.kty != 'oct']) + jwks["keys"].extend([k.serialize() for k in kb.keys() if k.kty != "oct"]) keyjar.add_kb("", kb) @@ -1159,18 +1194,16 @@ def key_summary(keyjar, issuer): try: kbl = keyjar[issuer] except KeyError: - return '' + return "" else: key_list = [] for kb in kbl: for key in kb.keys(): if key.inactive_since: - key_list.append( - '*{}:{}:{}'.format(key.kty, key.use, key.kid)) + key_list.append("*{}:{}:{}".format(key.kty, key.use, key.kid)) else: - key_list.append( - '{}:{}:{}'.format(key.kty, key.use, key.kid)) - return ', '.join(key_list) + key_list.append("{}:{}:{}".format(key.kty, key.use, key.kid)) + return ", ".join(key_list) def check_key_availability(inst, jwt): @@ -1186,8 +1219,8 @@ def check_key_availability(inst, jwt): """ _rj = jws.factory(jwt) payload = json.loads(as_unicode(_rj.jwt.part[1])) - _cid = payload['iss'] + _cid = payload["iss"] if _cid not in inst.keyjar: cinfo = inst.cdb[_cid] - inst.keyjar.add_symmetric(_cid, cinfo['client_secret'], ['enc', 'sig']) - inst.keyjar.add(_cid, cinfo['jwks_uri']) + inst.keyjar.add_symmetric(_cid, cinfo["client_secret"], ["enc", "sig"]) + inst.keyjar.add(_cid, cinfo["jwks_uri"]) diff --git a/src/oic/utils/restrict.py b/src/oic/utils/restrict.py index f6826fcec..fe5f24193 100644 --- a/src/oic/utils/restrict.py +++ b/src/oic/utils/restrict.py @@ -2,36 +2,36 @@ import json import sys -__author__ = 'roland' +__author__ = "roland" def single(restriction, cinfo): for s in restriction: try: if len(cinfo[s]) != 1: - return 'Too Many {}'.format(s) + return "Too Many {}".format(s) except KeyError: pass - return '' + return "" def map_grant_type2response_type(restriction, cinfo): - if 'grant_types' in cinfo and 'response_types' in cinfo: + if "grant_types" in cinfo and "response_types" in cinfo: for g, r in restriction.items(): - if g in cinfo['grant_types'] and r in cinfo['response_types']: + if g in cinfo["grant_types"] and r in cinfo["response_types"]: pass - elif g in cinfo['grant_types'] or r in cinfo['response_types']: + elif g in cinfo["grant_types"] or r in cinfo["response_types"]: return "grant_type didn't match response_type" - return '' + return "" def map(restriction, cinfo): for fname, spec in restriction.items(): - func = factory('map_' + fname) + func = factory("map_" + fname) resp = func(spec, cinfo) if resp: return resp - return '' + return "" def allow(restriction, cinfo): @@ -43,14 +43,14 @@ def allow(restriction, cinfo): if isinstance(_cparam, str): if _cparam not in args: - return 'Not allowed to register with {}={}'.format(param, - _cparam) + return "Not allowed to register with {}={}".format(param, _cparam) else: if not set(_cparam).issubset(args): - return 'Not allowed to register with {}={}'.format( - param, json.dumps(_cparam)) + return "Not allowed to register with {}={}".format( + param, json.dumps(_cparam) + ) - return '' + return "" def assign(restriction, cinfo): diff --git a/src/oic/utils/rp/__init__.py b/src/oic/utils/rp/__init__.py index cf9e2f32d..8df247971 100644 --- a/src/oic/utils/rp/__init__.py +++ b/src/oic/utils/rp/__init__.py @@ -18,7 +18,7 @@ from oic.utils.http_util import Redirect from oic.utils.sanitize import sanitize -__author__ = 'roland' +__author__ = "roland" logger = logging.getLogger(__name__) @@ -29,16 +29,30 @@ class OIDCError(Exception): class Client(oic.Client): - def __init__(self, client_id=None, - client_prefs=None, client_authn_method=None, keyjar=None, - verify_ssl=True, behaviour=None, config=None, jwks_uri='', - kid=None): - oic.Client.__init__(self, client_id, client_prefs, - client_authn_method, keyjar, verify_ssl, - config=config) + def __init__( + self, + client_id=None, + client_prefs=None, + client_authn_method=None, + keyjar=None, + verify_ssl=True, + behaviour=None, + config=None, + jwks_uri="", + kid=None, + ): + oic.Client.__init__( + self, + client_id, + client_prefs, + client_authn_method, + keyjar, + verify_ssl, + config=config, + ) if behaviour: self.behaviour = behaviour - self.userinfo_request_method = '' + self.userinfo_request_method = "" self.allow_sign_alg_none = False self.authz_req = {} self.get_userinfo = True @@ -52,12 +66,12 @@ def create_authn_request(self, session, acr_value=None, **kwargs): "response_type": self.behaviour["response_type"], "scope": self.behaviour["scope"], "state": session["state"], - "redirect_uri": self.registration_response["redirect_uris"][0] + "redirect_uri": self.registration_response["redirect_uris"][0], } if self.oidc: session["nonce"] = rndstr(32) - request_args['nonce'] = session['nonce'] + request_args["nonce"] = session["nonce"] if acr_value is not None: request_args["acr_values"] = acr_value @@ -66,11 +80,11 @@ def create_authn_request(self, session, acr_value=None, **kwargs): cis = self.construct_AuthorizationRequest(request_args=request_args) logger.debug("request: %s" % sanitize(cis)) - url, body, ht_args, cis = self.uri_and_body(AuthorizationRequest, cis, - method="GET", - request_args=request_args) + url, body, ht_args, cis = self.uri_and_body( + AuthorizationRequest, cis, method="GET", request_args=request_args + ) - self.authz_req[request_args['state']] = cis + self.authz_req[request_args["state"]] = cis logger.debug("body: %s" % sanitize(body)) logger.info("URL: %s" % sanitize(url)) logger.debug("ht_args: %s" % sanitize(ht_args)) @@ -95,7 +109,7 @@ def _err(self, txt): logger.error(sanitize(txt)) raise OIDCError(txt) - def callback(self, response, session, format='dict'): + def callback(self, response, session, format="dict"): """ Call when an AuthN response has been received from the OP. @@ -103,8 +117,9 @@ def callback(self, response, session, format='dict'): :return: """ try: - authresp = self.parse_response(AuthorizationResponse, response, - sformat=format, keyjar=self.keyjar) + authresp = self.parse_response( + AuthorizationResponse, response, sformat=format, keyjar=self.keyjar + ) except ResponseError: msg = "Could not parse response: '{}'" logger.error(msg.format(sanitize(response))) @@ -120,11 +135,11 @@ def callback(self, response, session, format='dict'): _state = authresp["state"] try: - _id_token = authresp['id_token'] + _id_token = authresp["id_token"] except KeyError: _id_token = None else: - if _id_token['nonce'] != self.authz_req[_state]['nonce']: + if _id_token["nonce"] != self.authz_req[_state]["nonce"]: self._err("Received nonce not the same as expected.") if self.behaviour["response_type"] == "code": @@ -132,14 +147,13 @@ def callback(self, response, session, format='dict'): try: args = { "code": authresp["code"], - "redirect_uri": self.registration_response[ - "redirect_uris"][0], + "redirect_uri": self.registration_response["redirect_uris"][0], "client_id": self.client_id, "client_secret": self.client_secret, } try: - args['scope'] = response['scope'] + args["scope"] = response["scope"] except KeyError: pass @@ -147,43 +161,45 @@ def callback(self, response, session, format='dict'): state=authresp["state"], request_args=args, authn_method=self.registration_response[ - "token_endpoint_auth_method"]) - msg = 'Access token response: {}' + "token_endpoint_auth_method" + ], + ) + msg = "Access token response: {}" logger.info(msg.format(sanitize(atresp))) except Exception as err: logger.error("%s" % err) raise if isinstance(atresp, ErrorResponse): - msg = 'Error response: {}' + msg = "Error response: {}" self._err(msg.format(sanitize(atresp.to_dict()))) - _token = atresp['access_token'] + _token = atresp["access_token"] try: - _id_token = atresp['id_token'] + _id_token = atresp["id_token"] except KeyError: pass else: - _token = authresp['access_token'] + _token = authresp["access_token"] if not self.oidc: - return {'access_token': _token} + return {"access_token": _token} if _id_token is None: self._err("Invalid response: no IdToken") - if _id_token['iss'] != self.provider_info['issuer']: + if _id_token["iss"] != self.provider_info["issuer"]: self._err("Issuer mismatch") - if _id_token['nonce'] != self.authz_req[_state]['nonce']: + if _id_token["nonce"] != self.authz_req[_state]["nonce"]: self._err("Nonce mismatch") if not self.allow_sign_alg_none: - if _id_token.jws_header['alg'] == 'none': + if _id_token.jws_header["alg"] == "none": self._err('Do not allow "none" signature algorithm') - user_id = '{}:{}'.format(_id_token['iss'], _id_token['sub']) + user_id = "{}:{}".format(_id_token["iss"], _id_token["sub"]) if self.get_userinfo: if self.userinfo_request_method: @@ -192,15 +208,14 @@ def callback(self, response, session, format='dict'): kwargs = {} if self.has_access_token(state=authresp["state"]): - inforesp = self.do_user_info_request(state=authresp["state"], - **kwargs) + inforesp = self.do_user_info_request(state=authresp["state"], **kwargs) if isinstance(inforesp, ErrorResponse): self._err("Invalid response %s." % inforesp["error"]) userinfo = inforesp.to_dict() - if _id_token['sub'] != userinfo['sub']: + if _id_token["sub"] != userinfo["sub"]: self._err("Invalid response: userid mismatch") logger.debug("UserInfo: %s" % sanitize(inforesp)) @@ -217,16 +232,18 @@ def callback(self, response, session, format='dict'): except KeyError: pass - return {'user_id': user_id, 'userinfo': userinfo, - 'id_token': _id_token, 'access_token': _token} + return { + "user_id": user_id, + "userinfo": userinfo, + "id_token": _id_token, + "access_token": _token, + } else: - return {'user_id': user_id, 'id_token': _id_token, - 'access_token': _token} + return {"user_id": user_id, "id_token": _id_token, "access_token": _token} class OIDCClients(object): - def __init__(self, config, base_url, seed='', jwks_info=None, - verify_ssl=True): + def __init__(self, config, base_url, seed="", jwks_info=None, verify_ssl=True): """ Initialize the client. @@ -237,7 +254,7 @@ def __init__(self, config, base_url, seed='', jwks_info=None, self.client_cls = Client self.config = config self.seed = seed or rndstr(16) - self.seed = self.seed.encode('utf8') + self.seed = self.seed.encode("utf8") self.path = {} self.base_url = base_url self.jwks_info = jwks_info @@ -272,15 +289,17 @@ def create_client(self, userid="", **kwargs): """ _key_set = set(list(kwargs.keys())) try: - _verify_ssl = kwargs['verify_ssl'] + _verify_ssl = kwargs["verify_ssl"] except KeyError: _verify_ssl = self.verify_ssl else: - _key_set.discard('verify_ssl') + _key_set.discard("verify_ssl") - client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD, - behaviour=kwargs["behaviour"], - verify_ssl=_verify_ssl) + client = self.client_cls( + client_authn_method=CLIENT_AUTHN_METHOD, + behaviour=kwargs["behaviour"], + verify_ssl=_verify_ssl, + ) try: client.userinfo_request_method = kwargs["userinfo_request_method"] @@ -309,87 +328,104 @@ def create_client(self, userid="", **kwargs): # Gather OP information client.provider_config(issuer) # register the client - client.register(client.provider_info["registration_endpoint"], - **kwargs["client_info"]) - self.get_path(kwargs['client_info']['redirect_uris'], issuer) + client.register( + client.provider_info["registration_endpoint"], **kwargs["client_info"] + ) + self.get_path(kwargs["client_info"]["redirect_uris"], issuer) elif _key_set == set(["client_info", "srv_discovery_url"]): # Ship the webfinger part # Gather OP information client.provider_config(kwargs["srv_discovery_url"]) # register the client - client.register(client.provider_info["registration_endpoint"], - **kwargs["client_info"]) - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["srv_discovery_url"]) + client.register( + client.provider_info["registration_endpoint"], **kwargs["client_info"] + ) + self.get_path( + kwargs["client_info"]["redirect_uris"], kwargs["srv_discovery_url"] + ) elif _key_set == set(["provider_info", "client_info"]): client.handle_provider_config( ProviderConfigurationResponse(**kwargs["provider_info"]), - kwargs["provider_info"]["issuer"]) - client.register(client.provider_info["registration_endpoint"], - **kwargs["client_info"]) - - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["provider_info"]["issuer"]) + kwargs["provider_info"]["issuer"], + ) + client.register( + client.provider_info["registration_endpoint"], **kwargs["client_info"] + ) + + self.get_path( + kwargs["client_info"]["redirect_uris"], + kwargs["provider_info"]["issuer"], + ) elif _key_set == set(["provider_info", "client_registration"]): client.handle_provider_config( ProviderConfigurationResponse(**kwargs["provider_info"]), - kwargs["provider_info"]["issuer"]) - client.store_registration_info(RegistrationResponse( - **kwargs["client_registration"])) - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["provider_info"]["issuer"]) + kwargs["provider_info"]["issuer"], + ) + client.store_registration_info( + RegistrationResponse(**kwargs["client_registration"]) + ) + self.get_path( + kwargs["client_info"]["redirect_uris"], + kwargs["provider_info"]["issuer"], + ) elif _key_set == set(["srv_discovery_url", "client_registration"]): client.provider_config(kwargs["srv_discovery_url"]) - client.store_registration_info(RegistrationResponse( - **kwargs["client_registration"])) - self.get_path(kwargs['client_registration']['redirect_uris'], - kwargs["srv_discovery_url"]) + client.store_registration_info( + RegistrationResponse(**kwargs["client_registration"]) + ) + self.get_path( + kwargs["client_registration"]["redirect_uris"], + kwargs["srv_discovery_url"], + ) else: raise Exception("Configuration error ?") return client - def dynamic_client(self, userid='', issuer=''): - client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD, - verify_ssl=self.verify_ssl, **self.jwks_info) + def dynamic_client(self, userid="", issuer=""): + client = self.client_cls( + client_authn_method=CLIENT_AUTHN_METHOD, + verify_ssl=self.verify_ssl, + **self.jwks_info + ) if userid: issuer = client.wf.discovery_query(userid) if not issuer: - raise OIDCError('Missing issuer') + raise OIDCError("Missing issuer") - logger.info('issuer: {}'.format(issuer)) + logger.info("issuer: {}".format(issuer)) if issuer in self.client: return self.client[issuer] else: # Gather OP information _pcr = client.provider_config(issuer) - logger.info('Provider info: {}'.format(sanitize(_pcr.to_dict()))) + logger.info("Provider info: {}".format(sanitize(_pcr.to_dict()))) # register the client _cinfo = self.config.CLIENTS[""]["client_info"] reg_args = copy.copy(_cinfo) h = hashlib.sha256(self.seed) - h.update(issuer.encode('utf8')) # issuer has to be bytes + h.update(issuer.encode("utf8")) # issuer has to be bytes base_urls = _cinfo["redirect_uris"] - reg_args['redirect_uris'] = [ - u.format(base=self.base_url, iss=h.hexdigest()) - for u in base_urls] + reg_args["redirect_uris"] = [ + u.format(base=self.base_url, iss=h.hexdigest()) for u in base_urls + ] try: - reg_args['post_logout_redirect_uris'] = [ + reg_args["post_logout_redirect_uris"] = [ u.format(base=self.base_url, iss=h.hexdigest()) - for u in reg_args['post_logout_redirect_uris'] - ] + for u in reg_args["post_logout_redirect_uris"] + ] except KeyError: pass - self.get_path(reg_args['redirect_uris'], issuer) + self.get_path(reg_args["redirect_uris"], issuer) if client.jwks_uri: - reg_args['jwks_uri'] = client.jwks_uri + reg_args["jwks_uri"] = client.jwks_uri rr = client.register(_pcr["registration_endpoint"], **reg_args) - msg = 'Registration response: {}' + msg = "Registration response: {}" logger.info(msg.format(sanitize(rr.to_dict()))) try: diff --git a/src/oic/utils/rp/oauth2.py b/src/oic/utils/rp/oauth2.py index 3d92da0c6..0a774834d 100644 --- a/src/oic/utils/rp/oauth2.py +++ b/src/oic/utils/rp/oauth2.py @@ -18,7 +18,7 @@ from oic.utils.sanitize import sanitize from oic.utils.webfinger import WebFinger -__author__ = 'roland' +__author__ = "roland" logger = logging.getLogger(__name__) @@ -29,14 +29,22 @@ class OAuth2Error(Exception): class OAuthClient(client.Client): - def __init__(self, client_id=None, - client_prefs=None, client_authn_method=None, keyjar=None, - verify_ssl=True, behaviour=None, jwks_uri='', - kid=None): - client.Client.__init__(self, client_id, client_authn_method, - keyjar=keyjar, verify_ssl=verify_ssl) + def __init__( + self, + client_id=None, + client_prefs=None, + client_authn_method=None, + keyjar=None, + verify_ssl=True, + behaviour=None, + jwks_uri="", + kid=None, + ): + client.Client.__init__( + self, client_id, client_authn_method, keyjar=keyjar, verify_ssl=verify_ssl + ) self.behaviour = behaviour or {} - self.userinfo_request_method = '' + self.userinfo_request_method = "" self.allow_sign_alg_none = False self.authz_req = {} self.get_userinfo = True @@ -51,7 +59,7 @@ def create_authn_request(self, session, acr_value=None, **kwargs): request_args = { "response_type": self.behaviour["response_type"], "state": session["state"], - "redirect_uri": self.registration_response["redirect_uris"][0] + "redirect_uri": self.registration_response["redirect_uris"][0], } try: @@ -63,11 +71,11 @@ def create_authn_request(self, session, acr_value=None, **kwargs): cis = self.construct_AuthorizationRequest(request_args=request_args) logger.debug("request: %s" % sanitize(cis)) - url, body, ht_args, cis = self.uri_and_body(AuthorizationRequest, cis, - method="GET", - request_args=request_args) + url, body, ht_args, cis = self.uri_and_body( + AuthorizationRequest, cis, method="GET", request_args=request_args + ) - self.authz_req[request_args['state']] = cis + self.authz_req[request_args["state"]] = cis logger.debug("body: %s" % sanitize(body)) logger.info("URL: %s" % sanitize(url)) logger.debug("ht_args: %s" % sanitize(ht_args)) @@ -92,21 +100,22 @@ def _err(self, txt): logger.error(sanitize(txt)) raise OAuth2Error(txt) - def callback(self, response, session, format='dict'): + def callback(self, response, session, format="dict"): """ Call when an AuthN response has been received from the OP. :param response: The URL returned by the OP :return: """ - if self.behaviour["response_type"] == 'code': + if self.behaviour["response_type"] == "code": respcls = AuthorizationResponse else: respcls = AccessTokenResponse try: - authresp = self.parse_response(respcls, response, - sformat=format, keyjar=self.keyjar) + authresp = self.parse_response( + respcls, response, sformat=format, keyjar=self.keyjar + ) except ResponseError: msg = "Could not parse response: '{}'" logger.error(msg.format(sanitize(response))) @@ -128,14 +137,13 @@ def callback(self, response, session, format='dict'): try: args = { "code": authresp["code"], - "redirect_uri": self.registration_response[ - "redirect_uris"][0], + "redirect_uri": self.registration_response["redirect_uris"][0], "client_id": self.client_id, "client_secret": self.client_secret, } try: - args['scope'] = response['scope'] + args["scope"] = response["scope"] except KeyError: pass @@ -143,25 +151,26 @@ def callback(self, response, session, format='dict'): state=authresp["state"], request_args=args, authn_method=self.registration_response[ - "token_endpoint_auth_method"]) - logger.info('Access token response: {}'.format(sanitize(atresp))) + "token_endpoint_auth_method" + ], + ) + logger.info("Access token response: {}".format(sanitize(atresp))) except Exception as err: logger.error("%s" % err) raise if isinstance(atresp, ErrorResponse): - self._err('Error response: {}'.format(atresp.to_dict())) + self._err("Error response: {}".format(atresp.to_dict())) - _token = atresp['access_token'] + _token = atresp["access_token"] else: - _token = authresp['access_token'] + _token = authresp["access_token"] - return {'access_token': _token} + return {"access_token": _token} class OAuthClients(object): - def __init__(self, config, base_url, seed='', jwks_info=None, - verify_ssl=True): + def __init__(self, config, base_url, seed="", jwks_info=None, verify_ssl=True): """ Initialize the client. @@ -172,7 +181,7 @@ def __init__(self, config, base_url, seed='', jwks_info=None, self.client_cls = OAuthClient self.config = config self.seed = seed or rndstr(16) - self.seed = self.seed.encode('utf8') + self.seed = self.seed.encode("utf8") self.path = {} self.base_url = base_url self.jwks_info = jwks_info @@ -210,15 +219,17 @@ def create_client(self, **kwargs): """ _key_set = set(list(kwargs.keys())) try: - _verify_ssl = kwargs['verify_ssl'] + _verify_ssl = kwargs["verify_ssl"] except KeyError: _verify_ssl = self.verify_ssl else: - _key_set.discard('verify_ssl') + _key_set.discard("verify_ssl") - _client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD, - behaviour=kwargs["behaviour"], - verify_ssl=_verify_ssl) + _client = self.client_cls( + client_authn_method=CLIENT_AUTHN_METHOD, + behaviour=kwargs["behaviour"], + verify_ssl=_verify_ssl, + ) # The behaviour parameter is not significant for the election process _key_set.discard("behaviour") @@ -236,42 +247,57 @@ def create_client(self, **kwargs): # Gather OP information _client.provider_config(kwargs["srv_discovery_url"]) # register the client - _client.register(_client.provider_info["registration_endpoint"], - **kwargs["client_info"]) - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["srv_discovery_url"]) + _client.register( + _client.provider_info["registration_endpoint"], **kwargs["client_info"] + ) + self.get_path( + kwargs["client_info"]["redirect_uris"], kwargs["srv_discovery_url"] + ) elif _key_set == {"provider_info", "client_info"}: _client.handle_provider_config( ProviderConfigurationResponse(**kwargs["provider_info"]), - kwargs["provider_info"]["issuer"]) - _client.register(_client.provider_info["registration_endpoint"], - **kwargs["client_info"]) - - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["provider_info"]["issuer"]) + kwargs["provider_info"]["issuer"], + ) + _client.register( + _client.provider_info["registration_endpoint"], **kwargs["client_info"] + ) + + self.get_path( + kwargs["client_info"]["redirect_uris"], + kwargs["provider_info"]["issuer"], + ) elif _key_set == {"provider_info", "client_registration"}: _client.handle_provider_config( ProviderConfigurationResponse(**kwargs["provider_info"]), - kwargs["provider_info"]["issuer"]) - _client.store_registration_info(ClientInfoResponse( - **kwargs["client_registration"])) - self.get_path(kwargs['client_info']['redirect_uris'], - kwargs["provider_info"]["issuer"]) + kwargs["provider_info"]["issuer"], + ) + _client.store_registration_info( + ClientInfoResponse(**kwargs["client_registration"]) + ) + self.get_path( + kwargs["client_info"]["redirect_uris"], + kwargs["provider_info"]["issuer"], + ) elif _key_set == {"srv_discovery_url", "client_registration"}: _client.provider_config(kwargs["srv_discovery_url"]) - _client.store_registration_info(ClientInfoResponse( - **kwargs["client_registration"])) - self.get_path(kwargs['client_registration']['redirect_uris'], - kwargs["srv_discovery_url"]) + _client.store_registration_info( + ClientInfoResponse(**kwargs["client_registration"]) + ) + self.get_path( + kwargs["client_registration"]["redirect_uris"], + kwargs["srv_discovery_url"], + ) else: raise Exception("Configuration error ?") return client - def dynamic_client(self, issuer='', userid=''): - client = self.client_cls(client_authn_method=CLIENT_AUTHN_METHOD, - verify_ssl=self.verify_ssl, - **self.jwks_info) + def dynamic_client(self, issuer="", userid=""): + client = self.client_cls( + client_authn_method=CLIENT_AUTHN_METHOD, + verify_ssl=self.verify_ssl, + **self.jwks_info + ) if userid: try: issuer = client.wf.discovery_query(userid) @@ -280,41 +306,41 @@ def dynamic_client(self, issuer='', userid=''): issuer = wf.discovery_query(userid) if not issuer: - raise OAuth2Error('Missing issuer') + raise OAuth2Error("Missing issuer") - logger.info('issuer: {}'.format(issuer)) + logger.info("issuer: {}".format(issuer)) if issuer in self.client: return self.client[issuer] else: # Gather OP information _pcr = client.provider_config(issuer) - logger.info('Provider info: {}'.format(sanitize(_pcr.to_dict()))) - issuer = _pcr['issuer'] # So no hickup later about trailing '/' + logger.info("Provider info: {}".format(sanitize(_pcr.to_dict()))) + issuer = _pcr["issuer"] # So no hickup later about trailing '/' # register the client _cinfo = self.config.CLIENTS[""]["client_info"] reg_args = copy.copy(_cinfo) h = hashlib.sha256(self.seed) - h.update(issuer.encode('utf8')) # issuer has to be bytes + h.update(issuer.encode("utf8")) # issuer has to be bytes base_urls = _cinfo["redirect_uris"] - reg_args['redirect_uris'] = [ - u.format(base=self.base_url, iss=h.hexdigest()) - for u in base_urls] + reg_args["redirect_uris"] = [ + u.format(base=self.base_url, iss=h.hexdigest()) for u in base_urls + ] try: - reg_args['post_logout_redirect_uris'] = [ + reg_args["post_logout_redirect_uris"] = [ u.format(base=self.base_url, iss=h.hexdigest()) - for u in reg_args['post_logout_redirect_uris'] - ] + for u in reg_args["post_logout_redirect_uris"] + ] except KeyError: pass - self.get_path(reg_args['redirect_uris'], issuer) + self.get_path(reg_args["redirect_uris"], issuer) if client.jwks_uri: - reg_args['jwks_uri'] = client.jwks_uri + reg_args["jwks_uri"] = client.jwks_uri rr = client.register(_pcr["registration_endpoint"], **reg_args) - msg = 'Registration response: {}' + msg = "Registration response: {}" logger.info(msg.format(sanitize(rr.to_dict()))) try: diff --git a/src/oic/utils/sanitize.py b/src/oic/utils/sanitize.py index 8fe579134..94cfcc9ed 100644 --- a/src/oic/utils/sanitize.py +++ b/src/oic/utils/sanitize.py @@ -2,12 +2,19 @@ from collections import Mapping from textwrap import dedent -SENSITIVE_THINGS = {'password', 'passwd', 'client_secret', 'code', - 'authorization', 'access_token', 'refresh_token'} - -REPLACEMENT = '' - -SANITIZE_PATTERN = r''' +SENSITIVE_THINGS = { + "password", + "passwd", + "client_secret", + "code", + "authorization", + "access_token", + "refresh_token", +} + +REPLACEMENT = "" + +SANITIZE_PATTERN = r""" (?' -''' +""" -SANITIZE_PATTERN = dedent(SANITIZE_PATTERN.format('|'.join(SENSITIVE_THINGS))) +SANITIZE_PATTERN = dedent(SANITIZE_PATTERN.format("|".join(SENSITIVE_THINGS))) SANITIZE_REGEX = re.compile(SANITIZE_PATTERN, re.VERBOSE | re.IGNORECASE | re.UNICODE) @@ -39,10 +46,8 @@ def sanitize(potentially_sensitive): if isinstance(potentially_sensitive, Mapping): # Makes new dict so we don't modify the original # Also case-insensitive--possibly important for HTTP headers. - return dict( - redacted(k.lower(), v) for k, v in potentially_sensitive.items()) + return dict(redacted(k.lower(), v) for k, v in potentially_sensitive.items()) else: if not isinstance(potentially_sensitive, str): potentially_sensitive = str(potentially_sensitive) - return SANITIZE_REGEX.sub(r'\1{}'.format(REPLACEMENT), - potentially_sensitive) + return SANITIZE_REGEX.sub(r"\1{}".format(REPLACEMENT), potentially_sensitive) diff --git a/src/oic/utils/sdb.py b/src/oic/utils/sdb.py index 9825f0772..38a07e537 100644 --- a/src/oic/utils/sdb.py +++ b/src/oic/utils/sdb.py @@ -16,7 +16,7 @@ from oic.utils.time_util import time_sans_frac from oic.utils.time_util import utc_time_sans_frac -__author__ = 'rohe0002' +__author__ = "rohe0002" logger = logging.getLogger(__name__) @@ -24,17 +24,17 @@ def lv_pack(*args): s = [] for a in args: - s.append('{}:{}'.format(len(a), a)) - return ''.join(s) + s.append("{}:{}".format(len(a), a)) + return "".join(s) def lv_unpack(txt): txt = txt.strip() res = [] while txt: - l, v = txt.split(':', 1) - res.append(v[:int(l)]) - txt = v[int(l):] + l, v = txt.split(":", 1) + res.append(v[: int(l)]) + txt = v[int(l) :] return res @@ -56,21 +56,22 @@ class UnknownToken(Exception): def pairwise_id(sub, sector_identifier, seed): return hashlib.sha256( - ("%s%s%s" % (sub, sector_identifier, seed)).encode("utf-8")).hexdigest() + ("%s%s%s" % (sub, sector_identifier, seed)).encode("utf-8") + ).hexdigest() class Crypt(object): - def __init__(self, password, mode=None): self.key = base64.urlsafe_b64encode( - hashlib.sha256(password.encode("utf-8")).digest()) + hashlib.sha256(password.encode("utf-8")).digest() + ) self.core = Fernet(self.key) def encrypt(self, text): # Padding to blocksize of AES text = tobytes(text) if len(text) % 16: - text += b' ' * (16 - len(text) % 16) + text += b" " * (16 - len(text) % 16) return self.core.encrypt(tobytes(text)) def decrypt(self, ciphertext): @@ -150,11 +151,11 @@ def valid(self, token): class DefaultToken(Token): - def __init__(self, secret, password, typ='', **kwargs): + def __init__(self, secret, password, typ="", **kwargs): Token.__init__(self, typ, **kwargs) self.crypt = Crypt(password) - def __call__(self, sid='', ttype='', **kwargs): + def __call__(self, sid="", ttype="", **kwargs): """ Return a token. @@ -166,10 +167,10 @@ def __call__(self, sid='', ttype='', **kwargs): if not ttype and self.type: ttype = self.type else: - ttype = 'A' + ttype = "A" - tmp = '' - rnd = '' + tmp = "" + rnd = "" while rnd == tmp: # Don't use the same random value again rnd = rndstr(32) # Ultimate length multiple of 16 @@ -187,8 +188,8 @@ def key(self, user="", areq=None): :param areq: The authorization request :return: A hash """ - csum = hashlib.new('sha224') - csum.update(rndstr(32).encode('utf-8')) + csum = hashlib.new("sha224") + csum.update(rndstr(32).encode("utf-8")) return csum.hexdigest() # 56 bytes long, 224 bits def _split_token(self, token): @@ -242,8 +243,16 @@ def expires_at(self, token): class AuthnEvent(object): - def __init__(self, uid, salt, valid=3600, authn_info=None, - time_stamp=0, authn_time=None, valid_until=None): + def __init__( + self, + uid, + salt, + valid=3600, + authn_info=None, + time_stamp=0, + authn_time=None, + valid_until=None, + ): """ Create a representation of an authentication event. @@ -317,18 +326,26 @@ def create_token(self, client_id, uid, scopes, sub, authzreq, sid): :param sid: Session ID :return: Refresh token """ - refresh_token = 'Refresh_{}'.format(rndstr(5 * 16)) - self.store(refresh_token, - {'client_id': client_id, 'uid': uid, 'scope': scopes, - 'sub': sub, 'authzreq': authzreq, 'sid': sid}) + refresh_token = "Refresh_{}".format(rndstr(5 * 16)) + self.store( + refresh_token, + { + "client_id": client_id, + "uid": uid, + "scope": scopes, + "sub": sub, + "authzreq": authzreq, + "sid": sid, + }, + ) return refresh_token def verify_token(self, client_id, refresh_token): """Verify if the refresh token belongs to client_id.""" - if not refresh_token.startswith('Refresh_'): + if not refresh_token.startswith("Refresh_"): raise WrongTokenType try: - stored_cid = self.get(refresh_token).get('client_id') + stored_cid = self.get(refresh_token).get("client_id") except KeyError: return False return client_id == stored_cid @@ -358,9 +375,15 @@ def remove(self, token): self._db.pop(token) -def create_session_db(base_url, secret, password, db=None, - token_expires_in=3600, grant_expires_in=600, - refresh_token_expires_in=86400): +def create_session_db( + base_url, + secret, + password, + db=None, + token_expires_in=3600, + grant_expires_in=600, + refresh_token_expires_in=86400, +): """ Construct SessionDB instance. @@ -377,14 +400,13 @@ def create_session_db(base_url, secret, password, db=None, :return: A constructed `SessionDB` object. """ - code_factory = DefaultToken(secret, password, typ='A', - lifetime=grant_expires_in) - token_factory = DefaultToken(secret, password, typ='T', - lifetime=token_expires_in) + code_factory = DefaultToken(secret, password, typ="A", lifetime=grant_expires_in) + token_factory = DefaultToken(secret, password, typ="T", lifetime=token_expires_in) db = {} if db is None else db return SessionDB( - base_url, db, + base_url, + db, refresh_db=None, code_factory=code_factory, token_factory=token_factory, @@ -394,10 +416,16 @@ def create_session_db(base_url, secret, password, db=None, class SessionDB(object): - def __init__(self, base_url, db, refresh_db=None, - refresh_token_expires_in=86400, - token_factory=None, code_factory=None, - refresh_token_factory=None): + def __init__( + self, + base_url, + db, + refresh_db=None, + refresh_token_expires_in=86400, + token_factory=None, + code_factory=None, + refresh_token_factory=None, + ): self.base_url = base_url self._db = db @@ -405,12 +433,9 @@ def __init__(self, base_url, db, refresh_db=None, # TODO: uid2sid should have a persistence option too. self.uid2sid = {} - self.token_factory = { - 'code': code_factory, - 'access_token': token_factory, - } + self.token_factory = {"code": code_factory, "access_token": token_factory} - self.token_factory_order = ['code', 'access_token'] + self.token_factory_order = ["code", "access_token"] # TODO: This should simply be a factory like all the others too, # even for the default case. @@ -418,16 +443,17 @@ def __init__(self, base_url, db, refresh_db=None, if refresh_token_factory: if refresh_db: raise ImproperlyConfigured( - "Only use one of refresh_db or refresh_token_factory") + "Only use one of refresh_db or refresh_token_factory" + ) self._refresh_db = None - self.token_factory['refresh_token'] = refresh_token_factory - self.token_factory_order.append('refresh_token') + self.token_factory["refresh_token"] = refresh_token_factory + self.token_factory_order.append("refresh_token") elif refresh_db: self._refresh_db = refresh_db else: self._refresh_db = DictRefreshDB() - self.access_token = self.token_factory['access_token'] + self.access_token = self.token_factory["access_token"] self.token = self.access_token def _get_token_key(self, item, order=None): @@ -534,10 +560,10 @@ def do_sub(self, sid, client_salt, sector_id="", subject_type="public"): if subject_type == "public": sub = hashlib.sha256( - "{}{}".format(uid, user_salt).encode("utf-8")).hexdigest() + "{}{}".format(uid, user_salt).encode("utf-8") + ).hexdigest() else: - sub = pairwise_id(uid, sector_id, - "{}{}".format(client_salt, user_salt)) + sub = pairwise_id(uid, sector_id, "{}{}".format(client_salt, user_salt)) # since sub can be public, there can be more then one session # that uses the same subject identifier @@ -547,12 +573,11 @@ def do_sub(self, sid, client_salt, sector_id="", subject_type="public"): self.uid2sid[uid] = [sid] logger.debug("uid2sid: %s" % self.uid2sid) - self.update(sid, 'sub', sub) + self.update(sid, "sub", sub) return sub - def create_authz_session(self, aevent, areq, id_token=None, oidreq=None, - **kwargs): + def create_authz_session(self, aevent, areq, id_token=None, oidreq=None, **kwargs): """ Create session holding info about the Authorization event. @@ -562,8 +587,8 @@ def create_authz_session(self, aevent, areq, id_token=None, oidreq=None, :param oidreq: An OpenIDRequest instance :return: The session identifier, which is the database key """ - sid = self.token_factory['code'].key(user=aevent.uid, areq=areq) - access_grant = self.token_factory['code'](sid=sid) + sid = self.token_factory["code"].key(user=aevent.uid, areq=areq) + access_grant = self.token_factory["code"](sid=sid) _dic = { "oauth_state": "authz", @@ -571,9 +596,9 @@ def create_authz_session(self, aevent, areq, id_token=None, oidreq=None, "code_used": False, "authzreq": areq.to_json(), "client_id": areq["client_id"], - 'response_type': areq['response_type'], + "response_type": areq["response_type"], "revoked": False, - "authn_event": aevent.to_json() + "authn_event": aevent.to_json(), } _dic.update(kwargs) @@ -614,8 +639,15 @@ def get_token(self, sid): elif self._db[sid]["oauth_state"] == "token": return self._db[sid]["access_token"] - def upgrade_to_token(self, token=None, issue_refresh=False, id_token="", - oidreq=None, key=None, access_grant=""): + def upgrade_to_token( + self, + token=None, + issue_refresh=False, + id_token="", + oidreq=None, + key=None, + access_grant="", + ): """ Promote session to token. @@ -628,7 +660,7 @@ def upgrade_to_token(self, token=None, issue_refresh=False, id_token="", """ if token: try: - (_, key) = self.token_factory['code'].type_and_key(token) + (_, key) = self.token_factory["code"].type_and_key(token) except Exception: raise WrongTokenType("Not a grant token") @@ -653,8 +685,8 @@ def upgrade_to_token(self, token=None, issue_refresh=False, id_token="", dic["oidreq"] = oidreq if issue_refresh: - if 'authn_event' in dic: - authn_event = AuthnEvent.from_json(dic['authn_event']) + if "authn_event" in dic: + authn_event = AuthnEvent.from_json(dic["authn_event"]) else: authn_event = None if authn_event: @@ -664,11 +696,15 @@ def upgrade_to_token(self, token=None, issue_refresh=False, id_token="", if self._refresh_db: refresh_token = self._refresh_db.create_token( - dic['client_id'], uid, dic.get('scope'), dic['sub'], - dic['authzreq'], key) + dic["client_id"], + uid, + dic.get("scope"), + dic["sub"], + dic["authzreq"], + key, + ) else: - refresh_token = self.token_factory['refresh_token'](key, - sinfo=dic) + refresh_token = self.token_factory["refresh_token"](key, sinfo=dic) dic["refresh_token"] = refresh_token self._db[key] = dic return dic @@ -689,12 +725,12 @@ def refresh_token(self, rtoken, client_id): # Valid refresh token _info = self._refresh_db.get(rtoken) try: - sid = _info['sid'] + sid = _info["sid"] except KeyError: - areq = json.loads(_info['authzreq']) - sid = self.token_factory['code'].key(user=_info['uid'], areq=areq) + areq = json.loads(_info["authzreq"]) + sid = self.token_factory["code"].key(user=_info["uid"], areq=areq) dic = _info - dic['response_type'] = areq['response_type'].split(' ') + dic["response_type"] = areq["response_type"].split(" ") else: try: dic = self._db[sid] @@ -711,8 +747,8 @@ def refresh_token(self, rtoken, client_id): self.access_token.invalidate(at) else: raise ExpiredToken() - elif self.token_factory['refresh_token'].valid(rtoken): - sid = self.token_factory['refresh_token'].get_key(rtoken) + elif self.token_factory["refresh_token"].valid(rtoken): + sid = self.token_factory["refresh_token"].get_key(rtoken) dic = self._db[sid] access_token = self.access_token(sid=sid, sinfo=dic) @@ -742,7 +778,7 @@ def is_valid(self, token, client_id=None): :param token: Access or refresh token :param client_id: Client ID, needed only for Refresh token """ - if token.startswith('Refresh_'): + if token.startswith("Refresh_"): return self._refresh_db.verify_token(client_id, token) try: @@ -784,7 +820,7 @@ def revoke_token(self, token): """ _, sid = self._get_token_type_and_key(token) - self.update(sid, 'revoked', True) + self.update(sid, "revoked", True) return True def revoke_refresh_token(self, rtoken): @@ -796,7 +832,7 @@ def revoke_refresh_token(self, rtoken): if self._refresh_db: self._refresh_db.revoke_token(rtoken) else: - self.token_factory['refresh_token'].invalidate(rtoken) + self.token_factory["refresh_token"].invalidate(rtoken) return True @@ -809,7 +845,7 @@ def revoke_all_tokens(self, token): _, sid = self._get_token_type_and_key(token) try: - rtoken = self._db[sid]['refresh_token'] + rtoken = self._db[sid]["refresh_token"] except KeyError: pass else: @@ -823,8 +859,7 @@ def get_client_id_for_session(self, sid): return _dict["client_id"] def get_client_ids_for_uid(self, uid): - return [self.get_client_id_for_session(sid) for sid in - self.uid2sid[uid]] + return [self.get_client_id_for_session(sid) for sid in self.uid2sid[uid]] def get_verified_Logout(self, uid): _dict = self._db[self.uid2sid[uid]] @@ -844,7 +879,7 @@ def is_revoke_uid(self, uid): return self._db[self.uid2sid[uid]]["revoked"] def revoke_uid(self, uid): - self.update(self.uid2sid[uid], 'revoked', True) + self.update(self.uid2sid[uid], "revoked", True) def get_sids_from_uid(self, uid): """ @@ -862,15 +897,23 @@ def get_sids_by_sub(self, sub): def duplicate(self, sinfo): _dic = copy.copy(sinfo) areq = AuthorizationRequest().from_json(_dic["authzreq"]) - sid = self.token_factory['code'].key(user=_dic["sub"], areq=areq) + sid = self.token_factory["code"].key(user=_dic["sub"], areq=areq) - _dic["code"] = self.token_factory['code'](sid=sid, sinfo=sinfo) + _dic["code"] = self.token_factory["code"](sid=sid, sinfo=sinfo) _dic["code_used"] = False - for key in ["access_token", "access_token_scope", "oauth_state", - "token_type", "token_expires_at", "expires_in", - "client_id_issued_at", "id_token", "oidreq", - "refresh_token"]: + for key in [ + "access_token", + "access_token_scope", + "oauth_state", + "token_type", + "token_expires_at", + "expires_in", + "client_id_issued_at", + "id_token", + "oidreq", + "refresh_token", + ]: try: del _dic[key] except KeyError: diff --git a/src/oic/utils/shelve_wrapper.py b/src/oic/utils/shelve_wrapper.py index 089dc0b8f..f94fbee20 100644 --- a/src/oic/utils/shelve_wrapper.py +++ b/src/oic/utils/shelve_wrapper.py @@ -1,6 +1,6 @@ import shelve -__author__ = 'danielevertsson' +__author__ = "danielevertsson" class ShelfWrapper(object): diff --git a/src/oic/utils/stateless.py b/src/oic/utils/stateless.py index e07e681b6..7228bc794 100644 --- a/src/oic/utils/stateless.py +++ b/src/oic/utils/stateless.py @@ -6,7 +6,7 @@ from oic.oic.message import SINGLE_REQUIRED_INT from oic.utils.time_util import epoch_in_a_while -__author__ = 'roland' +__author__ = "roland" class Content(Message): @@ -16,19 +16,29 @@ class Content(Message): "auz": SINGLE_OPTIONAL_STRING, # Authorization information "aud": SINGLE_OPTIONAL_STRING, # The intended receiver "val": SINGLE_REQUIRED_INT, # Valid until - "ref": SINGLE_OPTIONAL_STRING # Refresh token + "ref": SINGLE_OPTIONAL_STRING, # Refresh token } c_allowed_values = {"type": ["code", "access", "refresh"]} class StateLess(object): - def __init__(self, keys, enc_alg, enc_method, grant_validity=300, - access_validity=600, refresh_validity=0): + def __init__( + self, + keys, + enc_alg, + enc_method, + grant_validity=300, + access_validity=600, + refresh_validity=0, + ): self.keys = keys self.alg = enc_alg self.enc = enc_method - self.validity = {"grant": grant_validity, "access": access_validity, - "refresh": refresh_validity} + self.validity = { + "grant": grant_validity, + "access": access_validity, + "refresh": refresh_validity, + } self.used_grants = [] self.revoked = [] @@ -52,8 +62,12 @@ def create_authz_session(self, sub, areq, **kwargs): :param areq: The AuthorizationRequest instance :return: The session identifier, which is the database key """ - _cont = Content(typ="code", sub=sub, aud=areq["redirect_uri"], - val=epoch_in_a_while(self.validity["grant"])) + _cont = Content( + typ="code", + sub=sub, + aud=areq["redirect_uri"], + val=epoch_in_a_while(self.validity["grant"]), + ) return _cont @@ -61,8 +75,12 @@ def upgrade_to_token(self, cont, issue_refresh=False): cont["typ"] = "access" cont["val"] = epoch_in_a_while(self.validity["access"]) if issue_refresh: - _c = Content(sub=cont["sub"], aud=cont["aud"], typ="refresh", - val=epoch_in_a_while(self.validity["refresh"])) + _c = Content( + sub=cont["sub"], + aud=cont["aud"], + typ="refresh", + val=epoch_in_a_while(self.validity["refresh"]), + ) cont["ref"] = _c.to_jwe(self.keys, self.enc, self.alg) return cont @@ -70,7 +88,7 @@ def upgrade_to_token(self, cont, issue_refresh=False): def refresh_token(self, rtoken): # assert that it is a refresh token _cont = Content().from_jwe(rtoken, self.keys) - if _cont['typ'] != 'refresh': + if _cont["typ"] != "refresh": raise Exception("Not a refresh token") def is_expired(self, token): diff --git a/src/oic/utils/template_render.py b/src/oic/utils/template_render.py index 5e6e3523a..0d3a31429 100644 --- a/src/oic/utils/template_render.py +++ b/src/oic/utils/template_render.py @@ -29,7 +29,9 @@ def inputs(form_args): """Create list of input elements.""" element = [] for name, value in form_args.items(): - element.append(''.format(name, value)) + element.append( + ''.format(name, value) + ) return "\n".join(element) @@ -43,14 +45,16 @@ def render_template(template_name, context): Templates are defined as strings in this module. """ - if 'action' not in context: - raise TemplateException('Missing action in context.') - if template_name == 'form_post': - context['html_inputs'] = inputs(context.get('inputs', {})) + if "action" not in context: + raise TemplateException("Missing action in context.") + if template_name == "form_post": + context["html_inputs"] = inputs(context.get("inputs", {})) return FORM_POST.format(**context) - elif template_name == 'verify_logout': - form_args = {'id_token_hint': context.get('id_token_hint', ''), - 'post_logout_redirect_uri': context.get('post_logout_redirect_uri', '')} - context['html_inputs'] = inputs(form_args) + elif template_name == "verify_logout": + form_args = { + "id_token_hint": context.get("id_token_hint", ""), + "post_logout_redirect_uri": context.get("post_logout_redirect_uri", ""), + } + context["html_inputs"] = inputs(form_args) return VERIFY_LOGOUT.format(**context) - raise TemplateException('Unknown template name.') + raise TemplateException("Unknown template name.") diff --git a/src/oic/utils/time_util.py b/src/oic/utils/time_util.py index 628fdc503..5f0f71a5c 100644 --- a/src/oic/utils/time_util.py +++ b/src/oic/utils/time_util.py @@ -20,7 +20,9 @@ from datetime import timedelta TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" -TIME_FORMAT_WITH_FRAGMENT = re.compile(r"^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$") +TIME_FORMAT_WITH_FRAGMENT = re.compile( + r"^(\d{4,4}-\d{2,2}-\d{2,2}T\d{2,2}:\d{2,2}:\d{2,2})\.\d*Z$" +) class TimeUtilError(Exception): @@ -62,25 +64,25 @@ def maximum_day_in_month_for(year, month): ("T", None), ("H", "tm_hour"), ("M", "tm_min"), - ("S", "tm_sec") + ("S", "tm_sec"), ] def parse_duration(duration): # (-)PnYnMnDTnHnMnS index = 0 - if duration[0] == '-': - sign = '-' + if duration[0] == "-": + sign = "-" index += 1 else: - sign = '+' + sign = "+" assert duration[index] == "P" index += 1 dic = dict([(typ, 0) for (code, typ) in D_FORMAT]) for code, typ in D_FORMAT: - if duration[index] == '-': + if duration[index] == "-": raise TimeUtilError("Negation not allowed on individual items") if code == "T": if duration[index] == "T": @@ -93,16 +95,17 @@ def parse_duration(duration): try: mod = duration[index:].index(code) try: - dic[typ] = int(duration[index:index + mod]) + dic[typ] = int(duration[index : index + mod]) except ValueError: if code == "S": try: - dic[typ] = float(duration[index:index + mod]) + dic[typ] = float(duration[index : index + mod]) except ValueError: raise TimeUtilError("Not a float") else: raise TimeUtilError( - "Fractions not allow on anything byt seconds") + "Fractions not allow on anything byt seconds" + ) index = mod + index + 1 except ValueError: dic[typ] = 0 @@ -116,7 +119,7 @@ def parse_duration(duration): def add_duration(tid, duration): (sign, dur) = parse_duration(duration) - if sign == '+': + if sign == "+": # Months temp = tid.tm_mon + dur["tm_mon"] month = modulo(temp, 1, 13) @@ -155,8 +158,9 @@ def add_duration(tid, duration): month = modulo(temp, 1, 13) year += f_quotient(temp, 1, 13) - return time.localtime(time.mktime((year, month, days, hour, minutes, - secs, 0, 0, -1))) + return time.localtime( + time.mktime((year, month, days, hour, minutes, secs, 0, 0, -1)) + ) else: pass @@ -164,8 +168,9 @@ def add_duration(tid, duration): # --------------------------------------------------------------------------- -def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def time_in_a_while( + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 +): """ Return time in a future. @@ -175,13 +180,13 @@ def time_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, :return: UTC time """ - delta = timedelta(days, seconds, microseconds, milliseconds, - minutes, hours, weeks) + delta = timedelta(days, seconds, microseconds, milliseconds, minutes, hours, weeks) return datetime.utcnow() + delta -def time_a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def time_a_while_ago( + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 +): """ Return time in past. @@ -198,13 +203,20 @@ def time_a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, :param time_format: :return: datetime instance """ - delta = timedelta(days, seconds, microseconds, milliseconds, - minutes, hours, weeks) + delta = timedelta(days, seconds, microseconds, milliseconds, minutes, hours, weeks) return datetime.utcnow() - delta -def in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0, time_format=TIME_FORMAT): +def in_a_while( + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, +): """ Return time in a future. @@ -221,12 +233,21 @@ def in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, if not time_format: time_format = TIME_FORMAT - return time_in_a_while(days, seconds, microseconds, milliseconds, - minutes, hours, weeks).strftime(time_format) - - -def a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0, time_format=TIME_FORMAT): + return time_in_a_while( + days, seconds, microseconds, milliseconds, minutes, hours, weeks + ).strftime(time_format) + + +def a_while_ago( + days=0, + seconds=0, + microseconds=0, + milliseconds=0, + minutes=0, + hours=0, + weeks=0, + time_format=TIME_FORMAT, +): """ Return time in past. @@ -240,8 +261,9 @@ def a_while_ago(days=0, seconds=0, microseconds=0, milliseconds=0, :param time_format: :return: Formatet string """ - return time_a_while_ago(days, seconds, microseconds, milliseconds, - minutes, hours, weeks).strftime(time_format) + return time_a_while_ago( + days, seconds, microseconds, milliseconds, minutes, hours, weeks + ).strftime(time_format) # --------------------------------------------------------------------------- @@ -345,8 +367,9 @@ def time_sans_frac(): return int("%d" % time.time()) -def epoch_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, - minutes=0, hours=0, weeks=0): +def epoch_in_a_while( + days=0, seconds=0, microseconds=0, milliseconds=0, minutes=0, hours=0, weeks=0 +): """ Return the number of seconds since epoch a while from now. @@ -359,6 +382,7 @@ def epoch_in_a_while(days=0, seconds=0, microseconds=0, milliseconds=0, :param weeks: :return: Seconds since epoch (1970-01-01) """ - dt = time_in_a_while(days, seconds, microseconds, milliseconds, minutes, - hours, weeks) + dt = time_in_a_while( + days, seconds, microseconds, milliseconds, minutes, hours, weeks + ) return int((dt - datetime(1970, 1, 1)).total_seconds()) diff --git a/src/oic/utils/token_handler.py b/src/oic/utils/token_handler.py index 9c9df6f0b..26871c1ba 100644 --- a/src/oic/utils/token_handler.py +++ b/src/oic/utils/token_handler.py @@ -1,7 +1,7 @@ from oic import rndstr from oic.extension.token import JWTToken -__author__ = 'roland' +__author__ = "roland" class NotAllowed(Exception): @@ -15,8 +15,15 @@ class TokenHandler(object): Note! the token and refresh token factories both keep their own token databases. """ - def __init__(self, issuer, token_policy, token_factory=None, - refresh_token_factory=None, keyjar=None, sign_alg='RS256'): + def __init__( + self, + issuer, + token_policy, + token_factory=None, + refresh_token_factory=None, + keyjar=None, + sign_alg="RS256", + ): """ Initialize the class. @@ -32,15 +39,16 @@ def __init__(self, issuer, token_policy, token_factory=None, """ self.token_policy = token_policy if token_factory is None: - self.token_factory = JWTToken('T', keyjar=keyjar, iss=issuer, - sign_alg=sign_alg) + self.token_factory = JWTToken( + "T", keyjar=keyjar, iss=issuer, sign_alg=sign_alg + ) else: self.token_factory = token_factory if refresh_token_factory is None: - self.refresh_token_factory = JWTToken('R', keyjar=keyjar, - iss='https://example.com/as', - sign_alg=sign_alg) + self.refresh_token_factory = JWTToken( + "R", keyjar=keyjar, iss="https://example.com/as", sign_alg=sign_alg + ) else: self.refresh_token_factory = refresh_token_factory @@ -55,14 +63,20 @@ def get_access_token(self, target_id, scope, grant_type): """ # No default, either there is an explicit policy or there is not try: - lifetime = self.token_policy['access_token'][target_id][grant_type] + lifetime = self.token_policy["access_token"][target_id][grant_type] except KeyError: raise NotAllowed( - 'Access token for grant_type {} for target_id {} not allowed') + "Access token for grant_type {} for target_id {} not allowed" + ) sid = rndstr(32) - return self.token_factory(sid, target_id=target_id, scope=scope, - grant_type=grant_type, lifetime=lifetime) + return self.token_factory( + sid, + target_id=target_id, + scope=scope, + grant_type=grant_type, + lifetime=lifetime, + ) def refresh_access_token(self, target_id, token, grant_type, **kwargs): """ @@ -83,35 +97,38 @@ def refresh_access_token(self, target_id, token, grant_type, **kwargs): if target_id != info["azr"]: raise NotAllowed("{} can't use this token".format(target_id)) except KeyError: - if target_id not in info['aud']: + if target_id not in info["aud"]: raise NotAllowed("{} can't use this token".format(target_id)) if self.token_factory.is_valid(info): try: - lifetime = self.token_policy['access_token'][target_id][ - grant_type] + lifetime = self.token_policy["access_token"][target_id][grant_type] except KeyError: raise NotAllowed( - 'Issue access token for grant_type {} for target_id {} not allowed') + "Issue access token for grant_type {} for target_id {} not allowed" + ) else: - sid = self.token_factory.db[info['jti']] + sid = self.token_factory.db[info["jti"]] try: - _aud = kwargs['aud'] + _aud = kwargs["aud"] except KeyError: - _aud = info['aud'] + _aud = info["aud"] return self.token_factory( - sid, target_id=target_id, lifetime=lifetime, aud=_aud) + sid, target_id=target_id, lifetime=lifetime, aud=_aud + ) def get_refresh_token(self, target_id, grant_type, sid): try: - lifetime = self.token_policy['refresh_token'][target_id][grant_type] + lifetime = self.token_policy["refresh_token"][target_id][grant_type] except KeyError: raise NotAllowed( - 'Issue access token for grant_type {} for target_id {} not allowed') + "Issue access token for grant_type {} for target_id {} not allowed" + ) else: return self.refresh_token_factory( - sid, target_id=target_id, lifetime=lifetime) + sid, target_id=target_id, lifetime=lifetime + ) def invalidate(self, token): if self.token_factory.valid(token): diff --git a/src/oic/utils/userinfo/__init__.py b/src/oic/utils/userinfo/__init__.py index 0ca39d7f1..dbd7031eb 100644 --- a/src/oic/utils/userinfo/__init__.py +++ b/src/oic/utils/userinfo/__init__.py @@ -1,6 +1,6 @@ import copy -__author__ = 'rolandh' +__author__ = "rolandh" class UserInfo(object): diff --git a/src/oic/utils/userinfo/aa_info.py b/src/oic/utils/userinfo/aa_info.py index de456a155..d4890aade 100644 --- a/src/oic/utils/userinfo/aa_info.py +++ b/src/oic/utils/userinfo/aa_info.py @@ -3,14 +3,18 @@ from oic.utils.userinfo import UserInfo -__author__ = 'danielevertsson' +__author__ = "danielevertsson" try: from saml2.client import Saml2Client except ImportError: + class AaUserInfo(UserInfo): pass + + else: + class AaUserInfo(UserInfo): # type: ignore def __init__(self, spconf, url, db=None): UserInfo.__init__(self, db) @@ -33,14 +37,16 @@ def __call__(self, userid, client_id, user_info_claims=None, **kwargs): entity_id, ava[self.sp_conf.AA_NAMEID_ATTRIBUTE][0], nameid_format=self.sp_conf.AA_NAMEID_FORMAT, - attribute=self.sp_conf.AA_REQUEST_ATTRIBUTES) + attribute=self.sp_conf.AA_REQUEST_ATTRIBUTES, + ) response_dict = response.ava.copy() if self.sp_conf.AA_ATTRIBUTE_SAML_IDP is True: for key, value in ava.items(): - if (self.sp_conf.AA_ATTRIBUTE_SAML_IDP_WHITELIST is None or - key in self.sp_conf.AA_ATTRIBUTE_SAML_IDP_WHITELIST) and \ - key not in response_dict: + if ( + self.sp_conf.AA_ATTRIBUTE_SAML_IDP_WHITELIST is None + or key in self.sp_conf.AA_ATTRIBUTE_SAML_IDP_WHITELIST + ) and key not in response_dict: response_dict[key] = value return response_dict diff --git a/src/oic/utils/userinfo/distaggr.py b/src/oic/utils/userinfo/distaggr.py index fdcf37ad1..e42e5174a 100644 --- a/src/oic/utils/userinfo/distaggr.py +++ b/src/oic/utils/userinfo/distaggr.py @@ -7,7 +7,7 @@ from oic.utils.sanitize import sanitize from oic.utils.userinfo import UserInfo -__author__ = 'rolandh' +__author__ = "rolandh" logger = logging.getLogger(__name__) @@ -52,8 +52,7 @@ def init_claims_clients(self, client_info): def _collect_distributed(self, srv, cc, sub, what, alias=""): try: - resp = cc.do_claims_request(request_args={"sub": sub, - "claims_names": what}) + resp = cc.do_claims_request(request_args={"sub": sub, "claims_names": what}) except Exception: raise @@ -70,8 +69,7 @@ def _collect_distributed(self, srv, cc, sub, what, alias=""): else: result["_claims_sources"][alias] = {"endpoint": resp["endpoint"]} if "access_token" in resp: - result["_claims_sources"][alias]["access_token"] = resp[ - "access_token"] + result["_claims_sources"][alias]["access_token"] = resp["access_token"] return result @@ -118,13 +116,11 @@ def __call__(self, userid, client_id, user_info_claims=None, **kwargs): pass if remaining: - raise MissingAttribute( - "Missing properties '%s'" % remaining) + raise MissingAttribute("Missing properties '%s'" % remaining) for srv, what in cpoints.items(): cc = self.oidcsrv.claims_clients[srv] - logger.debug("srv: %s, what: %s" % (sanitize(srv), - sanitize(what))) + logger.debug("srv: %s, what: %s" % (sanitize(srv), sanitize(what))) _res = self._collect_distributed(srv, cc, userid, what) logger.debug("Got: %s" % sanitize(_res)) for key, val in _res.items(): diff --git a/src/oic/utils/userinfo/ldap_info.py b/src/oic/utils/userinfo/ldap_info.py index aac4b40e6..cebd771e7 100644 --- a/src/oic/utils/userinfo/ldap_info.py +++ b/src/oic/utils/userinfo/ldap_info.py @@ -1,14 +1,14 @@ try: import ldap except ImportError: - raise ImportError('This module can be used only with pyldap installed.') + raise ImportError("This module can be used only with pyldap installed.") import logging from oic.utils.sanitize import sanitize from oic.utils.userinfo import UserInfo -__author__ = 'rolandh' +__author__ = "rolandh" logger = logging.getLogger(__name__) @@ -32,13 +32,24 @@ "phone_number": "telephoneNumber", # phone_number_verified "address": "postalAddress", - "updated_at": "" # Nothing equivalent + "updated_at": "", # Nothing equivalent } class UserInfoLDAP(UserInfo): - def __init__(self, uri, base, filter_pattern, scope=ldap.SCOPE_SUBTREE, - tls=False, user="", passwd="", attr=None, attrsonly=False, attrmap=OPENID2LDAP): + def __init__( + self, + uri, + base, + filter_pattern, + scope=ldap.SCOPE_SUBTREE, + tls=False, + user="", + passwd="", + attr=None, + attrsonly=False, + attrmap=OPENID2LDAP, + ): super(UserInfoLDAP, self).__init__(None) self.ldapuri = uri self.base = base @@ -61,8 +72,9 @@ def bind(self): self.ld.start_tls_s() self.ld.simple_bind_s(self.ldapuser, self.ldappasswd) - def __call__(self, userid, client_id, user_info_claims=None, - first_only=True, **kwargs): + def __call__( + self, userid, client_id, user_info_claims=None, first_only=True, **kwargs + ): _filter = self.filter_pattern % userid logger.debug("CLAIMS: %s" % sanitize(user_info_claims)) _attr = self.attr diff --git a/src/oic/utils/webfinger.py b/src/oic/utils/webfinger.py index b7391b4c5..5e2ede65f 100644 --- a/src/oic/utils/webfinger.py +++ b/src/oic/utils/webfinger.py @@ -11,7 +11,7 @@ from oic.exception import PyoidcError from oic.utils.time_util import in_a_while -__author__ = 'rolandh' +__author__ = "rolandh" logger = logging.getLogger(__name__) @@ -133,8 +133,7 @@ class JRD(Base): "links": {"type": (list, LINK), "required": False}, # Optional } - def __init__(self, dic=None, days=0, seconds=0, minutes=0, hours=0, - weeks=0): + def __init__(self, dic=None, days=0, seconds=0, minutes=0, hours=0, weeks=0): Base.__init__(self, dic) self.expires_in(days, seconds, minutes, hours, weeks) @@ -147,9 +146,13 @@ def expires_in(self, days=0, seconds=0, minutes=0, hours=0, weeks=0): def export(self): res = self.dump() - res["expires"] = in_a_while(days=self._exp_days, seconds=self._exp_secs, - minutes=self._exp_min, hours=self._exp_hour, - weeks=self._exp_week) + res["expires"] = in_a_while( + days=self._exp_days, + seconds=self._exp_secs, + minutes=self._exp_min, + hours=self._exp_hour, + weeks=self._exp_week, + ) return res @@ -180,26 +183,27 @@ def export(self): # [ userinfo "@" ] host [ ":" port ], it is legal to have a user input # identifier like userinfo@host:port, e.g., alice@example.com:8080. + class URINormalizer(object): def has_scheme(self, inp): if "://" in inp: return True else: - authority = inp.replace('/', '#').replace('?', '#').split("#")[0] + authority = inp.replace("/", "#").replace("?", "#").split("#")[0] - if ':' in authority: - _, host_or_port = authority.split(':', 1) + if ":" in authority: + _, host_or_port = authority.split(":", 1) # Assert it's not a port number - if re.match(r'^\d+$', host_or_port): + if re.match(r"^\d+$", host_or_port): return False else: return False return True def acct_scheme_assumed(self, inp): - if '@' in inp: - host = inp.split('@')[-1] - return not (':' in host or '/' in host or '?' in host) + if "@" in inp: + host = inp.split("@")[-1] + return not (":" in host or "/" in host or "?" in host) else: return False @@ -240,10 +244,10 @@ def query(self, resource, rel=None): if part.port is not None: host += ":" + str(part.port) elif resource.startswith("acct:"): - host = resource.split('@')[-1] - host = host.replace('/', '#').replace('?', '#').split("#")[0] + host = resource.split("@")[-1] + host = host.replace("/", "#").replace("?", "#").split("#")[0] elif resource.startswith("device:"): - host = resource.split(':')[1] + host = resource.split(":")[1] else: raise WebFingerError("Unknown schema") @@ -261,9 +265,11 @@ def http_args(self, jrd=None): return None return { - "headers": {"Access-Control-Allow-Origin": "*", - "Content-Type": "application/json; charset=UTF-8"}, - "body": json.dumps(jrd.export()) + "headers": { + "Access-Control-Allow-Origin": "*", + "Content-Type": "application/json; charset=UTF-8", + }, + "body": json.dumps(jrd.export()), } def discovery_query(self, resource): @@ -282,15 +288,15 @@ def discovery_query(self, resource): if rsp.status_code == 200: if self.events: - self.events.store('Response', rsp.text) + self.events.store("Response", rsp.text) self.jrd = self.load(rsp.text) if self.events: - self.events.store('JRD Response', self.jrd) + self.events.store("JRD Response", self.jrd) for link in self.jrd["links"]: if link["rel"] == OIC_ISSUER: - if not link['href'].startswith('https://'): - raise WebFingerError('Must be a HTTPS href') + if not link["href"].startswith("https://"): + raise WebFingerError("Must be a HTTPS href") return link["href"] return None elif rsp.status_code in [302, 301, 307]: diff --git a/tox.ini b/tox.ini index a3b9f1ae5..bd910e207 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py{34,35,36,37},docs,quality +envlist = py{35,36,37},docs,quality [testenv] passenv = CI TRAVIS TRAVIS_* @@ -17,12 +17,15 @@ extras = docs commands = sphinx-build -b html doc/ doc/_build/html -W [testenv:quality] +# Black need python 3.6 +basepython = python3.6 ignore_errors = True extras = quality commands = isort --recursive --diff --check-only src/ tests/ pylama src/ tests/ mypy --config-file mypy.ini src/ tests/ + black src/ --check [pep8] max-line-length=100