Skip to content

Commit efa44d5

Browse files
author
Lily Orth-Smith
committed
Use flag to determine whether we use vector alignment or memref alignment
1 parent de73656 commit efa44d5

File tree

6 files changed

+103
-40
lines changed

6 files changed

+103
-40
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
13941394
"bool", /*default=*/"true",
13951395
"Allows compiler to assume vector indices fit in 32-bit if that "
13961396
"yields faster code">,
1397+
Option<"useVectorAlignment", "use-vector-alignment",
1398+
"bool", /*default=*/"false",
1399+
"Use the preferred alignment of a vector type in load/store "
1400+
"operations instead of the alignment of the element type of the "
1401+
"memref">,
13971402
Option<"amx", "enable-amx",
13981403
"bool", /*default=*/"false",
13991404
"Enables the use of AMX dialect while lowering the vector "

mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
2222
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
2323
void populateVectorToLLVMConversionPatterns(
2424
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
25-
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
25+
bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
26+
bool useVectorAlignment = false);
2627

2728
namespace vector {
2829
void registerConvertVectorToLLVMInterface(DialectRegistry &registry);

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
3232

3333
let arguments = (ins
3434
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
35-
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
35+
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
36+
DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
3637
let assemblyFormat = "attr-dict";
3738
}
3839

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
6767
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
6868
}
6969

70+
// Helper that returns data layout alignment of a vector.
71+
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
72+
VectorType vectorType, unsigned &align) {
73+
Type convertedVectorTy = typeConverter.convertType(vectorType);
74+
if (!convertedVectorTy)
75+
return failure();
76+
77+
llvm::LLVMContext llvmContext;
78+
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
79+
.getPreferredAlignment(convertedVectorTy,
80+
typeConverter.getDataLayout());
81+
82+
return success();
83+
}
84+
7085
// Helper that returns data layout alignment of a memref.
7186
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
7287
MemRefType memrefType, unsigned &align) {
73-
// If shape is statically known, assign MemRefTypes to the alignment of a
74-
// VectorType with the same size and dtype. Otherwise, fall back to the
75-
// alignment of the element type.
76-
Type convertedType;
77-
if (memrefType.hasStaticShape()) {
78-
convertedType = typeConverter.convertType(VectorType::get(
79-
memrefType.getNumElements(), memrefType.getElementType()));
80-
} else {
81-
convertedType = typeConverter.convertType(memrefType.getElementType());
82-
}
83-
84-
if (!convertedType)
88+
Type elementTy = typeConverter.convertType(memrefType.getElementType());
89+
if (!elementTy)
8590
return failure();
8691

8792
// TODO: this should use the MLIR data layout when it becomes available and
8893
// stop depending on translation.
8994
llvm::LLVMContext llvmContext;
90-
align =
91-
LLVM::TypeToLLVMIRTranslator(llvmContext)
92-
.getPreferredAlignment(convertedType, typeConverter.getDataLayout());
95+
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
96+
.getPreferredAlignment(elementTy, typeConverter.getDataLayout());
9397
return success();
9498
}
9599

@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
235239
template <class LoadOrStoreOp>
236240
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
237241
public:
242+
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
243+
bool useVectorAlign)
244+
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
245+
useVectorAlignment(useVectorAlign) {}
238246
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
239247

240248
LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
251259

252260
// Resolve alignment.
253261
unsigned align;
254-
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
255-
return failure();
262+
if (useVectorAlignment) {
263+
if (failed(
264+
getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
265+
return rewriter.notifyMatchFailure(
266+
loadOrStoreOp, "could not resolve vector alignment");
267+
} else {
268+
if (failed(
269+
getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
270+
return rewriter.notifyMatchFailure(
271+
loadOrStoreOp, "could not resolve memref alignment");
272+
}
256273

257274
// Resolve address.
258275
auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
263280
rewriter);
264281
return success();
265282
}
283+
284+
private:
285+
const bool useVectorAlignment;
266286
};
267287

268288
/// Conversion pattern for a vector.gather.
269289
class VectorGatherOpConversion
270290
: public ConvertOpToLLVMPattern<vector::GatherOp> {
271291
public:
292+
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
293+
bool useVectorAlign)
294+
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
295+
useVectorAlignment(useVectorAlign) {}
272296
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
273297

274298
LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
289313

290314
// Resolve alignment.
291315
unsigned align;
292-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
293-
return rewriter.notifyMatchFailure(gather,
294-
"could not resolve memref alignment");
316+
if (useVectorAlignment) {
317+
if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
318+
return rewriter.notifyMatchFailure(
319+
gather, "could not resolve vector alignment");
320+
} else {
321+
if (failed(
322+
getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
323+
return rewriter.notifyMatchFailure(
324+
gather, "could not resolve memref alignment");
295325
}
296326

297327
// Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
308338
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
309339
return success();
310340
}
341+
342+
private:
343+
const bool useVectorAlignment;
311344
};
312345

313346
/// Conversion pattern for a vector.scatter.
314347
class VectorScatterOpConversion
315348
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
316349
public:
350+
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
351+
bool useVectorAlign)
352+
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
353+
useVectorAlignment(useVectorAlign) {}
354+
317355
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
318356

319357
LogicalResult
@@ -333,9 +371,15 @@ class VectorScatterOpConversion
333371

334372
// Resolve alignment.
335373
unsigned align;
336-
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
337-
return rewriter.notifyMatchFailure(scatter,
338-
"could not resolve memref alignment");
374+
if (useVectorAlignment) {
375+
if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
376+
return rewriter.notifyMatchFailure(
377+
scatter, "could not resolve vector alignment");
378+
} else {
379+
if (failed(
380+
getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
381+
return rewriter.notifyMatchFailure(
382+
scatter, "could not resolve memref alignment");
339383
}
340384

341385
// Resolve address.
@@ -351,6 +395,9 @@ class VectorScatterOpConversion
351395
rewriter.getI32IntegerAttr(align));
352396
return success();
353397
}
398+
399+
private:
400+
const bool useVectorAlignment;
354401
};
355402

356403
/// Conversion pattern for a vector.expandload.
@@ -1939,7 +1986,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
19391986
/// Populate the given list with patterns that convert from Vector to LLVM.
19401987
void mlir::populateVectorToLLVMConversionPatterns(
19411988
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1942-
bool reassociateFPReductions, bool force32BitVectorIndices) {
1989+
bool reassociateFPReductions, bool force32BitVectorIndices,
1990+
bool useVectorAlignment) {
19431991
// This function populates only ConversionPatterns, not RewritePatterns.
19441992
MLIRContext *ctx = converter.getDialect()->getContext();
19451993
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1996,24 @@ void mlir::populateVectorToLLVMConversionPatterns(
19481996
VectorExtractElementOpConversion, VectorExtractOpConversion,
19491997
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
19501998
VectorInsertOpConversion, VectorPrintOpConversion,
1951-
VectorTypeCastOpConversion, VectorScaleOpConversion,
1952-
VectorLoadStoreConversion<vector::LoadOp>,
1953-
VectorLoadStoreConversion<vector::MaskedLoadOp>,
1954-
VectorLoadStoreConversion<vector::StoreOp>,
1955-
VectorLoadStoreConversion<vector::MaskedStoreOp>,
1956-
VectorGatherOpConversion, VectorScatterOpConversion,
1957-
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1958-
VectorSplatOpLowering, VectorSplatNdOpLowering,
1959-
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1960-
MaskedReductionOpConversion, VectorInterleaveOpLowering,
1961-
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1962-
VectorScalableStepOpLowering>(converter);
1999+
VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
2000+
patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
2001+
ctx, useVectorAlignment),
2002+
patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
2003+
ctx, useVectorAlignment),
2004+
patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
2005+
ctx, useVectorAlignment),
2006+
patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
2007+
ctx, useVectorAlignment),
2008+
patterns.add<VectorGatherOpConversion>(ctx, useVectorAlignment),
2009+
patterns.add<VectorScatterOpConversion>(ctx, useVectorAlignment),
2010+
patterns.add<VectorExpandLoadOpConversion,
2011+
VectorCompressStoreOpConversion, VectorSplatOpLowering,
2012+
VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
2013+
VectorScalableExtractOpLowering, MaskedReductionOpConversion,
2014+
VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
2015+
VectorFromElementsLowering, VectorScalableStepOpLowering>(
2016+
converter);
19632017
}
19642018

19652019
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9292
populateVectorTransferLoweringPatterns(patterns);
9393
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
9494
populateVectorToLLVMConversionPatterns(
95-
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
95+
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
96+
useVectorAlignment);
9697
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
9798

9899
// Architecture specific augmentations.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
3434
TypeConverter &typeConverter, RewritePatternSet &patterns) {
3535
populateVectorToLLVMConversionPatterns(
3636
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37-
getReassociateFpReductions(), getForce_32bitVectorIndices());
37+
getReassociateFpReductions(), getForce_32bitVectorIndices(),
38+
getUseVectorAlignment());
3839
}
3940

4041
LogicalResult

0 commit comments

Comments
 (0)