@@ -56,16 +56,34 @@ static MemRefType unpackOneDim(MemRefType type) {
56
56
vectorType.getElementType ()));
57
57
}
58
58
59
- // TODO: Parallelism and threadlocal considerations.
60
- static Value setAllocAtFunctionEntry (MemRefType type, Operation *op) {
59
+ // / Helper data structure for data and mask buffers.
60
+ struct BufferAllocs {
61
+ Value dataBuffer;
62
+ Value maskBuffer;
63
+ };
64
+
65
+ // / Allocate temporary buffers for data (vector) and mask (if present).
66
+ // / TODO: Parallelism and threadlocal considerations.
67
+ template <typename OpTy>
68
+ static BufferAllocs allocBuffers (OpTy xferOp) {
61
69
auto &b = ScopedContext::getBuilderRef ();
62
70
OpBuilder::InsertionGuard guard (b);
63
71
Operation *scope =
64
- op-> getParentWithTrait <OpTrait::AutomaticAllocationScope>();
72
+ xferOp-> template getParentWithTrait <OpTrait::AutomaticAllocationScope>();
65
73
assert (scope && " Expected op to be inside automatic allocation scope" );
66
74
b.setInsertionPointToStart (&scope->getRegion (0 ).front ());
67
- Value res = memref_alloca (type);
68
- return res;
75
+
76
+ BufferAllocs result;
77
+ auto bufferType = MemRefType::get ({}, xferOp.getVectorType ());
78
+ result.dataBuffer = memref_alloca (bufferType).value ;
79
+
80
+ if (xferOp.mask ()) {
81
+ auto maskType = MemRefType::get ({}, xferOp.mask ().getType ());
82
+ result.maskBuffer = memref_alloca (maskType).value ;
83
+ memref_store (xferOp.mask (), result.maskBuffer );
84
+ }
85
+
86
+ return result;
69
87
}
70
88
71
89
// / Given a vector transfer op, calculate which dimension of the `source`
@@ -238,6 +256,16 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
238
256
return ArrayAttr::get (builder.getContext (), attr.getValue ().drop_front ());
239
257
}
240
258
259
+ // / Given a transfer op, find the memref from which the mask is loaded. This
260
+ // / is similar to Strategy<TransferWriteOp>::getBuffer.
261
+ template <typename OpTy>
262
+ static Value getMaskBuffer (OpTy xferOp) {
263
+ assert (xferOp.mask () && " Expected that transfer op has mask" );
264
+ auto loadOp = xferOp.mask ().template getDefiningOp <memref::LoadOp>();
265
+ assert (loadOp && " Expected transfer op mask produced by LoadOp" );
266
+ return loadOp.getMemRef ();
267
+ }
268
+
241
269
// / Codegen strategy, depending on the operation.
242
270
template <typename OpTy>
243
271
struct Strategy ;
@@ -266,9 +294,9 @@ struct Strategy<TransferReadOp> {
266
294
return getStoreOp (xferOp).getMemRef ();
267
295
}
268
296
269
- // / Retrieve the indices of the current StoreOp.
270
- static void getStoreIndices (TransferReadOp xferOp,
271
- SmallVector<Value, 8 > &indices) {
297
+ // / Retrieve the indices of the current StoreOp that stores into the buffer .
298
+ static void getBufferIndices (TransferReadOp xferOp,
299
+ SmallVector<Value, 8 > &indices) {
272
300
auto storeOp = getStoreOp (xferOp);
273
301
auto prevIndices = memref::StoreOpAdaptor (storeOp).indices ();
274
302
indices.append (prevIndices.begin (), prevIndices.end ());
@@ -300,10 +328,11 @@ struct Strategy<TransferReadOp> {
300
328
// /
301
329
// / Note: The loop and type cast are generated in TransferOpConversion.
302
330
// / The original TransferReadOp and store op are deleted in `cleanup`.
303
- static void rewriteOp (OpBuilder &builder, TransferReadOp xferOp,
304
- Value buffer, Value iv) {
331
+ // / Note: The `mask` operand is set in TransferOpConversion.
332
+ static TransferReadOp rewriteOp (OpBuilder &builder, TransferReadOp xferOp,
333
+ Value buffer, Value iv) {
305
334
SmallVector<Value, 8 > storeIndices;
306
- getStoreIndices (xferOp, storeIndices);
335
+ getBufferIndices (xferOp, storeIndices);
307
336
storeIndices.push_back (iv);
308
337
309
338
SmallVector<Value, 8 > xferIndices;
@@ -321,6 +350,7 @@ struct Strategy<TransferReadOp> {
321
350
newXfer.getDefiningOp ()->setAttr (kPassLabel , builder.getUnitAttr ());
322
351
323
352
memref_store (newXfer, buffer, storeIndices);
353
+ return newXfer.getDefiningOp <TransferReadOp>();
324
354
}
325
355
326
356
// / Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
@@ -329,7 +359,7 @@ struct Strategy<TransferReadOp> {
329
359
OpBuilder &/* builder*/ , TransferReadOp xferOp, Value buffer,
330
360
Value iv) {
331
361
SmallVector<Value, 8 > storeIndices;
332
- getStoreIndices (xferOp, storeIndices);
362
+ getBufferIndices (xferOp, storeIndices);
333
363
storeIndices.push_back (iv);
334
364
335
365
auto bufferType = buffer.getType ().dyn_cast <ShapedType>();
@@ -361,9 +391,9 @@ struct Strategy<TransferWriteOp> {
361
391
return loadOp.getMemRef ();
362
392
}
363
393
364
- // / Retrieve the indices of the current LoadOp.
365
- static void getLoadIndices (TransferWriteOp xferOp,
366
- SmallVector<Value, 8 > &indices) {
394
+ // / Retrieve the indices of the current LoadOp that loads from the buffer .
395
+ static void getBufferIndices (TransferWriteOp xferOp,
396
+ SmallVector<Value, 8 > &indices) {
367
397
auto loadOp = xferOp.vector ().getDefiningOp <memref::LoadOp>();
368
398
auto prevIndices = memref::LoadOpAdaptor (loadOp).indices ();
369
399
indices.append (prevIndices.begin (), prevIndices.end ());
@@ -378,10 +408,10 @@ struct Strategy<TransferWriteOp> {
378
408
// / to memory.
379
409
// /
380
410
// / Note: For more details, see comments on Strategy<TransferReadOp>.
381
- static void rewriteOp (OpBuilder &builder, TransferWriteOp xferOp,
382
- Value buffer, Value iv) {
411
+ static TransferWriteOp rewriteOp (OpBuilder &builder, TransferWriteOp xferOp,
412
+ Value buffer, Value iv) {
383
413
SmallVector<Value, 8 > loadIndices;
384
- getLoadIndices (xferOp, loadIndices);
414
+ getBufferIndices (xferOp, loadIndices);
385
415
loadIndices.push_back (iv);
386
416
387
417
SmallVector<Value, 8 > xferIndices;
@@ -397,6 +427,8 @@ struct Strategy<TransferWriteOp> {
397
427
398
428
if (vecType.getRank () > kTargetRank )
399
429
newXfer.op ->setAttr (kPassLabel , builder.getUnitAttr ());
430
+
431
+ return newXfer;
400
432
}
401
433
402
434
// / Handle out-of-bounds accesses on the to-be-unpacked dimension.
@@ -416,8 +448,6 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
416
448
return failure ();
417
449
if (xferOp.getVectorType ().getRank () <= kTargetRank )
418
450
return failure ();
419
- if (xferOp.mask ())
420
- return failure ();
421
451
return success ();
422
452
}
423
453
@@ -442,6 +472,8 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
442
472
// / memref.store %1, %0[] : memref<vector<5x4xf32>>
443
473
// / %vec = memref.load %0[] : memref<vector<5x4xf32>>
444
474
// / ```
475
+ // /
476
+ // / Note: A second temporary buffer may be allocated for the `mask` operand.
445
477
struct PrepareTransferReadConversion
446
478
: public OpRewritePattern<TransferReadOp> {
447
479
using OpRewritePattern<TransferReadOp>::OpRewritePattern;
@@ -452,12 +484,16 @@ struct PrepareTransferReadConversion
452
484
return failure ();
453
485
454
486
ScopedContext scope (rewriter, xferOp.getLoc ());
455
- auto allocType = MemRefType::get ({}, xferOp.getVectorType ());
456
- auto buffer = setAllocAtFunctionEntry (allocType, xferOp);
487
+ auto buffers = allocBuffers (xferOp);
457
488
auto *newXfer = rewriter.clone (*xferOp.getOperation ());
458
489
newXfer->setAttr (kPassLabel , rewriter.getUnitAttr ());
459
- memref_store (newXfer->getResult (0 ), buffer);
460
- rewriter.replaceOpWithNewOp <memref::LoadOp>(xferOp, buffer);
490
+ if (xferOp.mask ()) {
491
+ auto loadedMask = memref_load (buffers.maskBuffer );
492
+ dyn_cast<TransferReadOp>(newXfer).maskMutable ().assign (loadedMask);
493
+ }
494
+
495
+ memref_store (newXfer->getResult (0 ), buffers.dataBuffer );
496
+ rewriter.replaceOpWithNewOp <memref::LoadOp>(xferOp, buffers.dataBuffer );
461
497
462
498
return success ();
463
499
}
@@ -484,6 +520,8 @@ struct PrepareTransferReadConversion
484
520
// / vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
485
521
// / : vector<5x4xf32>, memref<?x?x?xf32>
486
522
// / ```
523
+ // /
524
+ // / Note: A second temporary buffer may be allocated for the `mask` operand.
487
525
struct PrepareTransferWriteConversion
488
526
: public OpRewritePattern<TransferWriteOp> {
489
527
using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
@@ -494,16 +532,20 @@ struct PrepareTransferWriteConversion
494
532
return failure ();
495
533
496
534
ScopedContext scope (rewriter, xferOp.getLoc ());
497
- auto allocType = MemRefType::get ({}, xferOp.getVectorType ());
498
- auto buffer = setAllocAtFunctionEntry (allocType, xferOp);
499
- memref_store (xferOp.vector (), buffer);
500
- auto loadedVec = memref_load (buffer);
501
-
535
+ auto buffers = allocBuffers (xferOp);
536
+ memref_store (xferOp.vector (), buffers.dataBuffer );
537
+ auto loadedVec = memref_load (buffers.dataBuffer );
502
538
rewriter.updateRootInPlace (xferOp, [&]() {
503
539
xferOp.vectorMutable ().assign (loadedVec);
504
540
xferOp->setAttr (kPassLabel , rewriter.getUnitAttr ());
505
541
});
506
542
543
+ if (xferOp.mask ()) {
544
+ auto loadedMask = memref_load (buffers.maskBuffer );
545
+ rewriter.updateRootInPlace (
546
+ xferOp, [&]() { xferOp.maskMutable ().assign (loadedMask); });
547
+ }
548
+
507
549
return success ();
508
550
}
509
551
};
@@ -535,16 +577,28 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
535
577
return failure ();
536
578
537
579
ScopedContext scope (rewriter, xferOp.getLoc ());
538
- // How the buffer can be found depends on OpTy.
539
- auto buffer = Strategy<OpTy>::getBuffer (xferOp);
540
- auto bufferType = buffer.getType ().template dyn_cast <MemRefType>();
541
- auto castedType = unpackOneDim (bufferType);
542
- auto casted = vector_type_cast (castedType, buffer);
580
+
581
+ // Find and cast data buffer. How the buffer can be found depends on OpTy.
582
+ auto dataBuffer = Strategy<OpTy>::getBuffer (xferOp);
583
+ auto dataBufferType = dataBuffer.getType ().template dyn_cast <MemRefType>();
584
+ auto castedDataType = unpackOneDim (dataBufferType);
585
+ auto castedDataBuffer = vector_type_cast (castedDataType, dataBuffer);
586
+
587
+ // If the xferOp has a mask: Find and cast mask buffer.
588
+ Value castedMaskBuffer;
589
+ if (xferOp.mask ()) {
590
+ auto maskBuffer = getMaskBuffer (xferOp);
591
+ auto maskBufferType =
592
+ maskBuffer.getType ().template dyn_cast <MemRefType>();
593
+ auto castedMaskType = unpackOneDim (maskBufferType);
594
+ castedMaskBuffer = vector_type_cast (castedMaskType, maskBuffer);
595
+ }
543
596
544
597
// Loop bounds and step.
545
598
auto lb = std_constant_index (0 ).value ;
546
599
auto ub = std_constant_index (
547
- castedType.getDimSize (castedType.getRank () - 1 )).value ;
600
+ castedDataType.getDimSize (castedDataType.getRank () - 1 ))
601
+ .value ;
548
602
auto step = std_constant_index (1 ).value ;
549
603
550
604
// Generate for loop.
@@ -555,11 +609,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
555
609
ScopedContext scope (b, loc);
556
610
generateInBoundsCheck (
557
611
xferOp, iv, b, unpackedDim (xferOp),
558
- /* inBoundsCase=*/ [&](OpBuilder &b, Location /* loc*/ ) {
559
- Strategy<OpTy>::rewriteOp (b, xferOp, casted, iv);
560
- }, /* outOfBoundsCase=*/ [&](OpBuilder &b, Location /* loc*/ ) {
561
- Strategy<OpTy>::handleOutOfBoundsDim (b, xferOp, casted, iv);
562
- });
612
+ /* inBoundsCase=*/
613
+ [&](OpBuilder &b, Location /* loc*/ ) {
614
+ // Create new transfer op.
615
+ OpTy newXfer =
616
+ Strategy<OpTy>::rewriteOp (b, xferOp, castedDataBuffer, iv);
617
+
618
+ // If old transfer op has a mask: Set mask on new transfer op.
619
+ if (xferOp.mask ()) {
620
+ OpBuilder::InsertionGuard guard (b);
621
+ b.setInsertionPoint (newXfer); // Insert load before newXfer.
622
+
623
+ SmallVector<Value, 8 > loadIndices;
624
+ Strategy<OpTy>::getBufferIndices (xferOp, loadIndices);
625
+ loadIndices.push_back (iv);
626
+
627
+ auto mask = memref_load (castedMaskBuffer, loadIndices);
628
+ rewriter.updateRootInPlace (
629
+ newXfer, [&]() { newXfer.maskMutable ().assign (mask); });
630
+ }
631
+ },
632
+ /* outOfBoundsCase=*/
633
+ [&](OpBuilder &b, Location /* loc*/ ) {
634
+ Strategy<OpTy>::handleOutOfBoundsDim (b, xferOp, castedDataBuffer,
635
+ iv);
636
+ });
563
637
b.create <scf::YieldOp>(loc);
564
638
});
565
639
0 commit comments