Skip to content

Commit 47549ba

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 17fee78 commit 47549ba

File tree

4 files changed

+274
-25
lines changed

4 files changed

+274
-25
lines changed

backends/arm/arm_backend.py

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def preprocess( # noqa: C901
246246
if path is None:
247247
path = tempfile.mkdtemp(prefix="arm_tosa_")
248248

249+
# Verify if this is a quantized model ahead so that the tensor data type of
250+
# tosa operations during lowering can be easier determined.
251+
is_quantized_model = tosa_quant_utils.isQuantizedModel(edge_program.graph)
252+
249253
# Converted output for this subgraph, serializer needs path early as it emits
250254
# const data directly. Path created and data written only in debug builds.
251255
tosa_fb = ts.TosaSerializer(path)
@@ -476,10 +480,15 @@ def preprocess( # noqa: C901
476480
elif exir_ops.edge.aten.convolution.default == node.target:
477481
input, weight, bias, stride, pad, dilation, _, _, group = inputs
478482

483+
# Currently only int8 is supported in quantized types.
484+
actual_out_type = (
485+
ts.DType.INT8 if is_quantized_model else outp.dtype
486+
)
487+
479488
## Transpose input tensor to NHWC_Order for TOSA
480489
NHWC_Order = [0, 2, 3, 1]
481490
input_transposed = transpose_helper(
482-
tosa_fb, input, NHWC_Order, outp.dtype
491+
tosa_fb, input, NHWC_Order, actual_out_type
483492
)
484493

485494
## CONV2DOp
@@ -493,6 +502,11 @@ def preprocess( # noqa: C901
493502
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
494503

495504
if group.number > 1:
505+
if is_quant_node:
506+
raise AssertionError(
507+
"quantized depthwise conv2d is not supported for now"
508+
)
509+
496510
# Transpose weight to [KH, KW, C, M]
497511
weight_HWCM_Order = [2, 3, 0, 1]
498512
weight_transposed = transpose_helper(
@@ -523,14 +537,17 @@ def preprocess( # noqa: C901
523537
# Transpose weight to [OC, H, W, IC]
524538
weight_CHWC_Order = [0, 2, 3, 1]
525539
weight_transposed = transpose_helper(
526-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
540+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
527541
)
528542

529543
## TOSA output shape is [NHWO]
530544
NHWO_Order = [0, 2, 3, 1]
531545
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
546+
547+
# The output type is int32 when input type is int8.
532548
conv2d_res = tosa_fb.addIntermediate(
533-
out_shape_TOSA_CONV2D, outp.dtype
549+
out_shape_TOSA_CONV2D,
550+
ts.DType.INT32 if is_quant_node else outp.dtype,
534551
)
535552
tosa_fb.addOperator(
536553
TosaOp.Op().CONV2D,
@@ -547,12 +564,45 @@ def preprocess( # noqa: C901
547564
NOHW_Order = [0, 3, 1, 2]
548565
attr_output_transpose = ts.TosaSerializerAttribute()
549566
attr_output_transpose.TransposeAttribute(NOHW_Order)
550-
tosa_fb.addOperator(
551-
TosaOp.Op().TRANSPOSE,
552-
[conv2d_res.name],
553-
[outp.name],
554-
attr_output_transpose,
555-
)
567+
568+
if len(node.all_input_nodes) == 3:
569+
input_node, weight_node, bias_node = node.all_input_nodes
570+
else:
571+
raise AssertionError(
572+
"non-biased conv2d is not supported for now"
573+
)
574+
575+
output_node = list(node.users)[0]
576+
577+
# For quantized convolution, rescale the output value back to the same
578+
# integer value domain of the next op. Otherwise return float32 output.
579+
if is_quant_node:
580+
# Get scale_factor from input, weight, and output.
581+
_, input_scale, _, _, _, _ = getNodeArgs(input_node)
582+
_, weight_scale, _, _, _, _ = getNodeArgs(weight_node)
583+
_, output_scale, _, _, _, _ = getNodeArgs(output_node)
584+
rescaled_conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
585+
tosa_fb,
586+
conv2d_res,
587+
actual_out_type,
588+
input_scale,
589+
weight_scale,
590+
output_scale,
591+
)
592+
tosa_fb.addOperator(
593+
TosaOp.Op().TRANSPOSE,
594+
[rescaled_conv2d_res.name],
595+
[outp.name],
596+
attr_output_transpose,
597+
)
598+
else:
599+
tosa_fb.addOperator(
600+
TosaOp.Op().TRANSPOSE,
601+
[conv2d_res.name],
602+
[outp.name],
603+
attr_output_transpose,
604+
)
605+
556606
elif exir_ops.edge.aten.div.Tensor == node.target:
557607
# Div is implemented as x/y = x*1/y
558608
recip = tosa_fb.addIntermediate(inputs[1].shape, inputs[1].dtype)
@@ -802,7 +852,7 @@ def preprocess( # noqa: C901
802852
p_data = edge_program.state_dict[parameter_name]
803853

804854
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
805-
weight_values = p_data.detach().numpy()
855+
ph_values = p_data.detach().numpy()
806856

807857
# Check if they're for quantized nodes
808858
consumer_node = list(node.users)[0]
@@ -811,14 +861,14 @@ def preprocess( # noqa: C901
811861
consumer_node
812862
)
813863

814-
weight_values_quantized = (
815-
(weight_values / weight_node_scale.number)
864+
ph_values_quantized = (
865+
(ph_values / weight_node_scale.number)
816866
+ weight_node_zp.number
817867
).astype(np.int8)
818868
tosa_fb.addConst(
819869
inputs[0].shape,
820870
ts.DType.INT8,
821-
weight_values_quantized,
871+
ph_values_quantized,
822872
name=out,
823873
)
824874
elif (
@@ -837,30 +887,53 @@ def preprocess( # noqa: C901
837887
weight_node
838888
)
839889

840-
weight_values_quantized = (
841-
weight_values / (input_node_scale * weight_node_scale)
890+
ph_values_quantized = (
891+
ph_values / (input_node_scale * weight_node_scale)
842892
).astype(np.int32)
843893

844894
tosa_fb.addConst(
845895
inputs[0].shape,
846896
ts.DType.INT32,
847-
weight_values_quantized,
897+
ph_values_quantized,
898+
name=out,
899+
)
900+
elif (
901+
consumer_node.target == exir_ops.edge.aten.convolution.default
902+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
903+
):
904+
(
905+
input_node,
906+
weight_node,
907+
bias_node,
908+
) = consumer_node.all_input_nodes
909+
910+
input_node_scale, _ = getQuantNodeArgs(input_node)
911+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
912+
913+
bias_scales = input_node_scale * weight_node_scale
914+
ph_values_quantized = (ph_values / bias_scales).astype(np.int32)
915+
916+
tosa_fb.addConst(
917+
inputs[0].shape,
918+
ts.DType.INT32,
919+
ph_values_quantized,
848920
name=out,
849921
)
850922
else:
851923
tosa_fb.addConst(
852-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
924+
inputs[0].shape, inputs[0].dtype, ph_values, name=out
853925
)
926+
854927
elif out in edge_program.graph_signature.inputs_to_buffers:
855928
parameter_name = edge_program.graph_signature.inputs_to_buffers[
856929
node.name
857930
]
858931
p_data = edge_program.state_dict[parameter_name]
859932

860933
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
861-
weight_values = p_data.detach().numpy()
934+
ph_values = p_data.detach().numpy()
862935
tosa_fb.addConst(
863-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
936+
inputs[0].shape, inputs[0].dtype, ph_values, name=out
864937
)
865938
else:
866939
tensor = ts.TosaSerializerTensor(

backends/arm/test/test_models.py

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
from enum import Enum
1111

12+
import numpy as np
13+
1214
import torch
1315

1416
TestList = {}
1517

18+
# Seed the RNG a convenient number so that we get the same random tests for each test each time
19+
seed = 42
20+
rng = np.random.default_rng(seed)
21+
1622

1723
def register_test(cls):
1824
TestList[cls.__name__] = cls()
@@ -103,15 +109,19 @@ class simple_linear(torch.nn.Module):
103109

104110
def __init__(self):
105111
super().__init__()
106-
torch.manual_seed(42)
112+
torch.manual_seed(seed)
107113
self.fc = torch.nn.Linear(20, 30)
108114

109115
def forward(self, x):
110116
x = self.fc(x)
111117
return x
112118

119+
"""Currenly we compare the quantized result directly with the floating point result, to avoid a noticable
120+
precision difference due to wide random numerical distribution, generate small random value range for
121+
convolution testing instead for now"""
122+
113123
@register_test
114-
class simple_conv2d(torch.nn.Module):
124+
class simple_conv2d_3x3_1x3x256x256_st1(torch.nn.Module):
115125
inputs = {
116126
TosaProfile.BI: (
117127
torch.ones(
@@ -129,6 +139,115 @@ def __init__(self):
129139
self.conv2d = torch.nn.Conv2d(
130140
in_channels=3, out_channels=10, kernel_size=3, stride=1
131141
)
142+
with torch.no_grad():
143+
self.conv2d.weight.copy_(
144+
torch.from_numpy(
145+
np.float32(rng.integers(low=1, high=4, size=(10, 3, 3, 3)))
146+
)
147+
)
148+
self.conv2d.bias.copy_(
149+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(10))))
150+
)
151+
152+
def forward(self, x):
153+
x = self.conv2d(x)
154+
return x
155+
156+
@register_test
157+
class simple_conv2d_1x1_1x2x128x128_st1(torch.nn.Module):
158+
inputs = {
159+
TosaProfile.BI: (
160+
torch.from_numpy(
161+
np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128)))
162+
),
163+
),
164+
TosaProfile.MI: (
165+
torch.from_numpy(
166+
np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128)))
167+
),
168+
),
169+
}
170+
171+
def __init__(self):
172+
super().__init__()
173+
self.conv2d = torch.nn.Conv2d(
174+
in_channels=2, out_channels=1, kernel_size=1, stride=1
175+
)
176+
with torch.no_grad():
177+
self.conv2d.weight.copy_(
178+
torch.from_numpy(
179+
np.float32(rng.integers(low=1, high=4, size=(1, 2, 1, 1)))
180+
)
181+
)
182+
self.conv2d.bias.copy_(
183+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
184+
)
185+
186+
def forward(self, x):
187+
x = self.conv2d(x)
188+
return x
189+
190+
@register_test
191+
class simple_conv2d_2x2_1x1x14x14_st2(torch.nn.Module):
192+
inputs = {
193+
TosaProfile.BI: (
194+
torch.from_numpy(
195+
np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14)))
196+
),
197+
),
198+
TosaProfile.MI: (
199+
torch.from_numpy(
200+
np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14)))
201+
),
202+
),
203+
}
204+
205+
def __init__(self):
206+
super().__init__()
207+
self.conv2d = torch.nn.Conv2d(
208+
in_channels=1, out_channels=1, kernel_size=2, stride=2
209+
)
210+
with torch.no_grad():
211+
self.conv2d.weight.copy_(
212+
torch.from_numpy(
213+
np.float32(rng.integers(low=1, high=4, size=(1, 1, 2, 2)))
214+
)
215+
)
216+
self.conv2d.bias.copy_(
217+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
218+
)
219+
220+
def forward(self, x):
221+
x = self.conv2d(x)
222+
return x
223+
224+
@register_test
225+
class simple_conv2d_5x5_3x2x128x128_st1(torch.nn.Module):
226+
inputs = {
227+
TosaProfile.BI: (
228+
torch.from_numpy(
229+
np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128)))
230+
),
231+
),
232+
TosaProfile.MI: (
233+
torch.from_numpy(
234+
np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128)))
235+
),
236+
),
237+
}
238+
239+
def __init__(self):
240+
super().__init__()
241+
self.conv2d = torch.nn.Conv2d(
242+
in_channels=2, out_channels=3, kernel_size=5, stride=1
243+
)
244+
with torch.no_grad():
245+
self.conv2d.weight.copy_(
246+
torch.from_numpy(
247+
np.float32(rng.integers(low=1, high=10, size=(1, 1, 5, 5)))
248+
)
249+
)
250+
self.conv2d.bias.copy_(torch.ones(3, dtype=torch.float))
132251

133252
def forward(self, x):
134253
x = self.conv2d(x)
@@ -137,8 +256,16 @@ def forward(self, x):
137256
@register_test
138257
class block_two_conv2d(torch.nn.Module):
139258
inputs = {
140-
TosaProfile.BI: (torch.ones(1, 3, 256, 256),),
141-
TosaProfile.MI: (torch.ones(1, 3, 256, 256),),
259+
TosaProfile.BI: (
260+
torch.from_numpy(
261+
np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256)))
262+
),
263+
),
264+
TosaProfile.MI: (
265+
torch.from_numpy(
266+
np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256)))
267+
),
268+
),
142269
}
143270

144271
def __init__(self):
@@ -149,6 +276,11 @@ def __init__(self):
149276
self.conv2d_2 = torch.nn.Conv2d(
150277
in_channels=10, out_channels=15, kernel_size=5, stride=1
151278
)
279+
with torch.no_grad():
280+
self.conv2d.weight.copy_(torch.ones(10, 3, 5, 5, dtype=torch.float))
281+
self.conv2d.bias.copy_(torch.ones(10))
282+
self.conv2d_2.weight.copy_(torch.ones(15, 10, 5, 5, dtype=torch.float))
283+
self.conv2d_2.bias.copy_(torch.ones(15))
152284

153285
def forward(self, x):
154286
x = self.conv2d(x)

0 commit comments

Comments
 (0)