Skip to content

Commit 96f5403

Browse files
committed
Fix
1 parent 3b75276 commit 96f5403

File tree

2 files changed

+22
-66
lines changed

2 files changed

+22
-66
lines changed

components/public-api-server/pkg/oidc/service_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"github.com/coreos/go-oidc/v3/oidc"
1818
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
1919
"github.com/gitpod-io/gitpod/components/gitpod-db/go/dbtest"
20+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
21+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
2022
"github.com/go-chi/chi/v5"
2123
"github.com/go-chi/chi/v5/middleware"
2224
"github.com/google/uuid"
@@ -223,11 +225,14 @@ func setupOIDCServiceForTests(t *testing.T) (*Service, *gorm.DB) {
223225

224226
dbConn := dbtest.ConnectForTests(t)
225227
cipher := dbtest.CipherSet(t)
226-
stateJWT := newTestStateJWT([]byte("ANY KEY"), 5*time.Minute)
227228

228229
sessionServerAddress := newFakeSessionServer(t)
229230

230-
service := NewService(sessionServerAddress, dbConn, cipher, stateJWT)
231+
keyset := jwstest.GenerateKeySet(t)
232+
signerVerifier, err := jws.NewRSA256(keyset)
233+
require.NoError(t, err)
234+
235+
service := NewService(sessionServerAddress, dbConn, cipher, signerVerifier, 5*time.Minute)
231236
service.skipVerifyIdToken = true
232237
return service, dbConn
233238
}

components/public-api-server/pkg/oidc/state_jwt_test.go

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,76 +5,27 @@
55
package oidc
66

77
import (
8-
"strings"
98
"testing"
109
"time"
1110

1211
"github.com/golang-jwt/jwt/v5"
1312
"github.com/stretchr/testify/require"
1413
)
1514

16-
func Test_Encode(t *testing.T) {
17-
stateJWT := NewStateJWT([]byte("ANY KEY"))
18-
encodedState, err := stateJWT.Encode(StateClaims{
19-
ClientConfigID: "test-id",
20-
ReturnToURL: "test-url",
21-
})
22-
require.NoError(t, err)
23-
// check for header: { "alg": "HS256", "typ": "JWT" }
24-
require.Contains(t, encodedState, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.", "")
25-
}
26-
27-
func Test_Decode(t *testing.T) {
28-
29-
testCases := []struct {
30-
Label string
31-
Key4Encode string
32-
expiresIn time.Duration
33-
Key4Decode string
34-
ExpectedError string
35-
}{
36-
{
37-
Label: "happy path",
38-
Key4Encode: "ANY KEY",
39-
expiresIn: 5 * time.Minute,
40-
Key4Decode: "ANY KEY",
41-
ExpectedError: "",
42-
},
43-
{
44-
Label: "expired state token",
45-
Key4Encode: "ANY KEY",
46-
expiresIn: 0 * time.Second,
47-
Key4Decode: "ANY KEY",
48-
ExpectedError: "token is expired",
15+
func TestNewStateJWT(t *testing.T) {
16+
var (
17+
clientConfigID = "test-id"
18+
returnURL = "test-url"
19+
issuedAt = time.Now()
20+
expiry = issuedAt.Add(5 * time.Minute)
21+
)
22+
token := NewStateJWT(clientConfigID, returnURL, issuedAt, expiry)
23+
require.Equal(t, &StateClaims{
24+
ClientConfigID: clientConfigID,
25+
ReturnToURL: returnURL,
26+
RegisteredClaims: jwt.RegisteredClaims{
27+
ExpiresAt: jwt.NewNumericDate(expiry),
28+
IssuedAt: jwt.NewNumericDate(issuedAt),
4929
},
50-
{
51-
Label: "signature is invalid",
52-
Key4Encode: "OTHER KEY",
53-
expiresIn: 5 * time.Minute,
54-
Key4Decode: "ANY KEY",
55-
ExpectedError: jwt.ErrSignatureInvalid.Error(),
56-
},
57-
}
58-
59-
for _, tc := range testCases {
60-
t.Run(tc.Label, func(t *testing.T) {
61-
encoder := newTestStateJWT([]byte(tc.Key4Encode), tc.expiresIn)
62-
decoder := NewStateJWT([]byte(tc.Key4Decode))
63-
encodedState, err := encoder.Encode(StateClaims{
64-
ClientConfigID: "test-id",
65-
ReturnToURL: "test-url",
66-
})
67-
if err != nil && tc.ExpectedError == "" {
68-
require.FailNowf(t, "Unexpected error on `Encode`.", "Error: %", err)
69-
}
70-
_, err = decoder.Decode(encodedState)
71-
if err != nil && tc.ExpectedError == "" {
72-
require.FailNowf(t, "Unexpected error on `Decode`.", "Error: %", err)
73-
}
74-
if err != nil && !strings.Contains(err.Error(), tc.ExpectedError) {
75-
require.FailNowf(t, "Unmatched error.", "Got error: %", err.Error())
76-
}
77-
})
78-
}
79-
30+
}, token.Claims)
8031
}

0 commit comments

Comments
 (0)