Skip to content

chore: refactor export context behavior #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {

if prg.IsChat() || r.ForceChat {
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
return prg, nil
return r.readProgram(ctx, gptScript, args)
}, os.Environ(), toolInput)
}

Expand Down
17 changes: 17 additions & 0 deletions pkg/counter/counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package counter

import (
"fmt"
"sync/atomic"
"time"
)

var counter = int32(time.Now().Unix())

func Reset(i int32) {
atomic.StoreInt32(&counter, i)
}

func Next() string {
return fmt.Sprint(atomic.AddInt32(&counter, 1))
}
4 changes: 2 additions & 2 deletions pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ import (
"runtime"
"sort"
"strings"
"sync/atomic"

"github.com/google/shlex"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
id := counter.Next()

defer func() {
e.Progress <- types.CompletionStatus{
Expand Down
10 changes: 3 additions & 7 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ import (
"fmt"
"strings"
"sync"
"sync/atomic"

"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

var completionID int64

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
}
Expand Down Expand Up @@ -123,12 +121,10 @@ func (c *Context) MarshalJSON() ([]byte, error) {
return json.Marshal(c.GetCallContext())
}

var execID int32

func NewContext(ctx context.Context, prg *types.Program) Context {
callCtx := Context{
commonContext: commonContext{
ID: fmt.Sprint(atomic.AddInt32(&execID, 1)),
ID: counter.Next(),
Tool: prg.ToolSet[prg.EntryToolID],
},
Ctx: ctx,
Expand All @@ -144,7 +140,7 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCatego
}

if callID == "" {
callID = fmt.Sprint(atomic.AddInt32(&execID, 1))
callID = counter.Next()
}

return Context{
Expand Down
5 changes: 2 additions & 3 deletions pkg/engine/print.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package engine

import (
"fmt"
"strings"
"sync/atomic"

"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (e *Engine) runEcho(tool types.Tool) (cmdOut *Return, cmdErr error) {
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
id := counter.Next()
out := strings.TrimPrefix(tool.Instructions, types.EchoPrefix+"\n")

e.Progress <- types.CompletionStatus{
Expand Down
8 changes: 8 additions & 0 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name"))
}

if i != 0 && tool.Parameters.GlobalModelName != "" {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name"))
}

if i != 0 && len(tool.Parameters.GlobalTools) > 0 {
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools"))
}

if targetToolName != "" && strings.EqualFold(tool.Parameters.Name, targetToolName) {
mainTool = tool
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/monitor/display.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/runner"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand Down Expand Up @@ -40,12 +41,11 @@ type Console struct {
}

var (
runID int64
prettyIDCounter int64
)

func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) {
id := atomic.AddInt64(&runID, 1)
id := counter.Next()
mon := newDisplay(c.dumpState, c.displayProgress, c.printMessages)
mon.dump.ID = fmt.Sprint(id)
mon.dump.Program = prg
Expand Down
11 changes: 5 additions & 6 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"slices"
"sort"
"strings"
"sync/atomic"

"github.com/getkin/kin-openapi/openapi3"
openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -24,10 +24,9 @@ const (
)

var (
key = os.Getenv("OPENAI_API_KEY")
url = os.Getenv("OPENAI_URL")
azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT")
completionID int64
key = os.Getenv("OPENAI_API_KEY")
url = os.Getenv("OPENAI_URL")
azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT")
)

type Client struct {
Expand Down Expand Up @@ -332,7 +331,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
})
}

id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
id := counter.Next()
status <- types.CompletionStatus{
CompletionID: id,
Request: request,
Expand Down
4 changes: 2 additions & 2 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
return false, err
}
tool.Parameters.Chat = v
case "export":
case "export", "exporttool", "exports", "exporttools":
tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...)
case "tool", "tools":
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...)
case "globaltool", "globaltools":
tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(value)...)
case "exportcontext":
case "exportcontext", "exportcontexts":
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(value)...)
case "context":
tool.Parameters.Context = append(tool.Parameters.Context, csv(value)...)
Expand Down
9 changes: 9 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,15 @@ func TestCase2(t *testing.T) {
assert.Equal(t, "TEST RESULT CALL: 1", x)
}

func TestGlobalErr(t *testing.T) {
runner := tester.NewRunner(t)
_, err := runner.Run("", "")
autogold.Expect("line testdata/TestGlobalErr/test.gpt:4: only the first tool in a file can have global model name").Equal(t, err.Error())

_, err = runner.Run("test2.gpt", "")
autogold.Expect("line testdata/TestGlobalErr/test2.gpt:4: only the first tool in a file can have global tools").Equal(t, err.Error())
}

func TestContextArg(t *testing.T) {
runner := tester.NewRunner(t)
x, err := runner.Run("", `{
Expand Down
4 changes: 2 additions & 2 deletions pkg/tests/testdata/TestExportContext/call1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"Tools": [
{
"function": {
"toolID": "testdata/TestExportContext/test.gpt:21",
"toolID": "testdata/TestExportContext/test.gpt:22",
"name": "subtool",
"parameters": {
"properties": {
Expand All @@ -22,7 +22,7 @@
},
{
"function": {
"toolID": "testdata/TestExportContext/test.gpt:14",
"toolID": "testdata/TestExportContext/test.gpt:15",
"name": "sampletool",
"description": "sample",
"parameters": {
Expand Down
2 changes: 1 addition & 1 deletion pkg/tests/testdata/TestExportContext/test.gpt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This is from tool
---
name: fromcontext
export: sampletool
export context: fromexportcontext

#!/bin/bash
echo this is from context
Expand All @@ -19,7 +20,6 @@ Dummy body

---
name: subtool
export context: fromexportcontext

Dummy body

Expand Down
7 changes: 7 additions & 0 deletions pkg/tests/testdata/TestGlobalErr/test.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
first

---
name: second
global model name: foo

second
7 changes: 7 additions & 0 deletions pkg/tests/testdata/TestGlobalErr/test2.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
first

---
name: second
global tools: asdf

second
45 changes: 45 additions & 0 deletions pkg/types/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package types

type toolRefKey struct {
name string
toolID string
arg string
}

type toolRefSet struct {
set map[toolRefKey]ToolReference
order []toolRefKey
err error
}

func (t *toolRefSet) List() (result []ToolReference, err error) {
for _, k := range t.order {
result = append(result, t.set[k])
}
return result, t.err
}

func (t *toolRefSet) AddAll(values []ToolReference, err error) {
if t.err != nil {
t.err = err
}
for _, v := range values {
t.Add(v)
}
}

func (t *toolRefSet) Add(value ToolReference) {
key := toolRefKey{
name: value.Named,
toolID: value.ToolID,
arg: value.Arg,
}

if _, ok := t.set[key]; !ok {
if t.set == nil {
t.set = map[toolRefKey]ToolReference{}
}
t.set[key] = value
t.order = append(t.order, key)
}
}
Loading