Skip to content

Commit 770218c

Browse files
committed
Add seed, temperature, max_tokens and system_fingerprint paramters to request/response
1 parent e139ad9 commit 770218c

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ streamlit run torchchat.py -- browser llama3.1
200200
<details>
201201
<summary>This mode gives a REST API that matches the OpenAI API spec for interacting with a model</summary>
202202

203+
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
204+
Since this feature is under active development, it's possible not every parameter is consumed. See api/api.py for details on
205+
which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973).
206+
203207
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
204208

205209
In one terminal, start the server
@@ -213,8 +217,7 @@ python3 torchchat.py server llama3.1
213217

214218
In another terminal, query the server using `curl`. Depending on the model configuration, this query might take a few minutes to respond.
215219

216-
Setting `stream` to "true" in the request emits a response in chunks. Currently, this response
217-
is plaintext and will not be formatted to the OpenAI API specification. If `stream` is unset or not "true", then the client will await the full response from the server.
220+
Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server.
218221

219222

220223
**Example Input + Output**
@@ -227,6 +230,7 @@ curl http://127.0.0.1:5000/v1/chat \
227230
-d '{
228231
"model": "llama3.1",
229232
"stream": "true",
233+
"max_tokens": 200,
230234
"messages": [
231235
{
232236
"role": "system",

api/api.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,20 @@ class CompletionRequest:
105105
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106106
logprobs: Optional[bool] = None # unimplemented
107107
top_logprobs: Optional[int] = None # unimplemented
108-
max_tokens: Optional[int] = None # unimplemented
108+
max_tokens: Optional[int] = None
109109
n: int = 1
110110
presence_penalty: float = 0 # unimplemented
111111
response_format: Optional[ResponseFormat] = None # unimplemented
112-
seed: Optional[int] = None # unimplemented
112+
seed: Optional[int] = None
113113
service_tier: Optional[str] = None # unimplemented
114114
stop: Optional[List[str]] = None # unimplemented
115115
stream: bool = False
116116
stream_options: Optional[StreamOptions] = None # unimplemented
117-
temperature: Optional[float] = 1.0 # unimplemented
117+
temperature: Optional[float] = 1.0
118118
top_p: Optional[float] = 1.0 # unimplemented
119-
tools: Optional[List[Any]] = None # unimplemented
120-
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121-
parallel_tool_calls: Optional[bool] = None # unimplemented
119+
tools: Optional[List[Any]] = None # unimplemented - Assistant features
120+
tool_choice: Optional[Union[str, Any]] = None # unimplemented - Assistant features
121+
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
122122
user: Optional[str] = None # unimplemented
123123

124124

@@ -229,9 +229,8 @@ def __init__(self, *args, **kwargs):
229229
else self.model.config.max_seq_length
230230
)
231231
# The System fingerprint is a unique identifier for the model and its configuration.
232-
# Currently, this is not implemented in a
233232
self.system_fingerprint = (
234-
self.builder_args.device + type(self.builder_args.precision).__name__
233+
self.builder_args.device + "_" + str(self.builder_args.precision)
235234
)
236235

237236
def chunked_completion(self, completion_request: CompletionRequest):
@@ -270,7 +269,11 @@ def chunked_completion(self, completion_request: CompletionRequest):
270269
)
271270
generator_args = GeneratorArgs(
272271
completion_request.messages[-1].get("content"),
272+
max_new_tokens=(
273+
completion_request.max_tokens if completion_request.max_tokens else 16
274+
),
273275
encoded_prompt=encoded,
276+
temperature=completion_request.temperature,
274277
chat_mode=False,
275278
)
276279

generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GeneratorArgs:
7171
num_samples: int = 1
7272
max_new_tokens: int = 200
7373
top_k: int = 200
74-
temperature: int = 0 # deterministic argmax
74+
temperature: float = 0 # deterministic argmax if 0.0
7575
compile: bool = False
7676
compile_prefill: bool = False
7777
speculate_k: int = 5
@@ -105,9 +105,7 @@ def validate_build(
105105
def from_args(cls, args):
106106
dso_path = getattr(args, "dso_path", None)
107107
pte_path = getattr(args, "pte_path", None)
108-
sequential_prefill = (
109-
args.sequential_prefill or bool(dso_path) or bool(pte_path)
110-
)
108+
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)
111109

112110
return cls(
113111
prompt=getattr(args, "prompt", ""),

server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dataclasses import asdict
1010
from typing import Dict, List, Union
1111

12+
import torch
13+
1214
from api.api import CompletionRequest, OpenAiApiGenerator
1315
from api.models import get_model_info_list, retrieve_model_info
1416

@@ -50,6 +52,8 @@ def chat_endpoint():
5052
"""
5153

5254
print(" === Completion Request ===")
55+
if seed := request.args.get("seed"):
56+
torch.manual_seed(int(seed))
5357

5458
# Parse the request in to a CompletionRequest object
5559
data = request.get_json()

0 commit comments

Comments
 (0)