Skip to content

Commit b0081ed

Browse files
authored
Set up OpenAI API Python Dataclasses (#907)
Introduces basic Python dataclasses to represent requests, responses, and associated objects defined in the OpenAI API specification. This will enable creating a basic server that follows the spec so users can leverage TorchChat to run LLMs on their own hardware with a familiar interface to existing cloud tools. **Testing** Lintunner ``` lintrunner Warning: Could not find a lintrunner config at: '.lintrunner.private.toml'. Continuing without using configuration file. >>> Lint for generate.py: Advice (FLAKE8) C901 'Generator.chat' is too complex (32) See https://www.flake8rules.com/rules/C901.html. 581 | buffer.clear() 582 | # print(, end='', flush=True) 583 | >>> 584 | def chat( 585 | self, 586 | generator_args: GeneratorArgs, 587 | ): ``` (advice from prior commit) Dataclasses are used and tested further in PR #908
1 parent ab85b2a commit b0081ed

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed

api/api.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
import uuid
9+
from abc import ABC
10+
from dataclasses import dataclass
11+
from typing import Any, Dict, List, Optional
12+
13+
from build.utils import device_sync
14+
15+
from generate import Generator, GeneratorArgs
16+
17+
"""Dataclasses defined around the objects used the OpenAI API Chat specification.
18+
19+
See https://platform.openai.com/docs/api-reference/chat for the full specification and details.
20+
"""
21+
22+
# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"
23+
24+
25+
@dataclass
26+
class _AbstractMessage(ABC):
27+
"""Base class with common parameters for message types.
28+
29+
Each message type is associated with a role (one of "system", "user", "assistant" or "tool") and contains an
30+
optional content field.
31+
32+
See more details at https://platform.openai.com/docs/guides/text-generation/chat-completions-api .
33+
"""
34+
35+
role: str
36+
content: Optional[str] = None
37+
38+
39+
@dataclass
40+
class SystemMessage(_AbstractMessage):
41+
role: str = "system"
42+
name: Optional[str] = None
43+
44+
45+
@dataclass
46+
class UserMessage(_AbstractMessage):
47+
role: str = "user"
48+
49+
50+
@dataclass
51+
class ToolMessage:
52+
tool_call_id: str
53+
type: str
54+
role: str = "tool"
55+
56+
57+
@dataclass
58+
class ToolCallFunction:
59+
name: str
60+
arguments: str
61+
62+
63+
@dataclass
64+
class ToolCall:
65+
id: str
66+
type: str
67+
function: ToolCallFunction
68+
69+
70+
@dataclass
71+
class AssistantMessage(_AbstractMessage):
72+
role: str = "assistant"
73+
name: Optional[str] = None
74+
tool_calls: Optional[List[ToolCall]] = None
75+
76+
77+
# Completion request and response types.
78+
79+
80+
@dataclass
81+
class CompletionRequest:
82+
"""A full chat completion request.
83+
84+
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
85+
"""
86+
87+
model: str
88+
prompt: str
89+
messages: Optional[List[_AbstractMessage]]
90+
frequency_penalty: float = 0.0
91+
temperature: float = 0.0
92+
stop: Optional[List[str]] = None
93+
echo: bool = False
94+
frequency_penalty: float = 0.0
95+
guided_decode_json_schema: str = None
96+
guided_decode_json_schema_path: str = None
97+
n: int = 1
98+
presence_penalty: float = 0
99+
logit_bias: Optional[Dict[str, float]] = None
100+
logprobs: Optional[bool] = None
101+
top_logprobs: Optional[int] = None
102+
max_tokens: Optional[int] = None
103+
104+
105+
@dataclass
106+
class CompletionChoice:
107+
"""A single choice in a chat completion response.
108+
109+
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
110+
"""
111+
112+
finish_reason: str
113+
index: int
114+
message: AssistantMessage
115+
logprobs: Optional[List[Any]]
116+
117+
118+
@dataclass
119+
class UsageStats:
120+
"""Object representing a single choice in a chat completion response.
121+
122+
See the "The chat completion object >>> usage" section of the OpenAI API docs for more details.
123+
"""
124+
125+
completion_tokens: int
126+
prompt_tokens: int
127+
total_tokens: int
128+
129+
130+
@dataclass
131+
class CompletionResponse:
132+
"""A full chat completion response.
133+
134+
See the "The chat completion object" section of the OpenAI API docs for more details.
135+
"""
136+
137+
id: str
138+
choices: List[CompletionChoice]
139+
created: int
140+
model: str
141+
system_fingerprint: str
142+
usage: UsageStats
143+
object: str = "chat.completion"
144+
service_tier: Optional[str] = None
145+
146+
147+
@dataclass
148+
class ChunkDelta:
149+
"""Changes between the previous chunk emitted for a chunked completion response.
150+
151+
See the "The chat completion chunk object >>> choices >>> delta" section of the OpenAI API docs for more details.
152+
"""
153+
154+
tool_calls: Optional[List[ToolCall]]
155+
role: Optional[str]
156+
content: Optional[str]
157+
158+
159+
@dataclass
160+
class CompletionChoiceChunk:
161+
"""A single choice in a chat completion chunk response.
162+
163+
See the "The chat completion chunk object >>> choices" section of the OpenAI API docs for more details.
164+
"""
165+
166+
delta: ChunkDelta
167+
index: int
168+
finish_reason: Optional[str] = None
169+
logprobs: Optional[List[Any]] = None
170+
171+
172+
@dataclass
173+
class CompletionResponseChunk:
174+
"""Response chunk emitted during a chunked completion response.
175+
176+
See the "The chat completion chunk object" section of the OpenAI API docs for more details.
177+
"""
178+
179+
id: str
180+
choices: List[CompletionChoiceChunk]
181+
created: int
182+
model: str
183+
system_fingerprint: str
184+
object: str = "chat.completion.chunk"
185+
service_tier: Optional[str] = None
186+
usage: Optional[UsageStats] = None
187+
188+
189+
class OpenAiApiGenerator(Generator):
190+
"""A wrapper over the Generator class to interface with the OpenAI API.
191+
192+
Implements endpoints for completion requests, both chunked and non-chunked using the dataclasses
193+
defined above.
194+
"""
195+
196+
def __init__(self, *args, **kwargs):
197+
"""Initialize generator and parameters for maintaining context during generation.
198+
199+
See the docstring for the Generator class in generate.py for argument details.
200+
"""
201+
202+
super().__init__(*args, **kwargs)
203+
self.start_pos = 0
204+
self.max_seq_length = (
205+
self.model.config.max_seq_length
206+
+ self.speculative_builder_args.speculate_k
207+
+ 1
208+
if self.draft_model is not None
209+
else self.model.config.max_seq_length
210+
)
211+
212+
def completion(self, completion_request: CompletionRequest):
213+
"""Handle a chat completion request and yield a chunked response.
214+
215+
Args:
216+
completion_request: Request object with prompt and other parameters.
217+
218+
Yields:
219+
CompletionResponseChunk objects in response to completion_request as tokens are generated.
220+
221+
"""
222+
device_sync(device=self.builder_args.device)
223+
224+
# Initialize counters for chunk responses and encode the prompt.
225+
id = str(uuid.uuid4())
226+
idx = 0
227+
buffer = []
228+
encoded = self.encode_tokens(
229+
completion_request.prompt, bos=True, device=self.builder_args.device
230+
)
231+
generator_args = GeneratorArgs(
232+
completion_request.prompt,
233+
encoded_prompt=encoded,
234+
chat_mode=False,
235+
)
236+
237+
def callback(x, *, done_generating=False):
238+
return self._callback(
239+
x,
240+
buffer=buffer,
241+
done_generating=done_generating,
242+
)
243+
244+
# Process each token, metrics tuple yielded by Generator.generate.
245+
for y, _ in self.generate(
246+
self.model,
247+
encoded,
248+
generator_args.max_new_tokens,
249+
draft_model=self.draft_model,
250+
speculate_k=generator_args.speculate_k,
251+
chat_mode=generator_args.chat_mode,
252+
callback=callback,
253+
temperature=generator_args.temperature,
254+
top_k=generator_args.top_k,
255+
sequential_prefill=generator_args.sequential_prefill,
256+
start_pos=self.start_pos,
257+
max_seq_length=self.max_seq_length,
258+
):
259+
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
260+
content = "".join(
261+
self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:]
262+
)
263+
264+
# Package the sequence into a CompletionChunkResponse and yield it.
265+
chunk_delta = ChunkDelta(
266+
role="assistant",
267+
content=content,
268+
tool_calls=None,
269+
)
270+
choice_chunk = CompletionChoiceChunk(
271+
delta=chunk_delta,
272+
index=idx,
273+
)
274+
chunk_response = CompletionResponseChunk(
275+
id=str(id),
276+
choices=[choice_chunk],
277+
created=int(time.time()),
278+
model=completion_request.model,
279+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
280+
)
281+
yield chunk_response
282+
self.start_pos += y.size(0)
283+
idx += 1
284+
285+
# Yield an ending chunk indicating the generation has completed.
286+
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
287+
288+
yield CompletionResponseChunk(
289+
id=str(id),
290+
choices=[end_chunk],
291+
created=int(time.time()),
292+
model=completion_request.model,
293+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
294+
)
295+
296+
def _callback(self, x, *, buffer, done_generating):
297+
period_id = self.tokenizer.encode(".")[0]
298+
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
299+
if (
300+
self.is_llama3_model
301+
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
302+
):
303+
buffer = buffer[:-1] # drop the eot_id from the output buffer

0 commit comments

Comments
 (0)