@@ -107,8 +107,8 @@ static void getXferIndices(OpTy xferOp, Value iv,
107
107
indices[dim] = adaptor.indices ()[dim] + iv;
108
108
}
109
109
110
- static void maybeYieldValue (
111
- bool hasRetVal, OpBuilder builder, Location loc, Value value) {
110
+ static void maybeYieldValue (bool hasRetVal, OpBuilder builder, Location loc,
111
+ Value value) {
112
112
if (hasRetVal) {
113
113
builder.create <scf::YieldOp>(loc, value);
114
114
} else {
@@ -150,15 +150,19 @@ static Value generateInBoundsCheck(
150
150
auto cond = std_cmpi_sgt (memrefDim.value , memrefIdx);
151
151
auto check = builder.create <scf::IfOp>(
152
152
xferOp.getLoc (), resultTypes, cond,
153
- /* thenBuilder=*/ [&](OpBuilder &builder, Location loc) {
154
- maybeYieldValue (hasRetVal, builder, loc, inBoundsCase (builder, loc));
155
- }, /* elseBuilder=*/ [&](OpBuilder &builder, Location loc) {
156
- if (outOfBoundsCase) {
157
- maybeYieldValue (hasRetVal, builder, loc, outOfBoundsCase (builder, loc));
158
- } else {
159
- builder.create <scf::YieldOp>(loc);
160
- }
161
- });
153
+ /* thenBuilder=*/
154
+ [&](OpBuilder &builder, Location loc) {
155
+ maybeYieldValue (hasRetVal, builder, loc, inBoundsCase (builder, loc));
156
+ },
157
+ /* elseBuilder=*/
158
+ [&](OpBuilder &builder, Location loc) {
159
+ if (outOfBoundsCase) {
160
+ maybeYieldValue (hasRetVal, builder, loc,
161
+ outOfBoundsCase (builder, loc));
162
+ } else {
163
+ builder.create <scf::YieldOp>(loc);
164
+ }
165
+ });
162
166
163
167
return hasRetVal ? check.getResult (0 ) : Value ();
164
168
}
@@ -176,22 +180,24 @@ static void generateInBoundsCheck(
176
180
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
177
181
generateInBoundsCheck (
178
182
xferOp, iv, builder, dim, /* resultTypes=*/ TypeRange (),
179
- /* inBoundsCase=*/ [&](OpBuilder &builder, Location loc) {
183
+ /* inBoundsCase=*/
184
+ [&](OpBuilder &builder, Location loc) {
180
185
inBoundsCase (builder, loc);
181
186
return Value ();
182
187
},
183
- /* outOfBoundsCase=*/ [&](OpBuilder &builder, Location loc) {
188
+ /* outOfBoundsCase=*/
189
+ [&](OpBuilder &builder, Location loc) {
184
190
if (outOfBoundsCase)
185
- outOfBoundsCase (builder, loc);
191
+ outOfBoundsCase (builder, loc);
186
192
return Value ();
187
193
});
188
194
}
189
195
190
196
// / Given an ArrayAttr, return a copy where the first element is dropped.
191
- static ArrayAttr dropFirstElem (PatternRewriter &rewriter , ArrayAttr attr) {
197
+ static ArrayAttr dropFirstElem (OpBuilder &builder , ArrayAttr attr) {
192
198
if (!attr)
193
199
return attr;
194
- return ArrayAttr::get (rewriter .getContext (), attr.getValue ().drop_front ());
200
+ return ArrayAttr::get (builder .getContext (), attr.getValue ().drop_front ());
195
201
}
196
202
197
203
// / Codegen strategy, depending on the operation.
@@ -256,8 +262,8 @@ struct Strategy<TransferReadOp> {
256
262
// /
257
263
// / Note: The loop and type cast are generated in TransferOpConversion.
258
264
// / The original TransferReadOp and store op are deleted in `cleanup`.
259
- static void rewriteOp (PatternRewriter &rewriter , TransferReadOp xferOp,
260
- Value buffer, Value iv) {
265
+ static void rewriteOp (OpBuilder &builder , TransferReadOp xferOp, Value buffer ,
266
+ Value iv) {
261
267
SmallVector<Value, 8 > storeIndices;
262
268
getStoreIndices (xferOp, storeIndices);
263
269
storeIndices.push_back (iv);
@@ -267,25 +273,25 @@ struct Strategy<TransferReadOp> {
267
273
268
274
auto bufferType = buffer.getType ().dyn_cast <ShapedType>();
269
275
auto vecType = bufferType.getElementType ().dyn_cast <VectorType>();
270
- auto inBoundsAttr = dropFirstElem (rewriter , xferOp.in_boundsAttr ());
276
+ auto inBoundsAttr = dropFirstElem (builder , xferOp.in_boundsAttr ());
271
277
auto newXfer =
272
278
vector_transfer_read (
273
279
vecType, xferOp.source (), xferIndices,
274
- AffineMapAttr::get (unpackedPermutationMap (xferOp, rewriter )),
280
+ AffineMapAttr::get (unpackedPermutationMap (xferOp, builder )),
275
281
xferOp.padding (), Value (), inBoundsAttr)
276
282
.value ;
277
283
278
284
if (vecType.getRank () > kTargetRank )
279
- newXfer.getDefiningOp ()->setAttr (kPassLabel , rewriter .getUnitAttr ());
285
+ newXfer.getDefiningOp ()->setAttr (kPassLabel , builder .getUnitAttr ());
280
286
281
287
memref_store (newXfer, buffer, storeIndices);
282
288
}
283
289
284
290
// / Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
285
291
// / padding value to the temporary buffer.
286
- static void handleOutOfBoundsDim (
287
- PatternRewriter &rewriter, TransferReadOp xferOp, Value buffer,
288
- Value iv) {
292
+ static void handleOutOfBoundsDim (OpBuilder & /* builder */ ,
293
+ TransferReadOp xferOp, Value buffer,
294
+ Value iv) {
289
295
SmallVector<Value, 8 > storeIndices;
290
296
getStoreIndices (xferOp, storeIndices);
291
297
storeIndices.push_back (iv);
@@ -336,7 +342,7 @@ struct Strategy<TransferWriteOp> {
336
342
// / to memory.
337
343
// /
338
344
// / Note: For more details, see comments on Strategy<TransferReadOp>.
339
- static void rewriteOp (PatternRewriter &rewriter , TransferWriteOp xferOp,
345
+ static void rewriteOp (OpBuilder &builder , TransferWriteOp xferOp,
340
346
Value buffer, Value iv) {
341
347
SmallVector<Value, 8 > loadIndices;
342
348
getLoadIndices (xferOp, loadIndices);
@@ -347,20 +353,19 @@ struct Strategy<TransferWriteOp> {
347
353
348
354
auto vec = memref_load (buffer, loadIndices);
349
355
auto vecType = vec.value .getType ().dyn_cast <VectorType>();
350
- auto inBoundsAttr = dropFirstElem (rewriter , xferOp.in_boundsAttr ());
356
+ auto inBoundsAttr = dropFirstElem (builder , xferOp.in_boundsAttr ());
351
357
auto newXfer = vector_transfer_write (
352
358
Type (), vec, xferOp.source (), xferIndices,
353
- AffineMapAttr::get (unpackedPermutationMap (xferOp, rewriter )), Value (),
359
+ AffineMapAttr::get (unpackedPermutationMap (xferOp, builder )), Value (),
354
360
inBoundsAttr);
355
361
356
362
if (vecType.getRank () > kTargetRank )
357
- newXfer.op ->setAttr (kPassLabel , rewriter .getUnitAttr ());
363
+ newXfer.op ->setAttr (kPassLabel , builder .getUnitAttr ());
358
364
}
359
365
360
366
// / Handle out-of-bounds accesses on the to-be-unpacked dimension.
361
- static void handleOutOfBoundsDim (
362
- PatternRewriter &rewriter, TransferWriteOp xferOp, Value buffer,
363
- Value iv) {}
367
+ static void handleOutOfBoundsDim (OpBuilder &builder, TransferWriteOp xferOp,
368
+ Value buffer, Value iv) {}
364
369
365
370
// / Cleanup after rewriting the op.
366
371
static void cleanup (PatternRewriter &rewriter, TransferWriteOp xferOp) {
@@ -499,18 +504,29 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
499
504
auto castedType = unpackOneDim (bufferType);
500
505
auto casted = vector_type_cast (castedType, buffer);
501
506
507
+ // Loop bounds and step.
502
508
auto lb = std_constant_index (0 ).value ;
503
509
auto ub = std_constant_index (
504
510
castedType.getDimSize (castedType.getRank () - 1 )).value ;
505
- affineLoopBuilder (lb, ub, 1 , [&](Value iv) {
506
- generateInBoundsCheck (
507
- xferOp, iv, rewriter, unpackedDim (xferOp),
508
- /* inBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
509
- Strategy<OpTy>::rewriteOp (rewriter, xferOp, casted, iv);
510
- }, /* outOfBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
511
- Strategy<OpTy>::handleOutOfBoundsDim (rewriter, xferOp, casted, iv);
512
- });
513
- });
511
+ auto step = std_constant_index (1 ).value ;
512
+
513
+ // Generate for loop.
514
+ rewriter.create <scf::ForOp>(
515
+ xferOp.getLoc (), lb, ub, step, ValueRange (),
516
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange /* loopState*/ ) {
517
+ ScopedContext scope (b, loc);
518
+ generateInBoundsCheck (
519
+ xferOp, iv, b, unpackedDim (xferOp),
520
+ /* inBoundsCase=*/
521
+ [&](OpBuilder &b, Location /* loc*/ ) {
522
+ Strategy<OpTy>::rewriteOp (b, xferOp, casted, iv);
523
+ },
524
+ /* outOfBoundsCase=*/
525
+ [&](OpBuilder &b, Location /* loc*/ ) {
526
+ Strategy<OpTy>::handleOutOfBoundsDim (b, xferOp, casted, iv);
527
+ });
528
+ b.create <scf::YieldOp>(loc);
529
+ });
514
530
515
531
Strategy<OpTy>::cleanup (rewriter, xferOp);
516
532
return success ();
@@ -546,25 +562,25 @@ struct Strategy1d;
546
562
// / Codegen strategy for TransferReadOp.
547
563
template <>
548
564
struct Strategy1d <TransferReadOp> {
549
- static void generateForLoopBody (
550
- OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
551
- ValueRange loopState) {
565
+ static void generateForLoopBody (OpBuilder &builder, Location loc,
566
+ TransferReadOp xferOp, Value iv,
567
+ ValueRange loopState) {
552
568
SmallVector<Value, 8 > indices;
553
569
auto dim = get1dMemrefIndices (xferOp, iv, indices);
554
- auto ivI32 = std_index_cast (
555
- IntegerType::get (builder.getContext (), 32 ), iv);
570
+ auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
556
571
auto vec = loopState[0 ];
557
572
558
573
// In case of out-of-bounds access, leave `vec` as is (was initialized with
559
574
// padding value).
560
575
auto nextVec = generateInBoundsCheck (
561
576
xferOp, iv, builder, dim, TypeRange (xferOp.getVectorType ()),
562
- /* inBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
563
- auto val = memref_load (xferOp.source (), indices);
564
- return vector_insert_element (val, vec, ivI32.value ).value ;
565
- }, /* outOfBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
566
- return vec;
567
- });
577
+ /* inBoundsCase=*/
578
+ [&](OpBuilder & /* b*/ , Location loc) {
579
+ auto val = memref_load (xferOp.source (), indices);
580
+ return vector_insert_element (val, vec, ivI32.value ).value ;
581
+ },
582
+ /* outOfBoundsCase=*/
583
+ [&](OpBuilder & /* b*/ , Location loc) { return vec; });
568
584
builder.create <scf::YieldOp>(loc, nextVec);
569
585
}
570
586
@@ -577,27 +593,24 @@ struct Strategy1d<TransferReadOp> {
577
593
// / Codegen strategy for TransferWriteOp.
578
594
template <>
579
595
struct Strategy1d <TransferWriteOp> {
580
- static void generateForLoopBody (
581
- OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
582
- ValueRange /* loopState*/ ) {
596
+ static void generateForLoopBody (OpBuilder &builder, Location loc,
597
+ TransferWriteOp xferOp, Value iv,
598
+ ValueRange /* loopState*/ ) {
583
599
SmallVector<Value, 8 > indices;
584
600
auto dim = get1dMemrefIndices (xferOp, iv, indices);
585
- auto ivI32 = std_index_cast (
586
- IntegerType::get (builder.getContext (), 32 ), iv);
601
+ auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
587
602
588
603
// Nothing to do in case of out-of-bounds access.
589
604
generateInBoundsCheck (
590
605
xferOp, iv, builder, dim,
591
- /* inBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
592
- auto val = vector_extract_element (xferOp.vector (), ivI32.value );
593
- memref_store (val, xferOp.source (), indices);
594
- });
606
+ /* inBoundsCase=*/ [&](OpBuilder & /* b*/ , Location loc) {
607
+ auto val = vector_extract_element (xferOp.vector (), ivI32.value );
608
+ memref_store (val, xferOp.source (), indices);
609
+ });
595
610
builder.create <scf::YieldOp>(loc);
596
611
}
597
612
598
- static Value initialLoopState (TransferWriteOp xferOp) {
599
- return Value ();
600
- }
613
+ static Value initialLoopState (TransferWriteOp xferOp) { return Value (); }
601
614
};
602
615
603
616
// / Lower a 1D vector transfer op that operates on a dimension different from
@@ -631,11 +644,11 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
631
644
auto map = xferOp.permutation_map ();
632
645
633
646
if (xferOp.getVectorType ().getRank () != 1 )
634
- return failure ();
635
- if (map.isMinorIdentity ()) // Handled by ConvertVectorToLLVM
636
- return failure ();
647
+ return failure ();
648
+ if (map.isMinorIdentity ()) // Handled by ConvertVectorToLLVM
649
+ return failure ();
637
650
if (xferOp.mask ())
638
- return failure ();
651
+ return failure ();
639
652
640
653
// Loop bounds, step, state...
641
654
auto vecType = xferOp.getVectorType ();
@@ -648,10 +661,10 @@ struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
648
661
rewriter.replaceOpWithNewOp <scf::ForOp>(
649
662
xferOp, lb, ub, step, loopState ? ValueRange (loopState) : ValueRange (),
650
663
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
651
- ScopedContext nestedScope (builder, loc);
652
- Strategy1d<OpTy>::generateForLoopBody (
653
- builder, loc, xferOp, iv, loopState);
654
- });
664
+ ScopedContext nestedScope (builder, loc);
665
+ Strategy1d<OpTy>::generateForLoopBody (builder, loc, xferOp, iv,
666
+ loopState);
667
+ });
655
668
656
669
return success ();
657
670
}
@@ -689,3 +702,4 @@ std::unique_ptr<Pass>
689
702
mlir::createProgressiveConvertVectorToSCFPass () {
690
703
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
691
704
}
705
+
0 commit comments