Skip to content

Commit 1dfb1a8

Browse files
committed
bump
1 parent 74b858f commit 1dfb1a8

File tree

5 files changed

+161
-1
lines changed

5 files changed

+161
-1
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Using SGLang and Dynamo to serve embedding models!
18+
"""
19+
20+
import asyncio
21+
import logging
22+
import random
23+
import socket
24+
from typing import Any
25+
26+
import sglang as sgl
27+
from utils.protocol import EmbeddingRequest
28+
from utils.sglang import parse_sglang_args
29+
30+
from dynamo.llm import ModelType, register_llm
31+
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
@service(
37+
dynamo={
38+
"namespace": "dynamo",
39+
},
40+
resources={"gpu": 1},
41+
workers=1,
42+
)
43+
class SGLangEmbeddingWorker:
44+
45+
def __init__(self):
46+
class_name = self.__class__.__name__
47+
self.engine_args = parse_sglang_args(class_name, "")
48+
self.engine = sgl.Engine(server_args=self.engine_args)
49+
50+
logger.info("SGLangEmbeddingWorker initialized")
51+
52+
@async_on_start
53+
async def async_init(self):
54+
runtime = dynamo_context["runtime"]
55+
logger.info("Registering LLM for discovery")
56+
comp_ns, comp_name = SGLangEmbeddingWorker.dynamo_address() # type: ignore
57+
endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
58+
await register_llm(
59+
ModelType.Embedding,
60+
endpoint,
61+
self.engine_args.model_path,
62+
self.engine_args.served_model_name,
63+
)
64+
65+
@endpoint()
66+
async def generate(self, request: EmbeddingRequest):
67+
if isinstance(request.input, str):
68+
input = request.input
69+
elif isinstance(request.input, list):
70+
input = [i for i in request.input]
71+
else:
72+
raise ValueError(f"Invalid input type: {type(request.input)}")
73+
74+
g = await self.engine.async_encode(
75+
prompt=input,
76+
)
77+
78+
# Transform response to match OpenAI embedding format
79+
response = self._transform_response(g, request.model)
80+
yield response
81+
82+
def _transform_response(self, ret, model_name):
83+
"""Transform SGLang response to OpenAI embedding format"""
84+
if not isinstance(ret, list):
85+
ret = [ret]
86+
87+
embedding_objects = []
88+
prompt_tokens = 0
89+
90+
for idx, ret_item in enumerate(ret):
91+
embedding_objects.append({
92+
"object": "embedding",
93+
"embedding": ret_item["embedding"],
94+
"index": idx,
95+
})
96+
prompt_tokens += ret_item["meta_info"]["prompt_tokens"]
97+
98+
return {
99+
"object": "list",
100+
"data": embedding_objects,
101+
"model": model_name,
102+
"usage": {
103+
"prompt_tokens": prompt_tokens,
104+
"total_tokens": prompt_tokens,
105+
},
106+
}

examples/sglang/components/frontend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919

2020
from components.worker import SGLangWorker
21+
from components.embedding_worker import SGLangEmbeddingWorker
2122
from fastapi import FastAPI
2223
from pydantic import BaseModel
2324

@@ -57,6 +58,7 @@ class FrontendConfig(BaseModel):
5758
)
5859
class Frontend:
5960
worker = depends(SGLangWorker)
61+
embedding_worker = depends(SGLangEmbeddingWorker)
6062

6163
def __init__(self):
6264
"""Initialize Frontend service with HTTP server and model configuration."""
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Frontend:
2+
served_model_name: e5
3+
endpoint: SGLangEmbeddingWorker.generate
4+
port: 8000
5+
SGLangEmbeddingWorker:
6+
model-path: intfloat/e5-base-v2
7+
served-model-name: e5
8+
is-embedding: true
9+
tp: 1
10+
trust-remote-code: true
11+
is-embedding: true
12+
json-model-override-args: '{"get_embedding": true, "chat_template": ""}'
13+
ServiceArgs:
14+
workers: 1
15+
resources:
16+
gpu: 1

examples/sglang/graphs/embedding.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from components.frontend import Frontend
18+
from components.embedding_worker import SGLangEmbeddingWorker
19+
20+
Frontend.link(SGLangEmbeddingWorker)

examples/sglang/utils/protocol.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import List, Optional
16+
from typing import List, Optional, Union, Literal
1717

1818
from pydantic import BaseModel, Field
1919

@@ -60,3 +60,19 @@ class DisaggPreprocessedRequest(BaseModel):
6060
bootstrap_host: str
6161
bootstrap_port: int
6262
bootstrap_room: int
63+
64+
EmbeddingInput = Union[
65+
str,
66+
List[str],
67+
List[int],
68+
List[List[int]]
69+
]
70+
71+
EncodingFormat = Literal["float", "base64"]
72+
73+
class EmbeddingRequest(BaseModel):
74+
model: str
75+
input: EmbeddingInput
76+
encoding_format: Optional[EncodingFormat] = None
77+
user: Optional[str] = None
78+
dimensions: Optional[int] = None # only supported in text-embedding-3 and later models from OpenAI

0 commit comments

Comments
 (0)