Skip to content

Commit dd84ea2

Browse files
authored
chore: add Add Elements tool to dataset sdk (#885)
Signed-off-by: Grant Linville <[email protected]>
1 parent af2e82f commit dd84ea2

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

pkg/sdkserver/datasets.go

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ import (
1111
)
1212

1313
type datasetRequest struct {
14-
Input string `json:"input"`
15-
Workspace string `json:"workspace"`
16-
DatasetToolRepo string `json:"datasetToolRepo"`
14+
Input string `json:"input"`
15+
WorkspaceID string `json:"workspaceID"`
16+
DatasetToolRepo string `json:"datasetToolRepo"`
17+
Env []string `json:"env"`
1718
}
1819

1920
func (r datasetRequest) validate(requireInput bool) error {
20-
if r.Workspace == "" {
21-
return fmt.Errorf("workspace is required")
21+
if r.WorkspaceID == "" {
22+
return fmt.Errorf("workspaceID is required")
2223
} else if requireInput && r.Input == "" {
2324
return fmt.Errorf("input is required")
25+
} else if len(r.Env) == 0 {
26+
return fmt.Errorf("env is required")
2427
}
2528
return nil
2629
}
@@ -30,7 +33,7 @@ func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
3033
Cache: o.Cache,
3134
Monitor: o.Monitor,
3235
Runner: o.Runner,
33-
Workspace: r.Workspace,
36+
Workspace: r.WorkspaceID,
3437
}
3538
return opts
3639
}
@@ -39,7 +42,7 @@ func (r datasetRequest) getToolRepo() string {
3942
if r.DatasetToolRepo != "" {
4043
return r.DatasetToolRepo
4144
}
42-
return "github.com/gptscript-ai/datasets"
45+
return "github.com/otto8-ai/datasets"
4346
}
4447

4548
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
@@ -71,7 +74,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
7174
return
7275
}
7376

74-
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
77+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
7578
if err != nil {
7679
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
7780
return
@@ -132,7 +135,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
132135
return
133136
}
134137

135-
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
138+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
136139
if err != nil {
137140
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
138141
return
@@ -200,7 +203,80 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
200203
return
201204
}
202205

203-
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
206+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
207+
if err != nil {
208+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
209+
return
210+
}
211+
212+
writeResponse(logger, w, map[string]any{"stdout": result})
213+
}
214+
215+
type addDatasetElementsArgs struct {
216+
DatasetID string `json:"datasetID"`
217+
Elements []struct {
218+
Name string `json:"name"`
219+
Description string `json:"description"`
220+
Contents string `json:"contents"`
221+
}
222+
}
223+
224+
func (a addDatasetElementsArgs) validate() error {
225+
if a.DatasetID == "" {
226+
return fmt.Errorf("datasetID is required")
227+
}
228+
if len(a.Elements) == 0 {
229+
return fmt.Errorf("elements is required")
230+
}
231+
return nil
232+
}
233+
234+
func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) {
235+
logger := gcontext.GetLogger(r.Context())
236+
237+
var req datasetRequest
238+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
239+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
240+
return
241+
}
242+
243+
if err := req.validate(true); err != nil {
244+
writeError(logger, w, http.StatusBadRequest, err)
245+
return
246+
}
247+
248+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
249+
if err != nil {
250+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
251+
return
252+
}
253+
254+
var args addDatasetElementsArgs
255+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
256+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
257+
return
258+
}
259+
260+
if err := args.validate(); err != nil {
261+
writeError(logger, w, http.StatusBadRequest, err)
262+
return
263+
}
264+
265+
prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{
266+
Cache: g.Cache,
267+
})
268+
if err != nil {
269+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
270+
return
271+
}
272+
273+
elementsJSON, err := json.Marshal(args.Elements)
274+
if err != nil {
275+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to marshal elements: %w", err))
276+
return
277+
}
278+
279+
result, err := g.Run(r.Context(), prg, req.Env, fmt.Sprintf(`{"datasetID":%q, "elements":%q}`, args.DatasetID, string(elementsJSON)))
204280
if err != nil {
205281
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
206282
return
@@ -259,7 +335,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
259335
return
260336
}
261337

262-
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
338+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
263339
if err != nil {
264340
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
265341
return
@@ -322,7 +398,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
322398
return
323399
}
324400

325-
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
401+
result, err := g.Run(r.Context(), prg, req.Env, req.Input)
326402
if err != nil {
327403
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
328404
return

pkg/sdkserver/routes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func (s *server) addRoutes(mux *http.ServeMux) {
7373
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
7474
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
7575
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
76+
mux.HandleFunc("POST /datasets/add-elements", s.addDatasetElements)
7677

7778
mux.HandleFunc("POST /workspaces/create", s.createWorkspace)
7879
mux.HandleFunc("POST /workspaces/delete", s.deleteWorkspace)

0 commit comments

Comments
 (0)