Skip to content

Commit bba9af2

Browse files
committed
[mlir][vector] Decouple unrolling gather and gather to llvm lowering
1 parent 64555e3 commit bba9af2

File tree

7 files changed

+36
-87
lines changed

7 files changed

+36
-87
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
244244
/// [FlattenGather]
245245
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
246246
/// outermost dimension.
247+
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
248+
PatternBenefit benefit = 1);
249+
250+
/// Populate the pattern set with the following patterns:
247251
///
248252
/// [Gather1DToConditionalLoads]
249253
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
250254
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
251255
/// loads/extracts are made conditional using `scf.if` ops.
252-
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
253-
PatternBenefit benefit = 1);
256+
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
257+
PatternBenefit benefit = 1);
254258

255259
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
256260
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -269,49 +269,32 @@ class VectorGatherOpConversion
269269
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
270270
return failure();
271271

272+
VectorType vType = gather.getVectorType();
273+
if (vType.getRank() > 1)
274+
return failure();
275+
272276
auto loc = gather->getLoc();
273277

274278
// Resolve alignment.
275279
unsigned align;
276280
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
277281
return failure();
278282

283+
// Resolve address.
279284
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
280285
adaptor.getIndices(), rewriter);
281286
Value base = adaptor.getBase();
282287

283-
auto llvmNDVectorTy = adaptor.getIndexVec().getType();
284288
// Handle the simple case of 1-D vector.
285-
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
286-
auto vType = gather.getVectorType();
287-
// Resolve address.
288-
Value ptrs =
289-
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
290-
base, ptr, adaptor.getIndexVec(), vType);
291-
// Replace with the gather intrinsic.
292-
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293-
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
294-
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
295-
return success();
296-
}
297-
298-
const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
299-
auto callback = [align, memRefType, base, ptr, loc, &rewriter,
300-
&typeConverter](Type llvm1DVectorTy,
301-
ValueRange vectorOperands) {
302-
// Resolve address.
303-
Value ptrs = getIndexedPtrs(
304-
rewriter, loc, typeConverter, memRefType, base, ptr,
305-
/*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
306-
// Create the gather intrinsic.
307-
return rewriter.create<LLVM::masked_gather>(
308-
loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
309-
/*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
310-
};
311-
SmallVector<Value> vectorOperands = {
312-
adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
313-
return LLVM::detail::handleMultidimensionalVectors(
314-
gather, vectorOperands, *getTypeConverter(), callback, rewriter);
289+
// Resolve address.
290+
Value ptrs =
291+
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
292+
base, ptr, adaptor.getIndexVec(), vType);
293+
// Replace with the gather intrinsic.
294+
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
295+
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
296+
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
297+
return success();
315298
}
316299
};
317300

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8181
populateVectorInsertExtractStridedSliceTransforms(patterns);
8282
populateVectorStepLoweringPatterns(patterns);
8383
populateVectorRankReducingFMAPattern(patterns);
84+
populateVectorGatherLoweringPatterns(patterns);
8485
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8586
}
8687

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using namespace mlir;
3838
using namespace mlir::vector;
3939

4040
namespace {
41-
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
41+
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
4242
/// outermost dimension. For example:
4343
/// ```
4444
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
5656
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
5757
///
5858
/// Supports vector types with a fixed leading dimension.
59-
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
59+
struct UnrollGather : OpRewritePattern<vector::GatherOp> {
6060
using OpRewritePattern::OpRewritePattern;
6161

6262
LogicalResult matchAndRewrite(vector::GatherOp op,
6363
PatternRewriter &rewriter) const override {
6464
VectorType resultTy = op.getType();
6565
if (resultTy.getRank() < 2)
66-
return rewriter.notifyMatchFailure(op, "already flat");
66+
return rewriter.notifyMatchFailure(op, "already 1-D");
6767

6868
// Unrolling doesn't take vscale into account. Pattern is disabled for
6969
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
107107
/// ```mlir
108108
/// %subview = memref.subview %M (...)
109109
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110-
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
110+
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
111+
/// strided<[3]>>
111112
/// ```
112113
/// ==>
113114
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
269270

270271
void mlir::vector::populateVectorGatherLoweringPatterns(
271272
RewritePatternSet &patterns, PatternBenefit benefit) {
272-
patterns.add<FlattenGather, RemoveStrideFromGatherSource,
273-
Gather1DToConditionalLoads>(patterns.getContext(), benefit);
273+
patterns.add<UnrollGather>(patterns.getContext(), benefit);
274+
}
275+
276+
void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
277+
RewritePatternSet &patterns, PatternBenefit benefit) {
278+
patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
279+
patterns.getContext(), benefit);
274280
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
20742074

20752075
// -----
20762076

2077-
func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
2078-
%0 = arith.constant 0: index
2079-
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
2080-
return %1 : vector<2x3xf32>
2081-
}
2082-
2083-
// CHECK-LABEL: func @gather_2d_from_1d
2084-
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2085-
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
2086-
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
2087-
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2088-
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2089-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2090-
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
2091-
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
2092-
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
2093-
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2094-
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
2095-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
2096-
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
2097-
2098-
// -----
2099-
2100-
func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
2101-
%0 = arith.constant 0: index
2102-
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
2103-
return %1 : vector<2x[3]xf32>
2104-
}
2105-
2106-
// CHECK-LABEL: func @gather_2d_from_1d_scalable
2107-
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
2108-
// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
2109-
// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
2110-
// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2111-
// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2112-
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2113-
// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
2114-
// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
2115-
// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
2116-
// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2117-
// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
2118-
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
2119-
// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
2120-
2121-
// -----
2122-
21232077

21242078
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
21252079
%0 = arith.constant 3 : index

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
16631663

16641664
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
16651665
%0 = arith.constant 0: index
1666-
%1 = vector.constant_mask [1, 2] : vector<2x3xi1>
1666+
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
16671667
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
16681668
return %2 : vector<2x3xf32>
16691669
}
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
16791679
// vector.constant_mask only supports 'none set' or 'all set' scalable
16801680
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
16811681
// width vectors above.
1682-
%1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
1682+
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
16831683
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
16841684
return %2 : vector<2x[3]xf32>
16851685
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
782782
void runOnOperation() override {
783783
RewritePatternSet patterns(&getContext());
784784
populateVectorGatherLoweringPatterns(patterns);
785+
populateVectorGatherToConditionalLoadPatterns(patterns);
785786
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
786787
}
787788
};

0 commit comments

Comments
 (0)