Skip to content

Commit 3033b05

Browse files
authored
Merge pull request #551 from njhale/fix/sdkserver-run-start-events
fix: send only one run start/finish event from sdkserver
2 parents fa97a9e + 5c06cf1 commit 3033b05

File tree

6 files changed

+36
-31
lines changed

6 files changed

+36
-31
lines changed

pkg/engine/engine.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,13 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co
165165
return context.WithValue(ctx, toolCategoryKey{}, toolCategory)
166166
}
167167

168-
func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) {
168+
func ToolCategoryFromContext(ctx context.Context) ToolCategory {
169169
category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory)
170+
return category
171+
}
172+
173+
func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) {
174+
category := ToolCategoryFromContext(ctx)
170175

171176
callCtx := Context{
172177
commonContext: commonContext{

pkg/monitor/display.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ type Console struct {
3838
callLock sync.Mutex
3939
}
4040

41-
var (
42-
prettyIDCounter int64
43-
)
41+
var prettyIDCounter int64
4442

4543
func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) {
4644
id := counter.Next()
@@ -290,7 +288,7 @@ func (d *display) Event(event runner.Event) {
290288
d.dump.Calls[currentIndex] = currentCall
291289
}
292290

293-
func (d *display) Stop(output string, err error) {
291+
func (d *display) Stop(_ context.Context, output string, err error) {
294292
d.callLock.Lock()
295293
defer d.callLock.Unlock()
296294

pkg/monitor/fd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (f *fd) event(event Event) {
139139
}
140140
}
141141

142-
func (f *fd) Stop(output string, err error) {
142+
func (f *fd) Stop(_ context.Context, output string, err error) {
143143
e := Event{
144144
Event: runner.Event{
145145
Time: time.Now(),

pkg/runner/monitor.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ import (
66
"github.com/gptscript-ai/gptscript/pkg/types"
77
)
88

9-
type noopFactory struct {
10-
}
9+
type noopFactory struct{}
1110

1211
func (n noopFactory) Start(context.Context, *types.Program, []string, string) (Monitor, error) {
1312
return noopMonitor{}, nil
@@ -17,13 +16,12 @@ func (n noopFactory) Pause() func() {
1716
return func() {}
1817
}
1918

20-
type noopMonitor struct {
21-
}
19+
type noopMonitor struct{}
2220

2321
func (n noopMonitor) Event(Event) {
2422
}
2523

26-
func (n noopMonitor) Stop(string, error) {}
24+
func (n noopMonitor) Stop(context.Context, string, error) {}
2725

2826
func (n noopMonitor) Pause() func() {
2927
return func() {}

pkg/runner/runner.go

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type MonitorFactory interface {
2626
type Monitor interface {
2727
Event(event Event)
2828
Pause() func()
29-
Stop(output string, err error)
29+
Stop(ctx context.Context, output string, err error)
3030
}
3131

3232
type Options struct {
@@ -162,7 +162,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
162162
return resp, err
163163
}
164164
defer func() {
165-
monitor.Stop(resp.Content, err)
165+
monitor.Stop(ctx, resp.Content, err)
166166
}()
167167

168168
callCtx, err := engine.NewContext(ctx, &prg, input)
@@ -425,9 +425,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
425425
}
426426
}
427427

428-
var (
429-
newState *State
430-
)
428+
var newState *State
431429
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
432430
if err != nil {
433431
return nil, err
@@ -632,9 +630,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
632630
Env: env,
633631
}
634632

635-
var (
636-
contentInput string
637-
)
633+
var contentInput string
638634

639635
if state.Continuation != nil && state.Continuation.State != nil {
640636
contentInput = state.Continuation.State.Input
@@ -745,9 +741,7 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
745741
}
746742

747743
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) {
748-
var (
749-
resultLock sync.Mutex
750-
)
744+
var resultLock sync.Mutex
751745

752746
if state.Continuation != nil {
753747
callCtx.LastReturn = state.Continuation

pkg/sdkserver/monitor.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"time"
77

88
"github.com/gptscript-ai/broadcaster"
9+
"github.com/gptscript-ai/gptscript/pkg/engine"
910
"github.com/gptscript-ai/gptscript/pkg/runner"
1011
gserver "github.com/gptscript-ai/gptscript/pkg/server"
1112
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -23,16 +24,19 @@ func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory {
2324

2425
func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) {
2526
id := gserver.RunIDFromContext(ctx)
27+
category := engine.ToolCategoryFromContext(ctx)
2628

27-
s.events.C <- event{
28-
Event: gserver.Event{
29-
Event: runner.Event{
30-
Time: time.Now(),
31-
Type: runner.EventTypeRunStart,
29+
if category == engine.NoCategory {
30+
s.events.C <- event{
31+
Event: gserver.Event{
32+
Event: runner.Event{
33+
Time: time.Now(),
34+
Type: runner.EventTypeRunStart,
35+
},
36+
RunID: id,
37+
Program: prg,
3238
},
33-
RunID: id,
34-
Program: prg,
35-
},
39+
}
3640
}
3741

3842
return &Session{
@@ -69,7 +73,13 @@ func (s *Session) Event(e runner.Event) {
6973
}
7074
}
7175

72-
func (s *Session) Stop(output string, err error) {
76+
func (s *Session) Stop(ctx context.Context, output string, err error) {
77+
category := engine.ToolCategoryFromContext(ctx)
78+
79+
if category != engine.NoCategory {
80+
return
81+
}
82+
7383
e := event{
7484
Event: gserver.Event{
7585
Event: runner.Event{

0 commit comments

Comments
 (0)