Skip to content

Commit 91f9f0f

Browse files
authored
[flang][cuda] Update cuf.kernel_launch stream and conversion (#136179)
Update `cuf.kernel_launch` to take the stream as a reference. Update the conversion to insert the `cuf.stream_cast` op so the stream can be set as dependency.
1 parent 9f9c1f9 commit 91f9f0f

File tree

6 files changed

+37
-15
lines changed

6 files changed

+37
-15
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
200200

201201
let arguments = (ins SymbolRefAttr:$callee, I32:$grid_x, I32:$grid_y,
202202
I32:$grid_z, I32:$block_x, I32:$block_y, I32:$block_z,
203-
Optional<I32>:$bytes, Optional<AnyIntegerType>:$stream,
203+
Optional<I32>:$bytes, Optional<fir_ReferenceType>:$stream,
204204
Variadic<AnyType>:$args, OptionalAttr<DictArrayAttr>:$arg_attrs,
205205
OptionalAttr<DictArrayAttr>:$res_attrs);
206206

@@ -237,6 +237,8 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
237237
*this, getNbNoArgOperand(), getArgs().size() - 1);
238238
}
239239
}];
240+
241+
let hasVerifier = 1;
240242
}
241243

242244
def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,

flang/lib/Lower/ConvertCall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ Fortran::lower::genCallOpAndResult(
589589

590590
mlir::Value stream; // stream is optional.
591591
if (caller.getCallDescription().chevrons().size() > 3)
592-
stream = fir::getBase(converter.genExprValue(
592+
stream = fir::getBase(converter.genExprAddr(
593593
caller.getCallDescription().chevrons()[3], stmtCtx));
594594

595595
builder.create<cuf::KernelLaunchOp>(

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,24 @@ llvm::LogicalResult cuf::DeallocateOp::verify() {
139139
return mlir::success();
140140
}
141141

142+
//===----------------------------------------------------------------------===//
143+
// KernelLaunchOp
144+
//===----------------------------------------------------------------------===//
145+
146+
template <typename OpTy>
147+
static llvm::LogicalResult checkStreamType(OpTy op) {
148+
if (!op.getStream())
149+
return mlir::success();
150+
auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getStream().getType());
151+
if (!refTy.getEleTy().isInteger(64))
152+
return op.emitOpError("stream is expected to be a i64 reference");
153+
return mlir::success();
154+
}
155+
156+
llvm::LogicalResult cuf::KernelLaunchOp::verify() {
157+
return checkStreamType(*this);
158+
}
159+
142160
//===----------------------------------------------------------------------===//
143161
// KernelOp
144162
//===----------------------------------------------------------------------===//
@@ -324,10 +342,7 @@ void cuf::SharedMemoryOp::build(
324342
//===----------------------------------------------------------------------===//
325343

326344
llvm::LogicalResult cuf::StreamCastOp::verify() {
327-
auto refTy = mlir::dyn_cast<fir::ReferenceType>(getStream().getType());
328-
if (!refTy.getEleTy().isInteger(64))
329-
return emitOpError("stream is expected to be a i64 reference");
330-
return mlir::success();
345+
return checkStreamType(*this);
331346
}
332347

333348
// Tablegen operators

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -879,8 +879,13 @@ struct CUFLaunchOpConversion
879879
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
880880
gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
881881
}
882-
if (op.getStream())
883-
gpuLaunchOp.getAsyncObjectMutable().assign(op.getStream());
882+
if (op.getStream()) {
883+
mlir::OpBuilder::InsertionGuard guard(rewriter);
884+
rewriter.setInsertionPoint(gpuLaunchOp);
885+
mlir::Value stream =
886+
rewriter.create<cuf::StreamCastOp>(loc, op.getStream());
887+
gpuLaunchOp.getAsyncDependenciesMutable().append(stream);
888+
}
884889
if (procAttr)
885890
gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr);
886891
rewriter.replaceOp(op, gpuLaunchOp);
@@ -933,6 +938,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
933938
/*forceUnifiedTBAATree=*/false, *dl);
934939
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
935940
mlir::gpu::GPUDialect>();
941+
target.addLegalOp<cuf::StreamCastOp>();
936942
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
937943
patterns);
938944
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,

flang/test/Fir/CUDA/cuda-launch.fir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,13 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
146146
%1:2 = hlfir.declare %0 {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
147147
%c1_i32 = arith.constant 1 : i32
148148
%c0_i32 = arith.constant 0 : i32
149-
%2 = fir.load %1#0 : !fir.ref<i64>
150-
cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %2 : i64>>>()
149+
cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %1#0 : !fir.ref<i64>>>>()
151150
return
152151
}
153152
}
154153

155154
// CHECK-LABEL: func.func @_QQmain()
156155
// CHECK: %[[STREAM:.*]] = fir.alloca i64 {bindc_name = "stream", uniq_name = "_QMtest_callFhostEstream"}
157156
// CHECK: %[[DECL_STREAM:.*]]:2 = hlfir.declare %[[STREAM]] {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
158-
// CHECK: %[[STREAM_LOADED:.*]] = fir.load %[[DECL_STREAM]]#0 : !fir.ref<i64>
159-
// CHECK: gpu.launch_func <%[[STREAM_LOADED]] : i64> @cuda_device_mod::@_QMdevptrPtest
157+
// CHECK: %[[TOKEN:.*]] = cuf.stream_cast %[[DECL_STREAM]]#0 : <i64>
158+
// CHECK: gpu.launch_func [%[[TOKEN]]] @cuda_device_mod::@_QMdevptrPtest

flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ contains
4545
call dev_kernel0<<<10, 20, 2>>>()
4646
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>()
4747

48-
call dev_kernel0<<<10, 20, 2, 0>>>()
49-
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
48+
call dev_kernel0<<<10, 20, 2, 0_8>>>()
49+
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %{{.*}} : !fir.ref<i64>>>>()
5050

5151
call dev_kernel1<<<1, 32>>>(a)
5252
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) : (!fir.ref<f32>)
@@ -55,7 +55,7 @@ contains
5555
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}})
5656

5757
call dev_kernel1<<<*,32,0,stream>>>(a)
58-
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : i64>>>(%{{.*}}) : (!fir.ref<f32>)
58+
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : !fir.ref<i64>>>>(%{{.*}}) : (!fir.ref<f32>)
5959

6060
end
6161

0 commit comments

Comments
 (0)