Skip to content

Commit 3715de9

Browse files
electricliliesLily Orth-Smith
andauthored
[mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (llvm#137389)
In ConvertVectorToLLVM, the only option for setting alignment of `vector.gather`, `vector.scatter`, and the `vector.load/store` ops was to extract it from the datatype of the memref type. However, this is insufficient for hardware backends requiring alignment of vector types. This PR introduces the `use-vector-alignment` option to the `ConvertVectorToLLVMPass`, which makes the pass use the alignment of the vector type of these operations instead of the alignment of the memref type. --------- Co-authored-by: Lily Orth-Smith <[email protected]>
1 parent 1101b76 commit 3715de9

File tree

7 files changed

+206
-19
lines changed

7 files changed

+206
-19
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14111411
"bool", /*default=*/"true",
14121412
"Allows compiler to assume vector indices fit in 32-bit if that "
14131413
"yields faster code">,
1414+
Option<"useVectorAlignment", "use-vector-alignment",
1415+
"bool", /*default=*/"false",
1416+
"Use the preferred alignment of a vector type in load/store "
1417+
"operations instead of the alignment of the element type of the "
1418+
"memref. This flag is intended for use with hardware which requires"
1419+
"vector alignment, or in application contexts where it is known all "
1420+
"vector access are naturally aligned. ">,
14141421
Option<"amx", "enable-amx",
14151422
"bool", /*default=*/"false",
14161423
"Enables the use of AMX dialect while lowering the vector "

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
2222
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
2323
void populateVectorToLLVMConversionPatterns(
2424
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
25-
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
25+
bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
26+
bool useVectorAlignment = false);
2627

2728
namespace vector {
2829
void registerConvertVectorToLLVMInterface(DialectRegistry &registry);

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
3232

3333
let arguments = (ins
3434
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
35-
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
35+
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
36+
DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
3637
let assemblyFormat = "attr-dict";
3738
}
3839

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
6767
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
6868
}
6969

70+
// Helper that returns data layout alignment of a vector.
71+
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
72+
VectorType vectorType, unsigned &align) {
73+
Type convertedVectorTy = typeConverter.convertType(vectorType);
74+
if (!convertedVectorTy)
75+
return failure();
76+
77+
llvm::LLVMContext llvmContext;
78+
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
79+
.getPreferredAlignment(convertedVectorTy,
80+
typeConverter.getDataLayout());
81+
82+
return success();
83+
}
84+
7085
// Helper that returns data layout alignment of a memref.
7186
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
7287
MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
8297
return success();
8398
}
8499

100+
// Helper to resolve the alignment for vector load/store, gather and scatter
101+
// ops. If useVectorAlignment is true, get the preferred alignment for the
102+
// vector type in the operation. This option is used for hardware backends with
103+
// vectorization. Otherwise, use the preferred alignment of the element type of
104+
// the memref. Note that if you choose to use vector alignment, the shape of the
105+
// vector type must be resolved before the ConvertVectorToLLVM pass is run.
106+
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
107+
VectorType vectorType,
108+
MemRefType memrefType, unsigned &align,
109+
bool useVectorAlignment) {
110+
if (useVectorAlignment) {
111+
if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
112+
return failure();
113+
}
114+
} else {
115+
if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
116+
return failure();
117+
}
118+
}
119+
return success();
120+
}
121+
85122
// Check if the last stride is non-unit and has a valid memory space.
86123
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
87124
const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
224261
template <class LoadOrStoreOp>
225262
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
226263
public:
264+
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
265+
bool useVectorAlign)
266+
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267+
useVectorAlignment(useVectorAlign) {}
227268
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
228269

229270
LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
240281

241282
// Resolve alignment.
242283
unsigned align;
243-
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
244-
return failure();
284+
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
285+
memRefTy, align, useVectorAlignment)))
286+
return rewriter.notifyMatchFailure(loadOrStoreOp,
287+
"could not resolve alignment");
245288

246289
// Resolve address.
247290
auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
252295
rewriter);
253296
return success();
254297
}
298+
299+
private:
300+
// If true, use the preferred alignment of the vector type.
301+
// If false, use the preferred alignment of the element type
302+
// of the memref. This flag is intended for use with hardware
303+
// backends that require alignment of vector operations.
304+
const bool useVectorAlignment;
255305
};
256306

257307
/// Conversion pattern for a vector.gather.
258308
class VectorGatherOpConversion
259309
: public ConvertOpToLLVMPattern<vector::GatherOp> {
260310
public:
311+
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
312+
bool useVectorAlign)
313+
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314+
useVectorAlignment(useVectorAlign) {}
261315
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
262316

263317
LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
278332

279333
// Resolve alignment.
280334
unsigned align;
281-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
282-
return rewriter.notifyMatchFailure(gather,
283-
"could not resolve memref alignment");
284-
}
335+
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
336+
memRefType, align, useVectorAlignment)))
337+
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
285338

286339
// Resolve address.
287340
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
297350
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
298351
return success();
299352
}
353+
354+
private:
355+
// If true, use the preferred alignment of the vector type.
356+
// If false, use the preferred alignment of the element type
357+
// of the memref. This flag is intended for use with hardware
358+
// backends that require alignment of vector operations.
359+
const bool useVectorAlignment;
300360
};
301361

302362
/// Conversion pattern for a vector.scatter.
303363
class VectorScatterOpConversion
304364
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
305365
public:
366+
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
367+
bool useVectorAlign)
368+
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369+
useVectorAlignment(useVectorAlign) {}
370+
306371
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
307372

308373
LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
322387

323388
// Resolve alignment.
324389
unsigned align;
325-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
390+
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
391+
memRefType, align, useVectorAlignment)))
326392
return rewriter.notifyMatchFailure(scatter,
327-
"could not resolve memref alignment");
328-
}
393+
"could not resolve alignment");
329394

330395
// Resolve address.
331396
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
340405
rewriter.getI32IntegerAttr(align));
341406
return success();
342407
}
408+
409+
private:
410+
// If true, use the preferred alignment of the vector type.
411+
// If false, use the preferred alignment of the element type
412+
// of the memref. This flag is intended for use with hardware
413+
// backends that require alignment of vector operations.
414+
const bool useVectorAlignment;
343415
};
344416

345417
/// Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
19282000
/// Populate the given list with patterns that convert from Vector to LLVM.
19292001
void mlir::populateVectorToLLVMConversionPatterns(
19302002
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1931-
bool reassociateFPReductions, bool force32BitVectorIndices) {
2003+
bool reassociateFPReductions, bool force32BitVectorIndices,
2004+
bool useVectorAlignment) {
19322005
// This function populates only ConversionPatterns, not RewritePatterns.
19332006
MLIRContext *ctx = converter.getDialect()->getContext();
19342007
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
19352008
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2009+
patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2010+
VectorLoadStoreConversion<vector::MaskedLoadOp>,
2011+
VectorLoadStoreConversion<vector::StoreOp>,
2012+
VectorLoadStoreConversion<vector::MaskedStoreOp>,
2013+
VectorGatherOpConversion, VectorScatterOpConversion>(
2014+
converter, useVectorAlignment);
19362015
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
19372016
VectorExtractElementOpConversion, VectorExtractOpConversion,
19382017
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
19392018
VectorInsertOpConversion, VectorPrintOpConversion,
19402019
VectorTypeCastOpConversion, VectorScaleOpConversion,
1941-
VectorLoadStoreConversion<vector::LoadOp>,
1942-
VectorLoadStoreConversion<vector::MaskedLoadOp>,
1943-
VectorLoadStoreConversion<vector::StoreOp>,
1944-
VectorLoadStoreConversion<vector::MaskedStoreOp>,
1945-
VectorGatherOpConversion, VectorScatterOpConversion,
19462020
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
19472021
VectorSplatOpLowering, VectorSplatNdOpLowering,
19482022
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9292
populateVectorTransferLoweringPatterns(patterns);
9393
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
9494
populateVectorToLLVMConversionPatterns(
95-
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
95+
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
96+
useVectorAlignment);
9697
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
9798

9899
// Architecture specific augmentations.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
3434
TypeConverter &typeConverter, RewritePatternSet &patterns) {
3535
populateVectorToLLVMConversionPatterns(
3636
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37-
getReassociateFpReductions(), getForce_32bitVectorIndices());
37+
getReassociateFpReductions(), getForce_32bitVectorIndices(),
38+
getUseVectorAlignment());
3839
}
3940

4041
LogicalResult
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
2+
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
3+
4+
5+
//===----------------------------------------------------------------------===//
6+
// vector.load
7+
//===----------------------------------------------------------------------===//
8+
9+
func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
10+
%0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
11+
return %0 : vector<8xf32>
12+
}
13+
14+
// ALL-LABEL: func @load
15+
16+
// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
17+
// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
18+
19+
// -----
20+
21+
//===----------------------------------------------------------------------===//
22+
// vector.store
23+
//===----------------------------------------------------------------------===//
24+
25+
func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
26+
%val = arith.constant dense<11.0> : vector<4xf32>
27+
vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
28+
return
29+
}
30+
31+
// ALL-LABEL: func @store
32+
33+
// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
34+
// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr
35+
36+
// -----
37+
38+
//===----------------------------------------------------------------------===//
39+
// vector.maskedload
40+
//===----------------------------------------------------------------------===//
41+
42+
func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
43+
%c0 = arith.constant 0: index
44+
%0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
45+
return %0 : vector<16xf32>
46+
}
47+
48+
// ALL-LABEL: func @masked_load
49+
50+
// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
51+
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
52+
53+
// -----
54+
55+
//===----------------------------------------------------------------------===//
56+
// vector.maskedstore
57+
//===----------------------------------------------------------------------===//
58+
59+
func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
60+
%c0 = arith.constant 0: index
61+
vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
62+
return
63+
}
64+
65+
// ALL-LABEL: func @masked_store
66+
67+
// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
68+
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
69+
70+
// -----
71+
72+
//===----------------------------------------------------------------------===//
73+
// vector.scatter
74+
//===----------------------------------------------------------------------===//
75+
76+
func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
77+
%0 = arith.constant 0: index
78+
vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
79+
return
80+
}
81+
82+
// ALL-LABEL: func @scatter
83+
84+
// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
85+
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
86+
87+
// -----
88+
89+
//===----------------------------------------------------------------------===//
90+
// vector.gather
91+
//===----------------------------------------------------------------------===//
92+
93+
func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
94+
%0 = arith.constant 0: index
95+
%1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
96+
return %1 : vector<3xf32>
97+
}
98+
99+
// ALL-LABEL: func @gather
100+
101+
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
102+
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>

0 commit comments

Comments
 (0)