Skip to content

[flang][cuda] Lower kernel launch to fir.cuda_kernel_launch #81891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2436,6 +2436,65 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
}];
}

def fir_CUDAKernelLaunch : fir_Op<"cuda_kernel_launch", [CallOpInterface,
AttrSizedOperandSegments]> {
let summary = "call CUDA kernel";

let description = [{
Launch a CUDA kernel from the host.

```
// launch simple kernel with no arguments. bytes and stream value are
// optional in the chevron notation.
fir.cuda_kernel_launch @kernel<<<%gx, %gy, %bx, %by, %bz>>>()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that you could reproduce the chevron syntax in the MLIR format!

```
}];

let arguments = (ins
SymbolRefAttr:$callee,
I32:$grid_x,
I32:$grid_y,
I32:$block_x,
I32:$block_y,
I32:$block_z,
Optional<I32>:$bytes,
Optional<I32>:$stream,
Variadic<AnyType>:$args
);

let assemblyFormat = [{
$callee `<` `<` `<` $grid_x `,` $grid_y `,` $block_x `,` $block_y `,`
$block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
`` `(` ( $args^ `:` type($args) )? `)` attr-dict
}];

let extraClassDeclaration = [{
mlir::CallInterfaceCallable getCallableForCallee() {
return getCalleeAttr();
}

void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
(*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
}
mlir::FunctionType getFunctionType();

unsigned getNbNoArgOperand() {
unsigned nbNoArgOperand = 5; // grids and blocks values are always present.
if (getBytes()) ++nbNoArgOperand;
if (getStream()) ++nbNoArgOperand;
return nbNoArgOperand;
}

operand_range getArgOperands() {
return {operand_begin() + getNbNoArgOperand(), operand_end()};
}
mlir::MutableOperandRange getArgOperandsMutable() {
return mlir::MutableOperandRange(
*this, getNbNoArgOperand(), getArgs().size() - 1);
}
}];
}

// Constant operations that support Fortran

def fir_StringLitOp : fir_Op<"string_lit", [NoMemoryEffect]> {
Expand Down
77 changes: 76 additions & 1 deletion flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ static bool mustCastFuncOpToCopeWithImplicitInterfaceMismatch(
return false;
}

static mlir::Value readDim3Value(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value dim3Addr, llvm::StringRef comp) {
mlir::Type i32Ty = builder.getI32Type();
mlir::Type refI32Ty = fir::ReferenceType::get(i32Ty);
llvm::SmallVector<mlir::Value> lenParams;

mlir::Value designate = builder.create<hlfir::DesignateOp>(
loc, refI32Ty, dim3Addr, /*component=*/comp,
/*componentShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
mlir::Value{}, lenParams);

return hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{designate});
}

std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
Expand Down Expand Up @@ -394,7 +409,67 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(

mlir::Value callResult;
unsigned callNumResults;
if (caller.requireDispatchCall()) {

if (!caller.getCallDescription().chevrons().empty()) {
// A call to a CUDA kernel with the chevron syntax.

mlir::Type i32Ty = builder.getI32Type();
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);

mlir::Value grid_x, grid_y;
if (caller.getCallDescription().chevrons()[0].GetType()->category() ==
Fortran::common::TypeCategory::Integer) {
// If grid is an integer, it is converted to dim3(grid,1,1). Since z is
// not used for the number of thread blocks, it is omitted in the op.
grid_x = builder.createConvert(
loc, i32Ty,
fir::getBase(converter.genExprValue(
caller.getCallDescription().chevrons()[0], stmtCtx)));
grid_y = one;
} else {
auto dim3Addr = converter.genExprAddr(
caller.getCallDescription().chevrons()[0], stmtCtx);
grid_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x");
grid_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y");
}

mlir::Value block_x, block_y, block_z;
if (caller.getCallDescription().chevrons()[1].GetType()->category() ==
Fortran::common::TypeCategory::Integer) {
// If block is an integer, it is converted to dim3(block,1,1).
block_x = builder.createConvert(
loc, i32Ty,
fir::getBase(converter.genExprValue(
caller.getCallDescription().chevrons()[1], stmtCtx)));
block_y = one;
block_z = one;
} else {
auto dim3Addr = converter.genExprAddr(
caller.getCallDescription().chevrons()[1], stmtCtx);
block_x = readDim3Value(builder, loc, fir::getBase(dim3Addr), "x");
block_y = readDim3Value(builder, loc, fir::getBase(dim3Addr), "y");
block_z = readDim3Value(builder, loc, fir::getBase(dim3Addr), "z");
}

mlir::Value bytes; // bytes is optional.
if (caller.getCallDescription().chevrons().size() > 2)
bytes = builder.createConvert(
loc, i32Ty,
fir::getBase(converter.genExprValue(
caller.getCallDescription().chevrons()[2], stmtCtx)));

mlir::Value stream; // stream is optional.
if (caller.getCallDescription().chevrons().size() > 3)
stream = builder.createConvert(
loc, i32Ty,
fir::getBase(converter.genExprValue(
caller.getCallDescription().chevrons()[3], stmtCtx)));

builder.create<fir::CUDAKernelLaunch>(
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, block_x,
block_y, block_z, bytes, stream, operands);
callNumResults = 0;
} else if (caller.requireDispatchCall()) {
// Procedure call requiring a dynamic dispatch. Call is created with
// fir.dispatch.

Expand Down
50 changes: 50 additions & 0 deletions flang/test/Lower/CUDA/cuda-kernel-calls.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

! Test lowering of CUDA procedure calls.

module test_call
use, intrinsic :: __fortran_builtins, only: __builtin_dim3
contains
attributes(global) subroutine dev_kernel0()
end

attributes(global) subroutine dev_kernel1(a)
real :: a
end

subroutine host()
real, device :: a
! CHECK-LABEL: func.func @_QMtest_callPhost()
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {cuda_attr = #fir.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)

call dev_kernel0<<<10, 20>>>()
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}>>>()

call dev_kernel0<<< __builtin_dim3(1,1), __builtin_dim3(32,1,1) >>>
! CHECK: %[[ADDR_DIM3_GRID:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
! CHECK: %[[DIM3_GRID:.*]]:2 = hlfir.declare %[[ADDR_DIM3_GRID]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.0"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
! CHECK: %[[GRID_X:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: %[[GRID_X_LOAD:.*]] = fir.load %[[GRID_X]] : !fir.ref<i32>
! CHECK: %[[GRID_Y:.*]] = hlfir.designate %[[DIM3_GRID]]#1{"y"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: %[[GRID_Y_LOAD:.*]] = fir.load %[[GRID_Y]] : !fir.ref<i32>
! CHECK: %[[ADDR_DIM3_BLOCK:.*]] = fir.address_of(@_QQro._QM__fortran_builtinsT__builtin_dim3.{{.*}}) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
! CHECK: %[[DIM3_BLOCK:.*]]:2 = hlfir.declare %[[ADDR_DIM3_BLOCK]] {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QQro._QM__fortran_builtinsT__builtin_dim3.1"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>)
! CHECK: %[[BLOCK_X:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"x"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: %[[BLOCK_X_LOAD:.*]] = fir.load %[[BLOCK_X]] : !fir.ref<i32>
! CHECK: %[[BLOCK_Y:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"y"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: %[[BLOCK_Y_LOAD:.*]] = fir.load %[[BLOCK_Y]] : !fir.ref<i32>
! CHECK: %[[BLOCK_Z:.*]] = hlfir.designate %[[DIM3_BLOCK]]#1{"z"} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>) -> !fir.ref<i32>
! CHECK: %[[BLOCK_Z_LOAD:.*]] = fir.load %[[BLOCK_Z]] : !fir.ref<i32>
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%[[GRID_X_LOAD]], %[[GRID_Y_LOAD]], %[[BLOCK_X_LOAD]], %[[BLOCK_Y_LOAD]], %[[BLOCK_Z_LOAD]]>>>()

call dev_kernel0<<<10, 20, 2>>>()
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>()

call dev_kernel0<<<10, 20, 2, 0>>>()
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()

call dev_kernel1<<<1, 32>>>(a)
! CHECK: fir.cuda_kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%1#1 : !fir.ref<f32>)
end

end