Skip to content

Commit dfbbc66

Browse files
committed
Merge remote-tracking branch 'oauth2cli/dev' into acf
2 parents 8ea9ba5 + 0e512b1 commit dfbbc66

File tree

2 files changed

+57
-13
lines changed

2 files changed

+57
-13
lines changed

msal/oauth2cli/oauth2.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -475,26 +475,55 @@ def obtain_token_by_auth_code_flow(
475475
* A dict containing "error", optionally "error_description", "error_uri".
476476
(It is either `this <https://tools.ietf.org/html/rfc6749#section-4.1.2.1>`_
477477
or `that <https://tools.ietf.org/html/rfc6749#section-5.2>`_
478+
* Most client-side data error would result in ValueError exception.
479+
So the usage pattern could be without any protocol details::
480+
481+
def authorize(): # A controller in a web app
482+
try:
483+
result = client.obtain_token_by_auth_code_flow(
484+
session.get("flow", {}), auth_resp)
485+
if "error" in result:
486+
return render_template("error.html", result)
487+
store_tokens()
488+
except ValueError: # Usually caused by CSRF
489+
pass # Simply ignore them
490+
return redirect(url_for("index"))
478491
"""
492+
assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict)
493+
# This is app developer's error which we do NOT want to map to ValueError
494+
if not auth_code_flow.get("state"):
495+
# initiate_auth_code_flow() already guarantees a state to be available.
496+
# This check will also allow a web app to blindly call this method with
497+
# obtain_token_by_auth_code_flow(session.get("flow", {}), auth_resp)
498+
# which further simplifies their usage.
499+
raise ValueError("state missing from auth_code_flow")
479500
if auth_code_flow.get("state") != auth_response.get("state"):
480501
raise ValueError("state mismatch: {} vs {}".format(
481502
auth_code_flow.get("state"), auth_response.get("state")))
482-
if auth_response.get("error"): # It means the first leg encountered error
483-
return auth_response
484503
if scope and set(scope) - set(auth_code_flow.get("scope", [])):
485504
raise ValueError(
486505
"scope must be None or a subset of %s" % auth_code_flow.get("scope"))
487-
assert auth_response.get("code"), "First leg's response should have code"
488-
return self._obtain_token_by_authorization_code(
489-
auth_response["code"],
490-
redirect_uri=auth_code_flow.get("redirect_uri"),
491-
# Required, if "redirect_uri" parameter was included in the
492-
# authorization request, and their values MUST be identical.
493-
scope=scope or auth_code_flow.get("scope"),
494-
# It is both unnecessary and harmless, per RFC 6749.
495-
# We use the same scope already used in auth request uri,
496-
# thus token cache can know what scope the tokens are for.
497-
**kwargs)
506+
if auth_response.get("code"): # i.e. the first leg was successful
507+
return self._obtain_token_by_authorization_code(
508+
auth_response["code"],
509+
redirect_uri=auth_code_flow.get("redirect_uri"),
510+
# Required, if "redirect_uri" parameter was included in the
511+
# authorization request, and their values MUST be identical.
512+
scope=scope or auth_code_flow.get("scope"),
513+
# It is both unnecessary and harmless, per RFC 6749.
514+
# We use the same scope already used in auth request uri,
515+
# thus token cache can know what scope the tokens are for.
516+
**kwargs)
517+
if auth_response.get("error"): # It means the first leg encountered error
518+
# Here we do NOT return original auth_response as-is, to prevent a
519+
# potential {..., "access_token": "attacker's AT"} input being leaked
520+
error = {"error": auth_response["error"]}
521+
if auth_response.get("error_description"):
522+
error["error_description"] = auth_response["error_description"]
523+
if auth_response.get("error_uri"):
524+
error["error_uri"] = auth_response["error_uri"]
525+
return error
526+
raise ValueError('auth_response must contain either "code" or "error"')
498527

499528
@staticmethod
500529
def parse_auth_response(params, state=None):

tests/test_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,21 @@ def test_auth_code_flow(self):
180180
#TBD: data={"resource": CONFIG.get("resource")}, # MSFT AAD v1 only
181181
self.assertLoosely(result, lambda: self.assertIn('access_token', result))
182182

183+
def test_auth_code_flow_error_response(self):
184+
with self.assertRaisesRegexp(ValueError, "state missing"):
185+
self.client.obtain_token_by_auth_code_flow({}, {"code": "foo"})
186+
with self.assertRaisesRegexp(ValueError, "state mismatch"):
187+
self.client.obtain_token_by_auth_code_flow({"state": "1"}, {"state": "2"})
188+
with self.assertRaisesRegexp(ValueError, "scope"):
189+
self.client.obtain_token_by_auth_code_flow(
190+
{"state": "s", "scope": ["foo"]}, {"state": "s"}, scope=["bar"])
191+
self.assertEqual(
192+
{"error": "foo", "error_uri": "bar"},
193+
self.client.obtain_token_by_auth_code_flow(
194+
{"state": "s"},
195+
{"state": "s", "error": "foo", "error_uri": "bar", "access_token": "fake"}),
196+
"We should not leak malicious input into our output")
197+
183198
@unittest.skipUnless(
184199
CONFIG.get("openid_configuration", {}).get("device_authorization_endpoint"),
185200
"device_authorization_endpoint is missing")

0 commit comments

Comments
 (0)