Skip to content

feat: add prompt support to the SDK server #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
return SetDefaults(t), ok
}

func SysFind(ctx context.Context, env []string, input string) (string, error) {
func SysFind(_ context.Context, _ []string, input string) (string, error) {
var result []string
var params struct {
Pattern string `json:"pattern,omitempty"`
Expand Down Expand Up @@ -306,7 +306,7 @@ func SysFind(ctx context.Context, env []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysExec(ctx context.Context, env []string, input string) (string, error) {
func SysExec(_ context.Context, env []string, input string) (string, error) {
var params struct {
Command string `json:"command,omitempty"`
Directory string `json:"directory,omitempty"`
Expand Down Expand Up @@ -412,7 +412,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
return string(data), nil
}

func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
func SysWrite(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -444,7 +444,7 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
}

func SysAppend(ctx context.Context, env []string, input string) (string, error) {
func SysAppend(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -490,7 +490,7 @@ func fixQueries(u string) string {
return url.String()
}

func SysHTTPGet(ctx context.Context, env []string, input string) (_ string, err error) {
func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
}
Expand Down Expand Up @@ -534,7 +534,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
})
}

func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err error) {
func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -570,7 +570,7 @@ func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
}

func SysGetenv(ctx context.Context, env []string, input string) (string, error) {
func SysGetenv(_ context.Context, env []string, input string) (string, error) {
var params struct {
Name string `json:"name,omitempty"`
}
Expand Down Expand Up @@ -636,7 +636,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
return
}

func SysChatFinish(ctx context.Context, env []string, input string) (string, error) {
func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Message string `json:"return,omitempty"`
}
Expand All @@ -650,7 +650,7 @@ func SysChatFinish(ctx context.Context, env []string, input string) (string, err
}
}

func SysAbort(ctx context.Context, env []string, input string) (string, error) {
func SysAbort(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Message string `json:"message,omitempty"`
}
Expand All @@ -660,7 +660,7 @@ func SysAbort(ctx context.Context, env []string, input string) (string, error) {
return "", fmt.Errorf("ABORT: %s", params.Message)
}

func SysRemove(ctx context.Context, env []string, input string) (string, error) {
func SysRemove(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Location string `json:"location,omitempty"`
}
Expand All @@ -679,7 +679,7 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error)
return fmt.Sprintf("Removed file: %s", params.Location), nil
}

func SysStat(ctx context.Context, env []string, input string) (string, error) {
func SysStat(_ context.Context, _ []string, input string) (string, error) {
var params struct {
Filepath string `json:"filepath,omitempty"`
}
Expand All @@ -699,7 +699,7 @@ func SysStat(ctx context.Context, env []string, input string) (string, error) {
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
}

func SysDownload(ctx context.Context, env []string, input string) (_ string, err error) {
func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Location string `json:"location,omitempty"`
Expand Down Expand Up @@ -772,12 +772,8 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
}

func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) {
data, err := json.Marshal(map[string]any{
"message": message,
"fields": fields,
"sensitive": sensitive,
})
func sysPromptHTTP(ctx context.Context, url string, prompt types.Prompt) (_ string, err error) {
data, err := json.Marshal(prompt)
if err != nil {
return "", err
}
Expand All @@ -792,7 +788,7 @@ func sysPromptHTTP(ctx context.Context, url, message string, fields []string, se
if err != nil {
return "", err
}
resp.Body.Close()
defer resp.Body.Close()

if resp.StatusCode != 200 {
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
Expand All @@ -813,8 +809,13 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
}

for _, env := range envs {
if url, ok := strings.CutPrefix(env, "GPTSCRIPT_PROMPT_URL="); ok {
return sysPromptHTTP(ctx, url, params.Message, strings.Split(params.Fields, ","), params.Sensitive == "true")
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {
httpPrompt := types.Prompt{
Message: params.Message,
Fields: strings.Split(params.Fields, ","),
Sensitive: params.Sensitive == "true",
}
return sysPromptHTTP(ctx, url, httpPrompt)
}
}

Expand Down Expand Up @@ -844,6 +845,6 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
return string(resultsStr), nil
}

func SysTimeNow(ctx context.Context, env []string, input string) (string, error) {
func SysTimeNow(context.Context, []string, string) (string, error) {
return time.Now().Format(time.RFC3339), nil
}
16 changes: 9 additions & 7 deletions pkg/sdkserver/confirm.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ func (s *server) authorize(ctx engine.Context, input string) (runner.AuthorizerR
s.lock.Unlock()
}(ctx.ID)

s.events.C <- gserver.Event{
Event: runner.Event{
Time: time.Now(),
CallContext: ctx.GetCallContext(),
Type: CallConfirm,
s.events.C <- event{
Event: gserver.Event{
Event: runner.Event{
Time: time.Now(),
CallContext: ctx.GetCallContext(),
Type: CallConfirm,
},
Input: input,
RunID: runID,
},
Input: input,
RunID: runID,
}

// Wait for the confirmation to come through.
Expand Down
94 changes: 94 additions & 0 deletions pkg/sdkserver/monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sdkserver

import (
"context"
"sync"
"time"

"github.com/acorn-io/broadcaster"
"github.com/gptscript-ai/gptscript/pkg/runner"
gserver "github.com/gptscript-ai/gptscript/pkg/server"
"github.com/gptscript-ai/gptscript/pkg/types"
)

type SessionFactory struct {
events *broadcaster.Broadcaster[event]
}

func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory {
return &SessionFactory{
events: events,
}
}

func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) {
id := gserver.RunIDFromContext(ctx)

s.events.C <- event{
Event: gserver.Event{
Event: runner.Event{
Time: time.Now(),
Type: runner.EventTypeRunStart,
},
RunID: id,
Program: prg,
},
}

return &Session{
id: id,
prj: prg,
env: env,
input: input,
events: s.events,
}, nil
}

type Session struct {
id string
prj *types.Program
env []string
input string
events *broadcaster.Broadcaster[event]
runLock sync.Mutex
}

func (s *Session) Event(e runner.Event) {
s.runLock.Lock()
defer s.runLock.Unlock()
s.events.C <- event{
Event: gserver.Event{
Event: e,
RunID: s.id,
Input: s.input,
},
}
}

func (s *Session) Stop(output string, err error) {
e := event{
Event: gserver.Event{
Event: runner.Event{
Time: time.Now(),
Type: runner.EventTypeRunFinish,
},
RunID: s.id,
Input: s.input,
Output: output,
},
}
if err != nil {
e.Err = err.Error()
}

s.runLock.Lock()
defer s.runLock.Unlock()
s.events.C <- e
}

func (s *Session) Pause() func() {
s.runLock.Lock()
return func() {
s.runLock.Unlock()
}
}
111 changes: 111 additions & 0 deletions pkg/sdkserver/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package sdkserver

import (
"encoding/json"
"fmt"
"net/http"
"time"

gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/runner"
gserver "github.com/gptscript-ai/gptscript/pkg/server"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
id := r.PathValue("id")

s.lock.RLock()
promptChan := s.waitingToPrompt[id]
s.lock.RUnlock()

if promptChan == nil {
writeError(logger, w, http.StatusNotFound, fmt.Errorf("no prompt found with id %q", id))
return
}

var promptResponse map[string]string
if err := json.NewDecoder(r.Body).Decode(&promptResponse); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
return
}

// Don't block here because, if the prompter is no longer waiting on this then it will never unblock.
select {
case promptChan <- promptResponse:
w.WriteHeader(http.StatusAccepted)
default:
w.WriteHeader(http.StatusConflict)
}
}

func (s *server) prompt(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
id := r.PathValue("id")

s.lock.RLock()
promptChan := s.waitingToPrompt[id]
s.lock.RUnlock()

if promptChan != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("prompt called multiple times for same ID: %s", id))
return
}

var prompt types.Prompt
if err := json.NewDecoder(r.Body).Decode(&prompt); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %v", err))
return
}

s.lock.Lock()
promptChan = make(chan map[string]string)
s.waitingToPrompt[id] = promptChan
s.lock.Unlock()
defer func(id string) {
s.lock.Lock()
delete(s.waitingToPrompt, id)
s.lock.Unlock()
}(id)

s.events.C <- event{
Prompt: types.Prompt{
Message: prompt.Message,
Fields: prompt.Fields,
Sensitive: prompt.Sensitive,
},
Event: gserver.Event{
RunID: id,
Event: runner.Event{
Type: Prompt,
Time: time.Now(),
},
},
}

// Wait for the prompt response to come through.
select {
case <-r.Context().Done():
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("context canceled: %v", r.Context().Err()))
return
case promptResponse := <-promptChan:
writePromptResponse(logger, w, http.StatusOK, promptResponse)
}
}

func writePromptResponse(logger mvl.Logger, w http.ResponseWriter, code int, resp any) {
b, err := json.Marshal(resp)
if err != nil {
logger.Errorf("failed to marshal response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
} else {
w.WriteHeader(code)
}

_, err = w.Write(b)
if err != nil {
logger.Errorf("failed to write response: %v", err)
}
}
Loading