@@ -213,10 +213,8 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
213
213
Location loc = op.getLoc ();
214
214
SmallVector<Value, 2 > newOutputBuffers;
215
215
216
- if (op->getParentOfType <TiledLoopOp>()) {
217
- newOutputBuffers = adaptor.outputs ();
218
- } else if (failed (allocateBuffersForResults (loc, op, adaptor.outputs (),
219
- newOutputBuffers, rewriter))) {
216
+ if (failed (allocateBuffersForResults (loc, op, adaptor.outputs (),
217
+ newOutputBuffers, rewriter))) {
220
218
return op.emitOpError ()
221
219
<< " Failed to allocate buffers for tensor results." ;
222
220
}
@@ -233,14 +231,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
233
231
}
234
232
};
235
233
236
- bool IsBlockArgOfTiledLoop (Value tensor) {
237
- if (auto tensorLoad = tensor.getDefiningOp <memref::TensorLoadOp>())
238
- if (auto blockArgument = tensorLoad.memref ().dyn_cast <BlockArgument>())
239
- if (isa<TiledLoopOp>(blockArgument.getOwner ()->getParentOp ()))
240
- return true ;
241
- return false ;
242
- }
243
-
244
234
// / Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
245
235
// / alloc + copy pattern.
246
236
// / ```
@@ -263,15 +253,6 @@ class ExtractSliceOpConverter
263
253
Value sourceMemref = adaptor.source ();
264
254
assert (sourceMemref.getType ().isa <MemRefType>());
265
255
266
- // Block arguments of the tiled_loop can be bufferized inplace.
267
- if (IsBlockArgOfTiledLoop (op.source ())) {
268
- Value subView = rewriter.create <memref::SubViewOp>(
269
- op.getLoc (), sourceMemref, op.getMixedOffsets (), op.getMixedSizes (),
270
- op.getMixedStrides ());
271
- rewriter.replaceOp (op, subView);
272
- return success ();
273
- }
274
-
275
256
MemRefType subviewMemRefType =
276
257
getTypeConverter ()->convertType (op.getType ()).cast <MemRefType>();
277
258
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -315,12 +296,7 @@ class InsertSliceOpConverter
315
296
// For now, be conservative and copy the converted input memref.
316
297
// In general, the converted input memref here could be aliased or could
317
298
// point into constant memory, so mutating it would lead to miscompilations.
318
- // Block arguments of the tiled_loop can be bufferized inplace.
319
- Value destMemRef;
320
- if (IsBlockArgOfTiledLoop (op.dest ()))
321
- destMemRef = adaptor.dest ();
322
- else
323
- destMemRef = cloneMemref (op.getLoc (), adaptor.dest (), rewriter);
299
+ Value destMemRef = cloneMemref (op.getLoc (), adaptor.dest (), rewriter);
324
300
assert (destMemRef.getType ().isa <MemRefType>());
325
301
326
302
// Take a subview to copy the small memref.
@@ -334,60 +310,115 @@ class InsertSliceOpConverter
334
310
}
335
311
};
336
312
313
+ bool isBlockArgOfTiledLoop (Value tensor) {
314
+ if (auto blockArgument = tensor.dyn_cast <BlockArgument>())
315
+ return isa<TiledLoopOp>(blockArgument.getOwner ()->getParentOp ());
316
+ return false ;
317
+ }
318
+
319
+ SmallVector<Value, 3 > convertOperands (ValueRange operands,
320
+ BlockAndValueMapping &bvm) {
321
+ SmallVector<Value, 3 > newOperands;
322
+ newOperands.reserve (operands.size ());
323
+ for (auto operand : operands)
324
+ newOperands.push_back (bvm.lookupOrDefault (operand));
325
+ return newOperands;
326
+ }
327
+
337
328
class TiledLoopOpConverter : public OpConversionPattern <TiledLoopOp> {
338
329
public:
339
330
using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
340
331
341
332
LogicalResult
342
- matchAndRewrite (TiledLoopOp tiledLoop , ArrayRef<Value> operands,
333
+ matchAndRewrite (TiledLoopOp loop , ArrayRef<Value> operands,
343
334
ConversionPatternRewriter &rewriter) const final {
344
- TiledLoopOp::Adaptor adaptor (operands, tiledLoop->getAttrDictionary ());
345
- Location loc = tiledLoop.getLoc ();
346
- if (tiledLoop.getNumResults () == 0 )
335
+ TiledLoopOp::Adaptor adaptor (operands, loop->getAttrDictionary ());
336
+ if (loop.getNumResults () == 0 )
347
337
return failure ();
348
- auto newTiledLoop = rewriter.create <TiledLoopOp>(
338
+
339
+ Location loc = loop.getLoc ();
340
+ auto newLoop = rewriter.create <TiledLoopOp>(
349
341
loc, adaptor.lowerBound (), adaptor.upperBound (), adaptor.step (),
350
342
adaptor.inputs (), adaptor.outputs (), adaptor.iterator_types (),
351
343
adaptor.distribution_types ());
344
+
352
345
// Clone the region.
353
346
BlockAndValueMapping bvm;
354
- bvm.map (tiledLoop.getInductionVars (), newTiledLoop.getInductionVars ());
347
+ bvm.map (loop.getInductionVars (), newLoop.getInductionVars ());
348
+ bvm.map (loop.getRegionInputArgs (), newLoop.getRegionInputArgs ());
349
+ bvm.map (loop.getRegionOutputArgs (), newLoop.getRegionOutputArgs ());
355
350
356
351
OpBuilder innerBuilder =
357
- OpBuilder::atBlockEnd (newTiledLoop.getBody (), rewriter.getListener ());
358
-
359
- // Remap input block arguments.
360
- SmallVector<Value, 2 > inputs;
361
- for (auto en : llvm::zip (newTiledLoop.getRegionInputArgs (),
362
- tiledLoop.getRegionInputArgs ())) {
363
- auto &newInputArg = std::get<0 >(en);
364
- if (!newInputArg.getType ().isa <ShapedType>()) {
365
- inputs.push_back (std::get<0 >(en));
366
- continue ;
352
+ OpBuilder::atBlockEnd (newLoop.getBody (), rewriter.getListener ());
353
+
354
+ for (auto &op : loop.getBody ()->getOperations ()) {
355
+ Location loc = op.getLoc ();
356
+ if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
357
+ if (isBlockArgOfTiledLoop (extractSlice.source ())) {
358
+ auto newOperands = convertOperands (extractSlice.getOperands (), bvm);
359
+ auto srcMemRefType =
360
+ bvm.lookup (extractSlice.source ()).getType ().cast <MemRefType>();
361
+ auto dstMemRefType =
362
+ memref::SubViewOp::inferResultType (
363
+ srcMemRefType,
364
+ extractFromI64ArrayAttr (extractSlice.static_offsets ()),
365
+ extractFromI64ArrayAttr (extractSlice.static_sizes ()),
366
+ extractFromI64ArrayAttr (extractSlice.static_strides ()))
367
+ .cast <MemRefType>();
368
+
369
+ Value subView = innerBuilder.create <memref::SubViewOp>(
370
+ loc, TypeRange{dstMemRefType}, newOperands,
371
+ extractSlice->getAttrs ());
372
+ bvm.map (extractSlice.getResult (), subView);
373
+ continue ;
374
+ }
367
375
}
368
- inputs.push_back (
369
- innerBuilder.create <memref::TensorLoadOp>(loc, newInputArg));
370
- }
371
- bvm.map (tiledLoop.getRegionInputArgs (), inputs);
372
-
373
- // Remap output block arguments.
374
- SmallVector<Value, 2 > outputs;
375
- for (auto en : llvm::zip (newTiledLoop.getRegionOutputArgs (),
376
- tiledLoop.getRegionOutputArgs ())) {
377
- auto &newOutputArg = std::get<0 >(en);
378
- if (!newOutputArg.getType ().isa <ShapedType>()) {
379
- outputs.push_back (std::get<0 >(en));
376
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
377
+ if (isBlockArgOfTiledLoop (insertSlice.dest ())) {
378
+ continue ;
379
+ }
380
+ }
381
+ if (auto yield = dyn_cast<linalg::YieldOp>(op)) {
382
+ for (OpOperand &operand : yield->getOpOperands ()) {
383
+ if (auto insert =
384
+ operand.get ().getDefiningOp <tensor::InsertSliceOp>()) {
385
+
386
+ auto dstMemRefType = memref::SubViewOp::inferResultType (
387
+ getTypeConverter ()
388
+ ->convertType (insert.source ().getType ())
389
+ .cast <MemRefType>(),
390
+ extractFromI64ArrayAttr (insert.static_offsets ()),
391
+ extractFromI64ArrayAttr (insert.static_sizes ()),
392
+ extractFromI64ArrayAttr (insert.static_strides ()));
393
+
394
+ Value subView = innerBuilder.create <memref::SubViewOp>(
395
+ loc, dstMemRefType, bvm.lookup (insert.dest ()),
396
+ convertOperands (insert.offsets (), bvm),
397
+ convertOperands (insert.sizes (), bvm),
398
+ convertOperands (insert.strides (), bvm), insert.static_offsets (),
399
+ insert.static_sizes (), insert.static_strides ());
400
+
401
+ Value cast = innerBuilder.create <memref::BufferCastOp>(
402
+ loc,
403
+ getTypeConverter ()
404
+ ->convertType (insert.source ().getType ())
405
+ .cast <MemRefType>(),
406
+ bvm.lookup (insert.source ()));
407
+
408
+ innerBuilder.create <linalg::CopyOp>(loc, cast, subView);
409
+ continue ;
410
+ }
411
+ auto dst = newLoop.getRegionOutputArgs ()[operand.getOperandNumber ()];
412
+ Value cast = innerBuilder.create <memref::BufferCastOp>(
413
+ loc, dst.getType (), bvm.lookup (operand.get ()));
414
+ innerBuilder.create <linalg::CopyOp>(loc, cast, dst);
415
+ }
380
416
continue ;
381
417
}
382
- outputs.push_back (
383
- innerBuilder.create <memref::TensorLoadOp>(loc, newOutputArg));
384
- }
385
- bvm.map (tiledLoop.getRegionOutputArgs (), outputs);
386
-
387
- for (auto &op : tiledLoop.getBody ()->without_terminator ())
388
418
innerBuilder.clone (op, bvm);
419
+ }
389
420
innerBuilder.create <linalg::YieldOp>(loc);
390
- rewriter.replaceOp (tiledLoop, newTiledLoop .outputs ());
421
+ rewriter.replaceOp (loop, newLoop .outputs ());
391
422
return success ();
392
423
}
393
424
};
0 commit comments