Skip to content

Commit a33f1a6

Browse files
committed
Auto merge of #2031 - RalfJung:simd, r=RalfJung
implement SIMD sqrt and fma Cc #1912
2 parents a9a0d0e + 4fd5dca commit a33f1a6

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/shims/intrinsics.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
329329
| "simd_ceil"
330330
| "simd_floor"
331331
| "simd_round"
332-
| "simd_trunc" => {
332+
| "simd_trunc"
333+
| "simd_fsqrt" => {
333334
let &[ref op] = check_arg_count(args)?;
334335
let (op, op_len) = this.operand_to_simd(op)?;
335336
let (dest, dest_len) = this.place_to_simd(dest)?;
@@ -342,6 +343,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
342343
Floor,
343344
Round,
344345
Trunc,
346+
Sqrt,
345347
}
346348
#[derive(Copy, Clone)]
347349
enum Op {
@@ -356,6 +358,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
356358
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
357359
"simd_round" => Op::HostOp(HostFloatOp::Round),
358360
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
361+
"simd_fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
359362
_ => unreachable!(),
360363
};
361364

@@ -388,6 +391,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
388391
HostFloatOp::Floor => f.floor(),
389392
HostFloatOp::Round => f.round(),
390393
HostFloatOp::Trunc => f.trunc(),
394+
HostFloatOp::Sqrt => f.sqrt(),
391395
};
392396
Scalar::from_u32(res.to_bits())
393397
}
@@ -398,6 +402,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
398402
HostFloatOp::Floor => f.floor(),
399403
HostFloatOp::Round => f.round(),
400404
HostFloatOp::Trunc => f.trunc(),
405+
HostFloatOp::Sqrt => f.sqrt(),
401406
};
402407
Scalar::from_u64(res.to_bits())
403408
}
@@ -508,6 +513,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
508513
this.write_scalar(val, &dest.into())?;
509514
}
510515
}
516+
"simd_fma" => {
517+
let &[ref a, ref b, ref c] = check_arg_count(args)?;
518+
let (a, a_len) = this.operand_to_simd(a)?;
519+
let (b, b_len) = this.operand_to_simd(b)?;
520+
let (c, c_len) = this.operand_to_simd(c)?;
521+
let (dest, dest_len) = this.place_to_simd(dest)?;
522+
523+
assert_eq!(dest_len, a_len);
524+
assert_eq!(dest_len, b_len);
525+
assert_eq!(dest_len, c_len);
526+
527+
for i in 0..dest_len {
528+
let a = this.read_immediate(&this.mplace_index(&a, i)?.into())?.to_scalar()?;
529+
let b = this.read_immediate(&this.mplace_index(&b, i)?.into())?.to_scalar()?;
530+
let c = this.read_immediate(&this.mplace_index(&c, i)?.into())?.to_scalar()?;
531+
let dest = this.mplace_index(&dest, i)?;
532+
533+
// Works for f32 and f64.
534+
let ty::Float(float_ty) = dest.layout.ty.kind() else {
535+
bug!("{} operand is not a float", intrinsic_name)
536+
};
537+
let val = match float_ty {
538+
FloatTy::F32 =>
539+
Scalar::from_f32(a.to_f32()?.mul_add(b.to_f32()?, c.to_f32()?).value),
540+
FloatTy::F64 =>
541+
Scalar::from_f64(a.to_f64()?.mul_add(b.to_f64()?, c.to_f64()?).value),
542+
};
543+
this.write_scalar(val, &dest.into())?;
544+
}
545+
}
511546
#[rustfmt::skip]
512547
| "simd_reduce_and"
513548
| "simd_reduce_or"

tests/run-pass/portable-simd.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ fn simd_ops_f32() {
1515
assert_eq!(a.max(b * f32x4::splat(4.0)), f32x4::from_array([10.0, 10.0, 12.0, 10.0]));
1616
assert_eq!(a.min(b * f32x4::splat(4.0)), f32x4::from_array([4.0, 8.0, 10.0, -16.0]));
1717

18+
assert_eq!(a.mul_add(b, a), (a*b)+a);
19+
assert_eq!(b.mul_add(b, a), (b*b)+a);
20+
assert_eq!((a*a).sqrt(), a);
21+
assert_eq!((b*b).sqrt(), b.abs());
22+
1823
assert_eq!(a.lanes_eq(f32x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
1924
assert_eq!(a.lanes_ne(f32x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
2025
assert_eq!(a.lanes_le(f32x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));
@@ -59,6 +64,11 @@ fn simd_ops_f64() {
5964
assert_eq!(a.max(b * f64x4::splat(4.0)), f64x4::from_array([10.0, 10.0, 12.0, 10.0]));
6065
assert_eq!(a.min(b * f64x4::splat(4.0)), f64x4::from_array([4.0, 8.0, 10.0, -16.0]));
6166

67+
assert_eq!(a.mul_add(b, a), (a*b)+a);
68+
assert_eq!(b.mul_add(b, a), (b*b)+a);
69+
assert_eq!((a*a).sqrt(), a);
70+
assert_eq!((b*b).sqrt(), b.abs());
71+
6272
assert_eq!(a.lanes_eq(f64x4::splat(5.0) * b), Mask::from_array([false, true, false, false]));
6373
assert_eq!(a.lanes_ne(f64x4::splat(5.0) * b), Mask::from_array([true, false, true, true]));
6474
assert_eq!(a.lanes_le(f64x4::splat(5.0) * b), Mask::from_array([false, true, true, false]));

0 commit comments

Comments
 (0)