Skip to content

Commit 2366a90

Browse files
committed
Auto merge of #3124 - eduardosm:fix-sse41-round, r=RalfJung
Fix rounding mode check in SSE4.1 round functions Now it masks out the correct bit and adds some explanatory comments. Also extends the tests.
2 parents f408492 + a8aa303 commit 2366a90

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

src/tools/miri/src/shims/x86/sse41.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,20 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
283283
assert_eq!(dest_len, left_len);
284284
assert_eq!(dest_len, right_len);
285285

286-
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 {
287-
0x00 => rustc_apfloat::Round::NearestTiesToEven,
288-
0x01 => rustc_apfloat::Round::TowardNegative,
289-
0x02 => rustc_apfloat::Round::TowardPositive,
290-
0x03 => rustc_apfloat::Round::TowardZero,
286+
// The fourth bit of `rounding` only affects the SSE status
287+
// register, which cannot be accessed from Miri (or from Rust,
288+
// for that matter), so we can ignore it.
289+
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 {
290+
// When the third bit is 0, the rounding mode is determined by the
291+
// first two bits.
292+
0b000 => rustc_apfloat::Round::NearestTiesToEven,
293+
0b001 => rustc_apfloat::Round::TowardNegative,
294+
0b010 => rustc_apfloat::Round::TowardPositive,
295+
0b011 => rustc_apfloat::Round::TowardZero,
296+
// When the third bit is 1, the rounding mode is determined by the
297+
// SSE status register. Since we do not support modifying it from
298+
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
299+
0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven,
291300
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
292301
};
293302

src/tools/miri/tests/pass/intrinsics-x86-sse41.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,31 @@ unsafe fn test_sse41() {
119119
let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b);
120120
let e = _mm_setr_pd(-2.0, 3.5);
121121
assert_eq_m128d(r, e);
122+
123+
let a = _mm_setr_pd(1.5, 3.5);
124+
let b = _mm_setr_pd(-2.5, -4.5);
125+
let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b);
126+
let e = _mm_setr_pd(-3.0, 3.5);
127+
assert_eq_m128d(r, e);
128+
129+
let a = _mm_setr_pd(1.5, 3.5);
130+
let b = _mm_setr_pd(-2.5, -4.5);
131+
let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b);
132+
let e = _mm_setr_pd(-2.0, 3.5);
133+
assert_eq_m128d(r, e);
134+
135+
let a = _mm_setr_pd(1.5, 3.5);
136+
let b = _mm_setr_pd(-2.5, -4.5);
137+
let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b);
138+
let e = _mm_setr_pd(-2.0, 3.5);
139+
assert_eq_m128d(r, e);
140+
141+
// Assume round-to-nearest by default
142+
let a = _mm_setr_pd(1.5, 3.5);
143+
let b = _mm_setr_pd(-2.5, -4.5);
144+
let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b);
145+
let e = _mm_setr_pd(-2.0, 3.5);
146+
assert_eq_m128d(r, e);
122147
}
123148
test_mm_round_sd();
124149

@@ -129,6 +154,31 @@ unsafe fn test_sse41() {
129154
let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b);
130155
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
131156
assert_eq_m128(r, e);
157+
158+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
159+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
160+
let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b);
161+
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
162+
assert_eq_m128(r, e);
163+
164+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
165+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
166+
let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b);
167+
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
168+
assert_eq_m128(r, e);
169+
170+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
171+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
172+
let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b);
173+
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
174+
assert_eq_m128(r, e);
175+
176+
// Assume round-to-nearest by default
177+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
178+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
179+
let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b);
180+
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
181+
assert_eq_m128(r, e);
132182
}
133183
test_mm_round_ss();
134184

0 commit comments

Comments
 (0)