Skip to content

Commit 534e4f1

Browse files
Merge pull request #12851 from rabbitmq/mk-oauth2-type-spec-improvements
Type spec improvements in rabbit_auth_backend_oauth2
2 parents d6366a3 + 719b556 commit 534e4f1

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
%% End of Key JWT fields
2424

25+
-type raw_jwt_token() :: binary() | #{binary() => any()}.
26+
-type decoded_jwt_token() :: #{binary() => any()}.
2527

2628
-record(internal_oauth_provider, {
2729
id :: oauth_provider_id(),

deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,11 @@
3838
-endif.
3939

4040
%%
41-
%% App environment
41+
%% Types
4242
%%
4343

44-
45-
%% a term defined for Rich Authorization Request tokens to identify a RabbitMQ permission
46-
%% verify server_server_id aud field is on the aud field
47-
%% a term used by the IdentityServer community
48-
%% scope aliases map "role names" to a set of scopes
49-
44+
-type ok_extracted_auth_user() :: {ok, rabbit_types:auth_user()}.
45+
-type auth_user_extraction_fun() :: fun((decoded_jwt_token()) -> any()).
5046

5147
%%
5248
%% API
@@ -58,6 +54,11 @@ description() ->
5854

5955
%%--------------------------------------------------------------------
6056

57+
-spec user_login_authentication(rabbit_types:username(), [term()] | map()) ->
58+
{'ok', rabbit_types:auth_user()} |
59+
{'refused', string(), [any()]} |
60+
{'error', any()}.
61+
6162
user_login_authentication(Username, AuthProps) ->
6263
case authenticate(Username, AuthProps) of
6364
{refused, Msg, Args} = AuthResult ->
@@ -67,12 +68,21 @@ user_login_authentication(Username, AuthProps) ->
6768
AuthResult
6869
end.
6970

71+
-spec user_login_authorization(rabbit_types:username(), [term()] | map()) ->
72+
{'ok', any()} |
73+
{'ok', any(), any()} |
74+
{'refused', string(), [any()]} |
75+
{'error', any()}.
76+
7077
user_login_authorization(Username, AuthProps) ->
7178
case authenticate(Username, AuthProps) of
7279
{ok, #auth_user{impl = Impl}} -> {ok, Impl};
7380
Else -> Else
7481
end.
7582

83+
-spec check_vhost_access(AuthUser :: rabbit_types:auth_user(),
84+
VHost :: rabbit_types:vhost(),
85+
AuthzData :: rabbit_types:authz_data()) -> boolean() | {'error', any()}.
7686
check_vhost_access(#auth_user{impl = DecodedTokenFun},
7787
VHost, _AuthzData) ->
7888
with_decoded_token(DecodedTokenFun(),
@@ -136,6 +146,11 @@ expiry_timestamp(#auth_user{impl = DecodedTokenFun}) ->
136146

137147
%%--------------------------------------------------------------------
138148

149+
-spec authenticate(Username, Props) -> Result
150+
when Username :: rabbit_types:username(),
151+
Props :: list() | map(),
152+
Result :: {ok, any()} | {refused, list(), list()} | {refused, {error, any()}}.
153+
139154
authenticate(_, AuthProps0) ->
140155
AuthProps = to_map(AuthProps0),
141156
Token = token_from_context(AuthProps),
@@ -148,31 +163,42 @@ authenticate(_, AuthProps0) ->
148163
{refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid", []};
149164
{refused, Err} ->
150165
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
151-
{ok, DecodedToken} ->
152-
Func = fun(Token0) ->
153-
Username = username_from(
154-
ResourceServer#resource_server.preferred_username_claims,
155-
Token0),
156-
Tags = tags_from(Token0),
157-
{ok, #auth_user{username = Username,
158-
tags = Tags,
159-
impl = fun() -> Token0 end}}
160-
end,
161-
case with_decoded_token(DecodedToken, Func) of
166+
{ok, DecodedToken} ->
167+
case with_decoded_token(DecodedToken, fun(In) -> auth_user_from_token(In, ResourceServer) end) of
162168
{error, Err} ->
163169
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
164170
Else ->
165171
Else
166172
end
167173
end
168174
end.
175+
176+
-spec with_decoded_token(Token, Fun) -> Result
177+
when Token :: decoded_jwt_token(),
178+
Fun :: auth_user_extraction_fun(),
179+
Result :: {ok, any()} | {'error', any()}.
169180
with_decoded_token(DecodedToken, Fun) ->
170181
case validate_token_expiry(DecodedToken) of
171182
ok -> Fun(DecodedToken);
172183
{error, Msg} = Err ->
173184
rabbit_log:error(Msg),
174185
Err
175186
end.
187+
188+
%% This is a helper function used with HOFs that may return errors.
189+
-spec auth_user_from_token(Token, ResourceServer) -> Result
190+
when Token :: decoded_jwt_token(),
191+
ResourceServer :: resource_server(),
192+
Result :: ok_extracted_auth_user().
193+
auth_user_from_token(Token0, ResourceServer) ->
194+
Username = username_from(
195+
ResourceServer#resource_server.preferred_username_claims,
196+
Token0),
197+
Tags = tags_from(Token0),
198+
{ok, #auth_user{username = Username,
199+
tags = Tags,
200+
impl = fun() -> Token0 end}}.
201+
176202
ensure_same_username(PreferredUsernameClaims, CurrentDecodedToken, NewDecodedToken) ->
177203
CurUsername = username_from(PreferredUsernameClaims, CurrentDecodedToken),
178204
case {CurUsername, username_from(PreferredUsernameClaims, NewDecodedToken)} of
@@ -188,12 +214,10 @@ validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) ->
188214
end;
189215
validate_token_expiry(#{}) -> ok.
190216

191-
-spec check_token(binary() | map(), {resource_server(), internal_oauth_provider()}) ->
192-
{'ok', map()} |
193-
{'error', term() }|
194-
{'refused', 'signature_invalid' |
195-
{'error', term()} |
196-
{'invalid_aud', term()}}.
217+
-spec check_token(raw_jwt_token(), {resource_server(), internal_oauth_provider()}) ->
218+
{'ok', decoded_jwt_token()} |
219+
{'error', term() } |
220+
{'refused', 'signature_invalid' | {'error', term()} | {'invalid_aud', term()}}.
197221

198222
check_token(DecodedToken, _) when is_map(DecodedToken) ->
199223
{ok, DecodedToken};
@@ -206,7 +230,7 @@ check_token(Token, {ResourceServer, InternalOAuthProvider}) ->
206230
end.
207231

208232
-spec normalize_token_scope(
209-
ResourceServer :: resource_server(), DecodedToken :: map()) -> map().
233+
ResourceServer :: resource_server(), DecodedToken :: decoded_jwt_token()) -> map().
210234
normalize_token_scope(ResourceServer, Payload) ->
211235
Payload0 = maps:map(fun(K, V) ->
212236
case K of
@@ -395,7 +419,7 @@ resolve_scope_var(Elem, Token, Vhost) ->
395419
end)
396420
end.
397421

398-
-spec tags_from(map()) -> list(atom()).
422+
-spec tags_from(decoded_jwt_token()) -> list(atom()).
399423
tags_from(DecodedToken) ->
400424
Scopes = maps:get(?SCOPE_JWT_FIELD, DecodedToken, []),
401425
TagScopes = filter_matching_scope_prefix_and_drop_it(Scopes, ?TAG_SCOPE_PREFIX),

0 commit comments

Comments
 (0)