Skip to content

Commit c7b541c

Browse files
Fix nanvar (#945)
1 parent 61120c7 commit c7b541c

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

dpnp/dpnp_algo/dpnp_algo_statistics.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,10 @@ cpdef utils.dpnp_descriptor dpnp_min(utils.dpnp_descriptor input, axis):
417417

418418

419419
cpdef utils.dpnp_descriptor dpnp_nanvar(utils.dpnp_descriptor arr, ddof):
420-
cdef utils.dpnp_descriptor mask_arr = dpnp_isnan(arr)
421-
n = sum(mask_arr.get_pyobj())
420+
# dpnp_isnan does not support USM array as input in comparison to dpnp.isnan
421+
cdef utils.dpnp_descriptor mask_arr = dpnp.get_dpnp_descriptor(dpnp.isnan(arr.get_pyobj()))
422+
n = dpnp.count_nonzero(mask_arr.get_pyobj())
422423
res_size = arr.size - n
423-
424424
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
425425

426426
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_NANVAR, param1_type, param1_type)

dpnp/dpnp_iface_counting.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
import dpnp
4444
import numpy
4545

46+
import dpnp.config as config
47+
from dpnp.dpnp_utils import *
48+
4649
from dpnp.dpnp_algo.dpnp_algo import * # TODO need to investigate why dpnp.dpnp_algo can not be used
4750

4851
__all__ = [
@@ -72,7 +75,6 @@ def count_nonzero(x1, axis=None, *, keepdims=False):
7275
5
7376
7477
"""
75-
7678
x1_desc = dpnp.get_dpnp_descriptor(x1)
7779
if x1_desc:
7880
if axis is not None:
@@ -85,4 +87,4 @@ def count_nonzero(x1, axis=None, *, keepdims=False):
8587

8688
return result
8789

88-
return numpy.count_nonzero(x1, axis, keepdims=keepdims)
90+
return call_origin(numpy.count_nonzero, x1, axis, keepdims=keepdims)

0 commit comments

Comments
 (0)