Skip to content

Commit 35553d4

Browse files
cotaezhulenev
authored andcommitted
[mlir] Add polynomial approximation for vectorized math::Rsqrt
This patch adds a polynomial approximation that matches the approximation in Eigen. Note that the approximation only applies to vectorized inputs; the scalar rsqrt is left unmodified. The approximation is protected with a flag since it emits an AVX2 intrinsic (generated via the X86Vector). This is the only reasonably clean way that I could find to generate the exact approximation that I wanted (i.e. an identical one to Eigen's). I considered two alternatives: 1. Introduce a Rsqrt intrinsic in LLVM, which doesn't exist yet. I believe this is because there is no definition of Rsqrt that all backends could agree on, since hardware instructions that implement it have widely varying degrees of precision. This is something that the standard could mandate, but Rsqrt is not part of IEEE754, so I don't think this option is feasible. 2. Emit fdiv(1.0, sqrt) with fast math flags to allow reciprocal transformations. Although portable, this doesn't allow us to generate exactly the code we want; it is the LLVM backend, and not MLIR, who controls what code is generated based on the target CPU. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D112192
1 parent c534835 commit 35553d4

File tree

10 files changed

+177
-3
lines changed

10 files changed

+177
-3
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@ void populateExpandTanhPattern(RewritePatternSet &patterns);
1717

1818
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
1919

20-
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
20+
struct MathPolynomialApproximationOptions {
21+
// Enables the use of AVX2 intrinsics in some of the approximations.
22+
bool enableAvx2 = false;
23+
};
24+
25+
void populateMathPolynomialApproximationPatterns(
26+
RewritePatternSet &patterns,
27+
const MathPolynomialApproximationOptions &options = {});
2128

2229
} // namespace mlir
2330

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ add_mlir_dialect_library(MLIRMathTransforms
1313
MLIRPass
1414
MLIRStandard
1515
MLIRTransforms
16+
MLIRX86Vector
1617
)

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Math/IR/Math.h"
1616
#include "mlir/Dialect/Math/Transforms/Passes.h"
1717
#include "mlir/Dialect/Vector/VectorOps.h"
18+
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
1819
#include "mlir/IR/Builders.h"
1920
#include "mlir/IR/ImplicitLocOpBuilder.h"
2021
#include "mlir/Transforms/Bufferize.h"
@@ -778,13 +779,79 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
778779
return success();
779780
}
780781

782+
//----------------------------------------------------------------------------//
783+
// Rsqrt approximation.
784+
//----------------------------------------------------------------------------//
785+
786+
namespace {
787+
struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
788+
using OpRewritePattern::OpRewritePattern;
789+
790+
LogicalResult matchAndRewrite(math::RsqrtOp op,
791+
PatternRewriter &rewriter) const final;
792+
};
793+
} // namespace
794+
795+
LogicalResult
796+
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
797+
PatternRewriter &rewriter) const {
798+
auto width = vectorWidth(op.operand().getType(), isF32);
799+
// Only support already-vectorized rsqrt's.
800+
if (!width.hasValue() || *width != 8)
801+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
802+
803+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
804+
auto bcast = [&](Value value) -> Value {
805+
return broadcast(builder, value, *width);
806+
};
807+
808+
Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
809+
Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
810+
Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
811+
Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
812+
813+
Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf);
814+
815+
// Select only the inverse sqrt of positive normals (denormals are
816+
// flushed to zero).
817+
Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
818+
op.operand(), cstMinNormPos);
819+
Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
820+
op.operand(), cstPosInf);
821+
Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
822+
823+
// Compute an approximate result.
824+
Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand());
825+
826+
// Do a single step of Newton-Raphson iteration to improve the approximation.
827+
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
828+
// It is essential to evaluate the inner term like this because forming
829+
// y_n^2 may over- or underflow.
830+
Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
831+
Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
832+
Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
833+
834+
// Select the result of the Newton-Raphson step for positive normal arguments.
835+
// For other arguments, choose the output of the intrinsic. This will
836+
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
837+
// x is zero or a positive denormalized float (equivalent to flushing positive
838+
// denormalized inputs to zero).
839+
Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton);
840+
rewriter.replaceOp(op, res);
841+
842+
return success();
843+
}
844+
781845
//----------------------------------------------------------------------------//
782846

783847
void mlir::populateMathPolynomialApproximationPatterns(
784-
RewritePatternSet &patterns) {
848+
RewritePatternSet &patterns,
849+
const MathPolynomialApproximationOptions &options) {
785850
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
786851
Log1pApproximation, ExpApproximation, ExpM1Approximation,
787852
SinAndCosApproximation<true, math::SinOp>,
788853
SinAndCosApproximation<false, math::CosOp>>(
789854
patterns.getContext());
855+
if (options.enableAvx2)
856+
patterns.add<RsqrtApproximation>(patterns.getContext());
790857
}

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
2+
// RUN: mlir-opt %s -test-math-polynomial-approximation=enable-avx2 \
3+
// RUN: | FileCheck --check-prefix=AVX2 %s
24

35
// Check that all math functions lowered to approximations built from
46
// standard operations (add, mul, fma, shift, etc...).
@@ -300,3 +302,37 @@ func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
300302
%0 = math.tanh %arg0 : vector<8xf32>
301303
return %0 : vector<8xf32>
302304
}
305+
306+
// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
307+
// CHECK-LABEL: func @rsqrt_scalar
308+
// AVX2-LABEL: func @rsqrt_scalar
309+
// CHECK: math.rsqrt
310+
// AVX2: math.rsqrt
311+
func @rsqrt_scalar(%arg0: f32) -> f32 {
312+
%0 = math.rsqrt %arg0 : f32
313+
return %0 : f32
314+
}
315+
316+
// CHECK-LABEL: func @rsqrt_vector
317+
// CHECK: math.rsqrt
318+
// AVX2-LABEL: func @rsqrt_vector(
319+
// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
320+
// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32>
321+
// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32>
322+
// AVX2: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32>
323+
// AVX2: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32>
324+
// AVX2: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32>
325+
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
326+
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
327+
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
328+
// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32>
329+
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
330+
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
331+
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
332+
// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32>
333+
// AVX2: return %[[VAL_13]] : vector<8xf32>
334+
// AVX2: }
335+
func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
336+
%0 = math.rsqrt %arg0 : vector<8xf32>
337+
return %0 : vector<8xf32>
338+
}

mlir/test/lib/Dialect/Math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ add_mlir_library(MLIRMathTestPasses
1111
MLIRPass
1212
MLIRTransformUtils
1313
MLIRVector
14+
MLIRX86Vector
1415
)

mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Math/IR/Math.h"
1616
#include "mlir/Dialect/Math/Transforms/Passes.h"
1717
#include "mlir/Dialect/Vector/VectorOps.h"
18+
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021

@@ -23,23 +24,37 @@ using namespace mlir;
2324
namespace {
2425
struct TestMathPolynomialApproximationPass
2526
: public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
27+
TestMathPolynomialApproximationPass() = default;
28+
TestMathPolynomialApproximationPass(
29+
const TestMathPolynomialApproximationPass &pass) {}
30+
2631
void runOnFunction() override;
2732
void getDependentDialects(DialectRegistry &registry) const override {
2833
registry.insert<arith::ArithmeticDialect, math::MathDialect,
2934
vector::VectorDialect>();
35+
if (enableAvx2)
36+
registry.insert<x86vector::X86VectorDialect>();
3037
}
3138
StringRef getArgument() const final {
3239
return "test-math-polynomial-approximation";
3340
}
3441
StringRef getDescription() const final {
3542
return "Test math polynomial approximations";
3643
}
44+
45+
Option<bool> enableAvx2{
46+
*this, "enable-avx2",
47+
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
48+
"X86Vector dialect"),
49+
llvm::cl::init(false)};
3750
};
3851
} // end anonymous namespace
3952

4053
void TestMathPolynomialApproximationPass::runOnFunction() {
4154
RewritePatternSet patterns(&getContext());
42-
populateMathPolynomialApproximationPatterns(patterns);
55+
MathPolynomialApproximationOptions approx_options;
56+
approx_options.enableAvx2 = enableAvx2;
57+
populateMathPolynomialApproximationPatterns(patterns, approx_options);
4358
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4459
}
4560

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import sys
2+
3+
# X86Vector tests must be enabled via build flag.
4+
if not config.mlir_run_x86vector_tests:
5+
config.unsupported = True
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt %s -test-math-polynomial-approximation="enable-avx2" \
2+
// RUN: -convert-arith-to-llvm \
3+
// RUN: -convert-vector-to-llvm="enable-x86vector" \
4+
// RUN: -convert-math-to-llvm \
5+
// RUN: -convert-std-to-llvm \
6+
// RUN: -reconcile-unrealized-casts \
7+
// RUN: | mlir-cpu-runner \
8+
// RUN: -e main -entry-point-result=void -O0 \
9+
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
10+
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
11+
// RUN: | FileCheck %s
12+
13+
// -------------------------------------------------------------------------- //
14+
// rsqrt.
15+
// -------------------------------------------------------------------------- //
16+
17+
func @rsqrt() {
18+
// Sanity-check that the scalar rsqrt still works OK.
19+
// CHECK: inf
20+
%0 = arith.constant 0.0 : f32
21+
%rsqrt_0 = math.rsqrt %0 : f32
22+
vector.print %rsqrt_0 : f32
23+
// CHECK: 0.707107
24+
%two = arith.constant 2.0: f32
25+
%rsqrt_two = math.rsqrt %two : f32
26+
vector.print %rsqrt_two : f32
27+
28+
// Check that the vectorized approximation is reasonably accurate.
29+
// CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107
30+
%vec8 = arith.constant dense<2.0> : vector<8xf32>
31+
%rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32>
32+
vector.print %rsqrt_vec8 : vector<8xf32>
33+
34+
return
35+
}
36+
37+
func @main() {
38+
call @rsqrt(): () -> ()
39+
return
40+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7057,6 +7057,7 @@ cc_library(
70577057
":Support",
70587058
":Transforms",
70597059
":VectorOps",
7060+
":X86Vector",
70607061
"//llvm:Support",
70617062
],
70627063
)

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ cc_library(
406406
"//mlir:Pass",
407407
"//mlir:TransformUtils",
408408
"//mlir:VectorOps",
409+
"//mlir:X86Vector",
409410
],
410411
)
411412

0 commit comments

Comments
 (0)