Skip to content

Commit a717c69

Browse files
Update dpnp fft implementations to run on Iris Xe (#1524)
* Impove dpnp_fft impl for Iris Xe * Impove dpnp_rfft impl for Iris Xe * Update test_fft and cupy test * Refresh dpnp_iface_fft.py * Fix cupy test_fft.py * Apply review remarks * Reduce python version to 3.10 for generate_coverage * Raise TypeError for boolean data type * Add a new test for fft funcs
1 parent 92e9b1c commit a717c69

File tree

6 files changed

+404
-281
lines changed

6 files changed

+404
-281
lines changed

.github/workflows/generate_coverage.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
shell: bash -l {0}
1515

1616
env:
17-
python-ver: '3.11'
17+
python-ver: '3.10'
1818
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
1919

2020
steps:

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
414414
const size_t norm,
415415
const DPCTLEventVectorRef dep_event_vec_ref)
416416
{
417+
static_assert(sycl::detail::is_complex<_DataType_output>::value,
418+
"Output data type must be a complex type.");
419+
417420
DPCTLSyclEventRef event_ref = nullptr;
418421

419422
if (!shape_size || !array1_in || !result_out) {
@@ -476,8 +479,10 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
476479
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
477480
std::is_same<_DataType_input, int64_t>::value)
478481
{
479-
double *array1_copy = reinterpret_cast<double *>(
480-
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
482+
using CastType = typename _DataType_output::value_type;
483+
484+
CastType *array1_copy = reinterpret_cast<CastType *>(
485+
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));
481486

482487
shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
483488
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
@@ -486,15 +491,17 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
486491
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
487492
*copy_shape = input_size;
488493
shape_elem_type copy_shape_size = 1;
489-
event_ref = dpnp_copyto_c<_DataType_input, double>(
494+
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
490495
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
491496
copy_strides, array1_in, input_size, copy_shape_size,
492497
copy_shape, copy_strides, NULL, dep_event_vec_ref);
493498
DPCTLEvent_WaitAndThrow(event_ref);
494499
DPCTLEvent_Delete(event_ref);
495500

496-
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
497-
desc_dp_real_t>(
501+
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
502+
CastType, CastType,
503+
std::conditional_t<std::is_same<CastType, double>::value,
504+
desc_dp_real_t, desc_sp_real_t>>(
498505
q_ref, array1_copy, result_out, input_shape, result_shape,
499506
shape_size, input_size, result_size, inverse, norm, 0);
500507

@@ -577,6 +584,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
577584
const size_t norm,
578585
const DPCTLEventVectorRef dep_event_vec_ref)
579586
{
587+
static_assert(sycl::detail::is_complex<_DataType_output>::value,
588+
"Output data type must be a complex type.");
580589
DPCTLSyclEventRef event_ref = nullptr;
581590

582591
if (!shape_size || !array1_in || !result_out) {
@@ -617,8 +626,10 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
617626
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
618627
std::is_same<_DataType_input, int64_t>::value)
619628
{
620-
double *array1_copy = reinterpret_cast<double *>(
621-
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
629+
using CastType = typename _DataType_output::value_type;
630+
631+
CastType *array1_copy = reinterpret_cast<CastType *>(
632+
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));
622633

623634
shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
624635
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
@@ -627,15 +638,17 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
627638
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
628639
*copy_shape = input_size;
629640
shape_elem_type copy_shape_size = 1;
630-
event_ref = dpnp_copyto_c<_DataType_input, double>(
641+
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
631642
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
632643
copy_strides, array1_in, input_size, copy_shape_size,
633644
copy_shape, copy_strides, NULL, dep_event_vec_ref);
634645
DPCTLEvent_WaitAndThrow(event_ref);
635646
DPCTLEvent_Delete(event_ref);
636647

637-
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
638-
desc_dp_real_t>(
648+
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
649+
CastType, CastType,
650+
std::conditional_t<std::is_same<CastType, double>::value,
651+
desc_dp_real_t, desc_sp_real_t>>(
639652
q_ref, array1_copy, result_out, input_shape, result_shape,
640653
shape_size, input_size, result_size, inverse, norm, 1);
641654

@@ -721,9 +734,11 @@ void func_map_init_fft_func(func_map_t &fmap)
721734
dpnp_fft_fft_default_c<std::complex<double>, std::complex<double>>};
722735

723736
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
724-
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>};
737+
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>,
738+
eft_C64, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<float>>};
725739
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
726-
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>};
740+
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>,
741+
eft_C64, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<float>>};
727742
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
728743
eft_C64, (void *)dpnp_fft_fft_ext_c<float, std::complex<float>>};
729744
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
@@ -748,9 +763,11 @@ void func_map_init_fft_func(func_map_t &fmap)
748763
(void *)dpnp_fft_rfft_default_c<double, std::complex<double>>};
749764

750765
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = {
751-
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>};
766+
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>,
767+
eft_C64, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<float>>};
752768
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = {
753-
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>};
769+
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>,
770+
eft_C64, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<float>>};
754771
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = {
755772
eft_C64, (void *)dpnp_fft_rfft_ext_c<float, std::complex<float>>};
756773
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = {

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,15 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
6868

6969
input_obj = input.get_array()
7070

71+
# get FPTR function and return type
72+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
73+
input_obj.sycl_device.has_aspect_fp64)
74+
cdef DPNPFuncType return_type = ret_type_and_func[0]
75+
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]
76+
7177
# ceate result array with type given by FPTR data
7278
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
73-
kernel_data.return_type,
79+
return_type,
7480
None,
7581
device=input_obj.sycl_device,
7682
usm_type=input_obj.usm_type,
@@ -81,7 +87,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
8187
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
8288
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
8389

84-
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
8590
# call FPTR function
8691
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
8792
input.get_data(),
@@ -122,9 +127,15 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,
122127

123128
input_obj = input.get_array()
124129

130+
# get FPTR function and return type
131+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
132+
input_obj.sycl_device.has_aspect_fp64)
133+
cdef DPNPFuncType return_type = ret_type_and_func[0]
134+
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]
135+
125136
# ceate result array with type given by FPTR data
126137
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
127-
kernel_data.return_type,
138+
return_type,
128139
None,
129140
device=input_obj.sycl_device,
130141
usm_type=input_obj.usm_type,
@@ -135,7 +146,6 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,
135146
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
136147
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
137148

138-
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
139149
# call FPTR function
140150
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
141151
input.get_data(),

0 commit comments

Comments
 (0)