Skip to content

Commit 0508932

Browse files
authored
cranelift: Align Scalar and SIMD shift semantics (#4520)
* cranelift: Reorganize test suite Group some SIMD operations by instruction. * cranelift: Deduplicate some shift tests Also, new tests with the mod behaviour * aarch64: Lower shifts with mod behaviour * x64: Lower shifts with mod behaviour * wasmtime: Don't mask SIMD shifts
1 parent e121c20 commit 0508932

File tree

15 files changed

+314
-423
lines changed

15 files changed

+314
-423
lines changed

cranelift/codegen/src/isa/aarch64/lower.isle

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,8 @@
927927
;; Shift for vector types.
928928
(rule (lower (has_type (ty_vec128 ty) (ishl x y)))
929929
(let ((size VectorSize (vector_size ty))
930-
(shift Reg (vec_dup y size)))
930+
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
931+
(shift Reg (vec_dup masked_shift_amt size)))
931932
(sshl x shift size)))
932933

933934
;; Helper function to emit a shift operation with the opcode specified and
@@ -986,7 +987,8 @@
986987
;; Vector shifts.
987988
(rule (lower (has_type (ty_vec128 ty) (ushr x y)))
988989
(let ((size VectorSize (vector_size ty))
989-
(shift Reg (vec_dup (sub $I32 (zero_reg) y) size)))
990+
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
991+
(shift Reg (vec_dup (sub $I64 (zero_reg) masked_shift_amt) size)))
990992
(ushl x shift size)))
991993

992994
;; lsr lo_rshift, src_lo, amt
@@ -1035,7 +1037,8 @@
10351037
;; Note that right shifts are implemented with a negative left shift.
10361038
(rule (lower (has_type (ty_vec128 ty) (sshr x y)))
10371039
(let ((size VectorSize (vector_size ty))
1038-
(shift Reg (vec_dup (sub $I32 (zero_reg) y) size)))
1040+
(masked_shift_amt Reg (and_imm $I32 y (shift_mask ty)))
1041+
(shift Reg (vec_dup (sub $I64 (zero_reg) masked_shift_amt) size)))
10391042
(sshl x shift size)))
10401043

10411044
;; lsr lo_rshift, src_lo, amt

cranelift/codegen/src/isa/aarch64/lower/isle.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ where
335335
}
336336

337337
fn shift_mask(&mut self, ty: Type) -> ImmLogic {
338-
let mask = (ty.bits() - 1) as u64;
338+
debug_assert!(ty.lane_bits().is_power_of_two());
339+
340+
let mask = (ty.lane_bits() - 1) as u64;
339341
ImmLogic::maybe_from_u64(mask, I32).unwrap()
340342
}
341343

cranelift/codegen/src/isa/x64/inst.isle

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,10 @@
11471147
(decl reg_mem_to_xmm_mem (RegMem) XmmMem)
11481148
(extern constructor reg_mem_to_xmm_mem reg_mem_to_xmm_mem)
11491149

1150+
;; Construct a new `RegMemImm` from the given `Reg`.
1151+
(decl reg_to_reg_mem_imm (Reg) RegMemImm)
1152+
(extern constructor reg_to_reg_mem_imm reg_to_reg_mem_imm)
1153+
11501154
;; Construct a new `GprMemImm` from the given `RegMemImm`.
11511155
;;
11521156
;; Asserts that the `RegMemImm`'s register, if any, is an GPR register.
@@ -1354,6 +1358,10 @@
13541358
(decl const_to_type_masked_imm8 (u64 Type) Imm8Gpr)
13551359
(extern constructor const_to_type_masked_imm8 const_to_type_masked_imm8)
13561360

1361+
;; Generate a mask for the bit-width of the given type
1362+
(decl shift_mask (Type) u32)
1363+
(extern constructor shift_mask shift_mask)
1364+
13571365
;; Extract a constant `GprMemImm.Imm` from a value operand.
13581366
(decl simm32_from_value (GprMemImm) Value)
13591367
(extern extractor simm32_from_value simm32_from_value)
@@ -3043,6 +3051,7 @@
30433051
(convert Xmm RegMem xmm_to_reg_mem)
30443052
(convert Reg Xmm xmm_new)
30453053
(convert Reg XmmMem reg_to_xmm_mem)
3054+
(convert Reg RegMemImm reg_to_reg_mem_imm)
30463055
(convert RegMem XmmMem reg_mem_to_xmm_mem)
30473056
(convert RegMemImm XmmMemImm mov_rmi_to_xmm)
30483057
(convert Xmm XmmMem xmm_to_xmm_mem)

cranelift/codegen/src/isa/x64/lower.isle

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,15 @@
531531
;; in higher feature sets like AVX), we lower the `ishl.i8x16` to a sequence of
532532
;; instructions. The basic idea, whether the amount to shift by is an immediate
533533
;; or not, is to use a 16x8 shift and then mask off the incorrect bits to 0s.
534-
(rule (lower (has_type $I8X16 (ishl src amt)))
534+
(rule (lower (has_type ty @ $I8X16 (ishl src amt)))
535535
(let (
536+
;; Mask the amount to ensure wrapping behaviour
537+
(masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty))))
536538
;; Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be
537539
;; correct for half of the lanes; the others must be fixed up with
538540
;; the mask below.
539-
(unmasked Xmm (x64_psllw src (mov_rmi_to_xmm amt)))
540-
(mask_addr SyntheticAmode (ishl_i8x16_mask amt))
541+
(unmasked Xmm (x64_psllw src (mov_rmi_to_xmm masked_amt)))
542+
(mask_addr SyntheticAmode (ishl_i8x16_mask masked_amt))
541543
(mask Reg (x64_load $I8X16 mask_addr (ExtKind.None))))
542544
(sse_and $I8X16 unmasked (RegMem.Reg mask))))
543545

@@ -571,16 +573,19 @@
571573
(rule (ishl_i8x16_mask (RegMemImm.Mem amt))
572574
(ishl_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))
573575

574-
;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
576+
;; 16x8, 32x4, and 64x2 shifts can each use a single instruction, once the shift amount is masked.
575577

576-
(rule (lower (has_type $I16X8 (ishl src amt)))
577-
(x64_psllw src (mov_rmi_to_xmm amt)))
578+
(rule (lower (has_type ty @ $I16X8 (ishl src amt)))
579+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
580+
(x64_psllw src (mov_rmi_to_xmm masked_amt))))
578581

579-
(rule (lower (has_type $I32X4 (ishl src amt)))
580-
(x64_pslld src (mov_rmi_to_xmm amt)))
582+
(rule (lower (has_type ty @ $I32X4 (ishl src amt)))
583+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
584+
(x64_pslld src (mov_rmi_to_xmm masked_amt))))
581585

582-
(rule (lower (has_type $I64X2 (ishl src amt)))
583-
(x64_psllq src (mov_rmi_to_xmm amt)))
586+
(rule (lower (has_type ty @ $I64X2 (ishl src amt)))
587+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
588+
(x64_psllq src (mov_rmi_to_xmm masked_amt))))
584589

585590
;;;; Rules for `ushr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
586591

@@ -630,13 +635,15 @@
630635

631636
;; There are no 8x16 shifts in x64. Do the same 16x8-shift-and-mask thing we do
632637
;; with 8x16 `ishl`.
633-
(rule (lower (has_type $I8X16 (ushr src amt)))
638+
(rule (lower (has_type ty @ $I8X16 (ushr src amt)))
634639
(let (
640+
;; Mask the amount to ensure wrapping behaviour
641+
(masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty))))
635642
;; Shift `src` using 16x8. Unfortunately, a 16x8 shift will only be
636643
;; correct for half of the lanes; the others must be fixed up with
637644
;; the mask below.
638-
(unmasked Xmm (x64_psrlw src (mov_rmi_to_xmm amt)))
639-
(mask_addr SyntheticAmode (ushr_i8x16_mask amt))
645+
(unmasked Xmm (x64_psrlw src (mov_rmi_to_xmm masked_amt)))
646+
(mask_addr SyntheticAmode (ushr_i8x16_mask masked_amt))
640647
(mask Reg (x64_load $I8X16 mask_addr (ExtKind.None))))
641648
(sse_and $I8X16
642649
unmasked
@@ -673,16 +680,19 @@
673680
(rule (ushr_i8x16_mask (RegMemImm.Mem amt))
674681
(ushr_i8x16_mask (RegMemImm.Reg (x64_load $I64 amt (ExtKind.None)))))
675682

676-
;; 16x8, 32x4, and 64x2 shifts can each use a single instruction.
683+
;; 16x8, 32x4, and 64x2 shifts can each use a single instruction, once the shift amount is masked.
677684

678-
(rule (lower (has_type $I16X8 (ushr src amt)))
679-
(x64_psrlw src (mov_rmi_to_xmm amt)))
685+
(rule (lower (has_type ty @ $I16X8 (ushr src amt)))
686+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
687+
(x64_psrlw src (mov_rmi_to_xmm masked_amt))))
680688

681-
(rule (lower (has_type $I32X4 (ushr src amt)))
682-
(x64_psrld src (mov_rmi_to_xmm amt)))
689+
(rule (lower (has_type ty @ $I32X4 (ushr src amt)))
690+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
691+
(x64_psrld src (mov_rmi_to_xmm masked_amt))))
683692

684-
(rule (lower (has_type $I64X2 (ushr src amt)))
685-
(x64_psrlq src (mov_rmi_to_xmm amt)))
693+
(rule (lower (has_type ty @ $I64X2 (ushr src amt)))
694+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
695+
(x64_psrlq src (mov_rmi_to_xmm masked_amt))))
686696

687697
;;;; Rules for `sshr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
688698

@@ -746,14 +756,16 @@
746756
;; hi.i16x8 = [(s8, s8), (s9, s9), ..., (s15, s15)]
747757
;; shifted_hi.i16x8 = shift each lane of `high`
748758
;; result = [s0'', s1'', ..., s15'']
749-
(rule (lower (has_type $I8X16 (sshr src amt @ (value_type amt_ty))))
759+
(rule (lower (has_type ty @ $I8X16 (sshr src amt @ (value_type amt_ty))))
750760
(let ((src_ Xmm (put_in_xmm src))
761+
;; Mask the amount to ensure wrapping behaviour
762+
(masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty))))
751763
;; In order for `packsswb` later to only use the high byte of each
752764
;; 16x8 lane, we shift right an extra 8 bits, relying on `psraw` to
753765
;; fill in the upper bits appropriately.
754766
(lo Xmm (x64_punpcklbw src_ src_))
755767
(hi Xmm (x64_punpckhbw src_ src_))
756-
(amt_ XmmMemImm (sshr_i8x16_bigger_shift amt_ty amt))
768+
(amt_ XmmMemImm (sshr_i8x16_bigger_shift amt_ty masked_amt))
757769
(shifted_lo Xmm (x64_psraw lo amt_))
758770
(shifted_hi Xmm (x64_psraw hi amt_)))
759771
(x64_packsswb shifted_lo shifted_hi)))
@@ -773,11 +785,13 @@
773785
;; `sshr.{i16x8,i32x4}` can be a simple `psra{w,d}`, we just have to make sure
774786
;; that if the shift amount is in a register, it is in an XMM register.
775787

776-
(rule (lower (has_type $I16X8 (sshr src amt)))
777-
(x64_psraw src (mov_rmi_to_xmm amt)))
788+
(rule (lower (has_type ty @ $I16X8 (sshr src amt)))
789+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
790+
(x64_psraw src (mov_rmi_to_xmm masked_amt))))
778791

779-
(rule (lower (has_type $I32X4 (sshr src amt)))
780-
(x64_psrad src (mov_rmi_to_xmm amt)))
792+
(rule (lower (has_type ty @ $I32X4 (sshr src amt)))
793+
(let ((masked_amt Reg (x64_and $I64 amt (RegMemImm.Imm (shift_mask ty)))))
794+
(x64_psrad src (mov_rmi_to_xmm masked_amt))))
781795

782796
;; The `sshr.i64x2` CLIF instruction has no single x86 instruction in the older
783797
;; feature sets. Newer ones like AVX512VL + AVX512F include `vpsraq`, a 128-bit

cranelift/codegen/src/isa/x64/lower/isle.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ where
229229
.unwrap()
230230
}
231231

232+
#[inline]
233+
fn shift_mask(&mut self, ty: Type) -> u32 {
234+
ty.lane_bits() - 1
235+
}
236+
232237
#[inline]
233238
fn simm32_from_value(&mut self, val: Value) -> Option<GprMemImm> {
234239
let inst = self.lower_ctx.dfg().value_def(val).inst()?;
@@ -415,6 +420,11 @@ where
415420
Writable::from_reg(Xmm::new(self.temp_writable_reg(I8X16).to_reg()).unwrap())
416421
}
417422

423+
#[inline]
424+
fn reg_to_reg_mem_imm(&mut self, reg: Reg) -> RegMemImm {
425+
RegMemImm::Reg { reg }
426+
}
427+
418428
#[inline]
419429
fn reg_mem_to_xmm_mem(&mut self, rm: &RegMem) -> XmmMem {
420430
XmmMem::new(rm.clone()).unwrap()

cranelift/filetests/filetests/isa/aarch64/arithmetic.clif

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,10 @@ block0(v0: i8x16):
344344

345345
; block0:
346346
; movz x3, #1
347-
; sub w5, wzr, w3
348-
; dup v7.16b, w5
349-
; ushl v0.16b, v0.16b, v7.16b
347+
; and w5, w3, #7
348+
; sub x7, xzr, x5
349+
; dup v17.16b, w7
350+
; ushl v0.16b, v0.16b, v17.16b
350351
; ret
351352

352353
function %add_i128(i128, i128) -> i128 {
@@ -492,4 +493,3 @@ block0(v0: i64):
492493
; b.vc 8 ; udf
493494
; sdiv x0, x0, x3
494495
; ret
495-

cranelift/filetests/filetests/isa/x64/simd-bitwise-compile.clif

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,13 @@ block0(v0: i32):
206206
; movq %rsp, %rbp
207207
; block0:
208208
; load_const VCodeConstant(1), %xmm0
209-
; movd %edi, %xmm5
210-
; psllw %xmm0, %xmm5, %xmm0
211-
; lea const(VCodeConstant(0)), %rsi
209+
; andq %rdi, $7, %rdi
210+
; movd %edi, %xmm7
211+
; psllw %xmm0, %xmm7, %xmm0
212+
; lea const(VCodeConstant(0)), %rax
212213
; shlq $4, %rdi, %rdi
213-
; movdqu 0(%rsi,%rdi,1), %xmm13
214-
; pand %xmm0, %xmm13, %xmm0
214+
; movdqu 0(%rax,%rdi,1), %xmm15
215+
; pand %xmm0, %xmm15, %xmm0
215216
; movq %rbp, %rsp
216217
; popq %rbp
217218
; ret
@@ -228,9 +229,14 @@ block0:
228229
; movq %rsp, %rbp
229230
; block0:
230231
; load_const VCodeConstant(1), %xmm0
231-
; psrlw %xmm0, $1, %xmm0
232-
; movdqu const(VCodeConstant(0)), %xmm5
233-
; pand %xmm0, %xmm5, %xmm0
232+
; movl $1, %r11d
233+
; andq %r11, $7, %r11
234+
; movd %r11d, %xmm7
235+
; psrlw %xmm0, %xmm7, %xmm0
236+
; lea const(VCodeConstant(0)), %rax
237+
; shlq $4, %r11, %r11
238+
; movdqu 0(%rax,%r11,1), %xmm15
239+
; pand %xmm0, %xmm15, %xmm0
234240
; movq %rbp, %rsp
235241
; popq %rbp
236242
; ret
@@ -245,15 +251,16 @@ block0(v0: i32):
245251
; pushq %rbp
246252
; movq %rsp, %rbp
247253
; block0:
248-
; load_const VCodeConstant(0), %xmm9
249-
; movdqa %xmm9, %xmm0
250-
; punpcklbw %xmm0, %xmm9, %xmm0
251-
; punpckhbw %xmm9, %xmm9, %xmm9
254+
; load_const VCodeConstant(0), %xmm10
255+
; andq %rdi, $7, %rdi
256+
; movdqa %xmm10, %xmm0
257+
; punpcklbw %xmm0, %xmm10, %xmm0
258+
; punpckhbw %xmm10, %xmm10, %xmm10
252259
; addl %edi, $8, %edi
253-
; movd %edi, %xmm11
254-
; psraw %xmm0, %xmm11, %xmm0
255-
; psraw %xmm9, %xmm11, %xmm9
256-
; packsswb %xmm0, %xmm9, %xmm0
260+
; movd %edi, %xmm13
261+
; psraw %xmm0, %xmm13, %xmm0
262+
; psraw %xmm10, %xmm13, %xmm10
263+
; packsswb %xmm0, %xmm10, %xmm0
257264
; movq %rbp, %rsp
258265
; popq %rbp
259266
; ret
@@ -267,17 +274,19 @@ block0(v0: i8x16, v1: i32):
267274
; pushq %rbp
268275
; movq %rsp, %rbp
269276
; block0:
270-
; movdqa %xmm0, %xmm9
271-
; punpcklbw %xmm9, %xmm0, %xmm9
277+
; movl $3, %esi
278+
; andq %rsi, $7, %rsi
279+
; movdqa %xmm0, %xmm15
280+
; punpcklbw %xmm15, %xmm0, %xmm15
281+
; movdqa %xmm15, %xmm13
272282
; punpckhbw %xmm0, %xmm0, %xmm0
273-
; movdqa %xmm9, %xmm12
274-
; psraw %xmm12, $11, %xmm12
275-
; movdqa %xmm12, %xmm9
276-
; psraw %xmm0, $11, %xmm0
277-
; movdqa %xmm9, %xmm1
278-
; packsswb %xmm1, %xmm0, %xmm1
279-
; movdqa %xmm1, %xmm9
280-
; movdqa %xmm9, %xmm0
283+
; movdqa %xmm0, %xmm7
284+
; addl %esi, $8, %esi
285+
; movd %esi, %xmm15
286+
; movdqa %xmm13, %xmm0
287+
; psraw %xmm0, %xmm15, %xmm0
288+
; psraw %xmm7, %xmm15, %xmm7
289+
; packsswb %xmm0, %xmm7, %xmm0
281290
; movq %rbp, %rsp
282291
; popq %rbp
283292
; ret

cranelift/filetests/filetests/runtests/simd-bitselect.clif

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,33 @@ block0(v0: i32x4, v1: i32x4, v2: i32x4):
1313
; run: %bitselect_i32x4(0x11111111111111111111111111111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x11111111111111111111111111111111
1414
; run: %bitselect_i32x4(0x01010011000011110000000011111111, 0x11111111111111111111111111111111, 0x00000000000000000000000000000000) == 0x01010011000011110000000011111111
1515
; run: %bitselect_i32x4(0x00000000000000001111111111111111, 0x00000000000000000000000000000000, 0x11111111111111111111111111111111) == 0x11111111111111110000000000000000
16+
17+
function %bitselect_i8x16(i8x16, i8x16, i8x16) -> i8x16 {
18+
block0(v0: i8x16, v1: i8x16, v2: i8x16):
19+
v3 = bitselect v0, v1, v2
20+
return v3
21+
}
22+
; Remember that bitselect accepts: 1) the selector vector, 2) the "if true" vector, and 3) the "if false" vector.
23+
; run: %bitselect_i8x16([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 255], [127 0 0 0 0 0 0 0 0 0 0 0 0 0 0 42], [42 0 0 0 0 0 0 0 0 0 0 0 0 0 0 127]) == [42 0 0 0 0 0 0 0 0 0 0 0 0 0 0 42]
24+
25+
function %bitselect_i8x16() -> b1 {
26+
block0:
27+
v0 = vconst.i8x16 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 255] ; the selector vector
28+
v1 = vconst.i8x16 [127 0 0 0 0 0 0 0 0 0 0 0 0 0 0 42] ; for each 1-bit in v0 the bit of v1 is selected
29+
v2 = vconst.i8x16 [42 0 0 0 0 0 0 0 0 0 0 0 0 0 0 127] ; for each 0-bit in v0 the bit of v2 is selected
30+
v3 = bitselect v0, v1, v2
31+
32+
v4 = extractlane v3, 0
33+
v5 = icmp_imm eq v4, 42
34+
35+
v6 = extractlane v3, 1
36+
v7 = icmp_imm eq v6, 0
37+
38+
v8 = extractlane v3, 15
39+
v9 = icmp_imm eq v8, 42
40+
41+
v10 = band v5, v7
42+
v11 = band v10, v9
43+
return v11
44+
}
45+
; run

0 commit comments

Comments
 (0)