Skip to content

[mlir][tosa] Implement dynamic shape support for tosa.max_pool2d lowering #87538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Apr 3, 2024

The existing lowering for tosa.max_pool2d only supports dynamic dimensions when the dynamic dimension is the batch dimension. This change updates the lowering to support arbitrary dynamic dimensions on the inputs and outputs of the tosa.max_pool2d operation.

This change also fixes a bug in the implementation of implicit broadcasting in the tosa-to-linalg pass, which was introducing uses of constant ops that violated dominance requirements.

@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-execution-engine

Author: Spenser Bauman (sabauma)

Changes

The existing lowering for tosa.max_pool2d only supports dynamic dimensions when the dynamic dimension is the batch dimension. This change updates the lowering to support arbitrary dynamic dimensions on the inputs and outputs of the tosa.max_pool2d operation.

This change also fixes a bug in the implementation of implicit broadcasting in the tosa-to-linalg pass, which was introducing uses of constant ops that violated dominance requirements.


Patch is 23.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87538.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+5-5)
  • (modified) mlir/include/mlir/ExecutionEngine/CRunnerUtils.h (+1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-4)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+64-24)
  • (modified) mlir/lib/ExecutionEngine/CRunnerUtils.cpp (+1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+54)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+8-4)
  • (added) mlir/test/Integration/Dialect/Tosa/CPU/test-maxpool-dynamic.mlir (+83)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cff3de0a69af95..3687891fe4b7cf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -130,11 +130,11 @@ def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
 // to not include any remaining unranked tensors.
 def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
 
-def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>]>;
-def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>]>;
+def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
+def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
 
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index 812f719e723eb3..1f12958532943f 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -461,6 +461,7 @@ memrefCopy(int64_t elemSize, ::UnrankedMemRefType<char> *src,
 //===----------------------------------------------------------------------===//
 // Small runtime support library for vector.print lowering during codegen.
 //===----------------------------------------------------------------------===//
+extern "C" MLIR_CRUNNERUTILS_EXPORT void printI1(bool i);
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
 extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..d8dd1c93722b09 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
 
   // Emit 'then' region of 'scf.if'
   auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+    // It is not safe to cache constants across regions.
+    // New constants could potentially violate dominance requirements.
+    IndexPool localPool;
+
     // Emit 'tensor.empty' op
     SmallVector<OpFoldResult> outputTensorShape;
     for (auto index : llvm::seq<int64_t>(0, rank)) {
       auto size = index == dim ? targetSize
-                               : getOrFoldTensorDim(rewriter, loc, indexPool,
+                               : getOrFoldTensorDim(rewriter, loc, localPool,
                                                     operand, index);
       outputTensorShape.push_back(size);
     }
@@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
                                         IndexPool &indexPool, Value operand,
                                         ArrayRef<OpFoldResult> targetShape,
                                         ArrayRef<Value> masterOperands) {
-  size_t rank = operand.getType().cast<RankedTensorType>().getRank();
-  assert(targetShape.size() == rank);
-  assert(masterOperands.size() == rank);
+  int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
+  assert((int64_t)targetShape.size() == rank);
+  assert((int64_t)masterOperands.size() == rank);
   for (auto index : llvm::seq<int64_t>(0, rank))
     operand =
         broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3f39cbf03a9a80..367c90538b5ccf 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,8 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
 #include <numeric>
 #include <type_traits>
 
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
 
 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
                             TypedAttr padAttr, OpBuilder &rewriter) {
-  // Input should be padded if necessary.
+  // Input should be padded only if necessary.
   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
     return input;
 
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
   SmallVector<int64_t, 4> paddedShape;
   SmallVector<OpFoldResult, 8> lowIndices;
   SmallVector<OpFoldResult, 8> highIndices;
-  for (int i = 0, s = inputShape.size(); i < s; i++) {
+  for (size_t i : llvm::seq(inputShape.size())) {
     auto lowPad = pad[i * 2];
     auto highPad = pad[i * 2 + 1];
     if (ShapedType::isDynamic(inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
 
 static mlir::Value reifyConstantDim(int64_t attr,
                                     ImplicitLocOpBuilder &builder) {
-  return builder.createOrFold<arith::IndexCastOp>(
-      builder.getIndexType(),
-      builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
+  return builder.create<arith::ConstantIndexOp>(attr);
 }
 
 // Calculating the output width/height using the formula:
 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
 
-static mlir::Value getConvOutputDim(Location loc, Value inputDim,
-                                    int64_t padBeforeAttr, int64_t padAfterAttr,
-                                    Value kernelDim, int64_t strideAttr,
-                                    int64_t dilationAttr, Type inputETy,
-                                    OpBuilder &rewriter) {
+static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
+                                          int64_t padBeforeAttr,
+                                          int64_t padAfterAttr, Value kernelDim,
+                                          int64_t strideAttr,
+                                          int64_t dilationAttr,
+                                          OpBuilder &rewriter) {
   ImplicitLocOpBuilder builder(loc, rewriter);
   auto one = rewriter.create<arith::ConstantOp>(
       loc, IntegerAttr::get(inputDim.getType(), 1));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
     ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
     ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
   ShapedType inputTy = cast<ShapedType>(input.getType());
-  Type inputETy = inputTy.getElementType();
   int64_t inputRank = inputTy.getRank();
 
   SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
           rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
       // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
       dynDims[inputDim] =
-          getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
-                           stride, dilation, inputETy, rewriter);
+          getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
+                                 kernelDynDim, stride, dilation, rewriter);
     }
   }
 
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
 public:
   using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
 
+  // Compute the dynamic output sizes of the maxpool operation.
+  static SmallVector<Value>
+  computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+    TensorType resultTy = op.getType();
+    Location loc = op.getLoc();
+
+    TypedValue<TensorType> input = op.getInput();
+    ArrayRef<int64_t> kernel = op.getKernel();
+    ArrayRef<int64_t> pad = op.getPad();
+    ArrayRef<int64_t> stride = op.getStride();
+
+    SmallVector<Value> dynamicDims;
+
+    // Batch dimension
+    if (resultTy.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+    // Height/width dimensions
+    for (int64_t dim : {1, 2}) {
+      if (!resultTy.isDynamicDim(dim))
+        continue;
+
+      // Index into the attribute arrays
+      int64_t index = dim - 1;
+
+      // Input height/width
+      Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+
+      // Kernel height/width
+      Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+
+      // Output height/width
+      Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
+                                         pad[index * 2 + 1], khw, stride[index],
+                                         /*dilationAttr=*/0, rewriter);
+      dynamicDims.push_back(ohw);
+    }
+
+    // Channel dimension
+    if (resultTy.isDynamicDim(3))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+
+    return dynamicDims;
+  }
+
   LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    Value input = op.getInput();
-    ShapedType inputTy = cast<ShapedType>(input.getType());
+    TypedValue<TensorType> input = op.getInput();
+    ShapedType inputTy = input.getType();
 
-    ShapedType resultTy = cast<ShapedType>(op.getType());
+    ShapedType resultTy = op.getType();
     Type resultETy = inputTy.getElementType();
 
-    auto dynamicDimsOr =
-        checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
-    if (!dynamicDimsOr.has_value())
-      return failure();
-    SmallVector<Value> dynamicDims = *dynamicDimsOr;
+    SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
 
     // Determine what the initial value needs to be for the max pool op.
     TypedAttr initialAttr;
@@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
     pad.resize(2, 0);
     llvm::append_range(pad, op.getPad());
     pad.resize(pad.size() + 2, 0);
+
     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
 
     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
         loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
 
     Value filledEmptyTensor =
-        rewriter
-            .create<linalg::FillOp>(loc, ValueRange{initialValue},
-                                    ValueRange{emptyTensor})
+        rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
             .result();
 
     Value fakeWindowDims =
diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
index 41c619566b55df..b2f86f83b1d3ba 100644
--- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
+++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp
@@ -49,6 +49,7 @@ void stdSort(uint64_t n, V *p) {
 // By providing elementary printing methods only, this
 // library can remain fully unaware of low-level implementation
 // details of our vectors. Also useful for direct LLVM IR output.
+extern "C" void printI1(bool i) { fprintf(stdout, i ? "true" : "false"); }
 extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
 extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
 extern "C" void printF32(float f) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index e64903671e599f..739a6c44fc9eb4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -215,6 +216,59 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
   return
 }
 
+// CHECK-CSE-LABEL: @max_pool_all_dynamic
+func.func @max_pool_all_dynamic(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // Batch size
+  // CHECK-CSE: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK-CSE: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x?xf32>
+
+  // Compute output height
+  // CHECK-CSE: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-CSE: %[[IH:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IH]], %[[C0]] : index
+  // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C0]] : index
+  // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C2]], %[[C1]] : index
+  // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C0]], %[[SUB_ONE]] : index
+  // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+  // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+  // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+  // CHECK-CSE: %[[HEIGHT:.+]] = arith.addi %[[DIVIDE]], %[[C1]] : index
+
+  // Compute output width
+  // CHECK-CSE: %[[IW:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[C5:.+]] = arith.constant 5 : index
+  // CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IW]], %[[C2]] : index
+  // CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C2]] : index
+  // CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C5]], %[[C1]] : index
+  // CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C0]], %[[SUB_ONE]] : index
+  // CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
+  // CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
+  // CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
+  // CHECK-CSE: %[[WIDTH:.+]] = arith.addi %14, %[[C1]] : index
+
+  // Channel size
+  // CHECK-CSE: %[[C3:.+]] = arith.constant 3 : index
+  // CHECK-CSE: %[[CHANNEL:.+]] = tensor.dim %arg0, %[[C3]] : tensor<?x?x?x?xf32>
+
+  // Pad the input
+  // CHECK-CSE: %[[FLOAT_MIN:.+]] = arith.constant -3.40282347E+38 : f32
+  // CHECK-CSE: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 0, 2, 0] high[0, 0, 2, 0] {
+  // CHECK-CSE:   tensor.yield %[[FLOAT_MIN]] : f32
+
+  // Allocate the output and fill with minimum value
+  // CHECK-CSE: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[HEIGHT]], %[[WIDTH]], %[[CHANNEL]]) : tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[FILL:.+]] = linalg.fill ins(%[[FLOAT_MIN]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK-CSE: %[[FAKE_WINDOW:.+]] = tensor.empty() : tensor<2x5xf32>
+
+  // Compute max pool
+  // CHECK-CSE: %[[OUT:.+]] = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PADDED]], %[[FAKE_WINDOW]] : tensor<?x?x?x?xf32>, tensor<2x5xf32>) outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  // CHECK-CSE: return %[[OUT]]
+
+  %0 = tosa.max_pool2d %arg0 {kernel = array<i64: 2, 5>, pad = array<i64: 0, 0, 2, 2>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @avg_pool_f32
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..445e8be47678d5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -270,7 +270,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
   // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
   // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
@@ -284,7 +285,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
   // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
   // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
@@ -298,7 +300,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
   // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
   // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
+  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
   // CHECK:   ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
@@ -312,7 +315,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
   // CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
   // CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
   // CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
-  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
+  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
   // CHECK:   %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
   // CHECK:   %[[VAL_25:...
[truncated]

@sabauma sabauma force-pushed the tosa-maxpool branch 2 times, most recently from 2383df9 to c5b78f0 Compare April 4, 2024 16:28
@eric-k256 eric-k256 requested a review from sjarus April 8, 2024 20:40
@sabauma
Copy link
Contributor Author

sabauma commented Apr 9, 2024

For context, this was attempted once here: https://reviews.llvm.org/D133389

And then subsequently reverted here: https://reviews.llvm.org/D134370 due to incorrect codegen in PyTorch. If there are thoughts on how to mitigate the downstream risk, please let me know.

@sjarus
Copy link
Contributor

sjarus commented Apr 15, 2024

Looking at this now @sabauma

…ring

The existing lowering for tosa.max_pool2d only supports dynamic
dimensions when the dynamic dimension is the batch dimension.
This change updates the lowering to support arbitrary dynamic dimensions
on the inputs and outputs of the tosa.max_pool2d operation.

This change also fixes a bug in the implementation of implicit
broadcasting in the tosa-to-linalg pass, which was introducing uses of
constant ops that violated dominanace requirements.
@sabauma sabauma merged commit 1c076b4 into llvm:main Apr 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants