Skip to content

Commit e8faed0

Browse files
committed
fix rfft
1 parent 8afa83c commit e8faed0

File tree

5 files changed

+93
-58
lines changed

5 files changed

+93
-58
lines changed

dpnp/backend/include/dpnp_iface_fft.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
* @param[in] input_boundarie Limit number of elements for @ref axis.
6363
* @param[in] inverse Using inverse algorithm.
6464
* @param[in] norm Normalization mode. 0 - backward, 1 - forward, 2 - ortho.
65+
* @param[in] real Real mode, 1 if type of array1_in is real and size of result_out is input_size/ 2 + 1, else 0.
6566
* @param[in] dep_event_vec_ref Reference to vector of SYCL events.
6667
*/
6768
template <typename _DataType>
@@ -75,5 +76,6 @@ INP_DLLEXPORT void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
7576
long input_boundarie,
7677
size_t inverse,
7778
const size_t norm,
79+
const size_t real,
7880
const DPCTLEventVectorRef dep_event_vec_ref);
7981
#endif // BACKEND_IFACE_FFT_H

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
175175
const void* array1_in,
176176
void* result_out,
177177
const shape_elem_type* input_shape,
178+
const shape_elem_type* result_shape,
178179
const size_t shape_size,
180+
const size_t input_size,
179181
const size_t result_size,
180182
_Descriptor_type& desc,
181183
const size_t norm)
@@ -187,7 +189,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187189

188190
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
189191

190-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, result_size);
192+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
191193
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size);
192194
_DataType_input* array_1 = input1_ptr.get_ptr();
193195
_DataType_output* result = result_ptr.get_ptr();
@@ -205,7 +207,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
205207
// enum value from math library C interface
206208
// instead of mkl_dft::config_value::NOT_INPLACE
207209
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
208-
desc.commit(DPNP_QUEUE);
210+
desc.commit(queue);
209211

210212
std::vector<sycl::event> fft_events;
211213
fft_events.reserve(n_iter);
@@ -227,72 +229,81 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
227229
const void* array1_in,
228230
void* result_out,
229231
const shape_elem_type* input_shape,
232+
const shape_elem_type* result_shape,
230233
const size_t shape_size,
234+
const size_t input_size,
231235
const size_t result_size,
232236
_Descriptor_type& desc,
233-
const size_t norm)
237+
const size_t norm,
238+
const size_t real)
234239
{
235240
if (!shape_size)
236241
{
237242
return;
238243
}
239244

240-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, result_size);
245+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, array1_in, input_size);
241246
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size * 2, true, true);
242247
_DataType_input* array_1 = input1_ptr.get_ptr();
243248
_DataType_output* result = result_ptr.get_ptr();
244249

250+
sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));
251+
245252
const size_t n_iter =
246253
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
247254

248-
const size_t shift = input_shape[shape_size - 1];
255+
const size_t input_shift = input_shape[shape_size - 1];
256+
const size_t result_shift = result_shape[shape_size - 1];;
249257

250258
double forward_scale = 1.0;
251-
double backward_scale = 1.0 / shift;
259+
double backward_scale = 1.0 / input_shift;
252260

253261
desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
254262
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
255263

256-
desc.commit(DPNP_QUEUE);
264+
desc.commit(queue);
257265

258266
std::vector<sycl::event> fft_events;
259267
fft_events.reserve(n_iter);
260268

261269
for (size_t i = 0; i < n_iter; ++i) {
262-
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift * 2));
270+
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
263271
}
264272

265273
sycl::event::wait(fft_events);
266274

267-
size_t n_conj = shift % 2 == 0 ? shift / 2 - 1 : shift / 2;
275+
if (!real) {
268276

269-
sycl::event event;
277+
size_t n_conj = result_shift % 2 == 0 ? result_shift / 2 - 1 : result_shift / 2;
270278

271-
sycl::range<2> gws(n_iter, n_conj);
279+
sycl::event event;
272280

273-
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
274-
size_t i = global_id[0];
275-
{
276-
size_t j = global_id[1];
281+
sycl::range<2> gws(n_iter, n_conj);
282+
283+
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
284+
size_t i = global_id[0];
277285
{
278-
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + shift * (i + 1) - (j + 1)) = std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + shift * i + (j + 1)));
286+
size_t j = global_id[1];
287+
{
288+
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) = std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
289+
}
279290
}
280-
}
281-
};
291+
};
282292

283-
auto kernel_func = [&](sycl::handler& cgh) {
284-
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
285-
gws, kernel_parallel_for_func);
286-
};
293+
auto kernel_func = [&](sycl::handler& cgh) {
294+
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
295+
gws, kernel_parallel_for_func);
296+
};
287297

288-
event = DPNP_QUEUE.submit(kernel_func);
289-
event.wait();
298+
event = queue.submit(kernel_func);
299+
event.wait();
300+
}
290301

291302
return;
292303
}
293304

294305
template <typename _DataType_input, typename _DataType_output>
295-
void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
306+
DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
296307
const void* array1_in,
297308
void* result_out,
298309
const shape_elem_type* input_shape,
@@ -302,10 +313,9 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
302313
long input_boundarie,
303314
size_t inverse,
304315
const size_t norm,
316+
const size_t real,
305317
const DPCTLEventVectorRef dep_event_vec_ref)
306318
{
307-
(void)dep_event_vec_ref;
308-
309319
DPCTLSyclEventRef event_ref = nullptr;
310320

311321
if (!shape_size || !array1_in || !result_out)
@@ -317,8 +327,6 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
317327
std::accumulate(result_shape, result_shape + shape_size, 1, std::multiplies<shape_elem_type>());
318328
const size_t input_size =
319329
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());
320-
321-
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
322330

323331
size_t dim = input_shape[shape_size - 1];
324332

@@ -330,15 +338,15 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
330338
{
331339
desc_dp_cmplx_t desc(dim);
332340
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
333-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
341+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
334342
}
335343
/* complex-to-complex, single precision */
336344
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
337345
std::is_same<_DataType_output, std::complex<float>>::value)
338346
{
339347
desc_sp_cmplx_t desc(dim);
340348
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
341-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
349+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm);
342350
}
343351
/* real-to-complex, double precision */
344352
else if constexpr (std::is_same<_DataType_input, double>::value &&
@@ -347,36 +355,36 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
347355
desc_dp_real_t desc(dim);
348356

349357
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
350-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
358+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
351359
}
352360
/* real-to-complex, single precision */
353361
else if constexpr (std::is_same<_DataType_input, float>::value &&
354362
std::is_same<_DataType_output, std::complex<float>>::value)
355363
{
356364
desc_sp_real_t desc(dim); // try: 2 * result_size
357365
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
358-
q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
366+
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
359367
}
360368
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
361369
std::is_same<_DataType_input, int64_t>::value)
362370
{
363371
double* array1_copy = reinterpret_cast<double*>(dpnp_memory_alloc_c(input_size * sizeof(double)));
364372

365-
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(sizeof(shape_elem_type)));
373+
shape_elem_type* copy_strides = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
366374
*copy_strides = 1;
367-
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(sizeof(shape_elem_type)));
375+
shape_elem_type* copy_shape = reinterpret_cast<shape_elem_type*>(dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
368376
*copy_shape = input_size;
369377
shape_elem_type copy_shape_size = 1;
370-
dpnp_copyto_c<_DataType_input, double>(array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
371-
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL);
378+
dpnp_copyto_c<_DataType_input, double>(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
379+
array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL, dep_event_vec_ref);
372380

373381
desc_dp_real_t desc(dim);
374382
dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
375-
array1_copy, result_out, input_shape, shape_size, result_size, desc, norm);
383+
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real);
376384

377-
dpnp_memory_free_c(array1_copy);
378-
dpnp_memory_free_c(copy_strides);
379-
dpnp_memory_free_c(copy_shape);
385+
dpnp_memory_free_c(q_ref, array1_copy);
386+
dpnp_memory_free_c(q_ref, copy_strides);
387+
dpnp_memory_free_c(q_ref, copy_shape);
380388
}
381389
else
382390
{
@@ -406,7 +414,8 @@ void dpnp_fft_fft_c(const void* array1_in,
406414
long axis,
407415
long input_boundarie,
408416
size_t inverse,
409-
const size_t norm)
417+
const size_t norm,
418+
const size_t real)
410419
{
411420
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
412421
DPCTLEventVectorRef dep_event_vec_ref = nullptr;
@@ -420,6 +429,7 @@ void dpnp_fft_fft_c(const void* array1_in,
420429
input_boundarie,
421430
inverse,
422431
norm,
432+
real,
423433
dep_event_vec_ref);
424434
DPCTLEvent_WaitAndThrow(event_ref);
425435
}
@@ -433,6 +443,7 @@ void (*dpnp_fft_fft_default_c)(const void*,
433443
long,
434444
long,
435445
size_t,
446+
const size_t,
436447
const size_t) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
437448

438449
template <typename _DataType_input, typename _DataType_output>
@@ -446,6 +457,7 @@ DPCTLSyclEventRef (*dpnp_fft_fft_ext_c)(DPCTLSyclQueueRef,
446457
long,
447458
size_t,
448459
const size_t,
460+
const size_t,
449461
const DPCTLEventVectorRef) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
450462

451463
void func_map_init_fft_func(func_map_t& fmap)

dpnp/fft/dpnp_algo_fft.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,16 @@ __all__ = [
4242
]
4343

4444
ctypedef void(*fptr_dpnp_fft_fft_t)(void *, void * , shape_elem_type * , shape_elem_type * ,
45-
size_t, long, long, size_t, size_t)
45+
size_t, long, long, size_t, size_t, size_t)
4646

4747

4848
cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
4949
size_t input_boundarie,
5050
size_t output_boundarie,
5151
long axis,
5252
size_t inverse,
53-
size_t norm):
53+
size_t norm,
54+
size_t real):
5455

5556
cdef shape_type_c input_shape = input.shape
5657
cdef shape_type_c output_shape = input_shape
@@ -70,6 +71,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
7071
cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
7172
# call FPTR function
7273
func(input.get_data(), result.get_data(), input_shape.data(),
73-
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse, norm)
74+
output_shape.data(), input_shape.size(), axis_norm, input_boundarie, inverse, norm, real)
7475

7576
return result

0 commit comments

Comments
 (0)