@@ -2,8 +2,10 @@ package openai
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"io"
6
7
"log/slog"
8
+ "math"
7
9
"os"
8
10
"slices"
9
11
"sort"
@@ -24,6 +26,7 @@ import (
24
26
const (
25
27
DefaultModel = openai .GPT4o
26
28
BuiltinCredName = "sys.openai"
29
+ TooLongMessage = "Error: tool call output is too long"
27
30
)
28
31
29
32
var (
@@ -317,6 +320,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
317
320
}
318
321
319
322
if messageRequest .Chat {
323
+ // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
324
+ lastMessage := msgs [len (msgs )- 1 ]
325
+ if lastMessage .Role == string (types .CompletionMessageRoleTypeTool ) && countMessage (lastMessage ) > int (math .Round (float64 (getBudget (messageRequest .MaxTokens ))* 0.8 )) {
326
+ // We need to update it in the msgs slice for right now and in the messageRequest for future calls.
327
+ msgs [len (msgs )- 1 ].Content = TooLongMessage
328
+ messageRequest .Messages [len (messageRequest .Messages )- 1 ].Content = types .Text (TooLongMessage )
329
+ }
330
+
320
331
msgs = dropMessagesOverCount (messageRequest .MaxTokens , msgs )
321
332
}
322
333
@@ -383,6 +394,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
383
394
return nil , err
384
395
} else if ! ok {
385
396
response , err = c .call (ctx , request , id , status )
397
+
398
+ // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass.
399
+ var apiError * openai.APIError
400
+ if err != nil && errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" && messageRequest .Chat {
401
+ // Decrease maxTokens by 10% to make garbage collection more aggressive.
402
+ // The retry loop will further decrease maxTokens if needed.
403
+ maxTokens := decreaseTenPercent (messageRequest .MaxTokens )
404
+ response , err = c .contextLimitRetryLoop (ctx , request , id , maxTokens , status )
405
+ }
406
+
386
407
if err != nil {
387
408
return nil , err
388
409
}
@@ -421,6 +442,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
421
442
return & result , nil
422
443
}
423
444
445
+ func (c * Client ) contextLimitRetryLoop (ctx context.Context , request openai.ChatCompletionRequest , id string , maxTokens int , status chan <- types.CompletionStatus ) ([]openai.ChatCompletionStreamResponse , error ) {
446
+ var (
447
+ response []openai.ChatCompletionStreamResponse
448
+ err error
449
+ )
450
+
451
+ for range 10 { // maximum 10 tries
452
+ // Try to drop older messages again, with a decreased max tokens.
453
+ request .Messages = dropMessagesOverCount (maxTokens , request .Messages )
454
+ response , err = c .call (ctx , request , id , status )
455
+ if err == nil {
456
+ break
457
+ }
458
+
459
+ var apiError * openai.APIError
460
+ if errors .As (err , & apiError ) && apiError .Code == "context_length_exceeded" {
461
+ // Decrease maxTokens and try again
462
+ maxTokens = decreaseTenPercent (maxTokens )
463
+ continue
464
+ }
465
+ return nil , err
466
+ }
467
+
468
+ return response , nil
469
+ }
470
+
424
471
func appendMessage (msg types.CompletionMessage , response openai.ChatCompletionStreamResponse ) types.CompletionMessage {
425
472
msg .Usage .CompletionTokens = types .FirstSet (msg .Usage .CompletionTokens , response .Usage .CompletionTokens )
426
473
msg .Usage .PromptTokens = types .FirstSet (msg .Usage .PromptTokens , response .Usage .PromptTokens )
0 commit comments