Skip to content

[MLIR][TOSA] Add --tosa-reduce-transposes pass #108260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 14, 2024

Conversation

arteen1000
Copy link
Contributor

@arteen1000 arteen1000 commented Sep 11, 2024


Motivation:

Some legalization pathways introduce redundant tosa.TRANSPOSE
operations that result in avoidable data movement. For example,
PyTorch -> TOSA contains a lot of unnecessary transposes due
to conversions between NCHW and NHWC.

We wish to remove all the ones that we can, since in general
it is possible to remove the overwhelming majority.


Changes Made:

  • Add the --tosa-reduce-transposes pass
  • Add TosaElementwiseOperator trait.

High-Level Overview:

The pass works through the transpose operators in the program. It begins at some
transpose operator with an associated permutations tensor. It traverses upwards
through the dependencies of this transpose and verifies that we encounter only
operators with the TosaElementwiseOperator trait and terminate in either
constants, reshapes, or transposes.

We then evaluate whether there are any additional restrictions (the transposes
it terminates in must invert the one we began at, and the reshapes must be ones
in which we can fold the transpose into), and then we hoist the transpose through
the intervening operators, folding it at the constants, reshapes, and transposes.

Finally, we ensure that we do not need both the transposed form (the form that
had the transpose hoisted through it) and the untransposed form (which it was prior),
by analyzing the usages of those dependent operators of a given transpose we are
attempting to hoist and replace.

If they are such that it would require both forms to be necessary, then we do not
replace the hoisted transpose, causing the new chain to be dead. Otherwise, we do
and the old chain (untransposed form) becomes dead. Only one chain will ever then
be live, resulting in no duplication.

We then perform a simple one-pass DCE, so no canonicalization is necessary.


Impact of Pass:

Patching the dense_resource artifacts (from PyTorch) with dense attributes to
permit constant folding, we receive the following results.

Note that data movement represents total transpose data movement, calculated
by noting which dimensions moved during the transpose.

///////////
MobilenetV3:
///////////

BEFORE total data movement: 11798776 B (11.25 MiB)
AFTER total data movement: 2998016 B (2.86 MiB)
74.6% of data movement removed.

BEFORE transposes: 82
AFTER transposes: 20
75.6% of transposes removed.

////////
ResNet18:
////////

BEFORE total data movement: 20596556 B (19.64 MiB)
AFTER total data movement: 1003520 B (0.96 MiB)
95.2% of data movement removed.

BEFORE transposes: 56
AFTER transposes: 5
91.1% of transposes removed.

////////
ResNet50:
////////

BEFORE total data movement: 83236172 B (79.3 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
96.4% of data movement removed

BEFORE transposes: 120
AFTER transposes: 7
94.2% of transposes removed.

/////////
ResNet101:
/////////

BEFORE total data movement: 124336460 B (118.58 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
97.6% of data movement removed

BEFORE transposes: 239
AFTER transposes: 7
97.1% of transposes removed.

/////////
ResNet152:
/////////

BEFORE total data movement: 175052108 B (166.94 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
98.3% of data movement removed

BEFORE transposes: 358
AFTER transposes: 7
98.0% of transposes removed.

////////
Overview:
////////

We see that we remove up to 98% of transposes and eliminate
up to 98.3% of redundant transpose data movement.

In the context of ResNet50, with 120 inferences per second,
we reduce dynamic transpose data bandwidth from 9.29 GiB/s
to 344.4 MiB/s.


Future Work:

(1) Evaluate tradeoffs with permitting ConstOp to be duplicated across hoisted
transposes with different permutation tensors.

(2) Expand the class of foldable upstream ReshapeOp we permit beyond
N -> 1x1x...x1xNx1x...x1x1.

(3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
those that form the identity.

(4) Add support for more instructions besides TosaElementwiseOperator as
the intervening ones (for example, the reduce_* operators).

(5) Support hoisting transposes up to an input parameter.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Arteen Abrishami (arteen1000)

Changes

Motivation:

Some legalization pathways introduce redundant tosa.TRANSPOSE operations that result in avoidable data movement. For example, PyTorch -> TOSA contains a lot of unnecessary transposes due to conversions between NCHW and NHWC.

We wish to remove all the ones that we can, since in general it is possible to remove upwards of 90% of these transposes in a provable manner.


Changes Made:

  • Add the --tosa-remove-redundant-transposes pass
  • Add TosaElementwiseOperator trait.

High-Level Overview:

The pass begins at a downstream transpose with some perms tensor. It traverses the dependencies upward, accepting only TosaElementwise operators. Dependencies must terminate in nullifying transposes (when composed, they form the identity), reshapes, or consts.

Conceptually, we then "bubble up" the downstream transpose until we hit the sources. For constants, we generate a new constants, composed with the downstream transpose. For nullifying transposes, we "cancel" them. For reshapes, we generally cannot "bubble" through them, so we insert the downstream transpose there.

We then ensure that we do not cause any duplication by "converting" this chain we bubbled-up into its transposed form. We do this by analyzing the dependency fan-ins across all transposes with the same perms tensor in order to ensure that they do not have uses outside this group, which would cause the old code section to remain "live", and not removed by canonicalization.


Impact of Pass:

For the ResNet18 network, we are able to reduce it to 5 transposes, from 56 -- with the patching of the torch dense_resource artifacts with dense attributes. Otherwise, without that patch, we reduce to 23, since we cannot fold those artifacts.

In the second case (56 -> 23), instruction count is reduced by exactly 33. There are 3 transposes that would be removed if we omitted the fan-in analysis, however, with fan-in analysis, we end up with ~15 less operations, due to the lack of duplication.

For ResNet50, the results are essentially identical.

For MobilenetV3, we reduce the number of transposes from 82 to 38 without taking care of upstream constants. After also taking care of constants, we reduce it to 20 transposes. The remaining have a use elsewhere outside of the fan-in cones. The pass alone (after --canonicalize is run on the initial network), is responsible for the removal of 48 of the transposes.

Due to cases where a constant is used purely in its NCHW form without a transpose to NHWC and also separately used in a place where the downstream converts to NHWC, we do end up with 7 additional constants; however, due to their small size, this has minimal memory footprint.


Future Work:

(1)

Evaluate tradeoffs with the duplication of ConstOp, especially across many downstream transposes with different perms, which can result in the same ConstOp being duplicated (but transposed) multiple times.

Observe tradeoffs between a lower memory footprint and potentially converting many fan-ins of downstream transposes with the same perms, which if not converted may affect ability of other inter-dependent fan-in to convert.

(2)

Restrict the propagation of transposes up their fan-in cone if one of the sources is a ReshapeOp for which the inserted TransposeOp would not be a TransposeOp that lends itself to the TransposeIsReshape Canonicalization, which permits them to be folded to a single ReshapeOp.

Observe impact on how this restriction may be detrimental to the conversion of other downstream transpose conversions due to the fan-in cone analysis. Additionally, consider cases where there may be multiple upstream transposes that could be removed as a result of this -- and trade that off with how many you would effectively insert if the ReshapeOp/TransposeOp can't be folded to a single ReshapeOp.

(3)

Make the pass more general, beyond just allowing upstream transposes to be nullifying. For example,

transpose1 -> ... -> transpose2

where transpose2(transpose1) do not cancel to identity.

This can be done by propagating the downstream transpose up and inserting after transpose1, just like how it is done for reshape. However, in the case of chains like

transpose1 -> ... -> transpose2 -> ... -> transpose3

this could require running the current runOnOperation() function until we converge. This can be done by stopping when all transposes that we can successfully collect the fan-ins of have the owner of their first operand being either another TransposeOp or a ReshapeOp, since those are what we propagate to and where we leave behind / insert another TransposeOp. Otherwise, we would could potentially have infinite looping.

This additionally has the implication that we would not replace any transposes and instead we could have canonicalization handle that.

(4)

Add support for more instructions (for example, those that reduce alongside an axis) to be one of the intervening operations in the fan-in cones (other than those with TosaElementwiseOperator trait).

(5)

Support bubbling transposes up to the input parameter. May not need extensive fan-in analysis as no operation cost associated if used elsewhere.


Patch is 87.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108260.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+6)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (+2)
  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+16)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp (+761)
  • (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-dynamic-shapes.mlir (+55)
  • (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-pipeline.mlir (+48)
  • (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-static-shapes.mlir (+552)
  • (added) mlir/test/Dialect/Tosa/tosa-remove-redundant-transposes/tosa-remove-redundant-transposes-unimplemented.mlir (+120)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..64bacd0e432fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -206,6 +206,15 @@ def Tosa_ExplicitValuePadOpQuantInfoBuilder : OpBuilder<
                                          input, paddings, pad_value);
   }]>;
 
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+
+// Permits broadcasting. Elementwise trait is too strict.
+def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Class.
 //===----------------------------------------------------------------------===//
@@ -219,6 +228,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
               DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                         ["inferReturnTypeComponents"]>,
               ResultsBroadcastableShape,
+              TosaElementwiseOperator,
               Pure])> {
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 7ed89bff474a2e..8122752a9f3e1a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -84,6 +84,12 @@ class MulOperandsAndResultElementType
   }
 };
 
+/// This class indicates that an op is tosa-elementwise (permits broadcasting,
+/// unlike Elementwise trait)
+template <typename ConcreteType>
+class TosaElementwiseOperator
+    : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+
 } // namespace tosa
 } // namespace OpTrait
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5..c0913171f9b17f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
 #include "mlir/Pass/Pass.h"
 
@@ -48,6 +49,7 @@ std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 std::unique_ptr<Pass> createTosaOptionalDecompositions();
+std::unique_ptr<Pass> createTosaRemoveRedundantTransposes();
 
 struct ValidationOptions {
   /// Validate if operations match for the given profile.
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a0f670de20150f..66d046c7040b4f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -126,4 +126,20 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
    ];
 }
 
+def TosaRemoveRedundantTransposes : Pass<"tosa-remove-redundant-transposes", "func::FuncOp"> {
+  let summary = "Remove redundant transposes";
+  let description = [{
+    Pass that identifies and removes redundant tosa.TRANSPOSE operations.
+    It does so by traversing dependencies of tosa.TRANSPOSE operations until they terminate in either
+    tosa.RESHAPE, a nullifying tosa.TRANSPOSE, or a tosa.CONST. It then propagates the downstream
+    transform upward through the intervening operators if it is able and replaces the downstream tosa.TRANSPOSE.
+    Results generally better when run after Canonicalization and resolution of dynamic shapes.
+    Canonicalization is required for dead code elimination after pass is run.
+    This pass has an important use-case in cleaning up the results of frameworks that introduce a lot
+    of data-layout transformations when legalizing to TOSA, a common one being transformations between NHWC and NCHW
+    layouts.
+  }];
+  let constructor = "tosa::createTosaRemoveRedundantTransposes()";
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index c78a74b874aff1..624038b9b38981 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaLayerwiseConstantFoldPass.cpp
   TosaMakeBroadcastable.cpp
   TosaOptionalDecompositions.cpp
+  TosaRemoveRedundantTransposes.cpp
   TosaTypeConverters.cpp
   TosaValidation.cpp
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
new file mode 100644
index 00000000000000..06d7754e79a023
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaRemoveRedundantTransposes.cpp
@@ -0,0 +1,761 @@
+//===- TosaRemoveRedundantTransposes.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ----------
+// Motivation:
+// ----------
+
+// Some legalization pathways introduce redundant tosa.TRANSPOSE
+// operations that result in avoidable data movement. For example,
+// PyTorch -> TOSA contains a lot of unnecessary transposes due
+// to conversions between NCHW and NHWC.
+
+// We wish to remove all the ones that we can, since in general
+// it is possible to remove upwards of 90% of these transposes
+// in a provable manner.
+//
+// -------------------
+// High-Level Overview:
+// -------------------
+
+// The pass begins at a downstream transpose with some perms tensor.
+// It traverses the dependencies upward, accepting only TosaElementwise
+// operators. Dependencies must terminate in nullifying transposes (when
+// composed, they form the identity), reshapes, or consts.
+
+// Conceptually, we then "bubble up" the downstream transpose until
+// we hit the sources. For constants, we generate a new constants, composed
+// with the downstream transpose. For nullifying transposes, we "cancel"
+// them. For reshapes, we generally cannot "bubble" through them, so we
+// insert the downstream transpose there.
+
+// We then ensure that we do not cause any duplication by replacing usages
+// of the downstream transpose with the converted value of the operand
+// that feeds into it (after this bubble-up process). We do this by analyzing
+// the dependency fan-ins across all transposes with the same perms tensor
+// in order to ensure that they do not have uses outside this group, which
+// would cause the old code section to remain "live", and not removed by
+// canonicalization.
+
+// --------------
+// Impact of Pass:
+// --------------
+
+// For the ResNet18 network, we are able to reduce it to 5 transposes, from
+// 56 -- with the patching of the torch dense_resource artifacts with dense
+// attributes. Otherwise, without that patch, we reduce to 23, since we cannot
+// fold those artifacts.
+
+// In the second case (56 -> 23), instruction count is reduced by exactly 33.
+// There are 3 transposes that would be removed if we omitted the fan-in
+// analysis, however, with fan-in analysis, we end up with ~15 less operations,
+// due to the lack of duplication.
+
+// For ResNet50, the results are essentially identical.
+
+// For MobilenetV3, we reduce the number of transposes from 82 to 38 without
+// taking care of upstream constants. After also taking care of constants, we
+// reduce it to 20 transposes. The remaining have a use elsewhere outside
+// of the fan-in cones. The pass alone (after --canonicalize is run on the
+// initial network), is responsible for the removal of 48 of the transposes.
+
+// Due to cases where a constant is used purely in its NCHW form without a
+// transpose to NHWC and  also separately used in a place where the downstream
+// converts to NHWC, we do end up with 7 additional constants; however, due to
+// their small size, this has minimal memory footprint.
+
+// -----------
+// Future Work:
+// -----------
+
+// (1)
+
+// Evaluate tradeoffs with the duplication of ConstOp, especially
+// across many downstream transposes with different perms, which can result
+// in the same ConstOp being duplicated (but transposed) multiple times.
+
+// Observe tradeoffs between a lower memory footprint and potentially
+// converting many fan-ins of downstream transposes with the same perms,
+// which if not converted may affect ability of other inter-dependent fan-in
+// to convert.
+
+// (2)
+
+// Restrict the propagation of transposes up their fan-in cone if one
+// of the sources is a ReshapeOp for which the inserted TransposeOp would
+// not be a TransposeOp that lends itself to the TransposeIsReshape
+// Canonicalization, which permits them to be folded to a single ReshapeOp.
+
+// Observe impact on how this restriction may be detrimental to the
+// conversion of other downstream transpose conversions due to the
+// fan-in cone analysis. Additionally, consider cases where there
+// may be multiple upstream transposes that could be removed as a
+// result of this -- and trade that off with how many you would
+// effectively insert if the ReshapeOp/TransposeOp can't be folded
+// to a single ReshapeOp.
+
+// (3)
+
+// Make the pass more general, beyond just allowing upstream transposes
+// to be nullifying. For example,
+
+// transpose1 -> ... -> transpose2
+
+// where transpose2(transpose1) do not cancel to identity.
+
+// This can be done by propagating the downstream transpose up
+// and inserting after transpose1, just like how it is done for
+// reshape. However, in the case of chains like
+
+// transpose1 -> ... -> transpose2 -> ... -> transpose3
+
+// this could require running the current runOnOperation() function
+// until we converge. This can be done by stopping when all transposes
+// that we can successfully collect the fan-ins of have the owner
+// of their first operand being either another TransposeOp or a
+// ReshapeOp, since those are what we propagate to and where we leave
+// behind / insert another TransposeOp. Otherwise, we would could potentially
+// have infinite looping.
+
+// This additionally has the implication that we would not replace any
+// transposes and instead we could have canonicalization handle that.
+
+// (4)
+
+// Add support for more instructions (for example, those that reduce
+// alongside an axis) to be one of the intervening operations in the
+// fan-in cones (other than those with TosaElementwiseOperator trait).
+
+// (5)
+
+// Support bubbling transposes up to the input parameter. May not
+// need extensive fan-in analysis as no operation cost associated
+// if used elsewhere.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <memory>
+#include <set>
+#include <stack>
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSAREMOVEREDUNDANTTRANSPOSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Remove Redundant Transposes Pass.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct TosaRemoveRedundantTransposes final
+    : public tosa::impl::TosaRemoveRedundantTransposesBase<
+          TosaRemoveRedundantTransposes> {
+  void runOnOperation() override;
+
+private:
+  // This will collect all the data dependencies for the given Operation
+  // up to and including ConstOp, ReshapeOp, and TransposeOp.
+  bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
+  bool convertDependentOps(SetVector<Operation *> &dependentOps,
+                           DenseMap<Value, Value> &valuesMap,
+                           IRRewriter &rewriter,
+                           ArrayRef<int32_t> downstreamPerms);
+
+  // Checks if the two permutations, when applied consecutively, result
+  // in the identity.
+  bool areNullifyingTransposes(ArrayRef<int32_t> perms1,
+                               ArrayRef<int32_t> perms2);
+
+  // This is meant to apply to operations with the TosaElementwiseOperator
+  // trait.
+  std::optional<Value>
+  buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+  // This updates valuesMap when we encounter another TransposeOp as a
+  // dependency of the downstream one. %0 = tosa.transpose %arg0 <- applies to
+  // this %1 = tosa.transpose %0 <- when tracking back from this
+  std::optional<Value>
+  buildMappedToValue(tosa::TransposeOp transposeOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+  // Inserts the downstream TransposeOp after the ReshapeOp, since we generally
+  // cannot propagate through it.
+  std::optional<Value>
+  buildMappedToValue(tosa::ReshapeOp reshapeOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+  // We may have something like:
+  // %0 = tosa.const
+  // %1 = tosa.transpose
+  // %2 = tosa.add %0, %1
+  // %3 = tosa.transpose %2
+  // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
+  // in MobilenetV3.
+  std::optional<Value>
+  buildMappedToValue(tosa::ConstOp constOp,
+                     const DenseMap<Value, Value> &valuesMap,
+                     IRRewriter &rewriter, ArrayRef<int32_t> downstreamPerms);
+
+  // Checks which TransposeOp we should "replace", turning their converted
+  // chains of ops, through which they were propagated, "live", and the old code
+  // "dead." Attempts to avoid doing so when doing so would result in the old
+  // code staying "live," resulting in duplication. Relies on --canonicalize to
+  // remove the dead code that results from performing said replacement.
+  std::set<tosa::TransposeOp> getGoodReplacements(
+      ArrayRef<int32_t> perms,
+      std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+          &transposeInfo);
+
+  // Helper function for getGoodReplacements to check if some TransposeOp's
+  // dependencies are OK.
+  bool dependenciesAreValid(
+      ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
+      std::set<tosa::TransposeOp> &validTransposes,
+      std::vector<std::pair<tosa::TransposeOp, SetVector<Operation *>>>
+          &transposeInfo);
+
+  // Applies perms to the DenseElementsAttr.
+  // If it returns std::nullopt, it also triggers pass failure, since verifier
+  // guarantees from TOSA are not in place (and otherwise, if used elsewhere
+  // it should fail).
+  // This is a basic API and may benefit from refactor into the core MLIR APIs.
+  std::optional<DenseElementsAttr>
+  transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
+};
+
+std::optional<DenseElementsAttr>
+TosaRemoveRedundantTransposes::transposeDenseAttribute(
+    DenseElementsAttr input, ArrayRef<int32_t> perms) {
+  RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
+  RankedTensorType newType = RankedTensorType::get(
+      tosa::applyTOSAPermutation(oldType.getShape(), perms),
+      oldType.getElementType());
+  size_t rank = oldType.getRank();
+
+  if (input.isSplat())
+    return input.reshape(newType);
+  // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
+  // 0.
+  // If not in place, something is very wrong.
+  if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
+    signalPassFailure();
+    return std::nullopt;
+  }
+
+  // The algorithm is approximately as follows:
+  // input: perms, input flat array, input tensor type
+  // (1/2) determine the strides of input/output if
+  // they were strided in row-major order. (3) adjust the strides for the
+  // input to be in the same order of indices as the output is written.
+  // (4) process dimension by dimension. example: perms 2, 0, 1; input
+  // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
+  // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
+  // input strides to be as input[i + 12j + 4k] so we may process
+  // layer-by-layer.
+
+  // Step 1/2: Strides for input. We ignore output since row-major and can just
+  // push_back.
+
+  SmallVector<int64_t> originalInputStrides(rank);
+  originalInputStrides[rank - 1] = 1;
+  // index with int64_t to avoid overflow
+  for (int64_t i = rank - 2; i >= 0; i--)
+    originalInputStrides[i] =
+        originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
+
+  // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
+  // output which is done in row-major order.
+
+  SmallVector<int64_t> newInputStrides;
+  newInputStrides.reserve(rank);
+  for (int32_t v : perms)
+    newInputStrides.push_back(originalInputStrides[v]);
+
+  // Step 4: Write out the transposed "flat array" dimension by dimension.
+
+  auto inputArray = input.getValues<Attribute>();
+  SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
+  for (size_t i = 0; i < rank; i++)
+    boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
+
+  SmallVector<Attribute> resultArray;
+  resultArray.reserve(inputArray.size());
+
+  std::function<void(int64_t,
+                     SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
+      processTransposeDim = [&](auto accumulatedIndex, auto it) {
+        if (it == boundsAndStrides.end()) {
+          resultArray.push_back(inputArray[accumulatedIndex]);
+          return;
+        }
+
+        for (int64_t i = 0; i < it->first; i++) {
+          int64_t j = accumulatedIndex + i * it->second;
+          processTransposeDim(j, it + 1);
+        }
+      };
+
+  processTransposeDim(0, boundsAndStrides.begin());
+
+  return DenseElementsAttr::get(newType, resultArray);
+}
+
+// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
+// as the sources of the data dependencies, and TosaElementWiseOperator
+// after that, if the function returns true.
+bool TosaRemoveRedundantTransposes::collectFanIn(
+    Operation *op, SetVector<Operation *> &collected) {
+  // Can occur if defined through the parameter to a func.func.
+  if (!op)
+    return false;
+
+  if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
+    return false;
+
+  // Prevent extra work if already seen.
+  if (collected.contains(op))
+    return true;
+
+  // Throw it out so later don't have to deal with this.
+  if (op->getNumResults() != 1 ||
+      !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
+    return false;
+
+  // We don't wish to traverse up a ReshapeOp,
+  // since generally we can't propagate a TransposeOp through it.
+  // TransposeOp, ReshapeOp, ConstOp will have no in-edges in the data
+  // dependency graph we construct for the downstream TransposeOp.
+  if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
+      !llvm::isa<tosa::ConstOp>(op)) {
+
+    if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
+      return false;
+
+    for (Value operand : op->getOperands()) {
+
+      if (!collectFanIn(operand.getDefiningOp(), collected))
+        return false;
+    }
+  }
+
+  // Insert in topological order.
+  collected.insert(op);
+
+  return true;
+}
+
+// Assuming that due to the verification of TransposeOp
+// perms arrays are permutations of 0 - perms.size() - 1.
+bool TosaRemoveRedundantTransposes::areNullifyingTransposes(
+    ...
[truncated]

Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend merging the test files together. I don't see a lot of value of separating them out from eachother.

@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from f44531b to 425f6e5 Compare September 12, 2024 03:09
@arteen1000
Copy link
Contributor Author

Made changes per @rsuderman request.

@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from 425f6e5 to 2d6e521 Compare September 12, 2024 03:27
@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from 0846498 to aebe42e Compare September 12, 2024 19:08
@arteen1000
Copy link
Contributor Author

Modified per request @joker-eph.

@sjarus
Copy link
Contributor

sjarus commented Sep 12, 2024

@arteen1000 is away from desk and had collected this data. 'Data movement' below refers to transposed data, counting only axes that are transposed, of all the tosa.transpose() instances in the IR.

These tests were done as part of functional and performance experiments through the TorchToTosa path on Torch-MLIR to validate that the final results are functionally unaffected.

///////////
MobilenetV3:
///////////

BEFORE total data movement: 11798776 B (11.25 MiB)
AFTER total data movement: 2998016 B (2.86 MiB)
74.6% of data movement removed.

BEFORE transposes: 82
AFTER transposes: 20
75.6% of transposes removed.

////////
ResNet18:
////////

BEFORE total data movement: 20596556 B (19.64 MiB)
AFTER total data movement: 1003520 B (0.96 MiB)
95.2% of data movement removed.

BEFORE transposes: 56
AFTER transposes: 5
91.1% of transposes removed.

////////
ResNet50:
////////

BEFORE total data movement: 83236172 B (79.3 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
96.4% of data movement removed

BEFORE transposes: 120
AFTER transposes: 7
94.2% of transposes removed.

/////////
ResNet101:
/////////

BEFORE total data movement: 124336460 B (118.58 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
97.6% of data movement removed

BEFORE transposes: 239
AFTER transposes: 7
97.1% of transposes removed.

/////////
ResNet152:
/////////

BEFORE total data movement: 175052108 B (166.94 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
98.3% of data movement removed

BEFORE transposes: 358
AFTER transposes: 7
98.0% of transposes removed.

@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from aebe42e to ee9b2b0 Compare September 12, 2024 20:23
@sjarus sjarus requested a review from joker-eph September 12, 2024 23:32
@jpienaar
Copy link
Member

Funnily I've seen two other teams do almost the same thing recently. @cathyzhyi you have some context of this pass in case you could take a look too.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some drive-by comments, I only skimmed the code: kudos for the great comments all along!

@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from ee9b2b0 to 8fde469 Compare September 13, 2024 14:31
Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from 8fde469 to 509076e Compare September 13, 2024 20:49
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good as start. Remember to update the PR description before landing (and just regular run through clang-tidy and clang-format to catch small changes)

----------
Motivation:
----------

Some legalization pathways introduce redundant tosa.TRANSPOSE
operations that result in avoidable data movement. For example,
PyTorch -> TOSA contains a lot of unnecessary transposes due
to conversions between NCHW and NHWC.

We wish to remove all the ones that we can, since in general
it is possible to remove the overwhelming majority.

------------
Changes Made:
------------

- Add the --tosa-reduce-transposes pass
- Add TosaElementwiseOperator trait.

-------------------
High-Level Overview:
-------------------

The pass works through the transpose operators in the program. It begins at some
transpose operator with an associated permutations tensor. It traverses upwards
through the dependencies of this transpose and verifies that we encounter only
operators with the TosaElementwiseOperator trait and terminate in either
constants, reshapes, or transposes.

We then evaluate whether there are any additional restrictions (the transposes
it terminates in must invert the one we began at, and the reshapes must be ones
in which we can fold the transpose into), and then we hoist the transpose through
the intervening operators, folding it at the constants, reshapes, and transposes.

Finally, we ensure that we do not need both the transposed form (the form that
had the transpose hoisted through it) and the untransposed form (which it was prior),
by analyzing the usages of those dependent operators of a given transpose we are
attempting to hoist and replace.

If they are such that it would require both forms to be necessary, then we do not
replace the hoisted transpose, causing the new chain to be dead. Otherwise, we do
and the old chain (untransposed form) becomes dead. Only one chain will ever then
be live, resulting in no duplication.

We then perform a simple one-pass DCE, so no canonicalization is necessary.

--------------
Impact of Pass:
--------------

Patching the dense_resource artifacts (from PyTorch) with dense attributes to
permit constant folding, we receive the following results.

Note that data movement represents total transpose data movement, calculated
by noting which dimensions moved during the transpose.

///////////
MobilenetV3:
///////////

BEFORE total data movement: 11798776 B (11.25 MiB)
AFTER total data movement: 2998016 B (2.86 MiB)
74.6% of data movement removed.

BEFORE transposes: 82
AFTER transposes: 20
75.6% of transposes removed.

////////
ResNet18:
////////

BEFORE total data movement: 20596556 B (19.64 MiB)
AFTER total data movement: 1003520 B (0.96 MiB)
95.2% of data movement removed.

BEFORE transposes: 56
AFTER transposes: 5
91.1% of transposes removed.

////////
ResNet50:
////////

BEFORE total data movement: 83236172 B (79.3 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
96.4% of data movement removed

BEFORE transposes: 120
AFTER transposes: 7
94.2% of transposes removed.

/////////
ResNet101:
/////////

BEFORE total data movement: 124336460 B (118.58 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
97.6% of data movement removed

BEFORE transposes: 239
AFTER transposes: 7
97.1% of transposes removed.

/////////
ResNet152:
/////////

BEFORE total data movement: 175052108 B (166.94 MiB)
AFTER total data movement: 3010560 B (2.87 MiB)
98.3% of data movement removed

BEFORE transposes: 358
AFTER transposes: 7
98.0% of transposes removed.

////////
Overview:
////////

We see that we remove up to 98% of transposes and eliminate
up to 98.3% of redundant transpose data movement.

In the context of ResNet50, with 120 inferences per second,
we reduce dynamic transpose data bandwidth from 9.29 GiB/s
to 344.4 MiB/s.

-----------
Future Work:
-----------

(1) Evaluate tradeoffs with permitting ConstOp to be duplicated across hoisted
    transposes with different permutation tensors.

(2) Expand the class of foldable upstream ReshapeOp we permit beyond
    N -> 1x1x...x1xNx1x...x1x1.

(3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
    those that form the identity.

(4) Add support for more instructions besides TosaElementwiseOperator as
    the intervening ones (for example, the reduce_* operators).

(5) Support hoisting transposes up to an input parameter.

Signed-off-by: Arteen Abrishami <[email protected]>
@arteen1000 arteen1000 force-pushed the tosa-remove-redundant-transposes branch from 509076e to a4ec527 Compare September 13, 2024 21:08
@arteen1000 arteen1000 changed the title [MLIR][TOSA] Add --tosa-remove-redundant-transposes pass [MLIR][TOSA] Add --tosa-reduce-transposes pass Sep 13, 2024
@sjarus
Copy link
Contributor

sjarus commented Sep 14, 2024

With the concurrence of Rob, Mehdi and Jacques I'm going to proceed and land this. Thanks for this work, @arteen1000 !

@sjarus sjarus merged commit 00f239e into llvm:main Sep 14, 2024
8 checks passed
@arteen1000 arteen1000 deleted the tosa-remove-redundant-transposes branch September 14, 2024 06:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants