Skip to content

Commit ab4d17c

Browse files
authored
Improve handling of literals and optionals (#8143)
1 parent 576f33b commit ab4d17c

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

dspy/adapters/json_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
126126
if match:
127127
completion = match.group(0)
128128
fields = json_repair.loads(completion)
129+
130+
if not isinstance(fields, dict):
131+
raise ValueError(f"Expected a JSON object but parsed a {type(fields)}")
132+
129133
fields = {k: v for k, v in fields.items() if k in signature.output_fields}
130134

131135
# Attempt to cast each value to type signature.output_fields[k].annotation.

dspy/adapters/utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,25 @@ def parse_value(value, annotation):
138138
if isinstance(annotation, enum.EnumMeta):
139139
return find_enum_member(annotation, value)
140140

141+
origin = get_origin(annotation)
142+
143+
if origin is Literal:
144+
allowed = get_args(annotation)
145+
if value in allowed:
146+
return value
147+
148+
if isinstance(value, str):
149+
v = value.strip()
150+
if v.startswith(("Literal[", "str[")) and v.endswith("]"):
151+
v = v[v.find("[") + 1 : -1]
152+
if len(v) > 1 and v[0] == v[-1] and v[0] in "\"'":
153+
v = v[1:-1]
154+
155+
if v in allowed:
156+
return v
157+
158+
raise ValueError(f"{value!r} is not one of {allowed!r}")
159+
141160
if not isinstance(value, str):
142161
return TypeAdapter(annotation).validate_python(value)
143162

@@ -147,15 +166,13 @@ def parse_value(value, annotation):
147166
candidate = ast.literal_eval(value)
148167
except (ValueError, SyntaxError):
149168
candidate = value
150-
169+
151170
try:
152171
return TypeAdapter(annotation).validate_python(candidate)
153-
except pydantic.ValidationError as e:
154-
# if the annotation is Optional[str], return just the string value
155-
if annotation.__origin__ is Union and type(None) in get_args(annotation):
156-
if len(get_args(annotation)) == 2 and str in get_args(annotation):
157-
return str(candidate)
158-
raise e
172+
except pydantic.ValidationError:
173+
if origin is Union and type(None) in get_args(annotation) and str in get_args(annotation):
174+
return str(candidate)
175+
raise
159176

160177
def get_annotation_name(annotation):
161178
origin = get_origin(annotation)

0 commit comments

Comments
 (0)