Skip to content

Commit 545f98e

Browse files
[mlir] Support masked 1D vector transfer ops in ProgressiveVectorToSCF
Support for masked N-D vector transfer ops will be added in a subsequent commit. Differential Revision: https://reviews.llvm.org/D101132
1 parent c229754 commit 545f98e

File tree

2 files changed

+170
-105
lines changed

2 files changed

+170
-105
lines changed

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Lines changed: 118 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
7474
template <typename OpTy>
7575
static Optional<int64_t> unpackedDim(OpTy xferOp) {
7676
auto map = xferOp.permutation_map();
77-
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>())
77+
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
7878
return expr.getPosition();
79-
79+
}
8080
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
8181
"Expected AffineDimExpr or AffineConstantExpr");
8282
return None;
@@ -88,8 +88,9 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
8888
template <typename OpTy>
8989
static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) {
9090
auto map = xferOp.permutation_map();
91-
return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
92-
builder.getContext());
91+
return AffineMap::get(
92+
map.getNumDims(), 0, map.getResults().drop_front(),
93+
builder.getContext());
9394
}
9495

9596
/// Calculate the indices for the new vector transfer op.
@@ -114,15 +115,29 @@ static void getXferIndices(OpTy xferOp, Value iv,
114115
}
115116
}
116117

117-
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
118-
Value value) {
118+
static void maybeYieldValue(
119+
bool hasRetVal, OpBuilder builder, Location loc, Value value) {
119120
if (hasRetVal) {
120121
builder.create<scf::YieldOp>(loc, value);
121122
} else {
122123
builder.create<scf::YieldOp>(loc);
123124
}
124125
}
125126

127+
/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
128+
/// is set to true. Does not return a Value if the transfer op is not 1D or
129+
/// if the transfer op does not have a mask.
130+
template <typename OpTy>
131+
static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
132+
if (xferOp.getVectorType().getRank() != 1)
133+
return Value();
134+
if (!xferOp.mask())
135+
return Value();
136+
137+
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
138+
return vector_extract_element(xferOp.mask(), ivI32).value;
139+
}
140+
126141
/// Helper function TransferOpConversion and TransferOp1dConversion.
127142
/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
128143
/// specified dimension `dim` with the loop iteration variable `iv`.
@@ -140,6 +155,10 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
140155
/// (out-of-bounds case)
141156
/// }
142157
/// ```
158+
///
159+
/// If the transfer is 1D and has a mask, this function generates a more complex
160+
/// check also accounts for potentially masked out elements.
161+
///
143162
/// This function variant returns the value returned by `inBoundsCase` or
144163
/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
145164
/// `resultTypes`.
@@ -150,33 +169,45 @@ static Value generateInBoundsCheck(
150169
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
151170
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
152171
bool hasRetVal = !resultTypes.empty();
153-
bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
172+
Value cond; // Condition to be built...
173+
174+
// Condition check 1: Access in-bounds?
175+
bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
154176
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
155177
auto memrefDim =
156178
memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
157179
using edsc::op::operator+;
158180
auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
159-
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
181+
cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
182+
}
183+
184+
// Condition check 2: Masked in?
185+
if (auto maskCond = maybeGenerateMaskCheck(builder, xferOp, iv)) {
186+
if (cond) {
187+
cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond);
188+
} else {
189+
cond = maskCond;
190+
}
191+
}
192+
193+
// If the condition is non-empty, generate an SCF::IfOp.
194+
if (cond) {
160195
auto check = builder.create<scf::IfOp>(
161196
xferOp.getLoc(), resultTypes, cond,
162-
/*thenBuilder=*/
163-
[&](OpBuilder &builder, Location loc) {
164-
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
165-
},
166-
/*elseBuilder=*/
167-
[&](OpBuilder &builder, Location loc) {
168-
if (outOfBoundsCase) {
169-
maybeYieldValue(hasRetVal, builder, loc,
170-
outOfBoundsCase(builder, loc));
171-
} else {
172-
builder.create<scf::YieldOp>(loc);
173-
}
174-
});
197+
/*thenBuilder=*/[&](OpBuilder &builder, Location loc) {
198+
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
199+
}, /*elseBuilder=*/[&](OpBuilder &builder, Location loc) {
200+
if (outOfBoundsCase) {
201+
maybeYieldValue(hasRetVal, builder, loc, outOfBoundsCase(builder, loc));
202+
} else {
203+
builder.create<scf::YieldOp>(loc);
204+
}
205+
});
175206

176207
return hasRetVal ? check.getResult(0) : Value();
177208
}
178209

179-
// No runtime check needed if dim is guaranteed to be in-bounds.
210+
// Condition is empty, no need for an SCF::IfOp.
180211
return inBoundsCase(builder, xferOp.getLoc());
181212
}
182213

@@ -189,15 +220,13 @@ static void generateInBoundsCheck(
189220
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
190221
generateInBoundsCheck(
191222
xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
192-
/*inBoundsCase=*/
193-
[&](OpBuilder &builder, Location loc) {
223+
/*inBoundsCase=*/[&](OpBuilder &builder, Location loc) {
194224
inBoundsCase(builder, loc);
195225
return Value();
196226
},
197-
/*outOfBoundsCase=*/
198-
[&](OpBuilder &builder, Location loc) {
227+
/*outOfBoundsCase=*/[&](OpBuilder &builder, Location loc) {
199228
if (outOfBoundsCase)
200-
outOfBoundsCase(builder, loc);
229+
outOfBoundsCase(builder, loc);
201230
return Value();
202231
});
203232
}
@@ -271,8 +300,8 @@ struct Strategy<TransferReadOp> {
271300
///
272301
/// Note: The loop and type cast are generated in TransferOpConversion.
273302
/// The original TransferReadOp and store op are deleted in `cleanup`.
274-
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, Value buffer,
275-
Value iv) {
303+
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp,
304+
Value buffer, Value iv) {
276305
SmallVector<Value, 8> storeIndices;
277306
getStoreIndices(xferOp, storeIndices);
278307
storeIndices.push_back(iv);
@@ -283,24 +312,22 @@ struct Strategy<TransferReadOp> {
283312
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
284313
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
285314
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
286-
auto newXfer =
287-
vector_transfer_read(
288-
vecType, xferOp.source(), xferIndices,
289-
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
290-
xferOp.padding(), Value(), inBoundsAttr)
291-
.value;
315+
auto newXfer = vector_transfer_read(
316+
vecType, xferOp.source(), xferIndices,
317+
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
318+
xferOp.padding(), Value(), inBoundsAttr).value;
292319

293320
if (vecType.getRank() > kTargetRank)
294-
newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
321+
newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
295322

296323
memref_store(newXfer, buffer, storeIndices);
297324
}
298325

299326
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
300327
/// padding value to the temporary buffer.
301-
static void handleOutOfBoundsDim(OpBuilder & /*builder*/,
302-
TransferReadOp xferOp, Value buffer,
303-
Value iv) {
328+
static void handleOutOfBoundsDim(
329+
OpBuilder &/*builder*/, TransferReadOp xferOp, Value buffer,
330+
Value iv) {
304331
SmallVector<Value, 8> storeIndices;
305332
getStoreIndices(xferOp, storeIndices);
306333
storeIndices.push_back(iv);
@@ -365,16 +392,17 @@ struct Strategy<TransferWriteOp> {
365392
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
366393
auto newXfer = vector_transfer_write(
367394
Type(), vec, xferOp.source(), xferIndices,
368-
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
369-
inBoundsAttr);
395+
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
396+
Value(), inBoundsAttr);
370397

371398
if (vecType.getRank() > kTargetRank)
372-
newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
399+
newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
373400
}
374401

375402
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
376-
static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp,
377-
Value buffer, Value iv) {}
403+
static void handleOutOfBoundsDim(
404+
OpBuilder &builder, TransferWriteOp xferOp, Value buffer,
405+
Value iv) {}
378406

379407
/// Cleanup after rewriting the op.
380408
static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
@@ -522,20 +550,18 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
522550
// Generate for loop.
523551
rewriter.create<scf::ForOp>(
524552
xferOp.getLoc(), lb, ub, step, ValueRange(),
525-
[&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
526-
ScopedContext scope(b, loc);
527-
generateInBoundsCheck(
528-
xferOp, iv, b, unpackedDim(xferOp),
529-
/*inBoundsCase=*/
530-
[&](OpBuilder &b, Location /*loc*/) {
531-
Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
532-
},
533-
/*outOfBoundsCase=*/
534-
[&](OpBuilder &b, Location /*loc*/) {
535-
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
536-
});
537-
b.create<scf::YieldOp>(loc);
538-
});
553+
[&](OpBuilder &b, Location loc, Value iv,
554+
ValueRange /*loopState*/) {
555+
ScopedContext scope(b, loc);
556+
generateInBoundsCheck(
557+
xferOp, iv, b, unpackedDim(xferOp),
558+
/*inBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
559+
Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
560+
}, /*outOfBoundsCase=*/[&](OpBuilder &b, Location /*loc*/) {
561+
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
562+
});
563+
b.create<scf::YieldOp>(loc);
564+
});
539565

540566
Strategy<OpTy>::cleanup(rewriter, xferOp);
541567
return success();
@@ -546,9 +572,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
546572
/// part of TransferOp1dConversion. Return the memref dimension on which
547573
/// the transfer is operating. A return value of None indicates a broadcast.
548574
template <typename OpTy>
549-
static Optional<int64_t>
550-
get1dMemrefIndices(OpTy xferOp, Value iv,
551-
SmallVector<Value, 8> &memrefIndices) {
575+
static Optional<int64_t> get1dMemrefIndices(
576+
OpTy xferOp, Value iv, SmallVector<Value, 8> &memrefIndices) {
552577
auto indices = xferOp.indices();
553578
auto map = xferOp.permutation_map();
554579

@@ -575,25 +600,25 @@ struct Strategy1d;
575600
/// Codegen strategy for TransferReadOp.
576601
template <>
577602
struct Strategy1d<TransferReadOp> {
578-
static void generateForLoopBody(OpBuilder &builder, Location loc,
579-
TransferReadOp xferOp, Value iv,
580-
ValueRange loopState) {
603+
static void generateForLoopBody(
604+
OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
605+
ValueRange loopState) {
581606
SmallVector<Value, 8> indices;
582607
auto dim = get1dMemrefIndices(xferOp, iv, indices);
583-
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
608+
auto ivI32 = std_index_cast(
609+
IntegerType::get(builder.getContext(), 32), iv);
584610
auto vec = loopState[0];
585611

586612
// In case of out-of-bounds access, leave `vec` as is (was initialized with
587613
// padding value).
588614
auto nextVec = generateInBoundsCheck(
589615
xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
590-
/*inBoundsCase=*/
591-
[&](OpBuilder & /*b*/, Location loc) {
592-
auto val = memref_load(xferOp.source(), indices);
593-
return vector_insert_element(val, vec, ivI32.value).value;
594-
},
595-
/*outOfBoundsCase=*/
596-
[&](OpBuilder & /*b*/, Location loc) { return vec; });
616+
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
617+
auto val = memref_load(xferOp.source(), indices);
618+
return vector_insert_element(val, vec, ivI32.value).value;
619+
}, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
620+
return vec;
621+
});
597622
builder.create<scf::YieldOp>(loc, nextVec);
598623
}
599624

@@ -606,24 +631,27 @@ struct Strategy1d<TransferReadOp> {
606631
/// Codegen strategy for TransferWriteOp.
607632
template <>
608633
struct Strategy1d<TransferWriteOp> {
609-
static void generateForLoopBody(OpBuilder &builder, Location loc,
610-
TransferWriteOp xferOp, Value iv,
611-
ValueRange /*loopState*/) {
634+
static void generateForLoopBody(
635+
OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
636+
ValueRange /*loopState*/) {
612637
SmallVector<Value, 8> indices;
613638
auto dim = get1dMemrefIndices(xferOp, iv, indices);
614-
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
639+
auto ivI32 = std_index_cast(
640+
IntegerType::get(builder.getContext(), 32), iv);
615641

616642
// Nothing to do in case of out-of-bounds access.
617643
generateInBoundsCheck(
618644
xferOp, iv, builder, dim,
619-
/*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
620-
auto val = vector_extract_element(xferOp.vector(), ivI32.value);
621-
memref_store(val, xferOp.source(), indices);
622-
});
645+
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
646+
auto val = vector_extract_element(xferOp.vector(), ivI32.value);
647+
memref_store(val, xferOp.source(), indices);
648+
});
623649
builder.create<scf::YieldOp>(loc);
624650
}
625651

626-
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
652+
static Value initialLoopState(TransferWriteOp xferOp) {
653+
return Value();
654+
}
627655
};
628656

629657
/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
@@ -667,11 +695,9 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
667695
auto map = xferOp.permutation_map();
668696

669697
if (xferOp.getVectorType().getRank() != 1)
670-
return failure();
671-
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
672-
return failure();
673-
if (xferOp.mask())
674-
return failure();
698+
return failure();
699+
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
700+
return failure();
675701

676702
// Loop bounds, step, state...
677703
auto vecType = xferOp.getVectorType();
@@ -684,10 +710,10 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
684710
rewriter.replaceOpWithNewOp<scf::ForOp>(
685711
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
686712
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
687-
ScopedContext nestedScope(builder, loc);
688-
Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
689-
loopState);
690-
});
713+
ScopedContext nestedScope(builder, loc);
714+
Strategy1d<OpTy>::generateForLoopBody(
715+
builder, loc, xferOp, iv, loopState);
716+
});
691717

692718
return success();
693719
}
@@ -699,7 +725,8 @@ namespace mlir {
699725

700726
void populateProgressiveVectorToSCFConversionPatterns(
701727
RewritePatternSet &patterns) {
702-
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
728+
patterns.add<PrepareTransferReadConversion,
729+
PrepareTransferWriteConversion,
703730
TransferOpConversion<TransferReadOp>,
704731
TransferOpConversion<TransferWriteOp>>(patterns.getContext());
705732

@@ -725,4 +752,3 @@ std::unique_ptr<Pass>
725752
mlir::createProgressiveConvertVectorToSCFPass() {
726753
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
727754
}
728-

0 commit comments

Comments
 (0)