Skip to content

[flang][cuda] Relax compatibility rules when host,device procedure is involved #134926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Support/Fortran.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ std::string AsFortran(IgnoreTKRSet);

bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
std::optional<CUDADataAttr>, IgnoreTKRSet, std::optional<std::string> *,
bool allowUnifiedMatchingRule,
bool allowUnifiedMatchingRule, bool isHostDeviceProcedure,
const LanguageFeatureControl *features = nullptr);

static constexpr char blankCommonObjectName[] = "__BLNK__";
Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Evaluate/characteristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual,
if (!attrs.test(Attr::Value) &&
!common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr,
ignoreTKR, warning,
/*allowUnifiedMatchingRule=*/false)) {
/*allowUnifiedMatchingRule=*/false,
/*=isHostDeviceProcedure*/ false)) {
if (whyNot) {
*whyNot = "incompatible CUDA data attributes";
}
Expand Down Expand Up @@ -1776,7 +1777,8 @@ bool DistinguishUtils::Distinguishable(
return true;
} else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr,
x.ignoreTKR | y.ignoreTKR, nullptr,
/*allowUnifiedMatchingRule=*/false)) {
/*allowUnifiedMatchingRule=*/false,
/*=isHostDeviceProcedure*/ false)) {
return true;
} else if (features_.IsEnabled(
common::LanguageFeature::DistinguishableSpecifics) &&
Expand Down
7 changes: 5 additions & 2 deletions flang/lib/Semantics/check-call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1016,9 +1016,12 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
}
}
std::optional<std::string> warning;
bool isHostDeviceProc = procedure.cudaSubprogramAttrs &&
*procedure.cudaSubprogramAttrs ==
common::CUDASubprogramAttrs::HostDevice;
if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
dummy.ignoreTKR, &warning,
/*allowUnifiedMatchingRule=*/true, &context.languageFeatures())) {
dummy.ignoreTKR, &warning, /*allowUnifiedMatchingRule=*/true,
isHostDeviceProc, &context.languageFeatures())) {
auto toStr{[](std::optional<common::CUDADataAttr> x) {
return x ? "ATTRIBUTES("s +
parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Support/Fortran.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ std::string AsFortran(IgnoreTKRSet tkr) {
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
std::optional<std::string> *warning, bool allowUnifiedMatchingRule,
const LanguageFeatureControl *features) {
bool isHostDeviceProcedure, const LanguageFeatureControl *features) {
bool isCudaManaged{features
? features->IsEnabled(common::LanguageFeature::CudaManaged)
: false};
Expand All @@ -114,6 +114,9 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
if (ignoreTKR.test(common::IgnoreTKR::Device)) {
return true;
}
if (!y && isHostDeviceProcedure) {
return true;
}
if (!x && !y) {
return true;
} else if (x && y && *x == *y) {
Expand Down
11 changes: 11 additions & 0 deletions flang/test/Semantics/cuf10.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,15 @@ module m
type (int) :: c, a, b
c = a+b ! ok resolve to addDevice
end subroutine overload

attributes(host,device) subroutine hostdev(a)
integer :: a(*)
end subroutine

subroutine host()
integer :: a(10)
call hostdev(a) ! ok because hostdev is attributes(host,device)
end subroutine


end
Loading