Skip to content

[WebAssembly] Implement prototype f16x8.splat instruction. #93228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsWebAssembly.def
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ TARGET_BUILTIN(__builtin_wasm_relaxed_dot_bf16x8_add_f32_f32x4, "V4fV8UsV8UsV4f"
// Half-Precision (fp16)
TARGET_BUILTIN(__builtin_wasm_loadf16_f32, "fh*", "nU", "half-precision")
TARGET_BUILTIN(__builtin_wasm_storef16_f32, "vfh*", "n", "half-precision")
TARGET_BUILTIN(__builtin_wasm_splat_f16x8, "V8hf", "nc", "half-precision")

// Reference Types builtins
// Some builtins are custom type-checked - see 't' as part of the third argument,
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Basic/Targets/WebAssembly.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {

StringRef getABI() const override;
bool setABI(const std::string &Name) override;
bool useFP16ConversionIntrinsics() const override {
return !HasHalfPrecision;
}

protected:
void getTargetDefines(const LangOptions &Opts,
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21230,6 +21230,11 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_storef16_f32);
return Builder.CreateCall(Callee, {Val, Addr});
}
case WebAssembly::BI__builtin_wasm_splat_f16x8: {
Value *Val = EmitScalarExpr(E->getArg(0));
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_splat_f16x8);
return Builder.CreateCall(Callee, {Val});
}
case WebAssembly::BI__builtin_wasm_table_get: {
assert(E->getArg(0)->getType()->isArrayType());
Value *Table = EmitArrayToPointerDecay(E->getArg(0)).emitRawPointer(*this);
Expand Down
6 changes: 6 additions & 0 deletions clang/test/CodeGen/builtins-wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ typedef unsigned char u8x16 __attribute((vector_size(16)));
typedef unsigned short u16x8 __attribute((vector_size(16)));
typedef unsigned int u32x4 __attribute((vector_size(16)));
typedef unsigned long long u64x2 __attribute((vector_size(16)));
typedef __fp16 f16x8 __attribute((vector_size(16)));
typedef float f32x4 __attribute((vector_size(16)));
typedef double f64x2 __attribute((vector_size(16)));

Expand Down Expand Up @@ -813,6 +814,11 @@ void store_f16_f32(float val, __fp16 *addr) {
// WEBASSEMBLY-NEXT: ret
}

f16x8 splat_f16x8(float a) {
// WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.splat.f16x8(float %a)
// WEBASSEMBLY-NEXT: ret <8 x half> %0
return __builtin_wasm_splat_f16x8(a);
}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def int_wasm_storef16_f32:
[llvm_float_ty, llvm_ptr_ty],
[IntrWriteMem, IntrArgMemOnly],
"", [SDNPMemOperand]>;
def int_wasm_splat_f16x8:
DefaultAttrsIntrinsic<[llvm_v8f16_ty],
[llvm_float_ty],
[IntrNoMem, IntrSpeculatable]>;


//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ wasm::ValType WebAssembly::toValType(MVT Type) {
case MVT::v8i16:
case MVT::v4i32:
case MVT::v2i64:
case MVT::v8f16:
case MVT::v4f32:
case MVT::v2f64:
return wasm::ValType::V128;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
addRegisterClass(MVT::v2i64, &WebAssembly::V128RegClass);
addRegisterClass(MVT::v2f64, &WebAssembly::V128RegClass);
}
if (Subtarget->hasHalfPrecision()) {
addRegisterClass(MVT::v8f16, &WebAssembly::V128RegClass);
}
if (Subtarget->hasReferenceTypes()) {
addRegisterClass(MVT::externref, &WebAssembly::EXTERNREFRegClass);
addRegisterClass(MVT::funcref, &WebAssembly::FUNCREFRegClass);
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
asmstr_s, simdop, HasRelaxedSIMD>;
}

multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
asmstr_s, simdop, HasHalfPrecision>;
}


defm "" : ARGUMENT<V128, v16i8>;
defm "" : ARGUMENT<V128, v8i16>;
Expand Down Expand Up @@ -591,6 +598,14 @@ defm "" : Splat<I64x2, 18>;
defm "" : Splat<F32x4, 19>;
defm "" : Splat<F64x2, 20>;

// Half values are not fully supported so an intrinsic is used instead of a
// regular Splat pattern as above.
defm SPLAT_F16x8 :
HALF_PRECISION_I<(outs V128:$dst), (ins F32:$x),
(outs), (ins),
[(set (v8f16 V128:$dst), (int_wasm_splat_f16x8 F32:$x))],
"f16x8.splat\t$dst, $x", "f16x8.splat", 0x120>;

// scalar_to_vector leaves high lanes undefined, so can be a splat
foreach vec = AllVecs in
def : Pat<(vec.vt (scalar_to_vector (vec.lane_vt vec.lane_rc:$x))),
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def I32 : WebAssemblyRegClass<[i32], 32, (add FP32, SP32, I32_0)>;
def I64 : WebAssemblyRegClass<[i64], 64, (add FP64, SP64, I64_0)>;
def F32 : WebAssemblyRegClass<[f32], 32, (add F32_0)>;
def F64 : WebAssemblyRegClass<[f64], 64, (add F64_0)>;
def V128 : WebAssemblyRegClass<[v4f32, v2f64, v2i64, v4i32, v16i8, v8i16], 128,
(add V128_0)>;
def V128 : WebAssemblyRegClass<[v8f16, v4f32, v2f64, v2i64, v4i32, v16i8,
v8i16],
128, (add V128_0)>;
def FUNCREF : WebAssemblyRegClass<[funcref], 0, (add FUNCREF_0)>;
def EXTERNREF : WebAssemblyRegClass<[externref], 0, (add EXTERNREF_0)>;
12 changes: 10 additions & 2 deletions llvm/test/CodeGen/WebAssembly/half-precision.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s

declare float @llvm.wasm.loadf32.f16(ptr)
declare void @llvm.wasm.storef16.f32(float, ptr)
Expand All @@ -19,3 +19,11 @@ define void @stf16_32(float %v, ptr %p) {
tail call void @llvm.wasm.storef16.f32(float %v, ptr %p)
ret void
}

; CHECK-LABEL: splat_v8f16:
; CHECK: f16x8.splat $push0=, $0
; CHECK-NEXT: return $pop0
define <8 x half> @splat_v8f16(float %x) {
%v = call <8 x half> @llvm.wasm.splat.f16x8(float %x)
ret <8 x half> %v
}
3 changes: 3 additions & 0 deletions llvm/test/MC/WebAssembly/simd-encodings.s
Original file line number Diff line number Diff line change
Expand Up @@ -845,4 +845,7 @@ main:
# CHECK: f32.store_f16 32 # encoding: [0xfc,0x31,0x01,0x20]
f32.store_f16 32

# CHECK: f16x8.splat # encoding: [0xfd,0xa0,0x02]
f16x8.splat

end_function
Loading