Skip to content

Fixes in 'tosa.reshape' lowering and folder #85798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 26, 2024

Conversation

rafaelubalmw
Copy link
Contributor

This pull request addresses missing features in the lowering conversion pattern for the tosa.reshape op and a bug in its canonicalizer.

  • Example of a valid use of tosa.reshape previously not supported:
func.func @main(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
  %0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
  return %0 : tensor<1x3x2x1xf32>
}
  • The new lowering is based on the use of tensor.reshape instead of a combination of tensor.collapse_shape + tensor.expand_shape.

  • When no -1 placeholder is present in the new_shape attribute, the target shape is encoded with an arith.constant op and the reshape occurs with a tensor.reshape op.

  • When a -1 placeholder is used in new_shape and the corresponding dimension in the result type is dynamic, the missing dimension size is inferred by calculating the input tensor size (tensor.collapse_shape + tensor.dim) and dividing it by the product of all other target dimension sizes (arith.divui).

  • When a -1 placeholder is used in new_shape and the corresponding dimension in the result type is static, the missing dimension size is grabbed from the result type.

  • Fixed bug in canonicalization pattern tosa::ReshapeOp::fold(). The input and result types being equal is not a sufficient condition for folding. If there is more than 1 dynamic dimension in the input and result types, a productive reshape could still occur. Unit tests are now available for 1 (existing test) and 2 (new test) dynamic dimensions.

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Rafael Ubal (rafaelubalmw)

Changes

This pull request addresses missing features in the lowering conversion pattern for the tosa.reshape op and a bug in its canonicalizer.

  • Example of a valid use of tosa.reshape previously not supported:
func.func @<!-- -->main(%input: tensor&lt;?x?x?xf32&gt;) -&gt; tensor&lt;1x3x2x1xf32&gt; {
  %0 = tosa.reshape %input {new_shape = array&lt;i64: 1, 3, 2, 1&gt;} : (tensor&lt;?x?x?xf32&gt;) -&gt; tensor&lt;1x3x2x1xf32&gt;
  return %0 : tensor&lt;1x3x2x1xf32&gt;
}
  • The new lowering is based on the use of tensor.reshape instead of a combination of tensor.collapse_shape + tensor.expand_shape.

  • When no -1 placeholder is present in the new_shape attribute, the target shape is encoded with an arith.constant op and the reshape occurs with a tensor.reshape op.

  • When a -1 placeholder is used in new_shape and the corresponding dimension in the result type is dynamic, the missing dimension size is inferred by calculating the input tensor size (tensor.collapse_shape + tensor.dim) and dividing it by the product of all other target dimension sizes (arith.divui).

  • When a -1 placeholder is used in new_shape and the corresponding dimension in the result type is static, the missing dimension size is grabbed from the result type.

  • Fixed bug in canonicalization pattern tosa::ReshapeOp::fold(). The input and result types being equal is not a sufficient condition for folding. If there is more than 1 dynamic dimension in the input and result types, a productive reshape could still occur. Unit tests are now available for 1 (existing test) and 2 (new test) dynamic dimensions.


Patch is 25.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85798.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+78-195)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+4-1)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+144-23)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+8)
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 505d85f211111c..62ed41ebda4f50 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -19,217 +19,98 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace tosa;
 
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
-
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
-
-  unsigned currLhsDim = 0, currRhsDim = 0;
-  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
-    int64_t rhsSize = rhsShape[currRhsDim];
-    int64_t lhsSize = lhsShape[currLhsDim];
-    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
-           currRhsDim < rhsShape.size()) {
-      if (lhsSize < rhsSize) {
-        currLhsDim++;
-        if (currLhsDim < lhsShape.size()) {
-          lhsSize *= lhsShape[currLhsDim];
-        }
-      } else {
-        currRhsDim++;
-        if (currRhsDim < rhsShape.size()) {
-          rhsSize *= rhsShape[currRhsDim];
-        }
-      }
-    }
-    if (lhsSize == rhsSize) {
-      intermediateShape.push_back(lhsSize);
-    }
-    currRhsDim++;
-    currLhsDim++;
-  }
-
-  // If the iterators didn't reach the end and their leftover dimensions are not
-  // equal to 1 an intermediate shape was not found.
-  while (currLhsDim < lhsShape.size()) {
-    if (lhsShape[currLhsDim++] != 1) {
-      return false;
-    }
-  }
-
-  while (currRhsDim < rhsShape.size()) {
-    if (rhsShape[currRhsDim++] != 1) {
-      return false;
-    }
-  }
-
-  return true;
+static Value getIndexConstant(OpBuilder& builder, Location loc, int64_t index) {
+  return builder.create<arith::ConstantIndexOp>(loc, index);
 }
 
-static bool createReassociationMapsForCollapse(
-    PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
-    ArrayRef<int64_t> dstShape,
-    SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
-
-  // If the shape is dynamic, create a map for collapsing into one dimension.
-  if (isDynamic) {
-    SmallVector<AffineExpr, 2> exprs;
-    for (int i = 0, s = srcShape.size(); i < s; ++i)
-      exprs.push_back(rewriter.getAffineDimExpr(i));
-    reassociationMap = {exprs};
-    return true;
-  }
-
-  if (dstShape.empty()) {
-    reassociationMap = {};
-    return true;
-  }
-
-  reassociationMap.resize(dstShape.size());
-  unsigned currSrcDim = 0, currDstDim = 0;
-  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
-    int64_t dstSize = dstShape[currDstDim];
-    int64_t srcSize = srcShape[currSrcDim];
-    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      srcSize *= srcShape[currSrcDim];
-    }
-    if (srcSize == dstSize) {
-      reassociationMap[currDstDim].push_back(
-          rewriter.getAffineDimExpr(currSrcDim++));
-      // If the next dim in collapsedShape is not 1, treat subsequent dims in
-      // expandedShape which are 1 to be collapsed.
-      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
-        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
-          reassociationMap[currDstDim].push_back(
-              rewriter.getAffineDimExpr(currSrcDim++));
-        }
-      }
-    }
-    currDstDim++;
+// Return the total size of the given input tensor.
+static Value getTensorSize(OpBuilder& builder, Location loc, TypedValue<TensorType> input) {
+  // If the input tensor is statically shaped, return its size as a constant.
+  if (input.getType().hasStaticShape()) {
+    auto shape = input.getType().getShape();
+    auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies());
+    return getIndexConstant(builder, loc, size);
   }
 
-  // If both iterators didn't reach the end, we have leftover dimentions which
-  // implies that we have a mismatch in shape.
-  return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
+  // When the input tensor has at least one dynamic dimension, collapse it into
+  // a 1D tensor and get its size.
+  auto rank = input.getType().getRank();
+  auto elementType = input.getType().getElementType();
+  auto collapsedType = RankedTensorType::get({ShapedType::kDynamic}, elementType);
+  auto reassociationIndices = SmallVector<ReassociationIndices>{
+    llvm::to_vector(llvm::seq<int64_t>(rank))
+  };
+  auto collapsed = builder.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, input, reassociationIndices);
+  return builder.create<tensor::DimOp>(loc, collapsed, 0);
 }
 
-namespace {
-Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
-                     ShapedType resultTy, Value operand) {
-  ShapedType operandTy = cast<ShapedType>(operand.getType());
-  if (resultTy == operandTy)
-    return operand;
-
-  bool isDynamic = !operandTy.hasStaticShape();
-
-  if (isDynamic && resultTy.getRank() != 1) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "Cannot collapse dynamic dims to more than one dimension");
-    return {};
-  }
-
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
-                                          resultTy.getShape(),
-                                          reassociationMap, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Attempting to collapse into an incompatible shape");
-    return {};
-  }
-
-  SmallVector<int64_t> intermediateShape;
-  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                             intermediateShape, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Cannot collapse into given shape");
-    return {};
-  }
-  return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
-                                                  reassociationMap);
+// Compute the dimension size of the result tensor corresponding to the
+// placeholder value set to -1 in the 'new_shape' attribute of a 'tosa.reshape'
+// op. Argument 'index' indicates the position of the -1 placeholder.
+static Value getReshapePlaceholderDimSize(OpBuilder &builder,
+                                          tosa::ReshapeOp reshape,
+                                          int64_t index) {
+  auto loc = reshape.getLoc();
+  auto input = reshape.getInput1();
+  auto newShape = reshape.getNewShape();
+  auto resultType = reshape.getResult().getType();
+
+  // If the corresponding dimension in the result type is static, take the
+  // dimension size from there.
+  assert(newShape[index] == -1);
+  if (!resultType.isDynamicDim(index))
+    return getIndexConstant(builder, loc, resultType.getDimSize(index));
+
+  // Calculate the product of all dimensions in the new shape. We expect to have
+  // exactly one size set to -1, so we can discard this component by just
+  // negating the final product.
+  auto newSizeLiteral = -std::accumulate(newShape.begin(), newShape.end(), 1,
+                                         std::multiplies<int64_t>());
+  assert(newSizeLiteral >= 0);
+  auto newSize = builder.create<arith::ConstantIndexOp>(loc, newSizeLiteral);
+
+  // Avoid a division by zero. If any of the given dimension sizes was set to
+  // zero, set the placeholder size to zero, too.
+  if (newSizeLiteral == 0)
+    return newSize;
+
+  // The size of the placeholder dimension is the size of the input tensor
+  // divided by all non-placeholder dimension sizes.
+  auto inputSize = getTensorSize(builder, loc, input);
+  return builder.createOrFold<arith::DivUIOp>(loc, inputSize, newSize);
 }
 
-Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
-                   ShapedType resultTy, Value operand) {
-  ShapedType operandTy = cast<ShapedType>(operand.getType());
-  if (resultTy == operandTy)
-    return operand;
-
-  bool isDynamic = !operandTy.hasStaticShape();
-
-  if (isDynamic && operandTy.getRank() != 1) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "Cannot expand dynamic dims from more than one dimension");
-    return {};
-  }
-
-  SmallVector<ReassociationExprs, 4> reassociationMap;
-  if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
-                                          operandTy.getShape(),
-                                          reassociationMap, isDynamic)) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Attempting to expand into an incompatible shape");
-    return {};
-  }
-
-  SmallVector<int64_t> intermediateShape;
-  if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
-                             intermediateShape, isDynamic) ||
-      intermediateShape != operandTy.getShape()) {
-    (void)rewriter.notifyMatchFailure(
-        loc, "tosa.reshape Cannot expand into given shape");
-    return {};
-  }
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
-                                                reassociationMap);
-}
+namespace {
 
-class ReshapeConverterCollapseExpand
-    : public OpConversionPattern<tosa::ReshapeOp> {
+class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
 public:
   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
-    ShapedType resultTy = cast<ShapedType>(reshape.getType());
-    bool isDynamic = !operandTy.hasStaticShape();
-
-    SmallVector<int64_t> intermediateShape;
-    if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
-                               intermediateShape, isDynamic)) {
-      return rewriter.notifyMatchFailure(
-          reshape, "tosa.reshape Cannot identify an intermediate shape between "
-                   "the given two shapes");
+    auto loc = reshape.getLoc();
+    auto input = reshape.getInput1();
+
+    // Create list of values for new shape
+    SmallVector<Value> newShapeVector(reshape.getNewShape().size());
+    for (auto [index, size] : llvm::enumerate(reshape.getNewShape())) {
+      newShapeVector[index] = size == -1 ?
+          getReshapePlaceholderDimSize(rewriter, reshape, index) :
+          getIndexConstant(rewriter, loc, size);
     }
-    auto intermediateTy = RankedTensorType::get(
-        intermediateShape, reshape.getType().getElementType());
 
-    Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
-                                    adaptor.getInput1());
-    if (!collapse)
-      return failure();
-
-    Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
-    if (!expand)
-      return failure();
-
-    rewriter.replaceOp(reshape, expand);
+    // Reshape tensor
+    auto newShapeTensor = rewriter.createOrFold<tensor::FromElementsOp>(
+        loc, newShapeVector);
+    rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
+        reshape, reshape.getResult().getType(), input, newShapeTensor);
     return success();
   }
 };
@@ -416,8 +297,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<SliceConverter, PadConverter, ConcatConverter>(
-      patterns->getContext());
-
-  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+  patterns->add<
+    ConcatConverter,
+    PadConverter,
+    ReshapeConverter,
+    SliceConverter
+  >(patterns->getContext());
 }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4c50aaecfe9488..d23c9fe824c94a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (!inputTy || !outputTy)
     return {};
 
-  if (inputTy == outputTy)
+  // Fold when the input and output types are the same. This is only safe when
+  // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
+  // there may still be a productive reshape.
+  if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
     return getInput1();
 
   // reshape(reshape(x)) -> reshape(x)
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..e1fd7838293b6a 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -1,11 +1,15 @@
 // RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
 
+// -----
+
 // CHECK-LABEL: @test_reshape_downrank
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.+]] = arith.constant dense<6> : tensor<1xindex>
+  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x3xf32>, tensor<1xindex>) -> tensor<6xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6>} : (tensor<2x3xf32>) -> tensor<6xf32>
-  // CHECK: return [[RESHAPE]]
+
+  // CHECK: return %[[RESHAPE]] : tensor<6xf32>
   return %0 : tensor<6xf32>
 }
 
@@ -14,9 +18,16 @@ func.func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
 // CHECK-LABEL: @test_reshape_downrank_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0, 1]] : tensor<2x?xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+
+  // CHECK-DAG: %[[SHAPE:.+]] = tensor.from_elements %[[SIZE]] : tensor<1xindex>
+  // CHECK-DAG: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<2x?xf32>, tensor<1xindex>) -> tensor<?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<2x?xf32>) -> tensor<?xf32>
-  // CHECK: return [[RESHAPE]]
+
+  // CHECK: return %[[RESHAPED]] : tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
@@ -25,9 +36,10 @@ func.func @test_reshape_downrank_dyn(%arg0: tensor<2x?xf32>) -> tensor<?xf32> {
 // CHECK-LABEL: @test_reshape_uprank
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.+]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+  // CHECK: %[[RESHAPE:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<6xf32>, tensor<2xindex>) -> tensor<2x3xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
-  // CHECK: return [[RESHAPE]]
+  // CHECK: return %[[RESHAPE]] : tensor<2x3xf32>
   return %0 : tensor<2x3xf32>
 }
 
@@ -36,57 +48,166 @@ func.func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
 // CHECK-LABEL: @test_reshape_uprank_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
 func.func @test_reshape_uprank_dyn(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
-  // CHECK: [[RESHAPE:%.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C2_0:.+]] = arith.constant 2 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{[\[]}}[0]] : tensor<?xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.+]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.+]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+  // CHECK: %[[SHAPE:.+]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+  // CHECK: %[[RESHAPED:.+]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?xf32>, tensor<2xindex>) -> tensor<2x?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
-  // CHECK: return [[RESHAPE]]
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @test_reshape_samerank
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xf32>)
 func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[2, 3]> : tensor<2xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<3x2xf32>, tensor<2xindex>) -> tensor<2x3xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
+  
+  // CHECK: return %[[RESHAPED]] : tensor<2x3xf32>
   return %0 : tensor<2x3xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @test_reshape_samerank_dyn
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x2xf32>)
 func.func @test_reshape_samerank_dyn(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
-  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
-  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]]
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index
+
+  // CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+  // CHECK-DAG: %[[PLACEHOLDER:.*]] = arith.divui %[[SIZE]], %[[C2_0]] : index
+
+  // CHECK-DAG: %[[SHAPE:.*]] = tensor.from_elements %[[C2]], %[[PLACEHOLDER]] : tensor<2xindex>
+  // CHECK-DAG: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<?x2xf32>, tensor<2xindex>) -> tensor<2x?xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
-  // CHECK-NEXT: return %[[RESHAPE2]]
+
+  // CHECK: return %[[RESHAPED]] : tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6D
+// CHECK-LABEL: @test_reshape_downrank_6d
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4, 5]]
+func.func @test_reshape_downrank_6d(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  // CHECK: %[[SHAPE:.*]] = arith.constant dense<[6, 5, 77]> : tensor<3xindex>
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[ARG0]](%[[SHAPE]]) : (tensor<1x2x3x5x7x11xf32>, tensor<3xindex>) -> tensor<6x5x77xf32>
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 6, 5, 77>} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  
+  // CHECK: return %[[RESHAPED]] : tensor<6x5x77xf32>
   return %0 : tensor<6x5x77xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_reshape_downrank_6D_dyn
+// CHECK-LABEL: @test_reshape_downrank_6d_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
-  // CHECK: tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3, 4, 5]]
-  // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2]]
+func.func @test_reshape_downrank_6d_dy...
[truncated]

Copy link

github-actions bot commented Mar 19, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 3a2c70b3713a856ea416d92abdddb7893fca308b 9500552edb844d1e36e5fcd1ee4fb0f94925f0c9 -- mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 11ba98ddf3..bac3c8dda7 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -35,10 +35,10 @@ TensorType inferReshapeInputType(TypedValue<TensorType> input,
     return input.getType();
 
   // The input type must be cast into a tensor with the same rank and all static
-  // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
-  // op that converts a dynamically shaped tensor into a 0D tensor. While such
-  // construct is not incorrect on its own, bufferization cannot properly handle
-  // it at the moment, so we avoid it.
+  // dimensions set to 1. This prevents the generation of a
+  // tensor.collapse_shape op that converts a dynamically shaped tensor into a
+  // 0D tensor. While such construct is not incorrect on its own, bufferization
+  // cannot properly handle it at the moment, so we avoid it.
   SmallVector<int64_t> shape(input.getType().getRank(), 1);
   return input.getType().clone(shape);
 }
@@ -55,41 +55,43 @@ TensorType inferReshapeExpandedType(TensorType inputType,
   // Check if the input is static, and if so, get its total size
   bool inputIsStatic = inputType.hasStaticShape();
   int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
- 
+
   // Compute result shape
   bool resultIsStatic = true;
-  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
-    // If this is not a placeholder, do not change it
-    if (size >= 0)
-      return size;
-
-    // If we do not know the total size of the tensor, keep this dimension
-    // dynamic in the result shape.
-    if (!inputIsStatic) {
-      resultIsStatic = false;
-      return ShapedType::kDynamic;
-    }
+  auto resultShape =
+      llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+        // If this is not a placeholder, do not change it
+        if (size >= 0)
+          return size;
+
+        // If we do not know the total size of the tensor, keep this dimension
+        // dynamic in the result shape.
+        if (!inputIsStatic) {
+          resultIsStatic = false;
+          return ShapedType::kDynamic;
+        }
 
-    // Calculate the product of all elements in 'newShape' except for the -1
-    // placeholder, which we discard by negating the result.
-    int64_t totalSizeNoPlaceholder = -std::accumulate(
-        newShape.begin(), newShape.end(), 1, std::multiplies());
+        // Calculate the product of all elements in 'newShape' except for the -1
+        // placeholder, which we discard by negating the result.
+        int64_t totalSizeNoPlaceholder = -std::accumulate(
+            newShape.begin(), newShape.end(), 1, std::multiplies());
 
-    // If there is a 0 component in 'newShape', resolve the placeholder as 0.
-    if (totalSizeNoPlaceholder == 0)
-      return 0;
+        // If there is a 0 component in 'newShape', resolve the placeholder as
+        // 0.
+        if (totalSizeNoPlaceholder == 0)
+          return 0;
 
-    // Resolve the placeholder as the quotient between the total tensor size and
-    // the product of all other sizes.
-    return totalSize / totalSizeNoPlaceholder;
-  });
+        // Resolve the placeholder as the quotient between the total tensor size
+        // and the product of all other sizes.
+        return totalSize / totalSizeNoPlaceholder;
+      });
 
   // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
   // shaped input from being reshaped into a statically shaped result. We may
   // simply turn the first result dimension dynamic to address this.
   if (!inputIsStatic && resultIsStatic)
     resultShape[0] = ShapedType::kDynamic;
-  
+
   // The 'tensor.expand_shape' op also forbids a statically shaped input from
   // being reshaped into a dynamically shaped result, but the placeholder
   // inference algorithm above guarantees that this will never be the case.
@@ -108,7 +110,8 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
   if (lhsShape.empty() || rhsShape.empty())
     return lhsType.clone(ArrayRef<int64_t>{});
 
-  if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
+  if (ShapedType::isDynamicShape(lhsShape) ||
+      ShapedType::isDynamicShape(rhsShape))
     return lhsType.clone({ShapedType::kDynamic});
 
   SmallVector<int64_t> intermediateShape;
@@ -145,19 +148,21 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
   for (; currRhsDim < rhsShape.size(); currRhsDim++) {
     assert(rhsShape[currRhsDim] == 1);
   }
-  
+
   return lhsType.clone(intermediateShape);
 }
 
 SmallVector<ReassociationExprs>
-createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
+createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
+                                  Type dstType) {
   auto srcShape = cast<TensorType>(srcType).getShape();
   auto dstShape = cast<TensorType>(dstType).getShape();
 
   if (srcShape.empty() || dstShape.empty())
     return {};
 
-  if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
+  if (ShapedType::isDynamicShape(srcShape) ||
+      ShapedType::isDynamicShape(dstShape)) {
     assert(dstShape.size() == 1);
     SmallVector<AffineExpr, 2> exprs;
     for (auto i : llvm::seq<int64_t>(srcShape.size()))
@@ -235,14 +240,16 @@ public:
     auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
 
     // Cast input if needed
-    auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
+    auto castInput =
+        rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
 
     // Emit collaspe-expand pair
     auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
     auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
 
     // Cast to final result type if needed
-    auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
+    auto result =
+        rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
     rewriter.replaceOp(reshape, result);
     return success();
   }
@@ -430,10 +437,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<
-    ConcatConverter,
-    PadConverter,
-    ReshapeConverter,
-    SliceConverter
-  >(patterns->getContext());
+  patterns
+      ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+          patterns->getContext());
 }

@sabauma
Copy link
Contributor

sabauma commented Mar 19, 2024

I think it is worth considering the downstream implications of changing the lowering from tensor.collapse_shape/tensor.expand_shape to tensor.reshape. Some existing optimizations rely on the fact that tensor.expand_shape and tensor.collapse_shape are easier to reason about. The --linalg-fuse-elementwise-ops pass and its associated rewrite patterns comes to mind as a good example.

For instance, the following example is able to produce a single fused linalg.generic with the existing lowering, but will no longer fuse with this change.

// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(tosa-to-tensor,tosa-to-linalg,linalg-fuse-elementwise-ops))"

func.func @main(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<10x10xf32> {
  %add1 = tosa.add %arg0, %arg1 : (tensor<100xf32>, tensor<100xf32>) -> (tensor<100xf32>)

  %2 = "tosa.reshape"(%add1) {new_shape = array<i64: 10, 10>} : (tensor<100xf32>) -> tensor<10x10xf32>
  %3 = "tosa.reshape"(%arg1) {new_shape = array<i64: 10, 10>} : (tensor<100xf32>) -> tensor<10x10xf32>

  %add2 = tosa.add %2, %3 : (tensor<10x10xf32>, tensor<10x10xf32>) -> (tensor<10x10xf32>)
  return %add2 : tensor<10x10xf32>
}

There may be other fallout that other reviewers are aware of. Maybe @eric-k256 or @rsuderman could chime in with more examples, or know those who would know.

@krzysz00
Copy link
Contributor

I'll note that I am one of those downstreams who's relying on collapse/expand pairs for tosa.reshape, but I didn't raise it because I figured it'd be resolvable.

If there isn't already a pass that rewrites tensor.reshape to tensor.collapse_shape and tensor.expand_shape where possible, we might want to add one, since that could then be slid into the Tosa to Linalg pipeline to preserve the existing behavior.

@eric-k256 eric-k256 requested a review from sjarus March 19, 2024 21:01
@rafaelubalmw
Copy link
Contributor Author

rafaelubalmw commented Mar 21, 2024

@sabauma @krzysz00: Thank you very much for your insightful feedback. The advantages of sticking to the collapse-expand pattern seemed a good enough reason to revert my lowering strategy and instead focus on enhancing it to support all tosa.reshape cases. The new version differs almost completely from my original PR, so I'd appreciate if you could take another look. Here's a brief description:

  • I modified existing auxiliary functions to avoid double-checking for invariants guaranteed by the tosa.reshape op verifier, such as target shape compatibility. At conversion, these invariants are checked with assert's rather than pattern match failures.

  • Previously unsupported cases were caused by failures to guarantee type consistency in the emitted tensor.expand_shape op (i.e., input and output shapes must be both static or both dynamic). The result type for this op is now manually constructed, and an additional tensor.cast op is emitted if this type differs from the tosa.reshape result type.

  • An extended set of tests are intended to cover relevant conversion paths. Tests are named using patten test_reshape_<rank>_{up|down|same}_{s2s|s2d|d2s|d2d}_{explicit|auto}[_empty][_identity], where:

    • <rank> is the input rank (e.g., 3d, 6d)
    • {up|down|same} indicates whether the reshape increases, decreases, or retains the input rank.
    • {s2s|s2d|d2s|d2d} indicates whether reshape converts a statically shaped input to a statically shaped result (s2s), a statically shaped input to a dynamically shaped result (s2d), etc.
    • {explicit|auto} is used to indicate that all values in the new_shape attribute are >=0 (explicit) or that a -1 placeholder value is used (auto).
    • empty is used to indicate that new_shape includes a component set to 0.
    • identity is used when the input and result shapes are the same.

Other considerations regarding Krzysztof's feedback, to the extent that it is still applicable:

  • I created tensor types using clone(), as suggested.

  • I avoided the use of auto for basic types. I agree it helps to spell these out, especially when it comes to integer sizes. I did leave all other uses of auto (SmallVector, Location, Value, ...) as it seems consistent with other occurrences in the repo.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

@rafaelubalmw
Copy link
Contributor Author

As observed by @sabauma, 0D tensors were not properly handled. I updated the PR as follows:

  • I added additional checks for this special case.

  • I added the following unit tests:

    • reshape_0d_up_s2s_explicit
    • reshape_0d_up_s2s_auto
    • reshape_0d_up_s2d_explicit
    • reshape_0d_up_s2d_auto
    • reshape_1d_down_s2s_explicit
    • reshape_1d_down_d2s_explicit
    • reshape_4d_down_d2s_explicit
  • The new code led me to the observation that bufferization does not properly handle a tensor.collapse_shape op producing a 0D tensor from a dynamically shaped one. It looks like memref.collapse_shape does not allow it. While the proper way to address this would involve releasing the memref.collapse_shape restriction and verifying correct bufferization for this case, I'll leave that as possible future work. For now, I just avoid this scenario by casting the input if necessary (see inferReshapeInputType()).

  • I renamed all the type inference functions and slightly modified their interfaces for symmetry.

Who knew there was so much to consider for a simple tensor reshape... :) Please let me know what you think.

Copy link
Contributor

@sabauma sabauma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the followup. Looks good to me.

// Fold when the input and output types are the same. This is only safe when
// there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
// there may still be a productive reshape.
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the earlier comment, this restriction will be relaxed. This would probably necessitate dynamic cases to be predicated on something like tensor.dim for later resolution.

However, for the purposes of this PR this code looks fine.

@rafaelubalmw rafaelubalmw merged commit 26d896f into llvm:main Mar 26, 2024
@rafaelubalmw rafaelubalmw deleted the tosa-reshape-lowering branch March 26, 2024 14:53
@bjacob
Copy link
Contributor

bjacob commented Apr 2, 2024

@rafaelubalmw @sabauma @sjarus , I have run into a regression #87396 while integrating in a downstream (IREE).

bjacob added a commit to iree-org/llvm-project that referenced this pull request Apr 2, 2024
bjacob added a commit to iree-org/iree that referenced this pull request Apr 3, 2024
IREE-side changes to adapt to MLIR changes:
1. `initializeOptions` changes to adapt to
llvm/llvm-project#87289
2. `enableFastMathMode` removal:
llvm/llvm-project#86578.
3. Bazel changes to adapt to
llvm/llvm-project#86819

IREE-side fixes for preexisting bugs revealed by a MLIR change:
1. `mlp_tosa` test fix: the shapes were inconsistent, used to
accidentally work, until MLIR started catching it since
llvm/llvm-project#85798. See diagnostic in
[87396](llvm/llvm-project#87396 (comment)).
FYI @MaheshRavishankar.

IREE-side fixes accidentally lumped into this:
1. The `iree_copts.cmake` change: It just happens that my bleeding-edge
Clang was updated and started diagnosing some code relying on C++20
semantics. Filed #16946 as TODO.

---------

Co-authored-by: Scott Todd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants