Skip to content

Commit c336a06

Browse files
ubfxftynse
authored andcommitted
[mlir] [memref] Fix alignment bug in memref.copy lowering
memref.copy gets lowered to a function call sometimes, this function is passed the element size of the memref in bytes as an argument. The element size passed to the copyMemRef() function call can be miscalculated if the LLVM IR uses aligned access to the memory. This can be fixed by using llvm.getelementptr to calculate the element size natively. This is also done in the other lowering path that lowers to an intrinsic. Fix llvm#64072 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156126
1 parent 263fc4c commit c336a06

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,10 +879,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
879879
auto sourcePtr = promote(unrankedSource);
880880
auto targetPtr = promote(unrankedTarget);
881881

882-
unsigned typeSize =
883-
mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
884-
auto elemSize = rewriter.create<LLVM::ConstantOp>(
885-
loc, getIndexType(), rewriter.getIndexAttr(typeSize));
882+
// Derive size from llvm.getelementptr which will account for any
883+
// potential alignment
884+
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
886885
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
887886
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
888887
rewriter.create<LLVM::CallOp>(loc, copyFn,

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ func.func @memref_copy_unranked() {
558558
// CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
559559
// CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
560560
// CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
561-
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
561+
// CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr) -> !llvm.ptr, i1
562+
// CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr to i64
562563
// CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr, !llvm.ptr) -> ()
563564
// CHECK: llvm.intr.stackrestore [[STACKSAVE]]
564565
return

mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ func.func @memref_copy_unranked() {
8282
// CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
8383
// CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
8484
// CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
85-
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
85+
// CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr<i1>) -> !llvm.ptr<i1>
86+
// CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr<i1> to i64
8687
// CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
8788
// CHECK: llvm.intr.stackrestore [[STACKSAVE]]
8889
return

0 commit comments

Comments
 (0)