41
41
#include " mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
42
42
#include " mlir/Dialect/Arith/IR/Arith.h"
43
43
#include " mlir/Dialect/DLTI/DLTI.h"
44
+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
44
45
#include " mlir/Dialect/LLVMIR/LLVMAttrs.h"
45
46
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
46
47
#include " mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
920
921
};
921
922
} // namespace
922
923
923
- // / Return the LLVMFuncOp corresponding to the standard malloc call.
924
+ template < typename ModuleOp>
924
925
static mlir::SymbolRefAttr
925
- getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
926
+ getMallocInModule (ModuleOp mod, fir::AllocMemOp op,
927
+ mlir::ConversionPatternRewriter &rewriter) {
926
928
static constexpr char mallocName[] = " malloc" ;
927
- auto module = op-> getParentOfType <mlir::ModuleOp>();
928
- if ( auto mallocFunc = module . lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
929
+ if ( auto mallocFunc =
930
+ mod. template lookupSymbol <mlir::LLVM::LLVMFuncOp>(mallocName))
929
931
return mlir::SymbolRefAttr::get (mallocFunc);
930
- if (auto userMalloc = module .lookupSymbol <mlir::func::FuncOp>(mallocName))
932
+ if (auto userMalloc =
933
+ mod.template lookupSymbol <mlir::func::FuncOp>(mallocName))
931
934
return mlir::SymbolRefAttr::get (userMalloc);
932
- mlir::OpBuilder moduleBuilder (
933
- op-> getParentOfType < mlir::ModuleOp>() .getBodyRegion ());
935
+
936
+ mlir::OpBuilder moduleBuilder (mod .getBodyRegion ());
934
937
auto indexType = mlir::IntegerType::get (op.getContext (), 64 );
935
938
auto mallocDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
936
939
op.getLoc (), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
940
943
return mlir::SymbolRefAttr::get (mallocDecl);
941
944
}
942
945
946
+ // / Return the LLVMFuncOp corresponding to the standard malloc call.
947
+ static mlir::SymbolRefAttr
948
+ getMalloc (fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
949
+ if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
950
+ return getMallocInModule (mod, op, rewriter);
951
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
952
+ return getMallocInModule (mod, op, rewriter);
953
+ }
954
+
943
955
// / Helper function for generating the LLVM IR that computes the distance
944
956
// / in bytes between adjacent elements pointed to by a pointer
945
957
// / of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
1016
1028
} // namespace
1017
1029
1018
1030
// / Return the LLVMFuncOp corresponding to the standard free call.
1019
- static mlir::SymbolRefAttr getFree (fir::FreeMemOp op,
1020
- mlir::ConversionPatternRewriter &rewriter) {
1031
+ template <typename ModuleOp>
1032
+ static mlir::SymbolRefAttr
1033
+ getFreeInModule (ModuleOp mod, fir::FreeMemOp op,
1034
+ mlir::ConversionPatternRewriter &rewriter) {
1021
1035
static constexpr char freeName[] = " free" ;
1022
- auto module = op->getParentOfType <mlir::ModuleOp>();
1023
1036
// Check if free already defined in the module.
1024
- if (auto freeFunc = module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(freeName))
1037
+ if (auto freeFunc =
1038
+ mod.template lookupSymbol <mlir::LLVM::LLVMFuncOp>(freeName))
1025
1039
return mlir::SymbolRefAttr::get (freeFunc);
1026
1040
if (auto freeDefinedByUser =
1027
- module . lookupSymbol <mlir::func::FuncOp>(freeName))
1041
+ mod. template lookupSymbol <mlir::func::FuncOp>(freeName))
1028
1042
return mlir::SymbolRefAttr::get (freeDefinedByUser);
1029
1043
// Create llvm declaration for free.
1030
- mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
1044
+ mlir::OpBuilder moduleBuilder (mod .getBodyRegion ());
1031
1045
auto voidType = mlir::LLVM::LLVMVoidType::get (op.getContext ());
1032
1046
auto freeDecl = moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
1033
1047
rewriter.getUnknownLoc (), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
1037
1051
return mlir::SymbolRefAttr::get (freeDecl);
1038
1052
}
1039
1053
1054
+ static mlir::SymbolRefAttr getFree (fir::FreeMemOp op,
1055
+ mlir::ConversionPatternRewriter &rewriter) {
1056
+ if (auto mod = op->getParentOfType <mlir::gpu::GPUModuleOp>())
1057
+ return getFreeInModule (mod, op, rewriter);
1058
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
1059
+ return getFreeInModule (mod, op, rewriter);
1060
+ }
1061
+
1040
1062
static unsigned getDimension (mlir::LLVM::LLVMArrayType ty) {
1041
1063
unsigned result = 1 ;
1042
1064
for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
3730
3752
mlir::configureOpenMPToLLVMConversionLegality (target, typeConverter);
3731
3753
target.addLegalDialect <mlir::omp::OpenMPDialect>();
3732
3754
target.addLegalDialect <mlir::acc::OpenACCDialect>();
3755
+ target.addLegalDialect <mlir::gpu::GPUDialect>();
3733
3756
3734
3757
// required NOPs for applying a full conversion
3735
3758
target.addLegalOp <mlir::ModuleOp>();
0 commit comments