7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10
+ #include " flang/Optimizer/Builder/CUFCommon.h"
10
11
#include " flang/Optimizer/CodeGen/TypeConverter.h"
12
+ #include " flang/Optimizer/Dialect/CUF/CUFOps.h"
11
13
#include " flang/Optimizer/Support/DataLayout.h"
12
14
#include " flang/Runtime/CUDA/common.h"
13
15
#include " flang/Support/Fortran.h"
14
16
#include " mlir/Conversion/LLVMCommon/Pattern.h"
15
17
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
18
+ #include " mlir/Dialect/LLVMIR/NVVMDialect.h"
16
19
#include " mlir/Pass/Pass.h"
17
20
#include " mlir/Transforms/DialectConversion.h"
18
21
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -175,6 +178,69 @@ struct GPULaunchKernelConversion
175
178
}
176
179
};
177
180
181
+ static std::string getFuncName (cuf::SharedMemoryOp op) {
182
+ if (auto gpuFuncOp = op->getParentOfType <mlir::gpu::GPUFuncOp>())
183
+ return gpuFuncOp.getName ().str ();
184
+ if (auto funcOp = op->getParentOfType <mlir::func::FuncOp>())
185
+ return funcOp.getName ().str ();
186
+ if (auto llvmFuncOp = op->getParentOfType <mlir::LLVM::LLVMFuncOp>())
187
+ return llvmFuncOp.getSymName ().str ();
188
+ return " " ;
189
+ }
190
+
191
+ static mlir::Value createAddressOfOp (mlir::ConversionPatternRewriter &rewriter,
192
+ mlir::Location loc,
193
+ gpu::GPUModuleOp gpuMod,
194
+ std::string &sharedGlobalName) {
195
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (
196
+ rewriter.getContext (), mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace );
197
+ if (auto g = gpuMod.lookupSymbol <fir::GlobalOp>(sharedGlobalName))
198
+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
199
+ g.getSymName ());
200
+ if (auto g = gpuMod.lookupSymbol <mlir::LLVM::GlobalOp>(sharedGlobalName))
201
+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
202
+ g.getSymName ());
203
+ return {};
204
+ }
205
+
206
+ struct CUFSharedMemoryOpConversion
207
+ : public mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp> {
208
+ explicit CUFSharedMemoryOpConversion (
209
+ const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit)
210
+ : mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp>(typeConverter,
211
+ benefit) {}
212
+ using OpAdaptor = typename cuf::SharedMemoryOp::Adaptor;
213
+
214
+ mlir::LogicalResult
215
+ matchAndRewrite (cuf::SharedMemoryOp op, OpAdaptor adaptor,
216
+ mlir::ConversionPatternRewriter &rewriter) const override {
217
+ mlir::Location loc = op->getLoc ();
218
+ if (!op.getOffset ())
219
+ mlir::emitError (loc,
220
+ " cuf.shared_memory must have an offset for code gen" );
221
+
222
+ auto gpuMod = op->getParentOfType <gpu::GPUModuleOp>();
223
+ std::string sharedGlobalName =
224
+ (getFuncName (op) + llvm::Twine (cudaSharedMemSuffix)).str ();
225
+ mlir::Value sharedGlobalAddr =
226
+ createAddressOfOp (rewriter, loc, gpuMod, sharedGlobalName);
227
+
228
+ if (!sharedGlobalAddr)
229
+ mlir::emitError (loc, " Could not find the shared global operation\n " );
230
+
231
+ auto castPtr = rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
232
+ loc, mlir::LLVM::LLVMPointerType::get (rewriter.getContext ()),
233
+ sharedGlobalAddr);
234
+ mlir::Type baseType = castPtr->getResultTypes ().front ();
235
+ llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs = {
236
+ static_cast <int32_t >(*op.getOffset ())};
237
+ mlir::Value shmemPtr = rewriter.create <mlir::LLVM::GEPOp>(
238
+ loc, baseType, rewriter.getI8Type (), castPtr, gepArgs);
239
+ rewriter.replaceOp (op, {shmemPtr});
240
+ return mlir::success ();
241
+ }
242
+ };
243
+
178
244
class CUFGPUToLLVMConversion
179
245
: public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> {
180
246
public:
@@ -194,6 +260,7 @@ class CUFGPUToLLVMConversion
194
260
/* forceUnifiedTBAATree=*/ false , *dl);
195
261
cuf::populateCUFGPUToLLVMConversionPatterns (typeConverter, patterns);
196
262
target.addIllegalOp <mlir::gpu::LaunchFuncOp>();
263
+ target.addIllegalOp <cuf::SharedMemoryOp>();
197
264
target.addLegalDialect <mlir::LLVM::LLVMDialect>();
198
265
if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
199
266
std::move (patterns)))) {
@@ -208,5 +275,6 @@ class CUFGPUToLLVMConversion
208
275
void cuf::populateCUFGPUToLLVMConversionPatterns (
209
276
const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
210
277
mlir::PatternBenefit benefit) {
211
- patterns.add <GPULaunchKernelConversion>(converter, benefit);
278
+ patterns.add <CUFSharedMemoryOpConversion, GPULaunchKernelConversion>(
279
+ converter, benefit);
212
280
}
0 commit comments