Skip to content

Commit d70f919

Browse files
chore: add ability to pass args to input/output filters
1 parent 6ec5178 commit d70f919

File tree

5 files changed

+149
-11
lines changed

5 files changed

+149
-11
lines changed

pkg/runner/input.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri
1818
data := map[string]any{}
1919
_ = json.Unmarshal([]byte(input), &data)
2020
data["input"] = input
21-
inputData, err := json.Marshal(data)
21+
22+
inputArgs, err := argsForFilters(callCtx.Program, inputToolRef, &State{
23+
StartInput: &input,
24+
}, data)
2225
if err != nil {
2326
return "", fmt.Errorf("failed to marshal input: %w", err)
2427
}
2528

26-
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory)
29+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, inputArgs, "", engine.InputToolCategory)
2730
if err != nil {
2831
return "", err
2932
}

pkg/runner/output.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,48 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"maps"
8+
"strings"
79

810
"github.com/gptscript-ai/gptscript/pkg/engine"
911
"github.com/gptscript-ai/gptscript/pkg/types"
1012
)
1113

12-
func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
14+
func argsForFilters(prg *types.Program, tool types.ToolReference, startState *State, filterDefinedInput map[string]any) (string, error) {
15+
startInput := ""
16+
if startState.ResumeInput != nil {
17+
startInput = *startState.ResumeInput
18+
} else if startState.StartInput != nil {
19+
startInput = *startState.StartInput
20+
}
21+
22+
parsedArgs, err := getToolRefInput(prg, tool, startInput)
23+
if err != nil {
24+
return "", err
25+
}
26+
27+
argData := map[string]any{}
28+
if strings.HasPrefix(parsedArgs, "{") {
29+
if err := json.Unmarshal([]byte(parsedArgs), &argData); err != nil {
30+
return "", fmt.Errorf("failed to unmarshal parsedArgs for filter: %w", err)
31+
}
32+
} else if _, hasInput := filterDefinedInput["input"]; parsedArgs != "" && !hasInput {
33+
argData["input"] = parsedArgs
34+
}
35+
36+
resultData := map[string]any{}
37+
maps.Copy(resultData, filterDefinedInput)
38+
maps.Copy(resultData, argData)
39+
40+
result, err := json.Marshal(resultData)
41+
if err != nil {
42+
return "", fmt.Errorf("failed to marshal resultData for filter: %w", err)
43+
}
44+
45+
return string(result), nil
46+
}
47+
48+
func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, startState, state *State, retErr error) (*State, error) {
1349
outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput)
1450
if err != nil {
1551
return nil, err
@@ -40,7 +76,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
4076
}
4177

4278
for _, outputToolRef := range outputToolRefs {
43-
inputData, err := json.Marshal(map[string]any{
79+
inputData, err := argsForFilters(callCtx.Program, outputToolRef, startState, map[string]any{
4480
"output": output,
4581
"continuation": continuation,
4682
"chat": callCtx.Tool.Chat,

pkg/runner/runner.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
269269
outputMap := map[string]interface{}{}
270270

271271
_ = json.Unmarshal([]byte(input), &inputMap)
272+
for k, v := range inputMap {
273+
inputMap[strings.ToLower(k)] = v
274+
}
272275

273276
fields := strings.Fields(ref.Arg)
274277

@@ -291,7 +294,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
291294
key := strings.TrimPrefix(field, "$")
292295
key = strings.TrimPrefix(key, "{")
293296
key = strings.TrimSuffix(key, "}")
294-
val = inputMap[key]
297+
val = inputMap[strings.ToLower(key)]
295298
} else {
296299
val = field
297300
}
@@ -425,6 +428,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
425428
msg = "Tool call request has been denied"
426429
}
427430
return &State{
431+
StartInput: &input,
428432
Continuation: &engine.Return{
429433
Result: &msg,
430434
},
@@ -438,6 +442,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
438442
}
439443

440444
return &State{
445+
StartInput: &input,
441446
Continuation: ret,
442447
}, nil
443448
}
@@ -447,6 +452,8 @@ type State struct {
447452
ContinuationToolID string `json:"continuationToolID,omitempty"`
448453
Result *string `json:"result,omitempty"`
449454

455+
StartInput *string `json:"startInput,omitempty"`
456+
450457
ResumeInput *string `json:"resumeInput,omitempty"`
451458
SubCalls []SubCallResult `json:"subCalls,omitempty"`
452459
SubCallID string `json:"subCallID,omitempty"`
@@ -485,14 +492,9 @@ func (s State) ContinuationContent() (string, error) {
485492
return "", fmt.Errorf("illegal state: no result message found in chat response")
486493
}
487494

488-
type Needed struct {
489-
Content string `json:"content,omitempty"`
490-
Input string `json:"input,omitempty"`
491-
}
492-
493495
func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) {
494496
defer func() {
495-
retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr)
497+
retState, retErr = r.handleOutput(callCtx, monitor, env, state, retState, retErr)
496498
}()
497499

498500
if state.Continuation == nil {

pkg/tests/runner2_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package tests
22

33
import (
44
"context"
5+
"encoding/json"
56
"testing"
67

78
"github.com/gptscript-ai/gptscript/pkg/loader"
89
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
10+
"github.com/hexops/autogold/v2"
911
"github.com/stretchr/testify/require"
1012
)
1113

@@ -111,3 +113,92 @@ echo '{"env": {"CRED2": "that also worked"}}'
111113
resp, err := r.Chat(context.Background(), nil, prg, nil, "")
112114
r.AssertStep(t, resp, err)
113115
}
116+
117+
func TestFilterArgs(t *testing.T) {
118+
r := tester.NewRunner(t)
119+
prg, err := loader.ProgramFromSource(context.Background(), `
120+
inputfilters: input with ${Foo}
121+
inputfilters: input with foo
122+
inputfilters: input with *
123+
outputfilters: output with *
124+
outputfilters: output with foo
125+
outputfilters: output with ${Foo}
126+
params: Foo: a description
127+
128+
#!/bin/bash
129+
echo ${FOO}
130+
131+
---
132+
name: input
133+
params: notfoo: a description
134+
135+
#!/bin/bash
136+
echo "${GPTSCRIPT_INPUT}"
137+
138+
---
139+
name: output
140+
params: notfoo: a description
141+
142+
#!/bin/bash
143+
echo "${GPTSCRIPT_INPUT}"
144+
`, "")
145+
require.NoError(t, err)
146+
147+
resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`)
148+
r.AssertStep(t, resp, err)
149+
150+
data := map[string]any{}
151+
err = json.Unmarshal([]byte(resp.Content), &data)
152+
require.NoError(t, err)
153+
154+
autogold.Expect(map[string]interface{}{
155+
"chat": false,
156+
"continuation": false,
157+
"notfoo": "baz",
158+
"output": `{"chat":false,"continuation":false,"notfoo":"foo","output":"{\"chat\":false,\"continuation\":false,\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\",\\\"input\\\":\\\"{\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\", \\\\\\\"start\\\\\\\": true}\\\",\\\"notfoo\\\":\\\"baz\\\",\\\"start\\\":true}\\n\",\"notfoo\":\"foo\",\"output\":\"baz\\n\",\"start\":true}\n"}
159+
`,
160+
}).Equal(t, data)
161+
162+
val := data["output"].(string)
163+
data = map[string]any{}
164+
err = json.Unmarshal([]byte(val), &data)
165+
require.NoError(t, err)
166+
autogold.Expect(map[string]interface{}{
167+
"chat": false,
168+
"continuation": false,
169+
"notfoo": "foo",
170+
"output": `{"chat":false,"continuation":false,"foo":"baz","input":"{\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\", \\\"start\\\": true}\",\"notfoo\":\"baz\",\"start\":true}\n","notfoo":"foo","output":"baz\n","start":true}
171+
`,
172+
}).Equal(t, data)
173+
174+
val = data["output"].(string)
175+
data = map[string]any{}
176+
err = json.Unmarshal([]byte(val), &data)
177+
require.NoError(t, err)
178+
autogold.Expect(map[string]interface{}{
179+
"chat": false,
180+
"continuation": false,
181+
"foo": "baz", "input": `{"foo":"baz","input":"{\"foo\":\"baz\", \"start\": true}","notfoo":"baz","start":true}
182+
`,
183+
"notfoo": "foo",
184+
"output": "baz\n",
185+
"start": true,
186+
}).Equal(t, data)
187+
188+
val = data["input"].(string)
189+
data = map[string]any{}
190+
err = json.Unmarshal([]byte(val), &data)
191+
require.NoError(t, err)
192+
autogold.Expect(map[string]interface{}{
193+
"foo": "baz",
194+
"input": `{"foo":"baz", "start": true}`,
195+
"notfoo": "baz",
196+
"start": true,
197+
}).Equal(t, data)
198+
199+
val = data["input"].(string)
200+
data = map[string]any{}
201+
err = json.Unmarshal([]byte(val), &data)
202+
require.NoError(t, err)
203+
autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data)
204+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
`{
2+
"done": true,
3+
"content": "{\"chat\":false,\"continuation\":false,\"notfoo\":\"baz\",\"output\":\"{\\\"chat\\\":false,\\\"continuation\\\":false,\\\"notfoo\\\":\\\"foo\\\",\\\"output\\\":\\\"{\\\\\\\"chat\\\\\\\":false,\\\\\\\"continuation\\\\\\\":false,\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\",\\\\\\\"input\\\\\\\":\\\\\\\"{\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"input\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"{\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\", \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\": true}\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"notfoo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\":true}\\\\\\\\n\\\\\\\",\\\\\\\"notfoo\\\\\\\":\\\\\\\"foo\\\\\\\",\\\\\\\"output\\\\\\\":\\\\\\\"baz\\\\\\\\n\\\\\\\",\\\\\\\"start\\\\\\\":true}\\\\n\\\"}\\n\"}\n",
4+
"toolID": "",
5+
"state": null
6+
}`

0 commit comments

Comments
 (0)