diff --git a/tests/conftest.py b/tests/conftest.py index 270d266fe..32b623e4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,16 +23,6 @@ def session_db(session_db_factory): return session_db_factory("https://op.example.com") -@pytest.fixture -def mitm_server(session_db_factory): - from tests.mitmsrv import MITMServer - - def fac(name): - return MITMServer(name, session_db_factory=session_db_factory) - - return fac - - @pytest.fixture def provider(session_db): issuer = "https://op.example.com" diff --git a/tests/mitmsrv.py b/tests/mitmsrv.py deleted file mode 100644 index 7a4d360bb..000000000 --- a/tests/mitmsrv.py +++ /dev/null @@ -1,347 +0,0 @@ -from typing import Any # noqa -from typing import Dict # noqa -from typing import Union # noqa -from urllib.parse import parse_qs -from urllib.parse import urlparse - -from jwkest import jws -from jwkest.jws import alg2keytype - -from oic import rndstr -from oic.oauth2.message import AccessTokenRequest # noqa -from oic.oauth2.message import RefreshAccessTokenRequest # noqa -from oic.oauth2.message import by_schema -from oic.oic import Server -from oic.oic.message import AccessTokenResponse -from oic.oic.message import AuthorizationResponse -from oic.oic.message import EndSessionResponse -from oic.oic.message import OpenIDSchema -from oic.oic.message import ProviderConfigurationResponse -from oic.oic.message import RegistrationResponse -from oic.oic.message import TokenErrorResponse -from oic.utils.sdb import AuthnEvent -from oic.utils.time_util import utc_time_sans_frac -from oic.utils.webfinger import WebFinger - -__author__ = "rohe0002" - - -class Response: - headers = None # type: Dict[str, str] - text = None # type: str - - def __init__(self, base=None): - self.status_code = 200 - if base: - for key, val in base.items(): - self.__setitem__(key, val) - - def __setitem__(self, key, value): - setattr(self, key, value) - - def __getitem__(self, item): - return getattr(self, item) - - -ENDPOINT = { - "authorization_endpoint": "/authorization", - "token_endpoint": "/token", - "user_info_endpoint": "/userinfo", - "check_session_endpoint": "/check_session", - "refresh_session_endpoint": "/refresh_session", - "end_session_endpoint": "/end_session", - "registration_endpoint": "/registration", - "discovery_endpoint": "/discovery", - "register_endpoint": "/register", -} - - -class MITMServer(Server): - def __init__(self, name="", session_db_factory=None): - Server.__init__(self) - self.sdb = session_db_factory(name) - self.name = name - self.client = {} # type: Dict[str, Dict[str, Any]] - self.registration_expires_in = 3600 - self.host = "" - self.webfinger = WebFinger() - self.userinfo_signed_response_alg = "" - - def http_request(self, path, method="GET", **kwargs): - part = urlparse(path) - path = part[2] - query = part[4] - self.host = "%s://%s" % (part.scheme, part.netloc) - - response = Response - response.status_code = 500 - response.text = "" - - if path == ENDPOINT["authorization_endpoint"]: - assert method == "GET" - response = self.authorization_endpoint(query) - elif path == ENDPOINT["token_endpoint"]: - assert method == "POST" - response = self.token_endpoint(kwargs["data"]) - elif path == ENDPOINT["user_info_endpoint"]: - assert method == "POST" - response = self.userinfo_endpoint(kwargs["data"]) - elif path == ENDPOINT["refresh_session_endpoint"]: - assert method == "GET" - response = self.refresh_session_endpoint(query) - elif path == ENDPOINT["check_session_endpoint"]: - assert method == "GET" - response = self.check_session_endpoint(query) - elif path == ENDPOINT["end_session_endpoint"]: - assert method == "GET" - response = self.end_session_endpoint(query) - elif path == ENDPOINT["registration_endpoint"]: - if method == "POST": - response = self.registration_endpoint(kwargs["data"]) - elif path == "/.well-known/webfinger": - assert method == "GET" - qdict = parse_qs(query) - response.status_code = 200 - response.text = self.webfinger.response( - qdict["resource"][0], "%s/" % self.name - ) - elif path == "/.well-known/openid-configuration": - assert method == "GET" - response = self.openid_conf() - - return response - - def authorization_endpoint(self, query): - req = self.parse_authorization_request(query=query) - aevent = AuthnEvent("user", "salt", authn_info="acr") - sid = self.sdb.create_authz_session(aevent, areq=req) - self.sdb.do_sub(sid, "client_salt") - _info = self.sdb[sid] - - if "code" in req["response_type"]: - if "token" in req["response_type"]: - grant = _info["code"] - _dict = self.sdb.upgrade_to_token(grant) - _dict["oauth_state"] = ("authz",) - - _dict = by_schema(AuthorizationResponse(), **_dict) - resp = AuthorizationResponse(**_dict) - else: - _state = req["state"] - resp = AuthorizationResponse(state=_state, code=_info["code"]) - - else: # "implicit" in req.response_type: - grant = _info["code"] - params = AccessTokenResponse.c_param.keys() - - if "token" in req["response_type"]: - _dict = dict( - [ - (k, v) - for k, v in self.sdb.upgrade_to_token(grant).items() - if k in params - ] - ) - try: - del _dict["refresh_token"] - except KeyError: - pass - else: - _dict = {"state": req["state"]} - - if "id_token" in req["response_type"]: - _idt = self.make_id_token(_info, issuer=self.name) - alg = "RS256" - ckey = self.keyjar.get_signing_key(alg2keytype(alg), _info["client_id"]) - _signed_jwt = _idt.to_jwt(key=ckey, algorithm=alg) - p = _signed_jwt.split(".") - p[2] = "aaa" - _dict["id_token"] = ".".join(p) - - resp = AuthorizationResponse(**_dict) - - location = resp.request(req["redirect_uri"]) - response = Response() - response.headers = {"location": location} - response.status_code = 302 - response.text = "" - return response - - def token_endpoint(self, data): - if "grant_type=refresh_token" in data: - req = self.parse_refresh_token_request( - body=data - ) # type: Union[AccessTokenRequest, RefreshAccessTokenRequest] - _info = self.sdb.refresh_token(req["refresh_token"]) - elif "grant_type=authorization_code" in data: - req = self.parse_token_request(body=data) - _info = self.sdb.upgrade_to_token(req["code"]) - else: - response = TokenErrorResponse(error="unsupported_grant_type") - return response, "" - - resp = AccessTokenResponse(**by_schema(AccessTokenResponse, **_info)) - response2 = Response() - response2.headers = {"content-type": "application/json"} - response2.text = resp.to_json() - - return response2 - - def userinfo_endpoint(self, data): - - self.parse_user_info_request(data) - _info = { - "sub": "melgar", - "name": "Melody Gardot", - "nickname": "Mel", - "email": "mel@example.com", - "verified": True, - } - - resp = OpenIDSchema(**_info) - response = Response() - - if self.userinfo_signed_response_alg: - alg = self.userinfo_signed_response_alg - response.headers = {"content-type": "application/jwt"} - key = self.keyjar.get_signing_key(alg2keytype(alg), "", alg=alg) - response.text = resp.to_jwt(key, alg) - else: - response.headers = {"content-type": "application/json"} - response.text = resp.to_json() - - return response - - def registration_endpoint(self, data): - try: - req = self.parse_registration_request(data, "json") - except ValueError: - req = self.parse_registration_request(data) - - client_secret = rndstr() - expires = utc_time_sans_frac() + self.registration_expires_in - kwargs = {} # type: Dict[str, str] - if "client_id" not in req: - client_id = rndstr(10) - registration_access_token = rndstr(20) - _client_info = req.to_dict() - kwargs.update(_client_info) - _client_info.update( - { - "client_secret": client_secret, - "info": req.to_dict(), - "expires": expires, - "registration_access_token": registration_access_token, - "registration_client_uri": "register_endpoint", - } - ) - self.client[client_id] = _client_info - kwargs["registration_access_token"] = registration_access_token - kwargs["registration_client_uri"] = "register_endpoint" - try: - del kwargs["operation"] - except KeyError: - pass - else: - client_id = req.client_id - _cinfo = self.client[req.client_id] - _cinfo["info"].update(req.to_dict()) - _cinfo["client_secret"] = client_secret - _cinfo["expires"] = expires - - resp = RegistrationResponse( - client_id=client_id, - client_secret=client_secret, - client_secret_expires_at=expires, - **kwargs - ) - - response = Response() - response.headers = {"content-type": "application/json"} - response.text = resp.to_json() - - return response - - def check_session_endpoint(self, query): - try: - idtoken = self.parse_check_session_request(query=query) - except Exception: - raise - - response = Response() - response.text = idtoken.to_json() - response.headers = {"content-type": "application/json"} - return response - - def refresh_session_endpoint(self, query): - try: - self.parse_refresh_session_request(query=query) - except Exception: - raise - - resp = RegistrationResponse(client_id="anonymous", client_secret="hemligt") - - response = Response() - response.headers = {"content-type": "application/json"} - response.text = resp.to_json() - return response - - def end_session_endpoint(self, query): - try: - req = self.parse_end_session_request(query=query) - except Exception: - raise - - # redirect back - resp = EndSessionResponse(state=req["state"]) - - url = resp.request(req["redirect_url"]) - - response = Response() - response.headers = {"location": url} - response.status_code = 302 # redirect - response.text = "" - return response - - @staticmethod - def add_credentials(user, passwd): - return - - def openid_conf(self): - endpoint = {} - for point, path in ENDPOINT.items(): - endpoint[point] = "%s%s" % (self.host, path) - - signing_algs = jws.SIGNER_ALGS.keys() - resp = ProviderConfigurationResponse( - issuer=self.name, - scopes_supported=["openid", "profile", "email", "address"], - identifiers_supported=["public", "PPID"], - flows_supported=[ - "code", - "token", - "code token", - "id_token", - "code id_token", - "token id_token", - ], - subject_types_supported=["pairwise", "public"], - response_types_supported=[ - "code", - "token", - "id_token", - "code token", - "code id_token", - "token id_token", - "code token id_token", - ], - jwks_uri="http://example.com/oidc/jwks", - id_token_signing_alg_values_supported=signing_algs, - grant_types_supported=["authorization_code", "implicit"], - **endpoint - ) - - response = Response() - response.headers = {"content-type": "application/json"} - response.text = resp.to_json() - return response diff --git a/tests/test_oic_consumer.py b/tests/test_oic_consumer.py index 45d0873b9..a012ee5ba 100644 --- a/tests/test_oic_consumer.py +++ b/tests/test_oic_consumer.py @@ -835,11 +835,7 @@ def test_faulty_id_token_in_access_token_response(self): with pytest.raises(ValueError): c.parse_response(AccessTokenResponse, _json, sformat="json") - def test_faulty_idtoken_from_accesstoken_endpoint(self, mitm_server): - mfos = mitm_server("http://localhost:8088") - mfos.keyjar = SRVKEYS - # FIXME: Drop the MITM server in favor of responses - self.consumer.http_request = mfos.http_request # type: ignore + def test_faulty_idtoken_from_accesstoken_endpoint(self): _state = "state0" self.consumer.consumer_config["response_type"] = ["id_token"] @@ -849,7 +845,23 @@ def test_faulty_idtoken_from_accesstoken_endpoint(self, mitm_server): "scope": ["openid"], } - result = self.consumer.do_authorization_request(state=_state, request_args=args) + location = ( + "https://example.com/cb?state=state0&id_token=eyJhbGciOiJSUzI1NiJ9" + ".eyJpc3MiOiAiaHR0cDovL2xvY2FsaG9zdDo4MDg4IiwgInN1YiI6ICJhNWRkMjRiMmYwOGE2ODZmZDM4NmMyMmM" + "zZmY4ZWUyODFlZjJmYmZmMWZkZTcwMDg2NjhjZGEzZGVjZmE0NjY5IiwgImF1ZCI6IFsiY2xpZW50XzEiXSwgImV" + "4cCI6IDE1NzIwOTk5NjAsICJhY3IiOiAiMiIsICJpYXQiOiAxNTcyMDEzNTYwLCAibm9uY2UiOiAibmdFTGZVdmN" + "PMWoyaXNWcXkwQWNwM0NOYlZnMGdFRDEifQ.aaa" + ) + with responses.RequestsMock() as rsps: + rsps.add( + responses.GET, + "https://example.com/authorization", + status=302, + headers={"location": location}, + ) + result = self.consumer.do_authorization_request( + state=_state, request_args=args + ) self.consumer._backup("state0") assert result.status_code == 302