Skip to content

Commit e8eb52d

Browse files
authored
[flang][cuda] Extends matching distance computation (#91810)
Extends the computation of the matching distance in the generic resolution to support options described in the table: https://docs.nvidia.com/hpc-sdk/archive/24.3/compilers/cuda-fortran-prog-guide/index.html#cfref-var-attr-unified-data Options are added as language features in the `SemanticsContext` and a flag is added in bbc for testing purpose.
1 parent be7c9e3 commit e8eb52d

File tree

8 files changed

+180
-21
lines changed

8 files changed

+180
-21
lines changed

flang/include/flang/Common/Fortran-features.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
4949
IndistinguishableSpecifics, SubroutineAndFunctionSpecifics,
5050
EmptySequenceType, NonSequenceCrayPointee, BranchIntoConstruct,
5151
BadBranchTarget, ConvertedArgument, HollerithPolymorphic, ListDirectedSize,
52-
NonBindCInteroperability)
52+
NonBindCInteroperability, CudaManaged, CudaUnified)
5353

5454
// Portability and suspicious usage warnings
5555
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
@@ -81,6 +81,8 @@ class LanguageFeatureControl {
8181
disable_.set(LanguageFeature::OpenACC);
8282
disable_.set(LanguageFeature::OpenMP);
8383
disable_.set(LanguageFeature::CUDA); // !@cuf
84+
disable_.set(LanguageFeature::CudaManaged);
85+
disable_.set(LanguageFeature::CudaUnified);
8486
disable_.set(LanguageFeature::ImplicitNoneTypeNever);
8587
disable_.set(LanguageFeature::ImplicitNoneTypeAlways);
8688
disable_.set(LanguageFeature::DefaultSave);

flang/include/flang/Common/Fortran.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <string>
2020

2121
namespace Fortran::common {
22+
class LanguageFeatureControl;
2223

2324
// Fortran has five kinds of intrinsic data types, plus the derived types.
2425
ENUM_CLASS(TypeCategory, Integer, Real, Complex, Character, Logical, Derived)
@@ -115,7 +116,8 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind,
115116
std::string AsFortran(IgnoreTKRSet);
116117

117118
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
118-
std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule);
119+
std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule,
120+
const LanguageFeatureControl *features = nullptr);
119121

120122
static constexpr char blankCommonObjectName[] = "__BLNK__";
121123

flang/lib/Common/Fortran.cpp

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Common/Fortran.h"
10+
#include "flang/Common/Fortran-features.h"
1011

1112
namespace Fortran::common {
1213

@@ -102,7 +103,13 @@ std::string AsFortran(IgnoreTKRSet tkr) {
102103
/// dummy argument attribute while `y` represents the actual argument attribute.
103104
bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
104105
std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
105-
bool allowUnifiedMatchingRule) {
106+
bool allowUnifiedMatchingRule, const LanguageFeatureControl *features) {
107+
bool isCudaManaged{features
108+
? features->IsEnabled(common::LanguageFeature::CudaManaged)
109+
: false};
110+
bool isCudaUnified{features
111+
? features->IsEnabled(common::LanguageFeature::CudaUnified)
112+
: false};
106113
if (!x && !y) {
107114
return true;
108115
} else if (x && y && *x == *y) {
@@ -120,19 +127,27 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
120127
return true;
121128
} else if (allowUnifiedMatchingRule) {
122129
if (!x) { // Dummy argument has no attribute -> host
123-
if (y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
130+
if ((y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
131+
(!y && (isCudaUnified || isCudaManaged))) {
124132
return true;
125133
}
126134
} else {
127-
if (*x == CUDADataAttr::Device && y &&
128-
(*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
129-
return true;
130-
} else if (*x == CUDADataAttr::Managed && y &&
131-
*y == CUDADataAttr::Unified) {
132-
return true;
133-
} else if (*x == CUDADataAttr::Unified && y &&
134-
*y == CUDADataAttr::Managed) {
135-
return true;
135+
if (*x == CUDADataAttr::Device) {
136+
if ((y &&
137+
(*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
138+
(!y && (isCudaUnified || isCudaManaged))) {
139+
return true;
140+
}
141+
} else if (*x == CUDADataAttr::Managed) {
142+
if ((y && *y == CUDADataAttr::Unified) ||
143+
(!y && (isCudaUnified || isCudaManaged))) {
144+
return true;
145+
}
146+
} else if (*x == CUDADataAttr::Unified) {
147+
if ((y && *y == CUDADataAttr::Managed) ||
148+
(!y && (isCudaUnified || isCudaManaged))) {
149+
return true;
150+
}
136151
}
137152
}
138153
return false;

flang/lib/Semantics/check-call.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
914914
}
915915
if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
916916
dummy.ignoreTKR,
917-
/*allowUnifiedMatchingRule=*/true)) {
917+
/*allowUnifiedMatchingRule=*/true, &context.languageFeatures())) {
918918
auto toStr{[](std::optional<common::CUDADataAttr> x) {
919919
return x ? "ATTRIBUTES("s +
920920
parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s

flang/lib/Semantics/expression.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,8 +2501,13 @@ static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
25012501

25022502
// Compute the matching distance as described in section 3.2.3 of the CUDA
25032503
// Fortran references.
2504-
static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
2504+
static int GetMatchingDistance(const common::LanguageFeatureControl &features,
2505+
const characteristics::DummyArgument &dummy,
25052506
const std::optional<ActualArgument> &actual) {
2507+
bool isCudaManaged{features.IsEnabled(common::LanguageFeature::CudaManaged)};
2508+
bool isCudaUnified{features.IsEnabled(common::LanguageFeature::CudaUnified)};
2509+
CHECK(!(isCudaUnified && isCudaManaged) && "expect only one enabled.");
2510+
25062511
std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
25072512
if (actual) {
25082513
if (auto *expr{actual->UnwrapExpr()}) {
@@ -2529,6 +2534,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
25292534

25302535
if (!dummyDataAttr) {
25312536
if (!actualDataAttr) {
2537+
if (isCudaUnified || isCudaManaged) {
2538+
return 3;
2539+
}
25322540
return 0;
25332541
} else if (*actualDataAttr == common::CUDADataAttr::Device) {
25342542
return cudaInfMatchingValue;
@@ -2538,6 +2546,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
25382546
}
25392547
} else if (*dummyDataAttr == common::CUDADataAttr::Device) {
25402548
if (!actualDataAttr) {
2549+
if (isCudaUnified || isCudaManaged) {
2550+
return 2;
2551+
}
25412552
return cudaInfMatchingValue;
25422553
} else if (*actualDataAttr == common::CUDADataAttr::Device) {
25432554
return 0;
@@ -2546,15 +2557,21 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
25462557
return 2;
25472558
}
25482559
} else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
2549-
if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
2560+
if (!actualDataAttr) {
2561+
return isCudaUnified ? 1 : isCudaManaged ? 0 : cudaInfMatchingValue;
2562+
}
2563+
if (*actualDataAttr == common::CUDADataAttr::Device) {
25502564
return cudaInfMatchingValue;
25512565
} else if (*actualDataAttr == common::CUDADataAttr::Managed) {
25522566
return 0;
25532567
} else if (*actualDataAttr == common::CUDADataAttr::Unified) {
25542568
return 1;
25552569
}
25562570
} else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
2557-
if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
2571+
if (!actualDataAttr) {
2572+
return isCudaUnified ? 0 : isCudaManaged ? 1 : cudaInfMatchingValue;
2573+
}
2574+
if (*actualDataAttr == common::CUDADataAttr::Device) {
25582575
return cudaInfMatchingValue;
25592576
} else if (*actualDataAttr == common::CUDADataAttr::Managed) {
25602577
return 1;
@@ -2566,6 +2583,7 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
25662583
}
25672584

25682585
static int ComputeCudaMatchingDistance(
2586+
const common::LanguageFeatureControl &features,
25692587
const characteristics::Procedure &procedure,
25702588
const ActualArguments &actuals) {
25712589
const auto &dummies{procedure.dummyArguments};
@@ -2574,7 +2592,7 @@ static int ComputeCudaMatchingDistance(
25742592
for (std::size_t i{0}; i < dummies.size(); ++i) {
25752593
const characteristics::DummyArgument &dummy{dummies[i]};
25762594
const std::optional<ActualArgument> &actual{actuals[i]};
2577-
int d{GetMatchingDistance(dummy, actual)};
2595+
int d{GetMatchingDistance(features, dummy, actual)};
25782596
if (d == cudaInfMatchingValue)
25792597
return d;
25802598
distance += d;
@@ -2666,7 +2684,9 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
26662684
CheckCompatibleArguments(*procedure, localActuals)) {
26672685
if ((procedure->IsElemental() && elemental) ||
26682686
(!procedure->IsElemental() && nonElemental)) {
2669-
int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
2687+
int d{ComputeCudaMatchingDistance(
2688+
context_.languageFeatures(), *procedure, localActuals)};
2689+
llvm::errs() << "matching distance: " << d << "\n";
26702690
if (d != crtMatchingDistance) {
26712691
if (d > crtMatchingDistance) {
26722692
continue;
@@ -2688,8 +2708,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
26882708
} else {
26892709
elemental = &specific;
26902710
}
2691-
crtMatchingDistance =
2692-
ComputeCudaMatchingDistance(*procedure, localActuals);
2711+
crtMatchingDistance = ComputeCudaMatchingDistance(
2712+
context_.languageFeatures(), *procedure, localActuals);
26932713
}
26942714
}
26952715
}

flang/test/Semantics/cuf14.cuf

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
! RUN: bbc -emit-hlfir -fcuda -gpu=unified %s -o - | FileCheck %s
2+
3+
module matching
4+
interface host_and_device
5+
module procedure sub_host
6+
module procedure sub_device
7+
end interface
8+
9+
interface all
10+
module procedure sub_host
11+
module procedure sub_device
12+
module procedure sub_managed
13+
module procedure sub_unified
14+
end interface
15+
16+
interface all_without_unified
17+
module procedure sub_host
18+
module procedure sub_device
19+
module procedure sub_managed
20+
end interface
21+
22+
contains
23+
subroutine sub_host(a)
24+
integer :: a(:)
25+
end
26+
27+
subroutine sub_device(a)
28+
integer, device :: a(:)
29+
end
30+
31+
subroutine sub_managed(a)
32+
integer, managed :: a(:)
33+
end
34+
35+
subroutine sub_unified(a)
36+
integer, unified :: a(:)
37+
end
38+
end module
39+
40+
program m
41+
use matching
42+
43+
integer, allocatable :: actual_host(:)
44+
45+
allocate(actual_host(10))
46+
47+
call host_and_device(actual_host) ! Should resolve to sub_device
48+
call all(actual_host) ! Should resolved to unified
49+
call all_without_unified(actual_host) ! Should resolved to managed
50+
end
51+
52+
! CHECK: fir.call @_QMmatchingPsub_device
53+
! CHECK: fir.call @_QMmatchingPsub_unified
54+
! CHECK: fir.call @_QMmatchingPsub_managed
55+

flang/test/Semantics/cuf15.cuf

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
! RUN: bbc -emit-hlfir -fcuda -gpu=managed %s -o - | FileCheck %s
2+
3+
module matching
4+
interface host_and_device
5+
module procedure sub_host
6+
module procedure sub_device
7+
end interface
8+
9+
interface all
10+
module procedure sub_host
11+
module procedure sub_device
12+
module procedure sub_managed
13+
module procedure sub_unified
14+
end interface
15+
16+
interface all_without_managed
17+
module procedure sub_host
18+
module procedure sub_device
19+
module procedure sub_unified
20+
end interface
21+
22+
contains
23+
subroutine sub_host(a)
24+
integer :: a(:)
25+
end
26+
27+
subroutine sub_device(a)
28+
integer, device :: a(:)
29+
end
30+
31+
subroutine sub_managed(a)
32+
integer, managed :: a(:)
33+
end
34+
35+
subroutine sub_unified(a)
36+
integer, unified :: a(:)
37+
end
38+
end module
39+
40+
program m
41+
use matching
42+
43+
integer, allocatable :: actual_host(:)
44+
45+
allocate(actual_host(10))
46+
47+
call host_and_device(actual_host) ! Should resolve to sub_device
48+
call all(actual_host) ! Should resolved to unified
49+
call all_without_managed(actual_host) ! Should resolved to managed
50+
end
51+
52+
! CHECK: fir.call @_QMmatchingPsub_device
53+
! CHECK: fir.call @_QMmatchingPsub_managed
54+
! CHECK: fir.call @_QMmatchingPsub_unified
55+

flang/tools/bbc/bbc.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
204204
llvm::cl::desc("enable CUDA Fortran"),
205205
llvm::cl::init(false));
206206

207+
static llvm::cl::opt<std::string>
208+
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
209+
llvm::cl::init(""));
210+
207211
static llvm::cl::opt<bool> fixedForm("ffixed-form",
208212
llvm::cl::desc("enable fixed form"),
209213
llvm::cl::init(false));
@@ -495,6 +499,12 @@ int main(int argc, char **argv) {
495499
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
496500
}
497501

502+
if (enableGPUMode == "managed") {
503+
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
504+
} else if (enableGPUMode == "unified") {
505+
options.features.Enable(Fortran::common::LanguageFeature::CudaUnified);
506+
}
507+
498508
if (fixedForm) {
499509
options.isFixedForm = fixedForm;
500510
}

0 commit comments

Comments
 (0)