Skip to content

Commit 24b9327

Browse files
vmpuridbort
authored andcommitted
Set up OpenAI API Python Dataclasses
Introduces basic Python dataclasses to represent requests, responses, and associated objects defined in the OpenAI API specification. This will enable creating a basic server that follows the spec so users can leverage TorchChat to run LLMs on their own hardware with a familiar interface to existing cloud tools. **Testing** Lintunner ``` lintrunner Warning: Could not find a lintrunner config at: '.lintrunner.private.toml'. Continuing without using configuration file. >>> Lint for generate.py: Advice (FLAKE8) C901 'Generator.chat' is too complex (32) See https://www.flake8rules.com/rules/C901.html. 581 | buffer.clear() 582 | # print(, end='', flush=True) 583 | >>> 584 | def chat( 585 | self, 586 | generator_args: GeneratorArgs, 587 | ): ``` (advice from prior commit) Dataclasses are used and tested further in PR #908
1 parent 4a275c5 commit 24b9327

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed

api/api.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
import uuid
9+
from abc import ABC
10+
from dataclasses import dataclass
11+
from typing import Any, Dict, List, Optional
12+
13+
from build.utils import device_sync
14+
15+
from generate import Generator, GeneratorArgs
16+
17+
"""Dataclasses defined around the objects used the OpenAI API Chat specification.
18+
19+
See https://platform.openai.com/docs/api-reference/chat for the full specification and details.
20+
"""
21+
22+
# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"
23+
24+
25+
@dataclass
26+
class _AbstractMessage(ABC):
27+
"""Base class with common parameters for message types.
28+
29+
Each message type is associated with a role (one of "system", "user", "assistant" or "tool") and contains an
30+
optional content field.
31+
32+
See more details at https://platform.openai.com/docs/guides/text-generation/chat-completions-api .
33+
"""
34+
35+
role: str
36+
content: Optional[str] = None
37+
38+
39+
@dataclass
40+
class SystemMessage(_AbstractMessage):
41+
role: str = "system"
42+
name: Optional[str] = None
43+
44+
45+
@dataclass
46+
class UserMessage(_AbstractMessage):
47+
role: str = "user"
48+
49+
50+
@dataclass
51+
class ToolMessage:
52+
tool_call_id: str
53+
type: str
54+
role: str = "tool"
55+
56+
57+
@dataclass
58+
class ToolCallFunction:
59+
name: str
60+
arguments: str
61+
62+
63+
@dataclass
64+
class ToolCall:
65+
id: str
66+
type: str
67+
function: ToolCallFunction
68+
69+
70+
@dataclass
71+
class AssistantMessage(_AbstractMessage):
72+
role: str = "assistant"
73+
name: Optional[str] = None
74+
tool_calls: Optional[List[ToolCall]] = None
75+
76+
77+
# Completion request and response types.
78+
79+
80+
@dataclass
81+
class CompletionRequest:
82+
"""A full chat completion request.
83+
84+
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
85+
"""
86+
87+
model: str
88+
prompt: str
89+
messages: Optional[List[_AbstractMessage]]
90+
frequency_penalty: float = 0.0
91+
temperature: float = 0.0
92+
stop: Optional[List[str]] = None
93+
echo: bool = False
94+
frequency_penalty: float = 0.0
95+
guided_decode_json_schema: str = None
96+
guided_decode_json_schema_path: str = None
97+
n: int = 1
98+
presence_penalty: float = 0
99+
logit_bias: Optional[Dict[str, float]] = None
100+
logprobs: Optional[bool] = None
101+
top_logprobs: Optional[int] = None
102+
max_tokens: Optional[int] = None
103+
104+
105+
@dataclass
106+
class CompletionChoice:
107+
"""A single choice in a chat completion response.
108+
109+
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
110+
"""
111+
112+
finish_reason: str
113+
index: int
114+
message: AssistantMessage
115+
logprobs: Optional[List[Any]]
116+
117+
118+
@dataclass
119+
class UsageStats:
120+
"""Object representing a single choice in a chat completion response.
121+
122+
See the "The chat completion object >>> usage" section of the OpenAI API docs for more details.
123+
"""
124+
125+
completion_tokens: int
126+
prompt_tokens: int
127+
total_tokens: int
128+
129+
130+
@dataclass
131+
class CompletionResponse:
132+
"""A full chat completion response.
133+
134+
See the "The chat completion object" section of the OpenAI API docs for more details.
135+
"""
136+
137+
id: str
138+
choices: List[CompletionChoice]
139+
created: int
140+
model: str
141+
system_fingerprint: str
142+
usage: UsageStats
143+
object: str = "chat.completion"
144+
service_tier: Optional[str] = None
145+
146+
147+
@dataclass
148+
class ChunkDelta:
149+
"""Changes between the previous chunk emitted for a chunked completion response.
150+
151+
See the "The chat completion chunk object >>> choices >>> delta" section of the OpenAI API docs for more details.
152+
"""
153+
154+
tool_calls: Optional[List[ToolCall]]
155+
role: Optional[str]
156+
content: Optional[str]
157+
158+
159+
@dataclass
160+
class CompletionChoiceChunk:
161+
"""A single choice in a chat completion chunk response.
162+
163+
See the "The chat completion chunk object >>> choices" section of the OpenAI API docs for more details.
164+
"""
165+
166+
delta: ChunkDelta
167+
index: int
168+
finish_reason: Optional[str] = None
169+
logprobs: Optional[List[Any]] = None
170+
171+
172+
@dataclass
173+
class CompletionResponseChunk:
174+
"""Response chunk emitted during a chunked completion response.
175+
176+
See the "The chat completion chunk object" section of the OpenAI API docs for more details.
177+
"""
178+
179+
id: str
180+
choices: List[CompletionChoiceChunk]
181+
created: int
182+
model: str
183+
system_fingerprint: str
184+
object: str = "chat.completion.chunk"
185+
service_tier: Optional[str] = None
186+
usage: Optional[UsageStats] = None
187+
188+
189+
class OpenAiApiGenerator(Generator):
190+
"""A wrapper over the Generator class to interface with the OpenAI API.
191+
192+
Implements endpoints for completion requests, both chunked and non-chunked using the dataclasses
193+
defined above.
194+
"""
195+
196+
def __init__(self, *args, **kwargs):
197+
"""Initialize generator and parameters for maintaining context during generation.
198+
199+
See the docstring for the Generator class in generate.py for argument details.
200+
"""
201+
202+
super().__init__(*args, **kwargs)
203+
self.start_pos = 0
204+
self.max_seq_length = (
205+
self.model.config.max_seq_length
206+
+ self.speculative_builder_args.speculate_k
207+
+ 1
208+
if self.draft_model is not None
209+
else self.model.config.max_seq_length
210+
)
211+
212+
def completion(self, completion_request: CompletionRequest):
213+
"""Handle a chat completion request and yield a chunked response.
214+
215+
Args:
216+
completion_request: Request object with prompt and other parameters.
217+
218+
Yields:
219+
CompletionResponseChunk objects in response to completion_request as tokens are generated.
220+
221+
"""
222+
device_sync(device=self.builder_args.device)
223+
224+
# Initialize counters for chunk responses and encode the prompt.
225+
id = str(uuid.uuid4())
226+
idx = 0
227+
buffer = []
228+
encoded = self.encode_tokens(
229+
completion_request.prompt, bos=True, device=self.builder_args.device
230+
)
231+
generator_args = GeneratorArgs(
232+
completion_request.prompt,
233+
encoded_prompt=encoded,
234+
chat_mode=False,
235+
)
236+
237+
def callback(x, *, done_generating=False):
238+
return self._callback(
239+
x,
240+
buffer=buffer,
241+
done_generating=done_generating,
242+
)
243+
244+
# Process each token, metrics tuple yielded by Generator.generate.
245+
for y, _ in self.generate(
246+
self.model,
247+
encoded,
248+
generator_args.max_new_tokens,
249+
draft_model=self.draft_model,
250+
speculate_k=generator_args.speculate_k,
251+
chat_mode=generator_args.chat_mode,
252+
callback=callback,
253+
temperature=generator_args.temperature,
254+
top_k=generator_args.top_k,
255+
sequential_prefill=generator_args.sequential_prefill,
256+
start_pos=self.start_pos,
257+
max_seq_length=self.max_seq_length,
258+
):
259+
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
260+
content = "".join(
261+
self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:]
262+
)
263+
264+
# Package the sequence into a CompletionChunkResponse and yield it.
265+
chunk_delta = ChunkDelta(
266+
role="assistant",
267+
content=content,
268+
tool_calls=None,
269+
)
270+
choice_chunk = CompletionChoiceChunk(
271+
delta=chunk_delta,
272+
index=idx,
273+
)
274+
chunk_response = CompletionResponseChunk(
275+
id=str(id),
276+
choices=[choice_chunk],
277+
created=int(time.time()),
278+
model=completion_request.model,
279+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
280+
)
281+
yield chunk_response
282+
self.start_pos += y.size(0)
283+
idx += 1
284+
285+
# Yield an ending chunk indicating the generation has completed.
286+
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
287+
288+
yield CompletionResponseChunk(
289+
id=str(id),
290+
choices=[end_chunk],
291+
created=int(time.time()),
292+
model=completion_request.model,
293+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
294+
)
295+
296+
def _callback(self, x, *, buffer, done_generating):
297+
period_id = self.tokenizer.encode(".")[0]
298+
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
299+
if (
300+
self.is_llama3_model
301+
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
302+
):
303+
buffer = buffer[:-1] # drop the eot_id from the output buffer

0 commit comments

Comments
 (0)