Skip to content

Commit a8aa303

Browse files
committed
Fix rounding mode check in SSE4.1 round functions
Now it masks out the correct bit and adds some explanatory comments. Also extends the tests.
1 parent c1dbc19 commit a8aa303

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

src/tools/miri/src/shims/x86/sse41.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,20 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
283283
assert_eq!(dest_len, left_len);
284284
assert_eq!(dest_len, right_len);
285285

286-
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 {
287-
0x00 => rustc_apfloat::Round::NearestTiesToEven,
288-
0x01 => rustc_apfloat::Round::TowardNegative,
289-
0x02 => rustc_apfloat::Round::TowardPositive,
290-
0x03 => rustc_apfloat::Round::TowardZero,
286+
// The fourth bit of `rounding` only affects the SSE status
287+
// register, which cannot be accessed from Miri (or from Rust,
288+
// for that matter), so we can ignore it.
289+
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 {
290+
// When the third bit is 0, the rounding mode is determined by the
291+
// first two bits.
292+
0b000 => rustc_apfloat::Round::NearestTiesToEven,
293+
0b001 => rustc_apfloat::Round::TowardNegative,
294+
0b010 => rustc_apfloat::Round::TowardPositive,
295+
0b011 => rustc_apfloat::Round::TowardZero,
296+
// When the third bit is 1, the rounding mode is determined by the
297+
// SSE status register. Since we do not support modifying it from
298+
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
299+
0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven,
291300
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
292301
};
293302

src/tools/miri/tests/pass/intrinsics-x86-sse41.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,31 @@ unsafe fn test_sse41() {
119119
let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b);
120120
let e = _mm_setr_pd(-2.0, 3.5);
121121
assert_eq_m128d(r, e);
122+
123+
let a = _mm_setr_pd(1.5, 3.5);
124+
let b = _mm_setr_pd(-2.5, -4.5);
125+
let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b);
126+
let e = _mm_setr_pd(-3.0, 3.5);
127+
assert_eq_m128d(r, e);
128+
129+
let a = _mm_setr_pd(1.5, 3.5);
130+
let b = _mm_setr_pd(-2.5, -4.5);
131+
let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b);
132+
let e = _mm_setr_pd(-2.0, 3.5);
133+
assert_eq_m128d(r, e);
134+
135+
let a = _mm_setr_pd(1.5, 3.5);
136+
let b = _mm_setr_pd(-2.5, -4.5);
137+
let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b);
138+
let e = _mm_setr_pd(-2.0, 3.5);
139+
assert_eq_m128d(r, e);
140+
141+
// Assume round-to-nearest by default
142+
let a = _mm_setr_pd(1.5, 3.5);
143+
let b = _mm_setr_pd(-2.5, -4.5);
144+
let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b);
145+
let e = _mm_setr_pd(-2.0, 3.5);
146+
assert_eq_m128d(r, e);
122147
}
123148
test_mm_round_sd();
124149

@@ -129,6 +154,31 @@ unsafe fn test_sse41() {
129154
let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b);
130155
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
131156
assert_eq_m128(r, e);
157+
158+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
159+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
160+
let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b);
161+
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
162+
assert_eq_m128(r, e);
163+
164+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
165+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
166+
let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b);
167+
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
168+
assert_eq_m128(r, e);
169+
170+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
171+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
172+
let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b);
173+
let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5);
174+
assert_eq_m128(r, e);
175+
176+
// Assume round-to-nearest by default
177+
let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
178+
let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
179+
let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b);
180+
let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5);
181+
assert_eq_m128(r, e);
132182
}
133183
test_mm_round_ss();
134184

0 commit comments

Comments
 (0)