Skip to content

Commit d240d22

Browse files
committed
Finalize CFD support in fft
1 parent ec7bafa commit d240d22

File tree

6 files changed

+168
-235
lines changed

6 files changed

+168
-235
lines changed

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ jobs:
241241

242242
# TODO: run the whole scope once the issues on CPU are resolved
243243
- name: Run tests
244-
run: python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_mathematical.py test_special.py
244+
run: python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_mathematical.py test_special.py
245245
env:
246246
SYCL_ENABLE_HOST_DEVICE: '1'
247247
working-directory: ${{ env.tests-path }}
@@ -416,7 +416,7 @@ jobs:
416416

417417
# TODO: run the whole scope once the issues on CPU are resolved
418418
- name: Run tests
419-
run: python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_mathematical.py test_special.py
419+
run: python -m pytest -q -ra --disable-warnings -vv test_arraycreation.py test_dparray.py test_fft.py test_mathematical.py test_special.py
420420
working-directory: ${{ env.tests-path }}
421421

422422
upload_linux:

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ template <typename _KernelNameSpecialization1, typename _KernelNameSpecializatio
4848
class dpnp_fft_fft_c_kernel;
4949

5050
template <typename _DataType_input, typename _DataType_output>
51-
void dpnp_fft_fft_sycl_c(DPCTLSyclQueueRef q_ref,
52-
const void* array1_in,
53-
void* result_out,
54-
const shape_elem_type* input_shape,
55-
const shape_elem_type* output_shape,
56-
size_t shape_size,
57-
const size_t result_size,
58-
const size_t input_size,
59-
long axis,
60-
long input_boundarie,
61-
size_t inverse)
51+
static void dpnp_fft_fft_sycl_c(DPCTLSyclQueueRef q_ref,
52+
const void* array1_in,
53+
void* result_out,
54+
const shape_elem_type* input_shape,
55+
const shape_elem_type* output_shape,
56+
size_t shape_size,
57+
const size_t result_size,
58+
const size_t input_size,
59+
long axis,
60+
long input_boundarie,
61+
size_t inverse)
6262
{
6363
if (!(input_size && result_size && shape_size))
6464
{
@@ -71,9 +71,8 @@ void dpnp_fft_fft_sycl_c(DPCTLSyclQueueRef q_ref,
7171

7272
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
7373

74-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
75-
const _DataType_input* array_1 = input1_ptr.get_ptr();
76-
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
74+
_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
75+
_DataType_output* result = static_cast<_DataType_output *>(result_out);
7776

7877
// kernel specific temporal data
7978
shape_elem_type* output_shape_offsets =
@@ -171,29 +170,28 @@ void dpnp_fft_fft_sycl_c(DPCTLSyclQueueRef q_ref,
171170
}
172171

173172
template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
174-
void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
175-
const void* array1_in,
176-
void* result_out,
177-
const shape_elem_type* input_shape,
178-
const shape_elem_type*,
179-
const size_t shape_size,
180-
const size_t input_size,
181-
const size_t result_size,
182-
_Descriptor_type& desc,
183-
size_t inverse,
184-
const size_t norm)
173+
static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
174+
const void* array1_in,
175+
void* result_out,
176+
const shape_elem_type* input_shape,
177+
const shape_elem_type* result_shape,
178+
const size_t shape_size,
179+
const size_t input_size,
180+
const size_t result_size,
181+
_Descriptor_type& desc,
182+
size_t inverse,
183+
const size_t norm)
185184
{
186-
if (!shape_size)
187-
{
185+
(void)result_shape;
186+
187+
if (!shape_size) {
188188
return;
189189
}
190190

191191
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
192192

193-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
194-
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size);
195-
_DataType_input* array_1 = input1_ptr.get_ptr();
196-
_DataType_output* result = result_ptr.get_ptr();
193+
_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
194+
_DataType_output* result = static_cast<_DataType_output *>(result_out);
197195

198196
const size_t n_iter =
199197
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
@@ -242,32 +240,29 @@ template <typename _KernelNameSpecialization1, typename _KernelNameSpecializatio
242240
class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel;
243241

244242
template <typename _DataType_input, typename _DataType_output, typename _Descriptor_type>
245-
DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
246-
const void* array1_in,
247-
void* result_out,
248-
const shape_elem_type* input_shape,
249-
const shape_elem_type* result_shape,
250-
const size_t shape_size,
251-
const size_t input_size,
252-
const size_t result_size,
253-
_Descriptor_type& desc,
254-
size_t inverse,
255-
const size_t norm,
256-
const size_t real)
243+
static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
244+
const void* array1_in,
245+
void* result_out,
246+
const shape_elem_type* input_shape,
247+
const shape_elem_type* result_shape,
248+
const size_t shape_size,
249+
const size_t input_size,
250+
const size_t result_size,
251+
_Descriptor_type& desc,
252+
size_t inverse,
253+
const size_t norm,
254+
const size_t real)
257255
{
258-
DPCTLSyclEventRef event_ref = nullptr;;
259-
if (!shape_size)
260-
{
256+
DPCTLSyclEventRef event_ref = nullptr;
257+
if (!shape_size) {
261258
return event_ref;
262259
}
263260

264-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
265-
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size * 2, true, true);
266-
_DataType_input* array_1 = input1_ptr.get_ptr();
267-
_DataType_output* result = result_ptr.get_ptr();
268-
269261
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
270262

263+
_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
264+
_DataType_output* result = static_cast<_DataType_output *>(result_out);
265+
271266
const size_t n_iter =
272267
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
273268

@@ -323,7 +318,8 @@ DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
323318
{
324319
size_t j = global_id[1];
325320
{
326-
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) = std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
321+
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) =
322+
std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
327323
}
328324
}
329325
};
@@ -337,7 +333,10 @@ DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
337333

338334
if (inverse) {
339335
event.wait();
340-
event = oneapi::mkl::vm::conj(queue, result_size, reinterpret_cast<std::complex<_DataType_output>*>(result), reinterpret_cast<std::complex<_DataType_output>*>(result));
336+
event = oneapi::mkl::vm::conj(queue,
337+
result_size,
338+
reinterpret_cast<std::complex<_DataType_output>*>(result),
339+
reinterpret_cast<std::complex<_DataType_output>*>(result));
341340
}
342341

343342
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
@@ -411,21 +410,25 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
411410
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
412411
std::is_same<_DataType_input, int64_t>::value)
413412
{
414-
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(input_size * sizeof(double)));
413+
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
415414

416415
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
417416
*copy_strides = 1;
418417
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
419418
*copy_shape = input_size;
420419
shape_elem_type copy_shape_size = 1;
421-
dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
422-
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
420+
event_ref = dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
421+
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
422+
DPCTLEvent_WaitAndThrow(event_ref);
423+
DPCTLEvent_Delete(event_ref);
423424

424425
desc_dp_real_t desc(dim);
425426
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
426427
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
427428

428429
DPCTLEvent_WaitAndThrow(event_ref);
430+
DPCTLEvent_Delete(event_ref);
431+
event_ref = nullptr;
429432

430433
dpnp_memory_free_c(q_ref, array1_copy);
431434
dpnp_memory_free_c(q_ref, copy_strides);
@@ -475,6 +478,7 @@ void dpnp_fft_fft_c(const void* array1_in,
475478
norm,
476479
dep_event_vec_ref);
477480
DPCTLEvent_WaitAndThrow(event_ref);
481+
DPCTLEvent_Delete(event_ref);
478482
}
479483

480484
template <typename _DataType_input, typename _DataType_output>
@@ -529,7 +533,6 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
529533

530534
size_t dim = input_shape[shape_size - 1];
531535

532-
533536
if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
534537
std::is_same<_DataType_output, std::complex<double>>::value)
535538
{
@@ -552,21 +555,25 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
552555
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
553556
std::is_same<_DataType_input, int64_t>::value)
554557
{
555-
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(input_size * sizeof(double)));
558+
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
556559

557560
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
558561
*copy_strides = 1;
559562
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
560563
*copy_shape = input_size;
561564
shape_elem_type copy_shape_size = 1;
562-
dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
563-
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
565+
event_ref = dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
566+
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
567+
DPCTLEvent_WaitAndThrow(event_ref);
568+
DPCTLEvent_Delete(event_ref);
564569

565570
desc_dp_real_t desc(dim);
566571
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
567572
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
568573

569574
DPCTLEvent_WaitAndThrow(event_ref);
575+
DPCTLEvent_Delete(event_ref);
576+
event_ref = nullptr;
570577

571578
dpnp_memory_free_c(q_ref, array1_copy);
572579
dpnp_memory_free_c(q_ref, copy_strides);
@@ -603,6 +610,7 @@ void dpnp_fft_rfft_c(const void* array1_in,
603610
norm,
604611
dep_event_vec_ref);
605612
DPCTLEvent_WaitAndThrow(event_ref);
613+
DPCTLEvent_Delete(event_ref);
606614
}
607615

608616
template <typename _DataType_input, typename _DataType_output>

0 commit comments

Comments
 (0)