Skip to content

Commit 301b79c

Browse files
Type spec improvements in rabbit_auth_backend_oauth2
1 parent d6366a3 commit 301b79c

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
%% End of Key JWT fields
2424

25+
-type raw_jwt_token() :: binary() | #{binary() => any()}.
26+
-type decoded_jwt_token() :: #{binary() => any()}.
2527

2628
-record(internal_oauth_provider, {
2729
id :: oauth_provider_id(),

deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ description() ->
5858

5959
%%--------------------------------------------------------------------
6060

61+
-spec user_login_authentication(rabbit_types:username(), [term()] | map()) ->
62+
{'ok', rabbit_types:auth_user()} |
63+
{'refused', string(), [any()]} |
64+
{'error', any()}.
65+
6166
user_login_authentication(Username, AuthProps) ->
6267
case authenticate(Username, AuthProps) of
6368
{refused, Msg, Args} = AuthResult ->
@@ -67,12 +72,21 @@ user_login_authentication(Username, AuthProps) ->
6772
AuthResult
6873
end.
6974

75+
-spec user_login_authorization(rabbit_types:username(), [term()] | map()) ->
76+
{'ok', any()} |
77+
{'ok', any(), any()} |
78+
{'refused', string(), [any()]} |
79+
{'error', any()}.
80+
7081
user_login_authorization(Username, AuthProps) ->
7182
case authenticate(Username, AuthProps) of
7283
{ok, #auth_user{impl = Impl}} -> {ok, Impl};
7384
Else -> Else
7485
end.
7586

87+
-spec check_vhost_access(AuthUser :: rabbit_types:auth_user(),
88+
VHost :: rabbit_types:vhost(),
89+
AuthzData :: rabbit_types:authz_data()) -> boolean() | {'error', any()}.
7690
check_vhost_access(#auth_user{impl = DecodedTokenFun},
7791
VHost, _AuthzData) ->
7892
with_decoded_token(DecodedTokenFun(),
@@ -136,6 +150,11 @@ expiry_timestamp(#auth_user{impl = DecodedTokenFun}) ->
136150

137151
%%--------------------------------------------------------------------
138152

153+
-spec authenticate(Username, Props) -> Result
154+
when Username :: rabbit_types:username(),
155+
Props :: list() | map(),
156+
Result :: {ok, any()} | {refused, list(), list()} | {refused, {error, any()}}.
157+
139158
authenticate(_, AuthProps0) ->
140159
AuthProps = to_map(AuthProps0),
141160
Token = token_from_context(AuthProps),
@@ -148,31 +167,45 @@ authenticate(_, AuthProps0) ->
148167
{refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid", []};
149168
{refused, Err} ->
150169
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
151-
{ok, DecodedToken} ->
152-
Func = fun(Token0) ->
153-
Username = username_from(
154-
ResourceServer#resource_server.preferred_username_claims,
155-
Token0),
156-
Tags = tags_from(Token0),
157-
{ok, #auth_user{username = Username,
158-
tags = Tags,
159-
impl = fun() -> Token0 end}}
160-
end,
161-
case with_decoded_token(DecodedToken, Func) of
170+
{ok, DecodedToken} ->
171+
case with_decoded_token(DecodedToken, fun(In) -> auth_user_from_token(In, ResourceServer) end) of
162172
{error, Err} ->
163173
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
164174
Else ->
165175
Else
166176
end
167177
end
168178
end.
179+
180+
-type ok_extracted_auth_user() :: {ok, rabbit_types:auth_user()}.
181+
-type auth_user_extraction_fun() :: fun((decoded_jwt_token()) -> any()).
182+
183+
-spec with_decoded_token(Token, Fun) -> Result
184+
when Token :: decoded_jwt_token(),
185+
Fun :: auth_user_extraction_fun(),
186+
Result :: {ok, any()} | {'error', any()}.
169187
with_decoded_token(DecodedToken, Fun) ->
170188
case validate_token_expiry(DecodedToken) of
171189
ok -> Fun(DecodedToken);
172190
{error, Msg} = Err ->
173191
rabbit_log:error(Msg),
174192
Err
175193
end.
194+
195+
%% This is a helper function used with HOFs that may return errors.
196+
-spec auth_user_from_token(Token, ResourceServer) -> Result
197+
when Token :: decoded_jwt_token(),
198+
ResourceServer :: resource_server(),
199+
Result :: ok_extracted_auth_user().
200+
auth_user_from_token(Token0, ResourceServer) ->
201+
Username = username_from(
202+
ResourceServer#resource_server.preferred_username_claims,
203+
Token0),
204+
Tags = tags_from(Token0),
205+
{ok, #auth_user{username = Username,
206+
tags = Tags,
207+
impl = fun() -> Token0 end}}.
208+
176209
ensure_same_username(PreferredUsernameClaims, CurrentDecodedToken, NewDecodedToken) ->
177210
CurUsername = username_from(PreferredUsernameClaims, CurrentDecodedToken),
178211
case {CurUsername, username_from(PreferredUsernameClaims, NewDecodedToken)} of
@@ -188,12 +221,10 @@ validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) ->
188221
end;
189222
validate_token_expiry(#{}) -> ok.
190223

191-
-spec check_token(binary() | map(), {resource_server(), internal_oauth_provider()}) ->
192-
{'ok', map()} |
193-
{'error', term() }|
194-
{'refused', 'signature_invalid' |
195-
{'error', term()} |
196-
{'invalid_aud', term()}}.
224+
-spec check_token(raw_jwt_token(), {resource_server(), internal_oauth_provider()}) ->
225+
{'ok', decoded_jwt_token()} |
226+
{'error', term() } |
227+
{'refused', 'signature_invalid' | {'error', term()} | {'invalid_aud', term()}}.
197228

198229
check_token(DecodedToken, _) when is_map(DecodedToken) ->
199230
{ok, DecodedToken};
@@ -206,7 +237,7 @@ check_token(Token, {ResourceServer, InternalOAuthProvider}) ->
206237
end.
207238

208239
-spec normalize_token_scope(
209-
ResourceServer :: resource_server(), DecodedToken :: map()) -> map().
240+
ResourceServer :: resource_server(), DecodedToken :: decoded_jwt_token()) -> map().
210241
normalize_token_scope(ResourceServer, Payload) ->
211242
Payload0 = maps:map(fun(K, V) ->
212243
case K of
@@ -395,7 +426,7 @@ resolve_scope_var(Elem, Token, Vhost) ->
395426
end)
396427
end.
397428

398-
-spec tags_from(map()) -> list(atom()).
429+
-spec tags_from(decoded_jwt_token()) -> list(atom()).
399430
tags_from(DecodedToken) ->
400431
Scopes = maps:get(?SCOPE_JWT_FIELD, DecodedToken, []),
401432
TagScopes = filter_matching_scope_prefix_and_drop_it(Scopes, ?TAG_SCOPE_PREFIX),

0 commit comments

Comments
 (0)