Skip to content

Commit a56d121

Browse files
committed
[ET-VK] Include FuseDequantLinearPass() in vulkan_preprocess
Pull Request resolved: #6168 ## Context Include `FuseDequantLinearPass` as a part of `vulkan_preprocess`, so that fusing the quant/dequant nodes added by `VulkanQuantizer` can be done as part of the lowering process. ghstack-source-id: 247613964 @exported-using-ghexport Differential Revision: [D64249613](https://our.internmc.facebook.com/intern/diff/D64249613/)
1 parent d094b09 commit a56d121

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ runtime.python_library(
2525
"//executorch/backends/transforms:addmm_mm_to_linear",
2626
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
2727
"//executorch/backends/transforms:fuse_conv_with_clamp",
28+
"//executorch/backends/transforms:fuse_dequant_linear",
2829
"//executorch/backends/transforms:fuse_view_copy",
2930
"//executorch/backends/transforms:mean_to_sum_div",
3031
"//executorch/backends/transforms:remove_clone_ops",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def __contains__(self, op):
4545

4646
PRIM_OPS = [
4747
operator.getitem,
48+
# Quantization related ops will be fused via graph passes
49+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
50+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
51+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
52+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
53+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
54+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
4855
]
4956

5057
SUPPORTS_DYNAMIC_SHAPE = [

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FuseBatchNormWithConvPass,
1414
)
1515
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
16+
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1617
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
1718
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1819
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
@@ -59,6 +60,7 @@ def preprocess( # noqa: C901
5960
passes = [
6061
RemoveCloneOpsTransform(),
6162
AddmmToLinearTransform(),
63+
FuseDequantLinearPass(),
6264
FuseViewCopyTransform(),
6365
FuseBatchNormWithConvPass(program),
6466
FuseClampPass(),

0 commit comments

Comments
 (0)