Skip to content

bug: speed up the openapi loading #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ require (
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
golang.org/x/sync v0.7.0
golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand Down Expand Up @@ -81,6 +80,7 @@ require (
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.20.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.5.1 // indirect
mvdan.cc/gofumpt v0.6.0 // indirect
)
2 changes: 1 addition & 1 deletion pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {

if prg.IsChat() || r.ForceChat {
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
return r.readProgram(ctx, gptScript, args)
return prg, nil
}, os.Environ(), toolInput)
}

Expand Down
11 changes: 9 additions & 2 deletions pkg/hash/sha256.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ func ID(parts ...string) string {

func Digest(obj any) string {
hash := sha256.New()
if err := gob.NewEncoder(hash).Encode(obj); err != nil {
panic(err)
switch v := obj.(type) {
case []byte:
hash.Write(v)
case string:
hash.Write([]byte(v))
default:
if err := gob.NewEncoder(hash).Encode(obj); err != nil {
panic(err)
}
}
return hex.EncodeToString(hash.Sum(nil))
}
12 changes: 12 additions & 0 deletions pkg/loader/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
"net/http"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/repos/git"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -26,6 +28,7 @@ const (

var (
githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN")
log = mvl.Package()
)

func init() {
Expand All @@ -37,7 +40,14 @@ func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string,
return git.LsRemote(ctx, url, ref)
}

// regexp to match a git commit id
var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$")

func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
if commitRegexp.MatchString(ref) {
return ref, nil
}

url := fmt.Sprintf(githubCommitURL, account, repo, ref)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
Expand Down Expand Up @@ -69,6 +79,8 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
return "", fmt.Errorf("failed to decode GitHub commit of %s/%s at %s: %w", account, repo, url, err)
}

log.Debugf("loaded github commit of %s/%s at %s as %q", account, repo, url, commit.SHA)

if commit.SHA == "" {
return "", fmt.Errorf("failed to find commit in response of %s, got empty string", url)
}
Expand Down
86 changes: 65 additions & 21 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/parser"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
"gopkg.in/yaml.v3"
)

const CacheTimeout = time.Hour
Expand Down Expand Up @@ -120,24 +120,50 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
return tool, nil
}

func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
var (
openAPICacheKey = hash.Digest(data)
openAPIDocument, ok = prg.OpenAPICache[openAPICacheKey].(*openapi3.T)
err error
)

if ok {
return openAPIDocument
}

if prg.OpenAPICache == nil {
prg.OpenAPICache = map[string]any{}
}

openAPIDocument, err = openapi3.NewLoader().LoadFromData(data)
if err != nil || openAPIDocument.Paths.Len() == 0 {
openAPIDocument = nil
}

prg.OpenAPICache[openAPICacheKey] = openAPIDocument
return openAPIDocument
}

func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) {
data := base.Content

if bytes.HasPrefix(data, assemble.Header) {
return loadProgram(data, prg, targetToolName)
}

var tools []types.Tool
if isOpenAPI(data) {
if t, err := openapi3.NewLoader().LoadFromData(data); err == nil {
if base.Remote {
tools, err = getOpenAPITools(t, base.Location)
} else {
tools, err = getOpenAPITools(t, "")
}
if err != nil {
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
}
var (
tools []types.Tool
)

if openAPIDocument := loadOpenAPI(prg, data); openAPIDocument != nil {
var err error
if base.Remote {
tools, err = getOpenAPITools(openAPIDocument, base.Location)
} else {
tools, err = getOpenAPITools(openAPIDocument, "")
}
if err != nil {
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
}
}

Expand Down Expand Up @@ -263,6 +289,12 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
}

func ProgramFromSource(ctx context.Context, content, subToolName string, opts ...Options) (types.Program, error) {
if log.IsDebug() {
start := time.Now()
defer func() {
log.Debugf("loaded program from source took %v", time.Since(start))
}()
}
opt := complete(opts...)

prg := types.Program{
Expand Down Expand Up @@ -292,6 +324,13 @@ func complete(opts ...Options) (result Options) {
}

func Program(ctx context.Context, name, subToolName string, opts ...Options) (types.Program, error) {
if log.IsDebug() {
start := time.Now()
defer func() {
log.Debugf("loaded program %s source took %v", name, time.Since(start))
}()
}

opt := complete(opts...)

if subToolName == "" {
Expand Down Expand Up @@ -346,15 +385,20 @@ func input(ctx context.Context, cache *cache.Client, base *source, name string)
return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name)
}

func isOpenAPI(data []byte) bool {
var fragment struct {
Paths map[string]any `json:"paths,omitempty"`
}
func SplitToolRef(targetToolName string) (toolName, subTool string) {
var (
fields = strings.Fields(targetToolName)
idx = slices.Index(fields, "from")
)

if err := json.Unmarshal(data, &fragment); err != nil {
if err := yaml.Unmarshal(data, &fragment); err != nil {
return false
}
defer func() {
toolName, _ = types.SplitArg(toolName)
}()

if idx == -1 {
return strings.TrimSpace(targetToolName), ""
}
return len(fragment.Paths) > 0

return strings.Join(fields[idx+1:], " "),
strings.Join(fields[:idx], " ")
}
7 changes: 7 additions & 0 deletions pkg/loader/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"slices"
"sort"
"strings"
"time"

"github.com/getkin/kin-openapi/openapi3"
"github.com/gptscript-ai/gptscript/pkg/engine"
Expand All @@ -18,6 +19,12 @@ import (
// The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'",
// where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct.
func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) {
if log.IsDebug() {
start := time.Now()
defer func() {
log.Debugf("loaded openapi tools in %v", time.Since(start))
}()
}
// Determine the default server.
if len(t.Servers) == 0 {
if defaultHost != "" {
Expand Down
3 changes: 3 additions & 0 deletions pkg/repos/git/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
)

func newGitCommand(ctx context.Context, args ...string) *debugcmd.WrappedCmd {
if log.IsDebug() {
log.Debugf("running git command: %s", strings.Join(args, " "))
}
cmd := debugcmd.New(ctx, "git", args...)
return cmd
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ func (e *ErrToolNotFound) Error() string {
type ToolSet map[string]Tool

type Program struct {
Name string `json:"name,omitempty"`
EntryToolID string `json:"entryToolId,omitempty"`
ToolSet ToolSet `json:"toolSet,omitempty"`
Name string `json:"name,omitempty"`
EntryToolID string `json:"entryToolId,omitempty"`
ToolSet ToolSet `json:"toolSet,omitempty"`
OpenAPICache map[string]any `json:"-"`
}

func (p Program) IsChat() bool {
Expand Down