Skip to content

Commit 019189c

Browse files
Enable Compute Follows Data in Cython in fft, random, linalg
1 parent f525309 commit 019189c

14 files changed

+1614
-331
lines changed

dpnp/backend/include/dpnp_iface_random.hpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -335,23 +335,23 @@ INP_DLLEXPORT void dpnp_rng_lognormal_c(void* result, const _DataType mean, cons
335335
* @param [in] q_ref Reference to SYCL queue.
336336
* @param [out] result Output array.
337337
* @param [in] ntrial Number of independent trials.
338-
* @param [in] p_vector Probability vector of possible outcomes (k length).
339-
* @param [in] p_vector_size Length of `p_vector`.
338+
* @param [in] p_in Probability of possible outcomes (k length).
339+
* @param [in] p_size Length of `p_in`.
340340
* @param [in] size Number of elements in `result` arrays.
341341
* @param [in] dep_event_vec_ref Reference to vector of SYCL events.
342342
*/
343343
template <typename _DataType>
344344
INP_DLLEXPORT DPCTLSyclEventRef dpnp_rng_multinomial_c(DPCTLSyclQueueRef q_ref,
345345
void* result,
346346
const int ntrial,
347-
const double* p_vector,
348-
const size_t p_vector_size,
347+
const double* p_in,
348+
const size_t p_size,
349349
const size_t size,
350350
const DPCTLEventVectorRef dep_event_vec_ref);
351351

352352
template <typename _DataType>
353353
INP_DLLEXPORT void dpnp_rng_multinomial_c(
354-
void* result, const int ntrial, const double* p_vector, const size_t p_vector_size, const size_t size);
354+
void* result, const int ntrial, const double* p_in, const size_t p_size, const size_t size);
355355

356356
/**
357357
* @ingroup BACKEND_RANDOM_API
@@ -360,31 +360,31 @@ INP_DLLEXPORT void dpnp_rng_multinomial_c(
360360
* @param [in] q_ref Reference to SYCL queue.
361361
* @param [out] result Output array.
362362
* @param [in] dimen Dimension of output random vectors.
363-
* @param [in] mean_vector Mean vector a of dimension.
364-
* @param [in] mean_vector_size Length of `mean_vector`.
365-
* @param [in] cov_vector Variance-covariance matrix.
366-
* @param [in] cov_vector_size Length of `cov_vector`.
363+
* @param [in] mean_in Mean arry of dimension.
364+
* @param [in] mean_size Length of `mean_in`.
365+
* @param [in] cov Variance-covariance matrix.
366+
* @param [in] cov_size Length of `cov_in`.
367367
* @param [in] size Number of elements in `result` arrays.
368368
* @param [in] dep_event_vec_ref Reference to vector of SYCL events.
369369
*/
370370
template <typename _DataType>
371371
INP_DLLEXPORT DPCTLSyclEventRef dpnp_rng_multivariate_normal_c(DPCTLSyclQueueRef q_ref,
372372
void* result,
373373
const int dimen,
374-
const double* mean_vector,
375-
const size_t mean_vector_size,
376-
const double* cov_vector,
377-
const size_t cov_vector_size,
374+
const double* mean_in,
375+
const size_t mean_size,
376+
const double* cov_in,
377+
const size_t cov_size,
378378
const size_t size,
379379
const DPCTLEventVectorRef dep_event_vec_ref);
380380

381381
template <typename _DataType>
382382
INP_DLLEXPORT void dpnp_rng_multivariate_normal_c(void* result,
383383
const int dimen,
384-
const double* mean_vector,
385-
const size_t mean_vector_size,
386-
const double* cov_vector,
387-
const size_t cov_vector_size,
384+
const double* mean_in,
385+
const size_t mean_size,
386+
const double* cov_in,
387+
const size_t cov_size,
388388
const size_t size);
389389

390390
/**

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,11 +1086,25 @@ void func_map_init_linalg(func_map_t& fmap)
10861086
fmap[DPNPFuncName::DPNP_FN_EIG][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_eig_default_c<float, float>};
10871087
fmap[DPNPFuncName::DPNP_FN_EIG][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_eig_default_c<double, double>};
10881088

1089+
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_eig_ext_c<int32_t, double>};
1090+
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_eig_ext_c<int64_t, double>};
1091+
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_eig_ext_c<float, float>};
1092+
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_eig_ext_c<double, double>};
1093+
10891094
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_eigvals_default_c<int32_t, double>};
10901095
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_eigvals_default_c<int64_t, double>};
10911096
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_eigvals_default_c<float, float>};
10921097
fmap[DPNPFuncName::DPNP_FN_EIGVALS][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_eigvals_default_c<double, double>};
10931098

1099+
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {eft_DBL,
1100+
(void*)dpnp_eigvals_ext_c<int32_t, double>};
1101+
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {eft_DBL,
1102+
(void*)dpnp_eigvals_ext_c<int64_t, double>};
1103+
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1104+
(void*)dpnp_eigvals_ext_c<float, float>};
1105+
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1106+
(void*)dpnp_eigvals_ext_c<double, double>};
1107+
10941108
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_initval_default_c<bool>};
10951109
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_initval_default_c<int32_t>};
10961110
fmap[DPNPFuncName::DPNP_FN_INITVAL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_initval_default_c<int64_t>};

dpnp/backend/kernels/dpnp_krnl_fft.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,5 +377,18 @@ void func_map_init_fft_func(func_map_t& fmap)
377377
eft_C64, (void*)dpnp_fft_fft_default_c<std::complex<float>, std::complex<float>>};
378378
fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_C128][eft_C128] = {
379379
eft_C128, (void*)dpnp_fft_fft_default_c<std::complex<double>, std::complex<double>>};
380+
381+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
382+
eft_C128, (void*)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>};
383+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
384+
eft_C128, (void*)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>};
385+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
386+
eft_C64, (void*)dpnp_fft_fft_ext_c<float, std::complex<float>>};
387+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
388+
eft_C128, (void*)dpnp_fft_fft_ext_c<double, std::complex<double>>};
389+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_C64][eft_C64] = {
390+
eft_C64, (void*)dpnp_fft_fft_ext_c<std::complex<float>, std::complex<float>>};
391+
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_C128][eft_C128] = {
392+
eft_C128, (void*)dpnp_fft_fft_ext_c<std::complex<double>, std::complex<double>>};
380393
return;
381394
}

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -878,16 +878,29 @@ void func_map_init_linalg_func(func_map_t& fmap)
878878
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cholesky_default_c<float>};
879879
fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cholesky_default_c<double>};
880880

881+
fmap[DPNPFuncName::DPNP_FN_CHOLESKY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_cholesky_ext_c<float>};
882+
fmap[DPNPFuncName::DPNP_FN_CHOLESKY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_cholesky_ext_c<double>};
883+
881884
fmap[DPNPFuncName::DPNP_FN_DET][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_det_default_c<int32_t>};
882885
fmap[DPNPFuncName::DPNP_FN_DET][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_det_default_c<int64_t>};
883886
fmap[DPNPFuncName::DPNP_FN_DET][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_det_default_c<float>};
884887
fmap[DPNPFuncName::DPNP_FN_DET][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_det_default_c<double>};
885888

889+
fmap[DPNPFuncName::DPNP_FN_DET_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_det_ext_c<int32_t>};
890+
fmap[DPNPFuncName::DPNP_FN_DET_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_det_ext_c<int64_t>};
891+
fmap[DPNPFuncName::DPNP_FN_DET_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_det_ext_c<float>};
892+
fmap[DPNPFuncName::DPNP_FN_DET_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_det_ext_c<double>};
893+
886894
fmap[DPNPFuncName::DPNP_FN_INV][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_inv_default_c<int32_t, double>};
887895
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_inv_default_c<int64_t, double>};
888896
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_inv_default_c<float, double>};
889897
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_inv_default_c<double, double>};
890898

899+
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_inv_ext_c<int32_t, double>};
900+
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_inv_ext_c<int64_t, double>};
901+
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_inv_ext_c<float, double>};
902+
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_inv_ext_c<double, double>};
903+
891904
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_INT] = {eft_INT,
892905
(void*)dpnp_kron_default_c<int32_t, int32_t, int32_t>};
893906
fmap[DPNPFuncName::DPNP_FN_KRON][eft_INT][eft_LNG] = {eft_LNG,
@@ -995,19 +1008,46 @@ void func_map_init_linalg_func(func_map_t& fmap)
9951008
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matrix_rank_default_c<float>};
9961009
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_matrix_rank_default_c<double>};
9971010

1011+
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_matrix_rank_ext_c<int32_t>};
1012+
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_matrix_rank_ext_c<int64_t>};
1013+
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_matrix_rank_ext_c<float>};
1014+
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_matrix_rank_ext_c<double>};
1015+
9981016
fmap[DPNPFuncName::DPNP_FN_QR][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_qr_default_c<int32_t, double>};
9991017
fmap[DPNPFuncName::DPNP_FN_QR][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_qr_default_c<int64_t, double>};
10001018
fmap[DPNPFuncName::DPNP_FN_QR][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_qr_default_c<float, float>};
10011019
fmap[DPNPFuncName::DPNP_FN_QR][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_qr_default_c<double, double>};
10021020
// fmap[DPNPFuncName::DPNP_FN_QR][eft_C128][eft_C128] = {
10031021
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};
10041022

1005-
fmap[DPNPFuncName::DPNP_FN_SVD][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_svd_default_c<int32_t, double, double>};
1006-
fmap[DPNPFuncName::DPNP_FN_SVD][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_svd_default_c<int64_t, double, double>};
1007-
fmap[DPNPFuncName::DPNP_FN_SVD][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_svd_default_c<float, float, float>};
1008-
fmap[DPNPFuncName::DPNP_FN_SVD][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_svd_default_c<double, double, double>};
1023+
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_qr_ext_c<int32_t, double>};
1024+
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_qr_ext_c<int64_t, double>};
1025+
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_qr_ext_c<float, float>};
1026+
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_qr_ext_c<double, double>};
1027+
// fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_C128][eft_C128] = {
1028+
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};
1029+
1030+
fmap[DPNPFuncName::DPNP_FN_SVD][eft_INT][eft_INT] = {eft_DBL,
1031+
(void*)dpnp_svd_default_c<int32_t, double, double>};
1032+
fmap[DPNPFuncName::DPNP_FN_SVD][eft_LNG][eft_LNG] = {eft_DBL,
1033+
(void*)dpnp_svd_default_c<int64_t, double, double>};
1034+
fmap[DPNPFuncName::DPNP_FN_SVD][eft_FLT][eft_FLT] = {eft_FLT,
1035+
(void*)dpnp_svd_default_c<float, float, float>};
1036+
fmap[DPNPFuncName::DPNP_FN_SVD][eft_DBL][eft_DBL] = {eft_DBL,
1037+
(void*)dpnp_svd_default_c<double, double, double>};
10091038
fmap[DPNPFuncName::DPNP_FN_SVD][eft_C128][eft_C128] = {
10101039
eft_C128, (void*)dpnp_svd_default_c<std::complex<double>, std::complex<double>, double>};
1040+
1041+
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {eft_DBL,
1042+
(void*)dpnp_svd_ext_c<int32_t, double, double>};
1043+
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {eft_DBL,
1044+
(void*)dpnp_svd_ext_c<int64_t, double, double>};
1045+
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1046+
(void*)dpnp_svd_ext_c<float, float, float>};
1047+
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1048+
(void*)dpnp_svd_ext_c<double, double, double>};
1049+
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_C128][eft_C128] = {
1050+
eft_C128, (void*)dpnp_svd_ext_c<std::complex<double>, std::complex<double>, double>};
10111051

10121052
return;
10131053
}

0 commit comments

Comments
 (0)