Skip to content

[WebAssembly] Implement all f16x8 unary instructions. #94063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2024

Conversation

brendandahl
Copy link
Contributor

All of these instructions can be generated using regular LL intrinsics.

Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md

@brendandahl brendandahl requested a review from aheejin May 31, 2024 22:26
@llvmbot llvmbot added backend:WebAssembly mc Machine (object) code labels May 31, 2024
@brendandahl brendandahl requested a review from dschuff May 31, 2024 22:26
@llvmbot
Copy link
Member

llvmbot commented May 31, 2024

@llvm/pr-subscribers-backend-webassembly

@llvm/pr-subscribers-mc

Author: Brendan Dahl (brendandahl)

Changes

All of these instructions can be generated using regular LL intrinsics.

Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md


Full diff: https://github.com/llvm/llvm-project/pull/94063.diff

3 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td (+18-2)
  • (modified) llvm/test/CodeGen/WebAssembly/half-precision.ll (+89)
  • (modified) llvm/test/MC/WebAssembly/simd-encodings.s (+21)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index b800123ac0fff..3c97befcea1a4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -828,12 +828,18 @@ multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
             (!cast<NI>(NAME) $lhs, $rhs)>;
 }
 
-multiclass SIMDUnary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
+multiclass SIMDUnary<Vec vec, SDPatternOperator node, string name,
+                     bits<32> simdop, list<Predicate> reqs = []> {
   defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$v), (outs), (ins),
                       [(set (vec.vt V128:$dst),
                         (vec.vt (node (vec.vt V128:$v))))],
                       vec.prefix#"."#name#"\t$dst, $v",
-                      vec.prefix#"."#name, simdop>;
+                      vec.prefix#"."#name, simdop, reqs>;
+}
+
+multiclass HalfPrecisionUnary<Vec vec, SDPatternOperator node, string name,
+                              bits<32> simdop> {
+  defm "" : SIMDUnary<vec, node, name, simdop, [HasHalfPrecision]>;
 }
 
 // Bitwise logic: v128.not
@@ -1190,6 +1196,10 @@ defm EXTMUL_HIGH_U :
 multiclass SIMDUnaryFP<SDNode node, string name, bits<32> baseInst> {
   defm "" : SIMDUnary<F32x4, node, name, baseInst>;
   defm "" : SIMDUnary<F64x2, node, name, !add(baseInst, 12)>;
+  // Unlike F32x4 and F64x2 there's not a gap in the opcodes between "neg" and
+  // "sqrt" so subtract one from the offset.
+  defm "" : HalfPrecisionUnary<F16x8, node, name,
+                               !add(baseInst,!if(!eq(name, "sqrt"), 80, 81))>;
 }
 
 // Absolute value: abs
@@ -1210,14 +1220,20 @@ defm CEIL : SIMDUnary<F64x2, fceil, "ceil", 0x74>;
 defm FLOOR : SIMDUnary<F64x2, ffloor, "floor", 0x75>;
 defm TRUNC: SIMDUnary<F64x2, ftrunc, "trunc", 0x7a>;
 defm NEAREST: SIMDUnary<F64x2, fnearbyint, "nearest", 0x94>;
+defm CEIL : HalfPrecisionUnary<F16x8, fceil, "ceil", 0x13c>;
+defm FLOOR : HalfPrecisionUnary<F16x8, ffloor, "floor", 0x13d>;
+defm TRUNC : HalfPrecisionUnary<F16x8, ftrunc, "trunc", 0x13e>;
+defm NEAREST : HalfPrecisionUnary<F16x8, fnearbyint, "nearest", 0x13f>;
 
 // WebAssembly doesn't expose inexact exceptions, so map frint to fnearbyint.
 def : Pat<(v4f32 (frint (v4f32 V128:$src))), (NEAREST_F32x4 V128:$src)>;
 def : Pat<(v2f64 (frint (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
+def : Pat<(v8f16 (frint (v8f16 V128:$src))), (NEAREST_F16x8 V128:$src)>;
 
 // WebAssembly always rounds ties-to-even, so map froundeven to fnearbyint.
 def : Pat<(v4f32 (froundeven (v4f32 V128:$src))), (NEAREST_F32x4 V128:$src)>;
 def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
+def : Pat<(v8f16 (froundeven (v8f16 V128:$src))), (NEAREST_F16x8 V128:$src)>;
 
 //===----------------------------------------------------------------------===//
 // Floating-point binary arithmetic
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index cca25b485cdf2..0f0a159091514 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -157,3 +157,92 @@ define <8 x i1> @compare_oge_v8f16 (<8 x half> %x, <8 x half> %y) {
   %res = fcmp oge <8 x half> %x, %y
   ret <8 x i1> %res
 }
+
+; CHECK-LABEL: abs_v8f16:
+; CHECK-NEXT:  .functype abs_v8f16 (v128) -> (v128)
+; CHECK-NEXT:  f16x8.abs $push0=, $0
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.fabs.v8f16(<8 x half>) nounwind readnone
+define <8 x half> @abs_v8f16(<8 x half> %x) {
+  %a = call <8 x half> @llvm.fabs.v8f16(<8 x half> %x)
+  ret <8 x half> %a
+}
+
+; CHECK-LABEL: neg_v8f16:
+; CHECK-NEXT:  .functype neg_v8f16 (v128) -> (v128)
+; CHECK-NEXT:  f16x8.neg $push0=, $0
+; CHECK-NEXT:  return $pop0
+define <8 x half> @neg_v8f16(<8 x half> %x) {
+  %a = fsub nsz <8 x half> <half 0., half 0., half 0., half 0., half 0., half 0., half 0., half 0.>, %x
+  ret <8 x half> %a
+}
+
+; CHECK-LABEL: sqrt_v8f16:
+; CHECK-NEXT:  .functype sqrt_v8f16 (v128) -> (v128)
+; CHECK-NEXT:  f16x8.sqrt $push0=, $0
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.sqrt.v8f16(<8 x half> %x)
+define <8 x half> @sqrt_v8f16(<8 x half> %x) {
+  %a = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %x)
+  ret <8 x half> %a
+}
+
+; CHECK-LABEL: ceil_v8f16:
+; CHECK-NEXT:  .functype ceil_v8f16 (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.ceil $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.ceil.v8f16(<8 x half>)
+define <8 x half> @ceil_v8f16(<8 x half> %a) {
+  %v = call <8 x half> @llvm.ceil.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: floor_v8f16:
+; CHECK-NEXT:  .functype floor_v8f16 (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.floor $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.floor.v8f16(<8 x half>)
+define <8 x half> @floor_v8f16(<8 x half> %a) {
+  %v = call <8 x half> @llvm.floor.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: trunc_v8f16:
+; CHECK-NEXT:  .functype trunc_v8f16 (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.trunc $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.trunc.v8f16(<8 x half>)
+define <8 x half> @trunc_v8f16(<8 x half> %a) {
+  %v = call <8 x half> @llvm.trunc.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: nearest_v8f16:
+; CHECK-NEXT:  .functype nearest_v8f16 (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.nearbyint.v8f16(<8 x half>)
+define <8 x half> @nearest_v8f16(<8 x half> %a) {
+  %v = call <8 x half> @llvm.nearbyint.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: nearest_v8f16_via_rint:
+; CHECK-NEXT:  .functype nearest_v8f16_via_rint (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.rint.v8f16(<8 x half>)
+define <8 x half> @nearest_v8f16_via_rint(<8 x half> %a) {
+  %v = call <8 x half> @llvm.rint.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: nearest_v8f16_via_roundeven:
+; CHECK-NEXT:  .functype nearest_v8f16_via_roundeven (v128) -> (v128){{$}}
+; CHECK-NEXT:  f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
+; CHECK-NEXT:  return $pop[[R]]{{$}}
+declare <8 x half> @llvm.roundeven.v8f16(<8 x half>)
+define <8 x half> @nearest_v8f16_via_roundeven(<8 x half> %a) {
+  %v = call <8 x half> @llvm.roundeven.v8f16(<8 x half> %a)
+  ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index aa70815245e5d..8e4d9301b6026 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -893,4 +893,25 @@ main:
     # CHECK: f16x8.ge # encoding: [0xfd,0xc5,0x02]
     f16x8.ge
 
+    # CHECK: f16x8.abs # encoding: [0xfd,0xb1,0x02]
+    f16x8.abs
+
+    # CHECK: f16x8.neg # encoding: [0xfd,0xb2,0x02]
+    f16x8.neg
+
+    # CHECK: f16x8.sqrt # encoding: [0xfd,0xb3,0x02]
+    f16x8.sqrt
+
+    # CHECK: f16x8.ceil # encoding: [0xfd,0xbc,0x02]
+    f16x8.ceil
+
+    # CHECK: f16x8.floor # encoding: [0xfd,0xbd,0x02]
+    f16x8.floor
+
+    # CHECK: f16x8.trunc # encoding: [0xfd,0xbe,0x02]
+    f16x8.trunc
+
+    # CHECK: f16x8.nearest # encoding: [0xfd,0xbf,0x02]
+    f16x8.nearest
+
     end_function

@brendandahl brendandahl merged commit dfd1a2f into llvm:main Jun 4, 2024
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:WebAssembly mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants