Skip to content

[flang][cuda] Add runtime check for passing device arrays #144003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flang-rt/lib/cuda/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ void RTDEF(CUFSyncGlobalDescriptor)(
((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
}

void RTDEF(CUFDescriptorCheckSection)(
const Descriptor *desc, const char *sourceFile, int sourceLine) {
if (desc && !desc->IsContiguous()) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash("device array section argument is not contiguous");
}
}

RT_EXT_API_GROUP_END
}
} // namespace Fortran::runtime::cuda
3 changes: 3 additions & 0 deletions flang/include/flang/Lower/LoweringOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,8 @@ ENUM_LOWERINGOPT(StackRepackArrays, unsigned, 1, 0)
/// in the leading dimension.
ENUM_LOWERINGOPT(RepackArraysWhole, unsigned, 1, 0)

/// If true, CUDA Fortran runtime check is inserted.
ENUM_LOWERINGOPT(CUDARuntimeCheck, unsigned, 1, 0)

#undef LOWERINGOPT
#undef ENUM_LOWERINGOPT
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace fir::runtime::cuda {
void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value hostPtr);

/// Generate runtime call to check the section of a descriptor and raise an
/// error if it is not contiguous.
void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value desc);

} // namespace fir::runtime::cuda

#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/CUDA/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
void RTDECL(CUFSyncGlobalDescriptor)(
void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0);

/// Check descriptor passed to a kernel.
void RTDECL(CUFDescriptorCheckSection)(
const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);

} // extern "C"

} // namespace Fortran::runtime::cuda
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
#include "flang/Optimizer/Builder/Runtime/Derived.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
Expand Down Expand Up @@ -543,6 +544,19 @@ Fortran::lower::genCallOpAndResult(
fir::FortranProcedureFlagsEnumAttr procAttrs =
caller.getProcedureAttrs(builder.getContext());

if (converter.getLoweringOptions().getCUDARuntimeCheck()) {
if (caller.getCallDescription().chevrons().empty()) {
for (auto [oper, arg] :
llvm::zip(operands, caller.getPassedArguments())) {
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(oper.getType())) {
const Fortran::semantics::Symbol *sym = caller.getDummySymbol(arg);
if (sym && Fortran::evaluate::IsCUDADeviceSymbol(*sym))
fir::runtime::cuda::genDescriptorCheckSection(builder, loc, oper);
}
}
}
}

if (!caller.getCallDescription().chevrons().empty()) {
// A call to a CUDA kernel with the chevron syntax.

Expand Down
15 changes: 15 additions & 0 deletions flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ void fir::runtime::cuda::genSyncGlobalDescriptor(fir::FirOpBuilder &builder,
builder, loc, fTy, hostPtr, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, callee, args);
}

void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value desc) {
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDescriptorCheckSection)>(loc,
builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, desc, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
}
22 changes: 22 additions & 0 deletions flang/test/Lower/CUDA/cuda-runtime-check.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

! Check insertion of runtime checks

interface
subroutine foo(a)
real, device, dimension(:,:) :: a
end subroutine
end interface

real, device, allocatable, dimension(:,:) :: a
allocate(a(10,10))
call foo(a(1:10,1:10:2))
end

subroutine foo(a)
real, device, dimension(:,:) :: a
end subroutine

! CHECK-LABEL: func.func @_QQmain()
! CHECK: fir.call @_FortranACUFDescriptorCheckSection
! CHECK: fir.call @_QPfoo
2 changes: 2 additions & 0 deletions flang/tools/bbc/bbc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
loweringOptions.setStackRepackArrays(stackRepackArrays);
loweringOptions.setRepackArrays(repackArrays);
loweringOptions.setRepackArraysWhole(repackArraysWhole);
if (enableCUDA)
loweringOptions.setCUDARuntimeCheck(true);
std::vector<Fortran::lower::EnvironmentDefault> envDefaults = {};
Fortran::frontend::TargetOptions targetOpts;
Fortran::frontend::CodeGenOptions cgOpts;
Expand Down
Loading