Skip to content

Commit 8bceee8

Browse files
Chun-I TsaiJoey Tsai
authored andcommitted
Rebase and change class name *Sha-> SHA
1 parent 2f85f9e commit 8bceee8

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,15 +678,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
678678
get_custom_quant_ios_dtype,
679679
)
680680

681+
atten = builder_exported_to_edge.model.layers[0].attention
682+
if args.use_qnn_sha:
683+
cache_shape = torch.Size(
684+
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
685+
)
686+
else:
687+
cache_shape = torch.Size(
688+
(
689+
atten.max_batch_size,
690+
atten.max_seq_len,
691+
atten.n_kv_heads,
692+
atten.head_dim,
693+
)
694+
)
681695
# pyre-ignore
682696
tag_quant_io(
683697
builder_exported_to_edge.edge_manager.exported_program().graph_module,
684-
partial(
685-
get_custom_quant_ios_dtype, # pyre-ignore
686-
builder_exported_to_edge.model.layers[
687-
0
688-
].attention.kv_cache.past_k_caches.shape,
689-
),
698+
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
690699
)
691700

692701
logging.info("Lowering model using following partitioner(s): ")

examples/models/llama/source_transformation/attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import List, Optional, Tuple
1313

1414
import torch
15-
from executorch.examples.models.llama2.llama_transformer import Attention
15+
from executorch.examples.models.llama.llama_transformer import Attention
1616
from torch import nn
1717

1818

@@ -28,7 +28,7 @@ def apply_rotary_emb_single(
2828
return x_out
2929

3030

31-
class KVCacheSha(torch.nn.Module):
31+
class KVCacheSHA(torch.nn.Module):
3232
def __init__(
3333
self,
3434
max_batch_size: int,
@@ -74,7 +74,7 @@ def get_cache(self, head_idx):
7474
)
7575

7676

77-
class SDPASha(torch.nn.Module):
77+
class SDPASHA(torch.nn.Module):
7878

7979
def __init__(
8080
self,
@@ -89,7 +89,7 @@ def __init__(
8989
self.head_dim = head_dim
9090
self.n_rep = n_rep
9191
self.dim = dim
92-
self.kv_cache = KVCacheSha(
92+
self.kv_cache = KVCacheSHA(
9393
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
9494
)
9595
self.scale_factor = math.sqrt(head_dim)
@@ -123,7 +123,7 @@ def forward(
123123
return torch.cat(output, dim=-1)
124124

125125

126-
class AttentionSha(nn.Module):
126+
class AttentionSHA(nn.Module):
127127
def __init__(self, attention_mha: nn.Module):
128128
super().__init__()
129129
if not attention_mha.use_kv_cache:
@@ -136,7 +136,7 @@ def __init__(self, attention_mha: nn.Module):
136136
self.max_batch_size = attention_mha.max_batch_size
137137
self.max_seq_len = attention_mha.max_seq_len
138138
self.head_dim = attention_mha.dim // self.n_heads
139-
self.SDPA = SDPASha(
139+
self.SDPA = SDPASHA(
140140
self.max_batch_size,
141141
self.max_seq_len,
142142
self.n_heads,
@@ -212,7 +212,7 @@ def replace_attention_to_attention_sha(module: torch.nn.Module):
212212
setattr(
213213
module,
214214
name,
215-
AttentionSha(child),
215+
AttentionSHA(child),
216216
)
217217
else:
218218
replace_attention_to_attention_sha(child)

0 commit comments

Comments
 (0)