Skip to content

Commit a3a74ac

Browse files
committed
[ET-VK] Include FuseDequantLinearPass() in vulkan_preprocess
## Context Include `FuseDequantLinearPass` as a part of `vulkan_preprocess`, so that fusing the quant/dequant nodes added by `VulkanQuantizer` can be done as part of the lowering process. Differential Revision: [D64249613](https://our.internmc.facebook.com/intern/diff/D64249613/) ghstack-source-id: 247543939 Pull Request resolved: #6168
1 parent 5696b35 commit a3a74ac

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ runtime.python_library(
2525
"//executorch/backends/transforms:addmm_mm_to_linear",
2626
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
2727
"//executorch/backends/transforms:fuse_conv_with_clamp",
28+
"//executorch/backends/transforms:fuse_dequant_linear",
2829
"//executorch/backends/transforms:fuse_view_copy",
2930
"//executorch/backends/transforms:mean_to_sum_div",
3031
"//executorch/backends/transforms:remove_clone_ops",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __contains__(self, op):
4545

4646
PRIM_OPS = [
4747
operator.getitem,
48+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
4849
]
4950

5051
SUPPORTS_DYNAMIC_SHAPE = [

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
FuseBatchNormWithConvPass,
1414
)
1515
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
16+
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1617
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
1718
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
1819
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
@@ -59,6 +60,7 @@ def preprocess( # noqa: C901
5960
passes = [
6061
RemoveCloneOpsTransform(),
6162
AddmmToLinearTransform(),
63+
FuseDequantLinearPass(),
6264
FuseViewCopyTransform(),
6365
FuseBatchNormWithConvPass(program),
6466
FuseClampPass(),

0 commit comments

Comments
 (0)