@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366
366
const fir::LLVMTypeConverter *typeConverter;
367
367
};
368
368
369
- static mlir::Value genGetDeviceAddress (mlir::PatternRewriter &rewriter,
370
- mlir::ModuleOp mod, mlir::Location loc,
371
- mlir::Value inputArg) {
372
- fir::FirOpBuilder builder (rewriter, mod);
373
- mlir::func::FuncOp callee =
374
- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
375
- auto fTy = callee.getFunctionType ();
376
- mlir::Value conv = createConvertOp (rewriter, loc, fTy .getInput (0 ), inputArg);
377
- mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
378
- mlir::Value sourceLine =
379
- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
380
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
381
- builder, loc, fTy , conv, sourceFile, sourceLine)};
382
- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
383
- return createConvertOp (rewriter, loc, inputArg.getType (), call->getResult (0 ));
384
- }
369
+ struct CUFDeviceAddressOpConversion
370
+ : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
371
+ using OpRewritePattern::OpRewritePattern;
372
+
373
+ CUFDeviceAddressOpConversion (mlir::MLIRContext *context,
374
+ const mlir::SymbolTable &symtab)
375
+ : OpRewritePattern(context), symTab{symtab} {}
376
+
377
+ mlir::LogicalResult
378
+ matchAndRewrite (cuf::DeviceAddressOp op,
379
+ mlir::PatternRewriter &rewriter) const override {
380
+ if (auto global = symTab.lookup <fir::GlobalOp>(
381
+ op.getHostSymbol ().getRootReference ().getValue ())) {
382
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
383
+ mlir::Location loc = op.getLoc ();
384
+ auto hostAddr = rewriter.create <fir::AddrOfOp>(
385
+ loc, fir::ReferenceType::get (global.getType ()), op.getHostSymbol ());
386
+ fir::FirOpBuilder builder (rewriter, mod);
387
+ mlir::func::FuncOp callee =
388
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc,
389
+ builder);
390
+ auto fTy = callee.getFunctionType ();
391
+ mlir::Value conv =
392
+ createConvertOp (rewriter, loc, fTy .getInput (0 ), hostAddr);
393
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
394
+ mlir::Value sourceLine =
395
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
396
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
397
+ builder, loc, fTy , conv, sourceFile, sourceLine)};
398
+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
399
+ mlir::Value addr = createConvertOp (rewriter, loc, hostAddr.getType (),
400
+ call->getResult (0 ));
401
+ rewriter.replaceOp (op, addr.getDefiningOp ());
402
+ return success ();
403
+ }
404
+ return failure ();
405
+ }
406
+
407
+ private:
408
+ const mlir::SymbolTable &symTab;
409
+ };
385
410
386
411
struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
387
412
using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
398
423
addrOfOp.getSymbol ().getRootReference ().getValue ())) {
399
424
if (cuf::isRegisteredDeviceGlobal (global)) {
400
425
rewriter.setInsertionPointAfter (addrOfOp);
401
- auto mod = op->getParentOfType <mlir::ModuleOp>();
402
- mlir::Value devAddr = genGetDeviceAddress (rewriter, mod, op.getLoc (),
403
- addrOfOp.getResult ());
426
+ mlir::Value devAddr = rewriter.create <cuf::DeviceAddressOp>(
427
+ op.getLoc (), addrOfOp.getType (), addrOfOp.getSymbol ());
404
428
rewriter.startOpModification (op);
405
429
op.getMemrefMutable ().assign (devAddr);
406
430
rewriter.finalizeOpModification (op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
773
797
}
774
798
}
775
799
llvm::SmallVector<mlir::Value> args;
776
- auto mod = op->getParentOfType <mlir::ModuleOp>();
777
800
for (mlir::Value arg : op.getArgs ()) {
778
801
// If the argument is a global descriptor, make sure we pass the device
779
802
// copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
785
808
if (auto global = symTab.lookup <fir::GlobalOp>(
786
809
addrOfOp.getSymbol ().getRootReference ().getValue ())) {
787
810
if (cuf::isRegisteredDeviceGlobal (global)) {
788
- arg = genGetDeviceAddress (rewriter, mod, op.getLoc (),
789
- declareOp.getResult ());
811
+ arg = rewriter
812
+ .create <cuf::DeviceAddressOp>(op.getLoc (),
813
+ addrOfOp.getType (),
814
+ addrOfOp.getSymbol ())
815
+ .getResult ();
790
816
}
791
817
}
792
818
}
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
907
933
patterns.getContext ());
908
934
patterns.insert <CUFDataTransferOpConversion>(patterns.getContext (), symtab,
909
935
&dl, &converter);
910
- patterns.insert <CUFLaunchOpConversion>(patterns.getContext (), symtab);
936
+ patterns.insert <CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
937
+ patterns.getContext (), symtab);
911
938
}
912
939
913
940
void cuf::populateFIRCUFConversionPatterns (const mlir::SymbolTable &symtab,
914
941
mlir::RewritePatternSet &patterns) {
915
- patterns.insert <DeclareOpConversion>(patterns.getContext (), symtab);
942
+ patterns.insert <DeclareOpConversion, CUFDeviceAddressOpConversion>(
943
+ patterns.getContext (), symtab);
916
944
}
0 commit comments