Skip to content

Commit 325b9d0

Browse files
committed
Loop device through elementwise functions
1 parent 600df5e commit 325b9d0

File tree

3 files changed

+185
-88
lines changed

3 files changed

+185
-88
lines changed

array_api_strict/_array_object.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@
4545

4646
class Device:
4747
def __init__(self, device="CPU_DEVICE"):
48+
if device not in ("CPU_DEVICE", "device1", "device2"):
49+
raise ValueError(f"The device '{device}' is not a valid choice.")
4850
self._device = device
4951

5052
def __repr__(self):
51-
return f"Device('{self._device}')"
53+
return f"array_api_strict.Device('{self._device}')"
5254

5355
def __eq__(self, other):
5456
return self._device == other._device
@@ -77,7 +79,7 @@ class Array:
7779
# Use a custom constructor instead of __init__, as manually initializing
7880
# this class is not supported API.
7981
@classmethod
80-
def _new(cls, x, /, device=CPU_DEVICE):
82+
def _new(cls, x, /, device=None):
8183
"""
8284
This is a private method for initializing the array API Array
8385
object.
@@ -123,7 +125,11 @@ def __repr__(self: Array, /) -> str:
123125
"""
124126
Performs the operation __repr__.
125127
"""
126-
suffix = f", dtype={self.dtype})"
128+
suffix = f", dtype={self.dtype}"
129+
if self.device != CPU_DEVICE:
130+
suffix += f", device={self.device})"
131+
else:
132+
suffix += ")"
127133
if 0 in self.shape:
128134
prefix = "empty("
129135
mid = str(self.shape)
@@ -202,6 +208,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
202208

203209
return other
204210

211+
def _check_device(self, other):
212+
"""Check that other is on a device compatible with the current array"""
213+
if isinstance(other, (int, complex, float, bool)):
214+
return other
215+
elif isinstance(other, Array):
216+
if self.device != other.device:
217+
raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
218+
return other
219+
205220
# Helper function to match the type promotion rules in the spec
206221
def _promote_scalar(self, scalar):
207222
"""
@@ -477,23 +492,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
477492
"""
478493
Performs the operation __add__.
479494
"""
495+
other = self._check_device(other)
480496
other = self._check_allowed_dtypes(other, "numeric", "__add__")
481497
if other is NotImplemented:
482498
return other
483499
self, other = self._normalize_two_args(self, other)
484500
res = self._array.__add__(other._array)
485-
return self.__class__._new(res)
501+
return self.__class__._new(res, device=self.device)
486502

487503
def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
488504
"""
489505
Performs the operation __and__.
490506
"""
507+
other = self._check_device(other)
491508
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
492509
if other is NotImplemented:
493510
return other
494511
self, other = self._normalize_two_args(self, other)
495512
res = self._array.__and__(other._array)
496-
return self.__class__._new(res)
513+
return self.__class__._new(res, device=self.device)
497514

498515
def __array_namespace__(
499516
self: Array, /, *, api_version: Optional[str] = None
@@ -577,14 +594,15 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
577594
"""
578595
Performs the operation __eq__.
579596
"""
597+
other = self._check_device(other)
580598
# Even though "all" dtypes are allowed, we still require them to be
581599
# promotable with each other.
582600
other = self._check_allowed_dtypes(other, "all", "__eq__")
583601
if other is NotImplemented:
584602
return other
585603
self, other = self._normalize_two_args(self, other)
586604
res = self._array.__eq__(other._array)
587-
return self.__class__._new(res)
605+
return self.__class__._new(res, device=self.device)
588606

589607
def __float__(self: Array, /) -> float:
590608
"""
@@ -602,23 +620,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
602620
"""
603621
Performs the operation __floordiv__.
604622
"""
623+
other = self._check_device(other)
605624
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
606625
if other is NotImplemented:
607626
return other
608627
self, other = self._normalize_two_args(self, other)
609628
res = self._array.__floordiv__(other._array)
610-
return self.__class__._new(res)
629+
return self.__class__._new(res, device=self.device)
611630

612631
def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
613632
"""
614633
Performs the operation __ge__.
615634
"""
635+
other = self._check_device(other)
616636
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
617637
if other is NotImplemented:
618638
return other
619639
self, other = self._normalize_two_args(self, other)
620640
res = self._array.__ge__(other._array)
621-
return self.__class__._new(res)
641+
return self.__class__._new(res, device=self.device)
622642

623643
def __getitem__(
624644
self: Array,
@@ -634,19 +654,21 @@ def __getitem__(
634654
"""
635655
Performs the operation __getitem__.
636656
"""
657+
# XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE?
637658
# Note: Only indices required by the spec are allowed. See the
638659
# docstring of _validate_index
639660
self._validate_index(key)
640661
if isinstance(key, Array):
641662
# Indexing self._array with array_api_strict arrays can be erroneous
642663
key = key._array
643664
res = self._array.__getitem__(key)
644-
return self._new(res)
665+
return self._new(res, device=self.device)
645666

646667
def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
647668
"""
648669
Performs the operation __gt__.
649670
"""
671+
other = self._check_device(other)
650672
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
651673
if other is NotImplemented:
652674
return other
@@ -680,7 +702,7 @@ def __invert__(self: Array, /) -> Array:
680702
if self.dtype not in _integer_or_boolean_dtypes:
681703
raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
682704
res = self._array.__invert__()
683-
return self.__class__._new(res)
705+
return self.__class__._new(res, device=self.device)
684706

685707
def __iter__(self: Array, /):
686708
"""
@@ -695,85 +717,92 @@ def __iter__(self: Array, /):
695717
# define __iter__, but it doesn't disallow it. The default Python
696718
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
697719
# implemented, which implies iteration on 1-D arrays.
698-
return (Array._new(i) for i in self._array)
720+
return (Array._new(i, device=self.device) for i in self._array)
699721

700722
def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
701723
"""
702724
Performs the operation __le__.
703725
"""
726+
other = self._check_device(other)
704727
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
705728
if other is NotImplemented:
706729
return other
707730
self, other = self._normalize_two_args(self, other)
708731
res = self._array.__le__(other._array)
709-
return self.__class__._new(res)
732+
return self.__class__._new(res, device=self.device)
710733

711734
def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
712735
"""
713736
Performs the operation __lshift__.
714737
"""
738+
other = self._check_device(other)
715739
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
716740
if other is NotImplemented:
717741
return other
718742
self, other = self._normalize_two_args(self, other)
719743
res = self._array.__lshift__(other._array)
720-
return self.__class__._new(res)
744+
return self.__class__._new(res, device=self.device)
721745

722746
def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
723747
"""
724748
Performs the operation __lt__.
725749
"""
750+
other = self._check_device(other)
726751
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
727752
if other is NotImplemented:
728753
return other
729754
self, other = self._normalize_two_args(self, other)
730755
res = self._array.__lt__(other._array)
731-
return self.__class__._new(res)
756+
return self.__class__._new(res, device=self.device)
732757

733758
def __matmul__(self: Array, other: Array, /) -> Array:
734759
"""
735760
Performs the operation __matmul__.
736761
"""
762+
other = self._check_device(other)
737763
# matmul is not defined for scalars, but without this, we may get
738764
# the wrong error message from asarray.
739765
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
740766
if other is NotImplemented:
741767
return other
742768
res = self._array.__matmul__(other._array)
743-
return self.__class__._new(res)
769+
return self.__class__._new(res, device=self.device)
744770

745771
def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
746772
"""
747773
Performs the operation __mod__.
748774
"""
775+
other = self._check_device(other)
749776
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
750777
if other is NotImplemented:
751778
return other
752779
self, other = self._normalize_two_args(self, other)
753780
res = self._array.__mod__(other._array)
754-
return self.__class__._new(res)
781+
return self.__class__._new(res, device=self.device)
755782

756783
def __mul__(self: Array, other: Union[int, float, Array], /) -> Array:
757784
"""
758785
Performs the operation __mul__.
759786
"""
787+
other = self._check_device(other)
760788
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
761789
if other is NotImplemented:
762790
return other
763791
self, other = self._normalize_two_args(self, other)
764792
res = self._array.__mul__(other._array)
765-
return self.__class__._new(res)
793+
return self.__class__._new(res, device=self.device)
766794

767795
def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
768796
"""
769797
Performs the operation __ne__.
770798
"""
799+
other = self._check_device(other)
771800
other = self._check_allowed_dtypes(other, "all", "__ne__")
772801
if other is NotImplemented:
773802
return other
774803
self, other = self._normalize_two_args(self, other)
775804
res = self._array.__ne__(other._array)
776-
return self.__class__._new(res)
805+
return self.__class__._new(res, device=self.device)
777806

778807
def __neg__(self: Array, /) -> Array:
779808
"""
@@ -782,18 +811,19 @@ def __neg__(self: Array, /) -> Array:
782811
if self.dtype not in _numeric_dtypes:
783812
raise TypeError("Only numeric dtypes are allowed in __neg__")
784813
res = self._array.__neg__()
785-
return self.__class__._new(res)
814+
return self.__class__._new(res, device=self.device)
786815

787816
def __or__(self: Array, other: Union[int, bool, Array], /) -> Array:
788817
"""
789818
Performs the operation __or__.
790819
"""
820+
other = self._check_device(other)
791821
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
792822
if other is NotImplemented:
793823
return other
794824
self, other = self._normalize_two_args(self, other)
795825
res = self._array.__or__(other._array)
796-
return self.__class__._new(res)
826+
return self.__class__._new(res, device=self.device)
797827

798828
def __pos__(self: Array, /) -> Array:
799829
"""
@@ -802,14 +832,15 @@ def __pos__(self: Array, /) -> Array:
802832
if self.dtype not in _numeric_dtypes:
803833
raise TypeError("Only numeric dtypes are allowed in __pos__")
804834
res = self._array.__pos__()
805-
return self.__class__._new(res)
835+
return self.__class__._new(res, device=self.device)
806836

807837
def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
808838
"""
809839
Performs the operation __pow__.
810840
"""
811841
from ._elementwise_functions import pow
812842

843+
other = self._check_device(other)
813844
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
814845
if other is NotImplemented:
815846
return other

0 commit comments

Comments
 (0)