Skip to content

Type spec improvements in rabbit_auth_backend_oauth2 #12851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

%% End of Key JWT fields

-type raw_jwt_token() :: binary() | #{binary() => any()}.
-type decoded_jwt_token() :: #{binary() => any()}.

-record(internal_oauth_provider, {
id :: oauth_provider_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,11 @@
-endif.

%%
%% App environment
%% Types
%%


%% a term defined for Rich Authorization Request tokens to identify a RabbitMQ permission
%% verify server_server_id aud field is on the aud field
%% a term used by the IdentityServer community
%% scope aliases map "role names" to a set of scopes

-type ok_extracted_auth_user() :: {ok, rabbit_types:auth_user()}.
-type auth_user_extraction_fun() :: fun((decoded_jwt_token()) -> any()).

%%
%% API
Expand All @@ -58,6 +54,11 @@ description() ->

%%--------------------------------------------------------------------

-spec user_login_authentication(rabbit_types:username(), [term()] | map()) ->
{'ok', rabbit_types:auth_user()} |
{'refused', string(), [any()]} |
{'error', any()}.

user_login_authentication(Username, AuthProps) ->
case authenticate(Username, AuthProps) of
{refused, Msg, Args} = AuthResult ->
Expand All @@ -67,12 +68,21 @@ user_login_authentication(Username, AuthProps) ->
AuthResult
end.

-spec user_login_authorization(rabbit_types:username(), [term()] | map()) ->
{'ok', any()} |
{'ok', any(), any()} |
{'refused', string(), [any()]} |
{'error', any()}.

user_login_authorization(Username, AuthProps) ->
case authenticate(Username, AuthProps) of
{ok, #auth_user{impl = Impl}} -> {ok, Impl};
Else -> Else
end.

-spec check_vhost_access(AuthUser :: rabbit_types:auth_user(),
VHost :: rabbit_types:vhost(),
AuthzData :: rabbit_types:authz_data()) -> boolean() | {'error', any()}.
check_vhost_access(#auth_user{impl = DecodedTokenFun},
VHost, _AuthzData) ->
with_decoded_token(DecodedTokenFun(),
Expand Down Expand Up @@ -136,6 +146,11 @@ expiry_timestamp(#auth_user{impl = DecodedTokenFun}) ->

%%--------------------------------------------------------------------

-spec authenticate(Username, Props) -> Result
when Username :: rabbit_types:username(),
Props :: list() | map(),
Result :: {ok, any()} | {refused, list(), list()} | {refused, {error, any()}}.

authenticate(_, AuthProps0) ->
AuthProps = to_map(AuthProps0),
Token = token_from_context(AuthProps),
Expand All @@ -148,31 +163,42 @@ authenticate(_, AuthProps0) ->
{refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid", []};
{refused, Err} ->
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
{ok, DecodedToken} ->
Func = fun(Token0) ->
Username = username_from(
ResourceServer#resource_server.preferred_username_claims,
Token0),
Tags = tags_from(Token0),
{ok, #auth_user{username = Username,
tags = Tags,
impl = fun() -> Token0 end}}
end,
case with_decoded_token(DecodedToken, Func) of
{ok, DecodedToken} ->
case with_decoded_token(DecodedToken, fun(In) -> auth_user_from_token(In, ResourceServer) end) of
{error, Err} ->
{refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]};
Else ->
Else
end
end
end.

-spec with_decoded_token(Token, Fun) -> Result
when Token :: decoded_jwt_token(),
Fun :: auth_user_extraction_fun(),
Result :: {ok, any()} | {'error', any()}.
with_decoded_token(DecodedToken, Fun) ->
case validate_token_expiry(DecodedToken) of
ok -> Fun(DecodedToken);
{error, Msg} = Err ->
rabbit_log:error(Msg),
Err
end.

%% This is a helper function used with HOFs that may return errors.
-spec auth_user_from_token(Token, ResourceServer) -> Result
when Token :: decoded_jwt_token(),
ResourceServer :: resource_server(),
Result :: ok_extracted_auth_user().
auth_user_from_token(Token0, ResourceServer) ->
Username = username_from(
ResourceServer#resource_server.preferred_username_claims,
Token0),
Tags = tags_from(Token0),
{ok, #auth_user{username = Username,
tags = Tags,
impl = fun() -> Token0 end}}.

ensure_same_username(PreferredUsernameClaims, CurrentDecodedToken, NewDecodedToken) ->
CurUsername = username_from(PreferredUsernameClaims, CurrentDecodedToken),
case {CurUsername, username_from(PreferredUsernameClaims, NewDecodedToken)} of
Expand All @@ -188,12 +214,10 @@ validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) ->
end;
validate_token_expiry(#{}) -> ok.

-spec check_token(binary() | map(), {resource_server(), internal_oauth_provider()}) ->
{'ok', map()} |
{'error', term() }|
{'refused', 'signature_invalid' |
{'error', term()} |
{'invalid_aud', term()}}.
-spec check_token(raw_jwt_token(), {resource_server(), internal_oauth_provider()}) ->
{'ok', decoded_jwt_token()} |
{'error', term() } |
{'refused', 'signature_invalid' | {'error', term()} | {'invalid_aud', term()}}.

check_token(DecodedToken, _) when is_map(DecodedToken) ->
{ok, DecodedToken};
Expand All @@ -206,7 +230,7 @@ check_token(Token, {ResourceServer, InternalOAuthProvider}) ->
end.

-spec normalize_token_scope(
ResourceServer :: resource_server(), DecodedToken :: map()) -> map().
ResourceServer :: resource_server(), DecodedToken :: decoded_jwt_token()) -> map().
normalize_token_scope(ResourceServer, Payload) ->
Payload0 = maps:map(fun(K, V) ->
case K of
Expand Down Expand Up @@ -395,7 +419,7 @@ resolve_scope_var(Elem, Token, Vhost) ->
end)
end.

-spec tags_from(map()) -> list(atom()).
-spec tags_from(decoded_jwt_token()) -> list(atom()).
tags_from(DecodedToken) ->
Scopes = maps:get(?SCOPE_JWT_FIELD, DecodedToken, []),
TagScopes = filter_matching_scope_prefix_and_drop_it(Scopes, ?TAG_SCOPE_PREFIX),
Expand Down
Loading