Skip to content

Commit f0ae38a

Browse files
Merge pull request #341 from cloudnautique/use_tool_from_name
fix: tool names should be used instead of filename
2 parents 3b5d2e3 + a3219cb commit f0ae38a

File tree

5 files changed

+42
-31
lines changed

5 files changed

+42
-31
lines changed

pkg/loader/loader.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
242242
tool.ToolMapping[targetToolName] = linkedTool.ID
243243
toolNames[targetToolName] = struct{}{}
244244
} else {
245-
toolName, subTool := SplitToolRef(targetToolName)
245+
toolName, subTool := types.SplitToolRef(targetToolName)
246246
resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool)
247247
if err != nil {
248248
return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err)
@@ -295,7 +295,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
295295
opt := complete(opts...)
296296

297297
if subToolName == "" {
298-
name, subToolName = SplitToolRef(name)
298+
name, subToolName = types.SplitToolRef(name)
299299
}
300300
prg := types.Program{
301301
Name: name,
@@ -346,24 +346,6 @@ func input(ctx context.Context, cache *cache.Client, base *source, name string)
346346
return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name)
347347
}
348348

349-
func SplitToolRef(targetToolName string) (toolName, subTool string) {
350-
var (
351-
fields = strings.Fields(targetToolName)
352-
idx = slices.Index(fields, "from")
353-
)
354-
355-
defer func() {
356-
toolName, _ = types.SplitArg(toolName)
357-
}()
358-
359-
if idx == -1 {
360-
return strings.TrimSpace(targetToolName), ""
361-
}
362-
363-
return strings.Join(fields[idx+1:], " "),
364-
strings.Join(fields[:idx], " ")
365-
}
366-
367349
func isOpenAPI(data []byte) bool {
368350
var fragment struct {
369351
Paths map[string]any `json:"paths,omitempty"`

pkg/loader/loader_test.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,3 @@ func TestHelloWorld(t *testing.T) {
9999
}
100100
}`, "MODEL", openai.DefaultModel)).Equal(t, toString(prg))
101101
}
102-
103-
func TestParse(t *testing.T) {
104-
tool, subTool := SplitToolRef("a from b with x")
105-
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})
106-
107-
tool, subTool = SplitToolRef("a with x")
108-
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
109-
}

pkg/remote/remote.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
4646
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
4747
}
4848

49-
_, modelName := loader.SplitToolRef(messageRequest.Model)
49+
_, modelName := types.SplitToolRef(messageRequest.Model)
5050
messageRequest.Model = modelName
5151
return client.Call(ctx, messageRequest, status)
5252
}
@@ -71,7 +71,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
7171
}
7272

7373
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
74-
toolName, modelNameSuffix := loader.SplitToolRef(modelName)
74+
toolName, modelNameSuffix := types.SplitToolRef(modelName)
7575
if modelNameSuffix == "" {
7676
return false, nil
7777
}

pkg/types/toolname.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package types
33
import (
44
"path/filepath"
55
"regexp"
6+
"slices"
67
"strings"
78

89
"github.com/gptscript-ai/gptscript/pkg/system"
@@ -14,7 +15,13 @@ var (
1415
)
1516

1617
func ToolNormalizer(tool string) string {
17-
parts := strings.Split(tool, "/")
18+
_, subTool := SplitToolRef(tool)
19+
lastTool := tool
20+
if subTool != "" {
21+
lastTool = subTool
22+
}
23+
24+
parts := strings.Split(lastTool, "/")
1825
tool = parts[len(parts)-1]
1926
if strings.HasSuffix(tool, system.Suffix) {
2027
tool = strings.TrimSuffix(tool, filepath.Ext(tool))
@@ -43,6 +50,24 @@ func ToolNormalizer(tool string) string {
4350
return strings.Join(result, "")
4451
}
4552

53+
func SplitToolRef(targetToolName string) (toolName, subTool string) {
54+
var (
55+
fields = strings.Fields(targetToolName)
56+
idx = slices.Index(fields, "from")
57+
)
58+
59+
defer func() {
60+
toolName, _ = SplitArg(toolName)
61+
}()
62+
63+
if idx == -1 {
64+
return strings.TrimSpace(targetToolName), ""
65+
}
66+
67+
return strings.Join(fields[idx+1:], " "),
68+
strings.Join(fields[:idx], " ")
69+
}
70+
4671
func PickToolName(toolName string, existing map[string]struct{}) string {
4772
if toolName == "" {
4873
toolName = "external"

pkg/types/toolname_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,16 @@ func TestToolNormalizer(t *testing.T) {
1010
autogold.Expect("bobTool").Equal(t, ToolNormalizer("bob-tool"))
1111
autogold.Expect("bobTool").Equal(t, ToolNormalizer("bob_tool"))
1212
autogold.Expect("bobTool").Equal(t, ToolNormalizer("BOB tOOL"))
13+
autogold.Expect("barList").Equal(t, ToolNormalizer("bar_list from ./foo.yaml"))
14+
autogold.Expect("barList").Equal(t, ToolNormalizer("bar_list from ./foo.gpt"))
15+
autogold.Expect("write").Equal(t, ToolNormalizer("sys.write"))
16+
autogold.Expect("gpt4VVision").Equal(t, ToolNormalizer("github.com/gptscript-ai/gpt4-v-vision"))
17+
}
18+
19+
func TestParse(t *testing.T) {
20+
tool, subTool := SplitToolRef("a from b with x")
21+
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})
22+
23+
tool, subTool = SplitToolRef("a with x")
24+
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
1325
}

0 commit comments

Comments
 (0)