Skip to content

WIP: Add VLM transformers backend #3132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 138 additions & 32 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,17 @@
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
from text_generation_server.models.transformers_flash_vlm import (
TransformersFlashVlmCausalLM,
TransformersGemma3VlmCausalLM,
)
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Transformers Backend: {e}")
FLASH_TRANSFORMERS_BACKEND = False

# TODO: remove this, it's a temporary for testing the FLASH_TRANSFORMERS_BACKEND
FLASH_ATTENTION = False


class ModelType(enum.Enum):
DEEPSEEK_V2 = {
Expand Down Expand Up @@ -1155,7 +1163,6 @@ def get_model(
)
elif model_type == GEMMA3:
if FLASH_ATTENTION:
# TODO: Use VlmCausalLM when image support is added.
return VlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
Expand All @@ -1173,12 +1180,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
from transformers import Gemma3ForConditionalGeneration as Gemma3Model

return TransformersGemma3VlmCausalLM.fallback(
model_id,
Gemma3Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
)
elif sharded:
Expand Down Expand Up @@ -1483,33 +1493,65 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel

# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_VL"))
if model_type == QWEN2_5_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
# TODO: Uncomment when transformers is refactored
# elif FLASH_TRANSFORMERS_BACKEND:
# return TransformersQwen2VlmCausalLM.fallback(
# model_id,
# Qwen2VLModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# config_class=Qwen2_5_VLConfig,
# processor_class=Qwen2_5_VLProcessor,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Qwen2_5_VL"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
Expand All @@ -1524,6 +1566,20 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
# TODO: Uncomment when transformers is refactored and cross attn is added
# elif FLASH_TRANSFORMERS_BACKEND:
# from transformers import MllamaForConditionalGeneration as MllamaModel

# return TransformersFlashVlmCausalLM.fallback(
# model_id,
# MllamaModel,
# revision,
# quantize=quantize,
# speculator=speculator,
# dtype=torch.bfloat16,
# trust_remote_code=trust_remote_code,
# batch_class=MllamaCausalLMBatch,
# )
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
Expand All @@ -1542,6 +1598,19 @@ def get_model(
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Idefics2ForConditionalGeneration as Idefics2Model

return TransformersFlashVlmCausalLM.fallback(
model_id,
Idefics2Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3:
Expand All @@ -1560,6 +1629,19 @@ def get_model(
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import Idefics3ForConditionalGeneration as Idefics3Model

return TransformersFlashVlmCausalLM.fallback(
model_id,
Idefics3Model,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
processor_kwargs={"size": {"longest_edge": 1456}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA:
Expand All @@ -1578,9 +1660,21 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel

return TransformersFlashVlmCausalLM.fallback(
model_id,
PaliGemmaModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
batch_class=PaliGemmaBatch,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
return VlmCausalLM(
Expand All @@ -1593,6 +1687,18 @@ def get_model(
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
elif FLASH_TRANSFORMERS_BACKEND:
from transformers import LlavaNextForConditionalGeneration as LlavaNextModel

return TransformersFlashVlmCausalLM.fallback(
model_id,
LlavaNextModel,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))

Expand Down
9 changes: 0 additions & 9 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,9 +1344,6 @@ def __init__(
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch

def max_past(self) -> int:
return getattr(self.model, "max_past", None)

def init_kv_cache(
self,
num_blocks: int,
Expand Down Expand Up @@ -1792,12 +1789,6 @@ def forward(
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices

if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)

bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def tgi_flash_attention_forward(
softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):

kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
Expand Down Expand Up @@ -72,6 +73,7 @@ def tgi_flash_attention_forward(
max_s,
kv_scales=kv_scales,
softcap=softcap,
window_size_left=sliding_window,
)

attn_output = attn_output.view(-1, num_heads * head_dim)
Expand Down Expand Up @@ -157,7 +159,14 @@ def __init__(
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads
self.num_kv_heads = model.config.num_key_value_heads
self.head_size = model.config.hidden_size // model.config.num_attention_heads
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if hasattr(model.config, "head_dim"):
self.head_size = model.config.head_dim
else:
self.head_size = (
model.config.hidden_size // model.config.num_attention_heads
)

# Skip it for models in the exception list
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
Expand Down
Loading
Loading