Skip to content

Commit 1c076b4

Browse files
authored
[mlir][tosa] Implement dynamic shape support for tosa.max_pool2d lowering (#87538)
The existing lowering for tosa.max_pool2d only supports dynamic dimensions when the dynamic dimension is the batch dimension. This change updates the lowering to support arbitrary dynamic dimensions on the inputs and outputs of the tosa.max_pool2d operation. This change also fixes a bug in the implementation of implicit broadcasting in the tosa-to-linalg pass, which was introducing uses of constant ops that violated dominance requirements.
1 parent ac1f2de commit 1c076b4

File tree

6 files changed

+251
-37
lines changed

6 files changed

+251
-37
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,11 @@ def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
130130
// to not include any remaining unranked tensors.
131131
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
132132

133-
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>]>;
134-
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>]>;
135-
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>]>;
136-
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>]>;
137-
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>]>;
133+
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
134+
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
135+
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
136+
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
137+
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
138138

139139
// Ranked tensors up to given rank.
140140
def Tosa_Tensor1Dto4D : AnyTypeOf<[

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
766766

767767
// Emit 'then' region of 'scf.if'
768768
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
769+
// It is not safe to cache constants across regions.
770+
// New constants could potentially violate dominance requirements.
771+
IndexPool localPool;
772+
769773
// Emit 'tensor.empty' op
770774
SmallVector<OpFoldResult> outputTensorShape;
771775
for (auto index : llvm::seq<int64_t>(0, rank)) {
772776
auto size = index == dim ? targetSize
773-
: getOrFoldTensorDim(rewriter, loc, indexPool,
777+
: getOrFoldTensorDim(rewriter, loc, localPool,
774778
operand, index);
775779
outputTensorShape.push_back(size);
776780
}
@@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
812816
IndexPool &indexPool, Value operand,
813817
ArrayRef<OpFoldResult> targetShape,
814818
ArrayRef<Value> masterOperands) {
815-
size_t rank = operand.getType().cast<RankedTensorType>().getRank();
816-
assert(targetShape.size() == rank);
817-
assert(masterOperands.size() == rank);
819+
int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
820+
assert((int64_t)targetShape.size() == rank);
821+
assert((int64_t)masterOperands.size() == rank);
818822
for (auto index : llvm::seq<int64_t>(0, rank))
819823
operand =
820824
broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include "mlir/Transforms/DialectConversion.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2828

29+
#include "mlir/Interfaces/InferTypeOpInterface.h"
30+
2931
#include <numeric>
3032
#include <type_traits>
3133

@@ -34,7 +36,7 @@ using namespace mlir::tosa;
3436

3537
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
3638
TypedAttr padAttr, OpBuilder &rewriter) {
37-
// Input should be padded if necessary.
39+
// Input should be padded only if necessary.
3840
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
3941
return input;
4042

@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
4749
SmallVector<int64_t, 4> paddedShape;
4850
SmallVector<OpFoldResult, 8> lowIndices;
4951
SmallVector<OpFoldResult, 8> highIndices;
50-
for (int i = 0, s = inputShape.size(); i < s; i++) {
52+
for (size_t i : llvm::seq(inputShape.size())) {
5153
auto lowPad = pad[i * 2];
5254
auto highPad = pad[i * 2 + 1];
5355
if (ShapedType::isDynamic(inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
131133

132134
static mlir::Value reifyConstantDim(int64_t attr,
133135
ImplicitLocOpBuilder &builder) {
134-
return builder.createOrFold<arith::IndexCastOp>(
135-
builder.getIndexType(),
136-
builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
136+
return builder.create<arith::ConstantIndexOp>(attr);
137137
}
138138

139139
// Calculating the output width/height using the formula:
140140
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
141141
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
142142

143-
static mlir::Value getConvOutputDim(Location loc, Value inputDim,
144-
int64_t padBeforeAttr, int64_t padAfterAttr,
145-
Value kernelDim, int64_t strideAttr,
146-
int64_t dilationAttr, Type inputETy,
147-
OpBuilder &rewriter) {
143+
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
144+
int64_t padBeforeAttr,
145+
int64_t padAfterAttr, Value kernelDim,
146+
int64_t strideAttr,
147+
int64_t dilationAttr,
148+
OpBuilder &rewriter) {
148149
ImplicitLocOpBuilder builder(loc, rewriter);
149150
auto one = rewriter.create<arith::ConstantOp>(
150151
loc, IntegerAttr::get(inputDim.getType(), 1));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
171172
ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
172173
ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
173174
ShapedType inputTy = cast<ShapedType>(input.getType());
174-
Type inputETy = inputTy.getElementType();
175175
int64_t inputRank = inputTy.getRank();
176176

177177
SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
190190
rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
191191
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
192192
dynDims[inputDim] =
193-
getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
194-
stride, dilation, inputETy, rewriter);
193+
getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
194+
kernelDynDim, stride, dilation, rewriter);
195195
}
196196
}
197197

@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
685685
public:
686686
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
687687

688+
// Compute the dynamic output sizes of the maxpool operation.
689+
static SmallVector<Value>
690+
computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
691+
TensorType resultTy = op.getType();
692+
Location loc = op.getLoc();
693+
694+
TypedValue<TensorType> input = op.getInput();
695+
ArrayRef<int64_t> kernel = op.getKernel();
696+
ArrayRef<int64_t> pad = op.getPad();
697+
ArrayRef<int64_t> stride = op.getStride();
698+
699+
SmallVector<Value> dynamicDims;
700+
701+
// Batch dimension
702+
if (resultTy.isDynamicDim(0))
703+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
704+
705+
// Height/width dimensions
706+
for (int64_t dim : {1, 2}) {
707+
if (!resultTy.isDynamicDim(dim))
708+
continue;
709+
710+
// Index into the attribute arrays
711+
int64_t index = dim - 1;
712+
713+
// Input height/width
714+
Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
715+
716+
// Kernel height/width
717+
Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
718+
719+
// Output height/width
720+
Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
721+
pad[index * 2 + 1], khw, stride[index],
722+
/*dilationAttr=*/1, rewriter);
723+
dynamicDims.push_back(ohw);
724+
}
725+
726+
// Channel dimension
727+
if (resultTy.isDynamicDim(3))
728+
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
729+
730+
return dynamicDims;
731+
}
732+
688733
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
689734
PatternRewriter &rewriter) const final {
690735
Location loc = op.getLoc();
691-
Value input = op.getInput();
692-
ShapedType inputTy = cast<ShapedType>(input.getType());
736+
TypedValue<TensorType> input = op.getInput();
737+
ShapedType inputTy = input.getType();
693738

694-
ShapedType resultTy = cast<ShapedType>(op.getType());
739+
ShapedType resultTy = op.getType();
695740
Type resultETy = inputTy.getElementType();
696741

697-
auto dynamicDimsOr =
698-
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
699-
if (!dynamicDimsOr.has_value())
700-
return failure();
701-
SmallVector<Value> dynamicDims = *dynamicDimsOr;
742+
SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
702743

703744
// Determine what the initial value needs to be for the max pool op.
704745
TypedAttr initialAttr;
@@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
721762
pad.resize(2, 0);
722763
llvm::append_range(pad, op.getPad());
723764
pad.resize(pad.size() + 2, 0);
765+
724766
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
725767

726768
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
736778
loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
737779

738780
Value filledEmptyTensor =
739-
rewriter
740-
.create<linalg::FillOp>(loc, ValueRange{initialValue},
741-
ValueRange{emptyTensor})
781+
rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
742782
.result();
743783

744784
Value fakeWindowDims =

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
22
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
3+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
34

45
// CHECK-LABEL: @matmul
56
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -215,6 +216,59 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
215216
return
216217
}
217218

219+
// CHECK-CSE-LABEL: @max_pool_all_dynamic
220+
func.func @max_pool_all_dynamic(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
221+
// Batch size
222+
// CHECK-CSE: %[[C0:.+]] = arith.constant 0 : index
223+
// CHECK-CSE: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x?xf32>
224+
225+
// Compute output height
226+
// CHECK-CSE: %[[C1:.+]] = arith.constant 1 : index
227+
// CHECK-CSE: %[[IH:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x?xf32>
228+
// CHECK-CSE: %[[C2:.+]] = arith.constant 2 : index
229+
// CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IH]], %[[C0]] : index
230+
// CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C0]] : index
231+
// CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C2]], %[[C1]] : index
232+
// CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
233+
// CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
234+
// CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
235+
// CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
236+
// CHECK-CSE: %[[HEIGHT:.+]] = arith.addi %[[DIVIDE]], %[[C1]] : index
237+
238+
// Compute output width
239+
// CHECK-CSE: %[[IW:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x?xf32>
240+
// CHECK-CSE: %[[C5:.+]] = arith.constant 5 : index
241+
// CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IW]], %[[C2]] : index
242+
// CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C2]] : index
243+
// CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C5]], %[[C1]] : index
244+
// CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
245+
// CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
246+
// CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
247+
// CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
248+
// CHECK-CSE: %[[WIDTH:.+]] = arith.addi %14, %[[C1]] : index
249+
250+
// Channel size
251+
// CHECK-CSE: %[[C3:.+]] = arith.constant 3 : index
252+
// CHECK-CSE: %[[CHANNEL:.+]] = tensor.dim %arg0, %[[C3]] : tensor<?x?x?x?xf32>
253+
254+
// Pad the input
255+
// CHECK-CSE: %[[FLOAT_MIN:.+]] = arith.constant -3.40282347E+38 : f32
256+
// CHECK-CSE: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 0, 2, 0] high[0, 0, 2, 0] {
257+
// CHECK-CSE: tensor.yield %[[FLOAT_MIN]] : f32
258+
259+
// Allocate the output and fill with minimum value
260+
// CHECK-CSE: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[HEIGHT]], %[[WIDTH]], %[[CHANNEL]]) : tensor<?x?x?x?xf32>
261+
// CHECK-CSE: %[[FILL:.+]] = linalg.fill ins(%[[FLOAT_MIN]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
262+
// CHECK-CSE: %[[FAKE_WINDOW:.+]] = tensor.empty() : tensor<2x5xf32>
263+
264+
// Compute max pool
265+
// CHECK-CSE: %[[OUT:.+]] = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PADDED]], %[[FAKE_WINDOW]] : tensor<?x?x?x?xf32>, tensor<2x5xf32>) outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
266+
// CHECK-CSE: return %[[OUT]]
267+
268+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 2, 5>, pad = array<i64: 0, 0, 2, 2>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
269+
return %0 : tensor<?x?x?x?xf32>
270+
}
271+
218272
// -----
219273

220274
// CHECK-LABEL: @avg_pool_f32

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
270270
// CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
271271
// CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
272272
// CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
273-
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
273+
// CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
274+
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
274275
// CHECK: %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
275276
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
276277
// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
@@ -284,7 +285,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
284285
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
285286
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
286287
// CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
287-
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
288+
// CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
289+
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
288290
// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
289291
// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
290292
// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
@@ -298,7 +300,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
298300
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
299301
// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
300302
// CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
301-
// CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
303+
// CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
304+
// CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
302305
// CHECK: %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
303306
// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
304307
// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
@@ -312,7 +315,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
312315
// CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
313316
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
314317
// CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
315-
// CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
318+
// CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
319+
// CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
316320
// CHECK: %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
317321
// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
318322
// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):

0 commit comments

Comments
 (0)