Skip to content

Commit f6a3e92

Browse files
[mlir] Use SCF for loops in ProgressiveVectorToSCF
Use SCF for loops instead of Affine for loops. Differential Revision: https://reviews.llvm.org/D101013
1 parent ab78e09 commit f6a3e92

File tree

1 file changed

+85
-71
lines changed

1 file changed

+85
-71
lines changed

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Lines changed: 85 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ static void getXferIndices(OpTy xferOp, Value iv,
107107
indices[dim] = adaptor.indices()[dim] + iv;
108108
}
109109

110-
static void maybeYieldValue(
111-
bool hasRetVal, OpBuilder builder, Location loc, Value value) {
110+
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
111+
Value value) {
112112
if (hasRetVal) {
113113
builder.create<scf::YieldOp>(loc, value);
114114
} else {
@@ -150,15 +150,19 @@ static Value generateInBoundsCheck(
150150
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
151151
auto check = builder.create<scf::IfOp>(
152152
xferOp.getLoc(), resultTypes, cond,
153-
/*thenBuilder=*/[&](OpBuilder &builder, Location loc) {
154-
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
155-
}, /*elseBuilder=*/[&](OpBuilder &builder, Location loc) {
156-
if (outOfBoundsCase) {
157-
maybeYieldValue(hasRetVal, builder, loc, outOfBoundsCase(builder, loc));
158-
} else {
159-
builder.create<scf::YieldOp>(loc);
160-
}
161-
});
153+
/*thenBuilder=*/
154+
[&](OpBuilder &builder, Location loc) {
155+
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
156+
},
157+
/*elseBuilder=*/
158+
[&](OpBuilder &builder, Location loc) {
159+
if (outOfBoundsCase) {
160+
maybeYieldValue(hasRetVal, builder, loc,
161+
outOfBoundsCase(builder, loc));
162+
} else {
163+
builder.create<scf::YieldOp>(loc);
164+
}
165+
});
162166

163167
return hasRetVal ? check.getResult(0) : Value();
164168
}
@@ -176,22 +180,24 @@ static void generateInBoundsCheck(
176180
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
177181
generateInBoundsCheck(
178182
xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
179-
/*inBoundsCase=*/[&](OpBuilder &builder, Location loc) {
183+
/*inBoundsCase=*/
184+
[&](OpBuilder &builder, Location loc) {
180185
inBoundsCase(builder, loc);
181186
return Value();
182187
},
183-
/*outOfBoundsCase=*/[&](OpBuilder &builder, Location loc) {
188+
/*outOfBoundsCase=*/
189+
[&](OpBuilder &builder, Location loc) {
184190
if (outOfBoundsCase)
185-
outOfBoundsCase(builder, loc);
191+
outOfBoundsCase(builder, loc);
186192
return Value();
187193
});
188194
}
189195

190196
/// Given an ArrayAttr, return a copy where the first element is dropped.
191-
static ArrayAttr dropFirstElem(PatternRewriter &rewriter, ArrayAttr attr) {
197+
static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
192198
if (!attr)
193199
return attr;
194-
return ArrayAttr::get(rewriter.getContext(), attr.getValue().drop_front());
200+
return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
195201
}
196202

197203
/// Codegen strategy, depending on the operation.
@@ -256,8 +262,8 @@ struct Strategy<TransferReadOp> {
256262
///
257263
/// Note: The loop and type cast are generated in TransferOpConversion.
258264
/// The original TransferReadOp and store op are deleted in `cleanup`.
259-
static void rewriteOp(PatternRewriter &rewriter, TransferReadOp xferOp,
260-
Value buffer, Value iv) {
265+
static void rewriteOp(OpBuilder &builder, TransferReadOp xferOp, Value buffer,
266+
Value iv) {
261267
SmallVector<Value, 8> storeIndices;
262268
getStoreIndices(xferOp, storeIndices);
263269
storeIndices.push_back(iv);
@@ -267,25 +273,25 @@ struct Strategy<TransferReadOp> {
267273

268274
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
269275
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
270-
auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr());
276+
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
271277
auto newXfer =
272278
vector_transfer_read(
273279
vecType, xferOp.source(), xferIndices,
274-
AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)),
280+
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
275281
xferOp.padding(), Value(), inBoundsAttr)
276282
.value;
277283

278284
if (vecType.getRank() > kTargetRank)
279-
newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr());
285+
newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr());
280286

281287
memref_store(newXfer, buffer, storeIndices);
282288
}
283289

284290
/// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
285291
/// padding value to the temporary buffer.
286-
static void handleOutOfBoundsDim(
287-
PatternRewriter &rewriter, TransferReadOp xferOp, Value buffer,
288-
Value iv) {
292+
static void handleOutOfBoundsDim(OpBuilder & /*builder*/,
293+
TransferReadOp xferOp, Value buffer,
294+
Value iv) {
289295
SmallVector<Value, 8> storeIndices;
290296
getStoreIndices(xferOp, storeIndices);
291297
storeIndices.push_back(iv);
@@ -336,7 +342,7 @@ struct Strategy<TransferWriteOp> {
336342
/// to memory.
337343
///
338344
/// Note: For more details, see comments on Strategy<TransferReadOp>.
339-
static void rewriteOp(PatternRewriter &rewriter, TransferWriteOp xferOp,
345+
static void rewriteOp(OpBuilder &builder, TransferWriteOp xferOp,
340346
Value buffer, Value iv) {
341347
SmallVector<Value, 8> loadIndices;
342348
getLoadIndices(xferOp, loadIndices);
@@ -347,20 +353,19 @@ struct Strategy<TransferWriteOp> {
347353

348354
auto vec = memref_load(buffer, loadIndices);
349355
auto vecType = vec.value.getType().dyn_cast<VectorType>();
350-
auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr());
356+
auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
351357
auto newXfer = vector_transfer_write(
352358
Type(), vec, xferOp.source(), xferIndices,
353-
AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), Value(),
359+
AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
354360
inBoundsAttr);
355361

356362
if (vecType.getRank() > kTargetRank)
357-
newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr());
363+
newXfer.op->setAttr(kPassLabel, builder.getUnitAttr());
358364
}
359365

360366
/// Handle out-of-bounds accesses on the to-be-unpacked dimension.
361-
static void handleOutOfBoundsDim(
362-
PatternRewriter &rewriter, TransferWriteOp xferOp, Value buffer,
363-
Value iv) {}
367+
static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp,
368+
Value buffer, Value iv) {}
364369

365370
/// Cleanup after rewriting the op.
366371
static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
@@ -499,18 +504,29 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
499504
auto castedType = unpackOneDim(bufferType);
500505
auto casted = vector_type_cast(castedType, buffer);
501506

507+
// Loop bounds and step.
502508
auto lb = std_constant_index(0).value;
503509
auto ub = std_constant_index(
504510
castedType.getDimSize(castedType.getRank() - 1)).value;
505-
affineLoopBuilder(lb, ub, 1, [&](Value iv) {
506-
generateInBoundsCheck(
507-
xferOp, iv, rewriter, unpackedDim(xferOp),
508-
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
509-
Strategy<OpTy>::rewriteOp(rewriter, xferOp, casted, iv);
510-
}, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
511-
Strategy<OpTy>::handleOutOfBoundsDim(rewriter, xferOp, casted, iv);
512-
});
513-
});
511+
auto step = std_constant_index(1).value;
512+
513+
// Generate for loop.
514+
rewriter.create<scf::ForOp>(
515+
xferOp.getLoc(), lb, ub, step, ValueRange(),
516+
[&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
517+
ScopedContext scope(b, loc);
518+
generateInBoundsCheck(
519+
xferOp, iv, b, unpackedDim(xferOp),
520+
/*inBoundsCase=*/
521+
[&](OpBuilder &b, Location /*loc*/) {
522+
Strategy<OpTy>::rewriteOp(b, xferOp, casted, iv);
523+
},
524+
/*outOfBoundsCase=*/
525+
[&](OpBuilder &b, Location /*loc*/) {
526+
Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, casted, iv);
527+
});
528+
b.create<scf::YieldOp>(loc);
529+
});
514530

515531
Strategy<OpTy>::cleanup(rewriter, xferOp);
516532
return success();
@@ -546,25 +562,25 @@ struct Strategy1d;
546562
/// Codegen strategy for TransferReadOp.
547563
template <>
548564
struct Strategy1d<TransferReadOp> {
549-
static void generateForLoopBody(
550-
OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
551-
ValueRange loopState) {
565+
static void generateForLoopBody(OpBuilder &builder, Location loc,
566+
TransferReadOp xferOp, Value iv,
567+
ValueRange loopState) {
552568
SmallVector<Value, 8> indices;
553569
auto dim = get1dMemrefIndices(xferOp, iv, indices);
554-
auto ivI32 = std_index_cast(
555-
IntegerType::get(builder.getContext(), 32), iv);
570+
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
556571
auto vec = loopState[0];
557572

558573
// In case of out-of-bounds access, leave `vec` as is (was initialized with
559574
// padding value).
560575
auto nextVec = generateInBoundsCheck(
561576
xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
562-
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
563-
auto val = memref_load(xferOp.source(), indices);
564-
return vector_insert_element(val, vec, ivI32.value).value;
565-
}, /*outOfBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
566-
return vec;
567-
});
577+
/*inBoundsCase=*/
578+
[&](OpBuilder & /*b*/, Location loc) {
579+
auto val = memref_load(xferOp.source(), indices);
580+
return vector_insert_element(val, vec, ivI32.value).value;
581+
},
582+
/*outOfBoundsCase=*/
583+
[&](OpBuilder & /*b*/, Location loc) { return vec; });
568584
builder.create<scf::YieldOp>(loc, nextVec);
569585
}
570586

@@ -577,27 +593,24 @@ struct Strategy1d<TransferReadOp> {
577593
/// Codegen strategy for TransferWriteOp.
578594
template <>
579595
struct Strategy1d<TransferWriteOp> {
580-
static void generateForLoopBody(
581-
OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
582-
ValueRange /*loopState*/) {
596+
static void generateForLoopBody(OpBuilder &builder, Location loc,
597+
TransferWriteOp xferOp, Value iv,
598+
ValueRange /*loopState*/) {
583599
SmallVector<Value, 8> indices;
584600
auto dim = get1dMemrefIndices(xferOp, iv, indices);
585-
auto ivI32 = std_index_cast(
586-
IntegerType::get(builder.getContext(), 32), iv);
601+
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
587602

588603
// Nothing to do in case of out-of-bounds access.
589604
generateInBoundsCheck(
590605
xferOp, iv, builder, dim,
591-
/*inBoundsCase=*/[&](OpBuilder& /*b*/, Location loc) {
592-
auto val = vector_extract_element(xferOp.vector(), ivI32.value);
593-
memref_store(val, xferOp.source(), indices);
594-
});
606+
/*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
607+
auto val = vector_extract_element(xferOp.vector(), ivI32.value);
608+
memref_store(val, xferOp.source(), indices);
609+
});
595610
builder.create<scf::YieldOp>(loc);
596611
}
597612

598-
static Value initialLoopState(TransferWriteOp xferOp) {
599-
return Value();
600-
}
613+
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
601614
};
602615

603616
/// Lower a 1D vector transfer op that operates on a dimension different from
@@ -631,11 +644,11 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
631644
auto map = xferOp.permutation_map();
632645

633646
if (xferOp.getVectorType().getRank() != 1)
634-
return failure();
635-
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
636-
return failure();
647+
return failure();
648+
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
649+
return failure();
637650
if (xferOp.mask())
638-
return failure();
651+
return failure();
639652

640653
// Loop bounds, step, state...
641654
auto vecType = xferOp.getVectorType();
@@ -648,10 +661,10 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
648661
rewriter.replaceOpWithNewOp<scf::ForOp>(
649662
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
650663
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
651-
ScopedContext nestedScope(builder, loc);
652-
Strategy1d<OpTy>::generateForLoopBody(
653-
builder, loc, xferOp, iv, loopState);
654-
});
664+
ScopedContext nestedScope(builder, loc);
665+
Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
666+
loopState);
667+
});
655668

656669
return success();
657670
}
@@ -689,3 +702,4 @@ std::unique_ptr<Pass>
689702
mlir::createProgressiveConvertVectorToSCFPass() {
690703
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
691704
}
705+

0 commit comments

Comments
 (0)