|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa
|
14 | 14 | import torch
|
15 | 15 | from executorch.backends.cadence.aot import compiler
|
16 |
| -from executorch.backends.cadence.aot.compiler import ( |
17 |
| - export_to_edge, |
18 |
| - quantize_and_export_to_edge, |
19 |
| -) |
| 16 | +from executorch.backends.cadence.aot.compiler import export_to_edge |
20 | 17 | from executorch.backends.cadence.aot.fuse_ops import (
|
21 | 18 | FuseFullThenReshapePass,
|
22 | 19 | FuseMulScalarIntoDequantPass,
|
@@ -340,94 +337,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
|
340 | 337 | )
|
341 | 338 |
|
342 | 339 | def test_replace_dequant_quant_with_requantize(self):
|
343 |
| - class M(torch.nn.Module): |
344 |
| - def __init__(self): |
345 |
| - super().__init__() |
346 |
| - |
347 |
| - def forward(self, x): |
348 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
349 |
| - x, 1.2, 3, 0, 127, torch.int8 |
350 |
| - ) |
351 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
352 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
353 |
| - x, 4.5, 6, 0, 127, torch.int8 |
354 |
| - ) |
355 |
| - return x |
356 |
| - |
357 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
358 |
| - model = M() |
359 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
360 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 340 | + builder = GraphBuilder() |
| 341 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 342 | + dequant = builder.call_operator( |
| 343 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 344 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 345 | + ) |
| 346 | + quant = builder.call_operator( |
| 347 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 348 | + args=(dequant, 4.5, 6, 0, 127, torch.int8), |
| 349 | + ) |
| 350 | + builder.output(quant) |
| 351 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 352 | + builder.get_graph_module() |
| 353 | + ).graph_module |
361 | 354 |
|
362 | 355 | self.check_op_counts(
|
363 | 356 | graph_module,
|
364 | 357 | expected_op_counts={
|
365 |
| - # Verify that dequant -> permute -> quant was replaced with permute -> requantize. |
| 358 | + # Verify that dequant -> quant was replaced with requantize. |
366 | 359 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
367 | 360 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
368 | 361 | exir_ops.edge.cadence.requantize.default: 1,
|
369 | 362 | },
|
370 | 363 | )
|
371 | 364 |
|
372 | 365 | def test_replace_dequant_permute_quant_with_requantize(self):
|
373 |
| - class M(torch.nn.Module): |
374 |
| - def __init__(self): |
375 |
| - super().__init__() |
376 |
| - |
377 |
| - def forward(self, x): |
378 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
379 |
| - x, 1.2, 3, 0, 127, torch.int8 |
380 |
| - ) |
381 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
382 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
383 |
| - x, 4.5, 6, 0, 127, torch.int8 |
384 |
| - ) |
385 |
| - return x |
386 |
| - |
387 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
388 |
| - model = M() |
389 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
390 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 366 | + builder = GraphBuilder() |
| 367 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 368 | + dequant = builder.call_operator( |
| 369 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 370 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 371 | + ) |
| 372 | + permute = builder.call_operator( |
| 373 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3]) |
| 374 | + ) |
| 375 | + quant = builder.call_operator( |
| 376 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 377 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 378 | + ) |
| 379 | + builder.output(quant) |
| 380 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 381 | + builder.get_graph_module() |
| 382 | + ).graph_module |
391 | 383 |
|
392 | 384 | self.check_op_counts(
|
393 | 385 | graph_module,
|
394 | 386 | expected_op_counts={
|
395 | 387 | # Verify that dequant -> permute -> quant was replaced with permute -> requantize.
|
396 | 388 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
397 | 389 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
| 390 | + exir_ops.edge.aten.permute_copy.default: 1, |
398 | 391 | exir_ops.edge.cadence.requantize.default: 1,
|
399 | 392 | },
|
400 | 393 | )
|
401 | 394 |
|
402 | 395 | def test_remove_nop_dequant_quant(self):
|
403 |
| - class M(torch.nn.Module): |
404 |
| - def __init__(self): |
405 |
| - super(M, self).__init__() |
406 |
| - self.lin1 = torch.nn.Linear(6, 12, bias=False) |
407 |
| - self.lin2 = torch.nn.Linear(12, 24, bias=False) |
408 |
| - |
409 |
| - def forward(self, x): |
410 |
| - x = self.lin1(x) |
411 |
| - # redundant dequant+quant will be created around this permute |
412 |
| - x = torch.permute(x, [0, 2, 1, 3]) |
413 |
| - x = self.lin2(x) |
414 |
| - return x |
| 396 | + LEADING_DIMS: Final[int] = 12 |
| 397 | + IN_DIM: Final[int] = 6 |
| 398 | + OUT_DIM: Final[int] = 12 |
415 | 399 |
|
416 |
| - inputs = torch.randn(2, 12, 1, 6) |
417 |
| - model = M() |
418 |
| - graph_module = ( |
419 |
| - quantize_and_export_to_edge(model, (inputs,)) |
420 |
| - .exported_program() |
421 |
| - .graph_module |
| 400 | + builder = GraphBuilder() |
| 401 | + x = builder.placeholder( |
| 402 | + "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) |
422 | 403 | )
|
423 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 404 | + quant1 = builder.call_operator( |
| 405 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 406 | + args=(x, 4.5, 6, 0, 127, torch.int8), |
| 407 | + ) |
| 408 | + weights = builder.call_operator( |
| 409 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) |
| 410 | + ) |
| 411 | + bias = builder.call_operator( |
| 412 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 413 | + ) |
| 414 | + weight_zero_point = builder.call_operator( |
| 415 | + op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) |
| 416 | + ) |
| 417 | + out_multiplier = builder.call_operator( |
| 418 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 419 | + ) |
| 420 | + out_shift = builder.call_operator( |
| 421 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) |
| 422 | + ) |
| 423 | + linear1 = builder.call_operator( |
| 424 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 425 | + args=( |
| 426 | + quant1, |
| 427 | + weights, |
| 428 | + bias, |
| 429 | + 0, # src_zero_point |
| 430 | + weight_zero_point, |
| 431 | + out_multiplier, |
| 432 | + out_shift, |
| 433 | + 0, # out_zero_point |
| 434 | + None, |
| 435 | + ), |
| 436 | + ) |
| 437 | + dequant1 = builder.call_operator( |
| 438 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 439 | + args=(linear1, 1.2, 3, 0, 127, torch.int8), |
| 440 | + ) |
| 441 | + permute = builder.call_operator( |
| 442 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0]) |
| 443 | + ) |
| 444 | + quant2 = builder.call_operator( |
| 445 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 446 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 447 | + ) |
| 448 | + linear2 = builder.call_operator( |
| 449 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 450 | + args=( |
| 451 | + quant2, |
| 452 | + weights, |
| 453 | + bias, |
| 454 | + 0, # src_zero_point |
| 455 | + weight_zero_point, |
| 456 | + out_multiplier, |
| 457 | + out_shift, |
| 458 | + 0, # out_zero_point |
| 459 | + None, |
| 460 | + ), |
| 461 | + ) |
| 462 | + dequant2 = builder.call_operator( |
| 463 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 464 | + args=(linear2, 1.2, 3, 0, 127, torch.int8), |
| 465 | + ) |
| 466 | + builder.output(dequant2) |
| 467 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 468 | + builder.get_graph_module() |
| 469 | + ).graph_module |
424 | 470 | self.check_op_counts(
|
425 | 471 | graph_module,
|
426 | 472 | expected_op_counts={
|
427 |
| - # Verify that one dequant/quant pair was removed |
428 |
| - # Expect 1 quantize ops: 1 input |
| 473 | + # Verify that one dequant/quant pair was removed from chain: |
| 474 | + # quant->linear->dequant->permute->quant->linear->dequant |
| 475 | + # gets converted to: |
| 476 | + # quant->linear->permute->linear->dequant |
429 | 477 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
|
430 |
| - # Expect 1 dequant op at the end (output of second linear) |
431 | 478 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
|
432 | 479 | },
|
433 | 480 | )
|
|
0 commit comments