Skip to content

Commit 4bbb5f2

Browse files
committed
update dpnp.cond
1 parent a8bcdaf commit 4bbb5f2

File tree

8 files changed

+258
-89
lines changed

8 files changed

+258
-89
lines changed

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ cimport numpy
4545
cimport dpnp.dpnp_utils as utils
4646

4747
__all__ = [
48-
"dpnp_cond",
4948
"dpnp_eig",
5049
"dpnp_eigvals",
5150
]
@@ -60,30 +59,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_2in_1out_func_ptr_t)(c_dpctl.D
6059
const c_dpctl.DPCTLEventVectorRef)
6160

6261

63-
cpdef object dpnp_cond(object input, object p):
64-
if p in ('f', 'fro'):
65-
# TODO: change order='K' when support is implemented
66-
input = dpnp.ravel(input, order='C')
67-
sqnorm = dpnp.dot(input, input)
68-
res = dpnp.sqrt(sqnorm)
69-
ret = dpnp.array([res])
70-
elif p == dpnp.inf:
71-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=1)
72-
ret = dpnp.max(dpnp_sum_val)
73-
elif p == -dpnp.inf:
74-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=1)
75-
ret = dpnp.min(dpnp_sum_val)
76-
elif p == 1:
77-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=0)
78-
ret = dpnp.max(dpnp_sum_val)
79-
elif p == -1:
80-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=0)
81-
ret = dpnp.min(dpnp_sum_val)
82-
else:
83-
ret = dpnp.array([input.item(0)])
84-
return ret
85-
86-
8762
cpdef tuple dpnp_eig(utils.dpnp_descriptor x1):
8863
cdef shape_type_c x1_shape = x1.shape
8964

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
check_stacked_2d,
4949
check_stacked_square,
5050
dpnp_cholesky,
51+
dpnp_cond,
5152
dpnp_det,
5253
dpnp_eigh,
5354
dpnp_inv,
@@ -144,32 +145,60 @@ def cholesky(a, upper=False):
144145
return dpnp_cholesky(a, upper=upper)
145146

146147

147-
def cond(input, p=None):
148+
def cond(x, p=None):
148149
"""
149150
Compute the condition number of a matrix.
150151
151152
For full documentation refer to :obj:`numpy.linalg.cond`.
152153
153-
Limitations
154-
-----------
155-
Input array is supported as :obj:`dpnp.ndarray`.
156-
Parameter p=[None, 1, -1, 2, -2, dpnp.inf, -dpnp.inf, 'fro'] is supported.
154+
Parameters
155+
----------
156+
x : {dpnp.ndarray, usm_ndarray}
157+
The matrix whose condition number is sought.
158+
p : {None, 1, -1, 2, -2, inf, -inf, "fro"}, optional
159+
Order of the norm used in the condition number computation.
160+
inf means dpnp's `inf` object. The default is ``None``.
161+
162+
Returns
163+
-------
164+
out : dpnp.ndarray
165+
The condition number of the matrix. May be infinite.
157166
158167
See Also
159168
--------
160169
:obj:`dpnp.norm` : Matrix or vector norm.
161-
"""
162170
163-
if not use_origin_backend(input):
164-
if p in [None, 1, -1, 2, -2, dpnp.inf, -dpnp.inf, "fro"]:
165-
result_obj = dpnp_cond(input, p)
166-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
171+
Examples
172+
--------
173+
>>> import dpnp as np
174+
>>> a = np.array([[1, 0, -1], [0, 1, 0], [1, 0, 1]])
175+
>>> a
176+
array([[ 1, 0, -1],
177+
[ 0, 1, 0],
178+
[ 1, 0, 1]])
179+
>>> np.linalg.cond(a)
180+
array(1.4142135623730951)
181+
>>> np.linalg.cond(a, 'fro')
182+
array(3.1622776601683795)
183+
>>> np.linalg.cond(a, np.inf)
184+
array(2.)
185+
>>> np.linalg.cond(a, -np.inf)
186+
array(1.)
187+
>>> np.linalg.cond(a, 1)
188+
array(2.)
189+
>>> np.linalg.cond(a, -1)
190+
array(1.)
191+
>>> np.linalg.cond(a, 2)
192+
array(1.4142135623730951)
193+
>>> np.linalg.cond(a, -2)
194+
array(0.70710678118654746) # may vary
195+
>>> min(np.linalg.svd(a, compute_uv=False))*min(np.linalg.svd(np.linalg.inv(a), compute_uv=False))
196+
array(0.70710678118654746) # may vary
167197
168-
return result
169-
else:
170-
pass
198+
"""
171199

172-
return call_origin(numpy.linalg.cond, input, p)
200+
dpnp.check_supported_arrays_type(x)
201+
return dpnp_cond(x, p)
173202

174203

175204
def det(a):

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"check_stacked_2d",
3939
"check_stacked_square",
4040
"dpnp_cholesky",
41+
"dpnp_cond",
4142
"dpnp_det",
4243
"dpnp_eigh",
4344
"dpnp_inv",
@@ -199,6 +200,11 @@ def _common_inexact_type(default_dtype, *dtypes):
199200
return dpnp.result_type(*inexact_dtypes)
200201

201202

203+
def _is_empty_2d(arr):
204+
# check size first for efficiency
205+
return arr.size == 0 and prod(arr.shape[-2:]) == 0
206+
207+
202208
def _lu_factor(a, res_type):
203209
"""
204210
Compute pivoted LU decomposition.
@@ -841,6 +847,40 @@ def dpnp_cholesky(a, upper):
841847
return a_h
842848

843849

850+
def dpnp_cond(x, p=None):
851+
"""Compute the condition number of a matrix."""
852+
853+
if _is_empty_2d(x):
854+
raise dpnp.linalg.LinAlgError("cond is not defined on empty arrays")
855+
if p is None or p == 2 or p == -2:
856+
s = dpnp.linalg.svd(x, compute_uv=False)
857+
with numpy.errstate(all="ignore"):
858+
if p == -2:
859+
r = s[..., -1] / s[..., 0]
860+
else:
861+
r = s[..., 0] / s[..., -1]
862+
else:
863+
# Call inv(x) ignoring errors. The result array will
864+
# contain nans in the entries where inversion failed.
865+
check_stacked_2d(x)
866+
check_stacked_square(x)
867+
result_t = _common_type(x)
868+
with numpy.errstate(all="ignore"):
869+
invx = dpnp.linalg.inv(x)
870+
r = dpnp.linalg.norm(x, p, axis=(-2, -1)) * dpnp.linalg.norm(
871+
invx, p, axis=(-2, -1)
872+
)
873+
r = r.astype(result_t, copy=False)
874+
875+
# Convert nans to infs unless the original array had nan entries
876+
nan_mask = dpnp.isnan(r)
877+
if nan_mask.any():
878+
nan_mask &= ~dpnp.isnan(x).any(axis=(-2, -1))
879+
r[nan_mask] = dpnp.inf
880+
881+
return r
882+
883+
844884
def dpnp_det(a):
845885
"""
846886
dpnp_det(a)
@@ -1222,18 +1262,18 @@ def dpnp_multi_dot(n, arrays, out=None):
12221262
"""Compute the dot product of two or more arrays in a single function call."""
12231263

12241264
if not arrays[0].ndim in [1, 2]:
1225-
raise numpy.linalg.LinAlgError(
1265+
raise dpnp.linalg.LinAlgError(
12261266
f"{arrays[0].ndim}-dimensional array given. First array must be 1-D or 2-D."
12271267
)
12281268

12291269
if not arrays[-1].ndim in [1, 2]:
1230-
raise numpy.linalg.LinAlgError(
1270+
raise dpnp.linalg.LinAlgError(
12311271
f"{arrays[-1].ndim}-dimensional array given. Last array must be 1-D or 2-D."
12321272
)
12331273

12341274
for arr in arrays[1:-1]:
12351275
if arr.ndim != 2:
1236-
raise numpy.linalg.LinAlgError(
1276+
raise dpnp.linalg.LinAlgError(
12371277
f"{arr.ndim}-dimensional array given. Inner arrays must be 2-D."
12381278
)
12391279

tests/skipped_tests.tbl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,6 @@ tests/third_party/cupy/fft_tests/test_fft.py::TestFftn_param_23_{axes=None, norm
3737

3838
tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory
3939

40-
tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
41-
tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
42-
tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
43-
tests/test_linalg.py::test_cond[2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
44-
tests/test_linalg.py::test_cond[-2-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
45-
tests/test_linalg.py::test_cond[2-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
46-
tests/test_linalg.py::test_cond["fro"-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
47-
tests/test_linalg.py::test_cond["fro"-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
48-
tests/test_linalg.py::test_cond[None-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
49-
tests/test_linalg.py::test_cond[None-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
50-
tests/test_linalg.py::test_cond[-numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
51-
tests/test_linalg.py::test_cond[numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
52-
5340
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-float64]
5441
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-float32]
5542
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-int64]

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,6 @@ tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsMult
158158

159159
tests/third_party/intel/test_zero_copy_test1.py::test_dpnp_interaction_with_dpctl_memory
160160

161-
tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
162-
tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
163-
tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
164-
tests/test_linalg.py::test_cond[2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
165-
tests/test_linalg.py::test_cond[-2-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
166-
tests/test_linalg.py::test_cond[2-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
167-
tests/test_linalg.py::test_cond["fro"-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
168-
tests/test_linalg.py::test_cond["fro"-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
169-
tests/test_linalg.py::test_cond[None-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
170-
tests/test_linalg.py::test_cond[None-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
171-
tests/test_linalg.py::test_cond[-numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
172-
tests/test_linalg.py::test_cond[numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
173-
174161
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-float64]
175162
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-float32]
176163
tests/test_linalg.py::test_matrix_rank[None-[0, 1]-int64]

0 commit comments

Comments
 (0)