Skip to content

Commit a218ec3

Browse files
author
c0sogi
committed
prevent memory access error by llama_grammar_free
1 parent b07713c commit a218ec3

File tree

3 files changed

+986
-672
lines changed

3 files changed

+986
-672
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
"""Helper classes for wrapping functions in OpenAI's API"""
2+
3+
from dataclasses import dataclass
4+
from inspect import signature
5+
import json
6+
from re import Pattern, compile
7+
import gc
8+
from typing import (
9+
Any,
10+
Callable,
11+
Dict,
12+
Generic,
13+
Iterable,
14+
List,
15+
Literal,
16+
Optional,
17+
Tuple,
18+
Type,
19+
TypeVar,
20+
Union,
21+
)
22+
23+
from typing_extensions import Annotated, NotRequired, TypedDict, get_args, get_origin
24+
25+
# Type aliases
26+
JsonTypes = Union[int, float, str, bool, dict, list, None]
27+
ParamType = TypeVar("ParamType", bound=JsonTypes)
28+
ReturnType = TypeVar("ReturnType")
29+
30+
31+
# whitespace is constrained to a single space char to prevent model "running away" in
32+
# whitespace. Also maybe improves generation quality?
33+
SPACE_RULE: str = '" "?'
34+
35+
PRIMITIVE_RULES: Dict[str, str] = {
36+
"boolean": '("true" | "false") space',
37+
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
38+
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
39+
"string": r""" "\"" (
40+
[^"\\] |
41+
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
42+
)* "\"" space """,
43+
"null": '"null" space',
44+
}
45+
46+
INVALID_RULE_CHARS_RE: "Pattern[str]" = compile(r"[^a-zA-Z0-9-]+")
47+
GRAMMAR_LITERAL_ESCAPE_RE: "Pattern[str]" = compile(r'[\r\n"]')
48+
GRAMMAR_LITERAL_ESCAPES: Dict[str, str] = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
49+
50+
# Type aliases
51+
SchemaType = Literal[
52+
"boolean", "number", "integer", "string", "null", "object", "array"
53+
]
54+
SchemaKey = Literal["type", "oneOf", "anyOf", "const", "enum", "properties", "items"]
55+
56+
57+
class ParameterProperty(TypedDict):
58+
type: str
59+
description: NotRequired[str]
60+
enum: NotRequired[List[JsonTypes]]
61+
62+
63+
class ParameterDefinition(TypedDict):
64+
type: Literal["object"]
65+
properties: Dict[str, ParameterProperty]
66+
required: NotRequired[List[str]]
67+
68+
69+
class FunctionProperty(TypedDict):
70+
name: str
71+
description: NotRequired[str]
72+
parameters: NotRequired[ParameterDefinition]
73+
74+
75+
@dataclass
76+
class FunctionCallParameter(Generic[ParamType]):
77+
"""A class for wrapping function parameters in OpenAI's API"""
78+
79+
name: str
80+
type: Type[ParamType]
81+
description: Optional[str] = None
82+
enum: Optional[List[ParamType]] = None
83+
84+
def to_dict(self) -> Dict[str, ParameterProperty]:
85+
"""Returns a dictionary representation of the parameter"""
86+
parameter_property: ParameterProperty = {
87+
"type": self._get_json_type(self.type)
88+
} # type: ignore
89+
if self.description:
90+
parameter_property["description"] = self.description
91+
if self.enum:
92+
parameter_property["enum"] = self.enum # type: ignore
93+
return {self.name: parameter_property}
94+
95+
@staticmethod
96+
def _get_json_type(python_type: Type[JsonTypes]) -> str:
97+
"""Returns the JSON type for a given python type"""
98+
if python_type is int:
99+
return "integer"
100+
elif python_type is float:
101+
return "number"
102+
elif python_type is str:
103+
return "string"
104+
elif python_type is bool:
105+
return "boolean"
106+
elif python_type is dict:
107+
return "object"
108+
elif python_type is list:
109+
return "array"
110+
elif python_type is type(None) or python_type is None:
111+
return "null"
112+
else:
113+
raise ValueError(
114+
f"Invalid type {python_type} for JSON. "
115+
f"Permitted types are {JsonTypes}"
116+
)
117+
118+
119+
@dataclass
120+
class FunctionCall:
121+
"""A class for wrapping functions in OpenAI's API"""
122+
123+
name: str
124+
description: Optional[str] = None
125+
parameters: Optional[List[FunctionCallParameter[Any]]] = None
126+
required: Optional[List[str]] = None
127+
128+
def to_dict(self) -> FunctionProperty:
129+
"""Returns a dictionary representation of the function"""
130+
function_property: FunctionProperty = FunctionProperty(
131+
name=self.name,
132+
)
133+
if self.description:
134+
function_property["description"] = self.description
135+
if self.parameters:
136+
function_property["parameters"] = {
137+
"type": "object",
138+
"properties": {
139+
param.name: param.to_dict()[param.name] for param in self.parameters
140+
},
141+
"required": [
142+
param.name
143+
for param in self.parameters
144+
if param.name in (self.required or [])
145+
],
146+
}
147+
return function_property
148+
149+
150+
class SchemaConverter:
151+
def __init__(self, prop_order: Dict[str, int]):
152+
self._prop_order = prop_order
153+
self._rules = {"space": SPACE_RULE}
154+
155+
def _format_literal(self, literal: Any) -> str:
156+
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
157+
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or "",
158+
json.dumps(literal),
159+
)
160+
return f'"{escaped}"'
161+
162+
def _add_rule(self, name, rule):
163+
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
164+
if esc_name not in self._rules or self._rules[esc_name] == rule:
165+
key = esc_name
166+
else:
167+
i = 0
168+
while f"{esc_name}{i}" in self._rules:
169+
i += 1
170+
key = f"{esc_name}{i}"
171+
self._rules[key] = rule
172+
return key
173+
174+
def visit(self, schema: Dict[SchemaKey, Any], name: str) -> str:
175+
schema_type: SchemaType = schema["type"] # The "type" key is always present
176+
rule_name: str = name or "root" # root rule is always named "root"
177+
178+
if "oneOf" in schema or "anyOf" in schema:
179+
# This is a union type
180+
rule: str = " | ".join(
181+
(
182+
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
183+
for i, alt_schema in enumerate(
184+
schema.get("oneOf") or schema["anyOf"]
185+
)
186+
)
187+
)
188+
return self._add_rule(rule_name, rule)
189+
190+
elif "const" in schema:
191+
# This is a literal
192+
return self._add_rule(rule_name, self._format_literal(schema["const"]))
193+
194+
elif "enum" in schema:
195+
# This is a set of literals
196+
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
197+
return self._add_rule(rule_name, rule)
198+
199+
elif schema_type == "object" and "properties" in schema:
200+
# TODO: `required` keyword
201+
prop_order = self._prop_order
202+
prop_pairs = sorted(
203+
schema["properties"].items(),
204+
# sort by position in prop_order (if specified) then by key
205+
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
206+
)
207+
208+
rule = '"{" space'
209+
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
210+
prop_rule_name = self.visit(
211+
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
212+
)
213+
if i > 0:
214+
rule += ' "," space'
215+
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
216+
rule += ' "}" space'
217+
218+
return self._add_rule(rule_name, rule)
219+
220+
elif schema_type == "array" and "items" in schema:
221+
# TODO `prefixItems` keyword
222+
item_rule_name = self.visit(
223+
schema["items"], f'{name}{"-" if name else ""}item'
224+
)
225+
rule = (
226+
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
227+
)
228+
return self._add_rule(rule_name, rule)
229+
230+
else:
231+
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
232+
return self._add_rule(
233+
"root" if rule_name == "root" else schema_type,
234+
PRIMITIVE_RULES[schema_type],
235+
)
236+
237+
def format_grammar(self):
238+
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
239+
240+
@staticmethod
241+
def parse_function_call_from_function(func: Callable) -> FunctionCall:
242+
"""
243+
Parse a function into a FunctionCall object.
244+
FunctionCall objects are used to represent the specification of a function
245+
"""
246+
json_types = get_args(JsonTypes)
247+
function_call_params = [] # type: List[FunctionCallParameter]
248+
required = [] # type: List[str]
249+
for name, param in signature(func).parameters.items():
250+
annotation = param.annotation
251+
description = "" # type: str
252+
enum = [] # type: List[Any]
253+
254+
if get_origin(annotation) is Annotated:
255+
# If the annotation is an Annotated type,
256+
# we need to parse the metadata
257+
_param_args = get_args(param.annotation)
258+
_param_type = _param_args[0]
259+
260+
for metadata in _param_args[1:]:
261+
if isinstance(metadata, str):
262+
# If the metadata is a string, it's the description
263+
description += metadata
264+
elif isinstance(metadata, Iterable):
265+
# If the metadata is an iterable, it's the enum
266+
enum.extend(metadata)
267+
268+
else:
269+
_param_type = annotation
270+
param_type, optional = _get_type_and_optional(_param_type)
271+
if not optional:
272+
required.append(name)
273+
if param_type not in json_types:
274+
continue
275+
function_call_params.append(
276+
FunctionCallParameter(
277+
name=name,
278+
type=param_type,
279+
description=description or None,
280+
enum=enum or None,
281+
)
282+
)
283+
line_break_pattern = compile(r"\n\s*")
284+
return FunctionCall(
285+
name=func.__name__,
286+
description=line_break_pattern.sub(" ", func.__doc__)
287+
if func.__doc__
288+
else None,
289+
parameters=function_call_params,
290+
required=required or None,
291+
)
292+
293+
@classmethod
294+
def from_function_call(
295+
cls,
296+
function_call: FunctionCall,
297+
prop_order: Optional[Dict[str, int]] = None,
298+
) -> str:
299+
self = cls(prop_order or {})
300+
parameters = function_call.to_dict().get("parameters")
301+
assert parameters is not None, "function call must have parameters"
302+
self.visit(dict(parameters), "")
303+
return self.format_grammar()
304+
305+
@classmethod
306+
def from_function(
307+
cls,
308+
function: Callable,
309+
prop_order: Optional[Dict[str, int]] = None,
310+
) -> str:
311+
return cls.from_function_call(
312+
cls.parse_function_call_from_function(function), prop_order
313+
)
314+
315+
316+
def _get_type_and_optional(t: Type) -> Tuple[Type, bool]:
317+
"""Returns the type and whether it's an Optional type.
318+
This is useful when Type can be Union and you want to know if it's an Optional type.
319+
"""
320+
# Optional[str] is equivalent to Union[str, None], so check if it's a Union type.
321+
if get_origin(t) in (type(Union), Union):
322+
args = get_args(t) # type: Tuple[Type, ...]
323+
# If there's a None type in the Union, it's an Optional type.
324+
optional = type(None) in args
325+
# Return the first argument that isn't None.
326+
first_arg = next(arg for arg in args if arg is not type(None))
327+
return first_arg, optional
328+
else:
329+
# If it's not a Union type, it's not an Optional type.
330+
return t, False
331+
332+
333+
if __name__ == "__main__":
334+
from llama_cpp import LlamaGrammar, Llama
335+
gc.disable()
336+
337+
# Define a python function and parse it into a grammar
338+
def get_current_weather(
339+
location: Annotated[
340+
str,
341+
"The location to get the current weather for",
342+
],
343+
unit: Annotated[
344+
str,
345+
"The unit of temperature to return",
346+
["fahrenheit", "celsius"],
347+
],
348+
source: Annotated[
349+
str,
350+
"The source of the weather information",
351+
["openweathermap", "weatherapi"],
352+
] = "openweathermap",
353+
):
354+
"""Get the current weather in a given location"""
355+
356+
model_path = "C:/Users/sdml/Desktop/orca-mini-3b.ggmlv3.q4_0.bin"
357+
grammar: str = SchemaConverter.from_function(get_current_weather)
358+
llama_grammar = LlamaGrammar.from_string(grammar, verbose=False)
359+
llm = Llama(model_path)
360+
llm.grammar = llama_grammar
361+
for city in ("London", "Paris", "New York", "Berlin", "Tokyo"):
362+
print(llm(prompt=f"### User: What is the weather in {city} today? ### Assistant:")["choices"][0]["text"]) # type: ignore
363+
364+
# Output:
365+
# { "location": "London", "source": "openweathermap","unit" : "celsius"}

0 commit comments

Comments
 (0)