@@ -982,7 +982,7 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
982
982
template <typename ModuleOp>
983
983
static mlir::SymbolRefAttr
984
984
getMallocInModule (ModuleOp mod, fir::AllocMemOp op,
985
- mlir::ConversionPatternRewriter &rewriter) {
985
+ mlir::ConversionPatternRewriter &rewriter, bool addr32 ) {
986
986
static constexpr char mallocName[] = " malloc" ;
987
987
if (auto mallocFunc =
988
988
mod.template lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
@@ -992,7 +992,7 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
992
992
return mlir::SymbolRefAttr::get (userMalloc);
993
993
994
994
mlir::OpBuilder moduleBuilder (mod.getBodyRegion ());
995
- auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
995
+ auto indexType = mlir::IntegerType::get (op.getContext (), addr32 ? 32 : 64 );
996
996
auto mallocDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
997
997
op.getLoc (), mallocName,
998
998
mlir::LLVM::LLVMFunctionType::get (getLlvmPtrType (op.getContext ()),
@@ -1002,12 +1002,13 @@ getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
1002
1002
}
1003
1003
1004
1004
// / Return the LLVMFuncOp corresponding to the standard malloc call.
1005
- static mlir::SymbolRefAttr
1006
- getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
1005
+ static mlir::SymbolRefAttr getMalloc (fir::AllocMemOp op,
1006
+ mlir::ConversionPatternRewriter &rewriter,
1007
+ bool addr32) {
1007
1008
if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
1008
- return getMallocInModule (mod, op, rewriter);
1009
+ return getMallocInModule (mod, op, rewriter, addr32 );
1009
1010
auto mod = op->getParentOfType <mlir::ModuleOp>();
1010
- return getMallocInModule (mod, op, rewriter);
1011
+ return getMallocInModule (mod, op, rewriter, addr32 );
1011
1012
}
1012
1013
1013
1014
// / Helper function for generating the LLVM IR that computes the distance
@@ -1057,6 +1058,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
1057
1058
mlir::Type heapTy = heap.getType ();
1058
1059
mlir::Location loc = heap.getLoc ();
1059
1060
auto ity = lowerTy ().indexType ();
1061
+ auto addr32 = lowerTy ().getPointerBitwidth (0 ) == 32 ;
1060
1062
mlir::Type dataTy = fir::unwrapRefType (heapTy);
1061
1063
mlir::Type llvmObjectTy = convertObjectType (dataTy);
1062
1064
if (fir::isRecordWithTypeParameters (fir::unwrapSequenceType (dataTy)))
@@ -1067,7 +1069,11 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
1067
1069
for (mlir::Value opnd : adaptor.getOperands ())
1068
1070
size = rewriter.create <mlir::LLVM::MulOp>(
1069
1071
loc, ity, size, integerCast (loc, rewriter, ity, opnd));
1070
- heap->setAttr (" callee" , getMalloc (heap, rewriter));
1072
+ if (addr32) {
1073
+ auto i32ty = mlir::IntegerType::get (rewriter.getContext (), 32 );
1074
+ size = integerCast (loc, rewriter, i32ty, size);
1075
+ }
1076
+ heap->setAttr (" callee" , getMalloc (heap, rewriter, addr32));
1071
1077
rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
1072
1078
heap, ::getLlvmPtrType (heap.getContext ()), size,
1073
1079
addLLVMOpBundleAttrs (rewriter, heap->getAttrs (), 1 ));
0 commit comments