Skip to content

Commit edaffeb

Browse files
committed
Cloned from CL 389610703 by 'g4 patch'.
Original change by pifon@pifon:tfrt_clean:6896:citc on 2021/08/09 05:30:17. Ad b Differential Revision: https://reviews.llvm.org/D107762
1 parent 8cf8349 commit edaffeb

File tree

2 files changed

+172
-4
lines changed

2 files changed

+172
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,10 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
213213
Location loc = op.getLoc();
214214
SmallVector<Value, 2> newOutputBuffers;
215215

216-
if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
217-
newOutputBuffers, rewriter))) {
216+
if (op->getParentOfType<TiledLoopOp>()) {
217+
newOutputBuffers = adaptor.outputs();
218+
} else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
219+
newOutputBuffers, rewriter))) {
218220
return op.emitOpError()
219221
<< "Failed to allocate buffers for tensor results.";
220222
}
@@ -231,6 +233,14 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
231233
}
232234
};
233235

236+
bool IsBlockArgOfTiledLoop(Value tensor) {
237+
if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
238+
if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
239+
if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
240+
return true;
241+
return false;
242+
}
243+
234244
/// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
235245
/// alloc + copy pattern.
236246
/// ```
@@ -253,6 +263,15 @@ class ExtractSliceOpConverter
253263
Value sourceMemref = adaptor.source();
254264
assert(sourceMemref.getType().isa<MemRefType>());
255265

266+
// Block arguments of the tiled_loop can be bufferized inplace.
267+
if (IsBlockArgOfTiledLoop(op.source())) {
268+
Value subView = rewriter.create<memref::SubViewOp>(
269+
op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
270+
op.getMixedStrides());
271+
rewriter.replaceOp(op, subView);
272+
return success();
273+
}
274+
256275
MemRefType subviewMemRefType =
257276
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
258277
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -296,7 +315,12 @@ class InsertSliceOpConverter
296315
// For now, be conservative and copy the converted input memref.
297316
// In general, the converted input memref here could be aliased or could
298317
// point into constant memory, so mutating it would lead to miscompilations.
299-
Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
318+
// Block arguments of the tiled_loop can be bufferized inplace.
319+
Value destMemRef;
320+
if (IsBlockArgOfTiledLoop(op.dest()))
321+
destMemRef = adaptor.dest();
322+
else
323+
destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
300324
assert(destMemRef.getType().isa<MemRefType>());
301325

302326
// Take a subview to copy the small memref.
@@ -310,6 +334,64 @@ class InsertSliceOpConverter
310334
}
311335
};
312336

337+
class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
338+
public:
339+
using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
340+
341+
LogicalResult
342+
matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
343+
ConversionPatternRewriter &rewriter) const final {
344+
TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
345+
Location loc = tiledLoop.getLoc();
346+
if (tiledLoop.getNumResults() == 0)
347+
return failure();
348+
auto newTiledLoop = rewriter.create<TiledLoopOp>(
349+
loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
350+
adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
351+
adaptor.distribution_types());
352+
// Clone the region.
353+
BlockAndValueMapping bvm;
354+
bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
355+
356+
OpBuilder innerBuilder =
357+
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
358+
359+
// Remap input block arguments.
360+
SmallVector<Value, 2> inputs;
361+
for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
362+
tiledLoop.getRegionInputArgs())) {
363+
auto &newInputArg = std::get<0>(en);
364+
if (!newInputArg.getType().isa<ShapedType>()) {
365+
inputs.push_back(std::get<0>(en));
366+
continue;
367+
}
368+
inputs.push_back(
369+
innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
370+
}
371+
bvm.map(tiledLoop.getRegionInputArgs(), inputs);
372+
373+
// Remap output block arguments.
374+
SmallVector<Value, 2> outputs;
375+
for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
376+
tiledLoop.getRegionOutputArgs())) {
377+
auto &newOutputArg = std::get<0>(en);
378+
if (!newOutputArg.getType().isa<ShapedType>()) {
379+
outputs.push_back(std::get<0>(en));
380+
continue;
381+
}
382+
outputs.push_back(
383+
innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
384+
}
385+
bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
386+
387+
for (auto &op : tiledLoop.getBody()->without_terminator())
388+
innerBuilder.clone(op, bvm);
389+
innerBuilder.create<linalg::YieldOp>(loc);
390+
rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
391+
return success();
392+
}
393+
};
394+
313395
class VectorTransferReadOpConverter
314396
: public OpConversionPattern<vector::TransferReadOp> {
315397
public:
@@ -352,14 +434,66 @@ class VectorTransferWriteOpConverter
352434
};
353435
} // namespace
354436

437+
static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
438+
ValueRange inputs, Location loc) {
439+
assert(inputs.size() == 1);
440+
assert(inputs[0].getType().isa<BaseMemRefType>());
441+
return builder.create<memref::TensorLoadOp>(loc, type, inputs[0]);
442+
}
443+
355444
namespace {
445+
446+
/// A helper type converter class that automatically populates the relevant
447+
/// materializations and type conversions for bufferization.
448+
//
449+
// The default BufferizeTypeConverter defined in "Transforms/Bufferize.h" does
450+
// not properly support memrefs with non-default layout. Whenever a layout of
451+
// memref changes during bufferization, target materialization call back would
452+
// assert that the non-matching type is a tensor.
453+
// There was an attempt to fix this behavior of dialect conversion in a more
454+
// principal way in https://reviews.llvm.org/D93126 but it had to be reverted
455+
// due to test failures outside of MLIR Core. It might make sense to revive this
456+
// PR.
457+
class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
458+
public:
459+
CustomBufferizeTypeConverter() {
460+
// Keep all types unchanged.
461+
addConversion([](Type type) { return type; });
462+
// Convert RankedTensorType to MemRefType.
463+
addConversion([](RankedTensorType type) -> Type {
464+
return MemRefType::get(type.getShape(), type.getElementType());
465+
});
466+
// Convert UnrankedTensorType to UnrankedMemRefType.
467+
addConversion([](UnrankedTensorType type) -> Type {
468+
return UnrankedMemRefType::get(type.getElementType(), 0);
469+
});
470+
addArgumentMaterialization(materializeTensorLoad);
471+
addSourceMaterialization(materializeTensorLoad);
472+
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
473+
ValueRange inputs, Location loc) -> Value {
474+
assert(inputs.size() == 1);
475+
// Target materialization is invoked if the new operand type does not
476+
// match the expected type. A special case is when the new operand type is
477+
// a memref with a specified layout, i.e. non-empty affine map.
478+
// TODO(pifon) : Change how target materialization is invoked in dialect
479+
// conversion.
480+
if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
481+
assert(!memrefType.getAffineMaps().empty());
482+
return inputs[0];
483+
}
484+
assert(inputs[0].getType().isa<TensorType>());
485+
return builder.create<memref::BufferCastOp>(loc, type, inputs[0]);
486+
});
487+
}
488+
};
489+
356490
/// Converts Linalg operations that work on tensor-type operands or results to
357491
/// work on buffers.
358492
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
359493
void runOnOperation() override {
360494
MLIRContext &context = getContext();
361495
ConversionTarget target(context);
362-
BufferizeTypeConverter typeConverter;
496+
CustomBufferizeTypeConverter typeConverter;
363497

364498
// Mark all Standard operations legal.
365499
target.addLegalDialect<AffineDialect, math::MathDialect,
@@ -401,6 +535,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
401535
BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
402536
ExtractSliceOpConverter,
403537
InsertSliceOpConverter,
538+
TiledLoopOpConverter,
404539
VectorTransferReadOpConverter,
405540
VectorTransferWriteOpConverter
406541
>(typeConverter, patterns.getContext());

mlir/test/Dialect/Linalg/bufferize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,36 @@ func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
316316
// CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
317317
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
318318
}
319+
320+
// -----
321+
322+
// CHECK: func @tiled_dot
323+
func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
324+
%C: tensor<f32>) -> tensor<f32> {
325+
%c0 = constant 0 : index
326+
%c2 = constant 2 : index
327+
%c10 = constant 10 : index
328+
329+
%dot = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
330+
ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
331+
outs (%C_ = %C: tensor<f32>)
332+
iterators["reduction"] {
333+
%A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
334+
: tensor<10xf32> to tensor<?xf32>
335+
%B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
336+
: tensor<10xf32> to tensor<?xf32>
337+
%dot_sub = linalg.dot ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
338+
outs(%C_ : tensor<f32>) -> tensor<f32>
339+
linalg.yield %dot_sub : tensor<f32>
340+
}
341+
// CHECK: linalg.tiled_loop
342+
// CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
343+
// CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
344+
// CHECK-NOT: alloc
345+
// CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
346+
// CHECK: %[[SV_B:.*]] = memref.subview %[[B]]
347+
// CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]]
348+
// CHECK-SAME: outs(%[[C]] : memref<f32>)
349+
// CHECK: linalg.yield
350+
return %dot : tensor<f32>
351+
}

0 commit comments

Comments
 (0)