|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa
|
14 | 14 | import torch
|
15 | 15 | from executorch.backends.cadence.aot import compiler
|
16 |
| -from executorch.backends.cadence.aot.compiler import ( |
17 |
| - export_to_edge, |
18 |
| - quantize_and_export_to_edge, |
19 |
| -) |
| 16 | +from executorch.backends.cadence.aot.compiler import export_to_edge |
20 | 17 | from executorch.backends.cadence.aot.fuse_ops import (
|
21 | 18 | FuseFullThenReshapePass,
|
22 | 19 | FuseMulScalarIntoDequantPass,
|
@@ -343,94 +340,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
|
343 | 340 | )
|
344 | 341 |
|
345 | 342 | def test_replace_dequant_quant_with_requantize(self):
|
346 |
| - class M(torch.nn.Module): |
347 |
| - def __init__(self): |
348 |
| - super().__init__() |
349 |
| - |
350 |
| - def forward(self, x): |
351 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
352 |
| - x, 1.2, 3, 0, 127, torch.int8 |
353 |
| - ) |
354 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
355 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
356 |
| - x, 4.5, 6, 0, 127, torch.int8 |
357 |
| - ) |
358 |
| - return x |
359 |
| - |
360 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
361 |
| - model = M() |
362 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
363 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 343 | + builder = GraphBuilder() |
| 344 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 345 | + dequant = builder.call_operator( |
| 346 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 347 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 348 | + ) |
| 349 | + quant = builder.call_operator( |
| 350 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 351 | + args=(dequant, 4.5, 6, 0, 127, torch.int8), |
| 352 | + ) |
| 353 | + builder.output(quant) |
| 354 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 355 | + builder.get_graph_module() |
| 356 | + ).graph_module |
364 | 357 |
|
365 | 358 | self.check_op_counts(
|
366 | 359 | graph_module,
|
367 | 360 | expected_op_counts={
|
368 |
| - # Verify that dequant -> permute -> quant was replaced with permute -> requantize. |
| 361 | + # Verify that dequant -> quant was replaced with requantize. |
369 | 362 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
370 | 363 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
371 | 364 | exir_ops.edge.cadence.requantize.default: 1,
|
372 | 365 | },
|
373 | 366 | )
|
374 | 367 |
|
375 | 368 | def test_replace_dequant_permute_quant_with_requantize(self):
|
376 |
| - class M(torch.nn.Module): |
377 |
| - def __init__(self): |
378 |
| - super().__init__() |
379 |
| - |
380 |
| - def forward(self, x): |
381 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
382 |
| - x, 1.2, 3, 0, 127, torch.int8 |
383 |
| - ) |
384 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
385 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
386 |
| - x, 4.5, 6, 0, 127, torch.int8 |
387 |
| - ) |
388 |
| - return x |
389 |
| - |
390 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
391 |
| - model = M() |
392 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
393 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 369 | + builder = GraphBuilder() |
| 370 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 371 | + dequant = builder.call_operator( |
| 372 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 373 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 374 | + ) |
| 375 | + permute = builder.call_operator( |
| 376 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3]) |
| 377 | + ) |
| 378 | + quant = builder.call_operator( |
| 379 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 380 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 381 | + ) |
| 382 | + builder.output(quant) |
| 383 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 384 | + builder.get_graph_module() |
| 385 | + ).graph_module |
394 | 386 |
|
395 | 387 | self.check_op_counts(
|
396 | 388 | graph_module,
|
397 | 389 | expected_op_counts={
|
398 | 390 | # Verify that dequant -> permute -> quant was replaced with permute -> requantize.
|
399 | 391 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
400 | 392 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
| 393 | + exir_ops.edge.aten.permute_copy.default: 1, |
401 | 394 | exir_ops.edge.cadence.requantize.default: 1,
|
402 | 395 | },
|
403 | 396 | )
|
404 | 397 |
|
405 | 398 | def test_remove_nop_dequant_quant(self):
|
406 |
| - class M(torch.nn.Module): |
407 |
| - def __init__(self): |
408 |
| - super(M, self).__init__() |
409 |
| - self.lin1 = torch.nn.Linear(6, 12, bias=False) |
410 |
| - self.lin2 = torch.nn.Linear(12, 24, bias=False) |
411 |
| - |
412 |
| - def forward(self, x): |
413 |
| - x = self.lin1(x) |
414 |
| - # redundant dequant+quant will be created around this permute |
415 |
| - x = torch.permute(x, [0, 2, 1, 3]) |
416 |
| - x = self.lin2(x) |
417 |
| - return x |
| 399 | + LEADING_DIMS: Final[int] = 12 |
| 400 | + IN_DIM: Final[int] = 6 |
| 401 | + OUT_DIM: Final[int] = 12 |
418 | 402 |
|
419 |
| - inputs = torch.randn(2, 12, 1, 6) |
420 |
| - model = M() |
421 |
| - graph_module = ( |
422 |
| - quantize_and_export_to_edge(model, (inputs,)) |
423 |
| - .exported_program() |
424 |
| - .graph_module |
| 403 | + builder = GraphBuilder() |
| 404 | + x = builder.placeholder( |
| 405 | + "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) |
425 | 406 | )
|
426 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 407 | + quant1 = builder.call_operator( |
| 408 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 409 | + args=(x, 4.5, 6, 0, 127, torch.int8), |
| 410 | + ) |
| 411 | + weights = builder.call_operator( |
| 412 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) |
| 413 | + ) |
| 414 | + bias = builder.call_operator( |
| 415 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 416 | + ) |
| 417 | + weight_zero_point = builder.call_operator( |
| 418 | + op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) |
| 419 | + ) |
| 420 | + out_multiplier = builder.call_operator( |
| 421 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 422 | + ) |
| 423 | + out_shift = builder.call_operator( |
| 424 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) |
| 425 | + ) |
| 426 | + linear1 = builder.call_operator( |
| 427 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 428 | + args=( |
| 429 | + quant1, |
| 430 | + weights, |
| 431 | + bias, |
| 432 | + 0, # src_zero_point |
| 433 | + weight_zero_point, |
| 434 | + out_multiplier, |
| 435 | + out_shift, |
| 436 | + 0, # out_zero_point |
| 437 | + None, |
| 438 | + ), |
| 439 | + ) |
| 440 | + dequant1 = builder.call_operator( |
| 441 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 442 | + args=(linear1, 1.2, 3, 0, 127, torch.int8), |
| 443 | + ) |
| 444 | + permute = builder.call_operator( |
| 445 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0]) |
| 446 | + ) |
| 447 | + quant2 = builder.call_operator( |
| 448 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 449 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 450 | + ) |
| 451 | + linear2 = builder.call_operator( |
| 452 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 453 | + args=( |
| 454 | + quant2, |
| 455 | + weights, |
| 456 | + bias, |
| 457 | + 0, # src_zero_point |
| 458 | + weight_zero_point, |
| 459 | + out_multiplier, |
| 460 | + out_shift, |
| 461 | + 0, # out_zero_point |
| 462 | + None, |
| 463 | + ), |
| 464 | + ) |
| 465 | + dequant2 = builder.call_operator( |
| 466 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 467 | + args=(linear2, 1.2, 3, 0, 127, torch.int8), |
| 468 | + ) |
| 469 | + builder.output(dequant2) |
| 470 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 471 | + builder.get_graph_module() |
| 472 | + ).graph_module |
427 | 473 | self.check_op_counts(
|
428 | 474 | graph_module,
|
429 | 475 | expected_op_counts={
|
430 |
| - # Verify that one dequant/quant pair was removed |
431 |
| - # Expect 1 quantize ops: 1 input |
| 476 | + # Verify that one dequant/quant pair was removed from chain: |
| 477 | + # quant->linear->dequant->permute->quant->linear->dequant |
| 478 | + # gets converted to: |
| 479 | + # quant->linear->permute->linear->dequant |
432 | 480 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
|
433 |
| - # Expect 1 dequant op at the end (output of second linear) |
434 | 481 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
|
435 | 482 | },
|
436 | 483 | )
|
|
0 commit comments