Skip to content

Commit 2dd88fb

Browse files
Support dim order in Arm backend (#5576)
Summary: Add both ahead-of-time and runtime guards to make sure we don't accept inputs with any other memory-format than contiguous format. Change-Id: I9e29badbcf238d458e12f0d62394abae66d421b7 Pull Request resolved: #5576 Reviewed By: mergennachin Differential Revision: D63638653 Pulled By: digantdesai fbshipit-source-id: 9eea86836919eacd5578ef024cd43d039fd40123
1 parent 1c6dbb6 commit 2dd88fb

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

backends/arm/operators/op_placeholder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def process_inputs(
2828
tosa_graph: ts.TosaSerializer,
2929
):
3030
"""Serialize an input node"""
31+
# inputs need to be in default dim_order (contiguous memory format)
32+
meta = node.meta["val"]
33+
if meta.dim_order() != tuple(range(meta.dim())):
34+
raise RuntimeError(
35+
f"Arm backend only supports contiguous memory format for inputs. "
36+
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
37+
)
3138
inputs = [TosaArg(node)]
3239
input_shape = inputs[0].shape
3340
input_dim_order = inputs[0].dim_order

backends/arm/runtime/ArmBackendEthosU.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "executorch/runtime/backend/interface.h"
2020
#include "executorch/runtime/core/error.h"
2121
#include "executorch/runtime/core/evalue.h"
22+
#include "executorch/runtime/core/exec_aten/util/dim_order_util.h"
2223
#include "executorch/runtime/core/exec_aten/util/scalar_type_util.h"
2324

2425
using namespace std;
@@ -144,6 +145,15 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
144145
toString(tensor_in.scalar_type()));
145146
return Error::InvalidProgram;
146147
}
148+
supported = is_contiguous_dim_order(
149+
tensor_in.dim_order().data(), tensor_in.dim());
150+
if (!supported) {
151+
ET_LOG(
152+
Error,
153+
"Input %d expected contiguous dim_order, but got non-contiguous dim_order",
154+
i);
155+
return Error::InvalidProgram;
156+
}
147157

148158
// Select a compatible copy routine including checking for input layouts
149159
// which require permutation.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import pytest
9+
10+
import torch
11+
from executorch.backends.arm.test import common
12+
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
15+
16+
class Conv2D(torch.nn.Module):
17+
18+
def __init__(self):
19+
super().__init__()
20+
self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(3, 3))
21+
22+
def forward(self, x):
23+
return self.conv2d(x.to(memory_format=torch.channels_last))
24+
25+
def get_inputs(self):
26+
return (torch.randn(1, 2, 20, 20),)
27+
28+
29+
class TestDimOrderGuards(unittest.TestCase):
30+
31+
def test_tosa_MI_pipeline(self):
32+
module = Conv2D()
33+
tester = (
34+
ArmTester(
35+
module,
36+
example_inputs=module.get_inputs(),
37+
compile_spec=common.get_tosa_compile_spec(),
38+
)
39+
.export()
40+
.to_edge()
41+
)
42+
with pytest.raises(RuntimeError):
43+
tester.partition()
44+
45+
def test_tosa_BI_pipeline(self):
46+
module = Conv2D()
47+
tester = (
48+
ArmTester(
49+
module,
50+
example_inputs=module.get_inputs(),
51+
compile_spec=common.get_tosa_compile_spec(),
52+
)
53+
.quantize()
54+
.export()
55+
.to_edge()
56+
)
57+
with pytest.raises(RuntimeError):
58+
tester.partition()

0 commit comments

Comments
 (0)