Skip to content

Commit 08733f0

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Make RoPE freq calculation broadcast for per head (#2353)
Summary: Pull Request resolved: #2353 This is a workaround, may not be even worth landing, to avoid broadcasting semantics in the mul op and for that matter any binary op. Current implementation of oiptimized ops doesnt handle broadcasting and falls back to portable op implementation. This diff also fixes an issue where (as seen in llama) two tensors of binary op are not broadcasting, but they have different # of dims, which results in invocation of unoptimized path. e.g. a = [1, 1, 2048], b = [2048], out = [1, 1, 2048]. In llama case this is optimized path when generating one token at a time. Not so during pre-fill Making optimized op handle broadcasting, and support vectorization, is not hard, but may take some time. ghstack-source-id: 219444233 exported-using-ghexport Reviewed By: digantdesai, kirklandsign Differential Revision: D54766067 fbshipit-source-id: 0b7318959994b93388832940a98e25b9cc360978
1 parent 9c20929 commit 08733f0

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2323
XnnpackDynamicallyQuantizedPartitioner,
2424
)
25+
26+
from executorch.examples.models.llama2.llama_transformer import Transformer
2527
from executorch.exir.backend.backend_details import CompileSpec
2628

2729
from executorch.sdk.etrecord import generate_etrecord
@@ -174,6 +176,32 @@ def check_embedding_byte_registered():
174176
return quantizers
175177

176178

179+
def materialze_broadcast_of_rope_freq_cis(
180+
module: torch.nn.Module,
181+
):
182+
assert isinstance(module, Transformer)
183+
assert module.freqs_cos.dim() == 2
184+
dim0 = module.freqs_cos.size(0)
185+
dim1 = module.freqs_cos.size(1)
186+
assert (
187+
module.layers[0].attention.n_local_kv_heads
188+
== module.layers[0].attention.n_local_heads
189+
), f"For rope freqs to be materialzed for broadcast q, k, v num heads must match. For q got {module.attention.n_kv_heads} for k got {module.attention.n_local_heads} and v got {module.attention.n_local_kv_heads}"
190+
num_heads = module.layers[0].attention.n_local_heads
191+
module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
192+
module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
193+
assert module.freqs_sin.dim() == 2
194+
assert dim0 == module.freqs_sin.size(
195+
0
196+
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}"
197+
assert dim1 == module.freqs_sin.size(
198+
1
199+
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}"
200+
module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1)
201+
module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous()
202+
return module
203+
204+
177205
def quantize(
178206
model: torch.nn.Module,
179207
qmode: str,
@@ -369,6 +397,13 @@ def build_args_parser() -> argparse.ArgumentParser:
369397
parser.add_argument("-V", "--vulkan", action="store_true")
370398
parser.add_argument("--mps", action="store_true")
371399

400+
parser.add_argument(
401+
"--expand_rope_table",
402+
default=False,
403+
action="store_true",
404+
help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.",
405+
)
406+
372407
parser.add_argument(
373408
"--generate_etrecord",
374409
action="store_true",
@@ -464,6 +499,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
464499
).quantized_model()
465500
)
466501

502+
if args.expand_rope_table:
503+
transforms.append(materialze_broadcast_of_rope_freq_cis)
504+
467505
return (
468506
load_llama_model(
469507
checkpoint=checkpoint_path,

examples/models/llama2/llama_transformer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,18 @@ def precompute_freqs_cis(dim: int, end: int, theta: float):
122122

123123
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
124124
ndim = x.ndim
125-
assert 0 <= 1 < ndim
126-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
127-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
125+
freqs_cis_ndim = freqs_cis.ndim
126+
if freqs_cis_ndim == 3:
127+
# freqs_cis: (seq_len, n_heads, head_dim // 2)
128+
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
129+
shape = [
130+
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
131+
for i, d in enumerate(x.shape)
132+
]
133+
else:
134+
# freqs_cis: (seq_len, head_dim // 2)
135+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
136+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
128137
return freqs_cis.view(shape)
129138

130139

kernels/optimized/cpu/op_mul.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,29 @@ namespace native {
2020
using Tensor = exec_aten::Tensor;
2121
using ScalarType = exec_aten::ScalarType;
2222

23+
namespace {
24+
25+
// Move to generic util as this is applicable to all binary ops
26+
bool can_use_optimized_path(
27+
const Tensor& a,
28+
const Tensor& b,
29+
const Tensor& out) {
30+
ScalarType a_type = a.scalar_type();
31+
ScalarType b_type = b.scalar_type();
32+
ScalarType out_type = out.scalar_type();
33+
34+
bool can_use_optimized_path = true;
35+
can_use_optimized_path =
36+
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
37+
can_use_optimized_path = can_use_optimized_path &&
38+
(a_type != ScalarType::Half && b_type != ScalarType::Half);
39+
can_use_optimized_path = can_use_optimized_path &&
40+
(a.sizes().equals(b.sizes()) ||
41+
(a.numel() == b.numel() && a.numel() == out.numel()));
42+
return can_use_optimized_path;
43+
}
44+
} // namespace
45+
2346
Tensor& opt_mul_out(
2447
RuntimeContext& ctx,
2548
const Tensor& a,
@@ -31,8 +54,7 @@ Tensor& opt_mul_out(
3154
ScalarType b_type = b.scalar_type();
3255
ScalarType out_type = out.scalar_type();
3356

34-
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes()) &&
35-
a_type != ScalarType::Half) {
57+
if (can_use_optimized_path(a, b, out)) {
3658
// Resize for dynamic shape
3759
auto error = resize_tensor(out, a.sizes());
3860
ET_KERNEL_CHECK_MSG(

0 commit comments

Comments
 (0)