Skip to content

[Executorch][llama] Make RoPE freq calculation broadcast for per head #2353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)

from executorch.examples.models.llama2.llama_transformer import Transformer
from executorch.exir.backend.backend_details import CompileSpec

from executorch.sdk.etrecord import generate_etrecord
Expand Down Expand Up @@ -174,6 +176,32 @@ def check_embedding_byte_registered():
return quantizers


def materialze_broadcast_of_rope_freq_cis(
module: torch.nn.Module,
):
assert isinstance(module, Transformer)
assert module.freqs_cos.dim() == 2
dim0 = module.freqs_cos.size(0)
dim1 = module.freqs_cos.size(1)
assert (
module.layers[0].attention.n_local_kv_heads
== module.layers[0].attention.n_local_heads
), f"For rope freqs to be materialzed for broadcast q, k, v num heads must match. For q got {module.attention.n_kv_heads} for k got {module.attention.n_local_heads} and v got {module.attention.n_local_kv_heads}"
num_heads = module.layers[0].attention.n_local_heads
module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
assert module.freqs_sin.dim() == 2
assert dim0 == module.freqs_sin.size(
0
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}"
assert dim1 == module.freqs_sin.size(
1
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}"
module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1)
module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous()
return module


def quantize(
model: torch.nn.Module,
qmode: str,
Expand Down Expand Up @@ -369,6 +397,13 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--mps", action="store_true")

parser.add_argument(
"--expand_rope_table",
default=False,
action="store_true",
help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.",
)

parser.add_argument(
"--generate_etrecord",
action="store_true",
Expand Down Expand Up @@ -464,6 +499,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
).quantized_model()
)

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

return (
load_llama_model(
checkpoint=checkpoint_path,
Expand Down
15 changes: 12 additions & 3 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,18 @@ def precompute_freqs_cis(dim: int, end: int, theta: float):

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
freqs_cis_ndim = freqs_cis.ndim
if freqs_cis_ndim == 3:
# freqs_cis: (seq_len, n_heads, head_dim // 2)
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
shape = [
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
for i, d in enumerate(x.shape)
]
else:
# freqs_cis: (seq_len, head_dim // 2)
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)


Expand Down
26 changes: 24 additions & 2 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,29 @@ namespace native {
using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;

namespace {

// Move to generic util as this is applicable to all binary ops
bool can_use_optimized_path(
const Tensor& a,
const Tensor& b,
const Tensor& out) {
ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();

bool can_use_optimized_path = true;
can_use_optimized_path =
can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
can_use_optimized_path = can_use_optimized_path &&
(a_type != ScalarType::Half && b_type != ScalarType::Half);
can_use_optimized_path = can_use_optimized_path &&
(a.sizes().equals(b.sizes()) ||
(a.numel() == b.numel() && a.numel() == out.numel()));
return can_use_optimized_path;
}
} // namespace

Tensor& opt_mul_out(
RuntimeContext& ctx,
const Tensor& a,
Expand All @@ -31,8 +54,7 @@ Tensor& opt_mul_out(
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();

if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes()) &&
a_type != ScalarType::Half) {
if (can_use_optimized_path(a, b, out)) {
// Resize for dynamic shape
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
Expand Down