@@ -213,8 +213,10 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
213
213
Location loc = op.getLoc ();
214
214
SmallVector<Value, 2 > newOutputBuffers;
215
215
216
- if (failed (allocateBuffersForResults (loc, op, adaptor.outputs (),
217
- newOutputBuffers, rewriter))) {
216
+ if (op->getParentOfType <TiledLoopOp>()) {
217
+ newOutputBuffers = adaptor.outputs ();
218
+ } else if (failed (allocateBuffersForResults (loc, op, adaptor.outputs (),
219
+ newOutputBuffers, rewriter))) {
218
220
return op.emitOpError ()
219
221
<< " Failed to allocate buffers for tensor results." ;
220
222
}
@@ -231,6 +233,14 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
231
233
}
232
234
};
233
235
236
+ bool IsBlockArgOfTiledLoop (Value tensor) {
237
+ if (auto tensorLoad = tensor.getDefiningOp <memref::TensorLoadOp>())
238
+ if (auto blockArgument = tensorLoad.memref ().dyn_cast <BlockArgument>())
239
+ if (isa<TiledLoopOp>(blockArgument.getOwner ()->getParentOp ()))
240
+ return true ;
241
+ return false ;
242
+ }
243
+
234
244
// / Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
235
245
// / alloc + copy pattern.
236
246
// / ```
@@ -253,6 +263,15 @@ class ExtractSliceOpConverter
253
263
Value sourceMemref = adaptor.source ();
254
264
assert (sourceMemref.getType ().isa <MemRefType>());
255
265
266
+ // Block arguments of the tiled_loop can be bufferized inplace.
267
+ if (IsBlockArgOfTiledLoop (op.source ())) {
268
+ Value subView = rewriter.create <memref::SubViewOp>(
269
+ op.getLoc (), sourceMemref, op.getMixedOffsets (), op.getMixedSizes (),
270
+ op.getMixedStrides ());
271
+ rewriter.replaceOp (op, subView);
272
+ return success ();
273
+ }
274
+
256
275
MemRefType subviewMemRefType =
257
276
getTypeConverter ()->convertType (op.getType ()).cast <MemRefType>();
258
277
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -296,7 +315,12 @@ class InsertSliceOpConverter
296
315
// For now, be conservative and copy the converted input memref.
297
316
// In general, the converted input memref here could be aliased or could
298
317
// point into constant memory, so mutating it would lead to miscompilations.
299
- Value destMemRef = cloneMemref (op.getLoc (), adaptor.dest (), rewriter);
318
+ // Block arguments of the tiled_loop can be bufferized inplace.
319
+ Value destMemRef;
320
+ if (IsBlockArgOfTiledLoop (op.dest ()))
321
+ destMemRef = adaptor.dest ();
322
+ else
323
+ destMemRef = cloneMemref (op.getLoc (), adaptor.dest (), rewriter);
300
324
assert (destMemRef.getType ().isa <MemRefType>());
301
325
302
326
// Take a subview to copy the small memref.
@@ -310,6 +334,64 @@ class InsertSliceOpConverter
310
334
}
311
335
};
312
336
337
+ class TiledLoopOpConverter : public OpConversionPattern <TiledLoopOp> {
338
+ public:
339
+ using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
340
+
341
+ LogicalResult
342
+ matchAndRewrite (TiledLoopOp tiledLoop, ArrayRef<Value> operands,
343
+ ConversionPatternRewriter &rewriter) const final {
344
+ TiledLoopOp::Adaptor adaptor (operands, tiledLoop->getAttrDictionary ());
345
+ Location loc = tiledLoop.getLoc ();
346
+ if (tiledLoop.getNumResults () == 0 )
347
+ return failure ();
348
+ auto newTiledLoop = rewriter.create <TiledLoopOp>(
349
+ loc, adaptor.lowerBound (), adaptor.upperBound (), adaptor.step (),
350
+ adaptor.inputs (), adaptor.outputs (), adaptor.iterator_types (),
351
+ adaptor.distribution_types ());
352
+ // Clone the region.
353
+ BlockAndValueMapping bvm;
354
+ bvm.map (tiledLoop.getInductionVars (), newTiledLoop.getInductionVars ());
355
+
356
+ OpBuilder innerBuilder =
357
+ OpBuilder::atBlockEnd (newTiledLoop.getBody (), rewriter.getListener ());
358
+
359
+ // Remap input block arguments.
360
+ SmallVector<Value, 2 > inputs;
361
+ for (auto en : llvm::zip (newTiledLoop.getRegionInputArgs (),
362
+ tiledLoop.getRegionInputArgs ())) {
363
+ auto &newInputArg = std::get<0 >(en);
364
+ if (!newInputArg.getType ().isa <ShapedType>()) {
365
+ inputs.push_back (std::get<0 >(en));
366
+ continue ;
367
+ }
368
+ inputs.push_back (
369
+ innerBuilder.create <memref::TensorLoadOp>(loc, newInputArg));
370
+ }
371
+ bvm.map (tiledLoop.getRegionInputArgs (), inputs);
372
+
373
+ // Remap output block arguments.
374
+ SmallVector<Value, 2 > outputs;
375
+ for (auto en : llvm::zip (newTiledLoop.getRegionOutputArgs (),
376
+ tiledLoop.getRegionOutputArgs ())) {
377
+ auto &newOutputArg = std::get<0 >(en);
378
+ if (!newOutputArg.getType ().isa <ShapedType>()) {
379
+ outputs.push_back (std::get<0 >(en));
380
+ continue ;
381
+ }
382
+ outputs.push_back (
383
+ innerBuilder.create <memref::TensorLoadOp>(loc, newOutputArg));
384
+ }
385
+ bvm.map (tiledLoop.getRegionOutputArgs (), outputs);
386
+
387
+ for (auto &op : tiledLoop.getBody ()->without_terminator ())
388
+ innerBuilder.clone (op, bvm);
389
+ innerBuilder.create <linalg::YieldOp>(loc);
390
+ rewriter.replaceOp (tiledLoop, newTiledLoop.outputs ());
391
+ return success ();
392
+ }
393
+ };
394
+
313
395
class VectorTransferReadOpConverter
314
396
: public OpConversionPattern<vector::TransferReadOp> {
315
397
public:
@@ -352,14 +434,66 @@ class VectorTransferWriteOpConverter
352
434
};
353
435
} // namespace
354
436
437
+ static Value materializeTensorLoad (OpBuilder &builder, TensorType type,
438
+ ValueRange inputs, Location loc) {
439
+ assert (inputs.size () == 1 );
440
+ assert (inputs[0 ].getType ().isa <BaseMemRefType>());
441
+ return builder.create <memref::TensorLoadOp>(loc, type, inputs[0 ]);
442
+ }
443
+
355
444
namespace {
445
+
446
+ // / A helper type converter class that automatically populates the relevant
447
+ // / materializations and type conversions for bufferization.
448
+ //
449
+ // The default BufferizeTypeConverter defined in "Transforms/Bufferize.h" does
450
+ // not properly support memrefs with non-default layout. Whenever a layout of
451
+ // memref changes during bufferization, target materialization call back would
452
+ // assert that the non-matching type is a tensor.
453
+ // There was an attempt to fix this behavior of dialect conversion in a more
454
+ // principal way in https://reviews.llvm.org/D93126 but it had to be reverted
455
+ // due to test failures outside of MLIR Core. It might make sense to revive this
456
+ // PR.
457
+ class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
458
+ public:
459
+ CustomBufferizeTypeConverter () {
460
+ // Keep all types unchanged.
461
+ addConversion ([](Type type) { return type; });
462
+ // Convert RankedTensorType to MemRefType.
463
+ addConversion ([](RankedTensorType type) -> Type {
464
+ return MemRefType::get (type.getShape (), type.getElementType ());
465
+ });
466
+ // Convert UnrankedTensorType to UnrankedMemRefType.
467
+ addConversion ([](UnrankedTensorType type) -> Type {
468
+ return UnrankedMemRefType::get (type.getElementType (), 0 );
469
+ });
470
+ addArgumentMaterialization (materializeTensorLoad);
471
+ addSourceMaterialization (materializeTensorLoad);
472
+ addTargetMaterialization ([](OpBuilder &builder, BaseMemRefType type,
473
+ ValueRange inputs, Location loc) -> Value {
474
+ assert (inputs.size () == 1 );
475
+ // Target materialization is invoked if the new operand type does not
476
+ // match the expected type. A special case is when the new operand type is
477
+ // a memref with a specified layout, i.e. non-empty affine map.
478
+ // TODO(pifon) : Change how target materialization is invoked in dialect
479
+ // conversion.
480
+ if (auto memrefType = inputs[0 ].getType ().dyn_cast <MemRefType>()) {
481
+ assert (!memrefType.getAffineMaps ().empty ());
482
+ return inputs[0 ];
483
+ }
484
+ assert (inputs[0 ].getType ().isa <TensorType>());
485
+ return builder.create <memref::BufferCastOp>(loc, type, inputs[0 ]);
486
+ });
487
+ }
488
+ };
489
+
356
490
// / Converts Linalg operations that work on tensor-type operands or results to
357
491
// / work on buffers.
358
492
struct LinalgBufferizePass : public LinalgBufferizeBase <LinalgBufferizePass> {
359
493
void runOnOperation () override {
360
494
MLIRContext &context = getContext ();
361
495
ConversionTarget target (context);
362
- BufferizeTypeConverter typeConverter;
496
+ CustomBufferizeTypeConverter typeConverter;
363
497
364
498
// Mark all Standard operations legal.
365
499
target.addLegalDialect <AffineDialect, math::MathDialect,
@@ -401,6 +535,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
401
535
BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
402
536
ExtractSliceOpConverter,
403
537
InsertSliceOpConverter,
538
+ TiledLoopOpConverter,
404
539
VectorTransferReadOpConverter,
405
540
VectorTransferWriteOpConverter
406
541
>(typeConverter, patterns.getContext ());
0 commit comments