@@ -175,7 +175,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
175
175
const void * array1_in,
176
176
void * result_out,
177
177
const shape_elem_type* input_shape,
178
+ const shape_elem_type* result_shape,
178
179
const size_t shape_size,
180
+ const size_t input_size,
179
181
const size_t result_size,
180
182
_Descriptor_type& desc,
181
183
const size_t norm)
@@ -187,7 +189,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187
189
188
190
sycl::queue queue = *(reinterpret_cast <sycl::queue*>(q_ref));
189
191
190
- DPNPC_ptr_adapter<_DataType_input> input1_ptr (q_ref, array1_in, result_size );
192
+ DPNPC_ptr_adapter<_DataType_input> input1_ptr (q_ref, array1_in, input_size );
191
193
DPNPC_ptr_adapter<_DataType_output> result_ptr (q_ref, result_out, result_size);
192
194
_DataType_input* array_1 = input1_ptr.get_ptr ();
193
195
_DataType_output* result = result_ptr.get_ptr ();
@@ -227,72 +229,81 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
227
229
const void * array1_in,
228
230
void * result_out,
229
231
const shape_elem_type* input_shape,
232
+ const shape_elem_type* result_shape,
230
233
const size_t shape_size,
234
+ const size_t input_size,
231
235
const size_t result_size,
232
236
_Descriptor_type& desc,
233
- const size_t norm)
237
+ const size_t norm,
238
+ const size_t real)
234
239
{
235
240
if (!shape_size)
236
241
{
237
242
return ;
238
243
}
239
244
240
- DPNPC_ptr_adapter<_DataType_input> input1_ptr (q_ref, array1_in, result_size );
245
+ DPNPC_ptr_adapter<_DataType_input> input1_ptr (q_ref, array1_in, input_size );
241
246
DPNPC_ptr_adapter<_DataType_output> result_ptr (q_ref, result_out, result_size * 2 , true , true );
242
247
_DataType_input* array_1 = input1_ptr.get_ptr ();
243
248
_DataType_output* result = result_ptr.get_ptr ();
244
249
250
+ sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
251
+
245
252
const size_t n_iter =
246
253
std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
247
254
248
- const size_t shift = input_shape[shape_size - 1 ];
255
+ const size_t input_shift = input_shape[shape_size - 1 ];
256
+ const size_t result_shift = result_shape[shape_size - 1 ];;
249
257
250
258
double forward_scale = 1.0 ;
251
- double backward_scale = 1.0 / shift ;
259
+ double backward_scale = 1.0 / input_shift ;
252
260
253
261
desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
254
262
desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
255
263
256
- desc.commit (DPNP_QUEUE );
264
+ desc.commit (q );
257
265
258
266
std::vector<sycl::event> fft_events;
259
267
fft_events.reserve (n_iter);
260
268
261
269
for (size_t i = 0 ; i < n_iter; ++i) {
262
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift , result + i * shift * 2 ));
270
+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift , result + i * result_shift * 2 ));
263
271
}
264
272
265
273
sycl::event::wait (fft_events);
266
274
267
- size_t n_conj = shift % 2 == 0 ? shift / 2 - 1 : shift / 2 ;
275
+ if (!real) {
268
276
269
- sycl::event event ;
277
+ size_t n_conj = result_shift % 2 == 0 ? result_shift / 2 - 1 : result_shift / 2 ;
270
278
271
- sycl::range< 2 > gws (n_iter, n_conj) ;
279
+ sycl::event event ;
272
280
273
- auto kernel_parallel_for_func = [=]( sycl::id <2 > global_id) {
274
- size_t i = global_id[ 0 ];
275
- {
276
- size_t j = global_id[1 ];
281
+ sycl::range <2 > gws (n_iter, n_conj);
282
+
283
+ auto kernel_parallel_for_func = [=](sycl::id< 2 > global_id) {
284
+ size_t i = global_id[0 ];
277
285
{
278
- *(reinterpret_cast <std::complex<_DataType_output>*>(result) + shift * (i + 1 ) - (j + 1 )) = std::conj (*(reinterpret_cast <std::complex<_DataType_output>*>(result) + shift * i + (j + 1 )));
286
+ size_t j = global_id[1 ];
287
+ {
288
+ *(reinterpret_cast <std::complex<_DataType_output>*>(result) + result_shift * (i + 1 ) - (j + 1 )) = std::conj (*(reinterpret_cast <std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1 )));
289
+ }
279
290
}
280
- }
281
- };
291
+ };
282
292
283
- auto kernel_func = [&](sycl::handler& cgh) {
284
- cgh.parallel_for <class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
285
- gws, kernel_parallel_for_func);
286
- };
293
+ auto kernel_func = [&](sycl::handler& cgh) {
294
+ cgh.parallel_for <class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
295
+ gws, kernel_parallel_for_func);
296
+ };
287
297
288
- event = DPNP_QUEUE.submit (kernel_func);
289
- event.wait ();
298
+ event = q.submit (kernel_func);
299
+ event.wait ();
300
+ }
290
301
291
302
return ;
292
303
}
293
304
294
305
template <typename _DataType_input, typename _DataType_output>
295
- void dpnp_fft_fft_c (DPCTLSyclQueueRef q_ref,
306
+ DPCTLSyclEventRef dpnp_fft_fft_c (DPCTLSyclQueueRef q_ref,
296
307
const void * array1_in,
297
308
void * result_out,
298
309
const shape_elem_type* input_shape,
@@ -302,10 +313,9 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
302
313
long input_boundarie,
303
314
size_t inverse,
304
315
const size_t norm,
316
+ const size_t real,
305
317
const DPCTLEventVectorRef dep_event_vec_ref)
306
318
{
307
- (void )dep_event_vec_ref;
308
-
309
319
DPCTLSyclEventRef event_ref = nullptr ;
310
320
311
321
if (!shape_size || !array1_in || !result_out)
@@ -317,8 +327,6 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
317
327
std::accumulate (result_shape, result_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
318
328
const size_t input_size =
319
329
std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
320
-
321
- sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
322
330
323
331
size_t dim = input_shape[shape_size - 1 ];
324
332
@@ -330,15 +338,15 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
330
338
{
331
339
desc_dp_cmplx_t desc (dim);
332
340
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
333
- q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
341
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size , result_size, desc, norm);
334
342
}
335
343
/* complex-to-complex, single precision */
336
344
else if constexpr (std::is_same<_DataType_input, std::complex<float >>::value &&
337
345
std::is_same<_DataType_output, std::complex<float >>::value)
338
346
{
339
347
desc_sp_cmplx_t desc (dim);
340
348
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
341
- q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
349
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size , result_size, desc, norm);
342
350
}
343
351
/* real-to-complex, double precision */
344
352
else if constexpr (std::is_same<_DataType_input, double >::value &&
@@ -347,36 +355,36 @@ void dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
347
355
desc_dp_real_t desc (dim);
348
356
349
357
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
350
- q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
358
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real );
351
359
}
352
360
/* real-to-complex, single precision */
353
361
else if constexpr (std::is_same<_DataType_input, float >::value &&
354
362
std::is_same<_DataType_output, std::complex<float >>::value)
355
363
{
356
364
desc_sp_real_t desc (dim); // try: 2 * result_size
357
365
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
358
- q_ref, array1_in, result_out, input_shape, shape_size, result_size, desc, norm);
366
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real );
359
367
}
360
368
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
361
369
std::is_same<_DataType_input, int64_t >::value)
362
370
{
363
371
double * array1_copy = reinterpret_cast <double *>(dpnp_memory_alloc_c (input_size * sizeof (double )));
364
372
365
- shape_elem_type* copy_strides = reinterpret_cast <shape_elem_type*>(dpnp_memory_alloc_c (sizeof (shape_elem_type)));
373
+ shape_elem_type* copy_strides = reinterpret_cast <shape_elem_type*>(dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
366
374
*copy_strides = 1 ;
367
- shape_elem_type* copy_shape = reinterpret_cast <shape_elem_type*>(dpnp_memory_alloc_c (sizeof (shape_elem_type)));
375
+ shape_elem_type* copy_shape = reinterpret_cast <shape_elem_type*>(dpnp_memory_alloc_c (q_ref, sizeof (shape_elem_type)));
368
376
*copy_shape = input_size;
369
377
shape_elem_type copy_shape_size = 1 ;
370
- dpnp_copyto_c<_DataType_input, double >(array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
371
- array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL );
378
+ dpnp_copyto_c<_DataType_input, double >(q_ref, array1_copy, input_size, copy_shape_size, copy_shape, copy_strides,
379
+ array1_in, input_size, copy_shape_size, copy_shape, copy_strides, NULL , dep_event_vec_ref );
372
380
373
381
desc_dp_real_t desc (dim);
374
382
dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
375
- array1_copy, result_out, input_shape, shape_size, result_size, desc, norm);
383
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, real );
376
384
377
- dpnp_memory_free_c (array1_copy);
378
- dpnp_memory_free_c (copy_strides);
379
- dpnp_memory_free_c (copy_shape);
385
+ dpnp_memory_free_c (q_ref, array1_copy);
386
+ dpnp_memory_free_c (q_ref, copy_strides);
387
+ dpnp_memory_free_c (q_ref, copy_shape);
380
388
}
381
389
else
382
390
{
@@ -406,7 +414,8 @@ void dpnp_fft_fft_c(const void* array1_in,
406
414
long axis,
407
415
long input_boundarie,
408
416
size_t inverse,
409
- const size_t norm)
417
+ const size_t norm,
418
+ const size_t real)
410
419
{
411
420
DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
412
421
DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
@@ -420,6 +429,7 @@ void dpnp_fft_fft_c(const void* array1_in,
420
429
input_boundarie,
421
430
inverse,
422
431
norm,
432
+ real,
423
433
dep_event_vec_ref);
424
434
DPCTLEvent_WaitAndThrow (event_ref);
425
435
}
@@ -433,6 +443,7 @@ void (*dpnp_fft_fft_default_c)(const void*,
433
443
long ,
434
444
long ,
435
445
size_t ,
446
+ const size_t ,
436
447
const size_t ) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
437
448
438
449
template <typename _DataType_input, typename _DataType_output>
@@ -446,6 +457,7 @@ DPCTLSyclEventRef (*dpnp_fft_fft_ext_c)(DPCTLSyclQueueRef,
446
457
long ,
447
458
size_t ,
448
459
const size_t ,
460
+ const size_t ,
449
461
const DPCTLEventVectorRef) = dpnp_fft_fft_c<_DataType_input, _DataType_output>;
450
462
451
463
void func_map_init_fft_func (func_map_t & fmap)
0 commit comments