Skip to content

Commit 00f239e

Browse files
authored
[MLIR][TOSA] Add --tosa-reduce-transposes pass (#108260)
---------- Motivation: ---------- Some legalization pathways introduce redundant tosa.TRANSPOSE operations that result in avoidable data movement. For example, PyTorch -> TOSA contains a lot of unnecessary transposes due to conversions between NCHW and NHWC. We wish to remove all the ones that we can, since in general it is possible to remove the overwhelming majority. ------------ Changes Made: ------------ - Add the --tosa-reduce-transposes pass - Add TosaElementwiseOperator trait. ------------------- High-Level Overview: ------------------- The pass works through the transpose operators in the program. It begins at some transpose operator with an associated permutations tensor. It traverses upwards through the dependencies of this transpose and verifies that we encounter only operators with the TosaElementwiseOperator trait and terminate in either constants, reshapes, or transposes. We then evaluate whether there are any additional restrictions (the transposes it terminates in must invert the one we began at, and the reshapes must be ones in which we can fold the transpose into), and then we hoist the transpose through the intervening operators, folding it at the constants, reshapes, and transposes. Finally, we ensure that we do not need both the transposed form (the form that had the transpose hoisted through it) and the untransposed form (which it was prior), by analyzing the usages of those dependent operators of a given transpose we are attempting to hoist and replace. If they are such that it would require both forms to be necessary, then we do not replace the hoisted transpose, causing the new chain to be dead. Otherwise, we do and the old chain (untransposed form) becomes dead. Only one chain will ever then be live, resulting in no duplication. We then perform a simple one-pass DCE, so no canonicalization is necessary. -------------- Impact of Pass: -------------- Patching the dense_resource artifacts (from PyTorch) with dense attributes to permit constant folding, we receive the following results. Note that data movement represents total transpose data movement, calculated by noting which dimensions moved during the transpose. /////////// MobilenetV3: /////////// BEFORE total data movement: 11798776 B (11.25 MiB) AFTER total data movement: 2998016 B (2.86 MiB) 74.6% of data movement removed. BEFORE transposes: 82 AFTER transposes: 20 75.6% of transposes removed. //////// ResNet18: //////// BEFORE total data movement: 20596556 B (19.64 MiB) AFTER total data movement: 1003520 B (0.96 MiB) 95.2% of data movement removed. BEFORE transposes: 56 AFTER transposes: 5 91.1% of transposes removed. //////// ResNet50: //////// BEFORE total data movement: 83236172 B (79.3 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 96.4% of data movement removed BEFORE transposes: 120 AFTER transposes: 7 94.2% of transposes removed. ///////// ResNet101: ///////// BEFORE total data movement: 124336460 B (118.58 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 97.6% of data movement removed BEFORE transposes: 239 AFTER transposes: 7 97.1% of transposes removed. ///////// ResNet152: ///////// BEFORE total data movement: 175052108 B (166.94 MiB) AFTER total data movement: 3010560 B (2.87 MiB) 98.3% of data movement removed BEFORE transposes: 358 AFTER transposes: 7 98.0% of transposes removed. //////// Overview: //////// We see that we remove up to 98% of transposes and eliminate up to 98.3% of redundant transpose data movement. In the context of ResNet50, with 120 inferences per second, we reduce dynamic transpose data bandwidth from 9.29 GiB/s to 344.4 MiB/s. ----------- Future Work: ----------- (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across hoisted transposes with different permutation tensors. (2) Expand the class of foldable upstream ReshapeOp we permit beyond N -> 1x1x...x1xNx1x...x1x1. (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond those that form the identity. (4) Add support for more instructions besides TosaElementwiseOperator as the intervening ones (for example, the reduce_* operators). (5) Support hoisting transposes up to an input parameter. Signed-off-by: Arteen Abrishami <[email protected]>
1 parent 9ceb967 commit 00f239e

File tree

6 files changed

+1380
-0
lines changed

6 files changed

+1380
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
206206
input, paddings, pad_value);
207207
}]>;
208208

209+
//===----------------------------------------------------------------------===//
210+
// TOSA Operator Trait.
211+
//===----------------------------------------------------------------------===//
212+
213+
// Permits broadcasting. Elementwise trait is too strict.
214+
def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
215+
let cppNamespace = "mlir::OpTrait::tosa";
216+
}
217+
209218
//===----------------------------------------------------------------------===//
210219
// TOSA Operator Class.
211220
//===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
219228
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
220229
["inferReturnTypeComponents"]>,
221230
ResultsBroadcastableShape,
231+
TosaElementwiseOperator,
222232
Pure])> {
223233
let assemblyFormat =
224234
"operands attr-dict `:` functional-type(operands, results)";

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
8484
}
8585
};
8686

87+
/// This class indicates that an op is tosa-elementwise (permits broadcasting,
88+
/// unlike Elementwise trait).
89+
template <typename ConcreteType>
90+
class TosaElementwiseOperator
91+
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
92+
8793
} // namespace tosa
8894
} // namespace OpTrait
8995

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,25 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
126126
];
127127
}
128128

129+
def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
130+
let summary = "Reduce transposes through other operators";
131+
let description = [{
132+
Pass that identifies and reduces tosa.TRANSPOSE operations through chains
133+
of operators.
134+
135+
The pass traverses dependencies of tosa.TRANSPOSE operations until they
136+
terminate in either a tosa.RESHAPE that we can fold the hoisted
137+
tosa.TRANSPOSE into, a tosa.TRANSPOSE that forms the identity with the
138+
hoisted one, or a tosa.CONST with a dense elements attribute. It then
139+
propagates the hoisted transform upward through the intervening operators
140+
if the support is implemented. Finally, it observes that no duplication
141+
will occur of both the chain that was hoisted through and the new chain
142+
that results, and if so, it replaces the hoisted tosa.TRANSPOSE.
143+
144+
The pass has an important use-case in cleaning up the results of frameworks
145+
that introduce a lot of data-layout transformations when legalizing to TOSA,
146+
a common one being transformations between NHWC and NCHW layouts.
147+
}];
148+
}
149+
129150
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
77
TosaLayerwiseConstantFoldPass.cpp
88
TosaMakeBroadcastable.cpp
99
TosaOptionalDecompositions.cpp
10+
TosaReduceTransposes.cpp
1011
TosaTypeConverters.cpp
1112
TosaValidation.cpp
1213

0 commit comments

Comments
 (0)