Skip to content

Commit fdc91cb

Browse files
chore: move prompt to always http based
1 parent 58afa9f commit fdc91cb

File tree

12 files changed

+229
-87
lines changed

12 files changed

+229
-87
lines changed

pkg/builtin/builtin.go

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package builtin
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/json"
76
"errors"
@@ -18,9 +17,9 @@ import (
1817
"strings"
1918
"time"
2019

21-
"github.com/AlecAivazis/survey/v2"
2220
"github.com/BurntSushi/locker"
2321
"github.com/gptscript-ai/gptscript/pkg/engine"
22+
"github.com/gptscript-ai/gptscript/pkg/prompt"
2423
"github.com/gptscript-ai/gptscript/pkg/types"
2524
"github.com/jaytaylor/html2text"
2625
)
@@ -216,7 +215,7 @@ var tools = map[string]types.Tool{
216215
"sensitive", "(true or false) Whether the input should be hidden",
217216
),
218217
},
219-
BuiltinFunc: SysPrompt,
218+
BuiltinFunc: prompt.SysPrompt,
220219
},
221220
},
222221
"sys.chat.history": {
@@ -772,79 +771,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e
772771
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
773772
}
774773

775-
func sysPromptHTTP(ctx context.Context, url string, prompt types.Prompt) (_ string, err error) {
776-
data, err := json.Marshal(prompt)
777-
if err != nil {
778-
return "", err
779-
}
780-
781-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
782-
if err != nil {
783-
return "", err
784-
}
785-
req.Header.Set("Content-Type", "application/json")
786-
787-
resp, err := http.DefaultClient.Do(req)
788-
if err != nil {
789-
return "", err
790-
}
791-
defer resp.Body.Close()
792-
793-
if resp.StatusCode != 200 {
794-
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
795-
}
796-
797-
data, err = io.ReadAll(resp.Body)
798-
return string(data), err
799-
}
800-
801-
func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
802-
var params struct {
803-
Message string `json:"message,omitempty"`
804-
Fields string `json:"fields,omitempty"`
805-
Sensitive string `json:"sensitive,omitempty"`
806-
}
807-
if err := json.Unmarshal([]byte(input), &params); err != nil {
808-
return "", err
809-
}
810-
811-
for _, env := range envs {
812-
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {
813-
httpPrompt := types.Prompt{
814-
Message: params.Message,
815-
Fields: strings.Split(params.Fields, ","),
816-
Sensitive: params.Sensitive == "true",
817-
}
818-
return sysPromptHTTP(ctx, url, httpPrompt)
819-
}
820-
}
821-
822-
if params.Message != "" {
823-
_, _ = fmt.Fprintln(os.Stderr, params.Message)
824-
}
825-
826-
results := map[string]string{}
827-
for _, f := range strings.Split(params.Fields, ",") {
828-
var value string
829-
if params.Sensitive == "true" {
830-
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
831-
} else {
832-
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
833-
}
834-
if err != nil {
835-
return "", err
836-
}
837-
results[f] = value
838-
}
839-
840-
resultsStr, err := json.Marshal(results)
841-
if err != nil {
842-
return "", err
843-
}
844-
845-
return string(resultsStr), nil
846-
}
847-
848774
func SysTimeNow(context.Context, []string, string) (string, error) {
849775
return time.Now().Format(time.RFC3339), nil
850776
}

pkg/engine/cmd.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"strings"
1515

1616
"github.com/google/shlex"
17-
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1817
"github.com/gptscript-ai/gptscript/pkg/counter"
1918
"github.com/gptscript-ai/gptscript/pkg/env"
2019
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -73,12 +72,6 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
7372
cmd.Stderr = io.MultiWriter(all, os.Stderr)
7473
cmd.Stdout = io.MultiWriter(all, output)
7574

76-
if toolCategory == CredentialToolCategory {
77-
pause := context2.GetPauseFuncFromCtx(ctx.Ctx)
78-
unpause := pause()
79-
defer unpause()
80-
}
81-
8275
if err := cmd.Run(); err != nil {
8376
if toolCategory == NoCategory {
8477
return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, all), nil

pkg/gptscript/gptscript.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@ import (
55
"fmt"
66
"os"
77
"path/filepath"
8+
"slices"
89

910
"github.com/gptscript-ai/gptscript/pkg/builtin"
1011
"github.com/gptscript-ai/gptscript/pkg/cache"
12+
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1113
"github.com/gptscript-ai/gptscript/pkg/engine"
1214
"github.com/gptscript-ai/gptscript/pkg/hash"
1315
"github.com/gptscript-ai/gptscript/pkg/llm"
1416
"github.com/gptscript-ai/gptscript/pkg/monitor"
1517
"github.com/gptscript-ai/gptscript/pkg/mvl"
1618
"github.com/gptscript-ai/gptscript/pkg/openai"
19+
"github.com/gptscript-ai/gptscript/pkg/prompt"
1720
"github.com/gptscript-ai/gptscript/pkg/remote"
1821
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
1922
"github.com/gptscript-ai/gptscript/pkg/runner"
@@ -28,6 +31,8 @@ type GPTScript struct {
2831
Cache *cache.Client
2932
WorkspacePath string
3033
DeleteWorkspaceOnClose bool
34+
extraEnv []string
35+
close func()
3136
}
3237

3338
type Options struct {
@@ -96,12 +101,21 @@ func New(opts *Options) (*GPTScript, error) {
96101
return nil, err
97102
}
98103

104+
ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
105+
extraEnv, err := prompt.NewServer(ctx, opts.Env)
106+
if err != nil {
107+
closeServer()
108+
return nil, err
109+
}
110+
99111
return &GPTScript{
100112
Registry: registry,
101113
Runner: runner,
102114
Cache: cacheClient,
103115
WorkspacePath: opts.Workspace,
104116
DeleteWorkspaceOnClose: opts.Workspace == "",
117+
extraEnv: extraEnv,
118+
close: closeServer,
105119
}, nil
106120
}
107121

@@ -122,10 +136,10 @@ func (g *GPTScript) getEnv(env []string) ([]string, error) {
122136
if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil {
123137
return nil, err
124138
}
125-
return append([]string{
139+
return slices.Concat(g.extraEnv, []string{
126140
fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath),
127141
fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)),
128-
}, env...), nil
142+
}, env), nil
129143
}
130144

131145
func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) {
@@ -153,6 +167,8 @@ func (g *GPTScript) Close(closeDaemons bool) {
153167
}
154168
}
155169

170+
g.close()
171+
156172
if closeDaemons {
157173
engine.CloseDaemons()
158174
}

pkg/monitor/display.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Console struct {
3838
dumpState string
3939
displayProgress bool
4040
printMessages bool
41+
callLock sync.Mutex
4142
}
4243

4344
var (
@@ -47,6 +48,7 @@ var (
4748
func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) {
4849
id := counter.Next()
4950
mon := newDisplay(c.dumpState, c.displayProgress, c.printMessages)
51+
mon.callLock = &c.callLock
5052
mon.dump.ID = fmt.Sprint(id)
5153
mon.dump.Program = prg
5254
mon.dump.Input = input
@@ -55,13 +57,20 @@ func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input
5557
return mon, nil
5658
}
5759

60+
func (c *Console) Pause() func() {
61+
c.callLock.Lock()
62+
return func() {
63+
c.callLock.Unlock()
64+
}
65+
}
66+
5867
type display struct {
5968
dump dump
6069
printMessages bool
6170
livePrinter *livePrinter
6271
dumpState string
6372
callIDMap map[string]string
64-
callLock sync.Mutex
73+
callLock *sync.Mutex
6574
usage types.Usage
6675
}
6776

pkg/monitor/fd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ func (s *fileFactory) Start(_ context.Context, prg *types.Program, env []string,
7070
return fd, nil
7171
}
7272

73+
func (s *fileFactory) Pause() func() {
74+
return func() {}
75+
}
76+
7377
func (s *fileFactory) close() {
7478
s.lock.Lock()
7579
defer s.lock.Unlock()

pkg/prompt/prompt.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package prompt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"os"
11+
"strings"
12+
13+
"github.com/AlecAivazis/survey/v2"
14+
context2 "github.com/gptscript-ai/gptscript/pkg/context"
15+
"github.com/gptscript-ai/gptscript/pkg/types"
16+
)
17+
18+
func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.Prompt) (_ string, err error) {
19+
data, err := json.Marshal(prompt)
20+
if err != nil {
21+
return "", err
22+
}
23+
24+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
25+
if err != nil {
26+
return "", err
27+
}
28+
req.Header.Set("Content-Type", "application/json")
29+
30+
for _, env := range envs {
31+
if _, v, ok := strings.Cut(env, types.PromptTokenEnvVar+"="); ok && v != "" {
32+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", v))
33+
break
34+
}
35+
}
36+
37+
resp, err := http.DefaultClient.Do(req)
38+
if err != nil {
39+
return "", err
40+
}
41+
defer resp.Body.Close()
42+
43+
if resp.StatusCode != 200 {
44+
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
45+
}
46+
47+
data, err = io.ReadAll(resp.Body)
48+
return string(data), err
49+
}
50+
51+
func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
52+
var params struct {
53+
Message string `json:"message,omitempty"`
54+
Fields string `json:"fields,omitempty"`
55+
Sensitive string `json:"sensitive,omitempty"`
56+
}
57+
if err := json.Unmarshal([]byte(input), &params); err != nil {
58+
return "", err
59+
}
60+
61+
for _, env := range envs {
62+
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {
63+
httpPrompt := types.Prompt{
64+
Message: params.Message,
65+
Fields: strings.Split(params.Fields, ","),
66+
Sensitive: params.Sensitive == "true",
67+
}
68+
return sysPromptHTTP(ctx, envs, url, httpPrompt)
69+
}
70+
}
71+
72+
return "", fmt.Errorf("no prompt server found, can not continue")
73+
}
74+
75+
func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
76+
defer context2.GetPauseFuncFromCtx(ctx)()()
77+
78+
if req.Message != "" {
79+
_, _ = fmt.Fprintln(os.Stderr, req.Message)
80+
}
81+
82+
results := map[string]string{}
83+
for _, f := range req.Fields {
84+
var value string
85+
if req.Sensitive {
86+
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
87+
} else {
88+
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
89+
}
90+
if err != nil {
91+
return "", err
92+
}
93+
results[f] = value
94+
}
95+
96+
resultsStr, err := json.Marshal(results)
97+
if err != nil {
98+
return "", err
99+
}
100+
101+
return string(resultsStr), nil
102+
}

0 commit comments

Comments
 (0)