Skip to content

Commit bca670a

Browse files
committed
More device pass through
1 parent 6cc7bac commit bca670a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

array_api_strict/_creation_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def ones_like(
349349

350350
_check_valid_dtype(dtype)
351351
_check_device(device)
352+
if device is None:
353+
device = x.device
352354

353355
if dtype is not None:
354356
dtype = dtype._np_dtype
@@ -366,7 +368,7 @@ def tril(x: Array, /, *, k: int = 0) -> Array:
366368
if x.ndim < 2:
367369
# Note: Unlike np.tril, x must be at least 2-D
368370
raise ValueError("x must be at least 2-dimensional for tril")
369-
return Array._new(np.tril(x._array, k=k))
371+
return Array._new(np.tril(x._array, k=k), device=x.device)
370372

371373

372374
def triu(x: Array, /, *, k: int = 0) -> Array:
@@ -380,7 +382,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
380382
if x.ndim < 2:
381383
# Note: Unlike np.triu, x must be at least 2-D
382384
raise ValueError("x must be at least 2-dimensional for triu")
383-
return Array._new(np.triu(x._array, k=k), device=device)
385+
return Array._new(np.triu(x._array, k=k), device=x.device)
384386

385387

386388
def zeros(
@@ -416,6 +418,8 @@ def zeros_like(
416418

417419
_check_valid_dtype(dtype)
418420
_check_device(device)
421+
if device is None:
422+
device = x.device
419423

420424
if dtype is not None:
421425
dtype = dtype._np_dtype

0 commit comments

Comments
 (0)