@@ -70,13 +70,16 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
70
70
71
71
// / Given a vector transfer op, calculate which dimension of the `source`
72
72
// / memref should be unpacked in the next application of TransferOpConversion.
73
+ // / A return value of None indicates a broadcast.
73
74
template <typename OpTy>
74
- static unsigned unpackedDim (OpTy xferOp) {
75
+ static Optional< int64_t > unpackedDim (OpTy xferOp) {
75
76
auto map = xferOp.permutation_map ();
76
- // TODO: Handle broadcast
77
- auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>();
78
- assert (expr && " Expected AffineDimExpr in permutation map result" );
79
- return expr.getPosition ();
77
+ if (auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>())
78
+ return expr.getPosition ();
79
+
80
+ assert (map.getResult (0 ).template isa <AffineConstantExpr>() &&
81
+ " Expected AffineDimExpr or AffineConstantExpr" );
82
+ return None;
80
83
}
81
84
82
85
// / Compute the permutation map for the new (N-1)-D vector transfer op. This
@@ -103,8 +106,12 @@ static void getXferIndices(OpTy xferOp, Value iv,
103
106
auto dim = unpackedDim (xferOp);
104
107
auto prevIndices = adaptor.indices ();
105
108
indices.append (prevIndices.begin (), prevIndices.end ());
106
- using edsc::op::operator +;
107
- indices[dim] = adaptor.indices ()[dim] + iv;
109
+
110
+ bool isBroadcast = !dim.hasValue ();
111
+ if (!isBroadcast) {
112
+ using edsc::op::operator +;
113
+ indices[dim.getValue ()] = adaptor.indices ()[dim.getValue ()] + iv;
114
+ }
108
115
}
109
116
110
117
static void maybeYieldValue (bool hasRetVal, OpBuilder builder, Location loc,
@@ -116,7 +123,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
116
123
}
117
124
}
118
125
119
- // / Helper function TransferOpConversion and Strided1dTransferOpConversion .
126
+ // / Helper function TransferOpConversion and TransferOp1dConversion .
120
127
// / Generate an in-bounds check if the transfer op may go out-of-bounds on the
121
128
// / specified dimension `dim` with the loop iteration variable `iv`.
122
129
// / E.g., when unpacking dimension 0 from:
@@ -138,15 +145,17 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
138
145
// / `resultTypes`.
139
146
template <typename OpTy>
140
147
static Value generateInBoundsCheck (
141
- OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim,
148
+ OpTy xferOp, Value iv, OpBuilder &builder, Optional< int64_t > dim,
142
149
TypeRange resultTypes,
143
150
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
144
151
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
145
152
bool hasRetVal = !resultTypes.empty ();
146
- if (!xferOp.isDimInBounds (0 )) {
147
- auto memrefDim = memref_dim (xferOp.source (), std_constant_index (dim));
153
+ bool isBroadcast = !dim.hasValue (); // No in-bounds check for broadcasts.
154
+ if (!xferOp.isDimInBounds (0 ) && !isBroadcast) {
155
+ auto memrefDim =
156
+ memref_dim (xferOp.source (), std_constant_index (dim.getValue ()));
148
157
using edsc::op::operator +;
149
- auto memrefIdx = xferOp.indices ()[dim] + iv;
158
+ auto memrefIdx = xferOp.indices ()[dim. getValue () ] + iv;
150
159
auto cond = std_cmpi_sgt (memrefDim.value , memrefIdx);
151
160
auto check = builder.create <scf::IfOp>(
152
161
xferOp.getLoc (), resultTypes, cond,
@@ -175,7 +184,7 @@ static Value generateInBoundsCheck(
175
184
// / a return value. Consequently, this function does not have a return value.
176
185
template <typename OpTy>
177
186
static void generateInBoundsCheck (
178
- OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
187
+ OpTy xferOp, Value iv, OpBuilder &builder, Optional< int64_t > dim,
179
188
function_ref<void (OpBuilder &, Location)> inBoundsCase,
180
189
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
181
190
generateInBoundsCheck (
@@ -534,27 +543,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
534
543
};
535
544
536
545
// / Compute the indices into the memref for the LoadOp/StoreOp generated as
537
- // / part of Strided1dTransferOpConversion . Return the memref dimension on which
538
- // / the transfer is operating.
546
+ // / part of TransferOp1dConversion . Return the memref dimension on which
547
+ // / the transfer is operating. A return value of None indicates a broadcast.
539
548
template <typename OpTy>
540
- static unsigned get1dMemrefIndices (OpTy xferOp, Value iv,
541
- SmallVector<Value, 8 > &memrefIndices) {
549
+ static Optional<int64_t >
550
+ get1dMemrefIndices (OpTy xferOp, Value iv,
551
+ SmallVector<Value, 8 > &memrefIndices) {
542
552
auto indices = xferOp.indices ();
543
553
auto map = xferOp.permutation_map ();
544
554
545
555
memrefIndices.append (indices.begin (), indices.end ());
546
556
assert (map.getNumResults () == 1 &&
547
557
" Expected 1 permutation map result for 1D transfer" );
548
- // TODO: Handle broadcast
549
- auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>();
550
- assert (expr && " Expected AffineDimExpr in permutation map result" );
551
- auto dim = expr.getPosition ();
552
- using edsc::op::operator +;
553
- memrefIndices[dim] = memrefIndices[dim] + iv;
554
- return dim;
558
+ if (auto expr = map.getResult (0 ).template dyn_cast <AffineDimExpr>()) {
559
+ auto dim = expr.getPosition ();
560
+ using edsc::op::operator +;
561
+ memrefIndices[dim] = memrefIndices[dim] + iv;
562
+ return dim;
563
+ }
564
+
565
+ assert (map.getResult (0 ).template isa <AffineConstantExpr>() &&
566
+ " Expected AffineDimExpr or AffineConstantExpr" );
567
+ return None;
555
568
}
556
569
557
- // / Codegen strategy for Strided1dTransferOpConversion , depending on the
570
+ // / Codegen strategy for TransferOp1dConversion , depending on the
558
571
// / operation.
559
572
template <typename OpTy>
560
573
struct Strategy1d ;
@@ -613,14 +626,24 @@ struct Strategy1d<TransferWriteOp> {
613
626
static Value initialLoopState (TransferWriteOp xferOp) { return Value (); }
614
627
};
615
628
616
- // / Lower a 1D vector transfer op that operates on a dimension different from
617
- // / the last one. Instead of accessing contiguous chunks (vectors) of memory,
618
- // / such ops access memory in a strided fashion.
629
+ // / Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
630
+ // / necessary in cases where a 1D vector transfer op cannot be lowered into
631
+ // / vector load/stores due to non-unit strides or broadcasts:
632
+ // /
633
+ // / * Transfer dimension is not the last memref dimension
634
+ // / * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
635
+ // / * Memref has a layout map with non-unit stride on the last dimension
636
+ // /
637
+ // / This pattern generates IR as follows:
619
638
// /
620
639
// / 1. Generate a for loop iterating over each vector element.
621
640
// / 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
622
641
// / depending on OpTy.
623
642
// /
643
+ // / TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
644
+ // / can be generated instead of TransferOp1dConversion. Add such a pattern
645
+ // / to ConvertVectorToLLVM.
646
+ // /
624
647
// / E.g.:
625
648
// / ```
626
649
// / vector.transfer_write %vec, %A[%a, %b]
@@ -635,7 +658,7 @@ struct Strategy1d<TransferWriteOp> {
635
658
// / }
636
659
// / ```
637
660
template <typename OpTy>
638
- struct Strided1dTransferOpConversion : public OpRewritePattern <OpTy> {
661
+ struct TransferOp1dConversion : public OpRewritePattern <OpTy> {
639
662
using OpRewritePattern<OpTy>::OpRewritePattern;
640
663
641
664
LogicalResult matchAndRewrite (OpTy xferOp,
@@ -681,8 +704,8 @@ void populateProgressiveVectorToSCFConversionPatterns(
681
704
TransferOpConversion<TransferWriteOp>>(patterns.getContext ());
682
705
683
706
if (kTargetRank == 1 ) {
684
- patterns.add <Strided1dTransferOpConversion <TransferReadOp>,
685
- Strided1dTransferOpConversion <TransferWriteOp>>(
707
+ patterns.add <TransferOp1dConversion <TransferReadOp>,
708
+ TransferOp1dConversion <TransferWriteOp>>(
686
709
patterns.getContext ());
687
710
}
688
711
}
0 commit comments