@@ -862,6 +862,73 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
862
862
result = super ().call (graph_module )
863
863
return result
864
864
865
+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
866
+ class FuseMulTensorIntoQuantPass (ExportPass ):
867
+ """
868
+ Looks for the pattern where aten.mul.Tensor is followed by quant node.
869
+ If found, updates the quant scale to reflect the multiplication and
870
+ removes the mul node.
871
+ """
872
+ def attempt_fusion (
873
+ self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
874
+ ) -> None :
875
+ if mul_node .target != exir_ops .edge .aten .mul .Tensor :
876
+ return
877
+
878
+ full_nodes = [
879
+ arg
880
+ for arg in mul_node .args
881
+ if isinstance (arg , torch .fx .Node )
882
+ and arg .target == exir_ops .edge .aten .full .default
883
+ ]
884
+
885
+ if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
886
+ return
887
+
888
+ full_node = full_nodes [0 ]
889
+ mul_user = list (mul_node .users .keys ())[0 ]
890
+
891
+ if mul_user .target not in {
892
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
893
+ exir_ops .edge .cadence .quantize_per_tensor .default ,
894
+ }:
895
+ return
896
+
897
+ quant_node = mul_user
898
+
899
+ # First create a copy of the current args
900
+ new_quant_args = list (quant_node .args )
901
+ assert isinstance (quant_node .args [1 ], Number )
902
+ assert isinstance (full_node .args [1 ], Number )
903
+ # pyre-ignore[58]: Unsupported operand *
904
+ new_scale = quant_node .args [1 ] * full_node .args [1 ]
905
+
906
+ logging .debug (
907
+ f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
908
+ )
909
+
910
+ # Replace the input first
911
+ quant_node .replace_input_with (cast (torch .fx .Node , quant_node .args [0 ]), cast (torch .fx .Node , mul_node .args [0 ]))
912
+
913
+ # Now update the scale in the args
914
+ new_quant_args = list (quant_node .args )
915
+ new_quant_args [1 ] = new_scale
916
+ quant_node .args = tuple (new_quant_args )
917
+
918
+ # Clean up the mul_node
919
+ mul_node .args = tuple ()
920
+ mul_node .users = {}
921
+
922
+ graph_module .graph .erase_node (mul_node )
923
+ graph_module .graph .erase_node (full_node )
924
+ graph_module .recompile ()
925
+
926
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
927
+ for node in graph_module .graph .nodes :
928
+ self .attempt_fusion (graph_module , node )
929
+ result = super ().call (graph_module )
930
+ return result
931
+
865
932
866
933
@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
867
934
class FuseMulTensorIntoDequantPass (ExportPass ):
0 commit comments