Skip to content

Commit cd92c6a

Browse files
authored
[flang][cuda] Run target rewrite in gpu.module (#118592)
Apply signature conversion for `func.func` in the gpu.module. More work will need to be done for gpu.func op and implement the NVVM ABI for conversion in the gpu module.
1 parent d057b53 commit cd92c6a

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
2828
#include "flang/Optimizer/Support/DataLayout.h"
2929
#include "mlir/Dialect/DLTI/DLTI.h"
30+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3031
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3132
#include "mlir/Transforms/DialectConversion.h"
3233
#include "llvm/ADT/STLExtras.h"
@@ -720,6 +721,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
720721

721722
convertSignature(fn);
722723
}
724+
725+
for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
726+
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
727+
convertSignature(fn);
728+
723729
return mlir::success();
724730
}
725731

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: fir-opt --target-rewrite %s | FileCheck %s
2+
3+
gpu.module @testmod {
4+
gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
5+
%0 = fir.alloca i64
6+
%1 = fir.load %0 : !fir.ref<i64>
7+
%2 = fir.load %arg0 : !fir.ref<complex<f64>>
8+
%3 = fir.call @_FortranAzpowk(%2, %1) fastmath<contract> : (complex<f64>, i64) -> complex<f64>
9+
gpu.return
10+
}
11+
func.func private @_FortranAzpowk(complex<f64>, i64) -> complex<f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
12+
}
13+
14+
// CHECK-LABEL: gpu.func @_QPvcpowdk
15+
// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
16+
// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}

0 commit comments

Comments
 (0)