Skip to content

Commit 38ef1ce

Browse files
committed
feat: ensure closing a Run works in exec and request contexts
Signed-off-by: Donnie Adams <[email protected]>
1 parent bc1986e commit 38ef1ce

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

client_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,30 @@ func TestListModels(t *testing.T) {
5757
}
5858
}
5959

60+
func TestAbortRun(t *testing.T) {
61+
tool := &ToolDef{Instructions: "What is the capital of the united states?"}
62+
63+
run, err := client.Evaluate(context.Background(), Opts{DisableCache: true, IncludeEvents: true}, tool)
64+
if err != nil {
65+
t.Errorf("Error executing tool: %v", err)
66+
}
67+
68+
// Abort the run after the first event.
69+
<-run.Events()
70+
71+
if err := run.Close(); err != nil {
72+
t.Errorf("Error aborting run: %v", err)
73+
}
74+
75+
if run.State() != Error {
76+
t.Errorf("Unexpected run state: %s", run.State())
77+
}
78+
79+
if run.Err() == nil {
80+
t.Error("Expected error but got nil")
81+
}
82+
}
83+
6084
func TestSimpleEvaluate(t *testing.T) {
6185
tool := &ToolDef{Instructions: "What is the capital of the united states?"}
6286

run.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ import (
1717
"sync"
1818
)
1919

20+
var abortRunError = errors.New("run aborted")
21+
2022
type Run struct {
2123
url, binPath, requestPath, toolPath, content string
2224
opts Opts
2325
state RunState
2426
chatState string
25-
cmd *exec.Cmd
27+
cancel context.CancelCauseFunc
2628
err error
2729
stdout, stderr io.Reader
2830
wait func() error
@@ -61,6 +63,11 @@ func (r *Run) State() RunState {
6163
return r.state
6264
}
6365

66+
// Err returns the error that caused the gptscript to fail, if any.
67+
func (r *Run) Err() error {
68+
return r.err
69+
}
70+
6471
// ErrorOutput returns the stderr output of the gptscript.
6572
// Should only be called after Bytes or Text has returned an error.
6673
func (r *Run) ErrorOutput() string {
@@ -75,20 +82,20 @@ func (r *Run) Events() <-chan Event {
7582
// Close will stop the gptscript run, if it is running.
7683
func (r *Run) Close() error {
7784
// If the command was not started, then report error.
78-
if r.cmd == nil || r.cmd.Process == nil {
85+
if r.cancel == nil {
7986
return fmt.Errorf("run not started")
8087
}
8188

82-
// If the command has already exited, then nothing to do.
83-
if r.cmd.ProcessState != nil {
89+
r.cancel(abortRunError)
90+
if r.wait == nil {
8491
return nil
8592
}
8693

87-
if err := r.cmd.Process.Signal(os.Kill); err != nil {
94+
if err := r.wait(); !errors.Is(err, abortRunError) && !errors.Is(err, context.Canceled) && !errors.As(err, new(*exec.ExitError)) {
8895
return err
8996
}
9097

91-
return r.wait()
98+
return nil
9299
}
93100

94101
// RawOutput returns the raw output of the gptscript. Most users should use Text or Bytes instead.
@@ -169,26 +176,26 @@ func (r *Run) exec(ctx context.Context, extraArgs ...string) error {
169176
args = append(args, r.toolPath)
170177
}
171178

172-
cancelCtx, cancel := context.WithCancel(ctx)
179+
cancelCtx, cancel := context.WithCancelCause(ctx)
180+
r.cancel = cancel
173181
c, stdout, stderr, err := setupForkCommand(cancelCtx, r.binPath, r.content, r.opts.Input, args, eventsWrite)
174182
if err != nil {
175-
cancel()
183+
r.err = fmt.Errorf("failed to setup gptscript: %w", err)
184+
r.cancel(r.err)
176185
_ = eventsRead.Close()
177186
r.state = Error
178-
r.err = fmt.Errorf("failed to setup gptscript: %w", err)
179187
return r.err
180188
}
181189

182190
if err = c.Start(); err != nil {
183-
cancel()
191+
r.err = fmt.Errorf("failed to start gptscript: %w", err)
192+
r.cancel(r.err)
184193
_ = eventsRead.Close()
185194
r.state = Error
186-
r.err = fmt.Errorf("failed to start gptscript: %w", err)
187195
return r.err
188196
}
189197

190198
r.state = Running
191-
r.cmd = c
192199
r.stdout = stdout
193200
r.stderr = stderr
194201
r.events = make(chan Event, 100)
@@ -197,14 +204,15 @@ func (r *Run) exec(ctx context.Context, extraArgs ...string) error {
197204
r.wait = func() error {
198205
err := c.Wait()
199206
_ = eventsRead.Close()
200-
cancel()
201207
if err != nil {
202208
r.state = Error
203-
r.err = fmt.Errorf("failed to wait for gptscript: %w", err)
209+
r.err = fmt.Errorf("failed to wait for gptscript: error: %w, stderr: %s", err, string(r.errput))
210+
r.cancel(r.err)
204211
} else {
205212
if r.state == Running {
206213
r.state = Finished
207214
}
215+
r.cancel(nil)
208216
}
209217
return r.err
210218
}
@@ -255,7 +263,7 @@ func (r *Run) readEvents(ctx context.Context, events io.Reader) {
255263

256264
if err != nil && !errors.Is(err, io.EOF) {
257265
slog.Debug("failed to read events", "error", err)
258-
r.err = fmt.Errorf("failed to read events: %w", err)
266+
r.err = fmt.Errorf("failed to read events: error: %w, stderr: %s", err, string(r.errput))
259267
}
260268
}
261269

@@ -332,6 +340,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
332340
cancelCtx, cancel = context.WithCancelCause(ctx)
333341
)
334342

343+
r.cancel = cancel
335344
defer func() {
336345
if err != nil {
337346
cancel(err)
@@ -465,6 +474,8 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
465474
if err := context.Cause(cancelCtx); !errors.Is(err, context.Canceled) && r.err == nil {
466475
r.state = Error
467476
r.err = err
477+
} else if r.state != Continue {
478+
r.state = Finished
468479
}
469480
return r.err
470481
}

0 commit comments

Comments
 (0)