Skip to content

Commit 0a19b73

Browse files
authored
dpnp_array must expose usm_type (#1228)
1 parent 2c3c6b0 commit 0a19b73

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

dpnp/dpnp_array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def sycl_context(self):
115115
def device(self):
116116
return self._array_obj.device
117117

118+
@property
119+
def usm_type(self):
120+
return self._array_obj.usm_type
121+
118122
def __abs__(self):
119123
return dpnp.abs(self)
120124

tests/test_random_state.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
)
1616

1717

18+
def assert_cfd(data, exp_sycl_queue, exp_usm_type=None):
19+
assert exp_sycl_queue == data.sycl_queue
20+
if exp_usm_type:
21+
assert exp_usm_type == data.usm_type
22+
23+
1824
class TestNormal:
1925
@pytest.mark.parametrize("dtype",
2026
[dpnp.float32, dpnp.float64, None],
@@ -47,7 +53,7 @@ def test_distr(self, dtype, usm_type):
4753
assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision)
4854

4955
# check if compute follows data isn't broken
50-
assert sycl_queue == data.sycl_queue
56+
assert_cfd(data, sycl_queue, usm_type)
5157

5258

5359
@pytest.mark.parametrize("dtype",
@@ -138,7 +144,7 @@ def test_fallback(self, loc, scale):
138144
assert_array_almost_equal(actual, desired, decimal=precision)
139145

140146
# check if compute follows data isn't broken
141-
assert sycl_queue == data.sycl_queue
147+
assert_cfd(data, sycl_queue)
142148

143149

144150
@pytest.mark.parametrize("dtype",
@@ -174,17 +180,17 @@ def test_distr(self, usm_type):
174180

175181
precision = numpy.finfo(dtype=numpy.float64).precision
176182
assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision)
177-
assert sycl_queue == data.sycl_queue
183+
assert_cfd(data, sycl_queue, usm_type)
178184

179185
# call with the same seed has to draw the same values
180186
data = RandomState(seed, sycl_queue=sycl_queue).rand(3, 2, usm_type=usm_type)
181187
assert_array_almost_equal(dpnp.asnumpy(data), desired, decimal=precision)
182-
assert sycl_queue == data.sycl_queue
188+
assert_cfd(data, sycl_queue, usm_type)
183189

184190
# call with omitted dimensions has to draw the first element from desired
185191
data = RandomState(seed, sycl_queue=sycl_queue).rand(usm_type=usm_type)
186192
assert_array_almost_equal(dpnp.asnumpy(data), desired[0, 0], decimal=precision)
187-
assert sycl_queue == data.sycl_queue
193+
assert_cfd(data, sycl_queue, usm_type)
188194

189195
# rand() is an alias on random_sample(), map arguments
190196
with mock.patch('dpnp.random.RandomState.random_sample') as m:
@@ -245,7 +251,7 @@ def test_distr(self, dtype, usm_type):
245251
[5, 3],
246252
[5, 7]], dtype=numpy.int32)
247253
assert_array_equal(dpnp.asnumpy(data), desired)
248-
assert sycl_queue == data.sycl_queue
254+
assert_cfd(data, sycl_queue, usm_type)
249255

250256
# call with the same seed has to draw the same values
251257
data = RandomState(seed, sycl_queue=sycl_queue).randint(low=low,
@@ -254,15 +260,15 @@ def test_distr(self, dtype, usm_type):
254260
dtype=dtype,
255261
usm_type=usm_type)
256262
assert_array_equal(dpnp.asnumpy(data), desired)
257-
assert sycl_queue == data.sycl_queue
263+
assert_cfd(data, sycl_queue, usm_type)
258264

259265
# call with omitted dimensions has to draw the first element from desired
260266
data = RandomState(seed, sycl_queue=sycl_queue).randint(low=low,
261267
high=high,
262268
dtype=dtype,
263269
usm_type=usm_type)
264270
assert_array_equal(dpnp.asnumpy(data), desired[0, 0])
265-
assert sycl_queue == data.sycl_queue
271+
assert_cfd(data, sycl_queue, usm_type)
266272

267273
# rand() is an alias on random_sample(), map arguments
268274
with mock.patch('dpnp.random.RandomState.uniform') as m:
@@ -701,7 +707,7 @@ def test_distr(self, bounds, dtype, usm_type):
701707
assert_array_equal(dpnp.asnumpy(data), desired)
702708

703709
# check if compute follows data isn't broken
704-
assert sycl_queue == data.sycl_queue
710+
assert_cfd(data, sycl_queue, usm_type)
705711

706712

707713
@pytest.mark.parametrize("dtype",
@@ -766,7 +772,7 @@ def test_fallback(self, low, high):
766772
assert_array_almost_equal(actual, desired, decimal=precision)
767773

768774
# check if compute follows data isn't broken
769-
assert sycl_queue == data.sycl_queue
775+
assert_cfd(data, sycl_queue)
770776

771777

772778
@pytest.mark.parametrize("dtype",

tests/test_sycl_queue.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ def test_uniform(usm_type, size):
278278
high = 2.0
279279
res = dpnp.random.uniform(low, high, size=size, usm_type=usm_type)
280280

281-
res_usm_type = res.get_array().usm_type
282-
assert usm_type == res_usm_type
281+
assert usm_type == res.usm_type
283282

284283

285284
@pytest.mark.parametrize("usm_type",
@@ -295,8 +294,7 @@ def test_rs_uniform(usm_type, seed):
295294
rs = dpnp.random.RandomState(seed, sycl_queue=sycl_queue)
296295
res = rs.uniform(low, high, usm_type=usm_type)
297296

298-
res_usm_type = res.get_array().usm_type
299-
assert usm_type == res_usm_type
297+
assert usm_type == res.usm_type
300298

301299
res_sycl_queue = res.get_array().sycl_queue
302300
assert_sycl_queue_equal(res_sycl_queue, sycl_queue)

0 commit comments

Comments
 (0)