Skip to content

Commit 2d6e521

Browse files
committed
[MLIR][TOSA] Add --tosa-remove-redundant-transposes pass
---------- Motivation: ---------- Some legalization pathways introduce redundant tosa.TRANSPOSE operations that result in avoidable data movement. For example, PyTorch -> TOSA contains a lot of unnecessary transposes due to conversions between NCHW and NHWC. We wish to remove all the ones that we can, since in general it is possible to remove the overwhelming majority. ------------ Changes Made: ------------ - Add the --tosa-remove-redundant-transposes pass - Add TosaElementwiseOperator trait. ------------------- High-Level Overview: ------------------- The pass begins at a downstream transpose with some perms tensor. It traverses the dependencies upward, accepting only TosaElementwise operators. Dependencies must terminate in nullifying transposes (when composed, they form the identity), reshapes, or consts. Conceptually, we then "bubble up" the downstream transpose until we hit the sources. For constants, we generate a new constants, composed with the downstream transpose. For nullifying transposes, we "cancel" them. For reshapes, we generally cannot "bubble" through them, so we insert the downstream transpose there. We then ensure that we do not cause any duplication by "converting" this chain we bubbled-up into its transposed form. We do this by analyzing the dependency fan-ins across all transposes with the same perms tensor in order to ensure that they do not have uses outside this group, which would cause the old code section to remain "live", and not removed by canonicalization. -------------- Impact of Pass: -------------- Patching the dense_resource artifacts (from PyTorch) with dense attributes to permit constant folding, we receive the following results. Note that data movement represents total transpose data movement, calculated by noting which dimensions moved during the transpose. /////////// MobilenetV3: /////////// BEFORE total data movement: 11798776 B (11.25 MiB) AFTER total data movement: 2998016 B (2.86 MiB) 74.6% of data movement removed. BEFORE transposes: 82 AFTER transposes: 20 75.6% of transposes removed. //////// ResNet18: //////// BEFORE total data movement: 20596556 B (19.64 MiB) AFTER total data movement: 1003520 B (0.96 MiB) 95.2% of data movement removed. BEFORE transposes: 56 AFTER transposes: 5 91.1% of transposes removed. //////// ResNet50: //////// BEFORE total data movement: 83236172 B (79.3 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 96.4% of data movement removed BEFORE transposes: 120 AFTER transposes: 7 94.2% of transposes removed. ///////// ResNet101: ///////// BEFORE total data movement: 124336460 B (118.58 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 97.6% of data movement removed BEFORE transposes: 239 AFTER transposes: 7 97.1% of transposes removed. ///////// ResNet152: ///////// BEFORE total data movement: 175052108 B (166.94 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 98.3% of data movement removed BEFORE transposes: 358 AFTER transposes: 7 98.0% of transposes removed. //////// Overview: //////// We see that we remove up to 98% of transposes and eliminate up to 98.3% of redundant transpose data movement. In the context of ResNet50, with 120 inferences per second, we reduce dynamic transpose data bandwidth from 9.29 GiB/s to 344.4 MiB/s. ----------- Future Work: ----------- (1) Evaluate tradeoffs with the duplication of ConstOp, especially across many downstream transposes with different perms, which can result in the same ConstOp being duplicated (but transposed) multiple times. Observe tradeoffs between a lower memory footprint and potentially converting many fan-ins of downstream transposes with the same perms, which if not converted may affect ability of other inter-dependent fan-in to convert. (2) Restrict the propagation of transposes up their fan-in cone if one of the sources is a ReshapeOp for which the inserted TransposeOp would not be a TransposeOp that lends itself to the TransposeIsReshape Canonicalization, which permits them to be folded to a single ReshapeOp. Observe impact on how this restriction may be detrimental to the conversion of other downstream transpose conversions due to the fan-in cone analysis. Additionally, consider cases where there may be multiple upstream transposes that could be removed as a result of this -- and trade that off with how many you would effectively insert if the ReshapeOp/TransposeOp can't be folded to a single ReshapeOp. (3) Make the pass more general, beyond just allowing upstream transposes to be nullifying. For example, transpose1 -> ... -> transpose2 where transpose2(transpose1) do not cancel to identity. This can be done by propagating the downstream transpose up and inserting after transpose1, just like how it is done for reshape. However, in the case of chains like transpose1 -> ... -> transpose2 -> ... -> transpose3 this could require running the current runOnOperation() function until we converge. This can be done by stopping when all transposes that we can successfully collect the fan-ins of have the owner of their first operand being either another TransposeOp or a ReshapeOp, since those are what we propagate to and where we leave behind / insert another TransposeOp. Otherwise, we would could potentially have infinite looping. This additionally has the implication that we would not replace any transposes and instead we could have canonicalization handle that. (4) Add support for more instructions (for example, those that reduce alongside an axis) to be one of the intervening operations in the fan-in cones (other than those with TosaElementwiseOperator trait). (5) Support bubbling transposes up to the input parameter. May not need extensive fan-in analysis as no operation cost associated if used elsewhere. Signed-off-by: Arteen Abrishami <[email protected]>
1 parent f02c72f commit 2d6e521

File tree

8 files changed

+1590
-0
lines changed

8 files changed

+1590
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
206206
input, paddings, pad_value);
207207
}]>;
208208

209+
//===----------------------------------------------------------------------===//
210+
// TOSA Operator Trait.
211+
//===----------------------------------------------------------------------===//
212+
213+
// Permits broadcasting. Elementwise trait is too strict.
214+
def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
215+
let cppNamespace = "mlir::OpTrait::tosa";
216+
}
217+
209218
//===----------------------------------------------------------------------===//
210219
// TOSA Operator Class.
211220
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
219228
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
220229
["inferReturnTypeComponents"]>,
221230
ResultsBroadcastableShape,
231+
TosaElementwiseOperator,
222232
Pure])> {
223233
let assemblyFormat =
224234
"operands attr-dict `:` functional-type(operands, results)";

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
8484
}
8585
};
8686

87+
/// This class indicates that an op is tosa-elementwise (permits broadcasting,
88+
/// unlike Elementwise trait)
89+
template <typename ConcreteType>
90+
class TosaElementwiseOperator
91+
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
92+
8793
} // namespace tosa
8894
} // namespace OpTrait
8995

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
1515

1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1718
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
1819
#include "mlir/Pass/Pass.h"
1920

@@ -48,6 +49,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
4849
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
4950
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
5051
std::unique_ptr<Pass> createTosaOptionalDecompositions();
52+
std::unique_ptr<Pass> createTosaRemoveRedundantTransposes();
5153

5254
struct ValidationOptions {
5355
/// Validate if operations match for the given profile.

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,20 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
126126
];
127127
}
128128

129+
def TosaRemoveRedundantTransposes : Pass<"tosa-remove-redundant-transposes", "func::FuncOp"> {
130+
let summary = "Remove redundant transposes";
131+
let description = [{
132+
Pass that identifies and removes redundant tosa.TRANSPOSE operations.
133+
It does so by traversing dependencies of tosa.TRANSPOSE operations until they terminate in either
134+
tosa.RESHAPE, a nullifying tosa.TRANSPOSE, or a tosa.CONST. It then propagates the downstream
135+
transform upward through the intervening operators if it is able and replaces the downstream tosa.TRANSPOSE.
136+
Results generally better when run after Canonicalization and resolution of dynamic shapes.
137+
Canonicalization is required for dead code elimination after pass is run.
138+
This pass has an important use-case in cleaning up the results of frameworks that introduce a lot
139+
of data-layout transformations when legalizing to TOSA, a common one being transformations between NHWC and NCHW
140+
layouts.
141+
}];
142+
let constructor = "tosa::createTosaRemoveRedundantTransposes()";
143+
}
144+
129145
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
77
TosaLayerwiseConstantFoldPass.cpp
88
TosaMakeBroadcastable.cpp
99
TosaOptionalDecompositions.cpp
10+
TosaRemoveRedundantTransposes.cpp
1011
TosaTypeConverters.cpp
1112
TosaValidation.cpp
1213

0 commit comments

Comments
 (0)