Skip to content

Commit dfac79a

Browse files
ftynseJaddyen
authored andcommitted
[mlir] fix MemRefToLLVM lowering of atomic operations (llvm#139045)
We have been confusingly, and arguably incorrectly, lowering `m**imumf` atomic RMW operations in the MemRef dialect to `fm**` atomic RMW operations in the LLVM dialect, which have different NaN-propagation semantics: `m**imumf` propagates NaNs from either operand whereas `fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN operand. This also contradicts the lowering of `arith.m**imumf` and `arith.m**numf` operations. Change the lowering to match the terminology in arith. Add tests for these lowerings. Keep a debug message in case of surprising behavior downstream (the code may be producing more NaNs now).
1 parent cd03653 commit dfac79a

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#include "llvm/Support/MathExtras.h"
2929
#include <optional>
3030

31+
#define DEBUG_TYPE "memref-to-llvm"
32+
#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
33+
3134
namespace mlir {
3235
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
3336
#include "mlir/Conversion/Passes.h.inc"
@@ -1782,12 +1785,22 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
17821785
case arith::AtomicRMWKind::assign:
17831786
return LLVM::AtomicBinOp::xchg;
17841787
case arith::AtomicRMWKind::maximumf:
1788+
// TODO: remove this by end of 2025.
1789+
LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
1790+
"from fmax to fmaximum, expect more NaNs");
1791+
return LLVM::AtomicBinOp::fmaximum;
1792+
case arith::AtomicRMWKind::maxnumf:
17851793
return LLVM::AtomicBinOp::fmax;
17861794
case arith::AtomicRMWKind::maxs:
17871795
return LLVM::AtomicBinOp::max;
17881796
case arith::AtomicRMWKind::maxu:
17891797
return LLVM::AtomicBinOp::umax;
17901798
case arith::AtomicRMWKind::minimumf:
1799+
// TODO: remove this by end of 2025.
1800+
LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
1801+
"from fmin to fminimum, expect more NaNs");
1802+
return LLVM::AtomicBinOp::fminimum;
1803+
case arith::AtomicRMWKind::minnumf:
17911804
return LLVM::AtomicBinOp::fmin;
17921805
case arith::AtomicRMWKind::mins:
17931806
return LLVM::AtomicBinOp::min;

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,19 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
452452
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
453453
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
454454
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
455+
memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
456+
// CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
457+
memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
458+
// CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
459+
memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
460+
// CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
461+
memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
462+
// CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
455463
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
456464
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
457465
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
458466
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
459-
// CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
467+
// CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
460468
return
461469
}
462470

0 commit comments

Comments
 (0)