Skip to content

Commit 323a791

Browse files
committed
Adds tests for min and max
1 parent a733988 commit 323a791

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
21+
22+
23+
def test_max_min_axis():
24+
get_queue_or_skip()
25+
26+
x = dpt.reshape(
27+
dpt.arange((3 * 4 * 5 * 6 * 7), dtype="i4"), (3, 4, 5, 6, 7)
28+
)
29+
30+
m = dpt.max(x, axis=(1, 2, -1))
31+
assert m.shape == (3, 6)
32+
assert dpt.all(m == x[:, -1, -1, :, -1])
33+
34+
m = dpt.min(x, axis=(1, 2, -1))
35+
assert m.shape == (3, 6)
36+
assert dpt.all(m == x[:, 0, 0, :, 0])
37+
38+
39+
def test_reduction_keepdims():
40+
get_queue_or_skip()
41+
42+
x = dpt.ones((3, 4, 5, 6, 7), dtype="i4")
43+
m = dpt.max(x, axis=(1, 2, -1), keepdims=True)
44+
45+
assert m.shape == (3, 1, 1, 6, 1)
46+
assert dpt.all(m == dpt.reshape(x[:, 0, 0, :, 0], m.shape))
47+
48+
49+
def test_max_scalar():
50+
get_queue_or_skip()
51+
52+
x = dpt.ones(())
53+
m = dpt.max(x)
54+
55+
assert m.shape == ()
56+
assert x == m
57+
58+
59+
@pytest.mark.parametrize("arg_dtype", ["i4", "f4", "c8"])
60+
def test_reduction_kernels(arg_dtype):
61+
# i4 - always uses atomics w/ sycl group reduction
62+
# f4 - always uses atomics w/ custom group reduction
63+
# c8 - always uses temps w/ custom group reduction
64+
q = get_queue_or_skip()
65+
skip_if_dtype_not_supported(arg_dtype, q)
66+
67+
x = dpt.reshape(dpt.arange(24 * 1025, dtype=arg_dtype), (24, 1025))
68+
69+
m = dpt.max(x)
70+
assert m == x[-1, -1]
71+
m = dpt.max(x, axis=0)
72+
assert dpt.all(m == x[-1, :])
73+
m = dpt.max(x, axis=1)
74+
assert dpt.all(m == x[:, -1])
75+
76+
m = dpt.min(x)
77+
assert m == x[0, 0]
78+
m = dpt.min(x, axis=0)
79+
assert dpt.all(m == x[0, :])
80+
m = dpt.min(x, axis=1)
81+
assert dpt.all(m == x[:, 0])
82+
83+
84+
def test_max_min_nan_propagation():
85+
get_queue_or_skip()
86+
87+
# float, finites
88+
x = dpt.arange(4, dtype="f4")
89+
x[0] = dpt.nan
90+
assert dpt.isnan(dpt.max(x))
91+
assert dpt.isnan(dpt.min(x))
92+
93+
# float, infinities
94+
x[1:] = dpt.inf
95+
assert dpt.isnan(dpt.max(x))
96+
x[1:] = -dpt.inf
97+
assert dpt.isnan(dpt.min(x))
98+
99+
# complex
100+
x = dpt.arange(4, dtype="c8")
101+
x[0] = complex(dpt.nan, 0)
102+
assert dpt.isnan(dpt.max(x))
103+
assert dpt.isnan(dpt.min(x))
104+
105+
x[0] = complex(0, dpt.nan)
106+
assert dpt.isnan(dpt.max(x))
107+
assert dpt.isnan(dpt.min(x))

0 commit comments

Comments
 (0)