Skip to content

Commit de73656

Browse files
author
Lily Orth-Smith
committed
Fix how we get alignment for memrefs
1 parent d1c7fa8 commit de73656

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,26 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
7070
// Helper that returns data layout alignment of a memref.
7171
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
7272
MemRefType memrefType, unsigned &align) {
73-
Type elementTy = typeConverter.convertType(memrefType.getElementType());
74-
if (!elementTy)
73+
// If shape is statically known, assign MemRefTypes to the alignment of a
74+
// VectorType with the same size and dtype. Otherwise, fall back to the
75+
// alignment of the element type.
76+
Type convertedType;
77+
if (memrefType.hasStaticShape()) {
78+
convertedType = typeConverter.convertType(VectorType::get(
79+
memrefType.getNumElements(), memrefType.getElementType()));
80+
} else {
81+
convertedType = typeConverter.convertType(memrefType.getElementType());
82+
}
83+
84+
if (!convertedType)
7585
return failure();
7686

7787
// TODO: this should use the MLIR data layout when it becomes available and
7888
// stop depending on translation.
7989
llvm::LLVMContext llvmContext;
80-
align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81-
.getPreferredAlignment(elementTy, typeConverter.getDataLayout());
90+
align =
91+
LLVM::TypeToLLVMIRTranslator(llvmContext)
92+
.getPreferredAlignment(convertedType, typeConverter.getDataLayout());
8293
return success();
8394
}
8495

0 commit comments

Comments
 (0)