Skip to content

Commit fcabccd

Browse files
Lubomir LitchevFrank Laub
authored andcommitted
[MLIR] Add the sqrt operation to mlir.
Summary: Add and pipe through the sqrt operation for Standard and LLVM dialects. Reviewers: nicolasvasilache, ftynse Reviewed By: ftynse Subscribers: frej, ftynse, merge_guards_bot, flaub, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73571
1 parent 38ab3b8 commit fcabccd

File tree

7 files changed

+63
-3
lines changed

7 files changed

+63
-3
lines changed

mlir/docs/Dialects/Standard.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,25 @@ operand and returns one result of the same type. This type may be a float
587587
scalar type, a vector whose element type is float, or a tensor of floats. It
588588
has no standard attributes.
589589

590+
### 'sqrt' operation
591+
592+
Syntax:
593+
594+
```
595+
operation ::= ssa-id `=` `sqrt` ssa-use `:` type
596+
```
597+
598+
Examples:
599+
600+
```mlir
601+
// Scalar square root value.
602+
%a = sqrt %b : f64
603+
// SIMD vector element-wise square root value.
604+
%f = sqrt %g : vector<4xf32>
605+
// Tensor element-wise square root value.
606+
%x = sqrt %y : tensor<4x?xf32>
607+
```
608+
590609
### 'tanh' operation
591610

592611
Syntax:

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
716716
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
717717
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
718718
def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
719+
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
719720

720721
def LLVM_LogOp : LLVM_Op<"intr.log", [NoSideEffect]>,
721722
Arguments<(ins LLVM_Type:$in)>,

mlir/include/mlir/Dialect/StandardOps/Ops.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,16 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
14021402
let hasCanonicalizer = 1;
14031403
}
14041404

1405+
def SqrtOp : FloatUnaryOp<"sqrt"> {
1406+
let summary = "sqrt of the specified value";
1407+
let description = [{
1408+
The `sqrt` operation computes the square root. It takes one operand and
1409+
returns one result of the same type. This type may be a float scalar type, a
1410+
vector whose element type is float, or a tensor of floats. It has no standard
1411+
attributes.
1412+
}];
1413+
}
1414+
14051415
def TanhOp : FloatUnaryOp<"tanh"> {
14061416
let summary = "hyperbolic tangent of the specified value";
14071417
let description = [{

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,9 @@ struct SignedDivIOpLowering
807807
: public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
808808
using Super::Super;
809809
};
810+
struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
811+
using Super::Super;
812+
};
810813
struct UnsignedDivIOpLowering
811814
: public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
812815
using Super::Super;
@@ -2108,6 +2111,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
21082111
SignedShiftRightOpLowering,
21092112
SplatOpLowering,
21102113
SplatNdOpLowering,
2114+
SqrtOpLowering,
21112115
SubFOpLowering,
21122116
SubIOpLowering,
21132117
TanhOpLowering,

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>
398398
}
399399

400400
// CHECK-LABEL: @ops
401-
func @ops(f32, f32, i32, i32) -> (f32, i32) {
402-
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32):
401+
func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
402+
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
403403
// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : !llvm.float
404404
%0 = subf %arg0, %arg1: f32
405405
// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : !llvm.i32
@@ -440,7 +440,10 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
440440
%19 = shift_right_signed %arg2, %arg3 : i32
441441
// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
442442
%20 = shift_right_unsigned %arg2, %arg3 : i32
443-
443+
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
444+
%21 = std.sqrt %arg0 : f32
445+
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
446+
%22 = std.sqrt %arg4 : f64
444447
return %0, %4 : f32, i32
445448
}
446449

mlir/test/IR/core-ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
494494
// CHECK: %{{[0-9]+}} = shift_right_unsigned %cst_4, %cst_4 : tensor<42xi32>
495495
%138 = shift_right_unsigned %tci32, %tci32 : tensor<42 x i32>
496496

497+
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
498+
%139 = "std.sqrt"(%f) : (f32) -> f32
499+
500+
// CHECK: %{{[0-9]+}} = sqrt %arg1 : f32
501+
%140 = sqrt %f : f32
502+
503+
// CHECK: %{{[0-9]+}} = sqrt %cst_8 : vector<4xf32>
504+
%141 = sqrt %vcf32 : vector<4xf32>
505+
506+
// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
507+
%142 = sqrt %t : tensor<4x4x?xf32>
508+
497509
return
498510
}
499511

mlir/test/Target/llvmir-intrinsics.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ llvm.func @fabs_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
5959
llvm.return
6060
}
6161

62+
// CHECK-LABEL: @sqrt_test
63+
llvm.func @sqrt_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
64+
// CHECK: call float @llvm.sqrt.f32
65+
"llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
66+
// CHECK: call <8 x float> @llvm.sqrt.v8f32
67+
"llvm.intr.sqrt"(%arg1) : (!llvm<"<8 x float>">) -> !llvm<"<8 x float>">
68+
llvm.return
69+
}
70+
6271
// CHECK-LABEL: @ceil_test
6372
llvm.func @ceil_test(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">) {
6473
// CHECK: call float @llvm.ceil.f32
@@ -100,6 +109,8 @@ llvm.func @copysign_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<
100109
// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
101110
// CHECK: declare float @llvm.fabs.f32(float)
102111
// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
112+
// CHECK: declare float @llvm.sqrt.f32(float)
113+
// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
103114
// CHECK: declare float @llvm.ceil.f32(float)
104115
// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
105116
// CHECK: declare float @llvm.cos.f32(float)

0 commit comments

Comments
 (0)