Skip to content

Commit 64f7fb5

Browse files
[mlir] Support masked N-D vector transfer ops in ProgressiveVectorToSCF.
Mask vectors are handled similar to data vectors in N-D TransferWriteOp. They are copied into a temporary memory buffer, which can be indexed into with non-constant values. Differential Revision: https://reviews.llvm.org/D101136
1 parent c623945 commit 64f7fb5

File tree

2 files changed

+132
-46
lines changed

2 files changed

+132
-46
lines changed

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Lines changed: 114 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,34 @@ static MemRefType unpackOneDim(MemRefType type) {
5656
vectorType.getElementType()));
5757
}
5858

59-
// TODO: Parallelism and threadlocal considerations.
60-
static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
59+
/// Helper data structure for data and mask buffers.
60+
struct BufferAllocs {
61+
Value dataBuffer;
62+
Value maskBuffer;
63+
};
64+
65+
/// Allocate temporary buffers for data (vector) and mask (if present).
66+
/// TODO: Parallelism and threadlocal considerations.
67+
template <typename OpTy>
68+
static BufferAllocs allocBuffers(OpTy xferOp) {
6169
auto &b = ScopedContext::getBuilderRef();
6270
OpBuilder::InsertionGuard guard(b);
6371
Operation *scope =
64-
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
72+
xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
6573
assert(scope && "Expected op to be inside automatic allocation scope");
6674
b.setInsertionPointToStart(&scope->getRegion(0).front());
67-
Value res = memref_alloca(type);
68-
return res;
75+
76+
BufferAllocs result;
77+
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
78+
result.dataBuffer = memref_alloca(bufferType).value;
79+
80+
if (xferOp.mask()) {
81+
auto maskType = MemRefType::get({}, xferOp.mask().getType());
82+
result.maskBuffer = memref_alloca(maskType).value;
83+
memref_store(xferOp.mask(), result.maskBuffer);
84+
}
85+
86+
return result;
6987
}
7088

7189
/// Given a vector transfer op, calculate which dimension of the `source`
@@ -238,6 +256,16 @@ static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
238256
return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
239257
}
240258

259+
/// Given a transfer op, find the memref from which the mask is loaded. This
260+
/// is similar to Strategy<TransferWriteOp>::getBuffer.
261+
template <typename OpTy>
262+
static Value getMaskBuffer(OpTy xferOp) {
263+
assert(xferOp.mask() && "Expected that transfer op has mask");
264+
auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
265+
assert(loadOp && "Expected transfer op mask produced by LoadOp");
266+
return loadOp.getMemRef();
267+
}
268+
241269
/// Codegen strategy, depending on the operation.
242270
template <typename OpTy>
243271
struct Strategy;
@@ -266,9 +294,9 @@ struct Strategy<TransferReadOp> {
266294
return getStoreOp(xferOp).getMemRef();
267295
}
268296

269-
/// Retrieve the indices of the current StoreOp.
270-
static void getStoreIndices(TransferReadOp xferOp,
271-
SmallVector<Value, 8> &indices) {
297+
/// Retrieve the indices of the current StoreOp that stores into the buffer.
298+
static void getBufferIndices(TransferReadOp xferOp,
299+
SmallVector<Value, 8> &indices) {
272300
auto storeOp = getStoreOp(xferOp);
273301
auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
274302
indices.append(prevIndices.begin(), prevIndices.end());
@@ -300,10 +328,11 @@ struct Strategy<TransferReadOp> {
300328
///
301329
/// Note: The loop and type cast are generated in TransferOpConversion.
302330
/// The original TransferReadOp and store op are deleted in `cleanup`.
303-
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
304-
Value buffer, Value iv) {
331+
/// Note: The `mask` operand is set in TransferOpConversion.
332+
static TransferReadOp rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
333+
Value buffer, Value iv) {
305334
SmallVector<Value, 8> storeIndices;
306-
getStoreIndices(xferOp, storeIndices);
335+
getBufferIndices(xferOp, storeIndices);
307336
storeIndices.push_back(iv);
308337

309338
SmallVector<Value, 8> xferIndices;
@@ -321,6 +350,7 @@ struct Strategy<TransferReadOp> {
321350
newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
322351

323352
memref_store(newXfer, buffer, storeIndices);
353+
return newXfer.getDefiningOp<TransferReadOp>();
324354
}
325355

326356
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
@@ -329,7 +359,7 @@ struct Strategy<TransferReadOp> {
329359
OpBuilder &/*builder*/, TransferReadOp xferOp, Value buffer,
330360
Value iv) {
331361
SmallVector<Value, 8> storeIndices;
332-
getStoreIndices(xferOp, storeIndices);
362+
getBufferIndices(xferOp, storeIndices);
333363
storeIndices.push_back(iv);
334364

335365
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
@@ -361,9 +391,9 @@ struct Strategy<TransferWriteOp> {
361391
return loadOp.getMemRef();
362392
}
363393

364-
/// Retrieve the indices of the current LoadOp.
365-
static void getLoadIndices(TransferWriteOp xferOp,
366-
SmallVector<Value, 8> &indices) {
394+
/// Retrieve the indices of the current LoadOp that loads from the buffer.
395+
static void getBufferIndices(TransferWriteOp xferOp,
396+
SmallVector<Value, 8> &indices) {
367397
auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
368398
auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
369399
indices.append(prevIndices.begin(), prevIndices.end());
@@ -378,10 +408,10 @@ struct Strategy<TransferWriteOp> {
378408
/// to memory.
379409
///
380410
/// Note: For more details, see comments on Strategy<TransferReadOp>.
381-
static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
382-
Value buffer, Value iv) {
411+
static TransferWriteOp rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
412+
Value buffer, Value iv) {
383413
SmallVector<Value, 8> loadIndices;
384-
getLoadIndices(xferOp, loadIndices);
414+
getBufferIndices(xferOp, loadIndices);
385415
loadIndices.push_back(iv);
386416

387417
SmallVector<Value, 8> xferIndices;
@@ -397,6 +427,8 @@ struct Strategy<TransferWriteOp> {
397427

398428
if (vecType.getRank() > kTargetRank)
399429
newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
430+
431+
return newXfer;
400432
}
401433

402434
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
@@ -416,8 +448,6 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
416448
return failure();
417449
if (xferOp.getVectorType().getRank() <= kTargetRank)
418450
return failure();
419-
if (xferOp.mask())
420-
return failure();
421451
return success();
422452
}
423453

@@ -442,6 +472,8 @@ LogicalResult checkPrepareXferOp(OpTy xferOp) {
442472
/// memref.store %1, %0[] : memref<vector<5x4xf32>>
443473
/// %vec = memref.load %0[] : memref<vector<5x4xf32>>
444474
/// ```
475+
///
476+
/// Note: A second temporary buffer may be allocated for the `mask` operand.
445477
struct PrepareTransferReadConversion
446478
: public OpRewritePattern<TransferReadOp> {
447479
using OpRewritePattern<TransferReadOp>::OpRewritePattern;
@@ -452,12 +484,16 @@ struct PrepareTransferReadConversion
452484
return failure();
453485

454486
ScopedContext scope(rewriter, xferOp.getLoc());
455-
auto allocType = MemRefType::get({}, xferOp.getVectorType());
456-
auto buffer = setAllocAtFunctionEntry(allocType, xferOp);
487+
auto buffers = allocBuffers(xferOp);
457488
auto *newXfer = rewriter.clone(*xferOp.getOperation());
458489
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
459-
memref_store(newXfer->getResult(0), buffer);
460-
rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffer);
490+
if (xferOp.mask()) {
491+
auto loadedMask = memref_load(buffers.maskBuffer);
492+
dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(loadedMask);
493+
}
494+
495+
memref_store(newXfer->getResult(0), buffers.dataBuffer);
496+
rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
461497

462498
return success();
463499
}
@@ -484,6 +520,8 @@ struct PrepareTransferReadConversion
484520
/// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
485521
/// : vector<5x4xf32>, memref<?x?x?xf32>
486522
/// ```
523+
///
524+
/// Note: A second temporary buffer may be allocated for the `mask` operand.
487525
struct PrepareTransferWriteConversion
488526
: public OpRewritePattern<TransferWriteOp> {
489527
using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
@@ -494,16 +532,20 @@ struct PrepareTransferWriteConversion
494532
return failure();
495533

496534
ScopedContext scope(rewriter, xferOp.getLoc());
497-
auto allocType = MemRefType::get({}, xferOp.getVectorType());
498-
auto buffer = setAllocAtFunctionEntry(allocType, xferOp);
499-
memref_store(xferOp.vector(), buffer);
500-
auto loadedVec = memref_load(buffer);
501-
535+
auto buffers = allocBuffers(xferOp);
536+
memref_store(xferOp.vector(), buffers.dataBuffer);
537+
auto loadedVec = memref_load(buffers.dataBuffer);
502538
rewriter.updateRootInPlace(xferOp, [&]() {
503539
xferOp.vectorMutable().assign(loadedVec);
504540
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
505541
});
506542

543+
if (xferOp.mask()) {
544+
auto loadedMask = memref_load(buffers.maskBuffer);
545+
rewriter.updateRootInPlace(
546+
xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); });
547+
}
548+
507549
return success();
508550
}
509551
};
@@ -535,16 +577,28 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
535577
return failure();
536578

537579
ScopedContext scope(rewriter, xferOp.getLoc());
538-
// How the buffer can be found depends on OpTy.
539-
auto buffer = Strategy<OpTy>::getBuffer(xferOp);
540-
auto bufferType = buffer.getType().template dyn_cast<MemRefType>();
541-
auto castedType = unpackOneDim(bufferType);
542-
auto casted = vector_type_cast(castedType, buffer);
580+
581+
// Find and cast data buffer. How the buffer can be found depends on OpTy.
582+
auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
583+
auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
584+
auto castedDataType = unpackOneDim(dataBufferType);
585+
auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer);
586+
587+
// If the xferOp has a mask: Find and cast mask buffer.
588+
Value castedMaskBuffer;
589+
if (xferOp.mask()) {
590+
auto maskBuffer = getMaskBuffer(xferOp);
591+
auto maskBufferType =
592+
maskBuffer.getType().template dyn_cast<MemRefType>();
593+
auto castedMaskType = unpackOneDim(maskBufferType);
594+
castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
595+
}
543596

544597
// Loop bounds and step.
545598
auto lb = std_constant_index(0).value;
546599
auto ub = std_constant_index(
547-
castedType.getDimSize(castedType.getRank() - 1)).value;
600+
castedDataType.getDimSize(castedDataType.getRank() - 1))
601+
.value;
548602
auto step = std_constant_index(1).value;
549603

550604
// Generate for loop.
@@ -555,11 +609,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
555609
ScopedContext scope(b, loc);
556610
generateInBoundsCheck(
557611
xferOp, iv, b, unpackedDim(xferOp),
558-
/*inBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
559-
Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
560-
}, /*outOfBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
561-
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
562-
});
612+
/*inBoundsCase=*/
613+
[&](OpBuilder &b, Location /*loc*/) {
614+
// Create new transfer op.
615+
OpTy newXfer =
616+
Strategy<OpTy>::rewriteOp(b, xferOp, castedDataBuffer, iv);
617+
618+
// If old transfer op has a mask: Set mask on new transfer op.
619+
if (xferOp.mask()) {
620+
OpBuilder::InsertionGuard guard(b);
621+
b.setInsertionPoint(newXfer); // Insert load before newXfer.
622+
623+
SmallVector<Value, 8> loadIndices;
624+
Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
625+
loadIndices.push_back(iv);
626+
627+
auto mask = memref_load(castedMaskBuffer, loadIndices);
628+
rewriter.updateRootInPlace(
629+
newXfer, [&]() { newXfer.maskMutable().assign(mask); });
630+
}
631+
},
632+
/*outOfBoundsCase=*/
633+
[&](OpBuilder &b, Location /*loc*/) {
634+
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, castedDataBuffer,
635+
iv);
636+
});
563637
b.create<scf::YieldOp>(loc);
564638
});
565639

mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
2-
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3-
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4-
// RUN: FileCheck %s
5-
61
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
72
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
83
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
@@ -17,6 +12,19 @@ func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
1712
return
1813
}
1914

15+
func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
16+
%fm42 = constant -42.0: f32
17+
%mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
18+
[0, 0, 1, 1, 1, 1, 1, 0, 1],
19+
[1, 1, 1, 1, 1, 1, 1, 0, 1],
20+
[0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
21+
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
22+
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
23+
memref<?x?xf32>, vector<4x9xf32>
24+
vector.print %f: vector<4x9xf32>
25+
return
26+
}
27+
2028
func @transfer_read_2d_transposed(
2129
%A : memref<?x?xf32>, %base1: index, %base2: index) {
2230
%fm42 = constant -42.0: f32
@@ -80,7 +88,10 @@ func @entry() {
8088
call @transfer_write_2d(%A, %c3, %c1) : (memref<?x?xf32>, index, index) -> ()
8189
// Read shifted by 0 and pad with -42:
8290
call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
83-
// Same as above, but transposed
91+
// Same as above, but apply a mask
92+
call @transfer_read_2d_mask(%A, %c0, %c0)
93+
: (memref<?x?xf32>, index, index) -> ()
94+
// Same as above, but without mask and transposed
8495
call @transfer_read_2d_transposed(%A, %c0, %c0)
8596
: (memref<?x?xf32>, index, index) -> ()
8697
// Second vector dimension is a broadcast
@@ -92,5 +103,6 @@ func @entry() {
92103
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
93104
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
94105
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
106+
// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
95107
// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
96108
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )

0 commit comments

Comments
 (0)