Skip to content

Commit a819e73

Browse files
[mlir] Support broadcast dimensions in ProgressiveVectorToSCF
This commit adds support for broadcast dimensions in permutation maps of vector transfer ops. Also fixes a bug in VectorToSCF that generated incorrect in-bounds checks for broadcast dimensions. Differential Revision: https://reviews.llvm.org/D101019
1 parent 74854d0 commit a819e73

File tree

5 files changed

+107
-32
lines changed

5 files changed

+107
-32
lines changed

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ static Value setAllocAtFunctionEntry(MemRefType type, Operation *op) {
7070

7171
/// Given a vector transfer op, calculate which dimension of the `source`
7272
/// memref should be unpacked in the next application of TransferOpConversion.
73+
/// A return value of None indicates a broadcast.
7374
template <typename OpTy>
74-
static unsigned unpackedDim(OpTy xferOp) {
75+
static Optional<int64_t> unpackedDim(OpTy xferOp) {
7576
auto map = xferOp.permutation_map();
76-
// TODO: Handle broadcast
77-
auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
78-
assert(expr && "Expected AffineDimExpr in permutation map result");
79-
return expr.getPosition();
77+
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>())
78+
return expr.getPosition();
79+
80+
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
81+
"Expected AffineDimExpr or AffineConstantExpr");
82+
return None;
8083
}
8184

8285
/// Compute the permutation map for the new (N-1)-D vector transfer op. This
@@ -103,8 +106,12 @@ static void getXferIndices(OpTy xferOp, Value iv,
103106
auto dim = unpackedDim(xferOp);
104107
auto prevIndices = adaptor.indices();
105108
indices.append(prevIndices.begin(), prevIndices.end());
106-
using edsc::op::operator+;
107-
indices[dim] = adaptor.indices()[dim] + iv;
109+
110+
bool isBroadcast = !dim.hasValue();
111+
if (!isBroadcast) {
112+
using edsc::op::operator+;
113+
indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv;
114+
}
108115
}
109116

110117
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
@@ -116,7 +123,7 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
116123
}
117124
}
118125

119-
/// Helper function TransferOpConversion and Strided1dTransferOpConversion.
126+
/// Helper function TransferOpConversion and TransferOp1dConversion.
120127
/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
121128
/// specified dimension `dim` with the loop iteration variable `iv`.
122129
/// E.g., when unpacking dimension 0 from:
@@ -138,15 +145,17 @@ static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
138145
/// `resultTypes`.
139146
template <typename OpTy>
140147
static Value generateInBoundsCheck(
141-
OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim,
148+
OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
142149
TypeRange resultTypes,
143150
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
144151
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
145152
bool hasRetVal = !resultTypes.empty();
146-
if (!xferOp.isDimInBounds(0)) {
147-
auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim));
153+
bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
154+
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
155+
auto memrefDim =
156+
memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
148157
using edsc::op::operator+;
149-
auto memrefIdx = xferOp.indices()[dim] + iv;
158+
auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
150159
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
151160
auto check = builder.create<scf::IfOp>(
152161
xferOp.getLoc(), resultTypes, cond,
@@ -175,7 +184,7 @@ static Value generateInBoundsCheck(
175184
/// a return value. Consequently, this function does not have a return value.
176185
template <typename OpTy>
177186
static void generateInBoundsCheck(
178-
OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
187+
OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
179188
function_ref<void(OpBuilder &, Location)> inBoundsCase,
180189
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
181190
generateInBoundsCheck(
@@ -534,27 +543,31 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
534543
};
535544

536545
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
537-
/// part of Strided1dTransferOpConversion. Return the memref dimension on which
538-
/// the transfer is operating.
546+
/// part of TransferOp1dConversion. Return the memref dimension on which
547+
/// the transfer is operating. A return value of None indicates a broadcast.
539548
template <typename OpTy>
540-
static unsigned get1dMemrefIndices(OpTy xferOp, Value iv,
541-
SmallVector<Value, 8> &memrefIndices) {
549+
static Optional<int64_t>
550+
get1dMemrefIndices(OpTy xferOp, Value iv,
551+
SmallVector<Value, 8> &memrefIndices) {
542552
auto indices = xferOp.indices();
543553
auto map = xferOp.permutation_map();
544554

545555
memrefIndices.append(indices.begin(), indices.end());
546556
assert(map.getNumResults() == 1 &&
547557
"Expected 1 permutation map result for 1D transfer");
548-
// TODO: Handle broadcast
549-
auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
550-
assert(expr && "Expected AffineDimExpr in permutation map result");
551-
auto dim = expr.getPosition();
552-
using edsc::op::operator+;
553-
memrefIndices[dim] = memrefIndices[dim] + iv;
554-
return dim;
558+
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
559+
auto dim = expr.getPosition();
560+
using edsc::op::operator+;
561+
memrefIndices[dim] = memrefIndices[dim] + iv;
562+
return dim;
563+
}
564+
565+
assert(map.getResult(0).template isa<AffineConstantExpr>() &&
566+
"Expected AffineDimExpr or AffineConstantExpr");
567+
return None;
555568
}
556569

557-
/// Codegen strategy for Strided1dTransferOpConversion, depending on the
570+
/// Codegen strategy for TransferOp1dConversion, depending on the
558571
/// operation.
559572
template <typename OpTy>
560573
struct Strategy1d;
@@ -613,14 +626,24 @@ struct Strategy1d<TransferWriteOp> {
613626
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
614627
};
615628

616-
/// Lower a 1D vector transfer op that operates on a dimension different from
617-
/// the last one. Instead of accessing contiguous chunks (vectors) of memory,
618-
/// such ops access memory in a strided fashion.
629+
/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
630+
/// necessary in cases where a 1D vector transfer op cannot be lowered into
631+
/// vector load/stores due to non-unit strides or broadcasts:
632+
///
633+
/// * Transfer dimension is not the last memref dimension
634+
/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
635+
/// * Memref has a layout map with non-unit stride on the last dimension
636+
///
637+
/// This pattern generates IR as follows:
619638
///
620639
/// 1. Generate a for loop iterating over each vector element.
621640
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
622641
/// depending on OpTy.
623642
///
643+
/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
644+
/// can be generated instead of TransferOp1dConversion. Add such a pattern
645+
/// to ConvertVectorToLLVM.
646+
///
624647
/// E.g.:
625648
/// ```
626649
/// vector.transfer_write %vec, %A[%a, %b]
@@ -635,7 +658,7 @@ struct Strategy1d<TransferWriteOp> {
635658
/// }
636659
/// ```
637660
template <typename OpTy>
638-
struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
661+
struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
639662
using OpRewritePattern<OpTy>::OpRewritePattern;
640663

641664
LogicalResult matchAndRewrite(OpTy xferOp,
@@ -681,8 +704,8 @@ void populateProgressiveVectorToSCFConversionPatterns(
681704
TransferOpConversion<TransferWriteOp>>(patterns.getContext());
682705

683706
if (kTargetRank == 1) {
684-
patterns.add<Strided1dTransferOpConversion<TransferReadOp>,
685-
Strided1dTransferOpConversion<TransferWriteOp>>(
707+
patterns.add<TransferOp1dConversion<TransferReadOp>,
708+
TransferOp1dConversion<TransferWriteOp>>(
686709
patterns.getContext());
687710
}
688711
}

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ emitInBoundsCondition(PatternRewriter &rewriter,
230230
Value iv = std::get<0>(it), off = std::get<1>(it), ub = std::get<2>(it);
231231
using namespace mlir::edsc::op;
232232
majorIvsPlusOffsets.push_back(iv + off);
233-
if (!xferOp.isDimInBounds(leadingRank + idx)) {
233+
auto affineConstExpr =
234+
xferOp.permutation_map().getResult(idx).dyn_cast<AffineConstantExpr>();
235+
bool isBroadcast = affineConstExpr && affineConstExpr.getValue() == 0;
236+
if (!xferOp.isDimInBounds(leadingRank + idx) && !isBroadcast) {
234237
Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub);
235238
if (inBoundsCond)
236239
inBoundsCondition = (inBoundsCondition)

mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010

1111
// Test for special cases of 1D vector transfer ops.
1212

13+
func @transfer_read_2d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
14+
%fm42 = constant -42.0: f32
15+
%f = vector.transfer_read %A[%base1, %base2], %fm42
16+
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
17+
: memref<?x?xf32>, vector<5x6xf32>
18+
vector.print %f: vector<5x6xf32>
19+
return
20+
}
21+
1322
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
1423
%fm42 = constant -42.0: f32
1524
%f = vector.transfer_read %A[%base1, %base2], %fm42
@@ -19,6 +28,16 @@ func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
1928
return
2029
}
2130

31+
func @transfer_read_1d_broadcast(
32+
%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
33+
%fm42 = constant -42.0: f32
34+
%f = vector.transfer_read %A[%base1, %base2], %fm42
35+
{permutation_map = affine_map<(d0, d1) -> (0)>}
36+
: memref<?x?xf32>, vector<9xf32>
37+
vector.print %f: vector<9xf32>
38+
return
39+
}
40+
2241
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
2342
%fn1 = constant -1.0 : f32
2443
%vf0 = splat %fn1 : vector<7xf32>
@@ -53,8 +72,11 @@ func @entry() {
5372
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
5473
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
5574
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
75+
call @transfer_read_1d_broadcast(%A, %c1, %c2)
76+
: (memref<?x?xf32>, index, index) -> ()
5677
return
5778
}
5879

5980
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
6081
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )
82+
// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 )

mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ func @transfer_read_2d_transposed(
2727
return
2828
}
2929

30+
func @transfer_read_2d_broadcast(
31+
%A : memref<?x?xf32>, %base1: index, %base2: index) {
32+
%fm42 = constant -42.0: f32
33+
%f = vector.transfer_read %A[%base1, %base2], %fm42
34+
{permutation_map = affine_map<(d0, d1) -> (d1, 0)>} :
35+
memref<?x?xf32>, vector<4x9xf32>
36+
vector.print %f: vector<4x9xf32>
37+
return
38+
}
39+
3040
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
3141
%fn1 = constant -1.0 : f32
3242
%vf0 = splat %fn1 : vector<1x4xf32>
@@ -73,10 +83,14 @@ func @entry() {
7383
// Same as above, but transposed
7484
call @transfer_read_2d_transposed(%A, %c0, %c0)
7585
: (memref<?x?xf32>, index, index) -> ()
86+
// Second vector dimension is a broadcast
87+
call @transfer_read_2d_broadcast(%A, %c1, %c2)
88+
: (memref<?x?xf32>, index, index) -> ()
7689
return
7790
}
7891

7992
// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
8093
// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
8194
// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )
8295
// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) )
96+
// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) )

mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ func @transfer_read_3d(%A : memref<?x?x?x?xf32>,
1919
return
2020
}
2121

22+
func @transfer_read_3d_broadcast(%A : memref<?x?x?x?xf32>,
23+
%o: index, %a: index, %b: index, %c: index) {
24+
%fm42 = constant -42.0: f32
25+
%f = vector.transfer_read %A[%o, %a, %b, %c], %fm42
26+
{permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>}
27+
: memref<?x?x?x?xf32>, vector<2x5x3xf32>
28+
vector.print %f: vector<2x5x3xf32>
29+
return
30+
}
31+
2232
func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
2333
%o: index, %a: index, %b: index, %c: index) {
2434
%fm42 = constant -42.0: f32
@@ -78,9 +88,12 @@ func @entry() {
7888
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
7989
call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0)
8090
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
91+
call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0)
92+
: (memref<?x?x?x?xf32>, index, index, index, index) -> ()
8193
return
8294
}
8395

8496
// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) )
8597
// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) )
8698
// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) )
99+
// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )

0 commit comments

Comments
 (0)