Skip to content

Commit c0677c3

Browse files
committed
Pull request pytorch#124: Implementation of 'aten::hardtanh' operator conversion
Merge in AITEC/executorch from feature/nxg10272/EIEX-239-add-conversion-for-hardtanh-aka-relu6 to main-nxp * commit '3bc4775a02cd1c47e322402d7fd6eca495e47ec6': Add tests, implementation, and integration of 'aten::hardtanh' operator conversion
2 parents 2a72e6c + 3bc4775 commit c0677c3

File tree

9 files changed

+177
-9
lines changed

9 files changed

+177
-9
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
exir_ops.edge.aten.mm.default: MMConverter,
3030
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter,
3131
exir_ops.edge.aten.relu.default: ReLUConverter,
32+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter,
3233
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
3334
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
3435
exir_ops.edge.aten.add.Tensor: AddTensorConverter,

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
CloneConverter
3131
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.abs_converter import \
3232
AbsConverter
33+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import \
34+
HardTanhConverter
3335
__all__ = [
3436
"AddMMConverter", "ConvolutionConverter", "MMConverter", "PermuteCopyConverter", "SoftmaxConverter",
3537
"ViewCopyConverter", "QDQDequantizeConverter", "QDQQuantizeConverter", "ConstantPadNDConverter", "ReLUConverter",
3638
"MaxPool2dConverter", "AvgPool2dConverter", "AddTensorConverter", "MeanDimConverter", "AdaptiveAvgPool2dConverter",
37-
"CloneConverter", "AbsConverter"
39+
"CloneConverter", "AbsConverter", "HardTanhConverter"
3840
]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch.fx import Node
8+
from torch.nn import Parameter
9+
10+
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter, Target
11+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import BuiltinOperator
12+
13+
14+
class HardTanhConverter(NodeConverter):
15+
supported_targets = [Target.RT700]
16+
17+
# Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite.
18+
supported_modes_map = {
19+
( 0., 6.): BuiltinOperator.RELU6,
20+
(-1., 1.): BuiltinOperator.RELU_N1_TO_1,
21+
( 0., 1.): BuiltinOperator.RELU_0_TO_1,
22+
( 0., float('inf')): BuiltinOperator.RELU,
23+
}
24+
25+
@staticmethod
26+
def _is_supported_in_IR(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
27+
_, min_value, max_value = node.args
28+
return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys()
29+
30+
def convert(self, node: Node):
31+
""" Convert 'aten::hardtanh' to it's supported ReLU equivalent. """
32+
self.assert_convertible(node)
33+
34+
t_op = self._create_tflite_op_with_io_tensors(node)
35+
36+
_, min_value, max_value = node.args
37+
38+
op = self.supported_modes_map[(min_value, max_value)]
39+
t_op.opcode_index = self.builder.op_code_index_for_op_type(op)
40+
41+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
187187
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter,
188188
exir_ops.edge.aten.mm.default: MMConverter,
189189
exir_ops.edge.aten.relu.default: ReLUConverter,
190+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter,
190191
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
191192
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
192193
exir_ops.edge.aten.add.Tensor: AddTensorConverter,

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,26 @@ def partition_types(self):
169169
return [torch.ops.aten.relu_.default]
170170

171171

172+
class HardTanhPattern(SharedSpecPattern):
173+
"""
174+
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
175+
computation layer.
176+
"""
177+
178+
def partition_types(self):
179+
return [torch.ops.aten.hardtanh.default]
180+
181+
182+
class HardTanhInPlacePattern(SharedSpecPattern):
183+
"""
184+
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
185+
functions usually follows computation layer.
186+
"""
187+
188+
def partition_types(self):
189+
return [torch.ops.aten.hardtanh_.default]
190+
191+
172192
class ReshapePattern(SharedSpecPattern):
173193
"""
174194
Quantizer for Reshape operator.
@@ -317,6 +337,8 @@ def __init__(self):
317337
CadenceAtenQuantizer(PermutePattern(), static_qconfig),
318338
CadenceAtenQuantizer(PadPattern(), static_qconfig),
319339
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
340+
CadenceAtenQuantizer(HardTanhPattern(), static_qconfig),
341+
CadenceAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
320342
CadenceAtenQuantizer(ReluInPlacePattern(), static_qconfig),
321343
CadenceAtenQuantizer(AvgPoolPattern(), static_qconfig),
322344
CadenceAtenQuantizer(ViewPattern(), static_qconfig),

backends/nxp/tests/executors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ def convert_run_compare(edge_program: ExportedProgram, input_data, rtol=1.e-5, a
267267
return tflite_executor, edge_program_executor
268268

269269

270-
def graph_contains_op(graph: Graph, op: object) -> bool:
271-
return any(map(lambda node: node.target == op, graph.nodes))
270+
def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
271+
return any(map(lambda node: node.target in ops, graph.nodes))
272272

273273

274274
class OverrideSupportedTargets:

backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
77
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
8-
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_op, ToChannelLastPreprocess, \
8+
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_any_of_ops, ToChannelLastPreprocess, \
99
ToChannelFirstPreprocess
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111

@@ -51,7 +51,7 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)):
5151
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
5252
exported_program: ExportedProgram = converter_spy.call_args.args[1]
5353

54-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.abs.default)
54+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.abs.default])
5555

5656
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
5757
convert_run_compare(exported_program,
@@ -72,7 +72,7 @@ def test_abs_only(mocker, input_shape: tuple[int] = (1, 10)):
7272
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
7373
exported_program: ExportedProgram = converter_spy.call_args.args[1]
7474

75-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.abs.default)
75+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.abs.default])
7676

7777
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
7878
convert_run_compare(exported_program,

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
1414
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
15-
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_op, ToChannelLastPreprocess, \
15+
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_any_of_ops, ToChannelLastPreprocess, \
1616
ToChannelFirstPreprocess
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -75,7 +75,7 @@ def test_conv_dropout_quant(mocker, inplace_dropout: bool, input_shape: tuple[in
7575
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
7676
exported_program: ExportedProgram = converter_spy.call_args.args[1]
7777

78-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.clone.default)
78+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default])
7979

8080
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
8181
convert_run_compare(exported_program,
@@ -97,7 +97,7 @@ def test_clone_pool_view_copy_quant(mocker, inplace_dropout: bool, input_shape:
9797
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
9898
exported_program: ExportedProgram = converter_spy.call_args.args[1]
9999

100-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.clone.default)
100+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default])
101101

102102
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
103103
convert_run_compare(exported_program,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from torch.export import ExportedProgram
5+
6+
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
7+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import \
8+
HardTanhConverter
9+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
10+
from executorch.backends.nxp.tests.executors import convert_run_compare, ToChannelLastPreprocess, \
11+
ToChannelFirstPreprocess, graph_contains_any_of_ops
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def reseed_model_per_test_run():
17+
torch.manual_seed(23)
18+
np.random.seed(23)
19+
20+
21+
class Relu6ConvBlock(torch.nn.Module):
22+
def __init__(self, conv_in_channels: int = 3, inplace: bool = False):
23+
super().__init__()
24+
self.block = torch.nn.Sequential(
25+
torch.nn.Conv2d(in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)),
26+
torch.nn.ReLU6(inplace=inplace)
27+
)
28+
29+
def forward(self, x):
30+
return self.block(x)
31+
32+
33+
class CustomHardTanhBlock(torch.nn.Module):
34+
def __init__(self,
35+
conv_in_channels: int = 3,
36+
min_act_val: float = -1.,
37+
max_act_val: float = 1.,
38+
inplace: bool = False):
39+
super().__init__()
40+
self.block = torch.nn.Sequential(
41+
torch.nn.Conv2d(in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)),
42+
torch.nn.Hardtanh(min_val=min_act_val, max_val=max_act_val, inplace=inplace)
43+
)
44+
45+
def forward(self, x):
46+
return self.block(x)
47+
48+
49+
@pytest.mark.parametrize('input_shape', [(1, 3, 128, 128), (1, 3, 256, 256)])
50+
@pytest.mark.parametrize('inplace', [True, False])
51+
def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool):
52+
model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace)
53+
54+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
55+
56+
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
57+
58+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
59+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
60+
61+
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
62+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)
63+
64+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
65+
convert_run_compare(exported_program,
66+
tfl_model=tflite_flatbuffers_model,
67+
tflite_input_preprocess=ToChannelLastPreprocess(),
68+
tflite_output_preprocess=ToChannelFirstPreprocess(),
69+
input_data=input_data,
70+
atol=1.)
71+
72+
73+
@pytest.mark.parametrize('input_shape', [(1, 3, 128, 128), (1, 3, 256, 256)])
74+
@pytest.mark.parametrize('activation_range', list(HardTanhConverter.supported_modes_map.keys()))
75+
@pytest.mark.parametrize('inplace', [True, False])
76+
def test_custom_hardtanh_quant(mocker, input_shape: tuple[int], activation_range: tuple[int, int], inplace: bool):
77+
min_val, max_val = activation_range
78+
model = CustomHardTanhBlock(
79+
conv_in_channels=input_shape[1],
80+
min_act_val=min_val,
81+
max_act_val=max_val,
82+
inplace=inplace
83+
)
84+
85+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
86+
87+
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
88+
89+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
90+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
91+
92+
ops = [exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardtanh_.default]
93+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=ops)
94+
95+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
96+
convert_run_compare(exported_program,
97+
tfl_model=tflite_flatbuffers_model,
98+
tflite_input_preprocess=ToChannelLastPreprocess(),
99+
tflite_output_preprocess=ToChannelFirstPreprocess(),
100+
input_data=input_data,
101+
atol=1.)

0 commit comments

Comments
 (0)