Skip to content

Commit e924a7c

Browse files
hhzhang16krishung5
andauthored
feat: generalize VLM embedding extraction (#1388)
Signed-off-by: hhzhang16 <[email protected]> Co-authored-by: Kris Hung <[email protected]>
1 parent c43ebd2 commit e924a7c

File tree

12 files changed

+314
-53
lines changed

12 files changed

+314
-53
lines changed

examples/multimodal/README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
# Multimodal Deployment Examples
1919

2020
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo.
21-
The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model.
2221

2322
## Use the Latest Release
2423

@@ -59,11 +58,15 @@ flowchart LR
5958
decode_worker --image_url--> encode_worker
6059
encode_worker --embeddings--> decode_worker
6160
```
62-
```
6361

6462
```bash
6563
cd $DYNAMO_HOME/examples/multimodal
66-
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
64+
# Serve a LLaVA 1.5 7B model:
65+
dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml
66+
# Serve a Qwen2.5-VL model:
67+
# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml
68+
# Serve a Phi3V model:
69+
# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml
6770
```
6871

6972
### Client
@@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \
9295
}
9396
],
9497
"max_tokens": 300,
98+
"temperature": 0.0,
9599
"stream": false
96100
}'
97101
```
98102

103+
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
104+
99105
You should see a response similar to this:
100106
```json
101107
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
@@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \
162168
}
163169
],
164170
"max_tokens": 300,
171+
"temperature": 0.0,
165172
"stream": false
166173
}'
167174
```
@@ -171,6 +178,8 @@ You should see a response similar to this:
171178
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
172179
```
173180

181+
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
182+
174183
## Deployment with Dynamo Operator
175184

176185
These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI.
@@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk
206215

207216
# Deploy to Kubernetes
208217
export DEPLOYMENT_NAME=multimodal-agg
209-
# For aggregated serving:
210-
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg.yaml
218+
# For aggregated serving with LLaVA:
219+
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-llava.yaml
220+
# For aggregated serving with Qwen2.5-VL:
221+
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-qwen.yaml
222+
# For aggregated serving with Phi3V:
223+
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-phi3v.yaml
211224
# For disaggregated serving:
212225
# export DEPLOYMENT_NAME=multimodal-disagg
213226
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml
@@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \
244257
}
245258
],
246259
"max_tokens": 300,
260+
"temperature": 0.0,
247261
"stream": false
248262
}'
249263
```
250264

265+
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
266+
251267
For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).

examples/multimodal/components/decode_worker.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from components.disagg_router import PyDisaggregatedRouter
2525
from components.encode_worker import VllmEncodeWorker
2626
from components.prefill_worker import VllmPrefillWorker
27-
from transformers import LlavaForConditionalGeneration
2827
from utils.logging import check_required_workers
28+
from utils.model import construct_mm_data, get_vision_embeddings_info
2929
from utils.nixl import NixlMetadataStore
3030
from utils.prefill_queue import PrefillQueue
3131
from utils.protocol import (
@@ -117,6 +117,11 @@ async def async_init(self):
117117
)
118118

119119
runtime = dynamo_context["runtime"]
120+
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
121+
self.engine_args.model, self.engine_args.num_patches
122+
)
123+
logger.debug(f"Embeddings shape: {embeddings_shape}")
124+
self.embedding_size = embeddings_shape[1]
120125

121126
if self.do_remote_prefill:
122127
metadata = self.engine_client.nixl_metadata
@@ -133,18 +138,7 @@ async def async_init(self):
133138
await self.disaggregated_router.async_init()
134139
else:
135140
self.disaggregated_router = None
136-
137-
model = LlavaForConditionalGeneration.from_pretrained(
138-
self.engine_args.model,
139-
device_map="auto",
140-
torch_dtype=torch.bfloat16,
141-
).eval()
142-
vision_tower = model.vision_tower
143-
self.embedding_size = (
144-
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
145-
)
146141
else:
147-
EMBEDDINGS_SHAPE = (1, 577, 4096)
148142
EMBEDDINGS_DTYPE = torch.float16
149143
EMBEDDINGS_DEVICE = "cuda"
150144

@@ -161,7 +155,7 @@ async def async_init(self):
161155

162156
# Create a longer-lived buffer for receiving the image embeddings.
163157
embeddings = torch.empty(
164-
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
158+
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
165159
)
166160
descriptor = connect.Descriptor(embeddings)
167161
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
@@ -206,13 +200,15 @@ async def generate(self, request: vLLMMultimodalRequest):
206200
multi_modal_data,
207201
remote_prefill_params,
208202
) = await self.remote_prefill(request)
209-
210203
else:
211204
(
212205
prompt_ids,
213206
multi_modal_data,
214207
remote_prefill_params,
215208
) = await self.local_prefill(request)
209+
logger.debug(f"Prompt ids: {prompt_ids}")
210+
logger.debug(f"Multi modal data: {multi_modal_data}")
211+
logger.debug(f"Remote prefill params: {remote_prefill_params}")
216212

217213
# rust HTTP requires Delta streaming
218214
request.sampling_params.output_kind = RequestOutputKind.DELTA
@@ -227,7 +223,7 @@ async def generate(self, request: vLLMMultimodalRequest):
227223
remote_prefill_params=remote_prefill_params,
228224
):
229225
logger.debug(
230-
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
226+
f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
231227
)
232228
yield MyRequestOutput(
233229
request_id=response.request_id,
@@ -294,7 +290,9 @@ async def local_prefill(self, request: vLLMMultimodalRequest) -> tuple:
294290
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
295291
)
296292
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
297-
multi_modal_data = {"image": embeddings}
293+
multi_modal_data = construct_mm_data(
294+
self.engine_args.model, encode_output, embeddings, self.embeddings_dtype
295+
)
298296

299297
return prompt_ids, multi_modal_data, remote_prefill_params
300298

@@ -353,17 +351,16 @@ async def remote_prefill(self, request: vLLMMultimodalRequest) -> tuple:
353351
# As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
354352
# so that decode worker can pre-allocate the memory with the correct size.
355353
# The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
356-
# Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
357-
IMAGE_TOKEN_ID = 32000
354+
# Since the "<image>" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token.
358355
DUMMY_TOKEN_ID = 0
359356
# Find the index of the image token in the prompt token ids
360357
image_token_index = request.engine_prompt["prompt_token_ids"].index(
361-
IMAGE_TOKEN_ID
358+
self.engine_args.image_token_id
362359
)
363360
dummy_token_index = image_token_index + 1
364361
prompt_ids = (
365362
request.engine_prompt["prompt_token_ids"][:dummy_token_index]
366-
+ [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
363+
+ [DUMMY_TOKEN_ID] * self.embedding_size
367364
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:]
368365
)
369366
logger.debug(

examples/multimodal/components/encode_worker.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import httpx
2727
import torch
2828
from PIL import Image
29-
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
29+
from transformers import AutoImageProcessor
30+
from utils.model import load_vision_model
3031
from utils.protocol import EncodeRequest, EncodeResponse
3132
from utils.vllm import parse_vllm_args
3233

@@ -66,10 +67,7 @@ def __init__(self) -> None:
6667
self.image_processor = AutoImageProcessor.from_pretrained(
6768
self.MODEL_ID, trust_remote_code=True
6869
)
69-
70-
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
71-
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
72-
).eval()
70+
self.vision_model = load_vision_model(self.MODEL_ID)
7371

7472
self._image_cache: dict[str, Image.Image] = {}
7573
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
@@ -167,17 +165,32 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
167165

168166
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
169167
image_embeds = self.image_processor(images=image, return_tensors="pt")
168+
# Add a batch dimension to everything
169+
for item in image_embeds:
170+
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
171+
logger.debug(f"Image embeds: {image_embeds}")
172+
173+
image_grid_thw = (
174+
image_embeds["image_grid_thw"].tolist()
175+
if "image_grid_thw" in image_embeds
176+
else None
177+
)
178+
image_sizes = (
179+
image_embeds["image_sizes"].tolist()
180+
if "image_sizes" in image_embeds
181+
else [image.size]
182+
)
183+
logger.debug(
184+
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
185+
)
170186

171187
with torch.no_grad():
172-
logger.debug(f"Vision model device: {self.vision_model.device}")
173-
vision_outputs = self.vision_model.vision_tower(
174-
image_embeds["pixel_values"].to(self.vision_model.device)
175-
)
176-
logger.debug("Vision model completed.")
177-
178-
embeddings = vision_outputs.last_hidden_state
179-
embeddings = self.vision_model.multi_modal_projector(embeddings)
180-
188+
embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
189+
if isinstance(embeddings, tuple) or isinstance(embeddings, list):
190+
# The result multimodal_embeddings may be a list or tuple of tensors, with each
191+
# tensor corresponding to a multimodal data item (image or video).
192+
# TODO: for multi-image support, this result will contain multiple tensors.
193+
embeddings = embeddings[0].unsqueeze(0)
181194
logger.debug(
182195
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
183196
)
@@ -201,6 +214,8 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
201214

202215
yield EncodeResponse(
203216
request_id=request.request_id,
217+
image_grid_thw=image_grid_thw,
218+
image_sizes=image_sizes,
204219
).model_dump_json()
205220
except Exception as e:
206221
logger.error(f"Error processing request {request_id}: {e}")

examples/multimodal/components/prefill_worker.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from components.encode_worker import VllmEncodeWorker
2626
from pydantic import BaseModel
2727
from utils.logging import check_required_workers
28+
from utils.model import construct_mm_data, get_vision_embeddings_info
2829
from utils.nixl import NixlMetadataStore
2930
from utils.prefill_queue import PrefillQueue
3031
from utils.protocol import EncodeRequest, EncodeResponse
@@ -39,9 +40,6 @@
3940

4041
logger = logging.getLogger(__name__)
4142

42-
# Constants for the shape and dtype of the embeddings tensor.
43-
EMBEDDINGS_SHAPE = (1, 577, 4096)
44-
EMBEDDINGS_DTYPE = torch.float16
4543
EMBEDDINGS_DEVICE = "cuda"
4644

4745

@@ -113,9 +111,12 @@ async def async_init(self):
113111
await self._connector.initialize()
114112

115113
# Create a longer-lived buffer for receiving the image embeddings.
114+
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
115+
self.engine_args.model, self.engine_args.num_patches
116+
)
116117
embeddings = torch.empty(
117-
EMBEDDINGS_SHAPE,
118-
dtype=EMBEDDINGS_DTYPE,
118+
embeddings_shape,
119+
dtype=self.embeddings_dtype,
119120
device=EMBEDDINGS_DEVICE,
120121
)
121122
descriptor = connect.Descriptor(embeddings)
@@ -248,10 +249,11 @@ async def generate(self, request: RemotePrefillRequest):
248249
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
249250
# some placeholder dummy tokens are inserted based on the embedding size in the worker.py.
250251
# TODO: make this more flexible/model-dependent
251-
IMAGE_TOKEN_ID = 32000
252252
embedding_size = embeddings.shape[1]
253-
padding_size = embedding_size - 1
254-
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
253+
padding_size = embedding_size
254+
image_token_index = request.prompt_token_ids.index(
255+
self.engine_args.image_token_id
256+
)
255257
dummy_token_index = image_token_index + 1
256258
prompt_token_ids = (
257259
request.prompt_token_ids[:dummy_token_index]
@@ -262,7 +264,12 @@ async def generate(self, request: RemotePrefillRequest):
262264
request_id=request_id,
263265
prompt=TokensPrompt(
264266
prompt_token_ids=prompt_token_ids,
265-
multi_modal_data={"image": embeddings},
267+
multi_modal_data=construct_mm_data(
268+
self.engine_args.model,
269+
encode_output,
270+
embeddings,
271+
self.embeddings_dtype,
272+
),
266273
),
267274
sampling_params=sampling_params,
268275
remote_prefill_params=remote_prefill_params,

examples/multimodal/components/processor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,19 @@ async def _generate_responses(
188188
# The generate endpoint will be used by the frontend to handle incoming requests.
189189
@endpoint()
190190
async def generate(self, raw_request: MultiModalRequest):
191-
prompt = str(self.engine_args.prompt_template).replace(
192-
"<prompt>", raw_request.messages[0].content[0].text
193-
)
191+
# Ensure the configured template includes the placeholder
192+
template = self.engine_args.prompt_template
193+
if "<prompt>" not in template:
194+
raise ValueError("prompt_template must contain '<prompt>' placeholder")
195+
196+
# Safely extract user text
197+
try:
198+
user_text = raw_request.messages[0].content[0].text
199+
except (IndexError, AttributeError) as e:
200+
raise ValueError(f"Invalid message structure: {e}")
201+
202+
prompt = template.replace("<prompt>", user_text)
203+
194204
msg = {
195205
"role": "user",
196206
"content": prompt,
@@ -201,6 +211,7 @@ async def generate(self, raw_request: MultiModalRequest):
201211
messages=[msg],
202212
stream=raw_request.stream,
203213
max_tokens=raw_request.max_tokens,
214+
temperature=raw_request.temperature,
204215
request_id=str(uuid.uuid4()),
205216
)
206217
image_url = None

examples/multimodal/configs/agg.yaml renamed to examples/multimodal/configs/agg-llava.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ VllmDecodeWorker:
2626
enforce-eager: true
2727
max-num-batched-tokens: 16384
2828
enable-prefix-caching: true
29+
image-token-id: 32000
30+
num-patches: 576
2931
router: random
3032
tensor-parallel-size: 1
3133
ServiceArgs:

0 commit comments

Comments
 (0)