-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Convert cuf.sync_descriptor to runtime call #121524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-runtime Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesConvert the op to a new entry point in the runtime Full diff: https://github.com/llvm/llvm-project/pull/121524.diff 4 Files Affected:
diff --git a/flang/include/flang/Runtime/CUDA/descriptor.h b/flang/include/flang/Runtime/CUDA/descriptor.h
index 55878aaac57fb3..0ee7feca10e44c 100644
--- a/flang/include/flang/Runtime/CUDA/descriptor.h
+++ b/flang/include/flang/Runtime/CUDA/descriptor.h
@@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)(
void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
const char *sourceFile = nullptr, int sourceLine = 0);
+/// Get the device address of registered with the \p hostPtr and sync them.
+void RTDECL(CUFSyncGlobalDescriptor)(
+ void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);
+
} // extern "C"
} // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index fb0ef246546444..f08f9e412b8857 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -788,6 +788,45 @@ struct CUFLaunchOpConversion
const mlir::SymbolTable &symTab;
};
+struct CUFSyncDescriptorOpConversion
+ : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFSyncDescriptorOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symTab)
+ : OpRewritePattern(context), symTab{symTab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::SyncDescriptorOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
+ if (!globalOp)
+ return mlir::failure();
+
+ auto hostAddr = builder.create<fir::AddrOfOp>(
+ loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, callee, args);
+ op.erase();
+ return mlir::success();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
+
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
@@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns(
CUFFreeOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
- patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>(
+ patterns.getContext(), symtab);
}
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
diff --git a/flang/runtime/CUDA/descriptor.cpp b/flang/runtime/CUDA/descriptor.cpp
index 391c47e84241d4..947eeb66aa3d6c 100644
--- a/flang/runtime/CUDA/descriptor.cpp
+++ b/flang/runtime/CUDA/descriptor.cpp
@@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
(void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
}
+void RTDEF(CUFSyncGlobalDescriptor)(
+ void *hostPtr, const char *sourceFile, int sourceLine) {
+ void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
+ RTNAME(CUFDescriptorSync)
+ ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
+}
+
RT_EXT_API_GROUP_END
}
} // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-sync-desc.mlir b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
new file mode 100644
index 00000000000000..20b317f34a7f26
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-sync-desc.mlir
@@ -0,0 +1,20 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 ([email protected]:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+ fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+ %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+ %c0 = arith.constant 0 : index
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ }
+ func.func @_QQmain() {
+ cuf.sync_descriptor @_QMdevptrEdev_ptr
+ return
+ }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32)
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/16247 Here is the relevant piece of the build log for the reference
|
Failed buildbot after #121524
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/201/builds/502 Here is the relevant piece of the build log for the reference
|
Failed buildbot after llvm/llvm-project#121524
Convert the op to a new entry point in the runtime
CUFSyncGlobalDescriptor