Skip to content

Commit 2e7aa93

Browse files
committed
[BugFix] : Move DimOp canonicalization from memref to tensor.
1 parent 0fa04b6 commit 2e7aa93

File tree

6 files changed

+106
-78
lines changed

6 files changed

+106
-78
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
629629
Speculation::Speculatability getSpeculatability();
630630
}];
631631

632-
let hasCanonicalizer = 1;
633632
let hasFolder = 1;
634633
}
635634

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
317317
MLIRContext *context = enclosingOp->getContext();
318318
RewritePatternSet patterns(context);
319319
patterns.add<LinalgRewritePattern<LoopType>>(context);
320-
memref::DimOp::getCanonicalizationPatterns(patterns, context);
321320
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
322321
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
323322
patterns.add<FoldAffineOp>(context);

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
10691069
return {};
10701070
}
10711071

1072-
namespace {
1073-
/// Fold dim of a memref reshape operation to a load into the reshape's shape
1074-
/// operand.
1075-
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1076-
using OpRewritePattern<DimOp>::OpRewritePattern;
1077-
1078-
LogicalResult matchAndRewrite(DimOp dim,
1079-
PatternRewriter &rewriter) const override {
1080-
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1081-
1082-
if (!reshape)
1083-
return failure();
1084-
1085-
// Place the load directly after the reshape to ensure that the shape memref
1086-
// was not mutated.
1087-
rewriter.setInsertionPointAfter(reshape);
1088-
Location loc = dim.getLoc();
1089-
Value load =
1090-
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1091-
if (load.getType() != dim.getType())
1092-
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1093-
rewriter.replaceOp(dim, load);
1094-
return success();
1095-
}
1096-
};
1097-
1098-
} // namespace
1099-
1100-
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1101-
MLIRContext *context) {
1102-
results.add<DimOfMemRefReshape>(context);
1103-
}
1104-
11051072
// ---------------------------------------------------------------------------
11061073
// DmaStartOp
11071074
// ---------------------------------------------------------------------------

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
824824
return success();
825825
}
826826
};
827+
828+
/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
829+
/// operand.
830+
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
831+
using OpRewritePattern<DimOp>::OpRewritePattern;
832+
833+
LogicalResult matchAndRewrite(DimOp dim,
834+
PatternRewriter &rewriter) const override {
835+
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
836+
837+
if (!reshape)
838+
return failure();
839+
840+
// Since tensors are immutable we don't need to worry about where to place
841+
// the load call
842+
rewriter.setInsertionPointAfter(dim);
843+
Location loc = dim.getLoc();
844+
Value load =
845+
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
846+
if (load.getType() != dim.getType())
847+
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
848+
rewriter.replaceOp(dim, load);
849+
return success();
850+
}
851+
};
827852
} // namespace
828853

829854
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
830855
MLIRContext *context) {
831-
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
856+
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
832857
}
833858

834859
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
242242

243243
// -----
244244

245-
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
246-
// CHECK-LABEL: func @dim_of_memref_reshape(
247-
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
248-
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
249-
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
250-
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
251-
// CHECK-NEXT: memref.store
252-
// CHECK-NOT: memref.dim
253-
// CHECK: return %[[DIM]] : index
254-
func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
255-
-> index {
256-
%c3 = arith.constant 3 : index
257-
%0 = memref.reshape %arg0(%arg1)
258-
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
259-
// Update the shape to test that he load ends up in the right place.
260-
memref.store %c3, %arg1[%c3] : memref<?xindex>
261-
%1 = memref.dim %0, %c3 : memref<*xf32>
262-
return %1 : index
263-
}
264-
265-
// -----
266-
267-
// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
268-
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
269-
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
270-
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
271-
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
272-
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
273-
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
274-
// CHECK-NOT: memref.dim
275-
// CHECK: return %[[CAST]] : index
276-
func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
277-
-> index {
278-
%c3 = arith.constant 3 : index
279-
%0 = memref.reshape %arg0(%arg1)
280-
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
281-
%1 = memref.dim %0, %c3 : memref<*xf32>
282-
return %1 : index
283-
}
284-
285-
// -----
286-
287245
// CHECK-LABEL: func @alloc_const_fold
288246
func.func @alloc_const_fold() -> memref<?xf32> {
289247
// CHECK-NEXT: memref.alloc() : memref<4xf32>

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
22502250
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
22512251
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
22522252
// CHECK: return %[[SRC]]
2253+
2254+
// -----
2255+
2256+
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
2257+
// CHECK-LABEL: func @dim_of_reshape(
2258+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2259+
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2260+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
2261+
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2262+
// CHECK-NOT: tensor.store
2263+
// CHECK-NOT: tensor.dim
2264+
// CHECK-NOT: tensor.reshape
2265+
// CHECK: return %[[DIM]] : index
2266+
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2267+
-> index {
2268+
%c3 = arith.constant 3 : index
2269+
%0 = tensor.reshape %arg0(%arg1)
2270+
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2271+
// Update the shape to test that the load ends up in the right place.
2272+
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2273+
%1 = tensor.dim %0, %c3 : tensor<*xf32>
2274+
return %1 : index
2275+
}
2276+
2277+
// -----
2278+
2279+
// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2280+
// CHECK-LABEL: func @dim_of_reshape_i32(
2281+
// CHECK: tensor.extract
2282+
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
2283+
// CHECK-NOT: tensor.dim
2284+
// CHECK-NOT: tensor.reshape
2285+
// CHECK: return %[[CAST]] : index
2286+
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2287+
-> index {
2288+
%c3 = arith.constant 3 : index
2289+
%0 = tensor.reshape %arg0(%arg1)
2290+
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2291+
%1 = tensor.dim %0, %c3 : tensor<*xf32>
2292+
return %1 : index
2293+
}
2294+
2295+
// -----
2296+
2297+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
2298+
// CHECK-LABEL: func @dim_of_reshape_for(
2299+
// CHECK: scf.for
2300+
// CHECK-NEXT: tensor.extract
2301+
// CHECK-NOT: tensor.dim
2302+
// CHECK-NOT: tensor.reshape
2303+
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2304+
%c0 = arith.constant 0 : index
2305+
%c1 = arith.constant 1 : index
2306+
%c4 = arith.constant 4 : index
2307+
2308+
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2309+
2310+
%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
2311+
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
2312+
%3 = arith.muli %arg3, %2 : index
2313+
scf.yield %3 : index
2314+
}
2315+
return %1 : index
2316+
}
2317+
2318+
// -----
2319+
2320+
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
2321+
// CHECK-LABEL: func @dim_of_reshape_undominated(
2322+
// CHECK: arith.muli
2323+
// CHECK-NEXT: tensor.extract
2324+
// CHECK-NOT: tensor.dim
2325+
// CHECK-NOT: tensor.reshape
2326+
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2327+
%c4 = arith.constant 4 : index
2328+
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2329+
%0 = arith.muli %arg2, %c4 : index
2330+
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
2331+
return %dim : index
2332+
}

0 commit comments

Comments
 (0)