Skip to content

Commit f552ee8

Browse files
committed
Merge branch 'main' into migrate-benchmark-results-v3
2 parents a6889a5 + c726a9b commit f552ee8

File tree

12 files changed

+394
-96
lines changed

12 files changed

+394
-96
lines changed

backends/arm/arm_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
135135
self.quantize_io = quantize_io
136136
return self
137137

138-
def set_input_order(self, input_order: str = None) -> "ArmCompileSpecBuilder":
138+
def set_input_order(
139+
self, input_order: Optional[str] = None
140+
) -> "ArmCompileSpecBuilder":
139141
"""
140142
Reorder the inputs coming in. This may be required when inputs > 1.
141143
And while using the U55/U85 CompileSpec.

backends/cadence/fusion_g3/operators/op_add.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,26 +95,26 @@ Tensor& add_out(
9595
}
9696

9797
for (int i = 0; i < max_dim; i++) {
98-
out_shape[i] = 1;
98+
out_shape[i] = 1;
9999
inp1_shape[i] = 1;
100100
inp2_shape[i] = 1;
101101
}
102-
103-
int offset_out = max_dim - out.dim();
102+
103+
int offset_out = max_dim - out.dim();
104104
int offset_inp1 = max_dim - a.dim();
105105
int offset_inp2 = max_dim - b.dim();
106-
106+
107107
for (int i = 0; i < out.dim(); i++) {
108108
out_shape[i + offset_out] = out.size(i);
109109
}
110110
for (int i = 0; i < a.dim(); i++) {
111111
inp1_shape[i + offset_inp1] = a.size(i);
112112
}
113113
for (int i = 0; i < b.dim(); i++) {
114-
inp2_shape[i + offset_inp2] = b.size(i);
114+
inp2_shape[i + offset_inp2] = b.size(i);
115115
}
116116

117-
if ((compute_type == ScalarType::Int) && (optimized)){
117+
if ((compute_type == ScalarType::Int) && (optimized)) {
118118
const int* const inp1_data = a.const_data_ptr<int>();
119119
const int* const inp2_data = b.const_data_ptr<int>();
120120
int* const out_data = out.mutable_data_ptr<int>();

backends/cadence/fusion_g3/operators/op_mul.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,23 @@ Tensor& mul_out(
8787
}
8888

8989
for (int i = 0; i < max_dim; i++) {
90-
out_shape[i] = 1;
90+
out_shape[i] = 1;
9191
inp1_shape[i] = 1;
9292
inp2_shape[i] = 1;
9393
}
94-
95-
int offset_out = max_dim - out.dim();
94+
95+
int offset_out = max_dim - out.dim();
9696
int offset_inp1 = max_dim - a.dim();
9797
int offset_inp2 = max_dim - b.dim();
98-
98+
9999
for (int i = 0; i < out.dim(); i++) {
100100
out_shape[i + offset_out] = out.size(i);
101101
}
102102
for (int i = 0; i < a.dim(); i++) {
103103
inp1_shape[i + offset_inp1] = a.size(i);
104104
}
105105
for (int i = 0; i < b.dim(); i++) {
106-
inp2_shape[i + offset_inp2] = b.size(i);
106+
inp2_shape[i + offset_inp2] = b.size(i);
107107
}
108108

109109
if ((compute_type == ScalarType::Int) && (optimized)) {

backends/cadence/hifi/operators/op_maximum.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ using torch::executor::apply_binary_elementwise_fn;
2323
using torch::executor::Error;
2424
using torch::executor::resize_to_broadcast_target_size;
2525

26-
2726
namespace cadence {
2827
namespace impl {
2928
namespace HiFi {

backends/cadence/hifi/operators/op_pow.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,3 @@ Tensor& pow_Scalar_out(
351351
} // namespace HiFi
352352
} // namespace impl
353353
} // namespace cadence
354-

backends/cadence/hifi/operators/quantized_linear_out.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ using ::executorch::aten::Tensor;
2626
using ::executorch::runtime::getLeadingDims;
2727
using ::executorch::runtime::KernelRuntimeContext;
2828

29-
30-
// The nnlib kernel to compute quantized linear via matmul.
29+
// The nnlib kernel to compute quantized linear via matmul.
3130

3231
void _quantized_linear_asym8u(
3332
const Tensor& in,
@@ -48,22 +47,22 @@ void _quantized_linear_asym8u(
4847
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
4948
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
5049
int32_t ret = xa_nn_matmul_asym8uxasym8u_asym8u(
51-
out_data,
52-
weight_data,
53-
in_data,
54-
bias_data,
55-
out_dim,
56-
in_dim,
57-
in_dim,
58-
leading_dims,
59-
in_dim,
60-
out_dim,
61-
1,
50+
out_data,
51+
weight_data,
52+
in_data,
53+
bias_data,
54+
out_dim,
55+
in_dim,
56+
in_dim,
57+
leading_dims,
58+
in_dim,
59+
out_dim,
60+
1,
6261
-weight_zero_point.const_data_ptr<int32_t>()[0], // mat1_zero_bias
6362
-in_zero_point, // mat2_zero_bias
64-
out_multiplier.const_data_ptr<int32_t>()[0],
65-
out_shift.const_data_ptr<int32_t>()[0],
66-
out_zero_point);
63+
out_multiplier.const_data_ptr<int32_t>()[0],
64+
out_shift.const_data_ptr<int32_t>()[0],
65+
out_zero_point);
6766
ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear failed");
6867
}
6968

examples/models/llama/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ runtime.python_library(
9393
"source_transformation/sdpa.py",
9494
"source_transformation/spin_quant.py",
9595
"source_transformation/vulkan_rope.py",
96+
"source_transformation/attention_sink.py",
9697
],
9798
_is_external_target = True,
9899
base_module = "executorch.examples.models.llama",
@@ -213,3 +214,16 @@ runtime.python_test(
213214
"//executorch/examples/models/llama:llama_transformer",
214215
],
215216
)
217+
218+
runtime.python_test(
219+
name = "attention_sink_test",
220+
srcs = [
221+
"source_transformation/test_attention_sink.py",
222+
],
223+
supports_static_listing = False,
224+
deps = [
225+
"fbsource//third-party/pypi/parameterized:parameterized",
226+
"//caffe2:torch",
227+
":export_library",
228+
],
229+
)

examples/models/llama/llama_transformer.py

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,81 @@ def __post_init__(self):
147147
self.head_dim = self.dim // self.n_heads
148148

149149

150+
class Rope(torch.nn.Module):
151+
def __init__(self, params: ModelArgs):
152+
super().__init__()
153+
self.params = params
154+
if self.params.use_hf_rope:
155+
self.precompute_freqs_cis = hf_precompute_freqs_cis
156+
else:
157+
self.precompute_freqs_cis = partial(
158+
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
159+
)
160+
freqs_cos, freqs_sin = self.precompute_freqs_cis(
161+
self.params.head_dim,
162+
(
163+
self.params.max_seq_len # Normal llama2.
164+
if self.params.ffn_dim_multiplier is None
165+
else self.params.max_seq_len * 2 # Sharded checkpoint.
166+
),
167+
self.params.rope_freq_base,
168+
)
169+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
170+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
171+
if self.params.use_hf_rope:
172+
self.apply_rotary_emb = hf_apply_rotary_emb
173+
else:
174+
self.apply_rotary_emb = RotaryEmbedding()
175+
176+
def forward(
177+
self,
178+
q: torch.Tensor,
179+
k: torch.Tensor,
180+
freqs_cos: torch.Tensor,
181+
freqs_sin: torch.Tensor,
182+
):
183+
return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
184+
185+
def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
186+
"""
187+
Get the precomputed frequencies for the given input position and sequence length.
188+
189+
Args:
190+
input_pos (torch.Tensor): The input position tensor.
191+
seq_len (int): The sequence length.
192+
193+
Returns:
194+
Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
195+
"""
196+
if self.params.use_kv_cache:
197+
assert (
198+
input_pos is not None
199+
), "input_pos must be provided when use_kv_cache is True"
200+
201+
if self.params.enable_dynamic_shape:
202+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
203+
input_pos_item = input_pos[-1].item()
204+
torch._check_is_size(input_pos_item)
205+
torch._check(input_pos_item < self.params.max_seq_len)
206+
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
207+
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
208+
# pyre-ignore: Incompatible parameter type [6]
209+
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
210+
else:
211+
# When not using dynamic shape, use of the .item results in
212+
# symints, due to querying the data from tensor.
213+
# this path avoids that for mps backend, although probably mps backend
214+
# can support dynamic shape?
215+
freqs_cos = self.freqs_cos[input_pos]
216+
freqs_sin = self.freqs_sin[input_pos]
217+
218+
else:
219+
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
220+
freqs_cos = self.freqs_cos[:seq_len]
221+
freqs_sin = self.freqs_sin[:seq_len]
222+
return freqs_cos, freqs_sin
223+
224+
150225
class KVCache(nn.Module):
151226
def __init__(
152227
self,
@@ -266,7 +341,7 @@ def forward(
266341

267342

268343
class Attention(nn.Module):
269-
def __init__(self, args: ModelArgs, layer_id: int):
344+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
270345
super().__init__()
271346
self.use_kv_cache = args.use_kv_cache
272347
self.n_heads = args.n_heads
@@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
287362

288363
self.layer_id = layer_id
289364

365+
self.rope = rope
366+
290367
causal_mask = torch.tril(
291368
torch.ones(
292369
self.max_seq_len,
@@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
303380
args.max_seq_len,
304381
self.n_kv_heads,
305382
self.head_dim,
306-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
383+
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
307384
args.enable_dynamic_shape,
308385
)
309386
self.SDPA = SDPA(
@@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
314391
max_seq_len=self.max_seq_len,
315392
enable_dynamic_shape=args.enable_dynamic_shape,
316393
)
317-
if args.use_hf_rope:
318-
self.apply_rotary_emb = hf_apply_rotary_emb
319-
else:
320-
self.apply_rotary_emb = RotaryEmbedding()
321394

322395
def forward(
323396
self,
@@ -336,7 +409,7 @@ def forward(
336409
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
337410

338411
# RoPE relative positional embeddings
339-
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
412+
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
340413

341414
if self.use_kv_cache:
342415
assert input_pos is not None
@@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
424497

425498

426499
class TransformerBlock(nn.Module):
427-
def __init__(self, layer_id: int, args: ModelArgs):
500+
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
428501
super().__init__()
429502
self.use_kv_cache = args.use_kv_cache
430503
self.n_heads = args.n_heads
431504
self.dim = args.dim
432505
self.head_dim = args.head_dim
433-
self.attention = Attention(args, layer_id)
506+
self.attention = Attention(args, layer_id, rope)
434507
if args.moe:
435508
self.block_sparse_moe = MOEFeedForward(args)
436509
else:
@@ -459,33 +532,17 @@ def __init__(self, params: ModelArgs):
459532
self.n_layers = params.n_layers
460533

461534
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
535+
self.rope = Rope(params)
462536
self.layers = torch.nn.ModuleList()
463537
for layer_id in range(params.n_layers):
464-
self.layers.append(TransformerBlock(layer_id, params))
538+
self.layers.append(TransformerBlock(layer_id, params, self.rope))
465539
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
466540
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
467541
self.use_kv_cache = params.use_kv_cache
468542
self.generate_full_logits = params.generate_full_logits
469543
self.max_seq_len = params.max_seq_len
470544
self.input_prune_map = params.input_prune_map
471545
self.output_prune_map = params.output_prune_map
472-
if params.use_hf_rope:
473-
self.precompute_freqs_cis = hf_precompute_freqs_cis
474-
else:
475-
self.precompute_freqs_cis = partial(
476-
precompute_freqs_cis, use_scaled=params.use_scaled_rope
477-
)
478-
freqs_cos, freqs_sin = self.precompute_freqs_cis(
479-
params.head_dim,
480-
(
481-
params.max_seq_len # Normal llama2.
482-
if params.ffn_dim_multiplier is None
483-
else params.max_seq_len * 2 # Sharded checkpoint.
484-
),
485-
params.rope_freq_base,
486-
)
487-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
488-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
489546

490547
def forward(
491548
self,
@@ -502,33 +559,7 @@ def forward(
502559
if tokens is not None and h is None:
503560
h = self.tok_embeddings(tokens)
504561
seqlen = h.shape[1]
505-
506-
if self.use_kv_cache:
507-
assert (
508-
input_pos is not None
509-
), "input_pos must be provided when use_kv_cache is True"
510-
511-
if self.params.enable_dynamic_shape:
512-
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
513-
input_pos_item = input_pos[-1].item()
514-
torch._check_is_size(input_pos_item)
515-
torch._check(input_pos_item < self.params.max_seq_len)
516-
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
517-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
518-
# pyre-ignore: Incompatible parameter type [6]
519-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
520-
else:
521-
# When not using dynamic shape, use of the .item results in
522-
# symints, due to querying the data from tensor.
523-
# this path avoids that for mps backend, although probably mps backend
524-
# can support dynamic shape?
525-
freqs_cos = self.freqs_cos[input_pos]
526-
freqs_sin = self.freqs_sin[input_pos]
527-
528-
else:
529-
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
530-
freqs_cos = self.freqs_cos[:seqlen]
531-
freqs_sin = self.freqs_sin[:seqlen]
562+
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
532563

533564
for layer in self.layers:
534565
h = layer(

0 commit comments

Comments
 (0)