Skip to content

Commit 3c53e2a

Browse files
committed
feat: improve SDK server start up
Additionally, this change includes a way to run the server embeddedly in another process that may use stdin. Signed-off-by: Donnie Adams <[email protected]>
1 parent 3c29ebe commit 3c53e2a

File tree

2 files changed

+54
-24
lines changed

2 files changed

+54
-24
lines changed

pkg/cli/sdk_server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error {
2929
// Don't use cmd.Context() as we don't want to die on ctrl+c
3030
ctx := context.Background()
3131
if term.IsTerminal(int(os.Stdin.Fd())) {
32-
// Only support CTRL+C if stdin is the terminal. When ran as a SDK it will be a pipe
32+
// Only support CTRL+C if stdin is the terminal. When ran as an SDK it will be a pipe
3333
ctx = cmd.Context()
3434
}
3535

36-
return sdkserver.Start(ctx, sdkserver.Options{
36+
return sdkserver.Run(ctx, sdkserver.Options{
3737
Options: opts,
3838
ListenAddress: c.ListenAddress,
3939
Debug: c.Debug,

pkg/sdkserver/server.go

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"log/slog"
98
"net"
109
"net/http"
1110
"os"
@@ -29,7 +28,18 @@ type Options struct {
2928
Debug bool
3029
}
3130

32-
func Start(ctx context.Context, opts Options) error {
31+
// Run will start the server and block until the server is shut down.
32+
func Run(ctx context.Context, opts Options) error {
33+
listener, err := newListener(opts)
34+
if err != nil {
35+
return err
36+
}
37+
38+
_, err = io.WriteString(os.Stderr, listener.Addr().String()+"\n")
39+
if err != nil {
40+
return fmt.Errorf("failed to write to address to stderr: %w", err)
41+
}
42+
3343
sigCtx, cancel := signal.NotifyContext(ctx, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL)
3444
defer cancel()
3545
go func() {
@@ -40,6 +50,34 @@ func Start(ctx context.Context, opts Options) error {
4050
cancel()
4151
}()
4252

53+
return run(sigCtx, listener, opts)
54+
}
55+
56+
// EmbeddedStart allows running the server as an embedded process that may use Stdin for input.
57+
// It returns the address the server is listening on.
58+
func EmbeddedStart(ctx context.Context, opts Options) (string, error) {
59+
listener, err := newListener(opts)
60+
if err != nil {
61+
return "", err
62+
}
63+
64+
go func() {
65+
_ = run(ctx, listener, opts)
66+
}()
67+
68+
return listener.Addr().String(), nil
69+
}
70+
71+
func (s *server) close() {
72+
s.client.Close(true)
73+
s.events.Close()
74+
}
75+
76+
func newListener(opts Options) (net.Listener, error) {
77+
return net.Listen("tcp", opts.ListenAddress)
78+
}
79+
80+
func run(ctx context.Context, listener net.Listener, opts Options) error {
4381
if opts.Debug {
4482
mvl.SetDebug()
4583
}
@@ -58,11 +96,6 @@ func Start(ctx context.Context, opts Options) error {
5896
return err
5997
}
6098

61-
listener, err := net.Listen("tcp", opts.ListenAddress)
62-
if err != nil {
63-
return fmt.Errorf("failed to listen on %s: %w", opts.ListenAddress, err)
64-
}
65-
6699
s := &server{
67100
gptscriptOpts: opts.Options,
68101
address: listener.Addr().String(),
@@ -72,11 +105,11 @@ func Start(ctx context.Context, opts Options) error {
72105
waitingToConfirm: make(map[string]chan runner.AuthorizerResponse),
73106
waitingToPrompt: make(map[string]chan map[string]string),
74107
}
75-
defer s.Close()
108+
defer s.close()
76109

77110
s.addRoutes(http.DefaultServeMux)
78111

79-
server := http.Server{
112+
httpServer := &http.Server{
80113
Handler: apply(http.DefaultServeMux,
81114
contentType("application/json"),
82115
addRequestID,
@@ -86,25 +119,22 @@ func Start(ctx context.Context, opts Options) error {
86119
),
87120
}
88121

89-
slog.Info("Starting server", "addr", s.address)
90-
91-
context.AfterFunc(sigCtx, func() {
92-
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
122+
logger := mvl.Package()
123+
done := make(chan struct{})
124+
context.AfterFunc(ctx, func() {
125+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
93126
defer cancel()
94127

95-
slog.Info("Shutting down server")
96-
_ = server.Shutdown(ctx)
97-
slog.Info("Server stopped")
128+
logger.Infof("Shutting down server")
129+
_ = httpServer.Shutdown(ctx)
130+
logger.Infof("Server stopped")
131+
close(done)
98132
})
99133

100-
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
134+
if err = httpServer.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
101135
return fmt.Errorf("server error: %w", err)
102136
}
103137

138+
<-done
104139
return nil
105140
}
106-
107-
func (s *server) Close() {
108-
s.client.Close(true)
109-
s.events.Close()
110-
}

0 commit comments

Comments
 (0)