Skip to content

Commit a556a2d

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
Support SpinQuant to run on ET (#5435)
Summary: Pull Request resolved: #5435 This PR adds the option to run SpinQuant on ET. Reviewed By: mergennachin Differential Revision: D62526665 fbshipit-source-id: ff18110656d5ad90eb79020a1c2f6d235a9001b3
1 parent 28c9a1d commit a556a2d

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,11 @@ def _get_source_transforms( # noqa
790790

791791
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
792792
elif args.use_spin_quant == "native":
793-
raise NotImplementedError("native SpinQuant is not implemented yet.")
793+
from .source_transformation.spin_quant import (
794+
inject_fast_hadamard_transform_native_for_spin_quant,
795+
)
796+
797+
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
794798

795799
if args.embedding_quantize:
796800
modelname = f"{modelname}_e"

examples/models/llama2/source_transformation/spin_quant.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module)
3333
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
3434
)
3535

36-
class FeedForwardCustom(nn.Module):
36+
class FeedForwardCudaCustom(nn.Module):
3737
def __init__(self, w1, w2, w3):
3838
super().__init__()
3939
self.w1 = w1
@@ -47,7 +47,7 @@ def forward(self, x):
4747

4848
for name, child in module.named_children():
4949
if isinstance(child, FeedForward):
50-
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
50+
setattr(module, name, FeedForwardCudaCustom(child.w1, child.w2, child.w3))
5151
else:
5252
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)
5353

@@ -59,6 +59,38 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
5959
return module
6060

6161

62+
def _inject_fast_hadamard_transform_native_for_spin_quant(module: torch.nn.Module):
63+
"""
64+
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
65+
R3 needs to be injected as well when KV cache quantization is enabled.
66+
"""
67+
68+
class FeedForwardNativeCustom(nn.Module):
69+
def __init__(self, w1, w2, w3):
70+
super().__init__()
71+
self.w1 = w1
72+
self.w2 = w2
73+
self.w3 = w3
74+
75+
def forward(self, x):
76+
return self.w2(
77+
torch.ops.llama.fast_hadamard_transform(F.silu(self.w1(x)) * self.w3(x))
78+
)
79+
80+
for name, child in module.named_children():
81+
if isinstance(child, FeedForward):
82+
setattr(module, name, FeedForwardNativeCustom(child.w1, child.w2, child.w3))
83+
else:
84+
_inject_fast_hadamard_transform_native_for_spin_quant(child)
85+
86+
87+
def inject_fast_hadamard_transform_native_for_spin_quant(
88+
module: torch.nn.Module,
89+
) -> torch.nn.Module:
90+
_inject_fast_hadamard_transform_native_for_spin_quant(module)
91+
return module
92+
93+
6294
def _replace_linear_with_linear_8da4w_for_spin_quant(
6395
module: torch.nn.Module,
6496
checkpoint: Any,

0 commit comments

Comments
 (0)