Skip to content

Commit 095ce65

Browse files
committed
[mlir][math] Simplify pow(x, 0.75) into sqrt(sqrt(x)) * sqrt(x).
Trivial simplification for CPU2017/503.bwaves resulting in 3.89% speed-up on icelake. Differential Revision: https://reviews.llvm.org/D137351
1 parent 234e08e commit 095ce65

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
109109
return success();
110110
}
111111

112+
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113+
if (isExponentValue(0.75)) {
114+
Value pow_half = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115+
Value pow_quarter = rewriter.create<math::SqrtOp>(op.getLoc(), pow_half);
116+
rewriter.replaceOpWithNewOp<arith::MulFOp>(
117+
op, ValueRange{pow_half, pow_quarter});
118+
return success();
119+
}
120+
112121
return failure();
113122
}
114123

mlir/test/Dialect/Math/algebraic-simplification.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
7474
return %0, %1 : f32, vector<4xf32>
7575
}
7676

77+
// CHECK-LABEL: @pow_0_75
78+
func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
79+
// CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0
80+
// CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]]
81+
// CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]]
82+
// CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1
83+
// CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]]
84+
// CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]]
85+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
86+
%c = arith.constant 0.75 : f32
87+
%v = arith.constant dense <0.75> : vector<4xf32>
88+
%0 = math.powf %arg0, %c : f32
89+
%1 = math.powf %arg1, %v : vector<4xf32>
90+
return %0, %1 : f32, vector<4xf32>
91+
}
92+
7793
// CHECK-LABEL: @ipowi_zero_exp(
7894
// CHECK-SAME: %[[ARG0:.+]]: i32
7995
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>

0 commit comments

Comments
 (0)