-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass #137389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1394,6 +1394,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { | |
"bool", /*default=*/"true", | ||
"Allows compiler to assume vector indices fit in 32-bit if that " | ||
"yields faster code">, | ||
Option<"useVectorAlignment", "use-vector-alignment", | ||
"bool", /*default=*/"false", | ||
"Use the preferred alignment of a vector type in load/store " | ||
"operations instead of the alignment of the element type of the " | ||
"memref. This flag is intended for use with hardware which requires" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "or in application contexts where it is known all vector accesses are naturally aligned"? |
||
"vector alignment, or in application contexts where it is known all " | ||
"vector access are naturally aligned. ">, | ||
Option<"amx", "enable-amx", | ||
"bool", /*default=*/"false", | ||
"Enables the use of AMX dialect while lowering the vector " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter, | |
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); | ||
} | ||
|
||
// Helper that returns data layout alignment of a vector. | ||
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, | ||
VectorType vectorType, unsigned &align) { | ||
Type convertedVectorTy = typeConverter.convertType(vectorType); | ||
if (!convertedVectorTy) | ||
return failure(); | ||
|
||
llvm::LLVMContext llvmContext; | ||
align = LLVM::TypeToLLVMIRTranslator(llvmContext) | ||
.getPreferredAlignment(convertedVectorTy, | ||
typeConverter.getDataLayout()); | ||
|
||
return success(); | ||
} | ||
|
||
// Helper that returns data layout alignment of a memref. | ||
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, | ||
MemRefType memrefType, unsigned &align) { | ||
|
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, | |
return success(); | ||
} | ||
|
||
// Helper to resolve the alignment for vector load/store, gather and scatter | ||
// ops. If useVectorAlignment is true, get the preferred alignment for the | ||
// vector type in the operation. This option is used for hardware backends with | ||
// vectorization. Otherwise, use the preferred alignment of the element type of | ||
// the memref. Note that if you choose to use vector alignment, the shape of the | ||
// vector type must be resolved before the ConvertVectorToLLVM pass is run. | ||
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you document this as well? Some context |
||
VectorType vectorType, | ||
MemRefType memrefType, unsigned &align, | ||
bool useVectorAlignment) { | ||
if (useVectorAlignment) { | ||
if (failed(getVectorAlignment(typeConverter, vectorType, align))) { | ||
return failure(); | ||
} | ||
} else { | ||
if (failed(getMemRefAlignment(typeConverter, memrefType, align))) { | ||
return failure(); | ||
} | ||
} | ||
return success(); | ||
} | ||
|
||
// Check if the last stride is non-unit and has a valid memory space. | ||
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, | ||
const LLVMTypeConverter &converter) { | ||
|
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, | |
template <class LoadOrStoreOp> | ||
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { | ||
public: | ||
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv, | ||
bool useVectorAlign) | ||
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv), | ||
useVectorAlignment(useVectorAlign) {} | ||
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
|
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { | |
|
||
// Resolve alignment. | ||
unsigned align; | ||
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) | ||
return failure(); | ||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy, | ||
memRefTy, align, useVectorAlignment))) | ||
return rewriter.notifyMatchFailure(loadOrStoreOp, | ||
"could not resolve alignment"); | ||
|
||
// Resolve address. | ||
auto vtype = cast<VectorType>( | ||
|
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { | |
rewriter); | ||
return success(); | ||
} | ||
|
||
private: | ||
// If true, use the preferred alignment of the vector type. | ||
// If false, use the preferred alignment of the element type | ||
// of the memref. This flag is intended for use with hardware | ||
// backends that require alignment of vector operations. | ||
const bool useVectorAlignment; | ||
}; | ||
|
||
/// Conversion pattern for a vector.gather. | ||
class VectorGatherOpConversion | ||
: public ConvertOpToLLVMPattern<vector::GatherOp> { | ||
public: | ||
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv, | ||
bool useVectorAlign) | ||
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv), | ||
useVectorAlignment(useVectorAlign) {} | ||
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
|
@@ -278,10 +332,9 @@ class VectorGatherOpConversion | |
|
||
// Resolve alignment. | ||
unsigned align; | ||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { | ||
return rewriter.notifyMatchFailure(gather, | ||
"could not resolve memref alignment"); | ||
} | ||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, | ||
memRefType, align, useVectorAlignment))) | ||
return rewriter.notifyMatchFailure(gather, "could not resolve alignment"); | ||
|
||
// Resolve address. | ||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), | ||
|
@@ -297,12 +350,24 @@ class VectorGatherOpConversion | |
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); | ||
return success(); | ||
} | ||
|
||
private: | ||
// If true, use the preferred alignment of the vector type. | ||
// If false, use the preferred alignment of the element type | ||
// of the memref. This flag is intended for use with hardware | ||
// backends that require alignment of vector operations. | ||
const bool useVectorAlignment; | ||
electriclilies marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
|
||
/// Conversion pattern for a vector.scatter. | ||
class VectorScatterOpConversion | ||
: public ConvertOpToLLVMPattern<vector::ScatterOp> { | ||
public: | ||
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv, | ||
bool useVectorAlign) | ||
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv), | ||
useVectorAlignment(useVectorAlign) {} | ||
|
||
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
|
@@ -322,10 +387,10 @@ class VectorScatterOpConversion | |
|
||
// Resolve alignment. | ||
unsigned align; | ||
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) { | ||
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType, | ||
memRefType, align, useVectorAlignment))) | ||
return rewriter.notifyMatchFailure(scatter, | ||
"could not resolve memref alignment"); | ||
} | ||
"could not resolve alignment"); | ||
|
||
// Resolve address. | ||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), | ||
|
@@ -340,6 +405,13 @@ class VectorScatterOpConversion | |
rewriter.getI32IntegerAttr(align)); | ||
return success(); | ||
} | ||
|
||
private: | ||
// If true, use the preferred alignment of the vector type. | ||
// If false, use the preferred alignment of the element type | ||
// of the memref. This flag is intended for use with hardware | ||
// backends that require alignment of vector operations. | ||
const bool useVectorAlignment; | ||
}; | ||
|
||
/// Conversion pattern for a vector.expandload. | ||
|
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern( | |
/// Populate the given list with patterns that convert from Vector to LLVM. | ||
void mlir::populateVectorToLLVMConversionPatterns( | ||
const LLVMTypeConverter &converter, RewritePatternSet &patterns, | ||
bool reassociateFPReductions, bool force32BitVectorIndices) { | ||
bool reassociateFPReductions, bool force32BitVectorIndices, | ||
bool useVectorAlignment) { | ||
// This function populates only ConversionPatterns, not RewritePatterns. | ||
MLIRContext *ctx = converter.getDialect()->getContext(); | ||
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions); | ||
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices); | ||
patterns.add<VectorLoadStoreConversion<vector::LoadOp>, | ||
VectorLoadStoreConversion<vector::MaskedLoadOp>, | ||
VectorLoadStoreConversion<vector::StoreOp>, | ||
VectorLoadStoreConversion<vector::MaskedStoreOp>, | ||
VectorGatherOpConversion, VectorScatterOpConversion>( | ||
converter, useVectorAlignment); | ||
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, | ||
VectorExtractElementOpConversion, VectorExtractOpConversion, | ||
VectorFMAOp1DConversion, VectorInsertElementOpConversion, | ||
VectorInsertOpConversion, VectorPrintOpConversion, | ||
VectorTypeCastOpConversion, VectorScaleOpConversion, | ||
VectorLoadStoreConversion<vector::LoadOp>, | ||
VectorLoadStoreConversion<vector::MaskedLoadOp>, | ||
VectorLoadStoreConversion<vector::StoreOp>, | ||
VectorLoadStoreConversion<vector::MaskedStoreOp>, | ||
VectorGatherOpConversion, VectorScatterOpConversion, | ||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, | ||
VectorSplatOpLowering, VectorSplatNdOpLowering, | ||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN | ||
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN | ||
Comment on lines
+1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously this would also specify the actual alignment - what happened here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the option I was passing in was not actually doing anything-- it was setting the data layout as a MLIR attribute at the module level but I guess that wasn't being propagated into the LLVM context? The only test that uses the option just checks that the data layout description attr is present at the module level, and doesn't do anything else with it. There's a note in the original getMemRefAlignment method saying that we should be getting the data layout description from MLIR but instead it's gotten from the LLVMContext. So I think it's now just using whatever the default is? I'll try and look into it a bit more There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation and for being so diligent about this. No need to dig deeper, this is sufficient for me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah - in some future point we should have a MLIR data layout <=> LLVM data layout translation that'll make everything work nicely. I think you'll get what you want by changing the
That'll let you pick things up from an MLIR DLTI, though I don't think that supports vector alignments yet There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can also check for |
||
|
||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.load | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { | ||
%0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32> | ||
return %0 : vector<8xf32> | ||
} | ||
|
||
// ALL-LABEL: func @load | ||
|
||
// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32> | ||
// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32> | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.store | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) { | ||
%val = arith.constant dense<11.0> : vector<4xf32> | ||
vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32> | ||
return | ||
} | ||
|
||
// ALL-LABEL: func @store | ||
|
||
// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr | ||
// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.maskedload | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { | ||
%c0 = arith.constant 0: index | ||
%0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> | ||
return %0 : vector<16xf32> | ||
} | ||
|
||
// ALL-LABEL: func @masked_load | ||
|
||
// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> | ||
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32> | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.maskedstore | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) { | ||
%c0 = arith.constant 0: index | ||
vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> | ||
return | ||
} | ||
|
||
// ALL-LABEL: func @masked_store | ||
|
||
// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr | ||
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.scatter | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) { | ||
%0 = arith.constant 0: index | ||
vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> | ||
return | ||
} | ||
|
||
// ALL-LABEL: func @scatter | ||
|
||
// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> | ||
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// vector.gather | ||
//===----------------------------------------------------------------------===// | ||
|
||
func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> { | ||
%0 = arith.constant 0: index | ||
%1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> | ||
return %1 : vector<3xf32> | ||
} | ||
|
||
// ALL-LABEL: func @gather | ||
|
||
// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> | ||
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> |
Uh oh!
There was an error while loading. Please reload this page.