Skip to content

Commit d9ccc1d

Browse files
authored
[papi] OIDC service signs state with HS256, reusing signing PK - WEB-206 (#17328)
* [papi] OIDC service signs state with RSA256 * Fix * retest * fix * add test
1 parent 0a1ea38 commit d9ccc1d

File tree

8 files changed

+159
-135
lines changed

8 files changed

+159
-135
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package jws
6+
7+
import (
8+
"crypto/x509"
9+
"encoding/pem"
10+
"errors"
11+
"fmt"
12+
13+
"github.com/golang-jwt/jwt/v5"
14+
)
15+
16+
func NewHS256FromKeySet(keyset KeySet) *HS256 {
17+
// We treat the signing private key as our symmetric key, to do that, we first have to convert it to bytes
18+
// For bytes conversion, we encode it as PKCS1 PK, pem format
19+
raw := x509.MarshalPKCS1PrivateKey(keyset.Signing.Private)
20+
key := pem.EncodeToMemory(&pem.Block{
21+
Type: "",
22+
Bytes: raw,
23+
})
24+
25+
return NewHS256(key)
26+
}
27+
28+
func NewHS256(symmetricKey []byte) *HS256 {
29+
return &HS256{
30+
key: symmetricKey,
31+
}
32+
}
33+
34+
type HS256 struct {
35+
key []byte
36+
}
37+
38+
func (s *HS256) Sign(token *jwt.Token) (string, error) {
39+
if token.Method != jwt.SigningMethodHS256 {
40+
return "", errors.New("invalid signing method, token must use HS256")
41+
}
42+
43+
signed, err := token.SignedString(s.key)
44+
if err != nil {
45+
return "", fmt.Errorf("failed to sign jwt: %w", err)
46+
}
47+
48+
return signed, nil
49+
}
50+
51+
func (v *HS256) Verify(token string, claims jwt.Claims, opts ...jwt.ParserOption) (*jwt.Token, error) {
52+
parsed, err := jwt.ParseWithClaims(token, claims, jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) {
53+
return v.key, nil
54+
}), opts...)
55+
56+
if err != nil {
57+
return nil, fmt.Errorf("failed to parse jwt: %w", err)
58+
}
59+
60+
return parsed, nil
61+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
2+
// Licensed under the GNU Affero General Public License (AGPL).
3+
// See License.AGPL.txt in the project root for license information.
4+
5+
package jws_test
6+
7+
import (
8+
"testing"
9+
"time"
10+
11+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
12+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
13+
"github.com/golang-jwt/jwt/v5"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestHS256SignVerify(t *testing.T) {
18+
keyset := jwstest.GenerateKeySet(t)
19+
hs256 := jws.NewHS256FromKeySet(keyset)
20+
21+
claims := &jwt.RegisteredClaims{
22+
Subject: "user-id",
23+
Issuer: "test-issuer",
24+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
25+
IssuedAt: jwt.NewNumericDate(time.Now()),
26+
}
27+
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
28+
29+
signed, err := hs256.Sign(token)
30+
require.NoError(t, err)
31+
32+
verified, err := hs256.Verify(signed, &jwt.RegisteredClaims{})
33+
require.NoError(t, err)
34+
require.Equal(t, claims, verified.Claims)
35+
}

components/public-api-server/pkg/jws/keyset.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
type Key struct {
1818
ID string
1919
Private *rsa.PrivateKey
20+
Raw []byte
2021
// We don't need PublicKey because we can derive the public key from the private key
2122
}
2223

@@ -51,29 +52,30 @@ func NewKeySetFromAuthPKI(pki config.AuthPKIConfiguration) (KeySet, error) {
5152
}
5253

5354
func readKeyPair(keypair config.KeyPair) (Key, error) {
54-
pk, err := readPrivateKeyFromFile(keypair.PrivateKeyPath)
55+
pk, raw, err := readPrivateKeyFromFile(keypair.PrivateKeyPath)
5556
if err != nil {
5657
return Key{}, err
5758
}
5859

5960
return Key{
6061
ID: keypair.ID,
6162
Private: pk,
63+
Raw: raw,
6264
}, nil
6365
}
6466

65-
func readPrivateKeyFromFile(filepath string) (*rsa.PrivateKey, error) {
67+
func readPrivateKeyFromFile(filepath string) (*rsa.PrivateKey, []byte, error) {
6668
bytes, err := ioutil.ReadFile(filepath)
6769
if err != nil {
68-
return nil, fmt.Errorf("failed to read private key from %s: %w", filepath, err)
70+
return nil, nil, fmt.Errorf("failed to read private key from %s: %w", filepath, err)
6971
}
7072

7173
block, _ := pem.Decode(bytes)
7274
parseResult, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
7375
key, ok := parseResult.(*rsa.PrivateKey)
7476
if !ok {
75-
return nil, fmt.Errorf("file %s does not contain RSA Private Key", filepath)
77+
return nil, nil, fmt.Errorf("file %s does not contain RSA Private Key", filepath)
7678
}
7779

78-
return key, nil
80+
return key, bytes, nil
7981
}

components/public-api-server/pkg/oidc/service.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ import (
1414
"io"
1515
"io/ioutil"
1616
"net/http"
17+
"time"
1718

1819
"github.com/coreos/go-oidc/v3/oidc"
1920
goidc "github.com/coreos/go-oidc/v3/oidc"
2021
"github.com/gitpod-io/gitpod/common-go/log"
2122
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
23+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
2224
"github.com/google/uuid"
2325
"golang.org/x/oauth2"
2426
"google.golang.org/grpc/codes"
@@ -27,11 +29,13 @@ import (
2729
)
2830

2931
type Service struct {
30-
dbConn *gorm.DB
31-
cipher db.Cipher
32-
stateJWT *StateJWT
32+
dbConn *gorm.DB
33+
cipher db.Cipher
34+
35+
// jwts
36+
stateExpiry time.Duration
37+
signerVerifier jws.SignerVerifier
3338

34-
// verifierByIssuer map[string]*goidc.IDTokenVerifier
3539
sessionServiceAddress string
3640

3741
// TODO(at) remove by enhancing test setups
@@ -57,14 +61,15 @@ type AuthFlowResult struct {
5761
Claims map[string]interface{} `json:"claims"`
5862
}
5963

60-
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, stateJWT *StateJWT) *Service {
64+
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, signerVerifier jws.SignerVerifier, stateExpiry time.Duration) *Service {
6165
return &Service{
6266
sessionServiceAddress: sessionServiceAddress,
6367

6468
dbConn: dbConn,
6569
cipher: cipher,
6670

67-
stateJWT: stateJWT,
71+
signerVerifier: signerVerifier,
72+
stateExpiry: stateExpiry,
6873
}
6974
}
7075

@@ -98,18 +103,24 @@ func (s *Service) GetStartParams(config *ClientConfig, redirectURL string, retur
98103
}
99104

100105
func (s *Service) encodeStateParam(state StateParam) (string, error) {
101-
encodedState, err := s.stateJWT.Encode(StateClaims{
102-
ClientConfigID: state.ClientConfigID,
103-
ReturnToURL: state.ReturnToURL,
104-
})
105-
return encodedState, err
106+
now := time.Now().UTC()
107+
expiry := now.Add(s.stateExpiry)
108+
token := NewStateJWT(state.ClientConfigID, state.ReturnToURL, now, expiry)
109+
110+
signed, err := s.signerVerifier.Sign(token)
111+
if err != nil {
112+
return "", fmt.Errorf("failed to sign jwt: %w", err)
113+
}
114+
return signed, nil
106115
}
107116

108117
func (s *Service) decodeStateParam(encodedToken string) (StateParam, error) {
109-
claims, err := s.stateJWT.Decode(encodedToken)
118+
claims := &StateClaims{}
119+
_, err := s.signerVerifier.Verify(encodedToken, claims)
110120
if err != nil {
111-
return StateParam{}, err
121+
return StateParam{}, fmt.Errorf("failed to verify state token: %w", err)
112122
}
123+
113124
return StateParam{
114125
ClientConfigID: claims.ClientConfigID,
115126
ReturnToURL: claims.ReturnToURL,

components/public-api-server/pkg/oidc/service_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"github.com/coreos/go-oidc/v3/oidc"
1818
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
1919
"github.com/gitpod-io/gitpod/components/gitpod-db/go/dbtest"
20+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
21+
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
2022
"github.com/go-chi/chi/v5"
2123
"github.com/go-chi/chi/v5/middleware"
2224
"github.com/google/uuid"
@@ -223,11 +225,13 @@ func setupOIDCServiceForTests(t *testing.T) (*Service, *gorm.DB) {
223225

224226
dbConn := dbtest.ConnectForTests(t)
225227
cipher := dbtest.CipherSet(t)
226-
stateJWT := newTestStateJWT([]byte("ANY KEY"), 5*time.Minute)
227228

228229
sessionServerAddress := newFakeSessionServer(t)
229230

230-
service := NewService(sessionServerAddress, dbConn, cipher, stateJWT)
231+
keyset := jwstest.GenerateKeySet(t)
232+
signerVerifier := jws.NewHS256FromKeySet(keyset)
233+
234+
service := NewService(sessionServerAddress, dbConn, cipher, signerVerifier, 5*time.Minute)
231235
service.skipVerifyIdToken = true
232236
return service, dbConn
233237
}

components/public-api-server/pkg/oidc/state_jwt.go

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,6 @@ import (
1010
"github.com/golang-jwt/jwt/v5"
1111
)
1212

13-
type StateJWT struct {
14-
key []byte
15-
expiresIn time.Duration
16-
}
17-
18-
func NewStateJWT(key []byte) *StateJWT {
19-
return &StateJWT{
20-
key: key,
21-
expiresIn: 5 * time.Minute,
22-
}
23-
}
24-
25-
func newTestStateJWT(key []byte, expiresIn time.Duration) *StateJWT {
26-
thing := NewStateJWT(key)
27-
thing.expiresIn = expiresIn
28-
return thing
29-
}
30-
3113
type StateClaims struct {
3214
// Internal client ID
3315
ClientConfigID string `json:"clientId"`
@@ -36,26 +18,13 @@ type StateClaims struct {
3618
jwt.RegisteredClaims
3719
}
3820

39-
func (s *StateJWT) Encode(claims StateClaims) (string, error) {
40-
41-
expirationTime := time.Now().Add(s.expiresIn)
42-
claims.ExpiresAt = jwt.NewNumericDate(expirationTime)
43-
44-
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
45-
encodedToken, err := token.SignedString(s.key)
46-
47-
return encodedToken, err
48-
}
49-
50-
func (s *StateJWT) Decode(tokenString string) (*StateClaims, error) {
51-
claims := &StateClaims{}
52-
_, err := jwt.ParseWithClaims(
53-
tokenString,
54-
claims,
55-
func(token *jwt.Token) (interface{}, error) {
56-
return []byte(s.key), nil
21+
func NewStateJWT(clientConfigID string, returnURL string, issuedAt, expiry time.Time) *jwt.Token {
22+
return jwt.NewWithClaims(jwt.SigningMethodHS256, &StateClaims{
23+
ClientConfigID: clientConfigID,
24+
ReturnToURL: returnURL,
25+
RegisteredClaims: jwt.RegisteredClaims{
26+
ExpiresAt: jwt.NewNumericDate(expiry),
27+
IssuedAt: jwt.NewNumericDate(issuedAt),
5728
},
58-
)
59-
60-
return claims, err
29+
})
6130
}

0 commit comments

Comments
 (0)