Skip to content

feat: add support for wildcard subtool names #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ func (c *Context) ParentID() string {
func (c *Context) GetCallContext() *CallContext {
var toolName string
if c.Parent != nil {
for name, id := range c.Parent.Tool.ToolMapping {
if id == c.Tool.ID {
toolName = name
break
outer:
for name, refs := range c.Parent.Tool.ToolMapping {
for _, ref := range refs {
if ref.ToolID == c.Tool.ID {
toolName = name
break outer
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too

if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
referencedToolID, ok := tool.ToolMapping[referencedToolName]
if !ok {
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
if !ok || len(referencedToolRefs) != 1 {
return nil, fmt.Errorf("invalid reference [%s] to tool [%s] from [%s], missing \"tools: %s\" parameter", toolURL, referencedToolName, tool.Source, referencedToolName)
}
referencedTool, ok := prg.ToolSet[referencedToolID]
referencedTool, ok := prg.ToolSet[referencedToolRefs[0].ToolID]
if !ok {
return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname())
}
Expand Down
82 changes: 54 additions & 28 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,15 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
return openAPIDocument
}

func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) {
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) {
data := base.Content

if bytes.HasPrefix(data, assemble.Header) {
return loadProgram(data, prg, targetToolName)
tool, err := loadProgram(data, prg, targetToolName)
if err != nil {
return nil, err
}
return []types.Tool{tool}, nil
}

var (
Expand All @@ -200,7 +204,7 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
tools, err = getOpenAPITools(openAPIDocument, "")
}
if err != nil {
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
return nil, fmt.Errorf("error parsing OpenAPI definition: %w", err)
}
}

Expand All @@ -222,17 +226,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
AssignGlobals: true,
})
if err != nil {
return types.Tool{}, err
return nil, err
}
}

if len(tools) == 0 {
return types.Tool{}, fmt.Errorf("no tools found in %s", base)
return nil, fmt.Errorf("no tools found in %s", base)
}

var (
localTools = types.ToolSet{}
mainTool types.Tool
localTools = types.ToolSet{}
targetTools []types.Tool
)

for i, tool := range tools {
Expand All @@ -243,44 +247,65 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
// Probably a better way to come up with an ID
tool.ID = tool.Source.Location + ":" + tool.Name

if i == 0 {
mainTool = tool
if i == 0 && targetToolName == "" {
targetTools = append(targetTools, tool)
}

if i != 0 && tool.Parameters.Name == "" {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name"))
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name"))
}

if i != 0 && tool.Parameters.GlobalModelName != "" {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name"))
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name"))
}

if i != 0 && len(tool.Parameters.GlobalTools) > 0 {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools"))
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools"))
}

if targetToolName != "" && strings.EqualFold(tool.Parameters.Name, targetToolName) {
mainTool = tool
if targetToolName != "" && tool.Parameters.Name != "" {
if strings.EqualFold(tool.Parameters.Name, targetToolName) {
targetTools = append(targetTools, tool)
} else if strings.Contains(targetToolName, "*") {
match, err := filepath.Match(strings.ToLower(targetToolName), strings.ToLower(tool.Parameters.Name))
if err != nil {
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, err)
}
if match {
targetTools = append(targetTools, tool)
}
}
}

if existing, ok := localTools[strings.ToLower(tool.Parameters.Name)]; ok {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo,
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo,
fmt.Errorf("duplicate tool name [%s] in %s found at lines %d and %d", tool.Parameters.Name, tool.Source.Location,
tool.Source.LineNo, existing.Source.LineNo))
}

localTools[strings.ToLower(tool.Parameters.Name)] = tool
}

return link(ctx, cache, prg, base, mainTool, localTools)
return linkAll(ctx, cache, prg, base, targetTools, localTools)
}

func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) {
for _, tool := range tools {
tool, err := link(ctx, cache, prg, base, tool, localTools)
if err != nil {
return nil, err
}
result = append(result, tool)
}
return
}

func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) {
if existing, ok := prg.ToolSet[tool.ID]; ok {
return existing, nil
}

tool.ToolMapping = map[string]string{}
tool.ToolMapping = map[string][]types.ToolReference{}
tool.LocalTools = map[string]string{}
toolNames := map[string]struct{}{}

Expand Down Expand Up @@ -310,16 +335,17 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
}
}

tool.ToolMapping[targetToolName] = linkedTool.ID
tool.AddToolMapping(targetToolName, linkedTool)
toolNames[targetToolName] = struct{}{}
} else {
toolName, subTool := types.SplitToolRef(targetToolName)
resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool)
resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool)
if err != nil {
return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err)
}

tool.ToolMapping[targetToolName] = resolvedTool.ID
for _, resolvedTool := range resolvedTools {
tool.AddToolMapping(targetToolName, resolvedTool)
}
}
}

Expand All @@ -345,14 +371,14 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
prg := types.Program{
ToolSet: types.ToolSet{},
}
tool, err := readTool(ctx, opt.Cache, &prg, &source{
tools, err := readTool(ctx, opt.Cache, &prg, &source{
Content: []byte(content),
Location: "inline",
}, subToolName)
if err != nil {
return types.Program{}, err
}
prg.EntryToolID = tool.ID
prg.EntryToolID = tools[0].ID
return prg, nil
}

Expand Down Expand Up @@ -385,26 +411,26 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
Name: name,
ToolSet: types.ToolSet{},
}
tool, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName)
tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName)
if err != nil {
return types.Program{}, err
}
prg.EntryToolID = tool.ID
prg.EntryToolID = tools[0].ID
return prg, nil
}

func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) (types.Tool, error) {
func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) {
if subTool == "" {
t, ok := builtin.Builtin(name)
if ok {
prg.ToolSet[t.ID] = t
return t, nil
return []types.Tool{t}, nil
}
}

s, err := input(ctx, cache, base, name)
if err != nil {
return types.Tool{}, err
return nil, err
}

return readTool(ctx, cache, prg, s, subTool)
Expand Down
7 changes: 6 additions & 1 deletion pkg/loader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ func TestHelloWorld(t *testing.T) {
"instructions": "call bob",
"id": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:",
"toolMapping": {
"../bob.gpt": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:"
"../bob.gpt": [
{
"reference": "../bob.gpt",
"toolID": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:"
}
]
},
"localTools": {
"": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:"
Expand Down
8 changes: 4 additions & 4 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,12 +834,12 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists {
credToolID, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok {
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok || len(credToolRefs) != 1 {
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
if err != nil {
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
}
Expand Down Expand Up @@ -874,7 +874,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

// Only store the credential if the tool is on GitHub, and the credential is non-empty.
if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolID].Source.Repo != nil {
if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", credToolName)
} else if err := store.Add(*cred); err != nil {
Expand Down
85 changes: 85 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,91 @@ func toJSONString(t *testing.T, v interface{}) string {
return string(x)
}

func TestAsterick(t *testing.T) {
r := tester.NewRunner(t)
p, err := r.Load("")
require.NoError(t, err)
autogold.Expect(`{
"name": "testdata/TestAsterick/test.gpt",
"entryToolId": "testdata/TestAsterick/test.gpt:",
"toolSet": {
"testdata/TestAsterick/other.gpt:a": {
"name": "a",
"modelName": "gpt-4o",
"internalPrompt": null,
"instructions": "a",
"id": "testdata/TestAsterick/other.gpt:a",
"localTools": {
"a": "testdata/TestAsterick/other.gpt:a",
"afoo": "testdata/TestAsterick/other.gpt:afoo",
"foo": "testdata/TestAsterick/other.gpt:foo",
"fooa": "testdata/TestAsterick/other.gpt:fooa",
"fooafoo": "testdata/TestAsterick/other.gpt:fooafoo"
},
"source": {
"location": "testdata/TestAsterick/other.gpt",
"lineNo": 10
},
"workingDir": "testdata/TestAsterick"
},
"testdata/TestAsterick/other.gpt:afoo": {
"name": "afoo",
"modelName": "gpt-4o",
"internalPrompt": null,
"instructions": "afoo",
"id": "testdata/TestAsterick/other.gpt:afoo",
"localTools": {
"a": "testdata/TestAsterick/other.gpt:a",
"afoo": "testdata/TestAsterick/other.gpt:afoo",
"foo": "testdata/TestAsterick/other.gpt:foo",
"fooa": "testdata/TestAsterick/other.gpt:fooa",
"fooafoo": "testdata/TestAsterick/other.gpt:fooafoo"
},
"source": {
"location": "testdata/TestAsterick/other.gpt",
"lineNo": 4
},
"workingDir": "testdata/TestAsterick"
},
"testdata/TestAsterick/test.gpt:": {
"modelName": "gpt-4o",
"internalPrompt": null,
"tools": [
"a* from ./other.gpt"
],
"instructions": "Ask Bob how he is doing and let me know exactly what he said.",
"id": "testdata/TestAsterick/test.gpt:",
"toolMapping": {
"a* from ./other.gpt": [
{
"reference": "afoo from ./other.gpt",
"toolID": "testdata/TestAsterick/other.gpt:afoo"
},
{
"reference": "a from ./other.gpt",
"toolID": "testdata/TestAsterick/other.gpt:a"
}
]
},
"localTools": {
"": "testdata/TestAsterick/test.gpt:"
},
"source": {
"location": "testdata/TestAsterick/test.gpt",
"lineNo": 1
},
"workingDir": "testdata/TestAsterick"
}
}
}`).Equal(t, toJSONString(t, p))

r.RespondWith(tester.Result{
Text: "hi",
})
_, err = r.Run("", "")
require.NoError(t, err)
}

func TestDualSubChat(t *testing.T) {
r := tester.NewRunner(t)
r.RespondWith(tester.Result{
Expand Down
Loading