Skip to content

Commit b75e7d7

Browse files
authored
Add default dim_order asserts
Differential Revision: D61311560 Pull Request resolved: #4725
1 parent 7b795d7 commit b75e7d7

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ __ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
8686
// Reshape runtime inputs
8787
if (i < input_ids_.size()) {
8888
size_t num_dims = tensor->dim();
89+
ET_CHECK_OR_RETURN_ERROR(
90+
is_contiguous_dim_order(tensor->dim_order().data(), tensor->dim()),
91+
Internal,
92+
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
93+
i);
8994
size_t dims[XNN_MAX_TENSOR_DIMS];
9095
ET_CHECK_OR_RETURN_ERROR(
9196
num_dims <= XNN_MAX_TENSOR_DIMS,

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ def generate_node_to_external_map(
7878
return node_to_external_map
7979

8080

81+
def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
82+
for node in edge_graph_module.graph.nodes:
83+
if node.op != "placeholder":
84+
continue
85+
86+
# We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
87+
t = node.meta.get("val", None)
88+
if t is not None and getattr(t, "dim_order", None) is not None:
89+
default_dim_order = tuple(range(t.dim()))
90+
if t.dim_order() != default_dim_order:
91+
raise RuntimeError(
92+
f"XNNPACK backend only supports contiguous memory format for inputs."
93+
f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
94+
)
95+
96+
8197
@final
8298
class XnnpackBackend(BackendDetails):
8399
@staticmethod
@@ -126,6 +142,9 @@ def preprocess(
126142

127143
node_to_external_map = generate_node_to_external_map(ep, graph_module)
128144

145+
# Make sure all inputs are contiguous_format or NCHW or default dim order
146+
assert_default_dim_order(graph_module)
147+
129148
# TODO retrace the graph module to lift the new params may have
130149
# been added to the graph in passes
131150

0 commit comments

Comments
 (0)