Skip to content

Commit 6364c08

Browse files
Merge pull request #1034 from IntelPython/use-no-associative-math
Use no associative math
2 parents 47e4ae4 + 5a126fd commit 6364c08

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,28 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
230230
wTy w = wTy(n - i) / n;
231231
using dpctl::tensor::type_utils::is_complex;
232232
if constexpr (is_complex<Ty>::value) {
233-
auto _w = static_cast<typename Ty::value_type>(w);
234-
auto _wc = static_cast<typename Ty::value_type>(wc);
235-
auto re_comb = start_v.real() * _w + end_v.real() * _wc;
236-
auto im_comb = start_v.imag() * _w + end_v.imag() * _wc;
233+
using reT = typename Ty::value_type;
234+
auto _w = static_cast<reT>(w);
235+
auto _wc = static_cast<reT>(wc);
236+
auto re_comb = sycl::fma(start_v.real(), _w, reT(0));
237+
re_comb =
238+
sycl::fma(end_v.real(), _wc,
239+
re_comb); // start_v.real() * _w + end_v.real() * _wc;
240+
auto im_comb =
241+
sycl::fma(start_v.imag(), _w,
242+
reT(0)); // start_v.imag() * _w + end_v.imag() * _wc;
243+
im_comb = sycl::fma(end_v.imag(), _wc, im_comb);
237244
Ty affine_comb = Ty{re_comb, im_comb};
238245
p[i] = affine_comb;
239246
}
247+
else if constexpr (std::is_floating_point<Ty>::value) {
248+
Ty _w = static_cast<Ty>(w);
249+
Ty _wc = static_cast<Ty>(wc);
250+
auto affine_comb =
251+
sycl::fma(start_v, _w, Ty(0)); // start_v * w + end_v * wc;
252+
affine_comb = sycl::fma(end_v, _wc, affine_comb);
253+
p[i] = affine_comb;
254+
}
240255
else {
241256
using dpctl::tensor::type_utils::convert_impl;
242257
auto affine_comb = start_v * w + end_v * wc;

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,21 @@ def test_linspace_fp():
11681168
assert X.strides == (1,)
11691169

11701170

1171+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
1172+
def test_linspace_fp_max(dtype):
1173+
q = get_queue_or_skip()
1174+
skip_if_dtype_not_supported(dtype, q)
1175+
n = 16
1176+
dt = dpt.dtype(dtype)
1177+
max_ = dpt.finfo(dt).max
1178+
X = dpt.linspace(max_, max_, endpoint=True, num=n, dtype=dt, sycl_queue=q)
1179+
assert X.shape == (n,)
1180+
assert X.strides == (1,)
1181+
assert np.allclose(
1182+
dpt.asnumpy(X), np.linspace(max_, max_, endpoint=True, num=n, dtype=dt)
1183+
)
1184+
1185+
11711186
@pytest.mark.parametrize(
11721187
"dt",
11731188
_all_dtypes,

0 commit comments

Comments
 (0)