Skip to content

Commit aa4cde6

Browse files
authored
Merge pull request #429 from thedadams/fix-prompt-token
fix: add prompt token to sdkserver
2 parents b9fcb20 + aa83730 commit aa4cde6

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

pkg/sdkserver/prompt.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) {
4343

4444
func (s *server) prompt(w http.ResponseWriter, r *http.Request) {
4545
logger := gcontext.GetLogger(r.Context())
46+
if r.Header.Get("Authorization") != "Bearer "+s.token {
47+
writeError(logger, w, http.StatusUnauthorized, fmt.Errorf("invalid token"))
48+
return
49+
}
50+
4651
id := r.PathValue("id")
4752

4853
s.lock.RLock()

pkg/sdkserver/routes.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ import (
2727
const toolRunTimeout = 15 * time.Minute
2828

2929
type server struct {
30-
address string
31-
client *gptscript.GPTScript
32-
events *broadcaster.Broadcaster[event]
30+
address, token string
31+
client *gptscript.GPTScript
32+
events *broadcaster.Broadcaster[event]
3333

3434
lock sync.RWMutex
3535
waitingToConfirm map[string]chan runner.AuthorizerResponse
@@ -165,9 +165,9 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
165165

166166
reqObject.Env = append(os.Environ(), reqObject.Env...)
167167
// Don't overwrite the PromptURLEnvVar if it is already set in the environment.
168-
if !slices.ContainsFunc(reqObject.Env, func(s string) bool { return strings.HasPrefix(s, types.PromptURLEnvVar+"=") }) {
168+
if !slices.ContainsFunc(reqObject.Env, func(s string) bool { return strings.HasPrefix(s, types.PromptTokenEnvVar+"=") }) {
169169
// Append a prompt URL for this run.
170-
reqObject.Env = append(reqObject.Env, fmt.Sprintf("%s=http://%s/prompt/%s", types.PromptURLEnvVar, s.address, runID))
170+
reqObject.Env = append(reqObject.Env, fmt.Sprintf("%s=http://%s/prompt/%s", types.PromptURLEnvVar, s.address, runID), fmt.Sprintf("%s=%s", types.PromptTokenEnvVar, s.token))
171171
}
172172

173173
logger.Debugf("executing tool: %+v", reqObject)

pkg/sdkserver/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/acorn-io/broadcaster"
16+
"github.com/google/uuid"
1617
"github.com/gptscript-ai/gptscript/pkg/gptscript"
1718
"github.com/gptscript-ai/gptscript/pkg/mvl"
1819
"github.com/gptscript-ai/gptscript/pkg/runner"
@@ -52,6 +53,7 @@ func Start(ctx context.Context, opts Options) error {
5253

5354
s := &server{
5455
address: opts.ListenAddress,
56+
token: uuid.NewString(),
5557
client: g,
5658
events: events,
5759
waitingToConfirm: make(map[string]chan runner.AuthorizerResponse),

0 commit comments

Comments
 (0)