|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa
|
14 | 14 | import torch
|
15 | 15 | from executorch.backends.cadence.aot import compiler
|
16 |
| -from executorch.backends.cadence.aot.compiler import ( |
17 |
| - export_to_edge, |
18 |
| - quantize_and_export_to_edge, |
19 |
| -) |
20 | 16 | from executorch.backends.cadence.aot.fuse_ops import (
|
21 | 17 | FuseFullThenReshapePass,
|
22 | 18 | FuseMulScalarIntoDequantPass,
|
@@ -336,94 +332,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
|
336 | 332 | )
|
337 | 333 |
|
338 | 334 | def test_replace_dequant_quant_with_requantize(self):
|
339 |
| - class M(torch.nn.Module): |
340 |
| - def __init__(self): |
341 |
| - super().__init__() |
342 |
| - |
343 |
| - def forward(self, x): |
344 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
345 |
| - x, 1.2, 3, 0, 127, torch.int8 |
346 |
| - ) |
347 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
348 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
349 |
| - x, 4.5, 6, 0, 127, torch.int8 |
350 |
| - ) |
351 |
| - return x |
352 |
| - |
353 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
354 |
| - model = M() |
355 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
356 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 335 | + builder = GraphBuilder() |
| 336 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 337 | + dequant = builder.call_operator( |
| 338 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 339 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 340 | + ) |
| 341 | + quant = builder.call_operator( |
| 342 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 343 | + args=(dequant, 4.5, 6, 0, 127, torch.int8), |
| 344 | + ) |
| 345 | + builder.output(quant) |
| 346 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 347 | + builder.get_graph_module() |
| 348 | + ).graph_module |
357 | 349 |
|
358 | 350 | self.check_op_counts(
|
359 | 351 | graph_module,
|
360 | 352 | expected_op_counts={
|
361 |
| - # Verify that dequant -> permute -> quant was replaced with permute -> requantize. |
| 353 | + # Verify that dequant -> quant was replaced with requantize. |
362 | 354 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
363 | 355 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
364 | 356 | exir_ops.edge.cadence.requantize.default: 1,
|
365 | 357 | },
|
366 | 358 | )
|
367 | 359 |
|
368 | 360 | def test_replace_dequant_permute_quant_with_requantize(self):
|
369 |
| - class M(torch.nn.Module): |
370 |
| - def __init__(self): |
371 |
| - super().__init__() |
372 |
| - |
373 |
| - def forward(self, x): |
374 |
| - x = torch.ops.quantized_decomposed.dequantize_per_tensor( |
375 |
| - x, 1.2, 3, 0, 127, torch.int8 |
376 |
| - ) |
377 |
| - x = torch.permute(x, [2, 0, 1, 3]) |
378 |
| - x = torch.ops.quantized_decomposed.quantize_per_tensor( |
379 |
| - x, 4.5, 6, 0, 127, torch.int8 |
380 |
| - ) |
381 |
| - return x |
382 |
| - |
383 |
| - inputs = torch.randn(2, 12, 1, 6).to(torch.int8) |
384 |
| - model = M() |
385 |
| - graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module |
386 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 361 | + builder = GraphBuilder() |
| 362 | + x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) |
| 363 | + dequant = builder.call_operator( |
| 364 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 365 | + args=(x, 1.2, 3, 0, 127, torch.int8), |
| 366 | + ) |
| 367 | + permute = builder.call_operator( |
| 368 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3]) |
| 369 | + ) |
| 370 | + quant = builder.call_operator( |
| 371 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 372 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 373 | + ) |
| 374 | + builder.output(quant) |
| 375 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 376 | + builder.get_graph_module() |
| 377 | + ).graph_module |
387 | 378 |
|
388 | 379 | self.check_op_counts(
|
389 | 380 | graph_module,
|
390 | 381 | expected_op_counts={
|
391 | 382 | # Verify that dequant -> permute -> quant was replaced with permute -> requantize.
|
392 | 383 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
|
393 | 384 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
|
| 385 | + exir_ops.edge.aten.permute_copy.default: 1, |
394 | 386 | exir_ops.edge.cadence.requantize.default: 1,
|
395 | 387 | },
|
396 | 388 | )
|
397 | 389 |
|
398 | 390 | def test_remove_nop_dequant_quant(self):
|
399 |
| - class M(torch.nn.Module): |
400 |
| - def __init__(self): |
401 |
| - super(M, self).__init__() |
402 |
| - self.lin1 = torch.nn.Linear(6, 12, bias=False) |
403 |
| - self.lin2 = torch.nn.Linear(12, 24, bias=False) |
| 391 | + LEADING_DIMS: Final[int] = 12 |
| 392 | + IN_DIM: Final[int] = 6 |
| 393 | + OUT_DIM: Final[int] = 12 |
404 | 394 |
|
405 |
| - def forward(self, x): |
406 |
| - x = self.lin1(x) |
407 |
| - # redundant dequant+quant will be created around this permute |
408 |
| - x = torch.permute(x, [0, 2, 1, 3]) |
409 |
| - x = self.lin2(x) |
410 |
| - return x |
411 |
| - |
412 |
| - inputs = torch.randn(2, 12, 1, 6) |
413 |
| - model = M() |
414 |
| - graph_module = ( |
415 |
| - quantize_and_export_to_edge(model, (inputs,)) |
416 |
| - .exported_program() |
417 |
| - .graph_module |
| 395 | + builder = GraphBuilder() |
| 396 | + x = builder.placeholder( |
| 397 | + "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) |
| 398 | + ) |
| 399 | + quant1 = builder.call_operator( |
| 400 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 401 | + args=(x, 4.5, 6, 0, 127, torch.int8), |
| 402 | + ) |
| 403 | + weights = builder.call_operator( |
| 404 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) |
| 405 | + ) |
| 406 | + bias = builder.call_operator( |
| 407 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 408 | + ) |
| 409 | + weight_zero_point = builder.call_operator( |
| 410 | + op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) |
| 411 | + ) |
| 412 | + out_multiplier = builder.call_operator( |
| 413 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) |
| 414 | + ) |
| 415 | + out_shift = builder.call_operator( |
| 416 | + op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) |
418 | 417 | )
|
419 |
| - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module |
| 418 | + linear1 = builder.call_operator( |
| 419 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 420 | + args=( |
| 421 | + quant1, |
| 422 | + weights, |
| 423 | + bias, |
| 424 | + 0, # src_zero_point |
| 425 | + weight_zero_point, |
| 426 | + out_multiplier, |
| 427 | + out_shift, |
| 428 | + 0, # out_zero_point |
| 429 | + None, |
| 430 | + ), |
| 431 | + ) |
| 432 | + dequant1 = builder.call_operator( |
| 433 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 434 | + args=(linear1, 1.2, 3, 0, 127, torch.int8), |
| 435 | + ) |
| 436 | + permute = builder.call_operator( |
| 437 | + op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0]) |
| 438 | + ) |
| 439 | + quant2 = builder.call_operator( |
| 440 | + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 441 | + args=(permute, 4.5, 6, 0, 127, torch.int8), |
| 442 | + ) |
| 443 | + linear2 = builder.call_operator( |
| 444 | + op=exir_ops.edge.cadence.quantized_linear.default, |
| 445 | + args=( |
| 446 | + quant2, |
| 447 | + weights, |
| 448 | + bias, |
| 449 | + 0, # src_zero_point |
| 450 | + weight_zero_point, |
| 451 | + out_multiplier, |
| 452 | + out_shift, |
| 453 | + 0, # out_zero_point |
| 454 | + None, |
| 455 | + ), |
| 456 | + ) |
| 457 | + dequant2 = builder.call_operator( |
| 458 | + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 459 | + args=(linear2, 1.2, 3, 0, 127, torch.int8), |
| 460 | + ) |
| 461 | + builder.output(dequant2) |
| 462 | + graph_module = FuseQuantDequantToRequantizePass()( |
| 463 | + builder.get_graph_module() |
| 464 | + ).graph_module |
420 | 465 | self.check_op_counts(
|
421 | 466 | graph_module,
|
422 | 467 | expected_op_counts={
|
423 |
| - # Verify that one dequant/quant pair was removed |
424 |
| - # Expect 1 quantize ops: 1 input |
| 468 | + # Verify that one dequant/quant pair was removed from chain: |
| 469 | + # quant->linear->dequant->permute->quant->linear->dequant |
| 470 | + # gets converted to: |
| 471 | + # quant->linear->permute->linear->dequant |
425 | 472 | exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
|
426 |
| - # Expect 1 dequant op at the end (output of second linear) |
427 | 473 | exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
|
428 | 474 | },
|
429 | 475 | )
|
|
0 commit comments