Skip to content

Commit f6a2a55

Browse files
authored
[flang][cuda] Handle lowering of stars in cuf kernel launch parameters (#85695)
Parsing of the cuf kernel loop directive has been updated to handle variants with the * syntax. This patch updates the lowering to make use of them. - If the grid or block syntax uses only stars then the operation variadic operand remains empty. - If there is values and stars, then stars are represented as a zero constant value.
1 parent aec50cd commit f6a2a55

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,6 +3131,16 @@ def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {
31313131
def fir_CUDAKernelOp : fir_Op<"cuda_kernel", [AttrSizedOperandSegments,
31323132
DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
31333133

3134+
let description = [{
3135+
Represent the CUDA Fortran kernel directive. The operation is a loop like
3136+
operation that represents the iteration range of the embedded loop nest.
3137+
3138+
When grid or block variadic operands are empty, a `*` only syntax was used
3139+
in the Fortran code.
3140+
If the `*` is mixed with values for either grid or block, these are
3141+
represented by a 0 constant value.
3142+
}];
3143+
31343144
let arguments = (ins
31353145
Variadic<I32>:$grid, // empty means `*`
31363146
Variadic<I32>:$block, // empty means `*`

flang/lib/Lower/Bridge.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,23 +2529,42 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25292529
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
25302530
std::get<3>(dir.t);
25312531

2532+
auto isOnlyStars =
2533+
[&](const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
2534+
&list) -> bool {
2535+
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2536+
list) {
2537+
if (expr.v)
2538+
return false;
2539+
}
2540+
return true;
2541+
};
2542+
2543+
mlir::Value zero =
2544+
builder->createIntegerConstant(loc, builder->getI32Type(), 0);
2545+
25322546
llvm::SmallVector<mlir::Value> gridValues;
2533-
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : grid) {
2534-
if (expr.v) {
2535-
gridValues.push_back(fir::getBase(
2536-
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2537-
} else {
2538-
// TODO: '*'
2547+
if (!isOnlyStars(grid)) {
2548+
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2549+
grid) {
2550+
if (expr.v) {
2551+
gridValues.push_back(fir::getBase(
2552+
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2553+
} else {
2554+
gridValues.push_back(zero);
2555+
}
25392556
}
25402557
}
25412558
llvm::SmallVector<mlir::Value> blockValues;
2542-
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2543-
block) {
2544-
if (expr.v) {
2545-
blockValues.push_back(fir::getBase(
2546-
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2547-
} else {
2548-
// TODO: '*'
2559+
if (!isOnlyStars(block)) {
2560+
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
2561+
block) {
2562+
if (expr.v) {
2563+
blockValues.push_back(fir::getBase(
2564+
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
2565+
} else {
2566+
blockValues.push_back(zero);
2567+
}
25492568
}
25502569
}
25512570
mlir::Value streamValue;

flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,20 @@ subroutine sub1()
4242
! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
4343
! CHECK: {n = 2 : i64}
4444

45-
! TODO: lowering for these cases
46-
! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
47-
! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
45+
!$cuf kernel do(2) <<< (1,*), (256,1) >>>
46+
do i = 1, n
47+
do j = 1, n
48+
c(i,j) = c(i,j) * d(i,j)
49+
end do
50+
end do
51+
! CHECK: fir.cuda_kernel<<<(%c1{{.*}}, %c0{{.*}}), (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
52+
53+
!$cuf kernel do(2) <<< (*,*), (32,4) >>>
54+
do i = 1, n
55+
do j = 1, n
56+
c(i,j) = c(i,j) * d(i,j)
57+
end do
58+
end do
59+
60+
! CHECK: fir.cuda_kernel<<<*, (%c32{{.*}}, %c4{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
4861
end

0 commit comments

Comments
 (0)