Skip to content

Commit 7ecf824

Browse files
committed
[Executorch][llama] Make RoPE freq calculation broadcast for per head
Pull Request resolved: #2353 This is a workaround, may not be even worth landing, to avoid broadcasting semantics in the mul op and for that matter any binary op. Current implementation of oiptimized ops doesnt handle broadcasting and falls back to portable op implementation. This diff also fixes an issue where (as seen in llama) two tensors of binary op are not broadcasting, but they have different # of dims, which results in invocation of unoptimized path. e.g. a = [1, 1, 2048], b = [2048], out = [1, 1, 2048]. In llama case this is optimized path when generating one token at a time. Not so during pre-fill Making optimized op handle broadcasting, and support vectorization, is not hard, but may take some time. ghstack-source-id: 218755098 @exported-using-ghexport Differential Revision: [D54766067](https://our.internmc.facebook.com/intern/diff/D54766067/)
1 parent 03ac8c7 commit 7ecf824

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,22 +109,28 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
109109
)
110110

111111

112-
def precompute_freqs_cis(dim: int, end: int, theta: float):
112+
def precompute_freqs_cis(dim: int, n_heads: int, end: int, theta: float):
113113
freqs = 1.0 / (
114114
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
115115
)
116116
t = torch.arange(end, device=freqs.device) # pyre-ignore
117117
freqs = torch.outer(t, freqs).float() # pyre-ignore
118118
freqs_cos = torch.cos(freqs)
119119
freqs_sin = torch.sin(freqs)
120+
freqs_cos = freqs_cos.view(end, 1, dim // 2)
121+
freqs_cos = freqs_cos.expand(end, n_heads, dim // 2).contiguous()
122+
freqs_sin = freqs_sin.view(end, 1, dim // 2)
123+
freqs_sin = freqs_sin.expand(end, n_heads, dim // 2).contiguous()
120124
return freqs_cos, freqs_sin
121125

122126

123127
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
124128
ndim = x.ndim
125129
assert 0 <= 1 < ndim
126-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
127-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130+
assert freqs_cis.shape == (x.shape[1], x.shape[2], x.shape[-1])
131+
shape = [
132+
d if (i == 1 or i == 2 or i == ndim - 1) else 1 for i, d in enumerate(x.shape)
133+
]
128134
return freqs_cis.view(shape)
129135

130136

@@ -416,6 +422,7 @@ def __init__(self, params: ModelArgs):
416422

417423
freqs_cos, freqs_sin = precompute_freqs_cis(
418424
params.dim // params.n_heads,
425+
params.n_heads,
419426
params.max_seq_len,
420427
params.rope_freq_base,
421428
)

kernels/optimized/cpu/op_mul.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,29 @@ namespace native {
2020
using Tensor = exec_aten::Tensor;
2121
using ScalarType = exec_aten::ScalarType;
2222

23+
namespace {
24+
25+
// Move to generic util as this is applicable to all binary ops
26+
bool can_use_optimized_path(
27+
const Tensor& a,
28+
const Tensor& b,
29+
const Tensor& out) {
30+
ScalarType a_type = a.scalar_type();
31+
ScalarType b_type = b.scalar_type();
32+
ScalarType out_type = out.scalar_type();
33+
34+
bool can_use_optimized_path = true;
35+
can_use_optimized_path =
36+
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
37+
can_use_optimized_path = can_use_optimized_path &&
38+
(a_type != ScalarType::Half && b_type != ScalarType::Half);
39+
can_use_optimized_path = can_use_optimized_path &&
40+
(a.sizes().equals(b.sizes()) ||
41+
(a.numel() == b.numel() && a.numel() == out.numel()));
42+
return can_use_optimized_path;
43+
}
44+
} // namespace
45+
2346
Tensor& opt_mul_out(
2447
RuntimeContext& ctx,
2548
const Tensor& a,
@@ -31,8 +54,7 @@ Tensor& opt_mul_out(
3154
ScalarType b_type = b.scalar_type();
3255
ScalarType out_type = out.scalar_type();
3356

34-
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes()) &&
35-
a_type != ScalarType::Half) {
57+
if (can_use_optimized_path(a, b, out)) {
3658
// Resize for dynamic shape
3759
auto error = resize_tensor(out, a.sizes());
3860
ET_KERNEL_CHECK_MSG(

0 commit comments

Comments
 (0)