Skip to content

Commit cdcc9b5

Browse files
Add tests for subtract and multiply
1 parent 3273a6f commit cdcc9b5

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
27+
28+
29+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
30+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
31+
def test_multiply_dtype_matrix(op1_dtype, op2_dtype):
32+
q = get_queue_or_skip()
33+
skip_if_dtype_not_supported(op1_dtype, q)
34+
skip_if_dtype_not_supported(op2_dtype, q)
35+
36+
sz = 127
37+
ar1 = dpt.ones(sz, dtype=op1_dtype)
38+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
39+
40+
r = dpt.multiply(ar1, ar2)
41+
assert isinstance(r, dpt.usm_ndarray)
42+
expected = np.multiply(
43+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
44+
)
45+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
46+
assert r.shape == ar1.shape
47+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
48+
assert r.sycl_queue == ar1.sycl_queue
49+
50+
ar3 = dpt.ones(sz, dtype=op1_dtype)
51+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
52+
53+
r = dpt.multiply(ar3[::-1], ar4[::2])
54+
assert isinstance(r, dpt.usm_ndarray)
55+
expected = np.multiply(
56+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
57+
)
58+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
59+
assert r.shape == ar3.shape
60+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
61+
62+
63+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
64+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
65+
def test_multiply_usm_type_matrix(op1_usm_type, op2_usm_type):
66+
get_queue_or_skip()
67+
68+
sz = 128
69+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
70+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
71+
72+
r = dpt.multiply(ar1, ar2)
73+
assert isinstance(r, dpt.usm_ndarray)
74+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
75+
(op1_usm_type, op2_usm_type)
76+
)
77+
assert r.usm_type == expected_usm_type
78+
79+
80+
def test_multiply_order():
81+
get_queue_or_skip()
82+
83+
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
84+
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
85+
r1 = dpt.multiply(ar1, ar2, order="C")
86+
assert r1.flags.c_contiguous
87+
r2 = dpt.multiply(ar1, ar2, order="F")
88+
assert r2.flags.f_contiguous
89+
r3 = dpt.multiply(ar1, ar2, order="A")
90+
assert r3.flags.c_contiguous
91+
r4 = dpt.multiply(ar1, ar2, order="K")
92+
assert r4.flags.c_contiguous
93+
94+
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
95+
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
96+
r1 = dpt.multiply(ar1, ar2, order="C")
97+
assert r1.flags.c_contiguous
98+
r2 = dpt.multiply(ar1, ar2, order="F")
99+
assert r2.flags.f_contiguous
100+
r3 = dpt.multiply(ar1, ar2, order="A")
101+
assert r3.flags.f_contiguous
102+
r4 = dpt.multiply(ar1, ar2, order="K")
103+
assert r4.flags.f_contiguous
104+
105+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
106+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
107+
r4 = dpt.multiply(ar1, ar2, order="K")
108+
assert r4.strides == (20, -1)
109+
110+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
111+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
112+
r4 = dpt.multiply(ar1, ar2, order="K")
113+
assert r4.strides == (-1, 20)
114+
115+
116+
def test_multiply_broadcasting():
117+
get_queue_or_skip()
118+
119+
m = dpt.ones((100, 5), dtype="i4")
120+
v = dpt.arange(1, 6, dtype="i4")
121+
122+
r = dpt.multiply(m, v)
123+
124+
expected = np.multiply(
125+
np.ones((100, 5), dtype="i4"), np.arange(1, 6, dtype="i4")
126+
)
127+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
128+
129+
r2 = dpt.multiply(v, m)
130+
expected2 = np.multiply(
131+
np.arange(1, 6, dtype="i4"), np.ones((100, 5), dtype="i4")
132+
)
133+
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
134+
135+
136+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
137+
def test_multiply_python_scalar(arr_dt):
138+
q = get_queue_or_skip()
139+
skip_if_dtype_not_supported(arr_dt, q)
140+
141+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
142+
py_ones = (
143+
bool(1),
144+
int(1),
145+
float(1),
146+
complex(1),
147+
np.float32(1),
148+
ctypes.c_int(1),
149+
)
150+
for sc in py_ones:
151+
R = dpt.multiply(X, sc)
152+
assert isinstance(R, dpt.usm_ndarray)
153+
R = dpt.multiply(sc, X)
154+
assert isinstance(R, dpt.usm_ndarray)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
27+
28+
29+
@pytest.mark.parametrize("op1_dtype", _all_dtypes[1:])
30+
@pytest.mark.parametrize("op2_dtype", _all_dtypes[1:])
31+
def test_subtract_dtype_matrix(op1_dtype, op2_dtype):
32+
q = get_queue_or_skip()
33+
skip_if_dtype_not_supported(op1_dtype, q)
34+
skip_if_dtype_not_supported(op2_dtype, q)
35+
36+
sz = 127
37+
ar1 = dpt.ones(sz, dtype=op1_dtype)
38+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
39+
40+
r = dpt.subtract(ar1, ar2)
41+
assert isinstance(r, dpt.usm_ndarray)
42+
expected_dtype = np.subtract(
43+
np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype)
44+
).dtype
45+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
46+
assert r.shape == ar1.shape
47+
assert (dpt.asnumpy(r) == np.full(r.shape, 0, dtype=r.dtype)).all()
48+
assert r.sycl_queue == ar1.sycl_queue
49+
50+
r2 = dpt.empty_like(ar1, dtype=r.dtype)
51+
dpt.subtract(ar1, ar2, out=r2)
52+
assert (dpt.asnumpy(r2) == np.full(r2.shape, 0, dtype=r2.dtype)).all()
53+
54+
ar3 = dpt.ones(sz, dtype=op1_dtype)
55+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
56+
57+
r = dpt.subtract(ar3[::-1], ar4[::2])
58+
assert isinstance(r, dpt.usm_ndarray)
59+
expected_dtype = np.subtract(
60+
np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype)
61+
).dtype
62+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
63+
assert r.shape == ar3.shape
64+
assert (dpt.asnumpy(r) == np.full(r.shape, 0, dtype=r.dtype)).all()
65+
66+
r2 = dpt.empty_like(ar1, dtype=r.dtype)
67+
dpt.subtract(ar3[::-1], ar4[::2], out=r2)
68+
assert (dpt.asnumpy(r2) == np.full(r2.shape, 0, dtype=r2.dtype)).all()
69+
70+
71+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
72+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
73+
def test_subtract_usm_type_matrix(op1_usm_type, op2_usm_type):
74+
get_queue_or_skip()
75+
76+
sz = 128
77+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
78+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
79+
80+
r = dpt.subtract(ar1, ar2)
81+
assert isinstance(r, dpt.usm_ndarray)
82+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
83+
(op1_usm_type, op2_usm_type)
84+
)
85+
assert r.usm_type == expected_usm_type
86+
87+
88+
def test_subtract_order():
89+
get_queue_or_skip()
90+
91+
test_shape = (
92+
20,
93+
20,
94+
)
95+
test_shape2 = tuple(2 * dim for dim in test_shape)
96+
n = test_shape[-1]
97+
98+
for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]):
99+
ar1 = dpt.ones(test_shape, dtype=dt1, order="C")
100+
ar2 = dpt.ones(test_shape, dtype=dt2, order="C")
101+
r1 = dpt.subtract(ar1, ar2, order="C")
102+
assert r1.flags.c_contiguous
103+
r2 = dpt.subtract(ar1, ar2, order="F")
104+
assert r2.flags.f_contiguous
105+
r3 = dpt.subtract(ar1, ar2, order="A")
106+
assert r3.flags.c_contiguous
107+
r4 = dpt.subtract(ar1, ar2, order="K")
108+
assert r4.flags.c_contiguous
109+
110+
ar1 = dpt.ones(test_shape, dtype=dt1, order="F")
111+
ar2 = dpt.ones(test_shape, dtype=dt2, order="F")
112+
r1 = dpt.subtract(ar1, ar2, order="C")
113+
assert r1.flags.c_contiguous
114+
r2 = dpt.subtract(ar1, ar2, order="F")
115+
assert r2.flags.f_contiguous
116+
r3 = dpt.subtract(ar1, ar2, order="A")
117+
assert r3.flags.f_contiguous
118+
r4 = dpt.subtract(ar1, ar2, order="K")
119+
assert r4.flags.f_contiguous
120+
121+
ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2]
122+
ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2]
123+
r4 = dpt.subtract(ar1, ar2, order="K")
124+
assert r4.strides == (n, -1)
125+
r5 = dpt.subtract(ar1, ar2, order="C")
126+
assert r5.strides == (n, 1)
127+
128+
ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2].mT
129+
ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2].mT
130+
r4 = dpt.subtract(ar1, ar2, order="K")
131+
assert r4.strides == (-1, n)
132+
r5 = dpt.subtract(ar1, ar2, order="C")
133+
assert r5.strides == (n, 1)
134+
135+
136+
def test_subtract_broadcasting():
137+
get_queue_or_skip()
138+
139+
m = dpt.ones((100, 5), dtype="i4")
140+
v = dpt.arange(5, dtype="i4")
141+
142+
r = dpt.subtract(m, v)
143+
assert (
144+
dpt.asnumpy(r) == np.arange(1, -4, step=-1, dtype="i4")[np.newaxis, :]
145+
).all()
146+
147+
r2 = dpt.subtract(v, m)
148+
assert (
149+
dpt.asnumpy(r2) == np.arange(-1, 4, dtype="i4")[np.newaxis, :]
150+
).all()
151+
152+
153+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
154+
def test_subtract_python_scalar(arr_dt):
155+
q = get_queue_or_skip()
156+
skip_if_dtype_not_supported(arr_dt, q)
157+
158+
X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q)
159+
py_zeros = (
160+
bool(0),
161+
int(0),
162+
float(0),
163+
complex(0),
164+
np.float32(0),
165+
ctypes.c_int(0),
166+
)
167+
for sc in py_zeros:
168+
R = dpt.subtract(X, sc)
169+
assert isinstance(R, dpt.usm_ndarray)
170+
R = dpt.subtract(sc, X)
171+
assert isinstance(R, dpt.usm_ndarray)

0 commit comments

Comments
 (0)