Skip to content

Commit e377a5d

Browse files
[MLIR][Tensor] Remove tensor.dim canonicalization patterns registered on tensor.expand_shape/tensor.collapse_shape (#134219)
These are problematic because the iterative application that locally resolves the tensor.dim operation introduces intermediate floor_div, which is losing the information about the exact division that was carried out in the original IR, and the iterative algorithm can't converge towards the simplest form. Information loss is not acceptable for canonicalization. Resolving the dimOp can be achieved through resolve-ranked-shaped-type-result-dims and resolve-shaped-type-result-dims passes. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 9eeafc6 commit e377a5d

File tree

4 files changed

+22
-132
lines changed

4 files changed

+22
-132
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,90 +1986,6 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
19861986
}
19871987
};
19881988

1989-
struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1990-
using OpRewritePattern<DimOp>::OpRewritePattern;
1991-
1992-
LogicalResult matchAndRewrite(DimOp dimOp,
1993-
PatternRewriter &rewriter) const override {
1994-
auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1995-
if (!expandShapeOp)
1996-
return failure();
1997-
1998-
// Only constant dimension values are supported.
1999-
std::optional<int64_t> dim = dimOp.getConstantIndex();
2000-
if (!dim.has_value())
2001-
return failure();
2002-
2003-
// Skip static dims. These are folded to constant ops.
2004-
RankedTensorType resultType = expandShapeOp.getResultType();
2005-
if (!resultType.isDynamicDim(*dim))
2006-
return failure();
2007-
2008-
// Find reassociation group that contains this result dimension.
2009-
int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
2010-
2011-
// `dim` is the only dynamic dimension in `group`. (Otherwise, the
2012-
// ExpandShapeOp would be ambiguous.)
2013-
int64_t product = 1;
2014-
ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
2015-
for (int64_t d : grp) {
2016-
if (d != dim) {
2017-
assert(!resultType.isDynamicDim(d) && "expected static dim");
2018-
product *= resultType.getDimSize(d);
2019-
}
2020-
}
2021-
2022-
// result dim size = src dim size / (product(other dims in reassoc group))
2023-
Value srcDimSz =
2024-
rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
2025-
AffineExpr expr;
2026-
bindSymbols(dimOp.getContext(), expr);
2027-
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
2028-
dimOp, expr.floorDiv(product), srcDimSz);
2029-
return success();
2030-
}
2031-
};
2032-
2033-
struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
2034-
using OpRewritePattern<DimOp>::OpRewritePattern;
2035-
2036-
LogicalResult matchAndRewrite(DimOp dimOp,
2037-
PatternRewriter &rewriter) const override {
2038-
auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2039-
if (!collapseShapeOp)
2040-
return failure();
2041-
2042-
// Only constant dimension values are supported.
2043-
std::optional<int64_t> dim = dimOp.getConstantIndex();
2044-
if (!dim.has_value() ||
2045-
dim.value() >= collapseShapeOp.getResultType().getRank())
2046-
return failure();
2047-
2048-
// Skip static dims. These are folded to constant ops.
2049-
RankedTensorType resultType = collapseShapeOp.getResultType();
2050-
if (!resultType.isDynamicDim(*dim))
2051-
return failure();
2052-
2053-
// Get reassociation group of the result dimension.
2054-
ReassociationIndices group =
2055-
collapseShapeOp.getReassociationIndices()[*dim];
2056-
2057-
// result dim size = product(dims in reassoc group)
2058-
SmallVector<Value> srcDimSizes;
2059-
SmallVector<AffineExpr> syms;
2060-
AffineExpr product;
2061-
for (const auto &it : llvm::enumerate(group)) {
2062-
srcDimSizes.push_back(rewriter.create<DimOp>(
2063-
dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2064-
syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
2065-
product = product ? product * syms.back() : syms.back();
2066-
}
2067-
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
2068-
srcDimSizes);
2069-
return success();
2070-
}
2071-
};
2072-
20731989
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
20741990
/// matching constant output_shape operands of the expand. This makes the
20751991
/// `tensor.expand_shape` more static and creates a consumer cast that can be
@@ -2158,8 +2074,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
21582074
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
21592075
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
21602076
FoldReshapeWithSplat<ExpandShapeOp>,
2161-
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2162-
FoldDimOfCollapseShape>(context);
2077+
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
21632078
}
21642079

21652080
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,18 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
2525
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
2626
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
2727
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
28-
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
2928
// CHECK-LABEL: func @drop_one_trip_loops
3029
// CHECK: %[[C2:.*]] = arith.constant 2 : index
31-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
3230
// CHECK: %[[C0:.*]] = arith.constant 0 : index
3331
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
3432
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
3533
// CHECK: linalg.generic
3634
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
3735
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
3836
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
39-
// CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
4037
// CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
41-
// CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
4238
// CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
43-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
39+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[DIM]], 1, %[[DIM_1]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
4440

4541
// CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
4642
// CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
@@ -79,18 +75,15 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32,
7975
}
8076
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
8177
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
82-
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
8378
// CHECK-LABEL: func @drop_one_trip_loops_all_ones
8479
// CHECK: %[[C2:.*]] = arith.constant 2 : index
85-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
8680
// CHECK: tensor.collapse_shape %{{.*}} []
8781
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
8882
// CHECK: linalg.generic
8983
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
9084
// CHECK-SAME: iterator_types = ["parallel"]
9185
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
92-
// CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
93-
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
86+
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[DIM]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
9487

9588
// -----
9689

@@ -406,7 +399,6 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
406399
}
407400
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
408401
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
409-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
410402
// CHECK: func @unit_dim_for_reduction
411403
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32>
412404
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -422,8 +414,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
422414
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
423415
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
424416
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
425-
// CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
426-
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
417+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[DIM_0]]] : tensor<?xf32> into tensor<1x?xf32>
427418
// CHECK: return %[[EXPANDED]] : tensor<1x?xf32>
428419

429420
// -----
@@ -482,10 +473,8 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
482473
}
483474
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
484475
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
485-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
486476
// CHECK: func @unit_dim_for_reduction_inner
487477
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x1xf32>
488-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
489478
// CHECK: %[[C0:.*]] = arith.constant 0 : index
490479
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
491480
// CHECK: %[[C2:.*]] = arith.constant 2 : index
@@ -499,8 +488,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
499488
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
500489
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
501490
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
502-
// CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
503-
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
491+
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[DIM_0]], 1] : tensor<?xf32> into tensor<?x1xf32>
504492
// CHECK: return %[[RESULT_RESHAPE]]
505493

506494
// -----
@@ -1017,7 +1005,6 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
10171005
return %0 : tensor<1x?xf32>
10181006
}
10191007

1020-
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
10211008
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
10221009
// CHECK-LABEL: func @drop_unit_pad_dynamic_dims
10231010
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -1027,8 +1014,7 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
10271014
// CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
10281015
// CHECK: } : tensor<?xf32> to tensor<?xf32>
10291016
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
1030-
// CHECK: %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
1031-
// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
1017+
// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[DIM]]]
10321018
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>
10331019

10341020
// CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
@@ -1090,20 +1076,17 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
10901076

10911077
// -----
10921078

1093-
// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
10941079
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
10951080
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
10961081

10971082
// CHECK-LABEL: func @drop_unit_dim_corresponding_to_dynamic_dim
10981083
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
10991084
// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
11001085
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
1101-
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
11021086
// CHECK: %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
11031087
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
11041088
// CHECK: %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
1105-
// CHECK: %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
1106-
// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
1089+
// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
11071090
// CHECK: %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
11081091
// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
11091092
// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32

mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
7676
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
7777
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
7878
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
79-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
79+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
8080
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
8181
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
8282
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
8383
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat
8484
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
85-
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
85+
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
8686
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
8787
// CHECK-NEXT: return %[[RES]]
8888
%1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
@@ -134,7 +134,7 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
134134
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
135135
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
136136
// CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
137-
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
137+
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
138138
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
139139
// CHECK-NEXT: return %[[RES]]
140140
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
@@ -171,12 +171,12 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
171171
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
172172
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
173173
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
174-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
174+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
175175
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
176176
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
177177
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat
178178
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
179-
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
179+
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
180180
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
181181
// CHECK-NEXT: return %[[RES]]
182182
%0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,15 +1105,13 @@ func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -
11051105
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
11061106
return %expanded : tensor<?x384xf32>
11071107
}
1108-
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
11091108
// CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
11101109
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1111-
// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
11121110
// CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
1111+
// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
11131112
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1114-
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1115-
// CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1116-
// CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
1113+
// CHECK: %[[DIM:.+]] = tensor.dim %[[COLLAPSE]], %[[CONSTANT0]] : tensor<?xf32>
1114+
// CHECK: %[[DIVUI:.+]] = arith.divui %[[DIM]], %[[CONSTANT384]] : index
11171115
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
11181116
// CHECK: return %[[RESULT]]
11191117

@@ -2137,13 +2135,12 @@ func.func @empty_tensor_canonicalize(%i : index) {
21372135

21382136
// -----
21392137

2140-
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
21412138
// CHECK-LABEL: func @dim_of_expand_shape(
21422139
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
2143-
// CHECK: %[[c1:.*]] = arith.constant 1 : index
2144-
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2145-
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
2146-
// CHECK: return %[[apply]]
2140+
// CHECK: %[[c2:.*]] = arith.constant 2 : index
2141+
// CHECK: %[[expanded:.*]] = tensor.expand_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4, 5]] output_shape [%arg1, 1, %arg2, 5, 1, 8] : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
2142+
// CHECK: %[[dim:.*]] = tensor.dim %[[expanded]], %[[c2]] : tensor<?x1x?x5x1x8xf32>
2143+
// CHECK: return %[[dim]]
21472144
func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
21482145
%c2 = arith.constant 2 : index
21492146
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
@@ -2154,17 +2151,12 @@ func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) ->
21542151

21552152
// -----
21562153

2157-
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
21582154
// CHECK-LABEL: func @dim_of_collapse_shape(
21592155
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
21602156
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
2161-
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
2162-
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
2163-
// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
2164-
// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
2165-
// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
2166-
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
2167-
// CHECK: return %[[apply]]
2157+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4]] : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2158+
// CHECK-DAG: %[[dim:.*]] = tensor.dim %[[collapsed]], %[[c1]]
2159+
// CHECK: return %[[dim]]
21682160
func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
21692161
%c1 = arith.constant 1 : index
21702162
%0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]

0 commit comments

Comments
 (0)