Skip to content

Remove memory-format workaround for Arm backend #3981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 6 additions & 77 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import serializer.tosa_serializer as ts
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
from executorch.backends.arm.tosa_utils import (
dbg_fail,
dbg_tosa_dump,
is_consumer_node_depthwise_conv2d,
is_permute_node_before_addmm,
process_call_function,
)
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand All @@ -45,6 +43,7 @@ def __init__(self):
self.compiler_flags = []
self.output_format = None
self.path_for_intermediates = None
# TODO MLETORCH-265 Remove permute_nhwc flag
self.permute_nhwc = False
self.quantize_io = False

Expand Down Expand Up @@ -217,18 +216,13 @@ def preprocess( # noqa: C901
artifact_path = None
output_format = ""
compile_flags = []
permute_memory_to_nhwc = False
for spec in compile_spec:
if spec.key == "debug_artifact_path":
artifact_path = spec.value.decode()
if spec.key == "output_format":
output_format = spec.value.decode()
if spec.key == "compile_flags":
compile_flags.append(spec.value.decode())
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
if memory_format == "nhwc":
permute_memory_to_nhwc = True

# Check that the output format is set in the compile spec
if not output_format:
Expand All @@ -250,76 +244,11 @@ def preprocess( # noqa: C901

for node in graph_module.graph.nodes:
if node.op == "call_function":
# Unpack arguments and convert
inputs = []
for arg in node.args:
inputs.append(TosaArg(arg))

# Convert output (this node itself)
output = TosaArg(node)

# TODO: fragile code for temporary fix, not all outputs will be
# rank 4
if permute_memory_to_nhwc and len(output.shape) == 4:
# TODO: remove this if check
# this is added because we need to align the quant node
# output shape before the depthwise_conv2d node. The output
# shape between TOSA conv2d and depthwise_conv2d are different.
if (
node.all_input_nodes[0].op
== "placeholder" # check its parent is a placeholder
and is_quant_node(node)
and is_consumer_node_depthwise_conv2d(node)
):
NHWC_Order = [2, 3, 0, 1]
else:
NHWC_Order = [0, 2, 3, 1]
output.shape = [output.shape[i] for i in NHWC_Order]

# Add output to TOSA graph
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name,
(
inputs[0].shape
if is_permute_node_before_addmm(node)
else output.shape
),
(
map_dtype(get_quant_node_dtype(node))
if is_quant_node(node)
else output.dtype
),
)

# Visiting each Node
if node.target.__name__ in node_visitors:
if node.target.__name__ in [
"aten.add.Tensor",
"aten._native_batch_norm_legit_no_training.default",
]:
node_visitors[node.target.__name__].define_node(
node,
tosa_graph,
inputs,
output,
is_quant_node(node),
permute_memory_to_nhwc,
)
else:
node_visitors[node.target.__name__].define_node(
node, tosa_graph, inputs, output, is_quant_node(node)
)
else:
raise RuntimeError(f"Unknown operator {node.target}")
process_call_function(node, tosa_graph, node_visitors)
elif node.op == "placeholder":
process_placeholder(
node, tosa_graph, edge_program, permute_memory_to_nhwc
)
process_placeholder(node, tosa_graph, edge_program)
elif node.op == "output":
for output in node.args[0]:
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)
process_output(node, tosa_graph)
else:
# This will only happen if an unpartitioned graph is passed without
# any checking of compatibility.
Expand Down
10 changes: 3 additions & 7 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
build_rescale_from_int32,
build_rescale_to_int32,
)
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs
from executorch.backends.arm.tosa_utils import broadcast_shapes, getNodeArgs, tosa_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to add more ops? For instance Permute op argument needs to be updated based on what format input it gets when you update the previous node's output format.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are a few open PR:s with new ops. The ambition is that we can just update shapes such as permute's argument with tosa_shape(shape, dim_order).

from serializer.tosa_serializer import TosaOp


Expand All @@ -34,7 +34,6 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
permute_memory_to_nhwc: bool,
) -> None:
if is_quant_node:
# Single input or not
Expand All @@ -54,12 +53,9 @@ def define_node(
inputA_rescale_scale = input_A_scale.number / min_scale
inputB_rescale_scale = input_B_scale.number / min_scale

input_A.shape = tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tosa_shape(input_B.shape, input_B.dim_order)
broadcasted_shape = broadcast_shapes(input_A.shape, input_B.shape)
if permute_memory_to_nhwc:
NHWC_Order = [0, 2, 3, 1]
broadcasted_shape = [broadcasted_shape[i] for i in NHWC_Order]
input_A.shape = [input_A.shape[i] for i in NHWC_Order]
input_B.shape = [input_B.shape[i] for i in NHWC_Order]

input_A_rescaled_to_int32 = build_rescale_to_int32(
tosa_graph,
Expand Down
40 changes: 22 additions & 18 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import promote_shape
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
from serializer.tosa_serializer import TosaOp


Expand All @@ -25,12 +25,9 @@ def __init__(self, *args):
# For BatchNorm2D, mean and var are calculated over the channel dimension
# But TOSA doesn't allow subtraction of inputs with different ranks
# Need to augment the shapes to match the ranks with activations
def augment_shape_rank(self, input, permute_memory_to_nhwc):
return (
(1, 1, 1) + input.shape
if permute_memory_to_nhwc
else ((1,) + input.shape + (1, 1))
)
def augment_shape_rank(self, shape, dim_order):
nchw_shape = (1, *shape, 1, 1)
return tosa_shape(nchw_shape, dim_order)

def define_node(
self,
Expand All @@ -39,7 +36,6 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
permute_memory_to_nhwc: bool,
) -> None:
# Decompose batch norm into sequence
(activations, weights, bias, running_mean, running_var, momentum, epsilon) = (
Expand Down Expand Up @@ -67,13 +63,15 @@ def define_node(
mean_reshaped = promote_shape(
tosa_graph,
running_mean,
self.augment_shape_rank(running_mean, permute_memory_to_nhwc),
self.augment_shape_rank(running_mean.shape, output.dim_order),
input_dtype,
)

# Subtract mean
# %op1 = tosa.SUB(%activations, %running_mean)
op1 = tosa_graph.addIntermediate(output.shape, input_dtype)
op1 = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), input_dtype
)
tosa_graph.addOperator(
TosaOp.Op().SUB,
[activations.name, mean_reshaped.name],
Expand All @@ -82,7 +80,9 @@ def define_node(
# Adding eplison to variance
# %op2 = tosa.ADD(%running_var, %epsilon_const)
epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number])
op2 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
op2 = tosa_graph.addIntermediate(
tosa_shape(running_var.shape, running_var.dim_order), input_dtype
)
tosa_graph.addOperator(
TosaOp.Op().ADD,
[running_var.name, epsilon_const.name],
Expand All @@ -97,7 +97,7 @@ def define_node(
op3_reshaped = promote_shape(
tosa_graph,
op3,
self.augment_shape_rank(running_var, permute_memory_to_nhwc),
self.augment_shape_rank(running_var.shape, output.dim_order),
input_dtype,
)

Expand All @@ -114,7 +114,9 @@ def define_node(
else:
# Multiply shifted activations with reciprocal variance
# %op4 = tosa.MUL(%op1, %op3)
op4 = tosa_graph.addIntermediate(output.shape, input_dtype)
op4 = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), input_dtype
)
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
Expand All @@ -130,7 +132,7 @@ def define_node(
weights_reshaped = promote_shape(
tosa_graph,
weights,
self.augment_shape_rank(weights, permute_memory_to_nhwc),
self.augment_shape_rank(weights.shape, output.dim_order),
input_dtype,
)

Expand All @@ -152,7 +154,7 @@ def define_node(
bias_reshaped = promote_shape(
tosa_graph,
bias,
self.augment_shape_rank(bias, permute_memory_to_nhwc),
self.augment_shape_rank(bias.shape, output.dim_order),
input_dtype,
)

Expand All @@ -170,12 +172,14 @@ def define_node(
weights_reshaped = promote_shape(
tosa_graph,
weights,
self.augment_shape_rank(weights, permute_memory_to_nhwc),
self.augment_shape_rank(weights.shape, output.dim_order),
input_dtype,
)

# %op5 = tosa.MUL(%op4, %weights)
op5 = tosa_graph.addIntermediate(output.shape, input_dtype)
op5 = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), input_dtype
)
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
Expand All @@ -189,7 +193,7 @@ def define_node(
bias_reshaped = promote_shape(
tosa_graph,
bias,
self.augment_shape_rank(bias, permute_memory_to_nhwc),
self.augment_shape_rank(bias.shape, output.dim_order),
input_dtype,
)

Expand Down
6 changes: 4 additions & 2 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
build_rescale_conv_output,
get_quant_node_args,
)
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape

from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -107,7 +107,9 @@ def define_node(
# The output type is int32 when input type is int8.
conv2d_output_name = output.name
if is_quant_node:
conv2d_res = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)
conv2d_output_name = conv2d_res.name

# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
Expand Down
7 changes: 5 additions & 2 deletions backends/arm/operators/op_div.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,6 +11,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp


Expand All @@ -30,7 +31,9 @@ def define_node(
is_quant_node: bool,
) -> None:
# FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y))
recip = tosa_graph.addIntermediate(inputs[1].shape, inputs[1].dtype)
recip = tosa_graph.addIntermediate(
tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype
)
tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name])

attr = ts.TosaSerializerAttribute()
Expand Down
17 changes: 17 additions & 0 deletions backends/arm/operators/op_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import serializer.tosa_serializer as ts
import torch


def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
):
for output in node.args[0]:
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)
Loading
Loading