Skip to content

enhance: add mTLS between gptscript and daemon tools #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions pkg/certs/certs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package certs

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"time"
)

// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format)
type CertAndKey struct {
Cert []byte
Key []byte
}

func GenerateGPTScriptCert() (CertAndKey, error) {
return GenerateSelfSignedCert("gptscript server")
}

func GenerateSelfSignedCert(name string) (CertAndKey, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err)
}

marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err)
}

marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey})

template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
CommonName: name,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0), // a year from now
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
IsCA: false,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err)
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})

return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil
}
91 changes: 88 additions & 3 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package engine

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"math/rand"
Expand All @@ -11,11 +14,13 @@ import (
"sync"
"time"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
)

var ports Ports
var certificates Certs

type Ports struct {
daemonPorts map[string]int64
Expand All @@ -29,6 +34,35 @@ type Ports struct {
daemonWG sync.WaitGroup
}

type Certs struct {
daemonCerts map[string]certs.CertAndKey
clientCert certs.CertAndKey
lock sync.Mutex
}

func GetClientCert() (certs.CertAndKey, error) {
certificates.lock.Lock()
defer certificates.lock.Unlock()
if len(certificates.clientCert.Cert) == 0 {
cert, err := certs.GenerateGPTScriptCert()
if err != nil {
return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err)
}
certificates.clientCert = cert
}
return certificates.clientCert, nil
}

func GetDaemonCert(toolID string) ([]byte, error) {
certificates.lock.Lock()
defer certificates.lock.Unlock()
cert, exists := certificates.daemonCerts[toolID]
if !exists {
return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID)
}
return cert.Cert, nil
}

func IsDaemonRunning(url string) bool {
ports.daemonLock.Lock()
defer ports.daemonLock.Unlock()
Expand Down Expand Up @@ -128,7 +162,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
tool.Instructions = types.CommandPrefix + instructions

port, ok := ports.daemonPorts[tool.ID]
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
if ok && ports.daemonsRunning[url] != nil {
return url, nil
}
Expand All @@ -144,11 +178,40 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {

ctx := ports.daemonCtx
port = nextPort()
url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path)

// Generate a certificate for the daemon, unless one already exists.
certificates.lock.Lock()
defer certificates.lock.Unlock()
cert, exists := certificates.daemonCerts[tool.ID]
if !exists {
var err error
cert, err = certs.GenerateSelfSignedCert(tool.ID)
if err != nil {
return "", fmt.Errorf("failed to generate certificate for daemon: %v", err)
}

if certificates.daemonCerts == nil {
certificates.daemonCerts = map[string]certs.CertAndKey{}
}
certificates.daemonCerts[tool.ID] = cert
}

// Set the client certificate if there isn't one already.
if len(certificates.clientCert.Cert) == 0 {
gptscriptCert, err := certs.GenerateGPTScriptCert()
if err != nil {
return "", fmt.Errorf("failed to generate GPTScript certificate: %v", err)
}
certificates.clientCert = gptscriptCert
}

cmd, stop, err := e.newCommand(ctx, []string{
fmt.Sprintf("PORT=%d", port),
fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)),
fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)),
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(certificates.clientCert.Cert)),
},
tool,
"{}",
Expand Down Expand Up @@ -210,8 +273,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
ports.daemonWG.Done()
}()

// Build HTTP client for checking the health of the daemon
tlsClientCert, err := tls.X509KeyPair(certificates.clientCert.Cert, certificates.clientCert.Key)
if err != nil {
return "", fmt.Errorf("failed to create client certificate: %v", err)
}

pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(cert.Cert) {
return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID)
}

httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
},
},
}

// Check the health of the daemon
for i := 0; i < 120; i++ {
resp, err := http.Get(url)
resp, err := httpClient.Get(url)
if err == nil && resp.StatusCode == http.StatusOK {
go func() {
_, _ = io.ReadAll(resp.Body)
Expand Down
72 changes: 71 additions & 1 deletion pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package engine

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
Expand All @@ -12,6 +14,7 @@ import (
"slices"
"strings"

"github.com/gptscript-ai/gptscript/pkg/certs"
"github.com/gptscript-ai/gptscript/pkg/types"
)

Expand Down Expand Up @@ -40,6 +43,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
return nil, err
}

var tlsConfigForDaemonRequest *tls.Config
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
Expand All @@ -60,6 +64,34 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
}
parsed.Host = toolURLParsed.Host
toolURL = parsed.String()

// Find the certificate corresponding to this daemon tool
certificates.lock.Lock()
daemonCert, exists := certificates.daemonCerts[referencedTool.ID]
clientCert := certificates.clientCert
certificates.lock.Unlock()

if !exists {
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
}

tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
if err != nil {
return nil, err
}
} else if isLocalhostHTTPS(toolURL) {
// This sometimes happens when talking to a model provider
certificates.lock.Lock()
daemonCert, exists := certificates.daemonCerts[tool.ID]
clientCert := certificates.clientCert
certificates.lock.Unlock()

if exists {
tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
if err != nil {
return nil, err
}
}
}

if tool.Blocking {
Expand Down Expand Up @@ -112,7 +144,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
req.Header.Set("Content-Type", "text/plain")
}

resp, err := http.DefaultClient.Do(req)
var httpClient *http.Client
if tlsConfigForDaemonRequest != nil {
httpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfigForDaemonRequest,
},
}
} else {
httpClient = http.DefaultClient
}

resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -143,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
Result: &s,
}, nil
}

func isLocalhostHTTPS(u string) bool {
parsed, err := url.Parse(u)
if err != nil {
return false
}

return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1")
}

func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(daemonCert) {
return nil, fmt.Errorf("failed to append daemon certificate")
}

tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
if err != nil {
return nil, fmt.Errorf("failed to create client certificate: %v", err)
}

return &tls.Config{
Certificates: []tls.Certificate{tlsClientCert},
RootCAs: pool,
InsecureSkipVerify: false,
}, nil
}
Loading
Loading