@@ -248,32 +248,28 @@ def forward(self, x):
248
248
)
249
249
250
250
def test_force_quant_dequant_fusion (self ):
251
- class M (torch .nn .Module ):
252
- def __init__ (self ):
253
- super ().__init__ ()
254
-
255
- def forward (self , x ):
256
- x = torch .ops .quantized_decomposed .quantize_per_tensor (
257
- x , 1.2 , 3 , 0 , 127 , torch .int8
258
- )
259
- x = torch .permute (x , [2 , 0 , 1 , 3 ])
260
- x = torch .ops .quantized_decomposed .dequantize_per_tensor (
261
- x , 4.5 , 6 , 0 , 127 , torch .int8
262
- )
263
- return x
264
-
265
- inputs = torch .randn (2 , 12 , 1 , 6 )
266
- model = M ()
267
- graph_module = export_to_edge (model , (inputs ,)).exported_program ().graph_module
268
-
269
- graph_module = FuseQuantDequantToRequantizePass (
251
+ builder = GraphBuilder ()
252
+ x = builder .placeholder ("x" , torch .randn (2 , 12 , 1 , 6 , dtype = torch .float32 ))
253
+ quant = builder .call_operator (
254
+ op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
255
+ args = (x , 1.2 , 3 , 0 , 127 , torch .int8 ),
256
+ )
257
+ permute = builder .call_operator (
258
+ op = exir_ops .edge .aten .permute_copy .default , args = (quant , [2 , 0 , 1 , 3 ])
259
+ )
260
+ dequant = builder .call_operator (
261
+ op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
262
+ args = (permute , 4.5 , 6 , 0 , 127 , torch .int8 ),
263
+ )
264
+ builder .output (dequant )
265
+ original_graph = builder .get_graph_module ()
266
+ converted_graph = FuseQuantDequantToRequantizePass (
270
267
force_quant_dequant_fusion = True
271
- )(graph_module ).graph_module
268
+ )(original_graph ).graph_module
272
269
self .check_op_counts (
273
- graph_module ,
270
+ converted_graph ,
274
271
expected_op_counts = {
275
- # Verify that no dequant/quant pair was replaced with requantize.
276
- # quantize -> permute -> dequantize should not be replaced with requantize.
272
+ # Verify that dequant/quant pair was replaced with requantize.
277
273
exir_ops .edge .quantized_decomposed .quantize_per_tensor .default : 0 ,
278
274
exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default : 0 ,
279
275
exir_ops .edge .cadence .requantize .default : 1 ,
0 commit comments