Skip to content

[papi] OIDC service signs state with HS256, reusing signing PK - WEB-206 #17328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions components/public-api-server/pkg/jws/hs256.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package jws

import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"

"github.com/golang-jwt/jwt/v5"
)

func NewHS256FromKeySet(keyset KeySet) *HS256 {
// We treat the signing private key as our symmetric key, to do that, we first have to convert it to bytes
// For bytes conversion, we encode it as PKCS1 PK, pem format
raw := x509.MarshalPKCS1PrivateKey(keyset.Signing.Private)
key := pem.EncodeToMemory(&pem.Block{
Type: "",
Bytes: raw,
})

return NewHS256(key)
}

func NewHS256(symmetricKey []byte) *HS256 {
return &HS256{
key: symmetricKey,
}
}

type HS256 struct {
key []byte
}

func (s *HS256) Sign(token *jwt.Token) (string, error) {
if token.Method != jwt.SigningMethodHS256 {
return "", errors.New("invalid signing method, token must use HS256")
}

signed, err := token.SignedString(s.key)
if err != nil {
return "", fmt.Errorf("failed to sign jwt: %w", err)
}

return signed, nil
}

func (v *HS256) Verify(token string, claims jwt.Claims, opts ...jwt.ParserOption) (*jwt.Token, error) {
parsed, err := jwt.ParseWithClaims(token, claims, jwt.Keyfunc(func(t *jwt.Token) (interface{}, error) {
return v.key, nil
}), opts...)

if err != nil {
return nil, fmt.Errorf("failed to parse jwt: %w", err)
}

return parsed, nil
}
35 changes: 35 additions & 0 deletions components/public-api-server/pkg/jws/hs256_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package jws_test

import (
"testing"
"time"

"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)

func TestHS256SignVerify(t *testing.T) {
keyset := jwstest.GenerateKeySet(t)
hs256 := jws.NewHS256FromKeySet(keyset)

claims := &jwt.RegisteredClaims{
Subject: "user-id",
Issuer: "test-issuer",
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

signed, err := hs256.Sign(token)
require.NoError(t, err)

verified, err := hs256.Verify(signed, &jwt.RegisteredClaims{})
require.NoError(t, err)
require.Equal(t, claims, verified.Claims)
}
12 changes: 7 additions & 5 deletions components/public-api-server/pkg/jws/keyset.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
type Key struct {
ID string
Private *rsa.PrivateKey
Raw []byte
// We don't need PublicKey because we can derive the public key from the private key
}

Expand Down Expand Up @@ -51,29 +52,30 @@ func NewKeySetFromAuthPKI(pki config.AuthPKIConfiguration) (KeySet, error) {
}

func readKeyPair(keypair config.KeyPair) (Key, error) {
pk, err := readPrivateKeyFromFile(keypair.PrivateKeyPath)
pk, raw, err := readPrivateKeyFromFile(keypair.PrivateKeyPath)
if err != nil {
return Key{}, err
}

return Key{
ID: keypair.ID,
Private: pk,
Raw: raw,
}, nil
}

func readPrivateKeyFromFile(filepath string) (*rsa.PrivateKey, error) {
func readPrivateKeyFromFile(filepath string) (*rsa.PrivateKey, []byte, error) {
bytes, err := ioutil.ReadFile(filepath)
if err != nil {
return nil, fmt.Errorf("failed to read private key from %s: %w", filepath, err)
return nil, nil, fmt.Errorf("failed to read private key from %s: %w", filepath, err)
}

block, _ := pem.Decode(bytes)
parseResult, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
key, ok := parseResult.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("file %s does not contain RSA Private Key", filepath)
return nil, nil, fmt.Errorf("file %s does not contain RSA Private Key", filepath)
}

return key, nil
return key, bytes, nil
}
37 changes: 24 additions & 13 deletions components/public-api-server/pkg/oidc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (
"io"
"io/ioutil"
"net/http"
"time"

"github.com/coreos/go-oidc/v3/oidc"
goidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/gitpod-io/gitpod/common-go/log"
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
"github.com/google/uuid"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
Expand All @@ -27,11 +29,13 @@ import (
)

type Service struct {
dbConn *gorm.DB
cipher db.Cipher
stateJWT *StateJWT
dbConn *gorm.DB
cipher db.Cipher

// jwts
stateExpiry time.Duration
signerVerifier jws.SignerVerifier

// verifierByIssuer map[string]*goidc.IDTokenVerifier
sessionServiceAddress string

// TODO(at) remove by enhancing test setups
Expand All @@ -57,14 +61,15 @@ type AuthFlowResult struct {
Claims map[string]interface{} `json:"claims"`
}

func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, stateJWT *StateJWT) *Service {
func NewService(sessionServiceAddress string, dbConn *gorm.DB, cipher db.Cipher, signerVerifier jws.SignerVerifier, stateExpiry time.Duration) *Service {
return &Service{
sessionServiceAddress: sessionServiceAddress,

dbConn: dbConn,
cipher: cipher,

stateJWT: stateJWT,
signerVerifier: signerVerifier,
stateExpiry: stateExpiry,
}
}

Expand Down Expand Up @@ -98,18 +103,24 @@ func (s *Service) GetStartParams(config *ClientConfig, redirectURL string, retur
}

func (s *Service) encodeStateParam(state StateParam) (string, error) {
encodedState, err := s.stateJWT.Encode(StateClaims{
ClientConfigID: state.ClientConfigID,
ReturnToURL: state.ReturnToURL,
})
return encodedState, err
now := time.Now().UTC()
expiry := now.Add(s.stateExpiry)
token := NewStateJWT(state.ClientConfigID, state.ReturnToURL, now, expiry)

signed, err := s.signerVerifier.Sign(token)
if err != nil {
return "", fmt.Errorf("failed to sign jwt: %w", err)
}
return signed, nil
}

func (s *Service) decodeStateParam(encodedToken string) (StateParam, error) {
claims, err := s.stateJWT.Decode(encodedToken)
claims := &StateClaims{}
_, err := s.signerVerifier.Verify(encodedToken, claims)
if err != nil {
return StateParam{}, err
return StateParam{}, fmt.Errorf("failed to verify state token: %w", err)
}

return StateParam{
ClientConfigID: claims.ClientConfigID,
ReturnToURL: claims.ReturnToURL,
Expand Down
8 changes: 6 additions & 2 deletions components/public-api-server/pkg/oidc/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
db "github.com/gitpod-io/gitpod/components/gitpod-db/go"
"github.com/gitpod-io/gitpod/components/gitpod-db/go/dbtest"
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws"
"github.com/gitpod-io/gitpod/public-api-server/pkg/jws/jwstest"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
Expand Down Expand Up @@ -223,11 +225,13 @@ func setupOIDCServiceForTests(t *testing.T) (*Service, *gorm.DB) {

dbConn := dbtest.ConnectForTests(t)
cipher := dbtest.CipherSet(t)
stateJWT := newTestStateJWT([]byte("ANY KEY"), 5*time.Minute)

sessionServerAddress := newFakeSessionServer(t)

service := NewService(sessionServerAddress, dbConn, cipher, stateJWT)
keyset := jwstest.GenerateKeySet(t)
signerVerifier := jws.NewHS256FromKeySet(keyset)

service := NewService(sessionServerAddress, dbConn, cipher, signerVerifier, 5*time.Minute)
service.skipVerifyIdToken = true
return service, dbConn
}
Expand Down
47 changes: 8 additions & 39 deletions components/public-api-server/pkg/oidc/state_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,6 @@ import (
"github.com/golang-jwt/jwt/v5"
)

type StateJWT struct {
key []byte
expiresIn time.Duration
}

func NewStateJWT(key []byte) *StateJWT {
return &StateJWT{
key: key,
expiresIn: 5 * time.Minute,
}
}

func newTestStateJWT(key []byte, expiresIn time.Duration) *StateJWT {
thing := NewStateJWT(key)
thing.expiresIn = expiresIn
return thing
}

type StateClaims struct {
// Internal client ID
ClientConfigID string `json:"clientId"`
Expand All @@ -36,26 +18,13 @@ type StateClaims struct {
jwt.RegisteredClaims
}

func (s *StateJWT) Encode(claims StateClaims) (string, error) {

expirationTime := time.Now().Add(s.expiresIn)
claims.ExpiresAt = jwt.NewNumericDate(expirationTime)

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
encodedToken, err := token.SignedString(s.key)

return encodedToken, err
}

func (s *StateJWT) Decode(tokenString string) (*StateClaims, error) {
claims := &StateClaims{}
_, err := jwt.ParseWithClaims(
tokenString,
claims,
func(token *jwt.Token) (interface{}, error) {
return []byte(s.key), nil
func NewStateJWT(clientConfigID string, returnURL string, issuedAt, expiry time.Time) *jwt.Token {
return jwt.NewWithClaims(jwt.SigningMethodHS256, &StateClaims{
ClientConfigID: clientConfigID,
ReturnToURL: returnURL,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiry),
IssuedAt: jwt.NewNumericDate(issuedAt),
},
)

return claims, err
})
}
Loading