Skip to content

[WebAssembly] Implement all f16x8 unary instructions. #94063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,18 @@ multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
(!cast<NI>(NAME) $lhs, $rhs)>;
}

multiclass SIMDUnary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
multiclass SIMDUnary<Vec vec, SDPatternOperator node, string name,
bits<32> simdop, list<Predicate> reqs = []> {
defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$v), (outs), (ins),
[(set (vec.vt V128:$dst),
(vec.vt (node (vec.vt V128:$v))))],
vec.prefix#"."#name#"\t$dst, $v",
vec.prefix#"."#name, simdop>;
vec.prefix#"."#name, simdop, reqs>;
}

multiclass HalfPrecisionUnary<Vec vec, SDPatternOperator node, string name,
bits<32> simdop> {
defm "" : SIMDUnary<vec, node, name, simdop, [HasHalfPrecision]>;
}

// Bitwise logic: v128.not
Expand Down Expand Up @@ -1190,6 +1196,10 @@ defm EXTMUL_HIGH_U :
multiclass SIMDUnaryFP<SDNode node, string name, bits<32> baseInst> {
defm "" : SIMDUnary<F32x4, node, name, baseInst>;
defm "" : SIMDUnary<F64x2, node, name, !add(baseInst, 12)>;
// Unlike F32x4 and F64x2 there's not a gap in the opcodes between "neg" and
// "sqrt" so subtract one from the offset.
defm "" : HalfPrecisionUnary<F16x8, node, name,
!add(baseInst,!if(!eq(name, "sqrt"), 80, 81))>;
}

// Absolute value: abs
Expand All @@ -1210,14 +1220,20 @@ defm CEIL : SIMDUnary<F64x2, fceil, "ceil", 0x74>;
defm FLOOR : SIMDUnary<F64x2, ffloor, "floor", 0x75>;
defm TRUNC: SIMDUnary<F64x2, ftrunc, "trunc", 0x7a>;
defm NEAREST: SIMDUnary<F64x2, fnearbyint, "nearest", 0x94>;
defm CEIL : HalfPrecisionUnary<F16x8, fceil, "ceil", 0x13c>;
defm FLOOR : HalfPrecisionUnary<F16x8, ffloor, "floor", 0x13d>;
defm TRUNC : HalfPrecisionUnary<F16x8, ftrunc, "trunc", 0x13e>;
defm NEAREST : HalfPrecisionUnary<F16x8, fnearbyint, "nearest", 0x13f>;

// WebAssembly doesn't expose inexact exceptions, so map frint to fnearbyint.
def : Pat<(v4f32 (frint (v4f32 V128:$src))), (NEAREST_F32x4 V128:$src)>;
def : Pat<(v2f64 (frint (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
def : Pat<(v8f16 (frint (v8f16 V128:$src))), (NEAREST_F16x8 V128:$src)>;

// WebAssembly always rounds ties-to-even, so map froundeven to fnearbyint.
def : Pat<(v4f32 (froundeven (v4f32 V128:$src))), (NEAREST_F32x4 V128:$src)>;
def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
def : Pat<(v8f16 (froundeven (v8f16 V128:$src))), (NEAREST_F16x8 V128:$src)>;

//===----------------------------------------------------------------------===//
// Floating-point binary arithmetic
Expand Down
89 changes: 89 additions & 0 deletions llvm/test/CodeGen/WebAssembly/half-precision.ll
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,92 @@ define <8 x i1> @compare_oge_v8f16 (<8 x half> %x, <8 x half> %y) {
%res = fcmp oge <8 x half> %x, %y
ret <8 x i1> %res
}

; CHECK-LABEL: abs_v8f16:
; CHECK-NEXT: .functype abs_v8f16 (v128) -> (v128)
; CHECK-NEXT: f16x8.abs $push0=, $0
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.fabs.v8f16(<8 x half>) nounwind readnone
define <8 x half> @abs_v8f16(<8 x half> %x) {
%a = call <8 x half> @llvm.fabs.v8f16(<8 x half> %x)
ret <8 x half> %a
}

; CHECK-LABEL: neg_v8f16:
; CHECK-NEXT: .functype neg_v8f16 (v128) -> (v128)
; CHECK-NEXT: f16x8.neg $push0=, $0
; CHECK-NEXT: return $pop0
define <8 x half> @neg_v8f16(<8 x half> %x) {
%a = fsub nsz <8 x half> <half 0., half 0., half 0., half 0., half 0., half 0., half 0., half 0.>, %x
ret <8 x half> %a
}

; CHECK-LABEL: sqrt_v8f16:
; CHECK-NEXT: .functype sqrt_v8f16 (v128) -> (v128)
; CHECK-NEXT: f16x8.sqrt $push0=, $0
; CHECK-NEXT: return $pop0
declare <8 x half> @llvm.sqrt.v8f16(<8 x half> %x)
define <8 x half> @sqrt_v8f16(<8 x half> %x) {
%a = call <8 x half> @llvm.sqrt.v8f16(<8 x half> %x)
ret <8 x half> %a
}

; CHECK-LABEL: ceil_v8f16:
; CHECK-NEXT: .functype ceil_v8f16 (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.ceil $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.ceil.v8f16(<8 x half>)
define <8 x half> @ceil_v8f16(<8 x half> %a) {
%v = call <8 x half> @llvm.ceil.v8f16(<8 x half> %a)
ret <8 x half> %v
}

; CHECK-LABEL: floor_v8f16:
; CHECK-NEXT: .functype floor_v8f16 (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.floor $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.floor.v8f16(<8 x half>)
define <8 x half> @floor_v8f16(<8 x half> %a) {
%v = call <8 x half> @llvm.floor.v8f16(<8 x half> %a)
ret <8 x half> %v
}

; CHECK-LABEL: trunc_v8f16:
; CHECK-NEXT: .functype trunc_v8f16 (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.trunc $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.trunc.v8f16(<8 x half>)
define <8 x half> @trunc_v8f16(<8 x half> %a) {
%v = call <8 x half> @llvm.trunc.v8f16(<8 x half> %a)
ret <8 x half> %v
}

; CHECK-LABEL: nearest_v8f16:
; CHECK-NEXT: .functype nearest_v8f16 (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.nearbyint.v8f16(<8 x half>)
define <8 x half> @nearest_v8f16(<8 x half> %a) {
%v = call <8 x half> @llvm.nearbyint.v8f16(<8 x half> %a)
ret <8 x half> %v
}

; CHECK-LABEL: nearest_v8f16_via_rint:
; CHECK-NEXT: .functype nearest_v8f16_via_rint (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.rint.v8f16(<8 x half>)
define <8 x half> @nearest_v8f16_via_rint(<8 x half> %a) {
%v = call <8 x half> @llvm.rint.v8f16(<8 x half> %a)
ret <8 x half> %v
}

; CHECK-LABEL: nearest_v8f16_via_roundeven:
; CHECK-NEXT: .functype nearest_v8f16_via_roundeven (v128) -> (v128){{$}}
; CHECK-NEXT: f16x8.nearest $push[[R:[0-9]+]]=, $0{{$}}
; CHECK-NEXT: return $pop[[R]]{{$}}
declare <8 x half> @llvm.roundeven.v8f16(<8 x half>)
define <8 x half> @nearest_v8f16_via_roundeven(<8 x half> %a) {
%v = call <8 x half> @llvm.roundeven.v8f16(<8 x half> %a)
ret <8 x half> %v
}
21 changes: 21 additions & 0 deletions llvm/test/MC/WebAssembly/simd-encodings.s
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,25 @@ main:
# CHECK: f16x8.ge # encoding: [0xfd,0xc5,0x02]
f16x8.ge

# CHECK: f16x8.abs # encoding: [0xfd,0xb1,0x02]
f16x8.abs

# CHECK: f16x8.neg # encoding: [0xfd,0xb2,0x02]
f16x8.neg

# CHECK: f16x8.sqrt # encoding: [0xfd,0xb3,0x02]
f16x8.sqrt

# CHECK: f16x8.ceil # encoding: [0xfd,0xbc,0x02]
f16x8.ceil

# CHECK: f16x8.floor # encoding: [0xfd,0xbd,0x02]
f16x8.floor

# CHECK: f16x8.trunc # encoding: [0xfd,0xbe,0x02]
f16x8.trunc

# CHECK: f16x8.nearest # encoding: [0xfd,0xbf,0x02]
f16x8.nearest

end_function
Loading