@@ -93,36 +93,90 @@ static void getXferIndices(OpTy xferOp, Value iv,
93
93
indices[dim] = adaptor.indices ()[dim] + iv;
94
94
}
95
95
96
- // / Generate an in-bounds check if the transfer op on the to-be-unpacked
97
- // / dimension may go out-of-bounds.
98
- template <typename OpTy>
99
- static void generateInBoundsCheck (
100
- OpTy xferOp, Value iv, PatternRewriter &rewriter,
101
- function_ref<void (OpBuilder &, Location)> inBoundsCase,
102
- function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
103
- // Corresponding memref dim of the vector dim that is unpacked.
104
- auto dim = unpackedDim (xferOp);
96
+ static void maybeYieldValue (bool hasRetVal, OpBuilder builder, Location loc,
97
+ Value value) {
98
+ if (hasRetVal) {
99
+ builder.create <scf::YieldOp>(loc, value);
100
+ } else {
101
+ builder.create <scf::YieldOp>(loc);
102
+ }
103
+ }
105
104
105
+ // / Helper function TransferOpConversion and Strided1dTransferOpConversion.
106
+ // / Generate an in-bounds check if the transfer op may go out-of-bounds on the
107
+ // / specified dimension `dim` with the loop iteration variable `iv`.
108
+ // / E.g., when unpacking dimension 0 from:
109
+ // / ```
110
+ // / %vec = vector.transfer_read %A[%a, %b] %cst
111
+ // / : vector<5x4xf32>, memref<?x?xf32>
112
+ // / ```
113
+ // / An if check similar to this will be generated inside the loop:
114
+ // / ```
115
+ // / %d = memref.dim %A, %c0 : memref<?x?xf32>
116
+ // / if (%a + iv < %d) {
117
+ // / (in-bounds case)
118
+ // / } else {
119
+ // / (out-of-bounds case)
120
+ // / }
121
+ // / ```
122
+ // / This function variant returns the value returned by `inBoundsCase` or
123
+ // / `outOfBoundsCase`. The MLIR type of the return value must be specified in
124
+ // / `resultTypes`.
125
+ template <typename OpTy>
126
+ static Value generateInBoundsCheck (
127
+ OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
128
+ TypeRange resultTypes,
129
+ function_ref<Value(OpBuilder &, Location)> inBoundsCase,
130
+ function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
131
+ bool hasRetVal = !resultTypes.empty ();
106
132
if (!xferOp.isDimInBounds (0 )) {
107
133
auto memrefDim = memref_dim (xferOp.source (), std_constant_index (dim));
108
134
using edsc::op::operator +;
109
135
auto memrefIdx = xferOp.indices ()[dim] + iv;
110
136
auto cond = std_cmpi_sgt (memrefDim.value , memrefIdx);
111
- rewriter.create <scf::IfOp>(
112
- xferOp.getLoc (), cond,
137
+ auto check = builder.create <scf::IfOp>(
138
+ xferOp.getLoc (), resultTypes, cond,
139
+ /* thenBuilder=*/
113
140
[&](OpBuilder &builder, Location loc) {
114
- inBoundsCase (builder, loc);
115
- builder.create <scf::YieldOp>(xferOp.getLoc ());
141
+ maybeYieldValue (hasRetVal, builder, loc, inBoundsCase (builder, loc));
116
142
},
143
+ /* elseBuilder=*/
117
144
[&](OpBuilder &builder, Location loc) {
118
- if (outOfBoundsCase)
119
- outOfBoundsCase (builder, loc);
120
- builder.create <scf::YieldOp>(xferOp.getLoc ());
145
+ if (outOfBoundsCase) {
146
+ maybeYieldValue (hasRetVal, builder, loc,
147
+ outOfBoundsCase (builder, loc));
148
+ } else {
149
+ builder.create <scf::YieldOp>(loc);
150
+ }
121
151
});
122
- } else {
123
- // No runtime check needed if dim is guaranteed to be in-bounds.
124
- inBoundsCase (rewriter, xferOp.getLoc ());
152
+
153
+ return hasRetVal ? check.getResult (0 ) : Value ();
125
154
}
155
+
156
+ // No runtime check needed if dim is guaranteed to be in-bounds.
157
+ return inBoundsCase (builder, xferOp.getLoc ());
158
+ }
159
+
160
+ // / In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
161
+ // / a return value. Consequently, this function does not have a return value.
162
+ template <typename OpTy>
163
+ static void generateInBoundsCheck (
164
+ OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
165
+ function_ref<void (OpBuilder &, Location)> inBoundsCase,
166
+ function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
167
+ generateInBoundsCheck (
168
+ xferOp, iv, builder, dim, /* resultTypes=*/ TypeRange (),
169
+ /* inBoundsCase=*/
170
+ [&](OpBuilder &builder, Location loc) {
171
+ inBoundsCase (builder, loc);
172
+ return Value ();
173
+ },
174
+ /* outOfBoundsCase=*/
175
+ [&](OpBuilder &builder, Location loc) {
176
+ if (outOfBoundsCase)
177
+ outOfBoundsCase (builder, loc);
178
+ return Value ();
179
+ });
126
180
}
127
181
128
182
// / Given an ArrayAttr, return a copy where the first element is dropped.
@@ -442,7 +496,7 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
442
496
.value ;
443
497
affineLoopBuilder (lb, ub, 1 , [&](Value iv) {
444
498
generateInBoundsCheck (
445
- xferOp, iv, rewriter,
499
+ xferOp, iv, rewriter, unpackedDim (xferOp),
446
500
/* inBoundsCase=*/
447
501
[&](OpBuilder & /* b*/ , Location loc) {
448
502
Strategy<OpTy>::rewriteOp (rewriter, xferOp, casted, iv);
@@ -458,6 +512,143 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
458
512
}
459
513
};
460
514
515
+ // / Compute the indices into the memref for the LoadOp/StoreOp generated as
516
+ // / part of Strided1dTransferOpConversion. Return the memref dimension on which
517
+ // / the transfer is operating.
518
+ template <typename OpTy>
519
+ static unsigned get1dMemrefIndices (OpTy xferOp, Value iv,
520
+ SmallVector<Value, 8 > &memrefIndices) {
521
+ auto indices = xferOp.indices ();
522
+ auto map = xferOp.permutation_map ();
523
+
524
+ memrefIndices.append (indices.begin (), indices.end ());
525
+ assert (map.getNumResults () == 1 &&
526
+ " Expected 1 permutation map result for 1D transfer" );
527
+ // TODO: Handle broadcast
528
+ auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>();
529
+ assert (expr && " Expected AffineDimExpr in permutation map result" );
530
+ auto dim = expr.getPosition ();
531
+ using edsc::op::operator +;
532
+ memrefIndices[dim] = memrefIndices[dim] + iv;
533
+ return dim;
534
+ }
535
+
536
+ // / Codegen strategy for Strided1dTransferOpConversion, depending on the
537
+ // / operation.
538
+ template <typename OpTy>
539
+ struct Strategy1d ;
540
+
541
+ // / Codegen strategy for TransferReadOp.
542
+ template <>
543
+ struct Strategy1d <TransferReadOp> {
544
+ static void generateForLoopBody (OpBuilder &builder, Location loc,
545
+ TransferReadOp xferOp, Value iv,
546
+ ValueRange loopState) {
547
+ SmallVector<Value, 8 > indices;
548
+ auto dim = get1dMemrefIndices (xferOp, iv, indices);
549
+ auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
550
+ auto vec = loopState[0 ];
551
+
552
+ // In case of out-of-bounds access, leave `vec` as is (was initialized with
553
+ // padding value).
554
+ auto nextVec = generateInBoundsCheck (
555
+ xferOp, iv, builder, dim, TypeRange (xferOp.getVectorType ()),
556
+ /* inBoundsCase=*/
557
+ [&](OpBuilder & /* b*/ , Location loc) {
558
+ auto val = memref_load (xferOp.source (), indices);
559
+ return vector_insert_element (val, vec, ivI32.value ).value ;
560
+ },
561
+ /* outOfBoundsCase=*/
562
+ [&](OpBuilder & /* b*/ , Location loc) { return vec; });
563
+ builder.create <scf::YieldOp>(loc, nextVec);
564
+ }
565
+
566
+ static Value initialLoopState (TransferReadOp xferOp) {
567
+ // Inititalize vector with padding value.
568
+ return std_splat (xferOp.getVectorType (), xferOp.padding ()).value ;
569
+ }
570
+ };
571
+
572
+ // / Codegen strategy for TransferWriteOp.
573
+ template <>
574
+ struct Strategy1d <TransferWriteOp> {
575
+ static void generateForLoopBody (OpBuilder &builder, Location loc,
576
+ TransferWriteOp xferOp, Value iv,
577
+ ValueRange /* loopState*/ ) {
578
+ SmallVector<Value, 8 > indices;
579
+ auto dim = get1dMemrefIndices (xferOp, iv, indices);
580
+ auto ivI32 = std_index_cast (IntegerType::get (builder.getContext (), 32 ), iv);
581
+
582
+ // Nothing to do in case of out-of-bounds access.
583
+ generateInBoundsCheck (
584
+ xferOp, iv, builder, dim,
585
+ /* inBoundsCase=*/ [&](OpBuilder & /* b*/ , Location loc) {
586
+ auto val = vector_extract_element (xferOp.vector (), ivI32.value );
587
+ memref_store (val, xferOp.source (), indices);
588
+ });
589
+ builder.create <scf::YieldOp>(loc);
590
+ }
591
+
592
+ static Value initialLoopState (TransferWriteOp xferOp) { return Value (); }
593
+ };
594
+
595
+ // / Lower a 1D vector transfer op that operates on a dimension different from
596
+ // / the last one. Instead of accessing contiguous chunks (vectors) of memory,
597
+ // / such ops access memory in a strided fashion.
598
+ // /
599
+ // / 1. Generate a for loop iterating over each vector element.
600
+ // / 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
601
+ // / depending on OpTy.
602
+ // /
603
+ // / E.g.:
604
+ // / ```
605
+ // / vector.transfer_write %vec, %A[%a, %b]
606
+ // / {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
607
+ // / : vector<9xf32>, memref<?x?xf32>
608
+ // / ```
609
+ // / Is rewritten to approximately the following pseudo-IR:
610
+ // / ```
611
+ // / for i = 0 to 9 {
612
+ // / %t = vector.extractelement %vec[i] : vector<9xf32>
613
+ // / memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
614
+ // / }
615
+ // / ```
616
+ template <typename OpTy>
617
+ struct Strided1dTransferOpConversion : public OpRewritePattern <OpTy> {
618
+ using OpRewritePattern<OpTy>::OpRewritePattern;
619
+
620
+ LogicalResult matchAndRewrite (OpTy xferOp,
621
+ PatternRewriter &rewriter) const override {
622
+ ScopedContext scope (rewriter, xferOp.getLoc ());
623
+ auto map = xferOp.permutation_map ();
624
+
625
+ if (xferOp.getVectorType ().getRank () != 1 )
626
+ return failure ();
627
+ if (map.isMinorIdentity ()) // Handled by ConvertVectorToLLVM
628
+ return failure ();
629
+ if (xferOp.mask ())
630
+ return failure ();
631
+
632
+ // Loop bounds, step, state...
633
+ auto vecType = xferOp.getVectorType ();
634
+ auto lb = std_constant_index (0 );
635
+ auto ub = std_constant_index (vecType.getDimSize (0 ));
636
+ auto step = std_constant_index (1 );
637
+ auto loopState = Strategy1d<OpTy>::initialLoopState (xferOp);
638
+
639
+ // Generate for loop.
640
+ rewriter.replaceOpWithNewOp <scf::ForOp>(
641
+ xferOp, lb, ub, step, loopState ? ValueRange (loopState) : ValueRange (),
642
+ [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
643
+ ScopedContext nestedScope (builder, loc);
644
+ Strategy1d<OpTy>::generateForLoopBody (builder, loc, xferOp, iv,
645
+ loopState);
646
+ });
647
+
648
+ return success ();
649
+ }
650
+ };
651
+
461
652
} // namespace
462
653
463
654
namespace mlir {
@@ -466,7 +657,10 @@ void populateProgressiveVectorToSCFConversionPatterns(
466
657
RewritePatternSet &patterns) {
467
658
patterns.add <PrepareTransferReadConversion, PrepareTransferWriteConversion,
468
659
TransferOpConversion<TransferReadOp>,
469
- TransferOpConversion<TransferWriteOp>>(patterns.getContext ());
660
+ TransferOpConversion<TransferWriteOp>,
661
+ Strided1dTransferOpConversion<TransferReadOp>,
662
+ Strided1dTransferOpConversion<TransferWriteOp>>(
663
+ patterns.getContext ());
470
664
}
471
665
472
666
struct ConvertProgressiveVectorToSCFPass
0 commit comments