Skip to content

Commit 554db00

Browse files
authored
support channels last inputs in xnnpack (#11118)
### Summary Support channels last inputs in xnnpack ### Test plan Wrote tests taking combinations of channels last/first inputs with ops requiring inputs of both forms.
1 parent 1f52982 commit 554db00

File tree

5 files changed

+160
-21
lines changed

5 files changed

+160
-21
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
5656

5757
# Set of ops that require memory format to be NCHW
5858
memory_sensitive_ops_nchw = {
59-
"output",
6059
exir_ops.edge.aten.squeeze_copy.dim,
6160
exir_ops.edge.aten.unsqueeze_copy.default,
61+
exir_ops.edge.aten.linear.default,
6262
}
6363

6464
# Tag which is added to a node's meta to indicate that it uses NHWC format.
@@ -91,10 +91,18 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return node.target in self.memory_sensitive_ops_nhwc
94+
return (
95+
node.target in self.memory_sensitive_ops_nhwc
96+
or node.name == "output"
97+
and not node.args[0][0].meta["val"].is_contiguous()
98+
)
9599

96100
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
97-
return node.target in self.memory_sensitive_ops_nchw
101+
return (
102+
node.target in self.memory_sensitive_ops_nchw
103+
or node.name == "output"
104+
and node.args[0][0].meta["val"].is_contiguous()
105+
)
98106

99107
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
100108
# There are two conditions that must be met for a node to be able to
@@ -269,7 +277,10 @@ def input_to_nhwc(
269277
# serializing graph, but don't do anything else here
270278
self.mark_as_nhwc_node(input_node)
271279

272-
if self.is_nhwc_node(input_node):
280+
if input_node.op == "placeholder":
281+
if not input_node.meta["val"][0].is_contiguous():
282+
return
283+
elif self.is_nhwc_node(input_node):
273284
return
274285

275286
if not self.can_be_converted_to_nhwc(input_node):
@@ -333,7 +344,10 @@ def input_to_nchw(
333344
# do anything else here
334345
self.mark_as_nchw_node(input_node)
335346

336-
if self.is_nchw_node(input_node):
347+
if input_node.op == "placeholder":
348+
if input_node.meta["val"].is_contiguous():
349+
return
350+
elif self.is_nchw_node(input_node):
337351
return
338352

339353
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -371,7 +385,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
371385
# first input to be nhwc. This makes this node's output nhwc too
372386
# Currently, all nodes like this should have all of their other
373387
# inputs as nchw, so fail if this is not true
374-
self.input_to_nhwc(graph_module, node.args[0], node)
388+
if node.name == "output":
389+
self.input_to_nhwc(graph_module, node.args[0][0], node)
390+
else:
391+
self.input_to_nhwc(graph_module, node.args[0], node)
392+
375393
for input_node in node.all_input_nodes[1:]:
376394
if self.is_nhwc_node(input_node):
377395
raise AssertionError(

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,16 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
106106
err == Error::Ok,
107107
Internal,
108108
"Failed to retrieve dim order from tensor!");
109-
ET_CHECK_OR_RETURN_ERROR(
110-
is_contiguous_dim_order(dim_order, tensor->dim()),
111-
Internal,
112-
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
113-
i);
114109
size_t dims[XNN_MAX_TENSOR_DIMS];
115110
ET_CHECK_OR_RETURN_ERROR(
116111
num_dims <= XNN_MAX_TENSOR_DIMS,
117112
InvalidArgument,
118113
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
119114
XNN_MAX_TENSOR_DIMS,
120115
num_dims);
121-
for (int d = 0; d < num_dims; ++d) {
122-
dims[d] = tensor->size(d);
116+
117+
for (int j = 0; j < num_dims; ++j) {
118+
dims[j] = tensor->size(static_cast<int>(dim_order[j]));
123119
}
124120
status =
125121
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -220,8 +216,17 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
220216

221217
// Convert new output shape into SizesType
222218
SizesType expected_output_size[kTensorDimensionLimit];
223-
for (size_t d = 0; d < num_dim; ++d) {
224-
expected_output_size[d] = static_cast<SizesType>(dims[d]);
219+
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
220+
Error errr =
221+
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
222+
ET_CHECK_OR_RETURN_ERROR(
223+
errr == Error::Ok,
224+
Internal,
225+
"Failed to retrieve dim order from tensor!");
226+
227+
for (int j = 0; j < num_dim; ++j) {
228+
expected_output_size[static_cast<int>(dim_order[j])] =
229+
static_cast<SizesType>(dims[j]);
225230
}
226231

227232
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,121 @@ def setUp(self):
4343
)
4444
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
4545

46+
def run_tester(self, module, inputs):
47+
tester = Tester(
48+
module.eval(),
49+
inputs,
50+
)
51+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
52+
53+
class LinearConv(torch.nn.Module):
54+
def __init__(self):
55+
super().__init__()
56+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
57+
self.linear1 = torch.nn.Linear(4, 3)
58+
59+
def forward(self, x):
60+
y = self.linear1(x)
61+
return self.conv1(y)
62+
63+
LinearConvModule = LinearConv()
64+
65+
class ConvLinearConv(torch.nn.Module):
66+
def __init__(self):
67+
super().__init__()
68+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
69+
self.linear1 = torch.nn.Linear(4, 4)
70+
71+
def forward(self, x):
72+
y = self.conv1(x)
73+
return self.linear1(y)
74+
75+
ConvLinearConvModule = ConvLinearConv()
76+
77+
class Bilinear(torch.nn.Module):
78+
def __init__(self):
79+
super().__init__()
80+
81+
def forward(self, x):
82+
return torch.nn.functional.interpolate(
83+
x, scale_factor=2, mode="bilinear", align_corners=True
84+
)
85+
86+
BilinearModule = Bilinear()
87+
88+
class TwoConvAdd(torch.nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
92+
self.conv2 = torch.nn.Conv2d(5, 16, 3, padding=1)
93+
94+
def forward(self, x1, x2):
95+
y1 = self.conv1(x1)
96+
y2 = self.conv2(x2)
97+
return torch.add(y1, y2)
98+
99+
TwoConvAddModule = TwoConvAdd()
100+
101+
def test_two_conv_add(self):
102+
x1 = torch.randn(1, 3, 8, 8)
103+
x2 = torch.randn(1, 5, 8, 8)
104+
105+
# Test with regular format inputs
106+
self.run_tester(self.TwoConvAddModule, (x1, x2))
107+
108+
# Test with channels_last format inputs
109+
x1_cl = x1.to(memory_format=torch.channels_last)
110+
x2_cl = x2.to(memory_format=torch.channels_last)
111+
self.run_tester(self.TwoConvAddModule, (x1_cl, x2_cl))
112+
113+
# Test with mixed format inputs
114+
self.run_tester(self.TwoConvAddModule, (x1_cl, x2))
115+
self.run_tester(self.TwoConvAddModule, (x1, x2_cl))
116+
117+
# Verify the pass adds the expected number of to_copy operations
118+
(
119+
Tester(self.TwoConvAddModule, (x1, x2))
120+
.export()
121+
.to_edge()
122+
.run_passes(self.PassStage)
123+
.check_count(
124+
{
125+
self.to_copy_name: 3, # 2 for inputs to conv, 1 for outputs from add
126+
}
127+
)
128+
.run_method_and_compare_outputs()
129+
)
130+
131+
def test_conv_linear_dim_order_swaps(self):
132+
self.run_tester(self.LinearConvModule, (torch.randn(1, 3, 6, 4),))
133+
self.run_tester(
134+
self.LinearConvModule,
135+
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
136+
)
137+
138+
def test_linear_conv_dim_order_swaps(self):
139+
self.run_tester(self.ConvLinearConvModule, (torch.randn(1, 3, 6, 6),))
140+
self.run_tester(
141+
self.ConvLinearConvModule,
142+
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
143+
)
144+
145+
def test_nhwc_nchw_input_on_nhwc_op(self):
146+
self.run_tester(
147+
self.BilinearModule,
148+
(
149+
torch.arange(8)
150+
.reshape(1, 2, 2, 2)
151+
.to(torch.float32)
152+
.to(memory_format=torch.channels_last),
153+
),
154+
)
155+
156+
self.run_tester(
157+
self.BilinearModule,
158+
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
159+
)
160+
46161
def test_fp32_channels_last_tagged_reshape_pass(self):
47162
for module, num_reshape in self.modules.items():
48163
(

backends/xnnpack/test/tester/tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from executorch.exir.backend.backend_api import validation_disabled
3333
from executorch.exir.backend.partitioner import Partitioner
34+
from executorch.exir.dim_order_utils import get_memory_format
3435
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3536

3637
from executorch.exir.print_program import pretty_print, print_program
@@ -533,10 +534,13 @@ def fn(x):
533534
# create random tensor inputs with the shapes given above:
534535
random_inputs = []
535536
for arg_idx in range(len(self.example_inputs)):
537+
memFormat = get_memory_format(
538+
list(self.example_inputs[arg_idx].dim_order())
539+
)
536540
random_inputs.append(
537-
torch.randn(input_shapes[arg_idx]).to(
538-
dtype=self.example_inputs[arg_idx].dtype
539-
)
541+
torch.randn(input_shapes[arg_idx])
542+
.to(dtype=self.example_inputs[arg_idx].dtype)
543+
.to(memory_format=memFormat)
540544
)
541545

542546
yield tuple(random_inputs)

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ def preprocess(
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147147

148-
# Make sure all inputs are contiguous_format or NCHW or default dim order
149-
assert_default_dim_order(graph_module)
150-
151148
# TODO retrace the graph module to lift the new params may have
152149
# been added to the graph in passes
153150

0 commit comments

Comments
 (0)