Skip to content

Commit 7c6a711

Browse files
committed
[Executorch][llama] Make RoPE freq calculation broadcast for per head
Pull Request resolved: #2353 This is a workaround, may not be even worth landing, to avoid broadcasting semantics in the mul op and for that matter any binary op. Current implementation of oiptimized ops doesnt handle broadcasting and falls back to portable op implementation. This diff also fixes an issue where (as seen in llama) two tensors of binary op are not broadcasting, but they have different # of dims, which results in invocation of unoptimized path. e.g. a = [1, 1, 2048], b = [2048], out = [1, 1, 2048]. In llama case this is optimized path when generating one token at a time. Not so during pre-fill Making optimized op handle broadcasting, and support vectorization, is not hard, but may take some time. ghstack-source-id: 218210434 @exported-using-ghexport Differential Revision: [D54766067](https://our.internmc.facebook.com/intern/diff/D54766067/)
1 parent 50b864c commit 7c6a711

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,26 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
109109
)
110110

111111

112-
def precompute_freqs_cis(dim: int, end: int, theta: float):
112+
def precompute_freqs_cis(dim: int, n_heads: int, end: int, theta: float):
113113
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
114114
t = torch.arange(end, device=freqs.device) # pyre-ignore
115115
freqs = torch.outer(t, freqs).float() # pyre-ignore
116116
freqs_cos = torch.cos(freqs)
117117
freqs_sin = torch.sin(freqs)
118+
freqs_cos = freqs_cos.view(end, 1, dim // 2)
119+
freqs_cos = freqs_cos.expand(end, n_heads, dim // 2).contiguous()
120+
freqs_sin = freqs_sin.view(end, 1, dim // 2)
121+
freqs_sin = freqs_sin.expand(end, n_heads, dim // 2).contiguous()
118122
return freqs_cos, freqs_sin
119123

120124

121125
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
122126
ndim = x.ndim
123127
assert 0 <= 1 < ndim
124-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
125-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
128+
assert freqs_cis.shape == (x.shape[1], x.shape[2], x.shape[-1])
129+
shape = [
130+
d if (i == 1 or i == 2 or i == ndim - 1) else 1 for i, d in enumerate(x.shape)
131+
]
126132
return freqs_cis.view(shape)
127133

128134

@@ -413,6 +419,7 @@ def __init__(self, params: ModelArgs):
413419

414420
freqs_cos, freqs_sin = precompute_freqs_cis(
415421
params.dim // params.n_heads,
422+
params.n_heads,
416423
params.max_seq_len,
417424
params.rope_freq_base,
418425
)

kernels/optimized/cpu/op_mul.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,29 @@ namespace native {
2020
using Tensor = exec_aten::Tensor;
2121
using ScalarType = exec_aten::ScalarType;
2222

23+
namespace {
24+
25+
// Move to generic util as this is applicable to all binary ops
26+
bool can_use_optimized_path(
27+
const Tensor& a,
28+
const Tensor& b,
29+
const Tensor& out) {
30+
ScalarType a_type = a.scalar_type();
31+
ScalarType b_type = b.scalar_type();
32+
ScalarType out_type = out.scalar_type();
33+
34+
bool can_use_optimized_path = true;
35+
can_use_optimized_path =
36+
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
37+
can_use_optimized_path = can_use_optimized_path &&
38+
(a_type != ScalarType::Half && b_type != ScalarType::Half);
39+
can_use_optimized_path = can_use_optimized_path &&
40+
(a.sizes().equals(b.sizes()) ||
41+
(a.numel() == b.numel() && a.numel() == out.numel()));
42+
return can_use_optimized_path;
43+
}
44+
} // namespace
45+
2346
Tensor& opt_mul_out(
2447
RuntimeContext& ctx,
2548
const Tensor& a,
@@ -31,8 +54,7 @@ Tensor& opt_mul_out(
3154
ScalarType b_type = b.scalar_type();
3255
ScalarType out_type = out.scalar_type();
3356

34-
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes()) &&
35-
a_type != ScalarType::Half) {
57+
if (can_use_optimized_path(a, b, out)) {
3658
// Resize for dynamic shape
3759
auto error = resize_tensor(out, a.sizes());
3860
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");

0 commit comments

Comments
 (0)