Skip to content

Commit 50a2bb9

Browse files
[mlir][tensor] Fold rank-reducing extract_slice with inverse expand_shape
Differential Revision: https://reviews.llvm.org/D139220
1 parent 23ad3da commit 50a2bb9

File tree

5 files changed

+95
-1
lines changed

5 files changed

+95
-1
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ FailureOr<Value> replaceExtractSliceWithTiledProducer(
3636
void populateMergeConsecutiveInsertExtractSlicePatterns(
3737
RewritePatternSet &patterns);
3838

39+
/// Populates `patterns` with patterns that fold `tensor.expand_shape` and
40+
/// `tensor.collapse_shape` into other ops.
41+
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
42+
3943
} // namespace tensor
4044
} // namespace mlir
4145

mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
33
Bufferize.cpp
44
ExtractSliceFromReshapeUtils.cpp
55
MergeConsecutiveInsertExtractSlicePatterns.cpp
6+
ReshapePatterns.cpp
67
SplitPaddingPatterns.cpp
78
SwapExtractSliceWithProducerPatterns.cpp
89

@@ -26,4 +27,4 @@ add_mlir_dialect_library(MLIRTensorTransforms
2627
MLIRTensorDialect
2728
MLIRTilingInterface
2829
MLIRTransforms
29-
)
30+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
10+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11+
#include "mlir/IR/PatternMatch.h"
12+
#include "llvm/Support/Debug.h"
13+
14+
#define DEBUG_TYPE "mlir-tensor-split-padding"
15+
16+
using namespace mlir;
17+
using namespace mlir::tensor;
18+
19+
namespace {
20+
/// Fold expand_shape(extract_slice) ops that cancel itself out.
21+
struct FoldExpandOfRankReducingExtract
22+
: public OpRewritePattern<ExpandShapeOp> {
23+
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
24+
25+
LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
26+
PatternRewriter &rewriter) const override {
27+
RankedTensorType resultType = expandShapeOp.getResultType();
28+
auto extractSliceOp =
29+
expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
30+
if (!extractSliceOp)
31+
return failure();
32+
RankedTensorType srcType = extractSliceOp.getSourceType();
33+
34+
// Only cases where the ExpandShapeOp can be folded away entirely are
35+
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
36+
// has no rank-reduction anymore are supported at the moment.
37+
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
38+
srcType, extractSliceOp.getStaticOffsets(),
39+
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
40+
if (nonReducingExtractType != resultType)
41+
return failure();
42+
43+
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
44+
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
45+
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
46+
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
47+
expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
48+
mixedStrides);
49+
return success();
50+
}
51+
};
52+
} // namespace
53+
54+
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
55+
RewritePatternSet &patterns) {
56+
patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
57+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
4+
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
5+
// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
6+
// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
7+
// CHECK: return %[[extract1]], %[[extract2]]
8+
func.func @expand_shape_of_rank_reducing_extract(
9+
%t: tensor<?x?x?x?xf32>, %idx: index)
10+
-> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>)
11+
{
12+
%0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
13+
: tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
14+
%1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
15+
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
16+
%2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
17+
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
18+
return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
19+
}

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ struct TestTensorTransforms
6565
"with loop nest"),
6666
llvm::cl::init(false)};
6767

68+
Option<bool> testReassociativeReshapeFolding{
69+
*this, "test-reassociative-reshape-folding",
70+
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
71+
llvm::cl::init(false)};
72+
6873
Option<bool> useForeach{
6974
*this, "use-foreach",
7075
llvm::cl::desc(
@@ -74,6 +79,12 @@ struct TestTensorTransforms
7479
};
7580
} // namespace
7681

82+
static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
83+
RewritePatternSet patterns(rootOp->getContext());
84+
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
85+
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
86+
}
87+
7788
static void applySplitPaddingPatterns(Operation *rootOp) {
7889
RewritePatternSet patterns(rootOp->getContext());
7990
tensor::populateSplitPaddingPatterns(patterns);
@@ -251,6 +262,8 @@ void TestTensorTransforms::runOnOperation() {
251262
applyFoldConstantExtractSlicePatterns(rootOp);
252263
if (testFoldConsecutiveInsertExtractSlice)
253264
applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
265+
if (testReassociativeReshapeFolding)
266+
applyReassociativeReshapeFoldingPatterns(rootOp);
254267
if (testRewriteExtractSliceWithTiledCollapseShape) {
255268
if (failed(
256269
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))

0 commit comments

Comments
 (0)