Skip to content

Commit 7de4d37

Browse files
authored
Merge branch 'main' into add_torchchat_folder
2 parents ae0961e + 3ce1cef commit 7de4d37

File tree

4 files changed

+28
-14
lines changed

4 files changed

+28
-14
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ streamlit run torchchat.py -- browser llama3.1
200200
<details>
201201
<summary>This mode gives a REST API that matches the OpenAI API spec for interacting with a model</summary>
202202

203+
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
204+
Since this feature is under active development, not every parameter is consumed. See api/api.py for details on
205+
which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973).
206+
203207
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
204208

205209
In one terminal, start the server
@@ -213,8 +217,7 @@ python3 torchchat.py server llama3.1
213217

214218
In another terminal, query the server using `curl`. Depending on the model configuration, this query might take a few minutes to respond.
215219

216-
Setting `stream` to "true" in the request emits a response in chunks. Currently, this response
217-
is plaintext and will not be formatted to the OpenAI API specification. If `stream` is unset or not "true", then the client will await the full response from the server.
220+
Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server.
218221

219222

220223
**Example Input + Output**
@@ -227,6 +230,7 @@ curl http://127.0.0.1:5000/v1/chat \
227230
-d '{
228231
"model": "llama3.1",
229232
"stream": "true",
233+
"max_tokens": 200,
230234
"messages": [
231235
{
232236
"role": "system",

api/api.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
See https://platform.openai.com/docs/api-reference/chat for the full specification and details.
2020
"""
2121

22+
OPENAI_API_DEFAULT_MAX_TOKENS = 16
23+
2224
# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"
2325

2426

@@ -105,20 +107,20 @@ class CompletionRequest:
105107
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106108
logprobs: Optional[bool] = None # unimplemented
107109
top_logprobs: Optional[int] = None # unimplemented
108-
max_tokens: Optional[int] = None # unimplemented
110+
max_tokens: Optional[int] = None
109111
n: int = 1
110112
presence_penalty: float = 0 # unimplemented
111113
response_format: Optional[ResponseFormat] = None # unimplemented
112-
seed: Optional[int] = None # unimplemented
114+
seed: Optional[int] = None
113115
service_tier: Optional[str] = None # unimplemented
114116
stop: Optional[List[str]] = None # unimplemented
115117
stream: bool = False
116118
stream_options: Optional[StreamOptions] = None # unimplemented
117-
temperature: Optional[float] = 1.0 # unimplemented
119+
temperature: Optional[float] = 1.0
118120
top_p: Optional[float] = 1.0 # unimplemented
119-
tools: Optional[List[Any]] = None # unimplemented
120-
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121-
parallel_tool_calls: Optional[bool] = None # unimplemented
121+
tools: Optional[List[Any]] = None # unimplemented - Assistant features
122+
tool_choice: Optional[Union[str, Any]] = None # unimplemented - Assistant features
123+
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
122124
user: Optional[str] = None # unimplemented
123125

124126

@@ -229,9 +231,8 @@ def __init__(self, *args, **kwargs):
229231
else self.model.config.max_seq_length
230232
)
231233
# The System fingerprint is a unique identifier for the model and its configuration.
232-
# Currently, this is not implemented in a
233234
self.system_fingerprint = (
234-
self.builder_args.device + type(self.builder_args.precision).__name__
235+
f"{self.builder_args.device}_{self.builder_args.precision}"
235236
)
236237

237238
def chunked_completion(self, completion_request: CompletionRequest):
@@ -270,7 +271,13 @@ def chunked_completion(self, completion_request: CompletionRequest):
270271
)
271272
generator_args = GeneratorArgs(
272273
completion_request.messages[-1].get("content"),
274+
max_new_tokens=(
275+
int(completion_request.max_tokens)
276+
if completion_request.max_tokens
277+
else OPENAI_API_DEFAULT_MAX_TOKENS
278+
),
273279
encoded_prompt=encoded,
280+
temperature=float(completion_request.temperature),
274281
chat_mode=False,
275282
)
276283

@@ -295,6 +302,7 @@ def callback(x, *, done_generating=False):
295302
sequential_prefill=generator_args.sequential_prefill,
296303
start_pos=self.start_pos,
297304
max_seq_length=self.max_seq_length,
305+
seed=int(completion_request.seed),
298306
):
299307
if y is None:
300308
continue

generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GeneratorArgs:
7171
num_samples: int = 1
7272
max_new_tokens: int = 200
7373
top_k: int = 200
74-
temperature: int = 0 # deterministic argmax
74+
temperature: float = 0.0 # deterministic argmax if 0.0
7575
compile: bool = False
7676
compile_prefill: bool = False
7777
speculate_k: int = 5
@@ -105,9 +105,7 @@ def validate_build(
105105
def from_args(cls, args):
106106
dso_path = getattr(args, "dso_path", None)
107107
pte_path = getattr(args, "pte_path", None)
108-
sequential_prefill = (
109-
args.sequential_prefill or bool(dso_path) or bool(pte_path)
110-
)
108+
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)
111109

112110
return cls(
113111
prompt=getattr(args, "prompt", ""),

server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dataclasses import asdict
1010
from typing import Dict, List, Union
1111

12+
import torch
13+
1214
from api.api import CompletionRequest, OpenAiApiGenerator
1315
from api.models import get_model_info_list, retrieve_model_info
1416

@@ -50,6 +52,8 @@ def chat_endpoint():
5052
"""
5153

5254
print(" === Completion Request ===")
55+
if seed := request.args.get("seed"):
56+
torch.manual_seed(int(seed))
5357

5458
# Parse the request in to a CompletionRequest object
5559
data = request.get_json()

0 commit comments

Comments
 (0)