Skip to content

Commit 8ab37a1

Browse files
committed
Adds tests for reduction out kwarg
1 parent cd8acbe commit 8ab37a1

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import dpctl.tensor as dpt
2424
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
from dpctl.utils import ExecutionPlacementError
2526

2627
_no_complex_dtypes = [
2728
"?",
@@ -497,3 +498,173 @@ def test_tree_reduction_axis1_axis0():
497498
rtol=tol,
498499
atol=tol,
499500
)
501+
502+
503+
def test_numeric_reduction_out_kwarg():
504+
get_queue_or_skip()
505+
506+
n1, n2, n3 = 3, 4, 5
507+
x = dpt.ones((n1, n2, n3), dtype="i8")
508+
out = dpt.zeros((2 * n1, 3 * n2), dtype="i8")
509+
res = dpt.sum(x, axis=-1, out=out[::-2, 1::3])
510+
assert dpt.all(out[::-2, 0::3] == 0)
511+
assert dpt.all(out[::-2, 2::3] == 0)
512+
assert dpt.all(out[::-2, 1::3] == res)
513+
assert dpt.all(out[::-2, 1::3] == 5)
514+
515+
out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i8")
516+
res = dpt.sum(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
517+
assert res.shape == (n1, n2, 1)
518+
assert dpt.all(out[::-2, 0::3] == 0)
519+
assert dpt.all(out[::-2, 2::3] == 0)
520+
assert dpt.all(out[::-2, 1::3] == res)
521+
assert dpt.all(out[::-2, 1::3] == 5)
522+
523+
res = dpt.sum(x, axis=0, out=x[-1])
524+
assert dpt.all(x[-1] == res)
525+
assert dpt.all(x[-1] == 3)
526+
assert dpt.all(x[0:-1] == 1)
527+
528+
# test no-op case
529+
x = dpt.ones((n1, n2, n3), dtype="i8")
530+
out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i8")
531+
res = dpt.sum(x, axis=(), out=out[::-2, 1::3])
532+
assert dpt.all(out[::-2, 0::3] == 0)
533+
assert dpt.all(out[::-2, 2::3] == 0)
534+
assert dpt.all(out[::-2, 1::3] == x)
535+
536+
# test with dtype kwarg
537+
x = dpt.ones((n1, n2, n3), dtype="i4")
538+
out = dpt.zeros((2 * n1, 3 * n2), dtype="f4")
539+
res = dpt.sum(x, axis=-1, dtype="f4", out=out[::-2, 1::3])
540+
assert dpt.allclose(out[::-2, 0::3], dpt.zeros_like(res))
541+
assert dpt.allclose(out[::-2, 2::3], dpt.zeros_like(res))
542+
assert dpt.allclose(out[::-2, 1::3], res)
543+
assert dpt.allclose(out[::-2, 1::3], dpt.full_like(res, 5, dtype="f4"))
544+
545+
546+
def test_comparison_reduction_out_kwarg():
547+
get_queue_or_skip()
548+
549+
n1, n2, n3 = 3, 4, 5
550+
x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype="i4"), (n1, n2, n3))
551+
out = dpt.zeros((2 * n1, 3 * n2), dtype="i4")
552+
res = dpt.max(x, axis=-1, out=out[::-2, 1::3])
553+
assert dpt.all(out[::-2, 0::3] == 0)
554+
assert dpt.all(out[::-2, 2::3] == 0)
555+
assert dpt.all(out[::-2, 1::3] == res)
556+
assert dpt.all(out[::-2, 1::3] == x[:, :, -1])
557+
558+
out = dpt.zeros((2 * n1, 3 * n2, 1), dtype="i4")
559+
res = dpt.max(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
560+
assert res.shape == (n1, n2, 1)
561+
assert dpt.all(out[::-2, 0::3] == 0)
562+
assert dpt.all(out[::-2, 2::3] == 0)
563+
assert dpt.all(out[::-2, 1::3] == res)
564+
assert dpt.all(out[::-2, 1::3] == x[:, :, -1, dpt.newaxis])
565+
566+
# test no-op case
567+
out = dpt.zeros((2 * n1, 3 * n2, n3), dtype="i4")
568+
res = dpt.max(x, axis=(), out=out[::-2, 1::3])
569+
assert dpt.all(out[::-2, 0::3] == 0)
570+
assert dpt.all(out[::-2, 2::3] == 0)
571+
assert dpt.all(out[::-2, 1::3] == x)
572+
573+
# test overlap
574+
res = dpt.max(x, axis=0, out=x[0])
575+
assert dpt.all(x[0] == res)
576+
assert dpt.all(x[0] == x[-1])
577+
578+
579+
def test_search_reduction_out_kwarg():
580+
get_queue_or_skip()
581+
582+
n1, n2, n3 = 3, 4, 5
583+
dt = dpt.__array_namespace_info__().default_dtypes()["indexing"]
584+
585+
x = dpt.reshape(dpt.arange(n1 * n2 * n3, dtype=dt), (n1, n2, n3))
586+
out = dpt.zeros((2 * n1, 3 * n2), dtype=dt)
587+
res = dpt.argmax(x, axis=-1, out=out[::-2, 1::3])
588+
assert dpt.all(out[::-2, 0::3] == 0)
589+
assert dpt.all(out[::-2, 2::3] == 0)
590+
assert dpt.all(out[::-2, 1::3] == res)
591+
assert dpt.all(out[::-2, 1::3] == n2)
592+
593+
out = dpt.zeros((2 * n1, 3 * n2, 1), dtype=dt)
594+
res = dpt.argmax(x, axis=-1, keepdims=True, out=out[::-2, 1::3])
595+
assert res.shape == (n1, n2, 1)
596+
assert dpt.all(out[::-2, 0::3] == 0)
597+
assert dpt.all(out[::-2, 2::3] == 0)
598+
assert dpt.all(out[::-2, 1::3] == res)
599+
assert dpt.all(out[::-2, 1::3] == n3 - 1)
600+
601+
# test no-op case
602+
x = dpt.ones((), dtype=dt)
603+
out = dpt.ones(2, dtype=dt)
604+
res = dpt.argmax(x, axis=None, out=out[1])
605+
assert dpt.all(out[0] == 1)
606+
assert dpt.all(out[1] == 0)
607+
608+
# test overlap
609+
x = dpt.reshape(dpt.arange(n1 * n2, dtype=dt), (n1, n2))
610+
res = dpt.argmax(x, axis=0, out=x[0])
611+
assert dpt.all(x[0] == res)
612+
assert dpt.all(x[0] == n1 - 1)
613+
614+
615+
def test_reduction_out_kwarg_arg_validation():
616+
q1 = get_queue_or_skip()
617+
q2 = get_queue_or_skip()
618+
619+
ind_dt = dpt.__array_namespace_info__().default_dtypes()["indexing"]
620+
621+
x = dpt.ones(10, dtype="f4")
622+
out_wrong_queue = dpt.empty((), dtype="f4", sycl_queue=q2)
623+
out_wrong_dtype = dpt.empty((), dtype="i4", sycl_queue=q1)
624+
out_wrong_shape = dpt.empty(1, dtype="f4", sycl_queue=q1)
625+
out_wrong_keepdims = dpt.empty((), dtype="f4", sycl_queue=q1)
626+
out_not_writable = dpt.empty((), dtype="f4", sycl_queue=q1)
627+
out_not_writable.flags["W"] = False
628+
629+
with pytest.raises(TypeError):
630+
dpt.sum(x, out=dict())
631+
with pytest.raises(TypeError):
632+
dpt.max(x, out=dict())
633+
with pytest.raises(TypeError):
634+
dpt.argmax(x, out=dict())
635+
with pytest.raises(ExecutionPlacementError):
636+
dpt.sum(x, out=out_wrong_queue)
637+
with pytest.raises(ExecutionPlacementError):
638+
dpt.max(x, out=out_wrong_queue)
639+
with pytest.raises(ExecutionPlacementError):
640+
dpt.argmax(x, out=dpt.empty_like(out_wrong_queue, dtype=ind_dt))
641+
with pytest.raises(ValueError):
642+
dpt.sum(x, out=out_wrong_dtype)
643+
with pytest.raises(ValueError):
644+
dpt.max(x, out=out_wrong_dtype)
645+
with pytest.raises(ValueError):
646+
dpt.argmax(x, out=dpt.empty_like(out_wrong_dtype, dtype="f4"))
647+
with pytest.raises(ValueError):
648+
dpt.sum(x, out=out_wrong_shape)
649+
with pytest.raises(ValueError):
650+
dpt.max(x, out=out_wrong_shape)
651+
with pytest.raises(ValueError):
652+
dpt.argmax(x, out=dpt.empty_like(out_wrong_shape, dtype=ind_dt))
653+
with pytest.raises(ValueError):
654+
dpt.sum(x, out=out_not_writable)
655+
with pytest.raises(ValueError):
656+
dpt.max(x, out=out_not_writable)
657+
with pytest.raises(ValueError):
658+
search_not_writable = dpt.empty_like(out_not_writable, dtype=ind_dt)
659+
search_not_writable.flags["W"] = False
660+
dpt.argmax(x, out=search_not_writable)
661+
with pytest.raises(ValueError):
662+
dpt.sum(x, keepdims=True, out=out_wrong_keepdims)
663+
with pytest.raises(ValueError):
664+
dpt.max(x, keepdims=True, out=out_wrong_keepdims)
665+
with pytest.raises(ValueError):
666+
dpt.argmax(
667+
x,
668+
keepdims=True,
669+
out=dpt.empty_like(out_wrong_keepdims, dtype=ind_dt),
670+
)

0 commit comments

Comments
 (0)