|
15 | 15 | # Import passes
|
16 | 16 | import executorch.exir.memory_planning # noqa
|
17 | 17 | import torch
|
18 |
| -from executorch.exir import EdgeCompileConfig, memory, to_edge |
| 18 | +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge |
19 | 19 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
|
20 | 20 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
21 | 21 | from executorch.exir.emit import emit_program
|
|
50 | 50 | from functorch.experimental import control_flow
|
51 | 51 |
|
52 | 52 | from torch import nn
|
| 53 | + |
| 54 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 55 | +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| 56 | + get_symmetric_quantization_config, |
| 57 | + XNNPACKQuantizer, |
| 58 | +) |
53 | 59 | from torch.export import export
|
54 | 60 | from torch.fx import GraphModule, subgraph_rewriter
|
55 | 61 | from torch.fx.experimental.proxy_tensor import make_fx
|
@@ -1244,3 +1250,173 @@ def forward(self, x):
|
1244 | 1250 | # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
|
1245 | 1251 | # return (copy__default, aten_add_tensor)
|
1246 | 1252 | self.assertEqual(count_copies(gm), 1)
|
| 1253 | + |
| 1254 | + def test_remove_quantized_op_noop_pass(self) -> None: |
| 1255 | + class TestAddSliceNoop(torch.nn.Module): |
| 1256 | + def __init__(self): |
| 1257 | + super().__init__() |
| 1258 | + |
| 1259 | + def forward(self, x): |
| 1260 | + x = x + x |
| 1261 | + x = x + x[:] |
| 1262 | + return x |
| 1263 | + |
| 1264 | + class TestAddSliceNotNoop(torch.nn.Module): |
| 1265 | + def __init__(self): |
| 1266 | + super().__init__() |
| 1267 | + |
| 1268 | + def forward(self, x): |
| 1269 | + x = x + x |
| 1270 | + x = x + x[:1] |
| 1271 | + return x |
| 1272 | + |
| 1273 | + def count_dq_nodes(gm: torch.fx.GraphModule) -> int: |
| 1274 | + return sum( |
| 1275 | + ( |
| 1276 | + node.target |
| 1277 | + in ( |
| 1278 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 1279 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 1280 | + ) |
| 1281 | + ) |
| 1282 | + for node in gm.graph.nodes |
| 1283 | + ) |
| 1284 | + |
| 1285 | + def count_q_nodes(gm: torch.fx.GraphModule) -> int: |
| 1286 | + return sum( |
| 1287 | + ( |
| 1288 | + node.target |
| 1289 | + in ( |
| 1290 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 1291 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 1292 | + ) |
| 1293 | + ) |
| 1294 | + for node in gm.graph.nodes |
| 1295 | + ) |
| 1296 | + |
| 1297 | + def quantize_model( |
| 1298 | + m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor] |
| 1299 | + ) -> Tuple[EdgeProgramManager, int, int]: |
| 1300 | + # program capture |
| 1301 | + m = torch._export.capture_pre_autograd_graph( |
| 1302 | + m_eager, |
| 1303 | + example_inputs, |
| 1304 | + ) |
| 1305 | + |
| 1306 | + quantizer = XNNPACKQuantizer() |
| 1307 | + quantization_config = get_symmetric_quantization_config() |
| 1308 | + quantizer.set_global(quantization_config) |
| 1309 | + m = prepare_pt2e(m, quantizer) |
| 1310 | + m = convert_pt2e(m, fold_quantize=True) |
| 1311 | + ep = torch.export.export(m, example_inputs) |
| 1312 | + dq_nodes_pre = count_dq_nodes(ep.graph_module) |
| 1313 | + q_nodes_pre = count_q_nodes(ep.graph_module) |
| 1314 | + edge = to_edge( |
| 1315 | + ep, compile_config=EdgeCompileConfig(_check_ir_validity=False) |
| 1316 | + ) |
| 1317 | + return edge, dq_nodes_pre, q_nodes_pre |
| 1318 | + |
| 1319 | + example_inputs = (torch.randn(9, 8),) |
| 1320 | + model = TestAddSliceNoop() |
| 1321 | + m_eager = model.eval() |
| 1322 | + edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) |
| 1323 | + |
| 1324 | + dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) |
| 1325 | + q_nodes_post = count_q_nodes(edge.exported_program().graph_module) |
| 1326 | + # One dq and one q node around the slice copy should have been removed. |
| 1327 | + self.assertEqual(dq_nodes_pre - dq_nodes_post, 1) |
| 1328 | + self.assertEqual(q_nodes_pre - q_nodes_post, 1) |
| 1329 | + |
| 1330 | + # Check that the slice_copy is removed by the RemoveNoopPass. |
| 1331 | + for node in edge.exported_program().graph_module.graph.nodes: |
| 1332 | + self.assertFalse("slice" in str(node.target)) |
| 1333 | + |
| 1334 | + model = TestAddSliceNotNoop() |
| 1335 | + m_eager = model.eval() |
| 1336 | + edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) |
| 1337 | + |
| 1338 | + dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) |
| 1339 | + q_nodes_post = count_q_nodes(edge.exported_program().graph_module) |
| 1340 | + # One dq and one q node around the slice copy should have been removed. |
| 1341 | + self.assertEqual(dq_nodes_pre, dq_nodes_post) |
| 1342 | + self.assertEqual(q_nodes_pre, q_nodes_post) |
| 1343 | + |
| 1344 | + # Check that the slice_copy is not removed by the RemoveNoopPass. |
| 1345 | + self.assertTrue( |
| 1346 | + any( |
| 1347 | + "slice" in str(node.target) |
| 1348 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1349 | + ) |
| 1350 | + ) |
| 1351 | + |
| 1352 | + def test_dq_q_no_op_pass(self) -> None: |
| 1353 | + class TestDqQ(torch.nn.Module): |
| 1354 | + def __init__(self): |
| 1355 | + super().__init__() |
| 1356 | + |
| 1357 | + def forward(self, x): |
| 1358 | + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( |
| 1359 | + x, 1.0, 0, -128, 127, torch.int8 |
| 1360 | + ) |
| 1361 | + q = torch.ops.quantized_decomposed.quantize_per_tensor.default( |
| 1362 | + dq, 1.0, 0, -128, 127, torch.int8 |
| 1363 | + ) |
| 1364 | + return q |
| 1365 | + |
| 1366 | + model = TestDqQ() |
| 1367 | + m_eager = model.eval() |
| 1368 | + ep = torch.export.export(m_eager, (torch.randn(9, 8),)) |
| 1369 | + edge = to_edge(ep) |
| 1370 | + # Check that the dq and q nodes are not touched by the RemoveNoopPass. |
| 1371 | + self.assertTrue( |
| 1372 | + any( |
| 1373 | + "dequantize" in str(node.target) |
| 1374 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1375 | + ) |
| 1376 | + ) |
| 1377 | + self.assertTrue( |
| 1378 | + any( |
| 1379 | + "quantize" in str(node.target) |
| 1380 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1381 | + ) |
| 1382 | + ) |
| 1383 | + |
| 1384 | + def test_dq_q_different_qparams(self) -> None: |
| 1385 | + class TestDqQDifferentQParam(torch.nn.Module): |
| 1386 | + def __init__(self): |
| 1387 | + super().__init__() |
| 1388 | + |
| 1389 | + def forward(self, x): |
| 1390 | + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( |
| 1391 | + x, 1.0, 0, -128, 127, torch.int8 |
| 1392 | + ) |
| 1393 | + slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0) |
| 1394 | + q = torch.ops.quantized_decomposed.quantize_per_tensor.default( |
| 1395 | + slice_copy_output, 1.0, 0, -127, 127, torch.int8 |
| 1396 | + ) |
| 1397 | + return q |
| 1398 | + |
| 1399 | + model = TestDqQDifferentQParam() |
| 1400 | + m_eager = model.eval() |
| 1401 | + ep = torch.export.export(m_eager, (torch.randn(9, 8),)) |
| 1402 | + edge = to_edge(ep) |
| 1403 | + print(edge.exported_program().graph_module.graph) |
| 1404 | + # Check that the dq and q nodes are not touched by the RemoveNoopPass. |
| 1405 | + self.assertTrue( |
| 1406 | + any( |
| 1407 | + "dequantize" in str(node.target) |
| 1408 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1409 | + ) |
| 1410 | + ) |
| 1411 | + self.assertTrue( |
| 1412 | + any( |
| 1413 | + "quantize" in str(node.target) |
| 1414 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1415 | + ) |
| 1416 | + ) |
| 1417 | + self.assertFalse( |
| 1418 | + any( |
| 1419 | + "slice" in str(node.target) |
| 1420 | + for node in edge.exported_program().graph_module.graph.nodes |
| 1421 | + ) |
| 1422 | + ) |
0 commit comments