@@ -230,13 +230,28 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
230
230
wTy w = wTy (n - i) / n;
231
231
using dpctl::tensor::type_utils::is_complex;
232
232
if constexpr (is_complex<Ty>::value) {
233
- auto _w = static_cast <typename Ty::value_type>(w);
234
- auto _wc = static_cast <typename Ty::value_type>(wc);
235
- auto re_comb = start_v.real () * _w + end_v.real () * _wc;
236
- auto im_comb = start_v.imag () * _w + end_v.imag () * _wc;
233
+ using reT = typename Ty::value_type;
234
+ auto _w = static_cast <reT>(w);
235
+ auto _wc = static_cast <reT>(wc);
236
+ auto re_comb = sycl::fma (start_v.real (), _w, reT (0 ));
237
+ re_comb =
238
+ sycl::fma (end_v.real (), _wc,
239
+ re_comb); // start_v.real() * _w + end_v.real() * _wc;
240
+ auto im_comb =
241
+ sycl::fma (start_v.imag (), _w,
242
+ reT (0 )); // start_v.imag() * _w + end_v.imag() * _wc;
243
+ im_comb = sycl::fma (end_v.imag (), _wc, im_comb);
237
244
Ty affine_comb = Ty{re_comb, im_comb};
238
245
p[i] = affine_comb;
239
246
}
247
+ else if constexpr (std::is_floating_point<Ty>::value) {
248
+ Ty _w = static_cast <Ty>(w);
249
+ Ty _wc = static_cast <Ty>(wc);
250
+ auto affine_comb =
251
+ sycl::fma (start_v, _w, Ty (0 )); // start_v * w + end_v * wc;
252
+ affine_comb = sycl::fma (end_v, _wc, affine_comb);
253
+ p[i] = affine_comb;
254
+ }
240
255
else {
241
256
using dpctl::tensor::type_utils::convert_impl;
242
257
auto affine_comb = start_v * w + end_v * wc;
0 commit comments