Skip to content

Commit 5e0e947

Browse files
committed
reduce unnecessary transpose
1 parent b10b67c commit 5e0e947

File tree

3 files changed

+7
-11
lines changed

3 files changed

+7
-11
lines changed

examples/qualcomm/llama2/llama.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
setup_common_args_and_variables,
2222
SimpleADB,
2323
)
24+
2425
from sentencepiece import SentencePieceProcessor
2526

2627

@@ -55,7 +56,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
5556
return probs_indices.gather(dim=-1, index=next_token)
5657

5758
with torch.no_grad():
58-
while token_list[-1] != sp_model.eos_id() and pos < 32:
59+
while token_list[-1] != sp_model.eos_id() and pos < 128:
5960
logits, k_cache, v_cache, kv_mask = module(
6061
torch.full((1, 1), token_list[pos]),
6162
torch.full((1, 1), pos),
@@ -160,11 +161,10 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
160161
config = ModelArgs(**json.load(f))
161162
# TODO: support batch inputs if necessary
162163
config.max_batch_size = 1
163-
config.n_layers = 1
164164

165165
state_dict = torch.load(args.checkpoint)
166166
instance = LlamaModel(config)
167-
instance.load_state_dict(state_dict["model"], strict=False)
167+
instance.load_state_dict(state_dict["model"])
168168
inputs = instance.get_example_inputs()
169169
input_list = create_device_inputs(inputs)
170170
pte_filename = "llama2_qnn"
@@ -199,6 +199,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
199199
per_channel_linear=per_channel_linear,
200200
shared_buffer=args.shared_buffer,
201201
metadata=instance.get_metadata(),
202+
direct_io=True,
202203
)
203204

204205
if args.compile_only:

examples/qualcomm/llama2/model/static_llama.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,9 @@ def forward(
6464
v_cache = v_cache.view(bsz, self.max_seq_len, self.n_kv_heads, self.head_dim)
6565
k = k_cache * (1.0 - mask) + k * mask
6666
v = v_cache * (1.0 - mask) + v * mask
67-
# (bs, n_local_heads, seqlen, head_dim)
68-
q = q.transpose(1, 2)
69-
k = k.transpose(1, 2)
7067

71-
attn = q @ k.transpose(-2, -1)
72-
attn = attn * self.scale_tensor
73-
attn = attn + atten_mask
68+
attn = q.transpose(1, 2) @ k.permute(0, 2, 3, 1)
69+
attn = attn * self.scale_tensor + atten_mask
7470
attn = self.attn_softmax(attn)
7571
y = attn @ v.transpose(1, 2)
7672
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
@@ -182,7 +178,6 @@ def forward(
182178

183179
# update kv cache
184180
output_k_cache = torch.concat(output_k_cache)
185-
output_k_cache = output_k_cache.transpose(1, 2).contiguous()
186181
output_k_cache = output_k_cache.view(
187182
self.max_batch_size, self.n_layers, self.max_seq_len, self.dim
188183
)

examples/qualcomm/scripts/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def build_executorch_binary(
174174
skip_node_op_set=None,
175175
quant_dtype: Optional[QuantDtype] = None,
176176
per_channel_linear=False, # TODO: remove this once QNN fully supports linear
177-
direct_io=True, # TODO: temporal workaround for llama
177+
direct_io=False, # TODO: temporal workaround for llama
178178
shared_buffer=False,
179179
metadata=None,
180180
):

0 commit comments

Comments
 (0)