|
22 | 22 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
|
23 | 23 | XnnpackDynamicallyQuantizedPartitioner,
|
24 | 24 | )
|
| 25 | + |
| 26 | +from executorch.examples.models.llama2.llama_transformer import Transformer |
25 | 27 | from executorch.exir.backend.backend_details import CompileSpec
|
26 | 28 |
|
27 | 29 | from executorch.sdk.etrecord import generate_etrecord
|
@@ -174,6 +176,32 @@ def check_embedding_byte_registered():
|
174 | 176 | return quantizers
|
175 | 177 |
|
176 | 178 |
|
| 179 | +def materialze_broadcast_of_rope_freq_cis( |
| 180 | + module: torch.nn.Module, |
| 181 | +): |
| 182 | + assert isinstance(module, Transformer) |
| 183 | + assert module.freqs_cos.dim() == 2 |
| 184 | + dim0 = module.freqs_cos.size(0) |
| 185 | + dim1 = module.freqs_cos.size(1) |
| 186 | + assert ( |
| 187 | + module.layers[0].attention.n_local_kv_heads |
| 188 | + == module.layers[0].attention.n_local_heads |
| 189 | + ), f"For rope freqs to be materialzed for broadcast q, k, v num heads must match. For q got {module.attention.n_kv_heads} for k got {module.attention.n_local_heads} and v got {module.attention.n_local_kv_heads}" |
| 190 | + num_heads = module.layers[0].attention.n_local_heads |
| 191 | + module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1) |
| 192 | + module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous() |
| 193 | + assert module.freqs_sin.dim() == 2 |
| 194 | + assert dim0 == module.freqs_sin.size( |
| 195 | + 0 |
| 196 | + ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}" |
| 197 | + assert dim1 == module.freqs_sin.size( |
| 198 | + 1 |
| 199 | + ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}" |
| 200 | + module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1) |
| 201 | + module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous() |
| 202 | + return module |
| 203 | + |
| 204 | + |
177 | 205 | def quantize(
|
178 | 206 | model: torch.nn.Module,
|
179 | 207 | qmode: str,
|
@@ -369,6 +397,13 @@ def build_args_parser() -> argparse.ArgumentParser:
|
369 | 397 | parser.add_argument("-V", "--vulkan", action="store_true")
|
370 | 398 | parser.add_argument("--mps", action="store_true")
|
371 | 399 |
|
| 400 | + parser.add_argument( |
| 401 | + "--expand_rope_table", |
| 402 | + default=False, |
| 403 | + action="store_true", |
| 404 | + help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.", |
| 405 | + ) |
| 406 | + |
372 | 407 | parser.add_argument(
|
373 | 408 | "--generate_etrecord",
|
374 | 409 | action="store_true",
|
@@ -464,6 +499,9 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
|
464 | 499 | ).quantized_model()
|
465 | 500 | )
|
466 | 501 |
|
| 502 | + if args.expand_rope_table: |
| 503 | + transforms.append(materialze_broadcast_of_rope_freq_cis) |
| 504 | + |
467 | 505 | return (
|
468 | 506 | load_llama_model(
|
469 | 507 | checkpoint=checkpoint_path,
|
|
0 commit comments