Skip to content

Commit 6dcd2b0

Browse files
authored
[flang][cuda] Convert cuf.sync_descriptor to runtime call (#121524)
Convert the op to a new entry point in the runtime `CUFSyncGlobalDescriptor`
1 parent 4b17a8b commit 6dcd2b0

File tree

4 files changed

+72
-1
lines changed

4 files changed

+72
-1
lines changed

flang/include/flang/Runtime/CUDA/descriptor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)(
3333
void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
3434
const char *sourceFile = nullptr, int sourceLine = 0);
3535

36+
/// Get the device address of registered with the \p hostPtr and sync them.
37+
void RTDECL(CUFSyncGlobalDescriptor)(
38+
void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);
39+
3640
} // extern "C"
3741

3842
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,45 @@ struct CUFLaunchOpConversion
788788
const mlir::SymbolTable &symTab;
789789
};
790790

791+
struct CUFSyncDescriptorOpConversion
792+
: public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
793+
using OpRewritePattern::OpRewritePattern;
794+
795+
CUFSyncDescriptorOpConversion(mlir::MLIRContext *context,
796+
const mlir::SymbolTable &symTab)
797+
: OpRewritePattern(context), symTab{symTab} {}
798+
799+
mlir::LogicalResult
800+
matchAndRewrite(cuf::SyncDescriptorOp op,
801+
mlir::PatternRewriter &rewriter) const override {
802+
auto mod = op->getParentOfType<mlir::ModuleOp>();
803+
fir::FirOpBuilder builder(rewriter, mod);
804+
mlir::Location loc = op.getLoc();
805+
806+
auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
807+
if (!globalOp)
808+
return mlir::failure();
809+
810+
auto hostAddr = builder.create<fir::AddrOfOp>(
811+
loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
812+
mlir::func::FuncOp callee =
813+
fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
814+
builder);
815+
auto fTy = callee.getFunctionType();
816+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
817+
mlir::Value sourceLine =
818+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
819+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
820+
builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
821+
builder.create<fir::CallOp>(loc, callee, args);
822+
op.erase();
823+
return mlir::success();
824+
}
825+
826+
private:
827+
const mlir::SymbolTable &symTab;
828+
};
829+
791830
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
792831
public:
793832
void runOnOperation() override {
@@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns(
851890
CUFFreeOpConversion>(patterns.getContext());
852891
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
853892
&dl, &converter);
854-
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
893+
patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>(
894+
patterns.getContext(), symtab);
855895
}
856896

857897
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,

flang/runtime/CUDA/descriptor.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
4646
(void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
4747
}
4848

49+
void RTDEF(CUFSyncGlobalDescriptor)(
50+
void *hostPtr, const char *sourceFile, int sourceLine) {
51+
void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
52+
RTNAME(CUFDescriptorSync)
53+
((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
54+
}
55+
4956
RT_EXT_API_GROUP_END
5057
}
5158
} // namespace Fortran::runtime::cuda
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: fir-opt --cuf-convert %s | FileCheck %s
2+
3+
module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 ([email protected]:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
4+
fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
5+
%0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
6+
%c0 = arith.constant 0 : index
7+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
8+
%2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
9+
fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
10+
}
11+
func.func @_QQmain() {
12+
cuf.sync_descriptor @_QMdevptrEdev_ptr
13+
return
14+
}
15+
}
16+
17+
// CHECK-LABEL: func.func @_QQmain()
18+
// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
19+
// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
20+
// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32)

0 commit comments

Comments
 (0)