Skip to content

Commit 7e5b10b

Browse files
[mlir][Linalg] Add support for tiling tensor.pad to scf.forall
Also, properly propagate the nofold attribute. Differential Revision: https://reviews.llvm.org/D148114
1 parent a6d9730 commit 7e5b10b

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,6 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
354354
return getValueOrCreateConstantIndexOp(b, loc, ofr);
355355
}));
356356

357-
Operation *tiledOp = nullptr;
358-
359357
// 1. Create the ForallOp. We don't use the lambda body-builder
360358
// version because we require the use of RewriterBase in the body, so we
361359
// manually move the insertion point to the body below.
@@ -371,6 +369,8 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
371369
// 3. Clone the tileable op and update its destination operands to use the
372370
// output bbArgs of the ForallOp.
373371
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
372+
Operation *tiledOp = nullptr;
373+
SmallVector<Value> tiledValues;
374374
{
375375
// 3.a. RAII guard, inserting within forallOp, before terminator.
376376
OpBuilder::InsertionGuard g(b);
@@ -395,13 +395,12 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
395395
assert(tilingResult->tiledOps.size() == 1 &&
396396
"expected a single produced tiled op");
397397
tiledOp = tilingResult->tiledOps.front();
398+
tiledValues = tilingResult->tiledValues;
398399
}
399400

400401
// 5. Parallel insert back into the result tensor.
401-
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
402-
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
403402
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
404-
tilingInterfaceOp->getResults(), destBbArgs)) {
403+
tiledValues, destBbArgs)) {
405404
// 5.a. Partial subset information is inserted just before the terminator.
406405
OpBuilder::InsertionGuard g(b);
407406
b.setInsertionPoint(forallOp.getTerminator());

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
617617
// Create pad(extract_slice(x)).
618618
Value newSliceOp = b.create<tensor::ExtractSliceOp>(
619619
loc, padOp.getSource(), newOffsets, newLengths, newStrides);
620-
auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs);
620+
auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs,
621+
/*nofold=*/padOp.getNofold());
621622

622623
// Copy region to new PadOp.
623624
IRMapping bvm;

mlir/test/Dialect/Linalg/transform-op-tile.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,28 @@ func.func @tile_linalg_matmul(
123123
-> tensor<128x128xf32>
124124
return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
125125
}
126+
127+
// -----
128+
129+
// CHECK-LABEL: tile_tensor_pad
130+
func.func @tile_tensor_pad(
131+
%arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
132+
-> tensor<20x40xf32>
133+
{
134+
// CHECK: scf.forall
135+
// CHECK: scf.if
136+
// CHECK: tensor.generate
137+
// CHECK: else
138+
// CHECK: tensor.pad {{.*}} nofold
139+
%0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
140+
^bb0(%arg9: index, %arg10: index):
141+
tensor.yield %cst : f32
142+
} : tensor<?x?xf32> to tensor<20x40xf32>
143+
return %0 : tensor<20x40xf32>
144+
}
145+
146+
transform.sequence failures(propagate) {
147+
^bb0(%arg1: !pdl.operation):
148+
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
149+
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
150+
}

0 commit comments

Comments
 (0)