Skip to content

Commit 37143fe

Browse files
authored
[flang][cuda] Make launch configuration optional for cuf kernel (#115947)
1 parent 01d233f commit 37143fe

File tree

6 files changed

+67
-43
lines changed

6 files changed

+67
-43
lines changed

flang/include/flang/Parser/dump-parse-tree.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ class ParseTreeDumper {
236236
NODE(parser, CUFKernelDoConstruct)
237237
NODE(CUFKernelDoConstruct, StarOrExpr)
238238
NODE(CUFKernelDoConstruct, Directive)
239+
NODE(CUFKernelDoConstruct, LaunchConfiguration)
239240
NODE(parser, CUFReduction)
240241
NODE(parser, CycleStmt)
241242
NODE(parser, DataComponentDefStmt)

flang/include/flang/Parser/parse-tree.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4527,12 +4527,17 @@ struct CUFReduction {
45274527
struct CUFKernelDoConstruct {
45284528
TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
45294529
WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
4530+
struct LaunchConfiguration {
4531+
TUPLE_CLASS_BOILERPLATE(LaunchConfiguration);
4532+
std::tuple<std::list<StarOrExpr>, std::list<StarOrExpr>,
4533+
std::optional<ScalarIntExpr>>
4534+
t;
4535+
};
45304536
struct Directive {
45314537
TUPLE_CLASS_BOILERPLATE(Directive);
45324538
CharBlock source;
4533-
std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
4534-
std::list<StarOrExpr>, std::optional<ScalarIntExpr>,
4535-
std::list<CUFReduction>>
4539+
std::tuple<std::optional<ScalarIntConstantExpr>,
4540+
std::optional<LaunchConfiguration>, std::list<CUFReduction>>
45364541
t;
45374542
};
45384543
std::tuple<Directive, std::optional<DoConstruct>> t;

flang/lib/Lower/Bridge.cpp

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,14 +2862,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
28622862
if (nestedLoops > 1)
28632863
n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
28642864

2865-
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
2866-
std::get<1>(dir.t);
2867-
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
2868-
std::get<2>(dir.t);
2869-
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
2870-
std::get<3>(dir.t);
2865+
const auto &launchConfig = std::get<std::optional<
2866+
Fortran::parser::CUFKernelDoConstruct::LaunchConfiguration>>(dir.t);
2867+
28712868
const std::list<Fortran::parser::CUFReduction> &cufreds =
2872-
std::get<4>(dir.t);
2869+
std::get<2>(dir.t);
28732870

28742871
llvm::SmallVector<mlir::Value> reduceOperands;
28752872
llvm::SmallVector<mlir::Attribute> reduceAttrs;
@@ -2913,35 +2910,45 @@ class FirConverter : public Fortran::lower::AbstractConverter {
29132910
builder->createIntegerConstant(loc, builder->getI32Type(), 0);
29142911

29152912
llvm::SmallVector<mlir::Value> gridValues;
2916-
if (!isOnlyStars(grid)) {
2917-
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2918-
grid) {
2919-
if (expr.v) {
2920-
gridValues.push_back(fir::getBase(
2921-
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2922-
} else {
2923-
gridValues.push_back(zero);
2913+
llvm::SmallVector<mlir::Value> blockValues;
2914+
mlir::Value streamValue;
2915+
2916+
if (launchConfig) {
2917+
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
2918+
std::get<0>(launchConfig->t);
2919+
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
2920+
&block = std::get<1>(launchConfig->t);
2921+
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
2922+
std::get<2>(launchConfig->t);
2923+
if (!isOnlyStars(grid)) {
2924+
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2925+
grid) {
2926+
if (expr.v) {
2927+
gridValues.push_back(fir::getBase(
2928+
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2929+
} else {
2930+
gridValues.push_back(zero);
2931+
}
29242932
}
29252933
}
2926-
}
2927-
llvm::SmallVector<mlir::Value> blockValues;
2928-
if (!isOnlyStars(block)) {
2929-
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2930-
block) {
2931-
if (expr.v) {
2932-
blockValues.push_back(fir::getBase(
2933-
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2934-
} else {
2935-
blockValues.push_back(zero);
2934+
if (!isOnlyStars(block)) {
2935+
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2936+
block) {
2937+
if (expr.v) {
2938+
blockValues.push_back(fir::getBase(
2939+
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2940+
} else {
2941+
blockValues.push_back(zero);
2942+
}
29362943
}
29372944
}
2945+
2946+
if (stream)
2947+
streamValue = builder->createConvert(
2948+
loc, builder->getI32Type(),
2949+
fir::getBase(
2950+
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
29382951
}
2939-
mlir::Value streamValue;
2940-
if (stream)
2941-
streamValue = builder->createConvert(
2942-
loc, builder->getI32Type(),
2943-
fir::getBase(
2944-
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
29452952

29462953
const auto &outerDoConstruct =
29472954
std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);

flang/lib/Parser/executable-parsers.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,15 @@ TYPE_PARSER(("REDUCTION"_tok || "REDUCE"_tok) >>
563563
parenthesized(construct<CUFReduction>(Parser<CUFReduction::Operator>{},
564564
":" >> nonemptyList(scalar(variable)))))
565565

566+
TYPE_PARSER("<<<" >>
567+
construct<CUFKernelDoConstruct::LaunchConfiguration>(gridOrBlock,
568+
"," >> gridOrBlock,
569+
maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>"))
570+
566571
TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
567572
construct<CUFKernelDoConstruct::Directive>(
568-
maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
569-
"," >> gridOrBlock,
570-
maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>",
573+
maybe(parenthesized(scalarIntConstantExpr)),
574+
maybe(Parser<CUFKernelDoConstruct::LaunchConfiguration>{}),
571575
many(Parser<CUFReduction>{}) / endDirective)))
572576
TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
573577
extension<LanguageFeature::CUDA>(construct<CUFKernelDoConstruct>(

flang/lib/Parser/unparse.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,11 +2932,9 @@ class UnparseVisitor {
29322932
Word("*");
29332933
}
29342934
}
2935-
void Unparse(const CUFKernelDoConstruct::Directive &x) {
2936-
Word("!$CUF KERNEL DO");
2937-
Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
2935+
void Unparse(const CUFKernelDoConstruct::LaunchConfiguration &x) {
29382936
Word(" <<<");
2939-
const auto &grid{std::get<1>(x.t)};
2937+
const auto &grid{std::get<0>(x.t)};
29402938
if (grid.empty()) {
29412939
Word("*");
29422940
} else if (grid.size() == 1) {
@@ -2945,18 +2943,24 @@ class UnparseVisitor {
29452943
Walk("(", grid, ",", ")");
29462944
}
29472945
Word(",");
2948-
const auto &block{std::get<2>(x.t)};
2946+
const auto &block{std::get<1>(x.t)};
29492947
if (block.empty()) {
29502948
Word("*");
29512949
} else if (block.size() == 1) {
29522950
Walk(block.front());
29532951
} else {
29542952
Walk("(", block, ",", ")");
29552953
}
2956-
if (const auto &stream{std::get<3>(x.t)}) {
2954+
if (const auto &stream{std::get<2>(x.t)}) {
29572955
Word(",STREAM="), Walk(*stream);
29582956
}
29592957
Word(">>>");
2958+
}
2959+
void Unparse(const CUFKernelDoConstruct::Directive &x) {
2960+
Word("!$CUF KERNEL DO");
2961+
Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
2962+
Walk(std::get<std::optional<CUFKernelDoConstruct::LaunchConfiguration>>(
2963+
x.t));
29602964
Walk(" ", std::get<std::list<CUFReduction>>(x.t), " ");
29612965
Word("\n");
29622966
}

flang/test/Parser/cuf-sanity-common

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ module m
3131
!$cuf kernel do <<<1, (2, 3), stream = 1>>>
3232
do j = 1, 10
3333
end do
34+
!$cuf kernel do
35+
do j = 1, 10
36+
end do
3437
!$cuf kernel do <<<*, *>>> reduce(+:x,y) reduce(*:z)
3538
do j = 1, 10
3639
x = x + a(j)

0 commit comments

Comments
 (0)