@@ -180,7 +180,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
180
180
const size_t input_size,
181
181
const size_t result_size,
182
182
_Descriptor_type& desc,
183
- const size_t norm)
183
+ size_t inverse,
184
+ double backward_scale,
185
+ double forward_scale)
184
186
{
185
187
if (!shape_size)
186
188
{
@@ -199,9 +201,6 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
199
201
200
202
const size_t shift = input_shape[shape_size - 1 ];
201
203
202
- double forward_scale = 1.0 ;
203
- double backward_scale = 1.0 / shift;
204
-
205
204
desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
206
205
desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
207
206
// enum value from math library C interface
@@ -213,7 +212,11 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
213
212
fft_events.reserve (n_iter);
214
213
215
214
for (size_t i = 0 ; i < n_iter; ++i) {
216
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
215
+ if (inverse) {
216
+ fft_events.push_back (mkl_dft::compute_backward (desc, array_1 + i * shift, result + i * shift));
217
+ } else {
218
+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
219
+ }
217
220
}
218
221
219
222
sycl::event::wait (fft_events);
@@ -234,7 +237,9 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
234
237
const size_t input_size,
235
238
const size_t result_size,
236
239
_Descriptor_type& desc,
237
- const size_t norm,
240
+ size_t inverse,
241
+ double backward_scale,
242
+ double forward_scale,
238
243
const size_t real)
239
244
{
240
245
if (!shape_size)
@@ -255,19 +260,21 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
255
260
const size_t input_shift = input_shape[shape_size - 1 ];
256
261
const size_t result_shift = result_shape[shape_size - 1 ];;
257
262
258
- double forward_scale = 1.0 ;
259
- double backward_scale = 1.0 / input_shift;
260
-
261
263
desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
262
264
desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
265
+ desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
263
266
264
267
desc.commit (queue);
265
268
266
269
std::vector<sycl::event> fft_events;
267
270
fft_events.reserve (n_iter);
268
271
269
272
for (size_t i = 0 ; i < n_iter; ++i) {
270
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
273
+ if (inverse) {
274
+ fft_events.push_back (mkl_dft::compute_backward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
275
+ } else {
276
+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
277
+ }
271
278
}
272
279
273
280
sycl::event::wait (fft_events);
@@ -330,6 +337,21 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
330
337
331
338
size_t dim = input_shape[shape_size - 1 ];
332
339
340
+ double backward_scale = 1 ;
341
+ double forward_scale = 1 ;
342
+
343
+ if (norm == 0 ) { // norm = "backward"
344
+ backward_scale = 1 . / dim;
345
+ } else if (norm == 1 ) { // norm = "forward"
346
+ forward_scale = 1 . / dim;
347
+ } else { // norm = "ortho"
348
+ if (inverse) {
349
+ backward_scale = 1 . / sqrt (dim);
350
+ } else {
351
+ forward_scale = 1 . / sqrt (dim);
352
+ }
353
+ }
354
+
333
355
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
334
356
std::is_same<_DataType_output, std::complex<double >>::value)
335
357
{
@@ -338,15 +360,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
338
360
{
339
361
desc_dp_cmplx_t desc (dim);
340
362
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
341
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm );
363
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale );
342
364
}
343
365
/* complex-to-complex, single precision */
344
366
else if constexpr (std::is_same<_DataType_input, std::complex<float >>::value &&
345
367
std::is_same<_DataType_output, std::complex<float >>::value)
346
368
{
347
369
desc_sp_cmplx_t desc (dim);
348
370
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
349
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm );
371
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale );
350
372
}
351
373
/* real-to-complex, double precision */
352
374
else if constexpr (std::is_same<_DataType_input, double >::value &&
@@ -355,15 +377,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
355
377
desc_dp_real_t desc (dim);
356
378
357
379
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
358
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm , 0 );
380
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
359
381
}
360
382
/* real-to-complex, single precision */
361
383
else if constexpr (std::is_same<_DataType_input, float >::value &&
362
384
std::is_same<_DataType_output, std::complex<float >>::value)
363
385
{
364
386
desc_sp_real_t desc (dim); // try: 2 * result_size
365
387
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
366
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm , 0 );
388
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
367
389
}
368
390
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
369
391
std::is_same<_DataType_input, int64_t >::value)
@@ -380,7 +402,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
380
402
381
403
desc_dp_real_t desc (dim);
382
404
dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
383
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm , 0 );
405
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
384
406
385
407
dpnp_memory_free_c (q_ref, array1_copy);
386
408
dpnp_memory_free_c (q_ref, copy_strides);
@@ -484,6 +506,20 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
484
506
485
507
size_t dim = input_shape[shape_size - 1 ];
486
508
509
+ double backward_scale = 1 ;
510
+ double forward_scale = 1 ;
511
+ if (norm == 0 ) { // norm = "backward"
512
+ backward_scale = 1 . / dim;
513
+ } else if (norm == 1 ) { // norm = "forward"
514
+ forward_scale = 1 . / dim;
515
+ } else { // norm = "ortho"
516
+ if (inverse) {
517
+ backward_scale = 1 . / sqrt (dim);
518
+ } else {
519
+ forward_scale = 1 . / sqrt (dim);
520
+ }
521
+ }
522
+
487
523
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
488
524
std::is_same<_DataType_output, std::complex<double >>::value)
489
525
{
@@ -493,15 +529,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
493
529
desc_dp_real_t desc (dim);
494
530
495
531
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
496
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1l );
532
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1 );
497
533
}
498
534
/* real-to-complex, single precision */
499
535
else if constexpr (std::is_same<_DataType_input, float >::value &&
500
536
std::is_same<_DataType_output, std::complex<float >>::value)
501
537
{
502
538
desc_sp_real_t desc (dim); // try: 2 * result_size
503
539
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
504
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm , 1 );
540
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 1 );
505
541
}
506
542
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
507
543
std::is_same<_DataType_input, int64_t >::value)
@@ -518,7 +554,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
518
554
519
555
desc_dp_real_t desc (dim);
520
556
dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
521
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm , 1 );
557
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 1 );
522
558
523
559
dpnp_memory_free_c (q_ref, array1_copy);
524
560
dpnp_memory_free_c (q_ref, copy_strides);
0 commit comments