@@ -181,8 +181,7 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
181
181
const size_t result_size,
182
182
_Descriptor_type& desc,
183
183
size_t inverse,
184
- double backward_scale,
185
- double forward_scale)
184
+ const size_t norm)
186
185
{
187
186
if (!shape_size)
188
187
{
@@ -201,6 +200,21 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
201
200
202
201
const size_t shift = input_shape[shape_size - 1 ];
203
202
203
+ double backward_scale = 1 .;
204
+ double forward_scale = 1 .;
205
+
206
+ if (norm == 0 ) { // norm = "backward"
207
+ backward_scale = 1 . / shift;
208
+ } else if (norm == 1 ) { // norm = "forward"
209
+ forward_scale = 1 . / shift;
210
+ } else { // norm = "ortho"
211
+ if (inverse) {
212
+ backward_scale = 1 . / sqrt (shift);
213
+ } else {
214
+ forward_scale = 1 . / sqrt (shift);
215
+ }
216
+ }
217
+
204
218
desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
205
219
desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
206
220
// enum value from math library C interface
@@ -238,8 +252,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
238
252
const size_t result_size,
239
253
_Descriptor_type& desc,
240
254
size_t inverse,
241
- double backward_scale,
242
- double forward_scale,
255
+ const size_t norm,
243
256
const size_t real)
244
257
{
245
258
if (!shape_size)
@@ -258,7 +271,26 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
258
271
std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
259
272
260
273
const size_t input_shift = input_shape[shape_size - 1 ];
261
- const size_t result_shift = result_shape[shape_size - 1 ];;
274
+ const size_t result_shift = result_shape[shape_size - 1 ];
275
+
276
+ double backward_scale = 1 .;
277
+ double forward_scale = 1 .;
278
+
279
+ if (norm == 0 ) { // norm = "backward"
280
+ if (inverse) {
281
+ forward_scale = 1 . / result_shift;
282
+ } else {
283
+ backward_scale = 1 . / result_shift;
284
+ }
285
+ } else if (norm == 1 ) { // norm = "forward"
286
+ if (inverse) {
287
+ backward_scale = 1 . / result_shift;
288
+ } else {
289
+ forward_scale = 1 . / result_shift;
290
+ }
291
+ } else { // norm = "ortho"
292
+ forward_scale = 1 . / sqrt (result_shift);
293
+ }
262
294
263
295
desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
264
296
desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
@@ -270,11 +302,7 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
270
302
fft_events.reserve (n_iter);
271
303
272
304
for (size_t i = 0 ; i < n_iter; ++i) {
273
- if (inverse) {
274
- fft_events.push_back (mkl_dft::compute_backward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
275
- } else {
276
- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
277
- }
305
+ fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
278
306
}
279
307
280
308
sycl::event::wait (fft_events);
@@ -307,6 +335,11 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref,
307
335
event = queue.submit (kernel_func);
308
336
event.wait ();
309
337
338
+ if (inverse) {
339
+ event = oneapi::mkl::vm::conj (queue, result_size, reinterpret_cast <std::complex<_DataType_output>*>(result), reinterpret_cast <std::complex<_DataType_output>*>(result));
340
+ event.wait ();
341
+ }
342
+
310
343
return ;
311
344
}
312
345
@@ -337,21 +370,6 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
337
370
338
371
size_t dim = input_shape[shape_size - 1 ];
339
372
340
- double backward_scale = 1 ;
341
- double forward_scale = 1 ;
342
-
343
- if (norm == 0 ) { // norm = "backward"
344
- backward_scale = 1 . / dim;
345
- } else if (norm == 1 ) { // norm = "forward"
346
- forward_scale = 1 . / dim;
347
- } else { // norm = "ortho"
348
- if (inverse) {
349
- backward_scale = 1 . / sqrt (dim);
350
- } else {
351
- forward_scale = 1 . / sqrt (dim);
352
- }
353
- }
354
-
355
373
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
356
374
std::is_same<_DataType_output, std::complex<double >>::value)
357
375
{
@@ -360,15 +378,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
360
378
{
361
379
desc_dp_cmplx_t desc (dim);
362
380
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
363
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale );
381
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm );
364
382
}
365
383
/* complex-to-complex, single precision */
366
384
else if constexpr (std::is_same<_DataType_input, std::complex<float >>::value &&
367
385
std::is_same<_DataType_output, std::complex<float >>::value)
368
386
{
369
387
desc_sp_cmplx_t desc (dim);
370
388
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
371
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale );
389
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm );
372
390
}
373
391
/* real-to-complex, double precision */
374
392
else if constexpr (std::is_same<_DataType_input, double >::value &&
@@ -377,15 +395,15 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
377
395
desc_dp_real_t desc (dim);
378
396
379
397
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
380
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
398
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 0 );
381
399
}
382
400
/* real-to-complex, single precision */
383
401
else if constexpr (std::is_same<_DataType_input, float >::value &&
384
402
std::is_same<_DataType_output, std::complex<float >>::value)
385
403
{
386
404
desc_sp_real_t desc (dim); // try: 2 * result_size
387
405
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
388
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
406
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 0 );
389
407
}
390
408
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
391
409
std::is_same<_DataType_input, int64_t >::value)
@@ -402,7 +420,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
402
420
403
421
desc_dp_real_t desc (dim);
404
422
dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
405
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 0 );
423
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 0 );
406
424
407
425
dpnp_memory_free_c (q_ref, array1_copy);
408
426
dpnp_memory_free_c (q_ref, copy_strides);
@@ -506,19 +524,6 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
506
524
507
525
size_t dim = input_shape[shape_size - 1 ];
508
526
509
- double backward_scale = 1 ;
510
- double forward_scale = 1 ;
511
- if (norm == 0 ) { // norm = "backward"
512
- backward_scale = 1 . / dim;
513
- } else if (norm == 1 ) { // norm = "forward"
514
- forward_scale = 1 . / dim;
515
- } else { // norm = "ortho"
516
- if (inverse) {
517
- backward_scale = 1 . / sqrt (dim);
518
- } else {
519
- forward_scale = 1 . / sqrt (dim);
520
- }
521
- }
522
527
523
528
if constexpr (std::is_same<_DataType_output, std::complex<float >>::value ||
524
529
std::is_same<_DataType_output, std::complex<double >>::value)
@@ -529,15 +534,15 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
529
534
desc_dp_real_t desc (dim);
530
535
531
536
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
532
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 1 );
537
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 1 );
533
538
}
534
539
/* real-to-complex, single precision */
535
540
else if constexpr (std::is_same<_DataType_input, float >::value &&
536
541
std::is_same<_DataType_output, std::complex<float >>::value)
537
542
{
538
543
desc_sp_real_t desc (dim); // try: 2 * result_size
539
544
dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
540
- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 1 );
545
+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 1 );
541
546
}
542
547
else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
543
548
std::is_same<_DataType_input, int64_t >::value)
@@ -554,7 +559,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
554
559
555
560
desc_dp_real_t desc (dim);
556
561
dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
557
- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale , 1 );
562
+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm , 1 );
558
563
559
564
dpnp_memory_free_c (q_ref, array1_copy);
560
565
dpnp_memory_free_c (q_ref, copy_strides);
0 commit comments