You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For large complex values the division produces inf or NaN values which leads other functions to produce such too,
e.g. `torch._refs.sgn` used in a test.
Example:
```
$ python -c 'import torch; print(torch._refs.sgn(torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32))))'
tensor([-0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj])
$ python -c 'import torch; t = torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32)); print(t / t.abs())'
tensor([-0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj, -0.+nanj])
```
Implement the same algorithm as used in numpy and x86 (pytorch#93277)
Reason here is that for a tensor with a component of `1e20` the abs-squared value used in the division contains a term `1e20 * 1e20` which overflows the dynamic range of float32 (3e38) and yields an "inf", so the division yields "nan"
Output after change:
```
$ python -c 'import torch; t = torch.complex(torch.tensor([-501]*16, dtype=torch.float32), torch.tensor([-1e20]*16, dtype=torch.float32)); print(torch._refs.sgn(t), t.sgn(), t / t.abs())'
tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j]) tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j]) tensor([-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j,
-5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j, -5.0100e-18-1.j])
```
CC @quickwritereader who wrote the initial code and @VitalyFedyunin who was involved in the initial review and @lezcano who reviewed pytorch#93277
Pull Request resolved: pytorch#116972
Approved by: https://github.com/lezcano
0 commit comments