Skip to content

Commit 7849858

Browse files
Used sycl::fma for floating types Ty in attempt to fix test on Windows
1 parent 47e4ae4 commit 7849858

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,28 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
230230
wTy w = wTy(n - i) / n;
231231
using dpctl::tensor::type_utils::is_complex;
232232
if constexpr (is_complex<Ty>::value) {
233-
auto _w = static_cast<typename Ty::value_type>(w);
234-
auto _wc = static_cast<typename Ty::value_type>(wc);
235-
auto re_comb = start_v.real() * _w + end_v.real() * _wc;
236-
auto im_comb = start_v.imag() * _w + end_v.imag() * _wc;
233+
using reT = typename Ty::value_type;
234+
auto _w = static_cast<reT>(w);
235+
auto _wc = static_cast<reT>(wc);
236+
auto re_comb = sycl::fma(start_v.real(), _w, reT(0));
237+
re_comb =
238+
sycl::fma(end_v.real(), _wc,
239+
re_comb); // start_v.real() * _w + end_v.real() * _wc;
240+
auto im_comb =
241+
sycl::fma(start_v.imag(), _w,
242+
reT(0)); // start_v.imag() * _w + end_v.imag() * _wc;
243+
im_comb = sycl::fma(end_v.imag(), _wc, im_comb);
237244
Ty affine_comb = Ty{re_comb, im_comb};
238245
p[i] = affine_comb;
239246
}
247+
else if constexpr (std::is_floating_point<Ty>::value) {
248+
Ty _w = static_cast<Ty>(w);
249+
Ty _wc = static_cast<Ty>(wc);
250+
auto affine_comb =
251+
sycl::fma(start_v, _w, Ty(0)); // start_v * w + end_v * wc;
252+
affine_comb = sycl::fma(end_v, _wc, affine_comb);
253+
p[i] = affine_comb;
254+
}
240255
else {
241256
using dpctl::tensor::type_utils::convert_impl;
242257
auto affine_comb = start_v * w + end_v * wc;

0 commit comments

Comments
 (0)