File tree Expand file tree Collapse file tree 2 files changed +35
-0
lines changed Expand file tree Collapse file tree 2 files changed +35
-0
lines changed Original file line number Diff line number Diff line change @@ -108,6 +108,17 @@ def _forward_preprocess(self, **kwargs):
108
108
if (temperature is None or temperature <= 0.15 ) and num_generations > 1 :
109
109
config ["temperature" ] = 0.7
110
110
111
+ if "prediction" in kwargs :
112
+ if (
113
+ isinstance (kwargs ["prediction" ], dict )
114
+ and kwargs ["prediction" ].get ("type" ) == "content"
115
+ and "content" in kwargs ["prediction" ]
116
+ ):
117
+ # If the `prediction` is the standard predicted outputs format
118
+ # (https://platform.openai.com/docs/guides/predicted-outputs), we remvoe it from input kwargs and add it
119
+ # to the lm kwargs.
120
+ config ["prediction" ] = kwargs .pop ("prediction" )
121
+
111
122
if not all (k in kwargs for k in signature .input_fields ):
112
123
present = [k for k in signature .input_fields if k in kwargs ]
113
124
missing = [k for k in signature .input_fields if k not in kwargs ]
Original file line number Diff line number Diff line change @@ -572,3 +572,27 @@ async def test_async_predict():
572
572
dspy .settings .configure (lm = DummyLM ([{"answer" : "Paris" }]))
573
573
result = await program .acall (question = "What is the capital of France?" )
574
574
assert result .answer == "Paris"
575
+
576
+
577
+ def test_predicted_outputs_piped_from_predict_to_lm_call ():
578
+ program = Predict ("question -> answer" )
579
+ dspy .settings .configure (lm = dspy .LM ("openai/gpt-4o-mini" ))
580
+
581
+ with patch ("litellm.completion" ) as mock_completion :
582
+ program (
583
+ question = "Why did a chicken cross the kitchen?" ,
584
+ prediction = {"type" : "content" , "content" : "A chicken crossing the kitchen" },
585
+ )
586
+
587
+ assert mock_completion .call_args [1 ]["prediction" ] == {
588
+ "type" : "content" ,
589
+ "content" : "A chicken crossing the kitchen" ,
590
+ }
591
+
592
+ # If the signature has prediction as an input field, and the prediction is not set as the standard predicted output
593
+ # format, it should not be passed to the LM.
594
+ program = Predict ("question, prediction -> judgement" )
595
+ with patch ("litellm.completion" ) as mock_completion :
596
+ program (question = "Why did a chicken cross the kitchen?" , prediction = "To get to the other side!" )
597
+
598
+ assert "prediction" not in mock_completion .call_args [1 ]
You can’t perform that action at this time.
0 commit comments