Skip to content

Commit 9d6ee53

Browse files
committed
Add initial lowering of aten.convolution to tosa.conv2d support
1 parent 17fee78 commit 9d6ee53

File tree

4 files changed

+268
-28
lines changed

4 files changed

+268
-28
lines changed

backends/arm/arm_backend.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ def getQuantNodeArgs(node):
227227

228228
@final
229229
class ArmBackend(BackendDetails):
230+
# Class variable initialization
231+
ssa_num = -1
232+
233+
@staticmethod
234+
def getSSAnum():
235+
ArmBackend.ssa_num += 1
236+
return ArmBackend.ssa_num
237+
230238
@staticmethod
231239
def preprocess( # noqa: C901
232240
edge_program: ExportedProgram,
@@ -476,10 +484,13 @@ def preprocess( # noqa: C901
476484
elif exir_ops.edge.aten.convolution.default == node.target:
477485
input, weight, bias, stride, pad, dilation, _, _, group = inputs
478486

487+
# Currently only int8 is supported in quantized types.
488+
actual_out_type = ts.DType.INT8 if is_quant_node else outp.dtype
489+
479490
## Transpose input tensor to NHWC_Order for TOSA
480491
NHWC_Order = [0, 2, 3, 1]
481492
input_transposed = transpose_helper(
482-
tosa_fb, input, NHWC_Order, outp.dtype
493+
tosa_fb, input, NHWC_Order, actual_out_type
483494
)
484495

485496
## CONV2DOp
@@ -492,6 +503,21 @@ def preprocess( # noqa: C901
492503
dilation_attr = dilation.special
493504
attr.ConvAttribute(pad_attr, stride_attr, dilation_attr, 0, 0)
494505

506+
if len(node.all_input_nodes) == 3:
507+
input_node, weight_node, _ = node.all_input_nodes
508+
else:
509+
input_node, weight_node = node.all_input_nodes
510+
511+
# Create a zero bias tensor if not presented
512+
out_channels = weight.shape[0]
513+
bias_name = "const_bias_" + str(ArmBackend.getSSAnum())
514+
bias = tosa_fb.addConst(
515+
[out_channels],
516+
ts.DType.INT32 if is_quant_node else outp.dtype,
517+
[0] * out_channels,
518+
name=bias_name,
519+
)
520+
495521
if group.number > 1:
496522
# Transpose weight to [KH, KW, C, M]
497523
weight_HWCM_Order = [2, 3, 0, 1]
@@ -523,14 +549,17 @@ def preprocess( # noqa: C901
523549
# Transpose weight to [OC, H, W, IC]
524550
weight_CHWC_Order = [0, 2, 3, 1]
525551
weight_transposed = transpose_helper(
526-
tosa_fb, weight, weight_CHWC_Order, outp.dtype
552+
tosa_fb, weight, weight_CHWC_Order, actual_out_type
527553
)
528554

529555
## TOSA output shape is [NHWO]
530556
NHWO_Order = [0, 2, 3, 1]
531557
out_shape_TOSA_CONV2D = [outp.shape[i] for i in NHWO_Order]
558+
559+
# The output type is int32 when input type is int8.
532560
conv2d_res = tosa_fb.addIntermediate(
533-
out_shape_TOSA_CONV2D, outp.dtype
561+
out_shape_TOSA_CONV2D,
562+
ts.DType.INT32 if is_quant_node else outp.dtype,
534563
)
535564
tosa_fb.addOperator(
536565
TosaOp.Op().CONV2D,
@@ -547,12 +576,32 @@ def preprocess( # noqa: C901
547576
NOHW_Order = [0, 3, 1, 2]
548577
attr_output_transpose = ts.TosaSerializerAttribute()
549578
attr_output_transpose.TransposeAttribute(NOHW_Order)
579+
580+
# For quantized convolution, rescale the output value back to the same
581+
# integer value domain of the next op. Otherwise return float32 output.
582+
if is_quant_node:
583+
# Get scale_factor from input, weight, and output.
584+
output_node = list(node.users)[0]
585+
_, input_scale, _, _, _, _ = getNodeArgs(input_node)
586+
_, weight_scale, _, _, _, _ = getNodeArgs(weight_node)
587+
_, output_scale, _, _, _, _ = getNodeArgs(output_node)
588+
589+
conv2d_res = tosa_quant_utils.buildRescaleOpConvOutput(
590+
tosa_fb,
591+
conv2d_res,
592+
actual_out_type,
593+
input_scale,
594+
weight_scale,
595+
output_scale,
596+
)
597+
550598
tosa_fb.addOperator(
551599
TosaOp.Op().TRANSPOSE,
552600
[conv2d_res.name],
553601
[outp.name],
554602
attr_output_transpose,
555603
)
604+
556605
elif exir_ops.edge.aten.div.Tensor == node.target:
557606
# Div is implemented as x/y = x*1/y
558607
recip = tosa_fb.addIntermediate(inputs[1].shape, inputs[1].dtype)
@@ -802,7 +851,7 @@ def preprocess( # noqa: C901
802851
p_data = edge_program.state_dict[parameter_name]
803852

804853
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
805-
weight_values = p_data.detach().numpy()
854+
parameter_values = p_data.detach().numpy()
806855

807856
# Check if they're for quantized nodes
808857
consumer_node = list(node.users)[0]
@@ -811,14 +860,14 @@ def preprocess( # noqa: C901
811860
consumer_node
812861
)
813862

814-
weight_values_quantized = (
815-
(weight_values / weight_node_scale.number)
863+
parameter_values_quantized = (
864+
(parameter_values / weight_node_scale.number)
816865
+ weight_node_zp.number
817866
).astype(np.int8)
818867
tosa_fb.addConst(
819868
inputs[0].shape,
820869
ts.DType.INT8,
821-
weight_values_quantized,
870+
parameter_values_quantized,
822871
name=out,
823872
)
824873
elif (
@@ -837,30 +886,55 @@ def preprocess( # noqa: C901
837886
weight_node
838887
)
839888

840-
weight_values_quantized = (
841-
weight_values / (input_node_scale * weight_node_scale)
889+
parameter_values_quantized = (
890+
parameter_values / (input_node_scale * weight_node_scale)
842891
).astype(np.int32)
843892

844893
tosa_fb.addConst(
845894
inputs[0].shape,
846895
ts.DType.INT32,
847-
weight_values_quantized,
896+
parameter_values_quantized,
897+
name=out,
898+
)
899+
elif (
900+
consumer_node.target == exir_ops.edge.aten.convolution.default
901+
and list(consumer_node.users)[0].target == tosa_quant_utils.q_op
902+
):
903+
(
904+
input_node,
905+
weight_node,
906+
bias_node,
907+
) = consumer_node.all_input_nodes
908+
909+
input_node_scale, _ = getQuantNodeArgs(input_node)
910+
weight_node_scale, _ = getQuantNodeArgs(weight_node)
911+
912+
bias_scales = input_node_scale * weight_node_scale
913+
parameter_values_quantized = (
914+
parameter_values / bias_scales
915+
).astype(np.int32)
916+
917+
tosa_fb.addConst(
918+
inputs[0].shape,
919+
ts.DType.INT32,
920+
parameter_values_quantized,
848921
name=out,
849922
)
850923
else:
851924
tosa_fb.addConst(
852-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
925+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
853926
)
927+
854928
elif out in edge_program.graph_signature.inputs_to_buffers:
855929
parameter_name = edge_program.graph_signature.inputs_to_buffers[
856930
node.name
857931
]
858932
p_data = edge_program.state_dict[parameter_name]
859933

860934
assert isinstance(p_data, torch.Tensor), "Expect Attr to be tensor"
861-
weight_values = p_data.detach().numpy()
935+
parameter_values = p_data.detach().numpy()
862936
tosa_fb.addConst(
863-
inputs[0].shape, inputs[0].dtype, weight_values, name=out
937+
inputs[0].shape, inputs[0].dtype, parameter_values, name=out
864938
)
865939
else:
866940
tensor = ts.TosaSerializerTensor(

backends/arm/test/test_models.py

Lines changed: 145 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
from enum import Enum
1111

12+
import numpy as np
13+
1214
import torch
1315

1416
TestList = {}
1517

18+
# Seed the RNG a convenient number so that we get the same random tests for each test each time
19+
seed = 42
20+
rng = np.random.default_rng(seed)
21+
1622

1723
def register_test(cls):
1824
TestList[cls.__name__] = cls()
@@ -103,42 +109,163 @@ class simple_linear(torch.nn.Module):
103109

104110
def __init__(self):
105111
super().__init__()
106-
torch.manual_seed(42)
112+
torch.manual_seed(seed)
107113
self.fc = torch.nn.Linear(20, 30)
108114

109115
def forward(self, x):
110116
x = self.fc(x)
111117
return x
112118

119+
"""Currenly we compare the quantized result directly with the floating point result, to avoid a noticable
120+
precision difference due to wide random numerical distribution, generate small random value range for
121+
convolution testing instead for now"""
122+
113123
@register_test
114-
class simple_conv2d(torch.nn.Module):
124+
class simple_conv2d_2x2_3x1x40x40_non_bias(torch.nn.Module):
125+
data = torch.from_numpy(
126+
np.float32(rng.integers(low=10, high=20, size=(3, 1, 40, 40)))
127+
)
115128
inputs = {
116-
TosaProfile.BI: (
117-
torch.ones(
118-
1,
119-
3,
120-
256,
121-
256,
122-
),
123-
),
124-
TosaProfile.MI: (torch.ones(1, 3, 256, 256),),
129+
TosaProfile.BI: (data,),
130+
TosaProfile.MI: (data,),
131+
}
132+
133+
def __init__(self):
134+
super().__init__()
135+
self.conv2d = torch.nn.Conv2d(
136+
in_channels=1, out_channels=3, kernel_size=2, stride=1, bias=False
137+
)
138+
with torch.no_grad():
139+
self.conv2d.weight.copy_(
140+
torch.from_numpy(
141+
np.float32(rng.integers(low=1, high=10, size=(1, 1, 2, 2)))
142+
)
143+
)
144+
145+
def forward(self, x):
146+
x = self.conv2d(x)
147+
return x
148+
149+
@register_test
150+
class simple_conv2d_3x3_1x3x256x256_st1(torch.nn.Module):
151+
data = torch.ones(1, 3, 256, 256)
152+
inputs = {
153+
TosaProfile.BI: (data,),
154+
TosaProfile.MI: (data,),
125155
}
126156

127157
def __init__(self):
128158
super().__init__()
129159
self.conv2d = torch.nn.Conv2d(
130160
in_channels=3, out_channels=10, kernel_size=3, stride=1
131161
)
162+
with torch.no_grad():
163+
self.conv2d.weight.copy_(
164+
torch.from_numpy(
165+
np.float32(rng.integers(low=1, high=4, size=(10, 3, 3, 3)))
166+
)
167+
)
168+
self.conv2d.bias.copy_(
169+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(10))))
170+
)
171+
172+
def forward(self, x):
173+
x = self.conv2d(x)
174+
return x
175+
176+
@register_test
177+
class simple_conv2d_1x1_1x2x128x128_st1(torch.nn.Module):
178+
data = torch.from_numpy(
179+
np.float32(rng.integers(low=10, high=20, size=(1, 2, 128, 128)))
180+
)
181+
inputs = {
182+
TosaProfile.BI: (data,),
183+
TosaProfile.MI: (data,),
184+
}
185+
186+
def __init__(self):
187+
super().__init__()
188+
self.conv2d = torch.nn.Conv2d(
189+
in_channels=2, out_channels=1, kernel_size=1, stride=1
190+
)
191+
with torch.no_grad():
192+
self.conv2d.weight.copy_(
193+
torch.from_numpy(
194+
np.float32(rng.integers(low=1, high=4, size=(1, 2, 1, 1)))
195+
)
196+
)
197+
self.conv2d.bias.copy_(
198+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
199+
)
200+
201+
def forward(self, x):
202+
x = self.conv2d(x)
203+
return x
204+
205+
@register_test
206+
class simple_conv2d_2x2_1x1x14x14_st2(torch.nn.Module):
207+
data = torch.from_numpy(
208+
np.float32(rng.integers(low=10, high=20, size=(1, 1, 14, 14)))
209+
)
210+
inputs = {
211+
TosaProfile.BI: (data,),
212+
TosaProfile.MI: (data,),
213+
}
214+
215+
def __init__(self):
216+
super().__init__()
217+
self.conv2d = torch.nn.Conv2d(
218+
in_channels=1, out_channels=1, kernel_size=2, stride=2
219+
)
220+
with torch.no_grad():
221+
self.conv2d.weight.copy_(
222+
torch.from_numpy(
223+
np.float32(rng.integers(low=1, high=4, size=(1, 1, 2, 2)))
224+
)
225+
)
226+
self.conv2d.bias.copy_(
227+
torch.from_numpy(np.float32(rng.integers(low=1, high=4, size=(1))))
228+
)
229+
230+
def forward(self, x):
231+
x = self.conv2d(x)
232+
return x
233+
234+
@register_test
235+
class simple_conv2d_5x5_3x2x128x128_st1(torch.nn.Module):
236+
data = torch.from_numpy(
237+
np.float32(rng.integers(low=10, high=20, size=(3, 2, 128, 128)))
238+
)
239+
inputs = {
240+
TosaProfile.BI: (data,),
241+
TosaProfile.MI: (data,),
242+
}
243+
244+
def __init__(self):
245+
super().__init__()
246+
self.conv2d = torch.nn.Conv2d(
247+
in_channels=2, out_channels=3, kernel_size=5, stride=1
248+
)
249+
with torch.no_grad():
250+
self.conv2d.weight.copy_(
251+
torch.from_numpy(
252+
np.float32(rng.integers(low=1, high=10, size=(1, 1, 5, 5)))
253+
)
254+
)
255+
self.conv2d.bias.copy_(torch.ones(3, dtype=torch.float))
132256

133257
def forward(self, x):
134258
x = self.conv2d(x)
135259
return x
136260

137261
@register_test
138262
class block_two_conv2d(torch.nn.Module):
263+
data = torch.from_numpy(
264+
np.float32(rng.integers(low=10, high=20, size=(1, 3, 256, 256)))
265+
)
139266
inputs = {
140-
TosaProfile.BI: (torch.ones(1, 3, 256, 256),),
141-
TosaProfile.MI: (torch.ones(1, 3, 256, 256),),
267+
TosaProfile.BI: (data,),
268+
TosaProfile.MI: (data,),
142269
}
143270

144271
def __init__(self):
@@ -149,6 +276,11 @@ def __init__(self):
149276
self.conv2d_2 = torch.nn.Conv2d(
150277
in_channels=10, out_channels=15, kernel_size=5, stride=1
151278
)
279+
with torch.no_grad():
280+
self.conv2d.weight.copy_(torch.ones(10, 3, 5, 5, dtype=torch.float))
281+
self.conv2d.bias.copy_(torch.ones(10))
282+
self.conv2d_2.weight.copy_(torch.ones(15, 10, 5, 5, dtype=torch.float))
283+
self.conv2d_2.bias.copy_(torch.ones(15))
152284

153285
def forward(self, x):
154286
x = self.conv2d(x)

0 commit comments

Comments
 (0)