Skip to content

Commit 7288f1b

Browse files
authored
[flang][cuda] Use nvvm operation for match any (#134283)
The string used for intrinsic was not the correct one "llvm.nvvm.match.any.sync.i32p". There was an extra `p` at the end. Use the NVVM operation instead so we don't duplicate it.
1 parent b393ca6 commit 7288f1b

File tree

3 files changed

+16
-23
lines changed

3 files changed

+16
-23
lines changed

flang/include/flang/Optimizer/Support/InitFIR.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Dialect/Affine/Passes.h"
2323
#include "mlir/Dialect/Complex/IR/Complex.h"
2424
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
25+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2526
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
2627
#include "mlir/InitAllDialects.h"
2728
#include "mlir/Pass/Pass.h"
@@ -37,7 +38,8 @@ namespace fir::support {
3738
mlir::scf::SCFDialect, mlir::arith::ArithDialect, \
3839
mlir::cf::ControlFlowDialect, mlir::func::FuncDialect, \
3940
mlir::vector::VectorDialect, mlir::math::MathDialect, \
40-
mlir::complex::ComplexDialect, mlir::DLTIDialect, cuf::CUFDialect
41+
mlir::complex::ComplexDialect, mlir::DLTIDialect, cuf::CUFDialect, \
42+
mlir::NVVM::NVVMDialect
4143

4244
#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect
4345

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "mlir/Dialect/Complex/IR/Complex.h"
4949
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5050
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
51+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
5152
#include "mlir/Dialect/Math/IR/Math.h"
5253
#include "mlir/Dialect/Vector/IR/VectorOps.h"
5354
#include "llvm/Support/CommandLine.h"
@@ -6552,23 +6553,15 @@ IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
65526553
assert(args.size() == 2);
65536554
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
65546555

6555-
llvm::StringRef funcName =
6556-
is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
6557-
mlir::MLIRContext *context = builder.getContext();
6558-
mlir::Type i32Ty = builder.getI32Type();
6559-
mlir::Type i64Ty = builder.getI64Type();
6560-
mlir::Type valTy = is32 ? i32Ty : i64Ty;
6556+
mlir::Value arg1 = args[1];
6557+
if (arg1.getType().isF32() || arg1.getType().isF64())
6558+
arg1 = builder.create<fir::ConvertOp>(
6559+
loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
65616560

6562-
mlir::FunctionType ftype =
6563-
mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
6564-
auto funcOp = builder.createFunction(loc, funcName, ftype);
6565-
llvm::SmallVector<mlir::Value> filteredArgs;
6566-
filteredArgs.push_back(args[0]);
6567-
if (args[1].getType().isF32() || args[1].getType().isF64())
6568-
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
6569-
else
6570-
filteredArgs.push_back(args[1]);
6571-
return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
6561+
return builder
6562+
.create<mlir::NVVM::MatchSyncOp>(loc, resultType, args[0], arg1,
6563+
mlir::NVVM::MatchSyncKind::any)
6564+
.getResult();
65726565
}
65736566

65746567
// MATMUL

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,10 @@ attributes(device) subroutine testMatchAny()
143143
end subroutine
144144

145145
! CHECK-LABEL: func.func @_QPtestmatchany()
146-
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
147-
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
148-
! CHECK: fir.convert %{{.*}} : (f32) -> i32
149-
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
150-
! CHECK: fir.convert %{{.*}} : (f64) -> i64
151-
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
146+
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
147+
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
148+
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
149+
! CHECK: %{{.*}} = nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
152150

153151
attributes(device) subroutine testAtomic(aa, n)
154152
integer :: aa(*)

0 commit comments

Comments
 (0)