Skip to content

Perf opt #3256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)


Expand All @@ -22,6 +23,7 @@
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math

SUPPORTS_WINDOWING = False

Expand Down Expand Up @@ -106,6 +107,21 @@ def attention(
return attn_output


def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
return hpu_attention_meta


def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
Expand Down Expand Up @@ -176,4 +192,10 @@ def paged_attention_mla(
return output.view(batch_size, head_num, -1)


__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -415,6 +416,10 @@ def forward(
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -678,6 +679,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
Expand Down Expand Up @@ -569,6 +570,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
Expand Down Expand Up @@ -645,6 +646,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -466,6 +467,10 @@ def forward(
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -388,6 +389,10 @@ def forward(
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -383,6 +384,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

residual = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -324,6 +325,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from text_generation_server.layers.attention import (
KVCache,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -548,6 +549,10 @@ def forward(
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)

hidden_states = inputs_embeds
bs = seqlen.input_lengths.shape[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -549,6 +550,11 @@ def forward(
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)

hidden_states = inputs_embeds

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -396,6 +397,10 @@ def forward(
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
Expand Down Expand Up @@ -446,6 +447,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down Expand Up @@ -505,7 +510,6 @@ def forward(
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:

hidden_states = self.model(
input_ids,
position_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -354,6 +355,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -347,6 +348,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)

# Get rotary cos and sin for this forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -288,6 +289,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
Expand Down Expand Up @@ -359,7 +364,6 @@ def forward(
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:

inputs_embeds = self.embed_tokens(input_ids)

hidden_states = self.model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
Expand Down Expand Up @@ -266,7 +267,10 @@ def forward(
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:

if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
Expand Down Expand Up @@ -334,7 +338,6 @@ def forward(
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:

inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
Expand Down
Loading