Skip to content

Openai api completion params [seed, temperature, max_tokens, system_fingerprint] #1016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ streamlit run torchchat.py -- browser llama3.1
<details>
<summary>This mode gives a REST API that matches the OpenAI API spec for interacting with a model</summary>

The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
Since this feature is under active development, it's possible not every parameter is consumed. See api/api.py for details on
which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973).

To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.

In one terminal, start the server
Expand All @@ -213,8 +217,7 @@ python3 torchchat.py server llama3.1

In another terminal, query the server using `curl`. Depending on the model configuration, this query might take a few minutes to respond.

Setting `stream` to "true" in the request emits a response in chunks. Currently, this response
is plaintext and will not be formatted to the OpenAI API specification. If `stream` is unset or not "true", then the client will await the full response from the server.
Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server.


**Example Input + Output**
Expand All @@ -227,6 +230,7 @@ curl http://127.0.0.1:5000/v1/chat \
-d '{
"model": "llama3.1",
"stream": "true",
"max_tokens": 200,
"messages": [
{
"role": "system",
Expand Down
24 changes: 16 additions & 8 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
See https://platform.openai.com/docs/api-reference/chat for the full specification and details.
"""

OPENAI_API_DEFAULT_MAX_TOKENS = 16

# Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages"


Expand Down Expand Up @@ -105,20 +107,20 @@ class CompletionRequest:
logit_bias: Optional[Dict[str, float]] = None # unimplemented
logprobs: Optional[bool] = None # unimplemented
top_logprobs: Optional[int] = None # unimplemented
max_tokens: Optional[int] = None # unimplemented
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: float = 0 # unimplemented
response_format: Optional[ResponseFormat] = None # unimplemented
seed: Optional[int] = None # unimplemented
seed: Optional[int] = None
service_tier: Optional[str] = None # unimplemented
stop: Optional[List[str]] = None # unimplemented
stream: bool = False
stream_options: Optional[StreamOptions] = None # unimplemented
temperature: Optional[float] = 1.0 # unimplemented
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0 # unimplemented
tools: Optional[List[Any]] = None # unimplemented
tool_choice: Optional[Union[str, Any]] = None # unimplemented
parallel_tool_calls: Optional[bool] = None # unimplemented
tools: Optional[List[Any]] = None # unimplemented - Assistant features
tool_choice: Optional[Union[str, Any]] = None # unimplemented - Assistant features
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
user: Optional[str] = None # unimplemented


Expand Down Expand Up @@ -229,9 +231,8 @@ def __init__(self, *args, **kwargs):
else self.model.config.max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
# Currently, this is not implemented in a
self.system_fingerprint = (
self.builder_args.device + type(self.builder_args.precision).__name__
f"{self.builder_args.device}_{self.builder_args.precision}"
)

def chunked_completion(self, completion_request: CompletionRequest):
Expand Down Expand Up @@ -270,7 +271,13 @@ def chunked_completion(self, completion_request: CompletionRequest):
)
generator_args = GeneratorArgs(
completion_request.messages[-1].get("content"),
max_new_tokens=(
int(completion_request.max_tokens)
if completion_request.max_tokens
else OPENAI_API_DEFAULT_MAX_TOKENS
),
encoded_prompt=encoded,
temperature=float(completion_request.temperature),
chat_mode=False,
)

Expand All @@ -295,6 +302,7 @@ def callback(x, *, done_generating=False):
sequential_prefill=generator_args.sequential_prefill,
start_pos=self.start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
):
if y is None:
continue
Expand Down
6 changes: 2 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GeneratorArgs:
num_samples: int = 1
max_new_tokens: int = 200
top_k: int = 200
temperature: int = 0 # deterministic argmax
temperature: float = 0.0 # deterministic argmax if 0.0
compile: bool = False
compile_prefill: bool = False
speculate_k: int = 5
Expand Down Expand Up @@ -105,9 +105,7 @@ def validate_build(
def from_args(cls, args):
dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
sequential_prefill = (
args.sequential_prefill or bool(dso_path) or bool(pte_path)
)
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)

return cls(
prompt=getattr(args, "prompt", ""),
Expand Down
4 changes: 4 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dataclasses import asdict
from typing import Dict, List, Union

import torch

from api.api import CompletionRequest, OpenAiApiGenerator
from api.models import get_model_info_list, retrieve_model_info

Expand Down Expand Up @@ -50,6 +52,8 @@ def chat_endpoint():
"""

print(" === Completion Request ===")
if seed := request.args.get("seed"):
torch.manual_seed(int(seed))

# Parse the request in to a CompletionRequest object
data = request.get_json()
Expand Down
Loading