|
24 | 24 | #include "flang/Optimizer/Support/TypeCode.h"
|
25 | 25 | #include "flang/Optimizer/Support/Utils.h"
|
26 | 26 | #include "flang/Runtime/CUDA/descriptor.h"
|
| 27 | +#include "flang/Runtime/CUDA/memory.h" |
27 | 28 | #include "flang/Runtime/allocator-registry-consts.h"
|
28 | 29 | #include "flang/Runtime/descriptor-consts.h"
|
29 | 30 | #include "flang/Semantics/runtime-type-info.h"
|
@@ -1141,6 +1142,93 @@ convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy,
|
1141 | 1142 | return result;
|
1142 | 1143 | }
|
1143 | 1144 |
|
| 1145 | +static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
| 1146 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1147 | + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 1148 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
| 1149 | + auto fn = flc.getFilename().str() + '\0'; |
| 1150 | + std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
| 1151 | + |
| 1152 | + if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
| 1153 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1154 | + } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
| 1155 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
| 1156 | + } |
| 1157 | + |
| 1158 | + auto crtInsPt = rewriter.saveInsertionPoint(); |
| 1159 | + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
| 1160 | + auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
| 1161 | + mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
| 1162 | + mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
| 1163 | + loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
| 1164 | + globalName, mlir::Attribute()); |
| 1165 | + |
| 1166 | + mlir::Region ®ion = globalOp.getInitializerRegion(); |
| 1167 | + mlir::Block *block = rewriter.createBlock(®ion); |
| 1168 | + rewriter.setInsertionPoint(block, block->begin()); |
| 1169 | + mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| 1170 | + loc, arrayTy, rewriter.getStringAttr(fn)); |
| 1171 | + rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
| 1172 | + rewriter.restoreInsertionPoint(crtInsPt); |
| 1173 | + return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
| 1174 | + globalOp.getName()); |
| 1175 | + } |
| 1176 | + return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
| 1177 | +} |
| 1178 | + |
| 1179 | +static mlir::Value genSourceLine(mlir::Location loc, |
| 1180 | + mlir::ConversionPatternRewriter &rewriter) { |
| 1181 | + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
| 1182 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
| 1183 | + flc.getLine()); |
| 1184 | + return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
| 1185 | +} |
| 1186 | + |
| 1187 | +static mlir::Value |
| 1188 | +genCUFAllocDescriptor(mlir::Location loc, |
| 1189 | + mlir::ConversionPatternRewriter &rewriter, |
| 1190 | + mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
| 1191 | + const fir::LLVMTypeConverter &typeConverter) { |
| 1192 | + std::optional<mlir::DataLayout> dl = |
| 1193 | + fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
| 1194 | + if (!dl) |
| 1195 | + mlir::emitError(mod.getLoc(), |
| 1196 | + "module operation must carry a data layout attribute " |
| 1197 | + "to generate llvm IR from FIR"); |
| 1198 | + |
| 1199 | + mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
| 1200 | + mlir::Value sourceLine = genSourceLine(loc, rewriter); |
| 1201 | + |
| 1202 | + mlir::MLIRContext *ctx = mod.getContext(); |
| 1203 | + |
| 1204 | + mlir::LLVM::LLVMPointerType llvmPointerType = |
| 1205 | + mlir::LLVM::LLVMPointerType::get(ctx); |
| 1206 | + mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
| 1207 | + mlir::Type llvmIntPtrType = |
| 1208 | + mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
| 1209 | + auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
| 1210 | + llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
| 1211 | + |
| 1212 | + auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
| 1213 | + RTNAME_STRING(CUFAllocDesciptor)); |
| 1214 | + auto funcFunc = |
| 1215 | + mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
| 1216 | + if (!llvmFunc && !funcFunc) |
| 1217 | + mlir::OpBuilder::atBlockEnd(mod.getBody()) |
| 1218 | + .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
| 1219 | + fctTy); |
| 1220 | + |
| 1221 | + mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
| 1222 | + std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
| 1223 | + mlir::Value sizeInBytes = |
| 1224 | + genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
| 1225 | + llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
| 1226 | + return rewriter |
| 1227 | + .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
| 1228 | + args) |
| 1229 | + .getResult(); |
| 1230 | +} |
| 1231 | + |
1144 | 1232 | /// Common base class for embox to descriptor conversion.
|
1145 | 1233 | template <typename OP>
|
1146 | 1234 | struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
|
@@ -1554,15 +1642,24 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
|
1554 | 1642 | mlir::Value
|
1555 | 1643 | placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter,
|
1556 | 1644 | mlir::Location loc, mlir::Type boxTy,
|
1557 |
| - mlir::Value boxValue) const { |
| 1645 | + mlir::Value boxValue, |
| 1646 | + bool needDeviceAllocation = false) const { |
1558 | 1647 | if (isInGlobalOp(rewriter))
|
1559 | 1648 | return boxValue;
|
1560 | 1649 | mlir::Type llvmBoxTy = boxValue.getType();
|
1561 |
| - auto alloca = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, |
1562 |
| - defaultAlign, rewriter); |
1563 |
| - auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, alloca); |
| 1650 | + mlir::Value storage; |
| 1651 | + if (needDeviceAllocation) { |
| 1652 | + auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>(); |
| 1653 | + auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy); |
| 1654 | + storage = |
| 1655 | + genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy()); |
| 1656 | + } else { |
| 1657 | + storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign, |
| 1658 | + rewriter); |
| 1659 | + } |
| 1660 | + auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage); |
1564 | 1661 | this->attachTBAATag(storeOp, boxTy, boxTy, nullptr);
|
1565 |
| - return alloca; |
| 1662 | + return storage; |
1566 | 1663 | }
|
1567 | 1664 | };
|
1568 | 1665 |
|
@@ -1614,6 +1711,18 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
|
1614 | 1711 | }
|
1615 | 1712 | };
|
1616 | 1713 |
|
| 1714 | +static bool isDeviceAllocation(mlir::Value val) { |
| 1715 | + if (auto convertOp = |
| 1716 | + mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp())) |
| 1717 | + val = convertOp.getValue(); |
| 1718 | + if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp())) |
| 1719 | + if (callOp.getCallee() && |
| 1720 | + callOp.getCallee().value().getRootReference().getValue().starts_with( |
| 1721 | + RTNAME_STRING(CUFMemAlloc))) |
| 1722 | + return true; |
| 1723 | + return false; |
| 1724 | +} |
| 1725 | + |
1617 | 1726 | /// Create a generic box on a memory reference.
|
1618 | 1727 | struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
|
1619 | 1728 | using EmboxCommonConversion::EmboxCommonConversion;
|
@@ -1797,9 +1906,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
|
1797 | 1906 | dest = insertBaseAddress(rewriter, loc, dest, base);
|
1798 | 1907 | if (fir::isDerivedTypeWithLenParams(boxTy))
|
1799 | 1908 | TODO(loc, "fir.embox codegen of derived with length parameters");
|
1800 |
| - |
1801 |
| - mlir::Value result = |
1802 |
| - placeInMemoryIfNotGlobalInit(rewriter, loc, boxTy, dest); |
| 1909 | + mlir::Value result = placeInMemoryIfNotGlobalInit( |
| 1910 | + rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref())); |
1803 | 1911 | rewriter.replaceOp(xbox, result);
|
1804 | 1912 | return mlir::success();
|
1805 | 1913 | }
|
@@ -2977,93 +3085,6 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
|
2977 | 3085 | }
|
2978 | 3086 | };
|
2979 | 3087 |
|
2980 |
| -static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
2981 |
| - mlir::ConversionPatternRewriter &rewriter) { |
2982 |
| - auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
2983 |
| - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
2984 |
| - auto fn = flc.getFilename().str() + '\0'; |
2985 |
| - std::string globalName = fir::factory::uniqueCGIdent("cl", fn); |
2986 |
| - |
2987 |
| - if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
2988 |
| - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2989 |
| - } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
2990 |
| - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
2991 |
| - } |
2992 |
| - |
2993 |
| - auto crtInsPt = rewriter.saveInsertionPoint(); |
2994 |
| - rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
2995 |
| - auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
2996 |
| - mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
2997 |
| - mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
2998 |
| - loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
2999 |
| - globalName, mlir::Attribute()); |
3000 |
| - |
3001 |
| - mlir::Region ®ion = globalOp.getInitializerRegion(); |
3002 |
| - mlir::Block *block = rewriter.createBlock(®ion); |
3003 |
| - rewriter.setInsertionPoint(block, block->begin()); |
3004 |
| - mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
3005 |
| - loc, arrayTy, rewriter.getStringAttr(fn)); |
3006 |
| - rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
3007 |
| - rewriter.restoreInsertionPoint(crtInsPt); |
3008 |
| - return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
3009 |
| - globalOp.getName()); |
3010 |
| - } |
3011 |
| - return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
3012 |
| -} |
3013 |
| - |
3014 |
| -static mlir::Value genSourceLine(mlir::Location loc, |
3015 |
| - mlir::ConversionPatternRewriter &rewriter) { |
3016 |
| - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
3017 |
| - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
3018 |
| - flc.getLine()); |
3019 |
| - return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
3020 |
| -} |
3021 |
| - |
3022 |
| -static mlir::Value |
3023 |
| -genCUFAllocDescriptor(mlir::Location loc, |
3024 |
| - mlir::ConversionPatternRewriter &rewriter, |
3025 |
| - mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
3026 |
| - const fir::LLVMTypeConverter &typeConverter) { |
3027 |
| - std::optional<mlir::DataLayout> dl = |
3028 |
| - fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true); |
3029 |
| - if (!dl) |
3030 |
| - mlir::emitError(mod.getLoc(), |
3031 |
| - "module operation must carry a data layout attribute " |
3032 |
| - "to generate llvm IR from FIR"); |
3033 |
| - |
3034 |
| - mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
3035 |
| - mlir::Value sourceLine = genSourceLine(loc, rewriter); |
3036 |
| - |
3037 |
| - mlir::MLIRContext *ctx = mod.getContext(); |
3038 |
| - |
3039 |
| - mlir::LLVM::LLVMPointerType llvmPointerType = |
3040 |
| - mlir::LLVM::LLVMPointerType::get(ctx); |
3041 |
| - mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
3042 |
| - mlir::Type llvmIntPtrType = |
3043 |
| - mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
3044 |
| - auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
3045 |
| - llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
3046 |
| - |
3047 |
| - auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
3048 |
| - RTNAME_STRING(CUFAllocDesciptor)); |
3049 |
| - auto funcFunc = |
3050 |
| - mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDesciptor)); |
3051 |
| - if (!llvmFunc && !funcFunc) |
3052 |
| - mlir::OpBuilder::atBlockEnd(mod.getBody()) |
3053 |
| - .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDesciptor), |
3054 |
| - fctTy); |
3055 |
| - |
3056 |
| - mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
3057 |
| - std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
3058 |
| - mlir::Value sizeInBytes = |
3059 |
| - genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
3060 |
| - llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
3061 |
| - return rewriter |
3062 |
| - .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDesciptor), |
3063 |
| - args) |
3064 |
| - .getResult(); |
3065 |
| -} |
3066 |
| - |
3067 | 3088 | /// `fir.load` --> `llvm.load`
|
3068 | 3089 | struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
|
3069 | 3090 | using FIROpConversion::FIROpConversion;
|
|
0 commit comments