Skip to content

Commit 3ee092f

Browse files
Flamefirepytorchmergebot
authored andcommitted
VSX: Fix overflow in complex division (pytorch#116972)
For large complex values the division produces inf or NaN values which leads other functions to produce such too, e.g. `torch._refs.sgn` used in a test. Example: ``` $ python -c 'import torch; print(torch._refs.sgn(torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32))))' tensor([-0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj]) $ python -c 'import torch; t = torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32)); print(t / t.abs())' tensor([-0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj]) ``` Implement the same algorithm as used in numpy and x86 (pytorch#93277) Reason here is that for a tensor with a component of `1e20` the abs-squared value used in the division contains a term `1e20 * 1e20` which overflows the dynamic range of float32 (3e38) and yields an "inf", so the division yields "nan" Output after change: ``` $ python -c 'import torch; t = torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32)); print(torch._refs.sgn(t), t.sgn(), t / t.abs())' tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j]) tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j]) tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j]) ``` CC @quickwritereader who wrote the initial code and @VitalyFedyunin who was involved in the initial review and @lezcano who reviewed pytorch#93277 Pull Request resolved: pytorch#116972 Approved by: https://github.com/lezcano
1 parent afabed6 commit 3ee092f

File tree

3 files changed

+59
-58
lines changed

3 files changed

+59
-58
lines changed

aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,19 @@ class Vectorized<ComplexDbl> {
212212
static Vectorized<ComplexDbl> el_mergee(
213213
Vectorized<ComplexDbl>& first,
214214
Vectorized<ComplexDbl>& second) {
215-
// as mergee phased in , we can use vec_perm with mask
216215
return {
217216
vec_mergeh(first._vec0, second._vec0),
218217
vec_mergeh(first._vec1, second._vec1)};
219218
}
220219

220+
static Vectorized<ComplexDbl> el_mergeo(
221+
Vectorized<ComplexDbl>& first,
222+
Vectorized<ComplexDbl>& second) {
223+
return {
224+
vec_mergel(first._vec0, second._vec0),
225+
vec_mergel(first._vec1, second._vec1)};
226+
}
227+
221228
Vectorized<ComplexDbl> abs_2_() const {
222229
auto a = (*this).elwise_mult(*this);
223230
auto permuted = a.el_swapped();
@@ -385,13 +392,11 @@ class Vectorized<ComplexDbl> {
385392
static Vectorized<ComplexDbl> horizontal_add(
386393
Vectorized<ComplexDbl>& first,
387394
Vectorized<ComplexDbl>& second) {
388-
auto first_perm = first.el_swapped(); // 2perm
389-
auto second_perm = second.el_swapped(); // 2perm
390-
// summ
391-
auto first_ret = first + first_perm; // 2add
392-
auto second_ret = second + second_perm; // 2 add
393-
// now lets choose evens
394-
return el_mergee(first_ret, second_ret); // 2 mergee's
395+
// Operates on individual floats, see _mm_hadd_ps
396+
// {f0+f1, s0+s1, f2+f3, s2+s3, ...}
397+
// i.e. it sums the re and im of each value and interleaves first and second:
398+
// {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
399+
return el_mergee(first, second) + el_mergeo(first, second);
395400
}
396401

397402
static Vectorized<ComplexDbl> horizontal_sub(
@@ -432,25 +437,20 @@ class Vectorized<ComplexDbl> {
432437
// re + im*i = (a + bi) / (c + di)
433438
// re = (ac + bd)/abs_2()
434439
// im = (bc - ad)/abs_2()
435-
#if 1
436-
auto vi = b.el_mergeo();
437-
auto vr = b.el_mergee();
438-
auto abs_b = b.abs_2_();
439-
vi = vi ^ vd_isign_mask;
440-
auto ret = elwise_mult(vr);
441-
auto vx_swapped = el_swapped();
442-
ret = vx_swapped.el_madd(vi, ret);
443-
ret = ret.elwise_div(abs_b);
444-
#else
445-
// Vectorized x86 simulation
446-
auto ac_bd = elwise_mult(b);
447-
auto d_c = b.el_swapped();
448-
d_c = d_c ^ vd_rsign_mask;
449-
auto ad_bc = elwise_mult(d_c);
450-
auto abs_b = b.abs_2_();
451-
auto re_im = horizontal_add(ac_bd, ad_bc);
452-
auto ret = re_im.elwise_div(abs_b);
453-
#endif
440+
auto fabs_cd = Vectorized{
441+
vec_andc(b._vec0, vd_sign_mask),
442+
vec_andc(b._vec1, vd_sign_mask)}; // |c| |d|
443+
auto fabs_dc = fabs_cd.el_swapped(); // |d| |c|
444+
auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
445+
auto a2 = elwise_div(scale); // a/sc b/sc
446+
auto b2 = b.elwise_div(scale); // c/sc d/sc
447+
auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2
448+
auto dc2 = b2.el_swapped(); // d/sc c/sc
449+
dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc
450+
auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2
451+
auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2
452+
auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
453+
ret = ret.elwise_div(denom2);
454454
return ret;
455455
}
456456

@@ -511,13 +511,14 @@ class Vectorized<ComplexDbl> {
511511
DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and)
512512
DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or)
513513
DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor)
514-
// elelemtwise helpers
514+
// elementwise helpers
515515
DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul)
516516
DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div)
517517
DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt)
518518
DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge)
519519
DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt)
520520
DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple)
521+
DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max)
521522
};
522523

523524
template <>

aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,14 @@ class Vectorized<ComplexFlt> {
238238
return loadu(tmp);
239239
}
240240

241-
static Vectorized<ComplexFlt> horizontal_add_permD8(
241+
static Vectorized<ComplexFlt> horizontal_add(
242242
Vectorized<ComplexFlt>& first,
243243
Vectorized<ComplexFlt>& second) {
244-
// we will simulate it differently with 6 instructions total
245-
// lets permute second so that we can add it getting horizontal sums
246-
auto first_perm = first.el_swapped(); // 2perm
247-
auto second_perm = second.el_swapped(); // 2perm
248-
// sum
249-
auto first_ret = first + first_perm; // 2add
250-
auto second_ret = second + second_perm; // 2 add
251-
// now lets choose evens
252-
return el_mergee(first_ret, second_ret); // 2 mergee's
244+
// Operates on individual floats, see _mm_hadd_ps
245+
// {f0+f1, s0+s1, f2+f3, s2+s3, ...}
246+
// i.e. it sums the re and im of each value and interleaves first and second:
247+
// {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
248+
return el_mergee(first, second) + el_mergeo(first, second);
253249
}
254250

255251
static Vectorized<ComplexFlt> horizontal_sub_permD8(
@@ -353,12 +349,19 @@ class Vectorized<ComplexFlt> {
353349
static Vectorized<ComplexFlt> el_mergee(
354350
Vectorized<ComplexFlt>& first,
355351
Vectorized<ComplexFlt>& second) {
356-
// as mergee phased in , we can use vec_perm with mask
357352
return {
358353
vec_mergee(first._vecb0, second._vecb0),
359354
vec_mergee(first._vecb1, second._vecb1)};
360355
}
361356

357+
static Vectorized<ComplexFlt> el_mergeo(
358+
Vectorized<ComplexFlt>& first,
359+
Vectorized<ComplexFlt>& second) {
360+
return {
361+
vec_mergeo(first._vecb0, second._vecb0),
362+
vec_mergeo(first._vecb1, second._vecb1)};
363+
}
364+
362365
Vectorized<ComplexFlt> angle_() const {
363366
// angle = atan2(b/a)
364367
// auto b_a = _mm256_permute_ps(values, 0xB1); // b a
@@ -488,25 +491,20 @@ class Vectorized<ComplexFlt> {
488491
// re + im*i = (a + bi) / (c + di)
489492
// re = (ac + bd)/abs_2()
490493
// im = (bc - ad)/abs_2()
491-
#if 1
492-
auto vi = b.el_mergeo();
493-
auto vr = b.el_mergee();
494-
auto abs_b = b.abs_2_();
495-
vi = vi ^ isign_mask;
496-
auto ret = elwise_mult(vr);
497-
auto vx_swapped = el_swapped();
498-
ret = vx_swapped.el_madd(vi, ret);
499-
ret = ret.elwise_div(abs_b);
500-
#else
501-
// Vectorized x86 simulation
502-
auto ac_bd = elwise_mult(b);
503-
auto d_c = b.el_swapped();
504-
d_c = d_c ^ rsign_mask;
505-
auto ad_bc = elwise_mult(d_c);
506-
auto abs_b = b.abs_2_();
507-
auto re_im = horizontal_add_permD8(ac_bd, ad_bc);
508-
auto ret = re_im.elwise_div(abs_b);
509-
#endif
494+
auto fabs_cd = Vectorized{
495+
vec_andc(b._vec0, sign_mask),
496+
vec_andc(b._vec1, sign_mask)}; // |c| |d|
497+
auto fabs_dc = fabs_cd.el_swapped(); // |d| |c|
498+
auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
499+
auto a2 = elwise_div(scale); // a/sc b/sc
500+
auto b2 = b.elwise_div(scale); // c/sc d/sc
501+
auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2
502+
auto dc2 = b2.el_swapped(); // d/sc c/sc
503+
dc2 = dc2 ^ rsign_mask; // -d/sc c/sc
504+
auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2
505+
auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2
506+
auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
507+
ret = ret.elwise_div(denom2);
510508
return ret;
511509
}
512510

@@ -589,6 +587,7 @@ class Vectorized<ComplexFlt> {
589587
DEFINE_MEMBER_OP(elwise_ge, ComplexFlt, vec_cmpge)
590588
DEFINE_MEMBER_OP(elwise_lt, ComplexFlt, vec_cmplt)
591589
DEFINE_MEMBER_OP(elwise_le, ComplexFlt, vec_cmple)
590+
DEFINE_MEMBER_OP(elwise_max, ComplexFlt, vec_max)
592591
};
593592

594593
template <>

aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ const vbool32 imag_mask = vbool32{0x0, 0xFFFFFFFF, 0x0, 0xFFFFFFFF};
391391
const vbool32 isign_mask = vbool32{0x0, 0x80000000, 0x0, 0x80000000};
392392
const vbool32 rsign_mask = vbool32{0x80000000, 0x0, 0x80000000, 0x0};
393393

394+
const vbool64 vd_sign_mask = vbool64{0x8000000000000000, 0x8000000000000000};
394395
const vbool64 vd_imag_mask = vbool64{0x0, 0xFFFFFFFFFFFFFFFF};
395396
const vbool64 vd_real_mask = vbool64{0xFFFFFFFFFFFFFFFF, 0x0};
396397
const vbool64 vd_isign_mask = vbool64{0x0, 0x8000000000000000};

0 commit comments

Comments
 (0)