Skip to content

Commit 8654d4b

Browse files
committed
[flang] In AllocMemOp lowering, convert types for calling malloc on 32-bit
1 parent 9a54c77 commit 8654d4b

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
982982
template <typename ModuleOp>
983983
static mlir::SymbolRefAttr
984984
getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
985-
mlir::ConversionPatternRewriter &rewriter) {
985+
mlir::ConversionPatternRewriter &rewriter, bool addr32) {
986986
static constexpr char mallocName[] = "malloc";
987987
if (auto mallocFunc =
988988
mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
@@ -992,7 +992,7 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
992992
return mlir::SymbolRefAttr::get(userMalloc);
993993

994994
mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
995-
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
995+
auto indexType = mlir::IntegerType::get(op.getContext(), addr32 ? 32 : 64);
996996
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
997997
op.getLoc(), mallocName,
998998
mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()),
@@ -1002,12 +1002,13 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
10021002
}
10031003

10041004
/// Return the LLVMFuncOp corresponding to the standard malloc call.
1005-
static mlir::SymbolRefAttr
1006-
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1005+
static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op,
1006+
mlir::ConversionPatternRewriter &rewriter,
1007+
bool addr32) {
10071008
if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
1008-
return getMallocInModule(mod, op, rewriter);
1009+
return getMallocInModule(mod, op, rewriter, addr32);
10091010
auto mod = op->getParentOfType<mlir::ModuleOp>();
1010-
return getMallocInModule(mod, op, rewriter);
1011+
return getMallocInModule(mod, op, rewriter, addr32);
10111012
}
10121013

10131014
/// Helper function for generating the LLVM IR that computes the distance
@@ -1057,6 +1058,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10571058
mlir::Type heapTy = heap.getType();
10581059
mlir::Location loc = heap.getLoc();
10591060
auto ity = lowerTy().indexType();
1061+
auto addr32 = lowerTy().getPointerBitwidth(0) == 32;
10601062
mlir::Type dataTy = fir::unwrapRefType(heapTy);
10611063
mlir::Type llvmObjectTy = convertObjectType(dataTy);
10621064
if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
@@ -1067,7 +1069,11 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10671069
for (mlir::Value opnd : adaptor.getOperands())
10681070
size = rewriter.create<mlir::LLVM::MulOp>(
10691071
loc, ity, size, integerCast(loc, rewriter, ity, opnd));
1070-
heap->setAttr("callee", getMalloc(heap, rewriter));
1072+
if (addr32) {
1073+
auto i32ty = mlir::IntegerType::get(rewriter.getContext(), 32);
1074+
size = integerCast(loc, rewriter, i32ty, size);
1075+
}
1076+
heap->setAttr("callee", getMalloc(heap, rewriter, addr32));
10711077
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
10721078
heap, ::getLlvmPtrType(heap.getContext()), size,
10731079
addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1));

flang/lib/Optimizer/CodeGen/TypeConverter.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,22 @@
2828

2929
namespace fir {
3030

31+
static mlir::LowerToLLVMOptions MakeLowerOptions(mlir::ModuleOp module) {
32+
llvm::StringRef dataLayoutString;
33+
auto dataLayoutAttr = module->template getAttrOfType<mlir::StringAttr>(
34+
mlir::LLVM::LLVMDialect::getDataLayoutAttrName());
35+
if (dataLayoutAttr)
36+
dataLayoutString = dataLayoutAttr.getValue();
37+
38+
auto options = mlir::LowerToLLVMOptions(module.getContext());
39+
options.dataLayout = llvm::DataLayout(dataLayoutString);
40+
return options;
41+
}
42+
3143
LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
3244
bool forceUnifiedTBAATree,
3345
const mlir::DataLayout &dl)
34-
: mlir::LLVMTypeConverter(module.getContext()),
46+
: mlir::LLVMTypeConverter(module.getContext(), MakeLowerOptions(module)),
3547
kindMapping(getKindMapping(module)),
3648
specifics(CodeGenSpecifics::get(
3749
module.getContext(), getTargetTriple(module), getKindMapping(module),

0 commit comments

Comments
 (0)