45
45
46
46
class Device :
47
47
def __init__ (self , device = "CPU_DEVICE" ):
48
+ if device not in ("CPU_DEVICE" , "device1" , "device2" ):
49
+ raise ValueError (f"The device '{ device } ' is not a valid choice." )
48
50
self ._device = device
49
51
50
52
def __repr__ (self ):
51
- return f"Device('{ self ._device } ')"
53
+ return f"array_api_strict. Device('{ self ._device } ')"
52
54
53
55
def __eq__ (self , other ):
54
56
return self ._device == other ._device
@@ -77,7 +79,7 @@ class Array:
77
79
# Use a custom constructor instead of __init__, as manually initializing
78
80
# this class is not supported API.
79
81
@classmethod
80
- def _new (cls , x , / , device = CPU_DEVICE ):
82
+ def _new (cls , x , / , device = None ):
81
83
"""
82
84
This is a private method for initializing the array API Array
83
85
object.
@@ -123,7 +125,11 @@ def __repr__(self: Array, /) -> str:
123
125
"""
124
126
Performs the operation __repr__.
125
127
"""
126
- suffix = f", dtype={ self .dtype } )"
128
+ suffix = f", dtype={ self .dtype } "
129
+ if self .device != CPU_DEVICE :
130
+ suffix += f", device={ self .device } )"
131
+ else :
132
+ suffix += ")"
127
133
if 0 in self .shape :
128
134
prefix = "empty("
129
135
mid = str (self .shape )
@@ -202,6 +208,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
202
208
203
209
return other
204
210
211
+ def _check_device (self , other ):
212
+ """Check that other is on a device compatible with the current array"""
213
+ if isinstance (other , (int , complex , float , bool )):
214
+ return other
215
+ elif isinstance (other , Array ):
216
+ if self .device != other .device :
217
+ raise RuntimeError (f"Arrays from two different devices ({ self .device } and { other .device } ) can not be combined." )
218
+ return other
219
+
205
220
# Helper function to match the type promotion rules in the spec
206
221
def _promote_scalar (self , scalar ):
207
222
"""
@@ -477,23 +492,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
477
492
"""
478
493
Performs the operation __add__.
479
494
"""
495
+ other = self ._check_device (other )
480
496
other = self ._check_allowed_dtypes (other , "numeric" , "__add__" )
481
497
if other is NotImplemented :
482
498
return other
483
499
self , other = self ._normalize_two_args (self , other )
484
500
res = self ._array .__add__ (other ._array )
485
- return self .__class__ ._new (res )
501
+ return self .__class__ ._new (res , device = self . device )
486
502
487
503
def __and__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
488
504
"""
489
505
Performs the operation __and__.
490
506
"""
507
+ other = self ._check_device (other )
491
508
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__and__" )
492
509
if other is NotImplemented :
493
510
return other
494
511
self , other = self ._normalize_two_args (self , other )
495
512
res = self ._array .__and__ (other ._array )
496
- return self .__class__ ._new (res )
513
+ return self .__class__ ._new (res , device = self . device )
497
514
498
515
def __array_namespace__ (
499
516
self : Array , / , * , api_version : Optional [str ] = None
@@ -577,14 +594,15 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
577
594
"""
578
595
Performs the operation __eq__.
579
596
"""
597
+ other = self ._check_device (other )
580
598
# Even though "all" dtypes are allowed, we still require them to be
581
599
# promotable with each other.
582
600
other = self ._check_allowed_dtypes (other , "all" , "__eq__" )
583
601
if other is NotImplemented :
584
602
return other
585
603
self , other = self ._normalize_two_args (self , other )
586
604
res = self ._array .__eq__ (other ._array )
587
- return self .__class__ ._new (res )
605
+ return self .__class__ ._new (res , device = self . device )
588
606
589
607
def __float__ (self : Array , / ) -> float :
590
608
"""
@@ -602,23 +620,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
602
620
"""
603
621
Performs the operation __floordiv__.
604
622
"""
623
+ other = self ._check_device (other )
605
624
other = self ._check_allowed_dtypes (other , "real numeric" , "__floordiv__" )
606
625
if other is NotImplemented :
607
626
return other
608
627
self , other = self ._normalize_two_args (self , other )
609
628
res = self ._array .__floordiv__ (other ._array )
610
- return self .__class__ ._new (res )
629
+ return self .__class__ ._new (res , device = self . device )
611
630
612
631
def __ge__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
613
632
"""
614
633
Performs the operation __ge__.
615
634
"""
635
+ other = self ._check_device (other )
616
636
other = self ._check_allowed_dtypes (other , "real numeric" , "__ge__" )
617
637
if other is NotImplemented :
618
638
return other
619
639
self , other = self ._normalize_two_args (self , other )
620
640
res = self ._array .__ge__ (other ._array )
621
- return self .__class__ ._new (res )
641
+ return self .__class__ ._new (res , device = self . device )
622
642
623
643
def __getitem__ (
624
644
self : Array ,
@@ -634,19 +654,21 @@ def __getitem__(
634
654
"""
635
655
Performs the operation __getitem__.
636
656
"""
657
+ # XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE?
637
658
# Note: Only indices required by the spec are allowed. See the
638
659
# docstring of _validate_index
639
660
self ._validate_index (key )
640
661
if isinstance (key , Array ):
641
662
# Indexing self._array with array_api_strict arrays can be erroneous
642
663
key = key ._array
643
664
res = self ._array .__getitem__ (key )
644
- return self ._new (res )
665
+ return self ._new (res , device = self . device )
645
666
646
667
def __gt__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
647
668
"""
648
669
Performs the operation __gt__.
649
670
"""
671
+ other = self ._check_device (other )
650
672
other = self ._check_allowed_dtypes (other , "real numeric" , "__gt__" )
651
673
if other is NotImplemented :
652
674
return other
@@ -680,7 +702,7 @@ def __invert__(self: Array, /) -> Array:
680
702
if self .dtype not in _integer_or_boolean_dtypes :
681
703
raise TypeError ("Only integer or boolean dtypes are allowed in __invert__" )
682
704
res = self ._array .__invert__ ()
683
- return self .__class__ ._new (res )
705
+ return self .__class__ ._new (res , device = self . device )
684
706
685
707
def __iter__ (self : Array , / ):
686
708
"""
@@ -695,85 +717,92 @@ def __iter__(self: Array, /):
695
717
# define __iter__, but it doesn't disallow it. The default Python
696
718
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
697
719
# implemented, which implies iteration on 1-D arrays.
698
- return (Array ._new (i ) for i in self ._array )
720
+ return (Array ._new (i , device = self . device ) for i in self ._array )
699
721
700
722
def __le__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
701
723
"""
702
724
Performs the operation __le__.
703
725
"""
726
+ other = self ._check_device (other )
704
727
other = self ._check_allowed_dtypes (other , "real numeric" , "__le__" )
705
728
if other is NotImplemented :
706
729
return other
707
730
self , other = self ._normalize_two_args (self , other )
708
731
res = self ._array .__le__ (other ._array )
709
- return self .__class__ ._new (res )
732
+ return self .__class__ ._new (res , device = self . device )
710
733
711
734
def __lshift__ (self : Array , other : Union [int , Array ], / ) -> Array :
712
735
"""
713
736
Performs the operation __lshift__.
714
737
"""
738
+ other = self ._check_device (other )
715
739
other = self ._check_allowed_dtypes (other , "integer" , "__lshift__" )
716
740
if other is NotImplemented :
717
741
return other
718
742
self , other = self ._normalize_two_args (self , other )
719
743
res = self ._array .__lshift__ (other ._array )
720
- return self .__class__ ._new (res )
744
+ return self .__class__ ._new (res , device = self . device )
721
745
722
746
def __lt__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
723
747
"""
724
748
Performs the operation __lt__.
725
749
"""
750
+ other = self ._check_device (other )
726
751
other = self ._check_allowed_dtypes (other , "real numeric" , "__lt__" )
727
752
if other is NotImplemented :
728
753
return other
729
754
self , other = self ._normalize_two_args (self , other )
730
755
res = self ._array .__lt__ (other ._array )
731
- return self .__class__ ._new (res )
756
+ return self .__class__ ._new (res , device = self . device )
732
757
733
758
def __matmul__ (self : Array , other : Array , / ) -> Array :
734
759
"""
735
760
Performs the operation __matmul__.
736
761
"""
762
+ other = self ._check_device (other )
737
763
# matmul is not defined for scalars, but without this, we may get
738
764
# the wrong error message from asarray.
739
765
other = self ._check_allowed_dtypes (other , "numeric" , "__matmul__" )
740
766
if other is NotImplemented :
741
767
return other
742
768
res = self ._array .__matmul__ (other ._array )
743
- return self .__class__ ._new (res )
769
+ return self .__class__ ._new (res , device = self . device )
744
770
745
771
def __mod__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
746
772
"""
747
773
Performs the operation __mod__.
748
774
"""
775
+ other = self ._check_device (other )
749
776
other = self ._check_allowed_dtypes (other , "real numeric" , "__mod__" )
750
777
if other is NotImplemented :
751
778
return other
752
779
self , other = self ._normalize_two_args (self , other )
753
780
res = self ._array .__mod__ (other ._array )
754
- return self .__class__ ._new (res )
781
+ return self .__class__ ._new (res , device = self . device )
755
782
756
783
def __mul__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
757
784
"""
758
785
Performs the operation __mul__.
759
786
"""
787
+ other = self ._check_device (other )
760
788
other = self ._check_allowed_dtypes (other , "numeric" , "__mul__" )
761
789
if other is NotImplemented :
762
790
return other
763
791
self , other = self ._normalize_two_args (self , other )
764
792
res = self ._array .__mul__ (other ._array )
765
- return self .__class__ ._new (res )
793
+ return self .__class__ ._new (res , device = self . device )
766
794
767
795
def __ne__ (self : Array , other : Union [int , float , bool , Array ], / ) -> Array :
768
796
"""
769
797
Performs the operation __ne__.
770
798
"""
799
+ other = self ._check_device (other )
771
800
other = self ._check_allowed_dtypes (other , "all" , "__ne__" )
772
801
if other is NotImplemented :
773
802
return other
774
803
self , other = self ._normalize_two_args (self , other )
775
804
res = self ._array .__ne__ (other ._array )
776
- return self .__class__ ._new (res )
805
+ return self .__class__ ._new (res , device = self . device )
777
806
778
807
def __neg__ (self : Array , / ) -> Array :
779
808
"""
@@ -782,18 +811,19 @@ def __neg__(self: Array, /) -> Array:
782
811
if self .dtype not in _numeric_dtypes :
783
812
raise TypeError ("Only numeric dtypes are allowed in __neg__" )
784
813
res = self ._array .__neg__ ()
785
- return self .__class__ ._new (res )
814
+ return self .__class__ ._new (res , device = self . device )
786
815
787
816
def __or__ (self : Array , other : Union [int , bool , Array ], / ) -> Array :
788
817
"""
789
818
Performs the operation __or__.
790
819
"""
820
+ other = self ._check_device (other )
791
821
other = self ._check_allowed_dtypes (other , "integer or boolean" , "__or__" )
792
822
if other is NotImplemented :
793
823
return other
794
824
self , other = self ._normalize_two_args (self , other )
795
825
res = self ._array .__or__ (other ._array )
796
- return self .__class__ ._new (res )
826
+ return self .__class__ ._new (res , device = self . device )
797
827
798
828
def __pos__ (self : Array , / ) -> Array :
799
829
"""
@@ -802,14 +832,15 @@ def __pos__(self: Array, /) -> Array:
802
832
if self .dtype not in _numeric_dtypes :
803
833
raise TypeError ("Only numeric dtypes are allowed in __pos__" )
804
834
res = self ._array .__pos__ ()
805
- return self .__class__ ._new (res )
835
+ return self .__class__ ._new (res , device = self . device )
806
836
807
837
def __pow__ (self : Array , other : Union [int , float , Array ], / ) -> Array :
808
838
"""
809
839
Performs the operation __pow__.
810
840
"""
811
841
from ._elementwise_functions import pow
812
842
843
+ other = self ._check_device (other )
813
844
other = self ._check_allowed_dtypes (other , "numeric" , "__pow__" )
814
845
if other is NotImplemented :
815
846
return other
0 commit comments