@@ -74,9 +74,9 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
74
74
template <typename OpTy>
75
75
static Optional<int64_t > unpackedDim (OpTy xferOp) {
76
76
auto map = xferOp.permutation_map ();
77
- if (auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>())
77
+ if (auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>()) {
78
78
return expr.getPosition ();
79
-
79
+ }
80
80
assert (map.getResult (0 ).template isa <AffineConstantExpr>() &&
81
81
" Expected AffineDimExpr or AffineConstantExpr" );
82
82
return None;
@@ -88,8 +88,9 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
88
88
template <typename OpTy>
89
89
static AffineMap unpackedPermutationMap (OpTy xferOp, OpBuilder &builder) {
90
90
auto map = xferOp.permutation_map ();
91
- return AffineMap::get (map.getNumDims (), 0 , map.getResults ().drop_front (),
92
- builder.getContext ());
91
+ return AffineMap::get (
92
+ map.getNumDims (), 0 , map.getResults ().drop_front (),
93
+ builder.getContext ());
93
94
}
94
95
95
96
// / Calculate the indices for the new vector transfer op.
@@ -114,15 +115,29 @@ static void getXferIndices(OpTy xferOp, Value iv,
114
115
}
115
116
}
116
117
117
- static void maybeYieldValue (bool hasRetVal, OpBuilder builder, Location loc,
118
- Value value) {
118
+ static void maybeYieldValue (
119
+ bool hasRetVal, OpBuilder builder, Location loc, Value value) {
119
120
if (hasRetVal) {
120
121
builder.create <scf::YieldOp>(loc, value);
121
122
} else {
122
123
builder.create <scf::YieldOp>(loc);
123
124
}
124
125
}
125
126
127
+ // / Generates a boolean Value that is true if the iv-th bit in xferOp's mask
128
+ // / is set to true. Does not return a Value if the transfer op is not 1D or
129
+ // / if the transfer op does not have a mask.
130
+ template <typename OpTy>
131
+ static Value maybeGenerateMaskCheck (OpBuilder &builder, OpTy xferOp, Value iv) {
132
+ if (xferOp.getVectorType ().getRank () != 1 )
133
+ return Value ();
134
+ if (!xferOp.mask ())
135
+ return Value ();
136
+
137
+ auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
138
+ return vector_extract_element (xferOp.mask (), ivI32).value ;
139
+ }
140
+
126
141
// / Helper function TransferOpConversion and TransferOp1dConversion.
127
142
// / Generate an in-bounds check if the transfer op may go out-of-bounds on the
128
143
// / specified dimension `dim` with the loop iteration variable `iv`.
@@ -140,6 +155,10 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
140
155
// / (out-of-bounds case)
141
156
// / }
142
157
// / ```
158
+ // /
159
+ // / If the transfer is 1D and has a mask, this function generates a more complex
160
+ // / check also accounts for potentially masked out elements.
161
+ // /
143
162
// / This function variant returns the value returned by `inBoundsCase` or
144
163
// / `outOfBoundsCase`. The MLIR type of the return value must be specified in
145
164
// / `resultTypes`.
@@ -150,33 +169,45 @@ static Value generateInBoundsCheck(
150
169
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
151
170
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
152
171
bool hasRetVal = !resultTypes.empty ();
153
- bool isBroadcast = !dim.hasValue (); // No in-bounds check for broadcasts.
172
+ Value cond; // Condition to be built...
173
+
174
+ // Condition check 1: Access in-bounds?
175
+ bool isBroadcast = !dim.hasValue (); // No in-bounds check for broadcasts.
154
176
if (!xferOp.isDimInBounds (0 ) && !isBroadcast) {
155
177
auto memrefDim =
156
178
memref_dim (xferOp.source (), std_constant_index (dim.getValue ()));
157
179
using edsc::op::operator +;
158
180
auto memrefIdx = xferOp.indices ()[dim.getValue ()] + iv;
159
- auto cond = std_cmpi_sgt (memrefDim.value , memrefIdx);
181
+ cond = std_cmpi_sgt (memrefDim.value , memrefIdx);
182
+ }
183
+
184
+ // Condition check 2: Masked in?
185
+ if (auto maskCond = maybeGenerateMaskCheck (builder, xferOp, iv)) {
186
+ if (cond) {
187
+ cond = builder.create <AndOp>(xferOp.getLoc (), cond, maskCond);
188
+ } else {
189
+ cond = maskCond;
190
+ }
191
+ }
192
+
193
+ // If the condition is non-empty, generate an SCF::IfOp.
194
+ if (cond) {
160
195
auto check = builder.create <scf::IfOp>(
161
196
xferOp.getLoc (), resultTypes, cond,
162
- /* thenBuilder=*/
163
- [&](OpBuilder &builder, Location loc) {
164
- maybeYieldValue (hasRetVal, builder, loc, inBoundsCase (builder, loc));
165
- },
166
- /* elseBuilder=*/
167
- [&](OpBuilder &builder, Location loc) {
168
- if (outOfBoundsCase) {
169
- maybeYieldValue (hasRetVal, builder, loc,
170
- outOfBoundsCase (builder, loc));
171
- } else {
172
- builder.create <scf::YieldOp>(loc);
173
- }
174
- });
197
+ /* thenBuilder=*/ [&](OpBuilder &builder, Location loc) {
198
+ maybeYieldValue (hasRetVal, builder, loc, inBoundsCase (builder, loc));
199
+ }, /* elseBuilder=*/ [&](OpBuilder &builder, Location loc) {
200
+ if (outOfBoundsCase) {
201
+ maybeYieldValue (hasRetVal, builder, loc, outOfBoundsCase (builder, loc));
202
+ } else {
203
+ builder.create <scf::YieldOp>(loc);
204
+ }
205
+ });
175
206
176
207
return hasRetVal ? check.getResult (0 ) : Value ();
177
208
}
178
209
179
- // No runtime check needed if dim is guaranteed to be in-bounds .
210
+ // Condition is empty, no need for an SCF::IfOp .
180
211
return inBoundsCase (builder, xferOp.getLoc ());
181
212
}
182
213
@@ -189,15 +220,13 @@ static void generateInBoundsCheck(
189
220
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
190
221
generateInBoundsCheck (
191
222
xferOp, iv, builder, dim, /* resultTypes=*/ TypeRange (),
192
- /* inBoundsCase=*/
193
- [&](OpBuilder &builder, Location loc) {
223
+ /* inBoundsCase=*/ [&](OpBuilder &builder, Location loc) {
194
224
inBoundsCase (builder, loc);
195
225
return Value ();
196
226
},
197
- /* outOfBoundsCase=*/
198
- [&](OpBuilder &builder, Location loc) {
227
+ /* outOfBoundsCase=*/ [&](OpBuilder &builder, Location loc) {
199
228
if (outOfBoundsCase)
200
- outOfBoundsCase (builder, loc);
229
+ outOfBoundsCase (builder, loc);
201
230
return Value ();
202
231
});
203
232
}
@@ -271,8 +300,8 @@ struct Strategy<TransferReadOp> {
271
300
// /
272
301
// / Note: The loop and type cast are generated in TransferOpConversion.
273
302
// / The original TransferReadOp and store op are deleted in `cleanup`.
274
- static void rewriteOp (OpBuilder &builder, TransferReadOp xferOp, Value buffer,
275
- Value iv) {
303
+ static void rewriteOp (OpBuilder &builder, TransferReadOp xferOp,
304
+ Value buffer, Value iv) {
276
305
SmallVector<Value, 8 > storeIndices;
277
306
getStoreIndices (xferOp, storeIndices);
278
307
storeIndices.push_back (iv);
@@ -283,24 +312,22 @@ struct Strategy<TransferReadOp> {
283
312
auto bufferType = buffer.getType ().dyn_cast <ShapedType>();
284
313
auto vecType = bufferType.getElementType ().dyn_cast <VectorType>();
285
314
auto inBoundsAttr = dropFirstElem (builder, xferOp.in_boundsAttr ());
286
- auto newXfer =
287
- vector_transfer_read (
288
- vecType, xferOp.source (), xferIndices,
289
- AffineMapAttr::get (unpackedPermutationMap (xferOp, builder)),
290
- xferOp.padding (), Value (), inBoundsAttr)
291
- .value ;
315
+ auto newXfer = vector_transfer_read (
316
+ vecType, xferOp.source (), xferIndices,
317
+ AffineMapAttr::get (unpackedPermutationMap (xferOp, builder)),
318
+ xferOp.padding (), Value (), inBoundsAttr).value ;
292
319
293
320
if (vecType.getRank () > kTargetRank )
294
- newXfer.getDefiningOp ()->setAttr (kPassLabel , builder.getUnitAttr ());
321
+ newXfer.getDefiningOp ()->setAttr (kPassLabel , builder.getUnitAttr ());
295
322
296
323
memref_store (newXfer, buffer, storeIndices);
297
324
}
298
325
299
326
// / Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
300
327
// / padding value to the temporary buffer.
301
- static void handleOutOfBoundsDim (OpBuilder & /* builder */ ,
302
- TransferReadOp xferOp, Value buffer,
303
- Value iv) {
328
+ static void handleOutOfBoundsDim (
329
+ OpBuilder & /* builder */ , TransferReadOp xferOp, Value buffer,
330
+ Value iv) {
304
331
SmallVector<Value, 8 > storeIndices;
305
332
getStoreIndices (xferOp, storeIndices);
306
333
storeIndices.push_back (iv);
@@ -365,16 +392,17 @@ struct Strategy<TransferWriteOp> {
365
392
auto inBoundsAttr = dropFirstElem (builder, xferOp.in_boundsAttr ());
366
393
auto newXfer = vector_transfer_write (
367
394
Type (), vec, xferOp.source (), xferIndices,
368
- AffineMapAttr::get (unpackedPermutationMap (xferOp, builder)), Value (),
369
- inBoundsAttr);
395
+ AffineMapAttr::get (unpackedPermutationMap (xferOp, builder)),
396
+ Value (), inBoundsAttr);
370
397
371
398
if (vecType.getRank () > kTargetRank )
372
- newXfer.op ->setAttr (kPassLabel , builder.getUnitAttr ());
399
+ newXfer.op ->setAttr (kPassLabel , builder.getUnitAttr ());
373
400
}
374
401
375
402
// / Handle out-of-bounds accesses on the to-be-unpacked dimension.
376
- static void handleOutOfBoundsDim (OpBuilder &builder, TransferWriteOp xferOp,
377
- Value buffer, Value iv) {}
403
+ static void handleOutOfBoundsDim (
404
+ OpBuilder &builder, TransferWriteOp xferOp, Value buffer,
405
+ Value iv) {}
378
406
379
407
// / Cleanup after rewriting the op.
380
408
static void cleanup (PatternRewriter &rewriter, TransferWriteOp xferOp) {
@@ -522,20 +550,18 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
522
550
// Generate for loop.
523
551
rewriter.create <scf::ForOp>(
524
552
xferOp.getLoc (), lb, ub, step, ValueRange (),
525
- [&](OpBuilder &b, Location loc, Value iv, ValueRange /* loopState*/ ) {
526
- ScopedContext scope (b, loc);
527
- generateInBoundsCheck (
528
- xferOp, iv, b, unpackedDim (xferOp),
529
- /* inBoundsCase=*/
530
- [&](OpBuilder &b, Location /* loc*/ ) {
531
- Strategy<OpTy>::rewriteOp (b, xferOp, casted, iv);
532
- },
533
- /* outOfBoundsCase=*/
534
- [&](OpBuilder &b, Location /* loc*/ ) {
535
- Strategy<OpTy>::handleOutOfBoundsDim (b, xferOp, casted, iv);
536
- });
537
- b.create <scf::YieldOp>(loc);
538
- });
553
+ [&](OpBuilder &b, Location loc, Value iv,
554
+ ValueRange /* loopState*/ ) {
555
+ ScopedContext scope (b, loc);
556
+ generateInBoundsCheck (
557
+ xferOp, iv, b, unpackedDim (xferOp),
558
+ /* inBoundsCase=*/ [&](OpBuilder &b, Location /* loc*/ ) {
559
+ Strategy<OpTy>::rewriteOp (b, xferOp, casted, iv);
560
+ }, /* outOfBoundsCase=*/ [&](OpBuilder &b, Location /* loc*/ ) {
561
+ Strategy<OpTy>::handleOutOfBoundsDim (b, xferOp, casted, iv);
562
+ });
563
+ b.create <scf::YieldOp>(loc);
564
+ });
539
565
540
566
Strategy<OpTy>::cleanup (rewriter, xferOp);
541
567
return success ();
@@ -546,9 +572,8 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
546
572
// / part of TransferOp1dConversion. Return the memref dimension on which
547
573
// / the transfer is operating. A return value of None indicates a broadcast.
548
574
template <typename OpTy>
549
- static Optional<int64_t >
550
- get1dMemrefIndices (OpTy xferOp, Value iv,
551
- SmallVector<Value, 8 > &memrefIndices) {
575
+ static Optional<int64_t > get1dMemrefIndices (
576
+ OpTy xferOp, Value iv, SmallVector<Value, 8 > &memrefIndices) {
552
577
auto indices = xferOp.indices ();
553
578
auto map = xferOp.permutation_map ();
554
579
@@ -575,25 +600,25 @@ struct Strategy1d;
575
600
// / Codegen strategy for TransferReadOp.
576
601
template <>
577
602
struct Strategy1d <TransferReadOp> {
578
- static void generateForLoopBody (OpBuilder &builder, Location loc,
579
- TransferReadOp xferOp, Value iv,
580
- ValueRange loopState) {
603
+ static void generateForLoopBody (
604
+ OpBuilder &builder, Location loc, TransferReadOp xferOp, Value iv,
605
+ ValueRange loopState) {
581
606
SmallVector<Value, 8 > indices;
582
607
auto dim = get1dMemrefIndices (xferOp, iv, indices);
583
- auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
608
+ auto ivI32 = std_index_cast (
609
+ IntegerType::get (builder.getContext (), 32 ), iv);
584
610
auto vec = loopState[0 ];
585
611
586
612
// In case of out-of-bounds access, leave `vec` as is (was initialized with
587
613
// padding value).
588
614
auto nextVec = generateInBoundsCheck (
589
615
xferOp, iv, builder, dim, TypeRange (xferOp.getVectorType ()),
590
- /* inBoundsCase=*/
591
- [&](OpBuilder & /* b*/ , Location loc) {
592
- auto val = memref_load (xferOp.source (), indices);
593
- return vector_insert_element (val, vec, ivI32.value ).value ;
594
- },
595
- /* outOfBoundsCase=*/
596
- [&](OpBuilder & /* b*/ , Location loc) { return vec; });
616
+ /* inBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
617
+ auto val = memref_load (xferOp.source (), indices);
618
+ return vector_insert_element (val, vec, ivI32.value ).value ;
619
+ }, /* outOfBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
620
+ return vec;
621
+ });
597
622
builder.create <scf::YieldOp>(loc, nextVec);
598
623
}
599
624
@@ -606,24 +631,27 @@ struct Strategy1d<TransferReadOp> {
606
631
// / Codegen strategy for TransferWriteOp.
607
632
template <>
608
633
struct Strategy1d <TransferWriteOp> {
609
- static void generateForLoopBody (OpBuilder &builder, Location loc,
610
- TransferWriteOp xferOp, Value iv,
611
- ValueRange /* loopState*/ ) {
634
+ static void generateForLoopBody (
635
+ OpBuilder &builder, Location loc, TransferWriteOp xferOp, Value iv,
636
+ ValueRange /* loopState*/ ) {
612
637
SmallVector<Value, 8 > indices;
613
638
auto dim = get1dMemrefIndices (xferOp, iv, indices);
614
- auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
639
+ auto ivI32 = std_index_cast (
640
+ IntegerType::get (builder.getContext (), 32 ), iv);
615
641
616
642
// Nothing to do in case of out-of-bounds access.
617
643
generateInBoundsCheck (
618
644
xferOp, iv, builder, dim,
619
- /* inBoundsCase=*/ [&](OpBuilder & /* b*/ , Location loc) {
620
- auto val = vector_extract_element (xferOp.vector (), ivI32.value );
621
- memref_store (val, xferOp.source (), indices);
622
- });
645
+ /* inBoundsCase=*/ [&](OpBuilder& /* b*/ , Location loc) {
646
+ auto val = vector_extract_element (xferOp.vector (), ivI32.value );
647
+ memref_store (val, xferOp.source (), indices);
648
+ });
623
649
builder.create <scf::YieldOp>(loc);
624
650
}
625
651
626
- static Value initialLoopState (TransferWriteOp xferOp) { return Value (); }
652
+ static Value initialLoopState (TransferWriteOp xferOp) {
653
+ return Value ();
654
+ }
627
655
};
628
656
629
657
// / Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
@@ -667,11 +695,9 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
667
695
auto map = xferOp.permutation_map ();
668
696
669
697
if (xferOp.getVectorType ().getRank () != 1 )
670
- return failure ();
671
- if (map.isMinorIdentity ()) // Handled by ConvertVectorToLLVM
672
- return failure ();
673
- if (xferOp.mask ())
674
- return failure ();
698
+ return failure ();
699
+ if (map.isMinorIdentity ()) // Handled by ConvertVectorToLLVM
700
+ return failure ();
675
701
676
702
// Loop bounds, step, state...
677
703
auto vecType = xferOp.getVectorType ();
@@ -684,10 +710,10 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
684
710
rewriter.replaceOpWithNewOp <scf::ForOp>(
685
711
xferOp, lb, ub, step, loopState ? ValueRange (loopState) : ValueRange (),
686
712
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
687
- ScopedContext nestedScope (builder, loc);
688
- Strategy1d<OpTy>::generateForLoopBody (builder, loc, xferOp, iv,
689
- loopState);
690
- });
713
+ ScopedContext nestedScope (builder, loc);
714
+ Strategy1d<OpTy>::generateForLoopBody (
715
+ builder, loc, xferOp, iv, loopState);
716
+ });
691
717
692
718
return success ();
693
719
}
@@ -699,7 +725,8 @@ namespace mlir {
699
725
700
726
void populateProgressiveVectorToSCFConversionPatterns (
701
727
RewritePatternSet &patterns) {
702
- patterns.add <PrepareTransferReadConversion, PrepareTransferWriteConversion,
728
+ patterns.add <PrepareTransferReadConversion,
729
+ PrepareTransferWriteConversion,
703
730
TransferOpConversion<TransferReadOp>,
704
731
TransferOpConversion<TransferWriteOp>>(patterns.getContext ());
705
732
@@ -725,4 +752,3 @@ std::unique_ptr<Pass>
725
752
mlir::createProgressiveConvertVectorToSCFPass () {
726
753
return std::make_unique<ConvertProgressiveVectorToSCFPass>();
727
754
}
728
-
0 commit comments