Skip to content

pydantic : replace uses of __annotations__ with get_type_hints #8474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions examples/pydantic_models_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import copy
from enum import Enum
from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints

from docstring_parser import parse
from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -53,35 +53,38 @@ class PydanticDataType(Enum):


def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str:
if isclass(pydantic_type) and issubclass(pydantic_type, str):
origin_type = get_origin(pydantic_type)
origin_type = pydantic_type if origin_type is None else origin_type

if isclass(origin_type) and issubclass(origin_type, str):
return PydanticDataType.STRING.value
elif isclass(pydantic_type) and issubclass(pydantic_type, bool):
elif isclass(origin_type) and issubclass(origin_type, bool):
return PydanticDataType.BOOLEAN.value
elif isclass(pydantic_type) and issubclass(pydantic_type, int):
elif isclass(origin_type) and issubclass(origin_type, int):
return PydanticDataType.INTEGER.value
elif isclass(pydantic_type) and issubclass(pydantic_type, float):
elif isclass(origin_type) and issubclass(origin_type, float):
return PydanticDataType.FLOAT.value
elif isclass(pydantic_type) and issubclass(pydantic_type, Enum):
elif isclass(origin_type) and issubclass(origin_type, Enum):
return PydanticDataType.ENUM.value

elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
return format_model_and_field_name(pydantic_type.__name__)
elif get_origin(pydantic_type) is list:
elif isclass(origin_type) and issubclass(origin_type, BaseModel):
return format_model_and_field_name(origin_type.__name__)
elif origin_type is list:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
elif get_origin(pydantic_type) is set:
elif origin_type is set:
element_type = get_args(pydantic_type)[0]
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
elif get_origin(pydantic_type) is Union:
elif origin_type is Union:
union_types = get_args(pydantic_type)
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
return f"union-{'-or-'.join(union_rules)}"
elif get_origin(pydantic_type) is Optional:
elif origin_type is Optional:
element_type = get_args(pydantic_type)[0]
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
elif isclass(pydantic_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
elif get_origin(pydantic_type) is dict:
elif isclass(origin_type):
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}"
elif origin_type is dict:
key_type, value_type = get_args(pydantic_type)
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
else:
Expand Down Expand Up @@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
# Modify this comprehension
members = [
f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}'
for name, param_type in cls.__annotations__.items()
for name, param_type in get_type_hints(cls).items()
if name != "self"
]

Expand Down Expand Up @@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
field_name = format_model_and_field_name(field_name)
gbnf_type = map_pydantic_type_to_gbnf(field_type)

if isclass(field_type) and issubclass(field_type, BaseModel):
origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type

if isclass(origin_type) and issubclass(origin_type, BaseModel):
nested_model_name = format_model_and_field_name(field_type.__name__)
nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules)
rules.extend(nested_model_rules)
gbnf_type, rules = nested_model_name, rules
elif isclass(field_type) and issubclass(field_type, Enum):
elif isclass(origin_type) and issubclass(origin_type, Enum):
enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
rules.append(enum_rule)
gbnf_type, rules = model_name + "-" + field_name, rules
elif get_origin(field_type) == list: # Array
elif origin_type is list: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
Expand All @@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
rules.append(array_rule)
gbnf_type, rules = model_name + "-" + field_name, rules

elif get_origin(field_type) == set or field_type == set: # Array
elif origin_type is set: # Array
element_type = get_args(field_type)[0]
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
Expand Down Expand Up @@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
gbnf_type = f"{model_name}-{field_name}-optional"
else:
gbnf_type = f"{model_name}-{field_name}-union"
elif isclass(field_type) and issubclass(field_type, str):
elif isclass(origin_type) and issubclass(origin_type, str):
if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None:
triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False)
markdown_string = field_info.json_schema_extra.get("markdown_code_block", False)
Expand All @@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
gbnf_type = PydanticDataType.STRING.value

elif (
isclass(field_type)
and issubclass(field_type, float)
isclass(origin_type)
and issubclass(origin_type, float)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
Expand All @@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
)

elif (
isclass(field_type)
and issubclass(field_type, int)
isclass(origin_type)
and issubclass(origin_type, int)
and field_info
and hasattr(field_info, "json_schema_extra")
and field_info.json_schema_extra is not None
Expand Down Expand Up @@ -462,15 +468,15 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
if not issubclass(model, BaseModel):
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
if hasattr(model, "__annotations__") and model.__annotations__:
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()}
else:
init_signature = inspect.signature(model.__init__)
parameters = init_signature.parameters
model_fields = {name: (param.annotation, param.default) for name, param in parameters.items() if
name != "self"}
else:
# For Pydantic models, use model_fields and check for ellipsis (required fields)
model_fields = model.__annotations__
model_fields = get_type_hints(model)

model_rule_parts = []
nested_rules = []
Expand Down Expand Up @@ -706,7 +712,7 @@ def generate_markdown_documentation(
else:
documentation += f" Fields:\n" # noqa: F541
if isclass(model) and issubclass(model, BaseModel):
for name, field_type in model.__annotations__.items():
for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
Expand Down Expand Up @@ -754,14 +760,17 @@ def generate_field_markdown(
field_info = model.model_fields.get(field_name)
field_description = field_info.description if field_info and field_info.description else ""

if get_origin(field_type) == list:
origin_type = get_origin(field_type)
origin_type = field_type if origin_type is None else origin_type

if origin_type == list:
element_type = get_args(field_type)[0]
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
if field_description != "":
field_text += ":\n"
else:
field_text += "\n"
elif get_origin(field_type) == Union:
elif origin_type == Union:
element_types = get_args(field_type)
types = []
for element_type in element_types:
Expand Down Expand Up @@ -792,9 +801,9 @@ def generate_field_markdown(
example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example
field_text += f"{indent} Example: {example_text}\n"

if isclass(field_type) and issubclass(field_type, BaseModel):
if isclass(origin_type) and issubclass(origin_type, BaseModel):
field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items():
for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_markdown(name, type_, field_type, depth + 2)

return field_text
Expand Down Expand Up @@ -855,7 +864,7 @@ def generate_text_documentation(

if isclass(model) and issubclass(model, BaseModel):
documentation_fields = ""
for name, field_type in model.__annotations__.items():
for name, field_type in get_type_hints(model).items():
# if name == "markdown_code_block":
# continue
if get_origin(field_type) == list:
Expand Down Expand Up @@ -948,7 +957,7 @@ def generate_field_text(

if isclass(field_type) and issubclass(field_type, BaseModel):
field_text += f"{indent} Details:\n"
for name, type_ in field_type.__annotations__.items():
for name, type_ in get_type_hints(field_type).items():
field_text += generate_field_text(name, type_, field_type, depth + 2)

return field_text
Expand Down
2 changes: 2 additions & 0 deletions examples/pydantic_models_to_grammar_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def create_completion(prompt, grammar):
response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data)
data = response.json()

assert data.get("error") is None, data

print(data["content"])
return data["content"]

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-pydantic.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
docstring_parser~=0.15
pydantic~=2.6.3
requests
Loading