Skip to content

Commit 0804a88

Browse files
agostini01matthias-springer
authored andcommitted
[mlir][linalg] Transform PadTensorOp into InitOp, FillOp, GenericOp
Introduces a test pass that rewrites PadTensorOps with static shapes as a sequence of: ``` linalg.init_tensor // to create output linalg.fill // to initialize with padding value linalg.generic // to copy the original contents to the padded tensor ``` The pass can be triggered with: - `--test-linalg-transform-patterns="test-transform-pad-tensor"` Differential Revision: https://reviews.llvm.org/D102804
1 parent 3d2c906 commit 0804a88

File tree

4 files changed

+149
-0
lines changed

4 files changed

+149
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,15 @@ void populateLinalgDistributeTiledLoopPattern(
871871
// Op-specific patterns.
872872
//===----------------------------------------------------------------------===//
873873

874+
/// PadTensorOp is not canonicalized away yet, so we provide a transformation to
875+
/// `linalg.generic`.
876+
struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
877+
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
878+
879+
LogicalResult matchAndRewrite(PadTensorOp padOp,
880+
PatternRewriter &rewriter) const override;
881+
};
882+
874883
/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
875884
/// it needs a specific pattern to vectorize.
876885
struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,68 @@ LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
637637

638638
return failure();
639639
}
640+
641+
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
642+
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
643+
}
644+
645+
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
646+
/// with pad_val) and GenericOp (to copy contents).
647+
LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
648+
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
649+
650+
auto inputShapedType = padOp.source().getType().cast<ShapedType>();
651+
auto resultShapedType = padOp.result().getType().cast<ShapedType>();
652+
653+
// Bail on non-static shapes.
654+
if (!inputShapedType.hasStaticShape())
655+
return failure();
656+
if (!resultShapedType.hasStaticShape())
657+
return failure();
658+
659+
// Only support padding with a constant for now, i.e. either:
660+
// 1. A BBarg from a different block.
661+
// 2. A value defined outside of the current block.
662+
Block &block = padOp.region().front();
663+
auto yieldOp = cast<YieldOp>(block.getTerminator());
664+
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
665+
Value padValue = yieldOp.values().front();
666+
Operation *definingOp = padValue.getDefiningOp();
667+
if (definingOp && definingOp->getBlock() == &block)
668+
return failure();
669+
if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
670+
return failure();
671+
672+
// Create tensor with the padded shape
673+
Location loc = padOp.getLoc();
674+
SmallVector<Value> indices(resultShapedType.getRank(),
675+
rewriter.create<ConstantIndexOp>(loc, 0));
676+
Value initTensor = rewriter.create<InitTensorOp>(
677+
loc, resultShapedType.getShape(), resultShapedType.getElementType());
678+
679+
// Initialize tensor with the pad value
680+
Value tmpTensor =
681+
rewriter.create<linalg::FillOp>(loc, initTensor, padValue).result();
682+
683+
// Copy original contents into new tensor
684+
// Uses linalg.generic, but could be done with std.subtensor_insert
685+
SmallVector<AffineExpr, 4> outputExprs;
686+
for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
687+
outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
688+
padOp.static_low()[i].cast<IntegerAttr>().getInt());
689+
}
690+
691+
SmallVector<AffineMap, 2> transferMaps = {
692+
rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
693+
AffineMap::get(resultShapedType.getRank(),
694+
/*symbolCount=*/0, outputExprs, rewriter.getContext())};
695+
696+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
697+
padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
698+
getNParallelLoopsAttrs(resultShapedType.getRank()),
699+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
700+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
701+
});
702+
703+
return success();
704+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor" %s | FileCheck --check-prefix=CHECK %s
2+
3+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
4+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)>
5+
// CHECK-LABEL: func @pad_tensor_with_memrefs
6+
func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> {
7+
%cst = constant 0.000000e+00 : f32
8+
%0 = memref.tensor_load %arg0 : memref<1x28x28x1xf32>
9+
%1 = linalg.pad_tensor %0 low[1, 1, 1, 2] high[0, 2, 2, 0] {
10+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
11+
linalg.yield %cst : f32
12+
} : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32>
13+
%2 = memref.buffer_cast %1 : memref<2x31x31x3xf32>
14+
return %2 : memref<2x31x31x3xf32>
15+
}
16+
17+
// CHECK: linalg.fill
18+
// CHECK: linalg.generic
19+
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
20+
21+
// -----
22+
23+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
24+
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)>
25+
// CHECK-LABEL: func @pad_tensor_no_memrefs
26+
func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> {
27+
%cst = constant 0.000000e+00 : f32
28+
%0 = linalg.pad_tensor %arg0 low[1, 2, 2] high[0, 2, 2] {
29+
^bb0(%arg1: index, %arg2: index, %arg3: index): // no predecessors
30+
linalg.yield %cst : f32
31+
} : tensor<1x28x28xf32> to tensor<2x32x32xf32>
32+
return %0 : tensor<2x32x32xf32>
33+
}
34+
35+
// CHECK: linalg.fill
36+
// CHECK: linalg.generic
37+
// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
38+
39+
// -----
40+
41+
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
42+
// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)>
43+
// CHECK-LABEL: func @pad_tensor_detailed
44+
func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
45+
%cst = constant 0.000000e+00 : f32
46+
%0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
47+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
48+
linalg.yield %cst : f32
49+
} : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
50+
return %0 : tensor<1x32x32x1xf32>
51+
}
52+
53+
// CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32>
54+
// CHECK: %[[CTE:.+]] = constant 0.000000e+00 : f32
55+
// CHECK: %[[TMP:.+]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
56+
// CHECK: %[[R1c:.+]] = linalg.fill
57+
// CHECK: %[[R2c:.+]] = linalg.generic
58+
// CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]]
59+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
60+
// CHECK: ins(%arg0 : tensor<1x28x28x1xf32>) outs(%1 : tensor<1x32x32x1xf32>)
61+
// CHECK: ^bb0(%[[VAL:.+]]: f32, %arg2: f32)
62+
// CHECK: linalg.yield %[[VAL]] : f32
63+
// CHECK: return %[[R2c:.+]]

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ struct TestLinalgTransforms
8787
Option<int> testHoistPadding{*this, "test-hoist-padding",
8888
llvm::cl::desc("Test hoist padding"),
8989
llvm::cl::init(0)};
90+
Option<bool> testTransformPadTensor{
91+
*this, "test-transform-pad-tensor",
92+
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
93+
llvm::cl::init(false)};
9094
ListOption<int64_t> tileSizesForPadding{
9195
*this, "tile-sizes-for-padding",
9296
llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
@@ -508,6 +512,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
508512
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
509513
}
510514

515+
static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
516+
RewritePatternSet patterns(funcOp.getContext());
517+
patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
518+
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
519+
}
520+
511521
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
512522
RewritePatternSet foldPattern(funcOp.getContext());
513523
foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
@@ -583,6 +593,8 @@ void TestLinalgTransforms::runOnFunction() {
583593
return applyVectorTransferForwardingPatterns(getFunction());
584594
if (testGenericToVectorPattern)
585595
return applyLinalgToVectorPatterns(getFunction());
596+
if (testTransformPadTensor)
597+
return applyPadTensorToGenericPatterns(getFunction());
586598
if (testAffineMinSCFCanonicalizationPatterns)
587599
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
588600
if (testTileAndPadPattern)

0 commit comments

Comments
 (0)