Skip to content

Commit afaf36b

Browse files
[mlir] Handle strided 1D vector transfer ops in ProgressiveVectorToSCF
Strided 1D vector transfer ops are 1D transfers operating on a memref dimension different from the last one. Such transfer ops do not accesses contiguous memory blocks (vectors), but access memory in a strided fashion. In the absence of a mask, strided 1D vector transfer ops can also be lowered using matrix.column.major.* LLVM instructions (in a later commit). Subsequent commits will extend the pass to handle the remaining missing permutation maps (broadcasts, transposes, etc.). Differential Revision: https://reviews.llvm.org/D100946
1 parent 027d673 commit afaf36b

File tree

2 files changed

+275
-21
lines changed

2 files changed

+275
-21
lines changed

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Lines changed: 215 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,36 +93,90 @@ static void getXferIndices(OpTy xferOp, Value iv,
9393
indices[dim] = adaptor.indices()[dim] + iv;
9494
}
9595

96-
/// Generate an in-bounds check if the transfer op on the to-be-unpacked
97-
/// dimension may go out-of-bounds.
98-
template <typename OpTy>
99-
static void generateInBoundsCheck(
100-
OpTy xferOp, Value iv, PatternRewriter &rewriter,
101-
function_ref<void(OpBuilder &, Location)> inBoundsCase,
102-
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
103-
// Corresponding memref dim of the vector dim that is unpacked.
104-
auto dim = unpackedDim(xferOp);
96+
static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
97+
Value value) {
98+
if (hasRetVal) {
99+
builder.create<scf::YieldOp>(loc, value);
100+
} else {
101+
builder.create<scf::YieldOp>(loc);
102+
}
103+
}
105104

105+
/// Helper function TransferOpConversion and Strided1dTransferOpConversion.
106+
/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
107+
/// specified dimension `dim` with the loop iteration variable `iv`.
108+
/// E.g., when unpacking dimension 0 from:
109+
/// ```
110+
/// %vec = vector.transfer_read %A[%a, %b] %cst
111+
/// : vector<5x4xf32>, memref<?x?xf32>
112+
/// ```
113+
/// An if check similar to this will be generated inside the loop:
114+
/// ```
115+
/// %d = memref.dim %A, %c0 : memref<?x?xf32>
116+
/// if (%a + iv < %d) {
117+
/// (in-bounds case)
118+
/// } else {
119+
/// (out-of-bounds case)
120+
/// }
121+
/// ```
122+
/// This function variant returns the value returned by `inBoundsCase` or
123+
/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
124+
/// `resultTypes`.
125+
template <typename OpTy>
126+
static Value generateInBoundsCheck(
127+
OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
128+
TypeRange resultTypes,
129+
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
130+
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
131+
bool hasRetVal = !resultTypes.empty();
106132
if (!xferOp.isDimInBounds(0)) {
107133
auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim));
108134
using edsc::op::operator+;
109135
auto memrefIdx = xferOp.indices()[dim] + iv;
110136
auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
111-
rewriter.create<scf::IfOp>(
112-
xferOp.getLoc(), cond,
137+
auto check = builder.create<scf::IfOp>(
138+
xferOp.getLoc(), resultTypes, cond,
139+
/*thenBuilder=*/
113140
[&](OpBuilder &builder, Location loc) {
114-
inBoundsCase(builder, loc);
115-
builder.create<scf::YieldOp>(xferOp.getLoc());
141+
maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
116142
},
143+
/*elseBuilder=*/
117144
[&](OpBuilder &builder, Location loc) {
118-
if (outOfBoundsCase)
119-
outOfBoundsCase(builder, loc);
120-
builder.create<scf::YieldOp>(xferOp.getLoc());
145+
if (outOfBoundsCase) {
146+
maybeYieldValue(hasRetVal, builder, loc,
147+
outOfBoundsCase(builder, loc));
148+
} else {
149+
builder.create<scf::YieldOp>(loc);
150+
}
121151
});
122-
} else {
123-
// No runtime check needed if dim is guaranteed to be in-bounds.
124-
inBoundsCase(rewriter, xferOp.getLoc());
152+
153+
return hasRetVal ? check.getResult(0) : Value();
125154
}
155+
156+
// No runtime check needed if dim is guaranteed to be in-bounds.
157+
return inBoundsCase(builder, xferOp.getLoc());
158+
}
159+
160+
/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
161+
/// a return value. Consequently, this function does not have a return value.
162+
template <typename OpTy>
163+
static void generateInBoundsCheck(
164+
OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
165+
function_ref<void(OpBuilder &, Location)> inBoundsCase,
166+
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
167+
generateInBoundsCheck(
168+
xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
169+
/*inBoundsCase=*/
170+
[&](OpBuilder &builder, Location loc) {
171+
inBoundsCase(builder, loc);
172+
return Value();
173+
},
174+
/*outOfBoundsCase=*/
175+
[&](OpBuilder &builder, Location loc) {
176+
if (outOfBoundsCase)
177+
outOfBoundsCase(builder, loc);
178+
return Value();
179+
});
126180
}
127181

128182
/// Given an ArrayAttr, return a copy where the first element is dropped.
@@ -442,7 +496,7 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
442496
.value;
443497
affineLoopBuilder(lb, ub, 1, [&](Value iv) {
444498
generateInBoundsCheck(
445-
xferOp, iv, rewriter,
499+
xferOp, iv, rewriter, unpackedDim(xferOp),
446500
/*inBoundsCase=*/
447501
[&](OpBuilder & /*b*/, Location loc) {
448502
Strategy<OpTy>::rewriteOp(rewriter, xferOp, casted, iv);
@@ -458,6 +512,143 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
458512
}
459513
};
460514

515+
/// Compute the indices into the memref for the LoadOp/StoreOp generated as
516+
/// part of Strided1dTransferOpConversion. Return the memref dimension on which
517+
/// the transfer is operating.
518+
template <typename OpTy>
519+
static unsigned get1dMemrefIndices(OpTy xferOp, Value iv,
520+
SmallVector<Value, 8> &memrefIndices) {
521+
auto indices = xferOp.indices();
522+
auto map = xferOp.permutation_map();
523+
524+
memrefIndices.append(indices.begin(), indices.end());
525+
assert(map.getNumResults() == 1 &&
526+
"Expected 1 permutation map result for 1D transfer");
527+
// TODO: Handle broadcast
528+
auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
529+
assert(expr && "Expected AffineDimExpr in permutation map result");
530+
auto dim = expr.getPosition();
531+
using edsc::op::operator+;
532+
memrefIndices[dim] = memrefIndices[dim] + iv;
533+
return dim;
534+
}
535+
536+
/// Codegen strategy for Strided1dTransferOpConversion, depending on the
537+
/// operation.
538+
template <typename OpTy>
539+
struct Strategy1d;
540+
541+
/// Codegen strategy for TransferReadOp.
542+
template <>
543+
struct Strategy1d<TransferReadOp> {
544+
static void generateForLoopBody(OpBuilder &builder, Location loc,
545+
TransferReadOp xferOp, Value iv,
546+
ValueRange loopState) {
547+
SmallVector<Value, 8> indices;
548+
auto dim = get1dMemrefIndices(xferOp, iv, indices);
549+
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
550+
auto vec = loopState[0];
551+
552+
// In case of out-of-bounds access, leave `vec` as is (was initialized with
553+
// padding value).
554+
auto nextVec = generateInBoundsCheck(
555+
xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
556+
/*inBoundsCase=*/
557+
[&](OpBuilder & /*b*/, Location loc) {
558+
auto val = memref_load(xferOp.source(), indices);
559+
return vector_insert_element(val, vec, ivI32.value).value;
560+
},
561+
/*outOfBoundsCase=*/
562+
[&](OpBuilder & /*b*/, Location loc) { return vec; });
563+
builder.create<scf::YieldOp>(loc, nextVec);
564+
}
565+
566+
static Value initialLoopState(TransferReadOp xferOp) {
567+
// Inititalize vector with padding value.
568+
return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
569+
}
570+
};
571+
572+
/// Codegen strategy for TransferWriteOp.
573+
template <>
574+
struct Strategy1d<TransferWriteOp> {
575+
static void generateForLoopBody(OpBuilder &builder, Location loc,
576+
TransferWriteOp xferOp, Value iv,
577+
ValueRange /*loopState*/) {
578+
SmallVector<Value, 8> indices;
579+
auto dim = get1dMemrefIndices(xferOp, iv, indices);
580+
auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
581+
582+
// Nothing to do in case of out-of-bounds access.
583+
generateInBoundsCheck(
584+
xferOp, iv, builder, dim,
585+
/*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
586+
auto val = vector_extract_element(xferOp.vector(), ivI32.value);
587+
memref_store(val, xferOp.source(), indices);
588+
});
589+
builder.create<scf::YieldOp>(loc);
590+
}
591+
592+
static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
593+
};
594+
595+
/// Lower a 1D vector transfer op that operates on a dimension different from
596+
/// the last one. Instead of accessing contiguous chunks (vectors) of memory,
597+
/// such ops access memory in a strided fashion.
598+
///
599+
/// 1. Generate a for loop iterating over each vector element.
600+
/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
601+
/// depending on OpTy.
602+
///
603+
/// E.g.:
604+
/// ```
605+
/// vector.transfer_write %vec, %A[%a, %b]
606+
/// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
607+
/// : vector<9xf32>, memref<?x?xf32>
608+
/// ```
609+
/// Is rewritten to approximately the following pseudo-IR:
610+
/// ```
611+
/// for i = 0 to 9 {
612+
/// %t = vector.extractelement %vec[i] : vector<9xf32>
613+
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
614+
/// }
615+
/// ```
616+
template <typename OpTy>
617+
struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
618+
using OpRewritePattern<OpTy>::OpRewritePattern;
619+
620+
LogicalResult matchAndRewrite(OpTy xferOp,
621+
PatternRewriter &rewriter) const override {
622+
ScopedContext scope(rewriter, xferOp.getLoc());
623+
auto map = xferOp.permutation_map();
624+
625+
if (xferOp.getVectorType().getRank() != 1)
626+
return failure();
627+
if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
628+
return failure();
629+
if (xferOp.mask())
630+
return failure();
631+
632+
// Loop bounds, step, state...
633+
auto vecType = xferOp.getVectorType();
634+
auto lb = std_constant_index(0);
635+
auto ub = std_constant_index(vecType.getDimSize(0));
636+
auto step = std_constant_index(1);
637+
auto loopState = Strategy1d<OpTy>::initialLoopState(xferOp);
638+
639+
// Generate for loop.
640+
rewriter.replaceOpWithNewOp<scf::ForOp>(
641+
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
642+
[&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
643+
ScopedContext nestedScope(builder, loc);
644+
Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
645+
loopState);
646+
});
647+
648+
return success();
649+
}
650+
};
651+
461652
} // namespace
462653

463654
namespace mlir {
@@ -466,7 +657,10 @@ void populateProgressiveVectorToSCFConversionPatterns(
466657
RewritePatternSet &patterns) {
467658
patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
468659
TransferOpConversion<TransferReadOp>,
469-
TransferOpConversion<TransferWriteOp>>(patterns.getContext());
660+
TransferOpConversion<TransferWriteOp>,
661+
Strided1dTransferOpConversion<TransferReadOp>,
662+
Strided1dTransferOpConversion<TransferWriteOp>>(
663+
patterns.getContext());
470664
}
471665

472666
struct ConvertProgressiveVectorToSCFPass
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
2+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4+
// RUN: FileCheck %s
5+
6+
// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
7+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
8+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
9+
// RUN: FileCheck %s
10+
11+
// Test for special cases of 1D vector transfer ops.
12+
13+
func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
14+
%fm42 = constant -42.0: f32
15+
%f = vector.transfer_read %A[%base1, %base2], %fm42
16+
{permutation_map = affine_map<(d0, d1) -> (d0)>}
17+
: memref<?x?xf32>, vector<9xf32>
18+
vector.print %f: vector<9xf32>
19+
return
20+
}
21+
22+
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
23+
%fn1 = constant -1.0 : f32
24+
%vf0 = splat %fn1 : vector<7xf32>
25+
vector.transfer_write %vf0, %A[%base1, %base2]
26+
{permutation_map = affine_map<(d0, d1) -> (d0)>}
27+
: vector<7xf32>, memref<?x?xf32>
28+
return
29+
}
30+
31+
func @entry() {
32+
%c0 = constant 0: index
33+
%c1 = constant 1: index
34+
%c2 = constant 2: index
35+
%c3 = constant 3: index
36+
%f10 = constant 10.0: f32
37+
// work with dims of 4, not of 3
38+
%first = constant 5: index
39+
%second = constant 6: index
40+
%A = memref.alloc(%first, %second) : memref<?x?xf32>
41+
scf.for %i = %c0 to %first step %c1 {
42+
%i32 = index_cast %i : index to i32
43+
%fi = sitofp %i32 : i32 to f32
44+
%fi10 = mulf %fi, %f10 : f32
45+
scf.for %j = %c0 to %second step %c1 {
46+
%j32 = index_cast %j : index to i32
47+
%fj = sitofp %j32 : i32 to f32
48+
%fres = addf %fi10, %fj : f32
49+
memref.store %fres, %A[%i, %j] : memref<?x?xf32>
50+
}
51+
}
52+
53+
call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
54+
call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
55+
call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
56+
return
57+
}
58+
59+
// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
60+
// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )

0 commit comments

Comments
 (0)