Skip to content

Commit 072621b

Browse files
TDeckingAmanieu
authored andcommitted
Refactor avx512bw: avg + mulhi + abs
1 parent be95083 commit 072621b

File tree

2 files changed

+160
-24
lines changed

2 files changed

+160
-24
lines changed

crates/core_arch/src/simd.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,142 @@ simd_ty!(
743743
x6,
744744
x7
745745
);
746+
747+
// 1024-bit wide types:
748+
simd_ty!(
749+
u16x64[u16]:
750+
x0,
751+
x1,
752+
x2,
753+
x3,
754+
x4,
755+
x5,
756+
x6,
757+
x7,
758+
x8,
759+
x9,
760+
x10,
761+
x11,
762+
x12,
763+
x13,
764+
x14,
765+
x15,
766+
x16,
767+
x17,
768+
x18,
769+
x19,
770+
x20,
771+
x21,
772+
x22,
773+
x23,
774+
x24,
775+
x25,
776+
x26,
777+
x27,
778+
x28,
779+
x29,
780+
x30,
781+
x31,
782+
x32,
783+
x33,
784+
x34,
785+
x35,
786+
x36,
787+
x37,
788+
x38,
789+
x39,
790+
x40,
791+
x41,
792+
x42,
793+
x43,
794+
x44,
795+
x45,
796+
x46,
797+
x47,
798+
x48,
799+
x49,
800+
x50,
801+
x51,
802+
x52,
803+
x53,
804+
x54,
805+
x55,
806+
x56,
807+
x57,
808+
x58,
809+
x59,
810+
x60,
811+
x61,
812+
x62,
813+
x63
814+
);
815+
simd_ty!(
816+
i32x32[i32]:
817+
x0,
818+
x1,
819+
x2,
820+
x3,
821+
x4,
822+
x5,
823+
x6,
824+
x7,
825+
x8,
826+
x9,
827+
x10,
828+
x11,
829+
x12,
830+
x13,
831+
x14,
832+
x15,
833+
x16,
834+
x17,
835+
x18,
836+
x19,
837+
x20,
838+
x21,
839+
x22,
840+
x23,
841+
x24,
842+
x25,
843+
x26,
844+
x27,
845+
x28,
846+
x29,
847+
x30,
848+
x31
849+
);
850+
simd_ty!(
851+
u32x32[u32]:
852+
x0,
853+
x1,
854+
x2,
855+
x3,
856+
x4,
857+
x5,
858+
x6,
859+
x7,
860+
x8,
861+
x9,
862+
x10,
863+
x11,
864+
x12,
865+
x13,
866+
x14,
867+
x15,
868+
x16,
869+
x17,
870+
x18,
871+
x19,
872+
x20,
873+
x21,
874+
x22,
875+
x23,
876+
x24,
877+
x25,
878+
x26,
879+
x27,
880+
x28,
881+
x29,
882+
x30,
883+
x31
884+
);

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
arch::asm,
33
core_arch::{simd::*, x86::*},
44
intrinsics::simd::*,
5-
mem, ptr,
5+
ptr,
66
};
77

88
#[cfg(test)]
@@ -17,11 +17,8 @@ use stdarch_test::assert_instr;
1717
#[cfg_attr(test, assert_instr(vpabsw))]
1818
pub unsafe fn _mm512_abs_epi16(a: __m512i) -> __m512i {
1919
let a = a.as_i16x32();
20-
// all-0 is a properly initialized i16x32
21-
let zero: i16x32 = mem::zeroed();
22-
let sub = simd_sub(zero, a);
23-
let cmp: i16x32 = simd_gt(a, zero);
24-
transmute(simd_select(cmp, a, sub))
20+
let cmp: i16x32 = simd_gt(a, i16x32::splat(0));
21+
transmute(simd_select(cmp, a, simd_neg(a)))
2522
}
2623

2724
/// Compute the absolute value of packed signed 16-bit integers in a, and store the unsigned results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -108,11 +105,8 @@ pub unsafe fn _mm_maskz_abs_epi16(k: __mmask8, a: __m128i) -> __m128i {
108105
#[cfg_attr(test, assert_instr(vpabsb))]
109106
pub unsafe fn _mm512_abs_epi8(a: __m512i) -> __m512i {
110107
let a = a.as_i8x64();
111-
// all-0 is a properly initialized i8x64
112-
let zero: i8x64 = mem::zeroed();
113-
let sub = simd_sub(zero, a);
114-
let cmp: i8x64 = simd_gt(a, zero);
115-
transmute(simd_select(cmp, a, sub))
108+
let cmp: i8x64 = simd_gt(a, i8x64::splat(0));
109+
transmute(simd_select(cmp, a, simd_neg(a)))
116110
}
117111

118112
/// Compute the absolute value of packed signed 8-bit integers in a, and store the unsigned results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -1368,7 +1362,10 @@ pub unsafe fn _mm_maskz_subs_epi8(k: __mmask16, a: __m128i, b: __m128i) -> __m12
13681362
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
13691363
#[cfg_attr(test, assert_instr(vpmulhuw))]
13701364
pub unsafe fn _mm512_mulhi_epu16(a: __m512i, b: __m512i) -> __m512i {
1371-
transmute(vpmulhuw(a.as_u16x32(), b.as_u16x32()))
1365+
let a = simd_cast::<_, u32x32>(a.as_u16x32());
1366+
let b = simd_cast::<_, u32x32>(b.as_u16x32());
1367+
let r = simd_shr(simd_mul(a, b), u32x32::splat(16));
1368+
transmute(simd_cast::<u32x32, u16x32>(r))
13721369
}
13731370

13741371
/// Multiply the packed unsigned 16-bit integers in a and b, producing intermediate 32-bit integers, and store the high 16 bits of the intermediate integers in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -1464,7 +1461,10 @@ pub unsafe fn _mm_maskz_mulhi_epu16(k: __mmask8, a: __m128i, b: __m128i) -> __m1
14641461
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
14651462
#[cfg_attr(test, assert_instr(vpmulhw))]
14661463
pub unsafe fn _mm512_mulhi_epi16(a: __m512i, b: __m512i) -> __m512i {
1467-
transmute(vpmulhw(a.as_i16x32(), b.as_i16x32()))
1464+
let a = simd_cast::<_, i32x32>(a.as_i16x32());
1465+
let b = simd_cast::<_, i32x32>(b.as_i16x32());
1466+
let r = simd_shr(simd_mul(a, b), i32x32::splat(16));
1467+
transmute(simd_cast::<i32x32, i16x32>(r))
14681468
}
14691469

14701470
/// Multiply the packed signed 16-bit integers in a and b, producing intermediate 32-bit integers, and store the high 16 bits of the intermediate integers in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -5505,7 +5505,10 @@ pub unsafe fn _mm_maskz_packus_epi16(k: __mmask16, a: __m128i, b: __m128i) -> __
55055505
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
55065506
#[cfg_attr(test, assert_instr(vpavgw))]
55075507
pub unsafe fn _mm512_avg_epu16(a: __m512i, b: __m512i) -> __m512i {
5508-
transmute(vpavgw(a.as_u16x32(), b.as_u16x32()))
5508+
let a = simd_cast::<_, u32x32>(a.as_u16x16());
5509+
let b = simd_cast::<_, u32x32>(b.as_u16x16());
5510+
let r = simd_shr(simd_add(simd_add(a, b), u32x32::splat(1)), u32x32::splat(1));
5511+
transmute(simd_cast::<_, u16x32>(r))
55095512
}
55105513

55115514
/// Average packed unsigned 16-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -5591,7 +5594,10 @@ pub unsafe fn _mm_maskz_avg_epu16(k: __mmask8, a: __m128i, b: __m128i) -> __m128
55915594
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
55925595
#[cfg_attr(test, assert_instr(vpavgb))]
55935596
pub unsafe fn _mm512_avg_epu8(a: __m512i, b: __m512i) -> __m512i {
5594-
transmute(vpavgb(a.as_u8x64(), b.as_u8x64()))
5597+
let a = simd_cast::<_, u16x64>(a.as_u8x64());
5598+
let b = simd_cast::<_, u16x64>(b.as_u8x64());
5599+
let r = simd_shr(simd_add(simd_add(a, b), u16x64::splat(1)), u16x64::splat(1));
5600+
transmute(simd_cast::<_, u8x64>(r))
55955601
}
55965602

55975603
/// Average packed unsigned 8-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -10645,10 +10651,6 @@ extern "C" {
1064510651
#[link_name = "llvm.x86.avx512.mask.psubs.b.128"]
1064610652
fn vpsubsb128(a: i8x16, b: i8x16, src: i8x16, mask: u16) -> i8x16;
1064710653

10648-
#[link_name = "llvm.x86.avx512.pmulhu.w.512"]
10649-
fn vpmulhuw(a: u16x32, b: u16x32) -> u16x32;
10650-
#[link_name = "llvm.x86.avx512.pmulh.w.512"]
10651-
fn vpmulhw(a: i16x32, b: i16x32) -> i16x32;
1065210654
#[link_name = "llvm.x86.avx512.pmul.hr.sw.512"]
1065310655
fn vpmulhrsw(a: i16x32, b: i16x32) -> i16x32;
1065410656

@@ -10712,11 +10714,6 @@ extern "C" {
1071210714
#[link_name = "llvm.x86.avx512.packuswb.512"]
1071310715
fn vpackuswb(a: i16x32, b: i16x32) -> u8x64;
1071410716

10715-
#[link_name = "llvm.x86.avx512.pavg.w.512"]
10716-
fn vpavgw(a: u16x32, b: u16x32) -> u16x32;
10717-
#[link_name = "llvm.x86.avx512.pavg.b.512"]
10718-
fn vpavgb(a: u8x64, b: u8x64) -> u8x64;
10719-
1072010717
#[link_name = "llvm.x86.avx512.psll.w.512"]
1072110718
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;
1072210719

0 commit comments

Comments
 (0)