Skip to content

Commit 9ee4fdf

Browse files
authored
[flang][cuda] Introduce stream cast op (#136050)
Cast a stream object reference as a GPU async token. This is useful to be able to connect the stream representation of CUDA Fortran and the async mechanism of the GPU dialect. This op will later become a no op.
1 parent 728f6de commit 9ee4fdf

File tree

5 files changed

+55
-3
lines changed

5 files changed

+55
-3
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
1818
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
1919
include "flang/Optimizer/Dialect/FIRTypes.td"
2020
include "flang/Optimizer/Dialect/FIRAttr.td"
21+
include "mlir/Dialect/GPU/IR/GPUBase.td"
2122
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
2223
include "mlir/Interfaces/LoopLikeInterface.td"
2324
include "mlir/IR/BuiltinAttributes.td"
@@ -370,4 +371,25 @@ def cuf_SharedMemoryOp
370371
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];
371372
}
372373

374+
def cuf_StreamCastOp : cuf_Op<"stream_cast", [NoMemoryEffect]> {
375+
let summary = "Adapt a stream value to a GPU async token";
376+
377+
let description = [{
378+
Cast a stream object reference as a GPU async token. This is useful to be
379+
able to connect the stream representation of CUDA Fortran and the async
380+
mechanism of the GPU dialect.
381+
Later in the lowering this will become a no op.
382+
}];
383+
384+
let arguments = (ins fir_ReferenceType:$stream);
385+
386+
let results = (outs GPU_AsyncToken:$token);
387+
388+
let assemblyFormat = [{
389+
$stream attr-dict `:` type($stream)
390+
}];
391+
392+
let hasVerifier = 1;
393+
}
394+
373395
#endif // FORTRAN_DIALECT_CUF_CUF_OPS

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ namespace fir::support {
4040
mlir::cf::ControlFlowDialect, mlir::func::FuncDialect, \
4141
mlir::vector::VectorDialect, mlir::math::MathDialect, \
4242
mlir::complex::ComplexDialect, mlir::DLTIDialect, cuf::CUFDialect, \
43-
mlir::NVVM::NVVMDialect
43+
mlir::NVVM::NVVMDialect, mlir::gpu::GPUDialect
4444

4545
#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect
4646

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,17 @@ void cuf::SharedMemoryOp::build(
319319
result.addAttributes(attributes);
320320
}
321321

322+
//===----------------------------------------------------------------------===//
323+
// StreamCastOp
324+
//===----------------------------------------------------------------------===//
325+
326+
llvm::LogicalResult cuf::StreamCastOp::verify() {
327+
auto refTy = mlir::dyn_cast<fir::ReferenceType>(getStream().getType());
328+
if (!refTy.getEleTy().isInteger(64))
329+
return emitOpError("stream is expected to be a i64 reference");
330+
return mlir::success();
331+
}
332+
322333
// Tablegen operators
323334

324335
#define GET_OP_CLASSES

flang/test/Fir/CUDA/cuda-stream.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: fir-opt --split-input-file %s | FileCheck %s
2+
3+
module attributes {gpu.container_module} {
4+
gpu.module @cuda_device_mod {
5+
gpu.func @_QMmod1Psub1() kernel {
6+
gpu.return
7+
}
8+
}
9+
func.func @_QMmod1Phost_sub() {
10+
%0 = fir.alloca i64
11+
%1 = arith.constant 1 : index
12+
%asyncTok = cuf.stream_cast %0 : !fir.ref<i64>
13+
gpu.launch_func [%asyncTok] @cuda_device_mod::@_QMmod1Psub1 blocks in (%1, %1, %1) threads in (%1, %1, %1) args() {cuf.proc_attr = #cuf.cuda_proc<grid_global>}
14+
return
15+
}
16+
}
17+
18+
// CHECK-LABEL: func.func @_QMmod1Phost_sub()
19+
// CHECK: %[[STREAM:.*]] = fir.alloca i64
20+
// CHECK: %[[TOKEN:.*]] = cuf.stream_cast %[[STREAM]] : <i64>
21+
// CHECK: gpu.launch_func [%[[TOKEN]]] @cuda_device_mod::@_QMmod1Psub1

flang/tools/fir-opt/fir-opt.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ int main(int argc, char **argv) {
4444
#endif
4545
DialectRegistry registry;
4646
fir::support::registerDialects(registry);
47-
registry.insert<mlir::gpu::GPUDialect>();
48-
registry.insert<mlir::NVVM::NVVMDialect>();
4947
fir::support::addFIRExtensions(registry);
5048
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
5149
registry));

0 commit comments

Comments
 (0)