Skip to content

Commit 3405e7a

Browse files
pytorchbotDannyYuyang-quic
authored andcommitted
Remove Per-Op mode from DQPartitioner (pytorch#9502)
Differential Revision: D71427234
1 parent e0363c6 commit 3405e7a

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]:
115115
class XnnpackDynamicallyQuantizedPartitioner(XnnpackPartitioner):
116116
def __init__(self):
117117
super().__init__(
118-
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True
118+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
119119
)
120120

121121

backends/xnnpack/test/ops/test_linear.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ def forward(self, x, y):
191191
return a + b
192192

193193

194+
class SharedDQChain(torch.nn.Module):
195+
def __init__(self, input_size, output_size):
196+
super().__init__()
197+
self.linear1_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
198+
self.linear1_bias = torch.nn.Parameter(torch.rand(output_size))
199+
200+
self.linear2_weight = torch.nn.Parameter(torch.rand(output_size, input_size))
201+
self.linear2_bias = torch.nn.Parameter(torch.rand(output_size))
202+
203+
def forward(self, x):
204+
a = torch.nn.functional.linear(x, self.linear1_weight, self.linear1_bias)
205+
b = torch.nn.functional.linear(x, self.linear2_weight, self.linear2_bias)
206+
return a + b
207+
208+
194209
class TestLinear(unittest.TestCase):
195210
"""
196211
Test Class for XNNPACK Linear Operators.
@@ -520,6 +535,23 @@ def get_qnode_checks(quant_node_checks, dialect):
520535
# qtol=bool(quant_config), atol=atol
521536
# )
522537

538+
def test_qd8_f32_per_channel_shared_dq_chain(self):
539+
for use_bias in (False, True):
540+
module = SharedDQChain(
541+
input_size=13,
542+
output_size=17,
543+
)
544+
inputs = (torch.randn(1, 2, 13),)
545+
546+
self._test_dqlinear(
547+
module,
548+
inputs,
549+
dynamic_shapes=None,
550+
is_per_channel=True,
551+
linear_count=2,
552+
uses_bias=use_bias,
553+
)
554+
523555
def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
524556
for uses_bias in (False, True):
525557
module = BaseLinear(

backends/xnnpack/test/tester/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ runtime.python_library(
2626
"//executorch/exir/backend:partitioner",
2727
"//executorch/exir/passes:spec_prop_pass",
2828
"//executorch/extension/pybindings:portable_lib", # @manual
29+
"//executorch/backends/transforms:duplicate_dynamic_quant_chain"
2930
],
3031
)

backends/xnnpack/test/tester/tester.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
1616

1717
import torch
18+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
19+
DuplicateDynamicQuantChainPass,
20+
)
1821
from executorch.backends.xnnpack._passes import XNNPACKPassManager
1922
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
2023
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
@@ -177,6 +180,8 @@ def run(
177180
prepared(*inputs)
178181

179182
converted = convert_pt2e(prepared)
183+
DuplicateDynamicQuantChainPass()(converted)
184+
180185
self.converted_graph = converted
181186

182187
@property

0 commit comments

Comments
 (0)