Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 01350a2

Browse files
Merge portable-simd#203 - deantvv/add-spectral-norm
Add spectral_norm example from packed_simd
2 parents 03f6fbb + 861a6e8 commit 01350a2

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#![feature(portable_simd)]
2+
3+
use core_simd::simd::*;
4+
5+
fn a(i: usize, j: usize) -> f64 {
6+
((i + j) * (i + j + 1) / 2 + i + 1) as f64
7+
}
8+
9+
fn mult_av(v: &[f64], out: &mut [f64]) {
10+
assert!(v.len() == out.len());
11+
assert!(v.len() % 2 == 0);
12+
13+
for (i, out) in out.iter_mut().enumerate() {
14+
let mut sum = f64x2::splat(0.0);
15+
16+
let mut j = 0;
17+
while j < v.len() {
18+
let b = f64x2::from_slice(&v[j..]);
19+
let a = f64x2::from_array([a(i, j), a(i, j + 1)]);
20+
sum += b / a;
21+
j += 2
22+
}
23+
*out = sum.horizontal_sum();
24+
}
25+
}
26+
27+
fn mult_atv(v: &[f64], out: &mut [f64]) {
28+
assert!(v.len() == out.len());
29+
assert!(v.len() % 2 == 0);
30+
31+
for (i, out) in out.iter_mut().enumerate() {
32+
let mut sum = f64x2::splat(0.0);
33+
34+
let mut j = 0;
35+
while j < v.len() {
36+
let b = f64x2::from_slice(&v[j..]);
37+
let a = f64x2::from_array([a(j, i), a(j + 1, i)]);
38+
sum += b / a;
39+
j += 2
40+
}
41+
*out = sum.horizontal_sum();
42+
}
43+
}
44+
45+
fn mult_atav(v: &[f64], out: &mut [f64], tmp: &mut [f64]) {
46+
mult_av(v, tmp);
47+
mult_atv(tmp, out);
48+
}
49+
50+
pub fn spectral_norm(n: usize) -> f64 {
51+
assert!(n % 2 == 0, "only even lengths are accepted");
52+
53+
let mut u = vec![1.0; n];
54+
let mut v = u.clone();
55+
let mut tmp = u.clone();
56+
57+
for _ in 0..10 {
58+
mult_atav(&u, &mut v, &mut tmp);
59+
mult_atav(&v, &mut u, &mut tmp);
60+
}
61+
(dot(&u, &v) / dot(&v, &v)).sqrt()
62+
}
63+
64+
fn dot(x: &[f64], y: &[f64]) -> f64 {
65+
// This is auto-vectorized:
66+
x.iter().zip(y).map(|(&x, &y)| x * y).sum()
67+
}
68+
69+
#[cfg(test)]
70+
#[test]
71+
fn test() {
72+
assert_eq!(&format!("{:.9}", spectral_norm(100)), "1.274219991");
73+
}
74+
75+
fn main() {
76+
// Empty main to make cargo happy
77+
}

0 commit comments

Comments
 (0)