|
28 | 28 | to_edge,
|
29 | 29 | )
|
30 | 30 | from executorch.exir.pass_base import PassResult
|
| 31 | +from torch._inductor.decomposition import remove_decompositions |
31 | 32 | from torch.ao.quantization.pt2e.export_utils import model_is_exported
|
32 | 33 | from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
33 | 34 |
|
@@ -59,15 +60,29 @@ def convert_pt2(
|
59 | 60 | """
|
60 | 61 |
|
61 | 62 | # Export with dynamo
|
62 |
| - model_gm = torch.export.export_for_training(model, inputs).module() |
| 63 | + decomp_table = torch.export.default_decompositions() |
| 64 | + ops_to_keep = [ |
| 65 | + torch.ops.aten.conv1d.default, |
| 66 | + torch.ops.aten.conv2d.default, |
| 67 | + torch.ops.aten.layer_norm.default, |
| 68 | + torch.ops.aten.linear.default, |
| 69 | + torch.ops.aten.matmul.default, |
| 70 | + ] |
| 71 | + # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any |
| 72 | + remove_decompositions(decomp_table, ops_to_keep) |
| 73 | + model_gm = ( |
| 74 | + torch.export.export_for_training(model, inputs) |
| 75 | + .run_decompositions(decomp_table) |
| 76 | + .module() |
| 77 | + ) |
63 | 78 |
|
64 |
| - if model_gm_has_SDPA(model_gm): # pyre-fixme[6] |
| 79 | + if model_gm_has_SDPA(model_gm): |
65 | 80 | # Decompose SDPA
|
66 |
| - DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] |
| 81 | + DecomposeScaledDotProductAttention(False)(model_gm) |
67 | 82 |
|
68 | 83 | # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
|
69 | 84 | # for details).
|
70 |
| - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] |
| 85 | + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) |
71 | 86 | assert result is not None
|
72 | 87 | model_gm = result.graph_module
|
73 | 88 |
|
|
0 commit comments