Skip to content

Commit dcf0506

Browse files
committed
feat: improve SDK server start up
Additionally, this change includes a way to run the server embeddedly in another process that may use stdin. Signed-off-by: Donnie Adams <[email protected]>
1 parent 3c29ebe commit dcf0506

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

pkg/cli/sdk_server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error {
2929
// Don't use cmd.Context() as we don't want to die on ctrl+c
3030
ctx := context.Background()
3131
if term.IsTerminal(int(os.Stdin.Fd())) {
32-
// Only support CTRL+C if stdin is the terminal. When ran as a SDK it will be a pipe
32+
// Only support CTRL+C if stdin is the terminal. When ran as an SDK it will be a pipe
3333
ctx = cmd.Context()
3434
}
3535

36-
return sdkserver.Start(ctx, sdkserver.Options{
36+
return sdkserver.Run(ctx, sdkserver.Options{
3737
Options: opts,
3838
ListenAddress: c.ListenAddress,
3939
Debug: c.Debug,

pkg/sdkserver/server.go

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"log/slog"
98
"net"
109
"net/http"
1110
"os"
@@ -29,7 +28,18 @@ type Options struct {
2928
Debug bool
3029
}
3130

32-
func Start(ctx context.Context, opts Options) error {
31+
// Run will start the server and block until the server is shut down.
32+
func Run(ctx context.Context, opts Options) error {
33+
listener, err := newListener(opts)
34+
if err != nil {
35+
return err
36+
}
37+
38+
_, err = io.WriteString(os.Stderr, listener.Addr().String()+"\n")
39+
if err != nil {
40+
return fmt.Errorf("failed to write to address to stderr: %w", err)
41+
}
42+
3343
sigCtx, cancel := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL)
3444
defer cancel()
3545
go func() {
@@ -40,6 +50,32 @@ func Start(ctx context.Context, opts Options) error {
4050
cancel()
4151
}()
4252

53+
return run(sigCtx, listener, opts)
54+
}
55+
56+
// EmbeddedStart allows running the server as an embedded process that may use Stdin for input.
57+
// It returns the address the server is listening on.
58+
func EmbeddedStart(ctx context.Context, opts Options) (string, error) {
59+
listener, err := newListener(opts)
60+
if err != nil {
61+
return "", err
62+
}
63+
64+
go run(ctx, listener, opts)
65+
66+
return listener.Addr().String(), nil
67+
}
68+
69+
func (s *server) close() {
70+
s.client.Close(true)
71+
s.events.Close()
72+
}
73+
74+
func newListener(opts Options) (net.Listener, error) {
75+
return net.Listen("tcp", opts.ListenAddress)
76+
}
77+
78+
func run(ctx context.Context, listener net.Listener, opts Options) error {
4379
if opts.Debug {
4480
mvl.SetDebug()
4581
}
@@ -58,11 +94,6 @@ func Start(ctx context.Context, opts Options) error {
5894
return err
5995
}
6096

61-
listener, err := net.Listen("tcp", opts.ListenAddress)
62-
if err != nil {
63-
return fmt.Errorf("failed to listen on %s: %w", opts.ListenAddress, err)
64-
}
65-
6697
s := &server{
6798
gptscriptOpts: opts.Options,
6899
address: listener.Addr().String(),
@@ -72,11 +103,11 @@ func Start(ctx context.Context, opts Options) error {
72103
waitingToConfirm: make(map[string]chan runner.AuthorizerResponse),
73104
waitingToPrompt: make(map[string]chan map[string]string),
74105
}
75-
defer s.Close()
106+
defer s.close()
76107

77108
s.addRoutes(http.DefaultServeMux)
78109

79-
server := http.Server{
110+
httpServer := &http.Server{
80111
Handler: apply(http.DefaultServeMux,
81112
contentType("application/json"),
82113
addRequestID,
@@ -86,25 +117,22 @@ func Start(ctx context.Context, opts Options) error {
86117
),
87118
}
88119

89-
slog.Info("Starting server", "addr", s.address)
90-
91-
context.AfterFunc(sigCtx, func() {
92-
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
120+
logger := mvl.Package()
121+
done := make(chan struct{})
122+
context.AfterFunc(ctx, func() {
123+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
93124
defer cancel()
94125

95-
slog.Info("Shutting down server")
96-
_ = server.Shutdown(ctx)
97-
slog.Info("Server stopped")
126+
logger.Infof("Shutting down server")
127+
_ = httpServer.Shutdown(ctx)
128+
logger.Infof("Server stopped")
129+
close(done)
98130
})
99131

100-
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
132+
if err = httpServer.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
101133
return fmt.Errorf("server error: %w", err)
102134
}
103135

136+
<-done
104137
return nil
105138
}
106-
107-
func (s *server) Close() {
108-
s.client.Close(true)
109-
s.events.Close()
110-
}

0 commit comments

Comments
 (0)