Skip to content

Commit 85a9126

Browse files
Improve the positional args error message for dspy.Predict (#8152)
* improve the kwargs error message * improve unit test
1 parent a572bfa commit 85a9126

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

dspy/predict/predict.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,24 @@ def load_state(self, state):
7070

7171
return self
7272

73+
def _get_positional_args_error_message(self):
74+
input_fields = list(self.signature.input_fields.keys())
75+
return (
76+
"Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments "
77+
f"that match your signature input fields: '{', '.join(input_fields)}'. For example: "
78+
f"`predict({input_fields[0]}=input_value, ...)`."
79+
)
80+
7381
def __call__(self, *args, **kwargs):
7482
if args:
75-
raise ValueError(
76-
"Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments "
77-
"that match your signature input fields. For example: "
78-
"dspy.Predict('question -> answer')(question='What is the capital of France?')"
79-
)
83+
raise ValueError(self._get_positional_args_error_message())
8084

8185
return super().__call__(**kwargs)
8286

8387
async def acall(self, *args, **kwargs):
8488
if args:
85-
raise ValueError(
86-
"Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments "
87-
"that match your signature input fields. For example: "
88-
"dspy.Predict('question -> answer').acall(question='What is the capital of France?')"
89-
)
89+
raise ValueError(self._get_positional_args_error_message())
90+
9091
return await super().acall(**kwargs)
9192

9293
def _forward_preprocess(self, **kwargs):

tests/predict/test_predict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,9 @@ def test_positional_arguments():
511511
program = Predict("question -> answer")
512512
with pytest.raises(ValueError) as e:
513513
program("What is the capital of France?")
514-
assert "Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments" in str(
515-
e.value
514+
assert str(e.value) == (
515+
"Positional arguments are not allowed when calling `dspy.Predict`, must use keyword arguments that match "
516+
"your signature input fields: 'question'. For example: `predict(question=input_value, ...)`."
516517
)
517518

518519

0 commit comments

Comments
 (0)