Skip to content

Commit 7106389

Browse files
authored
[flang][cuda] Lower kernel launch to fir.cuda_kernel_launch (#81891)
This patch introduces a new `fir.cuda_kernel_launch` operation to represents the call to CUDA kernels with the chervon notation. The chevrons values in the parse tree can be scalar integer expr or dim3 derived type. The operation describes the grid/block values explicitly as i32 values. It lowers the parse-tree call to this new operation.
1 parent 737bc9f commit 7106389

File tree

3 files changed

+185
-1
lines changed

3 files changed

+185
-1
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,65 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
24362436
}];
24372437
}
24382438

2439+
def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
2440+
AttrSizedOperandSegments]> {
2441+
let summary = "call CUDA kernel";
2442+
2443+
let description = [{
2444+
Launch a CUDA kernel from the host.
2445+
2446+
```
2447+
// launch simple kernel with no arguments. bytes and stream value are
2448+
// optional in the chevron notation.
2449+
fir.cuda_kernel_launch @kernel<<<%gx, %gy, %bx, %by, %bz>>>()
2450+
```
2451+
}];
2452+
2453+
let arguments = (ins
2454+
SymbolRefAttr:$callee,
2455+
I32:$grid_x,
2456+
I32:$grid_y,
2457+
I32:$block_x,
2458+
I32:$block_y,
2459+
I32:$block_z,
2460+
Optional<I32>:$bytes,
2461+
Optional<I32>:$stream,
2462+
Variadic<AnyType>:$args
2463+
);
2464+
2465+
let assemblyFormat = [{
2466+
$callee `<` `<` `<` $grid_x `,` $grid_y `,` $block_x `,` $block_y `,`
2467+
$block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
2468+
`` `(` ( $args^ `:` type($args) )? `)` attr-dict
2469+
}];
2470+
2471+
let extraClassDeclaration = [{
2472+
mlir::CallInterfaceCallable getCallableForCallee() {
2473+
return getCalleeAttr();
2474+
}
2475+
2476+
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
2477+
(*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
2478+
}
2479+
mlir::FunctionType getFunctionType();
2480+
2481+
unsigned getNbNoArgOperand() {
2482+
unsigned nbNoArgOperand = 5; // grids and blocks values are always present.
2483+
if (getBytes()) ++nbNoArgOperand;
2484+
if (getStream()) ++nbNoArgOperand;
2485+
return nbNoArgOperand;
2486+
}
2487+
2488+
operand_range getArgOperands() {
2489+
return {operand_begin() + getNbNoArgOperand(), operand_end()};
2490+
}
2491+
mlir::MutableOperandRange getArgOperandsMutable() {
2492+
return mlir::MutableOperandRange(
2493+
*this, getNbNoArgOperand(), getArgs().size() - 1);
2494+
}
2495+
}];
2496+
}
2497+
24392498
// Constant operations that support Fortran
24402499

24412500
def fir_StringLitOp : fir_Op<"string_lit", [NoMemoryEffect]> {

flang/lib/Lower/ConvertCall.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,21 @@ static bool mustCastFuncOpToCopeWithImplicitInterfaceMismatch(
149149
return false;
150150
}
151151

152+
static mlir::Value readDim3Value(fir::FirOpBuilder &builder, mlir::Location loc,
153+
mlir::Value dim3Addr, llvm::StringRef comp) {
154+
mlir::Type i32Ty = builder.getI32Type();
155+
mlir::Type refI32Ty = fir::ReferenceType::get(i32Ty);
156+
llvm::SmallVector<mlir::Value> lenParams;
157+
158+
mlir::Value designate = builder.create<hlfir::DesignateOp>(
159+
loc, refI32Ty, dim3Addr, /*component=*/comp,
160+
/*componentShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
161+
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
162+
mlir::Value{}, lenParams);
163+
164+
return hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{designate});
165+
}
166+
152167
std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
153168
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
154169
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
@@ -394,7 +409,67 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
394409

395410
mlir::Value callResult;
396411
unsigned callNumResults;
397-
if (caller.requireDispatchCall()) {
412+
413+
if (!caller.getCallDescription().chevrons().empty()) {
414+
// A call to a CUDA kernel with the chevron syntax.
415+
416+
mlir::Type i32Ty = builder.getI32Type();
417+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
418+
419+
mlir::Value grid_x, grid_y;
420+
if (caller.getCallDescription().chevrons()[0].GetType()->category() ==
421+
Fortran::common::TypeCategory::Integer) {
422+
// If grid is an integer, it is converted to dim3(grid,1,1). Since z is
423+
// not used for the number of thread blocks, it is omitted in the op.
424+
grid_x = builder.createConvert(
425+
loc, i32Ty,
426+
fir::getBase(converter.genExprValue(
427+
caller.getCallDescription().chevrons()[0], stmtCtx)));
428+
grid_y = one;
429+
} else {
430+
auto dim3Addr = converter.genExprAddr(
431+
caller.getCallDescription().chevrons()[0], stmtCtx);
432+
grid_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x");
433+
grid_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y");
434+
}
435+
436+
mlir::Value block_x, block_y, block_z;
437+
if (caller.getCallDescription().chevrons()[1].GetType()->category() ==
438+
Fortran::common::TypeCategory::Integer) {
439+
// If block is an integer, it is converted to dim3(block,1,1).
440+
block_x = builder.createConvert(
441+
loc, i32Ty,
442+
fir::getBase(converter.genExprValue(
443+
caller.getCallDescription().chevrons()[1], stmtCtx)));
444+
block_y = one;
445+
block_z = one;
446+
} else {
447+
auto dim3Addr = converter.genExprAddr(
448+
caller.getCallDescription().chevrons()[1], stmtCtx);
449+
block_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x");
450+
block_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y");
451+
block_z = readDim3Value(builder, loc, fir::getBase(dim3Addr), "z");
452+
}
453+
454+
mlir::Value bytes; // bytes is optional.
455+
if (caller.getCallDescription().chevrons().size() > 2)
456+
bytes = builder.createConvert(
457+
loc, i32Ty,
458+
fir::getBase(converter.genExprValue(
459+
caller.getCallDescription().chevrons()[2], stmtCtx)));
460+
461+
mlir::Value stream; // stream is optional.
462+
if (caller.getCallDescription().chevrons().size() > 3)
463+
stream = builder.createConvert(
464+
loc, i32Ty,
465+
fir::getBase(converter.genExprValue(
466+
caller.getCallDescription().chevrons()[3], stmtCtx)));
467+
468+
builder.create<fir::CUDAKernelLaunch>(
469+
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, block_x,
470+
block_y, block_z, bytes, stream, operands);
471+
callNumResults = 0;
472+
} else if (caller.requireDispatchCall()) {
398473
// Procedure call requiring a dynamic dispatch. Call is created with
399474
// fir.dispatch.
400475

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Test lowering of CUDA procedure calls.
4+
5+
module test_call
6+
use, intrinsic :: __fortran_builtins, only: __builtin_dim3
7+
contains
8+
attributes(global) subroutine dev_kernel0()
9+
end
10+
11+
attributes(global) subroutine dev_kernel1(a)
12+
real :: a
13+
end
14+
15+
subroutine host()
16+
real, device :: a
17+
! CHECK-LABEL: func.func @_QMtest_callPhost()
18+
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
19+
20+
call dev_kernel0<<<10, 20>>>()
21+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()
22+
23+
call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>>
24+
! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
25+
! CHECK: %[[DIM3_GRID:.*]]:2 = hlfir.declare %[[ADDR_DIM3_GRID]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.0"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
26+
! CHECK: %[[GRID_X:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
27+
! CHECK: %[[GRID_X_LOAD:.*]] = fir.load %[[GRID_X]] : !fir.ref<i32>
28+
! CHECK: %[[GRID_Y:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"y"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
29+
! CHECK: %[[GRID_Y_LOAD:.*]] = fir.load %[[GRID_Y]] : !fir.ref<i32>
30+
! CHECK: %[[ADDR_DIM3_BLOCK:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
31+
! CHECK: %[[DIM3_BLOCK:.*]]:2 = hlfir.declare %[[ADDR_DIM3_BLOCK]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.1"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
32+
! CHECK: %[[BLOCK_X:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
33+
! CHECK: %[[BLOCK_X_LOAD:.*]] = fir.load %[[BLOCK_X]] : !fir.ref<i32>
34+
! CHECK: %[[BLOCK_Y:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"y"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
35+
! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32>
36+
! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
37+
! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32>
38+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()
39+
40+
call dev_kernel0<<<10, 20, 2>>>()
41+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>()
42+
43+
call dev_kernel0<<<10, 20, 2, 0>>>()
44+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
45+
46+
call dev_kernel1<<<1, 32>>>(a)
47+
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
48+
end
49+
50+
end

0 commit comments

Comments
 (0)