Skip to content

[flang][cuda] Avoid intrinsics simplification in device context #117026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CUFCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace cuf {
mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
mlir::SymbolTable &symTab);

bool isInCUDADeviceContext(mlir::Operation *op);

} // namespace cuf

#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
17 changes: 17 additions & 0 deletions flang/lib/Optimizer/Transforms/CUFCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"

/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
Expand All @@ -26,3 +28,18 @@ mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
symTab.insert(gpuMod, insertPt);
return gpuMod;
}

bool cuf::isInCUDADeviceContext(mlir::Operation *op) {
if (!op)
return false;
if (op->getParentOfType<cuf::KernelOp>() ||
op->getParentOfType<mlir::gpu::GPUFuncOp>())
return true;
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
if (auto cudaProcAttr = funcOp->getAttrOfType<cuf::ProcAttributeAttr>(
cuf::getProcAttrName())) {
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host;
}
}
return false;
}
3 changes: 3 additions & 0 deletions flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "flang/Runtime/entry-names.h"
Expand Down Expand Up @@ -1276,6 +1277,8 @@ void SimplifyIntrinsicsPass::runOnOperation() {
fir::KindMapping kindMap = fir::getKindMapping(module);
module.walk([&](mlir::Operation *op) {
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
if (cuf::isInCUDADeviceContext(op))
return;
if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
mlir::StringRef funcName = callee.getLeafReference().getValue();
// Replace call to runtime function for SUM when it has single
Expand Down
49 changes: 49 additions & 0 deletions flang/test/Fir/CUDA/cuda-device-context.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: fir-opt --simplify-intrinsics %s | FileCheck %s

func.func @_QPsum_in_device(%arg0: !fir.ref<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %arg1: i32 {fir.bindc_name = "n"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
%c5_i32 = arith.constant 5 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c-1 = arith.constant -1 : index
%0 = fir.dummy_scope : !fir.dscope
%1 = fir.shape %c-1 : (index) -> !fir.shape<1>
%2 = fir.declare %arg0(%1) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsum_in_deviceEa"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, !fir.dscope) -> !fir.ref<!fir.array<?xi32>>
%3 = fir.embox %2(%1) : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
%4 = fir.alloca i32
fir.store %arg1 to %4 : !fir.ref<i32>
%5 = fir.declare %4 dummy_scope %0 {fortran_attrs = #fir.var_attrs<value>, uniq_name = "_QFsum_in_deviceEn"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
%12 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsum_in_deviceEi"}
%13 = fir.declare %12 {uniq_name = "_QFsum_in_deviceEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
%14 = fir.address_of(@_QM__fortran_builtinsE__builtin_threadidx) : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_dim3{x:i32,y:i32,z:i32}>>
%18 = fir.load %5 : !fir.ref<i32>
%19 = fir.convert %18 : (i32) -> index
%20 = arith.cmpi sgt, %19, %c0 : index
%21 = arith.select %20, %19, %c0 : index
%22 = fir.alloca !fir.array<?xi32>, %21 {bindc_name = "auto", uniq_name = "_QFsum_in_deviceEauto"}
%23 = fir.shape %21 : (index) -> !fir.shape<1>
%24 = fir.declare %22(%23) {uniq_name = "_QFsum_in_deviceEauto"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<?xi32>>
%25 = fir.embox %24(%23) : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xi32>>
%26 = fir.undefined index
%27 = fir.slice %c1, %19, %c1 : (index, index, index) -> !fir.slice<1>
%28 = fir.embox %24(%23) [%27] : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
%29 = fir.absent !fir.box<i1>
%30 = fir.address_of(@_QQclX91d13f6e74caa2f03965d7a7c6a8fdd5) : !fir.ref<!fir.char<1,50>>
%31 = fir.convert %28 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
%32 = fir.convert %30 : (!fir.ref<!fir.char<1,50>>) -> !fir.ref<i8>
%33 = fir.convert %c0 : (index) -> i32
%34 = fir.convert %29 : (!fir.box<i1>) -> !fir.box<none>
%35 = fir.call @_FortranASumInteger4(%31, %32, %c5_i32, %33, %34) fastmath<contract> : (!fir.box<none>, !fir.ref<i8>, i32, i32, !fir.box<none>) -> i32
%36 = fir.load %13 : !fir.ref<i32>
%37 = fir.convert %36 : (i32) -> i64
%38 = fir.array_coor %2(%1) %37 : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, i64) -> !fir.ref<i32>
fir.store %35 to %38 : !fir.ref<i32>
return
}

// Check that intrinsic simplification is disabled in CUDA Fortran context. The simplified intrinsic is
// created in the module op but the device func will be migrated into a gpu module op resulting in a
// missing symbol error.
// The simplified intrinsic could also be migrated to the gpu module but the choice has not be made
// at this point.
// CHECK-LABEL: func.func @_QPsum_in_device
// CHECK-NOT: fir.call @_FortranASumInteger4x1_contract_simplified
Loading