Skip to content

Commit f6d175b

Browse files
vtavanaantonwolfy
andauthored
Fuzz testing fixes (#1596)
* fix cholesky for 0-D array (#1584) Co-authored-by: Anton <[email protected]> * fix `dpnp.linalg.qr` and `dpnp.linalg.det` functions (#1592) * modifying dpnp.linalg.qr function * modifying dpnp.linalg.det function * fix pre-commit --------- Co-authored-by: Anton <[email protected]>
1 parent 75b0dd1 commit f6d175b

File tree

4 files changed

+79
-13
lines changed

4 files changed

+79
-13
lines changed

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ DPCTLSyclEventRef dpnp_cholesky_c(DPCTLSyclQueueRef q_ref,
4747
(void)dep_event_vec_ref;
4848

4949
DPCTLSyclEventRef event_ref = nullptr;
50+
if (!data_size) {
51+
return event_ref;
52+
}
5053
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
5154

5255
sycl::event event;
@@ -609,6 +612,9 @@ DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
609612
(void)dep_event_vec_ref;
610613

611614
DPCTLSyclEventRef event_ref = nullptr;
615+
if (!size_m || !size_n) {
616+
return event_ref;
617+
}
612618
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
613619

614620
sycl::event event;

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,9 @@ cpdef object dpnp_cond(object input, object p):
142142
cpdef utils.dpnp_descriptor dpnp_det(utils.dpnp_descriptor input):
143143
cdef shape_type_c input_shape = input.shape
144144
cdef size_t n = input.shape[-1]
145-
cdef size_t size_out = 1
145+
cdef shape_type_c result_shape = (1,)
146146
if input.ndim != 2:
147-
output_shape = tuple((list(input.shape))[:-2])
148-
for i in range(len(output_shape)):
149-
size_out *= output_shape[i]
150-
151-
cdef shape_type_c result_shape = (size_out,)
152-
if size_out > 1:
153-
result_shape = output_shape
147+
result_shape = tuple((list(input.shape))[:-2])
154148

155149
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
156150

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def det(input):
159159

160160
x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False)
161161
if x1_desc:
162-
if x1_desc.shape[-1] == x1_desc.shape[-2]:
162+
if x1_desc.ndim < 2:
163+
pass
164+
elif x1_desc.shape[-1] == x1_desc.shape[-2]:
163165
result_obj = dpnp_det(x1_desc).get_pyobj()
164166
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
165167

@@ -488,7 +490,9 @@ def qr(x1, mode="reduced"):
488490

489491
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
490492
if x1_desc:
491-
if mode != "reduced":
493+
if x1_desc.ndim != 2:
494+
pass
495+
elif mode != "reduced":
492496
pass
493497
else:
494498
result_tup = dpnp_qr(x1_desc, mode)

tests/test_linalg.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def test_cholesky(array):
6060
assert_array_equal(expected, result)
6161

6262

63+
@pytest.mark.parametrize(
64+
"shape",
65+
[
66+
(0, 0),
67+
(3, 0, 0),
68+
],
69+
ids=[
70+
"(0, 0)",
71+
"(3, 0, 0)",
72+
],
73+
)
74+
def test_cholesky_0D(shape):
75+
a = numpy.empty(shape)
76+
ia = inp.array(a)
77+
result = inp.linalg.cholesky(ia)
78+
expected = numpy.linalg.cholesky(a)
79+
assert_array_equal(expected, result)
80+
81+
6382
@pytest.mark.parametrize(
6483
"arr",
6584
[[[1, 0, -1], [0, 1, 0], [1, 0, 1]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]],
@@ -109,6 +128,20 @@ def test_det(array):
109128
assert_allclose(expected, result)
110129

111130

131+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
132+
def test_det_empty():
133+
a = numpy.empty((0, 0, 2, 2), dtype=numpy.float32)
134+
ia = inp.array(a)
135+
136+
np_det = numpy.linalg.det(a)
137+
dpnp_det = inp.linalg.det(ia)
138+
139+
assert dpnp_det.dtype == np_det.dtype
140+
assert dpnp_det.shape == np_det.shape
141+
142+
assert_allclose(np_det, dpnp_det)
143+
144+
112145
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
113146
@pytest.mark.parametrize("size", [2, 4, 8, 16, 300])
114147
def test_eig_arange(type, size):
@@ -339,8 +372,8 @@ def test_norm3(array, ord, axis):
339372
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
340373
@pytest.mark.parametrize(
341374
"shape",
342-
[(2, 2), (3, 4), (5, 3), (16, 16)],
343-
ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)"],
375+
[(2, 2), (3, 4), (5, 3), (16, 16), (0, 0), (0, 2), (2, 0)],
376+
ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)", "(0,0)", "(0,2)", "(2,0)"],
344377
)
345378
@pytest.mark.parametrize(
346379
"mode", ["complete", "reduced"], ids=["complete", "reduced"]
@@ -369,7 +402,7 @@ def test_qr(type, shape, mode):
369402
# check decomposition
370403
assert_allclose(
371404
ia,
372-
numpy.dot(inp.asnumpy(dpnp_q), inp.asnumpy(dpnp_r)),
405+
inp.dot(dpnp_q, dpnp_r),
373406
rtol=tol,
374407
atol=tol,
375408
)
@@ -390,6 +423,35 @@ def test_qr(type, shape, mode):
390423
assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol)
391424

392425

426+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
427+
def test_qr_not_2D():
428+
a = numpy.arange(12, dtype=numpy.float32).reshape((3, 2, 2))
429+
ia = inp.array(a)
430+
431+
np_q, np_r = numpy.linalg.qr(a)
432+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
433+
434+
assert dpnp_q.dtype == np_q.dtype
435+
assert dpnp_r.dtype == np_r.dtype
436+
assert dpnp_q.shape == np_q.shape
437+
assert dpnp_r.shape == np_r.shape
438+
439+
assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r))
440+
441+
a = numpy.empty((0, 3, 2), dtype=numpy.float32)
442+
ia = inp.array(a)
443+
444+
np_q, np_r = numpy.linalg.qr(a)
445+
dpnp_q, dpnp_r = inp.linalg.qr(ia)
446+
447+
assert dpnp_q.dtype == np_q.dtype
448+
assert dpnp_r.dtype == np_r.dtype
449+
assert dpnp_q.shape == np_q.shape
450+
assert dpnp_r.shape == np_r.shape
451+
452+
assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r))
453+
454+
393455
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
394456
@pytest.mark.parametrize(
395457
"shape",

0 commit comments

Comments
 (0)