Skip to content

Commit 164c7af

Browse files
committed
[MLIR][Math] Add constant folder for powf
Constant fold powf, given two constant operands and a compatible type Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121845
1 parent d65cc85 commit 164c7af

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def Math_PowFOp : Math_FloatBinaryOp<"powf"> {
684684
%x = math.powf %y, %z : tensor<4x?xbf16>
685685
```
686686
}];
687+
let hasFolder = 1;
687688
}
688689

689690
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,40 @@ OpFoldResult math::Log2Op::fold(ArrayRef<Attribute> operands) {
6363
return FloatAttr::get(getType(), log2(apf.convertToDouble()));
6464

6565
if (ft.getWidth() == 32)
66-
return FloatAttr::get(getType(), log2f(apf.convertToDouble()));
66+
return FloatAttr::get(getType(), log2f(apf.convertToFloat()));
67+
68+
return {};
69+
}
70+
71+
//===----------------------------------------------------------------------===//
72+
// PowFOp folder
73+
//===----------------------------------------------------------------------===//
74+
75+
OpFoldResult math::PowFOp::fold(ArrayRef<Attribute> operands) {
76+
auto ft = getType().dyn_cast<FloatType>();
77+
if (!ft)
78+
return {};
79+
80+
APFloat vals[2]{APFloat(ft.getFloatSemantics()),
81+
APFloat(ft.getFloatSemantics())};
82+
for (int i = 0; i < 2; ++i) {
83+
if (!operands[i])
84+
return {};
85+
86+
auto attr = operands[i].dyn_cast<FloatAttr>();
87+
if (!attr)
88+
return {};
89+
90+
vals[i] = attr.getValue();
91+
}
92+
93+
if (ft.getWidth() == 64)
94+
return FloatAttr::get(
95+
getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble()));
96+
97+
if (ft.getWidth() == 32)
98+
return FloatAttr::get(
99+
getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat()));
67100

68101
return {};
69102
}

mlir/test/Dialect/Math/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,12 @@ func @log2_nofold2_64() -> f64 {
7373
%r = math.log2 %c : f64
7474
return %r : f64
7575
}
76+
77+
// CHECK-LABEL: @powf_fold
78+
// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32
79+
// CHECK: return %[[cst]]
80+
func @powf_fold() -> f32 {
81+
%c = arith.constant 2.0 : f32
82+
%r = math.powf %c, %c : f32
83+
return %r : f32
84+
}

0 commit comments

Comments
 (0)