Skip to content

Commit 9a90e5d

Browse files
committed
rename the rotation file to apply_spin_quant_r1_r2
1 parent c5e09d7 commit 9a90e5d

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,16 @@
4545
from executorch.util.activation_memory_profiler import generate_memory_trace
4646

4747
from ..model_factory import EagerModelFactory
48+
from .source_transformation.apply_spin_quant_r1_r2 import (
49+
fuse_layer_norms,
50+
get_model_with_r1_r2,
51+
)
4852
from .source_transformation.quantize import (
4953
get_quant_embedding_transform,
5054
get_quant_weight_transform,
5155
)
5256
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
5357
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
54-
from .source_transformation.rotation import fuse_layer_norms, get_rotate_model
5558
from .source_transformation.sdpa import (
5659
replace_causal_mask,
5760
replace_kv_cache_with_simple_kv_cache,
@@ -434,7 +437,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
434437

435438
if args.optimized_rotation_path:
436439
transforms.append(fuse_layer_norms)
437-
transforms.append(get_rotate_model(args.optimized_rotation_path))
440+
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
438441
return (
439442
_load_llama_model(
440443
modelname=modelname,

examples/models/llama2/source_transformation/rotation.py renamed to examples/models/llama2/source_transformation/apply_spin_quant_r1_r2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def cleanup_memory() -> None:
9393
gc.collect()
9494

9595

96-
def get_rotate_model(optimized_rotation_path: str):
97-
return lambda model: rotate_model(model, optimized_rotation_path)
96+
def get_model_with_r1_r2(optimized_rotation_path: str):
97+
return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path)
9898

9999

100-
def rotate_model(model: torch.nn.Module, optimized_rotation_path: str):
100+
def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str):
101101
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
102102
R1 = optimized_rotation["R1"].to(torch.float32)
103103
config = model.params

0 commit comments

Comments
 (0)