@@ -33,7 +33,7 @@ def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module)
33
33
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
34
34
)
35
35
36
- class FeedForwardCustom (nn .Module ):
36
+ class FeedForwardCudaCustom (nn .Module ):
37
37
def __init__ (self , w1 , w2 , w3 ):
38
38
super ().__init__ ()
39
39
self .w1 = w1
@@ -47,7 +47,7 @@ def forward(self, x):
47
47
48
48
for name , child in module .named_children ():
49
49
if isinstance (child , FeedForward ):
50
- setattr (module , name , FeedForwardCustom (child .w1 , child .w2 , child .w3 ))
50
+ setattr (module , name , FeedForwardCudaCustom (child .w1 , child .w2 , child .w3 ))
51
51
else :
52
52
_inject_fast_hadamard_transform_cuda_for_spin_quant (child )
53
53
@@ -59,6 +59,38 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
59
59
return module
60
60
61
61
62
+ def _inject_fast_hadamard_transform_native_for_spin_quant (module : torch .nn .Module ):
63
+ """
64
+ SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
65
+ R3 needs to be injected as well when KV cache quantization is enabled.
66
+ """
67
+
68
+ class FeedForwardNativeCustom (nn .Module ):
69
+ def __init__ (self , w1 , w2 , w3 ):
70
+ super ().__init__ ()
71
+ self .w1 = w1
72
+ self .w2 = w2
73
+ self .w3 = w3
74
+
75
+ def forward (self , x ):
76
+ return self .w2 (
77
+ torch .ops .llama .fast_hadamard_transform (F .silu (self .w1 (x )) * self .w3 (x ))
78
+ )
79
+
80
+ for name , child in module .named_children ():
81
+ if isinstance (child , FeedForward ):
82
+ setattr (module , name , FeedForwardNativeCustom (child .w1 , child .w2 , child .w3 ))
83
+ else :
84
+ _inject_fast_hadamard_transform_native_for_spin_quant (child )
85
+
86
+
87
+ def inject_fast_hadamard_transform_native_for_spin_quant (
88
+ module : torch .nn .Module ,
89
+ ) -> torch .nn .Module :
90
+ _inject_fast_hadamard_transform_native_for_spin_quant (module )
91
+ return module
92
+
93
+
62
94
def _replace_linear_with_linear_8da4w_for_spin_quant (
63
95
module : torch .nn .Module ,
64
96
checkpoint : Any ,
0 commit comments