Skip to content

Commit 129829a

Browse files
committed
Add seed, temperature, max_tokens and system_fingerprint paramters to request/response
1 parent e139ad9 commit 129829a

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

api/api.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,20 @@ class CompletionRequest:
105105
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106106
logprobs: Optional[bool] = None # unimplemented
107107
top_logprobs: Optional[int] = None # unimplemented
108-
max_tokens: Optional[int] = None # unimplemented
108+
max_tokens: Optional[int] = None
109109
n: int = 1
110110
presence_penalty: float = 0 # unimplemented
111111
response_format: Optional[ResponseFormat] = None # unimplemented
112-
seed: Optional[int] = None # unimplemented
112+
seed: Optional[int] = None
113113
service_tier: Optional[str] = None # unimplemented
114114
stop: Optional[List[str]] = None # unimplemented
115115
stream: bool = False
116116
stream_options: Optional[StreamOptions] = None # unimplemented
117-
temperature: Optional[float] = 1.0 # unimplemented
117+
temperature: Optional[float] = 1.0
118118
top_p: Optional[float] = 1.0 # unimplemented
119-
tools: Optional[List[Any]] = None # unimplemented
120-
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121-
parallel_tool_calls: Optional[bool] = None # unimplemented
119+
tools: Optional[List[Any]] = None # unimplemented - Assistant features
120+
tool_choice: Optional[Union[str, Any]] = None # unimplemented - Assistant features
121+
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
122122
user: Optional[str] = None # unimplemented
123123

124124

@@ -229,9 +229,8 @@ def __init__(self, *args, **kwargs):
229229
else self.model.config.max_seq_length
230230
)
231231
# The System fingerprint is a unique identifier for the model and its configuration.
232-
# Currently, this is not implemented in a
233232
self.system_fingerprint = (
234-
self.builder_args.device + type(self.builder_args.precision).__name__
233+
self.builder_args.device + "_" + str(self.builder_args.precision)
235234
)
236235

237236
def chunked_completion(self, completion_request: CompletionRequest):
@@ -270,7 +269,11 @@ def chunked_completion(self, completion_request: CompletionRequest):
270269
)
271270
generator_args = GeneratorArgs(
272271
completion_request.messages[-1].get("content"),
272+
max_new_tokens=(
273+
completion_request.max_tokens if completion_request.max_tokens else 16
274+
),
273275
encoded_prompt=encoded,
276+
temperature=completion_request.temperature,
274277
chat_mode=False,
275278
)
276279

generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GeneratorArgs:
7171
num_samples: int = 1
7272
max_new_tokens: int = 200
7373
top_k: int = 200
74-
temperature: int = 0 # deterministic argmax
74+
temperature: float = 0 # deterministic argmax if 0.0
7575
compile: bool = False
7676
compile_prefill: bool = False
7777
speculate_k: int = 5
@@ -105,9 +105,7 @@ def validate_build(
105105
def from_args(cls, args):
106106
dso_path = getattr(args, "dso_path", None)
107107
pte_path = getattr(args, "pte_path", None)
108-
sequential_prefill = (
109-
args.sequential_prefill or bool(dso_path) or bool(pte_path)
110-
)
108+
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)
111109

112110
return cls(
113111
prompt=getattr(args, "prompt", ""),

server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dataclasses import asdict
1010
from typing import Dict, List, Union
1111

12+
import torch
13+
1214
from api.api import CompletionRequest, OpenAiApiGenerator
1315
from api.models import get_model_info_list, retrieve_model_info
1416

@@ -50,6 +52,8 @@ def chat_endpoint():
5052
"""
5153

5254
print(" === Completion Request ===")
55+
if seed := request.args.get("seed"):
56+
torch.manual_seed(int(seed))
5357

5458
# Parse the request in to a CompletionRequest object
5559
data = request.get_json()

0 commit comments

Comments
 (0)