Skip to content

Commit 868dded

Browse files
committed
enhance: add mTLS between gptscript and daemon tools
Signed-off-by: Grant Linville <[email protected]>
1 parent e5fe428 commit 868dded

File tree

7 files changed

+183
-10
lines changed

7 files changed

+183
-10
lines changed

pkg/certs/certs.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package certs
2+
3+
import (
4+
"crypto/ecdsa"
5+
"crypto/elliptic"
6+
"crypto/rand"
7+
"crypto/x509"
8+
"crypto/x509/pkix"
9+
"encoding/pem"
10+
"fmt"
11+
"math/big"
12+
"net"
13+
"time"
14+
)
15+
16+
// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format)
17+
type CertAndKey struct {
18+
Cert []byte
19+
Key []byte
20+
}
21+
22+
func GenerateGPTScriptCert() (CertAndKey, error) {
23+
return GenerateSelfSignedCert("gptscript server")
24+
}
25+
26+
func GenerateSelfSignedCert(name string) (CertAndKey, error) {
27+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
28+
if err != nil {
29+
return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err)
30+
}
31+
32+
marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey)
33+
if err != nil {
34+
return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err)
35+
}
36+
37+
marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey})
38+
39+
template := &x509.Certificate{
40+
SerialNumber: big.NewInt(time.Now().UnixNano()),
41+
Subject: pkix.Name{
42+
CommonName: name,
43+
},
44+
NotBefore: time.Now(),
45+
NotAfter: time.Now().AddDate(1, 0, 0), // a year from now
46+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
47+
ExtKeyUsage: []x509.ExtKeyUsage{
48+
x509.ExtKeyUsageServerAuth,
49+
x509.ExtKeyUsageClientAuth,
50+
},
51+
IsCA: false,
52+
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
53+
}
54+
55+
cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
56+
if err != nil {
57+
return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err)
58+
}
59+
60+
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})
61+
62+
return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil
63+
}

pkg/engine/daemon.go

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package engine
22

33
import (
44
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"encoding/base64"
58
"fmt"
69
"io"
710
"math/rand"
@@ -11,11 +14,13 @@ import (
1114
"sync"
1215
"time"
1316

17+
"github.com/gptscript-ai/gptscript/pkg/certs"
1418
"github.com/gptscript-ai/gptscript/pkg/system"
1519
"github.com/gptscript-ai/gptscript/pkg/types"
1620
)
1721

1822
var ports Ports
23+
var certificates Certs
1924

2025
type Ports struct {
2126
daemonPorts map[string]int64
@@ -29,6 +34,11 @@ type Ports struct {
2934
daemonWG sync.WaitGroup
3035
}
3136

37+
type Certs struct {
38+
daemonCerts map[string]certs.CertAndKey
39+
daemonLock sync.Mutex
40+
}
41+
3242
func IsDaemonRunning(url string) bool {
3343
ports.daemonLock.Lock()
3444
defer ports.daemonLock.Unlock()
@@ -117,7 +127,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
117127
tool.Instructions = types.CommandPrefix + instructions
118128

119129
port, ok := ports.daemonPorts[tool.ID]
120-
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
130+
url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
121131
if ok {
122132
return url, nil
123133
}
@@ -133,11 +143,31 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
133143

134144
ctx := ports.daemonCtx
135145
port = nextPort()
136-
url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
146+
url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
147+
148+
// Generate a certificate for the daemon, unless one already exists.
149+
certificates.daemonLock.Lock()
150+
defer certificates.daemonLock.Unlock()
151+
cert, exists := certificates.daemonCerts[tool.ID]
152+
if !exists {
153+
var err error
154+
cert, err = certs.GenerateSelfSignedCert(tool.ID)
155+
if err != nil {
156+
return "", fmt.Errorf("failed to generate certificate for daemon: %v", err)
157+
}
158+
159+
if certificates.daemonCerts == nil {
160+
certificates.daemonCerts = map[string]certs.CertAndKey{}
161+
}
162+
certificates.daemonCerts[tool.ID] = cert
163+
}
137164

138165
cmd, stop, err := e.newCommand(ctx, []string{
139166
fmt.Sprintf("PORT=%d", port),
167+
fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)),
168+
fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)),
140169
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
170+
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)),
141171
},
142172
tool,
143173
"{}",
@@ -199,8 +229,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
199229
ports.daemonWG.Done()
200230
}()
201231

232+
// Build HTTP client for checking the health of the daemon
233+
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
234+
if err != nil {
235+
return "", fmt.Errorf("failed to create client certificate: %v", err)
236+
}
237+
238+
pool := x509.NewCertPool()
239+
if !pool.AppendCertsFromPEM(cert.Cert) {
240+
return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID)
241+
}
242+
243+
httpClient := &http.Client{
244+
Transport: &http.Transport{
245+
TLSClientConfig: &tls.Config{
246+
Certificates: []tls.Certificate{clientCert},
247+
RootCAs: pool,
248+
InsecureSkipVerify: false,
249+
},
250+
},
251+
}
252+
253+
// Check the health of the daemon
202254
for i := 0; i < 120; i++ {
203-
resp, err := http.Get(url)
255+
resp, err := httpClient.Get(url)
204256
if err == nil && resp.StatusCode == http.StatusOK {
205257
go func() {
206258
_, _ = io.ReadAll(resp.Body)

pkg/engine/engine.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88
"sync"
99

10+
"github.com/gptscript-ai/gptscript/pkg/certs"
1011
"github.com/gptscript-ai/gptscript/pkg/counter"
1112
"github.com/gptscript-ai/gptscript/pkg/types"
1213
"github.com/gptscript-ai/gptscript/pkg/version"
@@ -22,6 +23,7 @@ type RuntimeManager interface {
2223
}
2324

2425
type Engine struct {
26+
GPTScriptCert certs.CertAndKey
2527
Model Model
2628
RuntimeManager RuntimeManager
2729
Env []string

pkg/engine/http.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package engine
22

33
import (
44
"context"
5+
"crypto/tls"
6+
"crypto/x509"
57
"encoding/json"
68
"fmt"
79
"io"
@@ -40,6 +42,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
4042
return nil, err
4143
}
4244

45+
var tlsConfigForDaemonRequest *tls.Config
4346
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
4447
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
4548
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
@@ -60,6 +63,33 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
6063
}
6164
parsed.Host = toolURLParsed.Host
6265
toolURL = parsed.String()
66+
67+
// Find the certificate corresponding to this daemon tool
68+
certificates.daemonLock.Lock()
69+
daemonCert, exists := certificates.daemonCerts[referencedTool.ID]
70+
certificates.daemonLock.Unlock()
71+
72+
if !exists {
73+
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
74+
}
75+
76+
// Create a pool for the certificate to treat as a CA
77+
pool := x509.NewCertPool()
78+
if !pool.AppendCertsFromPEM(daemonCert.Cert) {
79+
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
80+
}
81+
82+
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to create client certificate: %v", err)
85+
}
86+
87+
// Create TLS config for use in the HTTP client later
88+
tlsConfigForDaemonRequest = &tls.Config{
89+
Certificates: []tls.Certificate{clientCert},
90+
RootCAs: pool,
91+
InsecureSkipVerify: false,
92+
}
6393
}
6494

6595
if tool.Blocking {
@@ -112,7 +142,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
112142
req.Header.Set("Content-Type", "text/plain")
113143
}
114144

115-
resp, err := http.DefaultClient.Do(req)
145+
var httpClient *http.Client
146+
if tlsConfigForDaemonRequest != nil {
147+
httpClient = &http.Client{
148+
Transport: &http.Transport{
149+
TLSClientConfig: tlsConfigForDaemonRequest,
150+
},
151+
}
152+
} else {
153+
httpClient = http.DefaultClient
154+
}
155+
156+
resp, err := httpClient.Do(req)
116157
if err != nil {
117158
return nil, err
118159
}

pkg/gptscript/gptscript.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/gptscript-ai/gptscript/pkg/builtin"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
"github.com/gptscript-ai/gptscript/pkg/certs"
1516
"github.com/gptscript-ai/gptscript/pkg/config"
1617
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1718
"github.com/gptscript-ai/gptscript/pkg/credentials"
@@ -107,7 +108,12 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
107108
opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir)
108109
}
109110

110-
simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env)
111+
gptscriptCert, err := certs.GenerateGPTScriptCert()
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert)
111117
if err != nil {
112118
return nil, err
113119
}
@@ -140,7 +146,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
140146
opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet})
141147
}
142148

143-
runner, err := runner.New(registry, credStore, opts.Runner)
149+
runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner)
144150
if err != nil {
145151
return nil, err
146152
}
@@ -285,8 +291,8 @@ type simpleRunner struct {
285291
env []string
286292
}
287293

288-
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) {
289-
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{
294+
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) {
295+
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{
290296
RuntimeManager: rm,
291297
MonitorFactory: simpleMonitorFactory{},
292298
})

pkg/runner/runner.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/gptscript-ai/gptscript/pkg/builtin"
14+
"github.com/gptscript-ai/gptscript/pkg/certs"
1415
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1516
"github.com/gptscript-ai/gptscript/pkg/credentials"
1617
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -95,9 +96,10 @@ type Runner struct {
9596
credOverrides []string
9697
credStore credentials.CredentialStore
9798
sequential bool
99+
gptscriptCert certs.CertAndKey
98100
}
99101

100-
func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) {
102+
func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) {
101103
opt := complete(opts...)
102104

103105
runner := &Runner{
@@ -109,6 +111,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt
109111
credStore: credStore,
110112
sequential: opt.Sequential,
111113
auth: opt.Authorizer,
114+
gptscriptCert: gptscriptCert,
112115
}
113116

114117
if opt.StartPort != 0 {
@@ -411,6 +414,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
411414
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
412415
Progress: progress,
413416
Env: env,
417+
GPTScriptCert: r.gptscriptCert,
414418
}
415419

416420
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
@@ -593,6 +597,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
593597
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
594598
Progress: progress,
595599
Env: env,
600+
GPTScriptCert: r.gptscriptCert,
596601
}
597602

598603
var contentInput string

pkg/tests/tester/runner.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/adrg/xdg"
12+
"github.com/gptscript-ai/gptscript/pkg/certs"
1213
"github.com/gptscript-ai/gptscript/pkg/credentials"
1314
"github.com/gptscript-ai/gptscript/pkg/loader"
1415
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
@@ -198,7 +199,10 @@ func NewRunner(t *testing.T) *Runner {
198199

199200
rm := runtimes.Default(cacheDir, "")
200201

201-
run, err := runner.New(c, credentials.NoopStore{}, runner.Options{
202+
gptscriptCert, err := certs.GenerateGPTScriptCert()
203+
require.NoError(t, err)
204+
205+
run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{
202206
Sequential: true,
203207
RuntimeManager: rm,
204208
})

0 commit comments

Comments
 (0)