@@ -414,6 +414,9 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
414
414
const size_t norm,
415
415
const DPCTLEventVectorRef dep_event_vec_ref)
416
416
{
417
+ static_assert (sycl::detail::is_complex<_DataType_output>::value,
418
+ " Output data type must be a complex type." );
419
+
417
420
DPCTLSyclEventRef event_ref = nullptr ;
418
421
419
422
if (!shape_size || !array1_in || !result_out) {
@@ -476,8 +479,10 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
476
479
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
477
480
std::is_same<_DataType_input, int64_t >::value)
478
481
{
479
- double *array1_copy = reinterpret_cast <double *>(
480
- dpnp_memory_alloc_c (q_ref, input_size * sizeof (double )));
482
+ using CastType = typename _DataType_output::value_type;
483
+
484
+ CastType *array1_copy = reinterpret_cast <CastType *>(
485
+ dpnp_memory_alloc_c (q_ref, input_size * sizeof (CastType)));
481
486
482
487
shape_elem_type *copy_strides = reinterpret_cast <shape_elem_type *>(
483
488
dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
@@ -486,15 +491,17 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
486
491
dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
487
492
*copy_shape = input_size;
488
493
shape_elem_type copy_shape_size = 1 ;
489
- event_ref = dpnp_copyto_c<_DataType_input, double >(
494
+ event_ref = dpnp_copyto_c<_DataType_input, CastType >(
490
495
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
491
496
copy_strides, array1_in, input_size, copy_shape_size,
492
497
copy_shape, copy_strides, NULL , dep_event_vec_ref);
493
498
DPCTLEvent_WaitAndThrow (event_ref);
494
499
DPCTLEvent_Delete (event_ref);
495
500
496
- event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double ,
497
- desc_dp_real_t >(
501
+ event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
502
+ CastType, CastType,
503
+ std::conditional_t <std::is_same<CastType, double >::value,
504
+ desc_dp_real_t , desc_sp_real_t >>(
498
505
q_ref, array1_copy, result_out, input_shape, result_shape,
499
506
shape_size, input_size, result_size, inverse, norm, 0 );
500
507
@@ -577,6 +584,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
577
584
const size_t norm,
578
585
const DPCTLEventVectorRef dep_event_vec_ref)
579
586
{
587
+ static_assert (sycl::detail::is_complex<_DataType_output>::value,
588
+ " Output data type must be a complex type." );
580
589
DPCTLSyclEventRef event_ref = nullptr ;
581
590
582
591
if (!shape_size || !array1_in || !result_out) {
@@ -617,8 +626,10 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
617
626
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
618
627
std::is_same<_DataType_input, int64_t >::value)
619
628
{
620
- double *array1_copy = reinterpret_cast <double *>(
621
- dpnp_memory_alloc_c (q_ref, input_size * sizeof (double )));
629
+ using CastType = typename _DataType_output::value_type;
630
+
631
+ CastType *array1_copy = reinterpret_cast <CastType *>(
632
+ dpnp_memory_alloc_c (q_ref, input_size * sizeof (CastType)));
622
633
623
634
shape_elem_type *copy_strides = reinterpret_cast <shape_elem_type *>(
624
635
dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
@@ -627,15 +638,17 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
627
638
dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
628
639
*copy_shape = input_size;
629
640
shape_elem_type copy_shape_size = 1 ;
630
- event_ref = dpnp_copyto_c<_DataType_input, double >(
641
+ event_ref = dpnp_copyto_c<_DataType_input, CastType >(
631
642
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
632
643
copy_strides, array1_in, input_size, copy_shape_size,
633
644
copy_shape, copy_strides, NULL , dep_event_vec_ref);
634
645
DPCTLEvent_WaitAndThrow (event_ref);
635
646
DPCTLEvent_Delete (event_ref);
636
647
637
- event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double ,
638
- desc_dp_real_t >(
648
+ event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
649
+ CastType, CastType,
650
+ std::conditional_t <std::is_same<CastType, double >::value,
651
+ desc_dp_real_t , desc_sp_real_t >>(
639
652
q_ref, array1_copy, result_out, input_shape, result_shape,
640
653
shape_size, input_size, result_size, inverse, norm, 1 );
641
654
@@ -721,9 +734,11 @@ void func_map_init_fft_func(func_map_t &fmap)
721
734
dpnp_fft_fft_default_c<std::complex<double >, std::complex<double >>};
722
735
723
736
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
724
- eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex<double >>};
737
+ eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex<double >>,
738
+ eft_C64, (void *)dpnp_fft_fft_ext_c<int32_t , std::complex<float >>};
725
739
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
726
- eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex<double >>};
740
+ eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex<double >>,
741
+ eft_C64, (void *)dpnp_fft_fft_ext_c<int64_t , std::complex<float >>};
727
742
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
728
743
eft_C64, (void *)dpnp_fft_fft_ext_c<float , std::complex<float >>};
729
744
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
@@ -748,9 +763,11 @@ void func_map_init_fft_func(func_map_t &fmap)
748
763
(void *)dpnp_fft_rfft_default_c<double , std::complex<double >>};
749
764
750
765
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = {
751
- eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex<double >>};
766
+ eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex<double >>,
767
+ eft_C64, (void *)dpnp_fft_rfft_ext_c<int32_t , std::complex<float >>};
752
768
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = {
753
- eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex<double >>};
769
+ eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex<double >>,
770
+ eft_C64, (void *)dpnp_fft_rfft_ext_c<int64_t , std::complex<float >>};
754
771
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = {
755
772
eft_C64, (void *)dpnp_fft_rfft_ext_c<float , std::complex<float >>};
756
773
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = {
0 commit comments