Skip to content

Commit fb57ef2

Browse files
committed
address comments
1 parent 1d360ad commit fb57ef2

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
### Changed
1717

1818
* Allowed input array of `uint64` dtype in `dpnp.bincount` [#2361](https://github.com/IntelPython/dpnp/pull/2361)
19-
* The vector norms `ord={1, 2, inf}` and the matrix norms `ord={1, 2, inf, "fro", "nuc"}` now consistently return zero for empty arrays, which are arrays with at least one axis of size zero. This change affects `dpnp.linalg.norm`, `dpnp.linalg.vector_norm`, and `dpnp.linalg.matrix_norm`. Previously, dpnp would either raise errors or return zero depending on the parameters provided [#2371](https://github.com/IntelPython/dpnp/pull/2371)
19+
* The vector norms `ord={None, 1, 2, inf}` and the matrix norms `ord={None, 1, 2, inf, "fro", "nuc"}` now consistently return zero for empty arrays, which are arrays with at least one axis of size zero. This change affects `dpnp.linalg.norm`, `dpnp.linalg.vector_norm`, and `dpnp.linalg.matrix_norm`. Previously, dpnp would either raise errors or return zero depending on the parameters provided [#2371](https://github.com/IntelPython/dpnp/pull/2371)
2020

2121
### Fixed
2222

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,9 @@ def _norm_int_axis(x, ord, axis, keepdims):
11851185
"""
11861186

11871187
if ord == dpnp.inf:
1188+
if x.shape[axis] == 0:
1189+
x = dpnp.moveaxis(x, axis, -1)
1190+
return dpnp.zeros_like(x, shape=x.shape[:-1])
11881191
return dpnp.abs(x).max(axis=axis, keepdims=keepdims)
11891192
if ord == -dpnp.inf:
11901193
return dpnp.abs(x).min(axis=axis, keepdims=keepdims)
@@ -1220,6 +1223,10 @@ def _norm_tuple_axis(x, ord, row_axis, col_axis, keepdims):
12201223
"""
12211224

12221225
axis = (row_axis, col_axis)
1226+
flag = x.shape[row_axis] == 0 or x.shape[col_axis] == 0
1227+
if flag and ord in [1, 2, dpnp.inf]:
1228+
x = dpnp.moveaxis(x, axis, (-2, -1))
1229+
return dpnp.zeros_like(x, shape=x.shape[:-2])
12231230
if row_axis == col_axis:
12241231
raise ValueError("Duplicate axes given.")
12251232
if ord == 2:
@@ -2401,17 +2408,10 @@ def dpnp_norm(x, ord=None, axis=None, keepdims=False):
24012408
axis = (axis,)
24022409

24032410
if len(axis) == 1:
2404-
if x.shape[axis[0]] == 0 and ord in [1, 2, dpnp.inf]:
2405-
x = dpnp.moveaxis(x, axis, -1)
2406-
return dpnp.zeros_like(x, shape=x.shape[:-1])
24072411
axis = normalize_axis_index(axis[0], ndim)
24082412
return _norm_int_axis(x, ord, axis, keepdims)
24092413

24102414
if len(axis) == 2:
2411-
flag = x.shape[axis[0]] == 0 or x.shape[axis[1]] == 0
2412-
if flag and ord in ["fro", "nuc", 1, 2, dpnp.inf]:
2413-
x = dpnp.moveaxis(x, axis, (-2, -1))
2414-
return dpnp.zeros_like(x, shape=x.shape[:-2])
24152415
row_axis, col_axis = axis
24162416
row_axis = normalize_axis_index(row_axis, ndim)
24172417
col_axis = normalize_axis_index(col_axis, ndim)

dpnp/tests/test_linalg.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,13 +2097,13 @@ def test_empty(self, shape, ord, axis, keepdims):
20972097
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
20982098
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
20992099
elif axis is None and a.ndim != 1 and a.shape[-1] == 0:
2100-
# TODO: when similar changes in numpy are available,
2101-
# instead of assert_equal with zero, we should compare with numpy
21022100
if ord in [-2, -1, 0, 3]:
21032101
# reduction cannot be performed over zero-size axes
21042102
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
21052103
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
21062104
else:
2105+
# TODO: when similar changes in numpy are available, instead
2106+
# of assert_equal with zero, we should compare with numpy
21072107
# ord in [None, 1, 2]
21082108
assert_equal(dpnp.linalg.norm(ia, **kwarg), 0)
21092109
else:
@@ -2295,14 +2295,15 @@ def test_matrix_norm(self, ord, keepdims):
22952295

22962296
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.int32])
22972297
@pytest.mark.parametrize(
2298-
"shape_axis", [[(2, 0), None], [(2, 0, 3), (0, 1)]]
2298+
"shape_axis", [[(2, 0), None], [(2, 0), (0, 1)], [(0, 2), (0, 1)]]
22992299
)
23002300
def test_matrix_norm_empty(self, dtype, shape_axis):
23012301
shape, axis = shape_axis[0], shape_axis[1]
23022302
x = dpnp.zeros(shape, dtype=dtype)
23032303

23042304
# TODO: when similar changes in numpy are available,
23052305
# instead of assert_equal with zero, we should compare with numpy
2306+
assert_equal(dpnp.linalg.norm(x, axis=axis), 0)
23062307
assert_equal(dpnp.linalg.norm(x, axis=axis, ord="fro"), 0)
23072308
assert_equal(dpnp.linalg.norm(x, axis=axis, ord="nuc"), 0)
23082309
assert_equal(dpnp.linalg.norm(x, axis=axis, ord=2), 0)
@@ -2315,6 +2316,7 @@ def test_vector_norm_empty(self, dtype, axis):
23152316
x = dpnp.zeros(0, dtype=dtype)
23162317
# TODO: when similar changes in numpy are available,
23172318
# instead of assert_equal with zero, we should compare with numpy
2319+
assert_equal(dpnp.linalg.vector_norm(x, axis=axis), 0)
23182320
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=1), 0)
23192321
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=2), 0)
23202322
assert_equal(dpnp.linalg.vector_norm(x, axis=axis, ord=dpnp.inf), 0)

0 commit comments

Comments
 (0)