Skip to content

Commit b51c09c

Browse files
committed
Allow use of the same group claim name for the prohibit login value
Signed-off-by: Andrew Thornton <[email protected]>
1 parent 0e27070 commit b51c09c

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

routers/web/user/auth.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,23 @@ func SignInOAuthCallback(ctx *context.Context) {
656656
handleOAuth2SignIn(ctx, loginSource, u, gothUser)
657657
}
658658

659+
func claimValueToStringSlice(claimValue interface{}) []string {
660+
var groups []string
661+
662+
switch rawGroup := claimValue.(type) {
663+
case []string:
664+
groups = rawGroup
665+
default:
666+
str := fmt.Sprintf("%s", rawGroup)
667+
if strings.Contains(str, ",") {
668+
groups = strings.Split(str, ",")
669+
} else {
670+
groups = []string{str}
671+
}
672+
}
673+
return groups
674+
}
675+
659676
func setUserGroupClaims(loginSource *models.LoginSource, u *models.User, gothUser *goth.User) bool {
660677

661678
source := loginSource.Cfg.(*oauth2.Source)
@@ -668,18 +685,7 @@ func setUserGroupClaims(loginSource *models.LoginSource, u *models.User, gothUse
668685
return false
669686
}
670687

671-
var groups []string
672-
673-
switch rawGroup := groupClaims.(type) {
674-
case []string:
675-
groups = rawGroup
676-
case string:
677-
if strings.Contains(rawGroup, ",") {
678-
groups = strings.Split(rawGroup, ",")
679-
} else {
680-
groups = []string{rawGroup}
681-
}
682-
}
688+
groups := claimValueToStringSlice(groupClaims)
683689

684690
wasAdmin, wasRestricted := u.IsAdmin, u.IsRestricted
685691

@@ -844,9 +850,23 @@ func oAuth2UserLoginCallback(loginSource *models.LoginSource, request *http.Requ
844850

845851
if oauth2Source.RequiredClaimName != "" {
846852
claimInterface, has := gothUser.RawData[oauth2Source.RequiredClaimName]
847-
if !has || (oauth2Source.RequiredClaimValue != "" && claimInterface.(string) != oauth2Source.RequiredClaimValue) {
853+
if !has {
848854
return nil, goth.User{}, models.ErrUserProhibitLogin{Name: gothUser.UserID}
849855
}
856+
857+
if oauth2Source.RequiredClaimValue != "" {
858+
groups := claimValueToStringSlice(claimInterface)
859+
found := false
860+
for _, group := range groups {
861+
if group == oauth2Source.RequiredClaimValue {
862+
found = true
863+
break
864+
}
865+
}
866+
if !found {
867+
return nil, goth.User{}, models.ErrUserProhibitLogin{Name: gothUser.UserID}
868+
}
869+
}
850870
}
851871

852872
user := &models.User{

0 commit comments

Comments
 (0)