Skip to content

Commit 2f946ea

Browse files
committed
[mlir] Change the pattern for TiledLoopOp bufferization.
This version is does not affect the patterns for Extract/InsertSliceOp and LinalgOps. Differential Revision: https://reviews.llvm.org/D107858
1 parent b821086 commit 2f946ea

File tree

2 files changed

+154
-70
lines changed

2 files changed

+154
-70
lines changed

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 93 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,8 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
213213
Location loc = op.getLoc();
214214
SmallVector<Value, 2> newOutputBuffers;
215215

216-
if (op->getParentOfType<TiledLoopOp>()) {
217-
newOutputBuffers = adaptor.outputs();
218-
} else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
219-
newOutputBuffers, rewriter))) {
216+
if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
217+
newOutputBuffers, rewriter))) {
220218
return op.emitOpError()
221219
<< "Failed to allocate buffers for tensor results.";
222220
}
@@ -233,14 +231,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
233231
}
234232
};
235233

236-
bool IsBlockArgOfTiledLoop(Value tensor) {
237-
if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
238-
if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
239-
if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
240-
return true;
241-
return false;
242-
}
243-
244234
/// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
245235
/// alloc + copy pattern.
246236
/// ```
@@ -263,15 +253,6 @@ class ExtractSliceOpConverter
263253
Value sourceMemref = adaptor.source();
264254
assert(sourceMemref.getType().isa<MemRefType>());
265255

266-
// Block arguments of the tiled_loop can be bufferized inplace.
267-
if (IsBlockArgOfTiledLoop(op.source())) {
268-
Value subView = rewriter.create<memref::SubViewOp>(
269-
op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
270-
op.getMixedStrides());
271-
rewriter.replaceOp(op, subView);
272-
return success();
273-
}
274-
275256
MemRefType subviewMemRefType =
276257
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
277258
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -315,12 +296,7 @@ class InsertSliceOpConverter
315296
// For now, be conservative and copy the converted input memref.
316297
// In general, the converted input memref here could be aliased or could
317298
// point into constant memory, so mutating it would lead to miscompilations.
318-
// Block arguments of the tiled_loop can be bufferized inplace.
319-
Value destMemRef;
320-
if (IsBlockArgOfTiledLoop(op.dest()))
321-
destMemRef = adaptor.dest();
322-
else
323-
destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
299+
Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
324300
assert(destMemRef.getType().isa<MemRefType>());
325301

326302
// Take a subview to copy the small memref.
@@ -334,60 +310,115 @@ class InsertSliceOpConverter
334310
}
335311
};
336312

313+
bool isBlockArgOfTiledLoop(Value tensor) {
314+
if (auto blockArgument = tensor.dyn_cast<BlockArgument>())
315+
return isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp());
316+
return false;
317+
}
318+
319+
SmallVector<Value, 3> convertOperands(ValueRange operands,
320+
BlockAndValueMapping &bvm) {
321+
SmallVector<Value, 3> newOperands;
322+
newOperands.reserve(operands.size());
323+
for (auto operand : operands)
324+
newOperands.push_back(bvm.lookupOrDefault(operand));
325+
return newOperands;
326+
}
327+
337328
class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
338329
public:
339330
using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
340331

341332
LogicalResult
342-
matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
333+
matchAndRewrite(TiledLoopOp loop, ArrayRef<Value> operands,
343334
ConversionPatternRewriter &rewriter) const final {
344-
TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
345-
Location loc = tiledLoop.getLoc();
346-
if (tiledLoop.getNumResults() == 0)
335+
TiledLoopOp::Adaptor adaptor(operands, loop->getAttrDictionary());
336+
if (loop.getNumResults() == 0)
347337
return failure();
348-
auto newTiledLoop = rewriter.create<TiledLoopOp>(
338+
339+
Location loc = loop.getLoc();
340+
auto newLoop = rewriter.create<TiledLoopOp>(
349341
loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
350342
adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
351343
adaptor.distribution_types());
344+
352345
// Clone the region.
353346
BlockAndValueMapping bvm;
354-
bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
347+
bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
348+
bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
349+
bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
355350

356351
OpBuilder innerBuilder =
357-
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
358-
359-
// Remap input block arguments.
360-
SmallVector<Value, 2> inputs;
361-
for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
362-
tiledLoop.getRegionInputArgs())) {
363-
auto &newInputArg = std::get<0>(en);
364-
if (!newInputArg.getType().isa<ShapedType>()) {
365-
inputs.push_back(std::get<0>(en));
366-
continue;
352+
OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
353+
354+
for (auto &op : loop.getBody()->getOperations()) {
355+
Location loc = op.getLoc();
356+
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
357+
if (isBlockArgOfTiledLoop(extractSlice.source())) {
358+
auto newOperands = convertOperands(extractSlice.getOperands(), bvm);
359+
auto srcMemRefType =
360+
bvm.lookup(extractSlice.source()).getType().cast<MemRefType>();
361+
auto dstMemRefType =
362+
memref::SubViewOp::inferResultType(
363+
srcMemRefType,
364+
extractFromI64ArrayAttr(extractSlice.static_offsets()),
365+
extractFromI64ArrayAttr(extractSlice.static_sizes()),
366+
extractFromI64ArrayAttr(extractSlice.static_strides()))
367+
.cast<MemRefType>();
368+
369+
Value subView = innerBuilder.create<memref::SubViewOp>(
370+
loc, TypeRange{dstMemRefType}, newOperands,
371+
extractSlice->getAttrs());
372+
bvm.map(extractSlice.getResult(), subView);
373+
continue;
374+
}
367375
}
368-
inputs.push_back(
369-
innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
370-
}
371-
bvm.map(tiledLoop.getRegionInputArgs(), inputs);
372-
373-
// Remap output block arguments.
374-
SmallVector<Value, 2> outputs;
375-
for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
376-
tiledLoop.getRegionOutputArgs())) {
377-
auto &newOutputArg = std::get<0>(en);
378-
if (!newOutputArg.getType().isa<ShapedType>()) {
379-
outputs.push_back(std::get<0>(en));
376+
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
377+
if (isBlockArgOfTiledLoop(insertSlice.dest())) {
378+
continue;
379+
}
380+
}
381+
if (auto yield = dyn_cast<linalg::YieldOp>(op)) {
382+
for (OpOperand &operand : yield->getOpOperands()) {
383+
if (auto insert =
384+
operand.get().getDefiningOp<tensor::InsertSliceOp>()) {
385+
386+
auto dstMemRefType = memref::SubViewOp::inferResultType(
387+
getTypeConverter()
388+
->convertType(insert.source().getType())
389+
.cast<MemRefType>(),
390+
extractFromI64ArrayAttr(insert.static_offsets()),
391+
extractFromI64ArrayAttr(insert.static_sizes()),
392+
extractFromI64ArrayAttr(insert.static_strides()));
393+
394+
Value subView = innerBuilder.create<memref::SubViewOp>(
395+
loc, dstMemRefType, bvm.lookup(insert.dest()),
396+
convertOperands(insert.offsets(), bvm),
397+
convertOperands(insert.sizes(), bvm),
398+
convertOperands(insert.strides(), bvm), insert.static_offsets(),
399+
insert.static_sizes(), insert.static_strides());
400+
401+
Value cast = innerBuilder.create<memref::BufferCastOp>(
402+
loc,
403+
getTypeConverter()
404+
->convertType(insert.source().getType())
405+
.cast<MemRefType>(),
406+
bvm.lookup(insert.source()));
407+
408+
innerBuilder.create<linalg::CopyOp>(loc, cast, subView);
409+
continue;
410+
}
411+
auto dst = newLoop.getRegionOutputArgs()[operand.getOperandNumber()];
412+
Value cast = innerBuilder.create<memref::BufferCastOp>(
413+
loc, dst.getType(), bvm.lookup(operand.get()));
414+
innerBuilder.create<linalg::CopyOp>(loc, cast, dst);
415+
}
380416
continue;
381417
}
382-
outputs.push_back(
383-
innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
384-
}
385-
bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
386-
387-
for (auto &op : tiledLoop.getBody()->without_terminator())
388418
innerBuilder.clone(op, bvm);
419+
}
389420
innerBuilder.create<linalg::YieldOp>(loc);
390-
rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
421+
rewriter.replaceOp(loop, newLoop.outputs());
391422
return success();
392423
}
393424
};

mlir/test/Dialect/Linalg/bufferize.mlir

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,66 @@ func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
339339
linalg.yield %dot_sub : tensor<f32>
340340
}
341341
// CHECK: linalg.tiled_loop
342-
// CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
343-
// CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
344-
// CHECK-NOT: alloc
345-
// CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
346-
// CHECK: %[[SV_B:.*]] = memref.subview %[[B]]
347-
// CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]]
348-
// CHECK-SAME: outs(%[[C]] : memref<f32>)
349-
// CHECK: linalg.yield
342+
// CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
343+
// CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
344+
// CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<f32>)
345+
346+
// CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
347+
// CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
348+
// CHECK-NEXT: %[[TMP:.*]] = memref.alloc
349+
// CHECK-NEXT: linalg.copy(%[[C]], %[[TMP]])
350+
// CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]]
351+
// CHECK-SAME: outs(%[[TMP]] : memref<f32>)
352+
// CHECK-NEXT: linalg.copy(%[[TMP]], %[[C]])
353+
// CHECK-NEXT: linalg.yield
350354
return %dot : tensor<f32>
351355
}
356+
357+
// -----
358+
359+
#map0 = affine_map<(d0) -> (d0)>
360+
361+
func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>,
362+
%C: tensor<10xf32>) -> tensor<10xf32> {
363+
%c0 = constant 0 : index
364+
%c2 = constant 2 : index
365+
%c10 = constant 10 : index
366+
367+
%sum = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
368+
ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
369+
outs (%C_ = %C: tensor<10xf32>) {
370+
%A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
371+
: tensor<10xf32> to tensor<?xf32>
372+
%B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
373+
: tensor<10xf32> to tensor<?xf32>
374+
%C_sub = tensor.extract_slice %C_[%i] [%c2] [1]
375+
: tensor<10xf32> to tensor<?xf32>
376+
%sum_sub = linalg.generic {
377+
indexing_maps = [#map0, #map0, #map0],
378+
iterator_types = ["parallel"]
379+
} ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
380+
outs(%C_sub : tensor<?xf32>) {
381+
^bb0(%a: f32, %b: f32, %c: f32):
382+
%0 = std.addf %a, %b : f32
383+
linalg.yield %0 : f32
384+
} -> tensor<?xf32>
385+
%update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1]
386+
: tensor<?xf32> into tensor<10xf32>
387+
linalg.yield %update : tensor<10xf32>
388+
}
389+
// CHECK: linalg.tiled_loop
390+
// CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
391+
// CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
392+
// CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>)
393+
394+
// CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
395+
// CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
396+
// CHECK-NEXT: %[[TMP:.*]] = memref.alloc
397+
// CHECK-NEXT: linalg.generic
398+
// CHECK-SAME: ins(%[[SV_A]], %[[SV_B]]
399+
// CHECK-SAME: outs(%[[TMP]] : memref<2xf32>)
400+
// CHECK: %[[SV_C:.*]] = memref.subview %[[C]]
401+
// CHECK-NEXT: linalg.copy(%[[TMP]], %[[SV_C]])
402+
// CHECK-NEXT: linalg.yield
403+
return %sum : tensor<10xf32>
404+
}

0 commit comments

Comments
 (0)