Skip to content

[mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass #137389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"true",
"Allows compiler to assume vector indices fit in 32-bit if that "
"yields faster code">,
Option<"useVectorAlignment", "use-vector-alignment",
"bool", /*default=*/"false",
"Use the preferred alignment of a vector type in load/store "
"operations instead of the alignment of the element type of the "
"memref. This flag is intended for use with hardware which requires"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"or in application contexts where it is known all vector accesses are naturally aligned"?

"vector alignment, or in application contexts where it is known all "
"vector access are naturally aligned. ">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
bool useVectorAlignment = false);

namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry &registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,

let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
let assemblyFormat = "attr-dict";
}

Expand Down
104 changes: 89 additions & 15 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}

// Helper that returns data layout alignment of a vector.
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
VectorType vectorType, unsigned &align) {
Type convertedVectorTy = typeConverter.convertType(vectorType);
if (!convertedVectorTy)
return failure();

llvm::LLVMContext llvmContext;
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
.getPreferredAlignment(convertedVectorTy,
typeConverter.getDataLayout());

return success();
}

// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
Expand All @@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}

// Helper to resolve the alignment for vector load/store, gather and scatter
// ops. If useVectorAlignment is true, get the preferred alignment for the
// vector type in the operation. This option is used for hardware backends with
// vectorization. Otherwise, use the preferred alignment of the element type of
// the memref. Note that if you choose to use vector alignment, the shape of the
// vector type must be resolved before the ConvertVectorToLLVM pass is run.
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document this as well? Some context useVectorAlignment would be helpful (similar to what you did elsewhere).

VectorType vectorType,
MemRefType memrefType, unsigned &align,
bool useVectorAlignment) {
if (useVectorAlignment) {
if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
return failure();
}
} else {
if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
return failure();
}
}
return success();
}

// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
Expand Down Expand Up @@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;

LogicalResult
Expand All @@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
return failure();
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
memRefTy, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(loadOrStoreOp,
"could not resolve alignment");

// Resolve address.
auto vtype = cast<VectorType>(
Expand All @@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
rewriter);
return success();
}

private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};

/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;

LogicalResult
Expand All @@ -278,10 +332,9 @@ class VectorGatherOpConversion

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
return rewriter.notifyMatchFailure(gather,
"could not resolve memref alignment");
}
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
Expand All @@ -297,12 +350,24 @@ class VectorGatherOpConversion
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}

private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};

/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
bool useVectorAlign)
: ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
useVectorAlignment(useVectorAlign) {}

using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;

LogicalResult
Expand All @@ -322,10 +387,10 @@ class VectorScatterOpConversion

// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
"could not resolve memref alignment");
}
"could not resolve alignment");

// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
Expand All @@ -340,6 +405,13 @@ class VectorScatterOpConversion
rewriter.getI32IntegerAttr(align));
return success();
}

private:
// If true, use the preferred alignment of the vector type.
// If false, use the preferred alignment of the element type
// of the memref. This flag is intended for use with hardware
// backends that require alignment of vector operations.
const bool useVectorAlignment;
};

/// Conversion pattern for a vector.expandload.
Expand Down Expand Up @@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
bool reassociateFPReductions, bool force32BitVectorIndices,
bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
VectorLoadStoreConversion<vector::MaskedLoadOp>,
VectorLoadStoreConversion<vector::StoreOp>,
VectorLoadStoreConversion<vector::MaskedStoreOp>,
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp>,
VectorLoadStoreConversion<vector::MaskedLoadOp>,
VectorLoadStoreConversion<vector::StoreOp>,
VectorLoadStoreConversion<vector::MaskedStoreOp>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);

// Architecture specific augmentations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateVectorToLLVMConversionPatterns(
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
getReassociateFpReductions(), getForce_32bitVectorIndices());
getReassociateFpReductions(), getForce_32bitVectorIndices(),
getUseVectorAlignment());
}

LogicalResult
Expand Down
102 changes: 102 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this would also specify the actual alignment - what happened here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the option I was passing in was not actually doing anything-- it was setting the data layout as a MLIR attribute at the module level but I guess that wasn't being propagated into the LLVM context? The only test that uses the option just checks that the data layout description attr is present at the module level, and doesn't do anything else with it.

There's a note in the original getMemRefAlignment method saying that we should be getting the data layout description from MLIR but instead it's gotten from the LLVMContext. So I think it's now just using whatever the default is? I'll try and look into it a bit more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation and for being so diligent about this. No need to dig deeper, this is sufficient for me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - in some future point we should have a MLIR data layout <=> LLVM data layout translation that'll make everything work nicely.

I think you'll get what you want by changing the options(context) to

    LowerToLLVMOptions options(
        ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));

That'll let you pick things up from an MLIR DLTI, though I don't think that supports vector alignments yet

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also check for llvm.data_layout attributes (aka LLVM::LLVMDialect::getDataLayoutAttrName()) on the op you're processing's enclosing module and do options.datalayout =



//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
%0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}

// ALL-LABEL: func @load

// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//

func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
%val = arith.constant dense<11.0> : vector<4xf32>
vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
return
}

// ALL-LABEL: func @store

// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr

// -----

//===----------------------------------------------------------------------===//
// vector.maskedload
//===----------------------------------------------------------------------===//

func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}

// ALL-LABEL: func @masked_load

// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.maskedstore
//===----------------------------------------------------------------------===//

func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = arith.constant 0: index
vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}

// ALL-LABEL: func @masked_store

// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr

// -----

//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//

func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
%0 = arith.constant 0: index
vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}

// ALL-LABEL: func @scatter

// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>

// -----

//===----------------------------------------------------------------------===//
// vector.gather
//===----------------------------------------------------------------------===//

func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
%0 = arith.constant 0: index
%1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %1 : vector<3xf32>
}

// ALL-LABEL: func @gather

// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>