Skip to content

Commit ec55576

Browse files
author
Varun Puri
committed
Set up OpenAI API Python Dataclasses
1 parent 9d1a86c commit ec55576

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

api/api.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import time
2+
import uuid
3+
from abc import ABC
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, List, Optional
6+
7+
from build.utils import device_sync
8+
9+
from generate import Generator, GeneratorArgs
10+
11+
12+
@dataclass
13+
class AbstractMessageType(ABC):
14+
role: str
15+
content: Optional[str] = None
16+
17+
18+
@dataclass
19+
class SystemMessageType(AbstractMessageType):
20+
role: str = "system"
21+
name: Optional[str] = None
22+
23+
24+
@dataclass
25+
class UserMessageType(AbstractMessageType):
26+
role: str = "user"
27+
28+
29+
@dataclass
30+
class ToolCall:
31+
@dataclass
32+
class ToolCallFunction:
33+
name: str
34+
arguments: str
35+
36+
id: str
37+
type: str
38+
function: ToolCallFunction
39+
40+
41+
@dataclass
42+
class AssistantMessageType(AbstractMessageType):
43+
role: str = "assistant"
44+
name: Optional[str] = None
45+
tool_calls: Optional[List[ToolCall]] = None
46+
47+
48+
@dataclass
49+
class ToolMessage(AbstractMessageType):
50+
role: str = "tool"
51+
tool_call_id: Optional[str] = None
52+
53+
54+
@dataclass
55+
class CompletionRequest:
56+
model: str
57+
prompt: str
58+
messages: Optional[List[AbstractMessageType]]
59+
frequency_penalty: float = 0.0
60+
temperature: float = 0.0
61+
stop: Optional[List[str]] = None
62+
echo: bool = False
63+
frequency_penalty: float = 0.0
64+
guided_decode_json_schema: str = None
65+
guided_decode_json_schema_path: str = None
66+
n: int = 1
67+
presence_penalty: float = 0
68+
logit_bias: Optional[Dict[str, float]] = None
69+
logprobs: Optional[bool] = None
70+
top_logprobs: Optional[int] = None
71+
max_tokens: Optional[int] = None
72+
73+
74+
@dataclass
75+
class CompletionChoice:
76+
finish_reason: str
77+
index: int
78+
message: AssistantMessageType
79+
logprobs: Optional[List[Any]]
80+
81+
82+
@dataclass
83+
class UsageStats:
84+
completion_tokens: int
85+
prompt_tokens: int
86+
total_tokens: int
87+
88+
89+
@dataclass
90+
class CompletionResponse:
91+
92+
id: str
93+
choices: List[CompletionChoice]
94+
created: int
95+
model: str
96+
system_fingerprint: str
97+
usage: UsageStats
98+
object: str = "chat.completion"
99+
service_tier: Optional[str] = None
100+
101+
102+
@dataclass
103+
class ChunkDelta:
104+
tool_calls: Optional[List[ToolCall]]
105+
role: Optional[str]
106+
content: Optional[str]
107+
108+
109+
@dataclass
110+
class CompletionChoiceChunk:
111+
delta: ChunkDelta
112+
index: int
113+
finish_reason: Optional[str] = None
114+
logprobs: Optional[List[Any]] = None
115+
116+
117+
@dataclass
118+
class CompletionResponseChunk:
119+
120+
id: str
121+
choices: List[CompletionChoiceChunk]
122+
created: int
123+
model: str
124+
system_fingerprint: str
125+
object: str = "chat.completion.chunk"
126+
service_tier: Optional[str] = None
127+
usage: Optional[UsageStats] = None
128+
129+
130+
class OpenAIAPIGenerator(Generator):
131+
def __init__(self, *args, **kwargs):
132+
super().__init__(*args, **kwargs)
133+
self.start_pos = 0
134+
self.max_seq_length = (
135+
self.model.config.max_seq_length
136+
+ self.speculative_builder_args.speculate_k
137+
+ 1
138+
if self.draft_model is not None
139+
else self.model.config.max_seq_length
140+
)
141+
142+
def completion(self, completion_request: CompletionRequest):
143+
device_sync(device=self.builder_args.device)
144+
145+
id = str(uuid.uuid4())
146+
idx = 0
147+
buffer = []
148+
encoded = self.encode_tokens(
149+
completion_request.prompt, bos=True, device=self.builder_args.device
150+
)
151+
generator_args = GeneratorArgs(
152+
completion_request.prompt,
153+
encoded_prompt=encoded,
154+
chat_mode=False,
155+
)
156+
157+
def callback(x, *, done_generating=False):
158+
return self._callback(
159+
x,
160+
buffer=buffer,
161+
done_generating=done_generating,
162+
)
163+
164+
for y, _ in self.generate(
165+
self.model,
166+
encoded,
167+
generator_args.max_new_tokens,
168+
draft_model=self.draft_model,
169+
speculate_k=generator_args.speculate_k,
170+
chat_mode=generator_args.chat_mode,
171+
callback=callback,
172+
temperature=generator_args.temperature,
173+
top_k=generator_args.top_k,
174+
sequential_prefill=generator_args.sequential_prefill,
175+
start_pos=self.start_pos,
176+
max_seq_length=self.max_seq_length,
177+
):
178+
content = "".join(
179+
self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:]
180+
)
181+
chunk_delta = ChunkDelta(
182+
role="assistant",
183+
content=content,
184+
tool_calls=None,
185+
)
186+
choice_chunk = CompletionChoiceChunk(
187+
delta=chunk_delta,
188+
index=idx,
189+
)
190+
chunk_response = CompletionResponseChunk(
191+
id=str(id),
192+
choices=[choice_chunk],
193+
created=int(time.time()),
194+
model=completion_request.model,
195+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
196+
)
197+
yield chunk_response
198+
self.start_pos += y.size(0)
199+
idx += 1
200+
201+
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
202+
203+
yield CompletionResponseChunk(
204+
id=str(id),
205+
choices=[end_chunk],
206+
created=int(time.time()),
207+
model=completion_request.model,
208+
system_fingerprint=uuid.UUID(int=uuid.getnode()),
209+
)
210+
211+
def _callback(self, x, *, buffer, done_generating):
212+
period_id = self.tokenizer.encode(".")[0]
213+
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
214+
if (
215+
self.is_llama3_model
216+
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
217+
):
218+
buffer = buffer[:-1] # drop the eot_id from the output buffer

0 commit comments

Comments
 (0)