Skip to content

Commit 52ccd6c

Browse files
ubfxkuhar
authored andcommitted
[mlir][arith] Match folding of arith.remf to llvm.frem semantics (llvm#96537)
There are multiple ways to define a remainder operation. Depending on the definition, the result could be either always positive or have the sign of the dividend. The pattern lowering `arith.remf` to LLVM assumes that the semantics match `llvm.frem`, which seems to be reasonable. The folder, however, is implemented via `APFloat::remainder()` which has different semantics. This patch matches the folding behaviour to lowering behavior by using `APFloat::mod()`, which matches the behavior of `llvm.frem` and libm's `fmod()`. It also updates the documentation of `arith.remf` to explain this behavior: The sign of the result of the remainder operation always matches the sign of the dividend (LHS operand). frem documentation: https://llvm.org/docs/LangRef.html#frem-instruction Fix llvm#94431 --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent ffbb0d0 commit 52ccd6c

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,10 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
11331133

11341134
def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
11351135
let summary = "floating point division remainder operation";
1136+
let description = [{
1137+
Returns the floating point division remainder.
1138+
The remainder has the same sign as the dividend (lhs operand).
1139+
}];
11361140
let hasFolder = 1;
11371141
}
11381142

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,10 @@ OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
12041204
return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
12051205
[](const APFloat &a, const APFloat &b) {
12061206
APFloat result(a);
1207-
(void)result.remainder(b);
1207+
// APFloat::mod() offers the remainder
1208+
// behavior we want, i.e. the result has
1209+
// the sign of LHS operand.
1210+
(void)result.mod(b);
12081211
return result;
12091212
});
12101213
}

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,7 +2467,7 @@ func.func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
24672467
// -----
24682468

24692469
// CHECK-LABEL: @test_remf(
2470-
// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32
2470+
// CHECK: %[[res:.+]] = arith.constant 1.000000e+00 : f32
24712471
// CHECK: return %[[res]]
24722472
func.func @test_remf() -> (f32) {
24732473
%v1 = arith.constant 3.0 : f32
@@ -2476,11 +2476,24 @@ func.func @test_remf() -> (f32) {
24762476
return %0 : f32
24772477
}
24782478

2479+
// CHECK-LABEL: @test_remf2(
2480+
// CHECK: %[[respos:.+]] = arith.constant 1.000000e+00 : f32
2481+
// CHECK: %[[resneg:.+]] = arith.constant -1.000000e+00 : f32
2482+
// CHECK: return %[[respos]], %[[resneg]]
2483+
func.func @test_remf2() -> (f32, f32) {
2484+
%v1 = arith.constant 3.0 : f32
2485+
%v2 = arith.constant -2.0 : f32
2486+
%v3 = arith.constant -3.0 : f32
2487+
%0 = arith.remf %v1, %v2 : f32
2488+
%1 = arith.remf %v3, %v2 : f32
2489+
return %0, %1 : f32, f32
2490+
}
2491+
24792492
// CHECK-LABEL: @test_remf_vec(
24802493
// CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32>
24812494
// CHECK: return %[[res]]
24822495
func.func @test_remf_vec() -> (vector<4xf32>) {
2483-
%v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
2496+
%v1 = arith.constant dense<[1.0, 2.0, -3.0, 4.0]> : vector<4xf32>
24842497
%v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32>
24852498
%0 = arith.remf %v1, %v2 : vector<4xf32>
24862499
return %0 : vector<4xf32>

0 commit comments

Comments
 (0)