Skip to content

Commit 614d30c

Browse files
clementvalvar-const
authored andcommitted
[flang][cuda] Relax compatibility rules when host,device procedure is involved (llvm#134926)
Relax too restrictive rule for host, device procedure.
1 parent 5c3ff79 commit 614d30c

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

flang/include/flang/Support/Fortran.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ std::string AsFortran(IgnoreTKRSet);
9696

9797
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
9898
std::optional<CUDADataAttr>, IgnoreTKRSet, std::optional<std::string> *,
99-
bool allowUnifiedMatchingRule,
99+
bool allowUnifiedMatchingRule, bool isHostDeviceProcedure,
100100
const LanguageFeatureControl *features = nullptr);
101101

102102
static constexpr char blankCommonObjectName[] = "__BLNK__";

flang/lib/Evaluate/characteristics.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ bool DummyDataObject::IsCompatibleWith(const DummyDataObject &actual,
370370
if (!attrs.test(Attr::Value) &&
371371
!common::AreCompatibleCUDADataAttrs(cudaDataAttr, actual.cudaDataAttr,
372372
ignoreTKR, warning,
373-
/*allowUnifiedMatchingRule=*/false)) {
373+
/*allowUnifiedMatchingRule=*/false,
374+
/*=isHostDeviceProcedure*/ false)) {
374375
if (whyNot) {
375376
*whyNot = "incompatible CUDA data attributes";
376377
}
@@ -1776,7 +1777,8 @@ bool DistinguishUtils::Distinguishable(
17761777
return true;
17771778
} else if (!common::AreCompatibleCUDADataAttrs(x.cudaDataAttr, y.cudaDataAttr,
17781779
x.ignoreTKR | y.ignoreTKR, nullptr,
1779-
/*allowUnifiedMatchingRule=*/false)) {
1780+
/*allowUnifiedMatchingRule=*/false,
1781+
/*=isHostDeviceProcedure*/ false)) {
17801782
return true;
17811783
} else if (features_.IsEnabled(
17821784
common::LanguageFeature::DistinguishableSpecifics) &&

flang/lib/Semantics/check-call.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,9 +1016,12 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
10161016
}
10171017
}
10181018
std::optional<std::string> warning;
1019+
bool isHostDeviceProc = procedure.cudaSubprogramAttrs &&
1020+
*procedure.cudaSubprogramAttrs ==
1021+
common::CUDASubprogramAttrs::HostDevice;
10191022
if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
1020-
dummy.ignoreTKR, &warning,
1021-
/*allowUnifiedMatchingRule=*/true, &context.languageFeatures())) {
1023+
dummy.ignoreTKR, &warning, /*allowUnifiedMatchingRule=*/true,
1024+
isHostDeviceProc, &context.languageFeatures())) {
10221025
auto toStr{[](std::optional<common::CUDADataAttr> x) {
10231026
return x ? "ATTRIBUTES("s +
10241027
parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s

flang/lib/Support/Fortran.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ std::string AsFortran(IgnoreTKRSet tkr) {
104104
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
105105
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
106106
std::optional<std::string> *warning, bool allowUnifiedMatchingRule,
107-
const LanguageFeatureControl *features) {
107+
bool isHostDeviceProcedure, const LanguageFeatureControl *features) {
108108
bool isCudaManaged{features
109109
? features->IsEnabled(common::LanguageFeature::CudaManaged)
110110
: false};
@@ -114,6 +114,9 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
114114
if (ignoreTKR.test(common::IgnoreTKR::Device)) {
115115
return true;
116116
}
117+
if (!y && isHostDeviceProcedure) {
118+
return true;
119+
}
117120
if (!x && !y) {
118121
return true;
119122
} else if (x && y && *x == *y) {

flang/test/Semantics/cuf10.cuf

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,15 @@ module m
4949
type (int) :: c, a, b
5050
c = a+b ! ok resolve to addDevice
5151
end subroutine overload
52+
53+
attributes(host,device) subroutine hostdev(a)
54+
integer :: a(*)
55+
end subroutine
56+
57+
subroutine host()
58+
integer :: a(10)
59+
call hostdev(a) ! ok because hostdev is attributes(host,device)
60+
end subroutine
61+
62+
5263
end

0 commit comments

Comments
 (0)