5
5
import logging
6
6
import warnings
7
7
from collections import OrderedDict , namedtuple
8
- from typing import TYPE_CHECKING , Callable , List , Optional , Set , Tuple , Union
8
+ from typing import TYPE_CHECKING , Callable , List , Optional , Sequence , Set , Tuple , Union
9
+ from typing import cast as type_cast
9
10
10
11
import numpy as np
11
12
21
22
graph_inputs ,
22
23
)
23
24
from aesara .graph .op import get_test_value
25
+ from aesara .graph .type import HasDataType
24
26
from aesara .graph .utils import TestValueError
25
- from aesara .tensor .basic import AllocEmpty , get_scalar_constant_value
27
+ from aesara .tensor .basic import AllocEmpty , cast , get_scalar_constant_value
26
28
from aesara .tensor .subtensor import set_subtensor
27
29
from aesara .tensor .var import TensorConstant
28
30
@@ -55,8 +57,11 @@ def safe_new(
55
57
nw_name = None
56
58
57
59
if isinstance (x , Constant ):
58
- if dtype and x .dtype != dtype :
59
- casted_x = x .astype (dtype )
60
+ # TODO: Do something better about this
61
+ assert isinstance (x .type , HasDataType )
62
+
63
+ if dtype and x .type .dtype != dtype :
64
+ casted_x = cast (x , dtype )
60
65
nwx = type (x )(casted_x .type , x .data , x .name )
61
66
nwx .tag = copy .copy (x .tag )
62
67
return nwx
@@ -89,8 +94,12 @@ def safe_new(
89
94
pass
90
95
91
96
# Cast `x` if needed. If `x` has a test value, this will also cast it.
92
- if dtype and x .dtype != dtype :
93
- x = x .astype (dtype )
97
+ if dtype :
98
+ # TODO: Do something better about this
99
+ assert isinstance (x .type , HasDataType )
100
+
101
+ if x .type .dtype != dtype :
102
+ x = cast (x , dtype )
94
103
95
104
nw_x = x .type ()
96
105
nw_x .name = nw_name
@@ -717,13 +726,13 @@ class ScanArgs:
717
726
718
727
def __init__ (
719
728
self ,
720
- outer_inputs ,
721
- outer_outputs ,
722
- _inner_inputs ,
723
- _inner_outputs ,
724
- info ,
725
- as_while ,
726
- clone = True ,
729
+ outer_inputs : Sequence [ Variable ] ,
730
+ outer_outputs : Sequence [ Variable ] ,
731
+ _inner_inputs : Sequence [ Variable ] ,
732
+ _inner_outputs : Sequence [ Variable ] ,
733
+ info : "ScanInfo" ,
734
+ as_while : bool ,
735
+ clone : Optional [ bool ] = True ,
727
736
):
728
737
self .n_steps = outer_inputs [0 ]
729
738
self .as_while = as_while
@@ -745,16 +754,16 @@ def __init__(
745
754
q = 0
746
755
747
756
n_seqs = info .n_seqs
748
- self .outer_in_seqs = outer_inputs [p : p + n_seqs ]
749
- self .inner_in_seqs = inner_inputs [q : q + n_seqs ]
757
+ self .outer_in_seqs = list ( outer_inputs [p : p + n_seqs ])
758
+ self .inner_in_seqs = list ( inner_inputs [q : q + n_seqs ])
750
759
p += n_seqs
751
760
q += n_seqs
752
761
753
762
n_mit_mot = info .n_mit_mot
754
763
n_mit_sot = info .n_mit_sot
755
764
756
- self .mit_mot_in_slices = info .tap_array [:n_mit_mot ]
757
- self .mit_sot_in_slices = info .tap_array [n_mit_mot : n_mit_mot + n_mit_sot ]
765
+ self .mit_mot_in_slices = list ( info .tap_array [:n_mit_mot ])
766
+ self .mit_sot_in_slices = list ( info .tap_array [n_mit_mot : n_mit_mot + n_mit_sot ])
758
767
759
768
n_mit_mot_ins = sum (len (s ) for s in self .mit_mot_in_slices )
760
769
n_mit_sot_ins = sum (len (s ) for s in self .mit_sot_in_slices )
@@ -775,63 +784,63 @@ def __init__(
775
784
qq += len (sl )
776
785
q += n_mit_sot_ins
777
786
778
- self .outer_in_mit_mot = outer_inputs [p : p + n_mit_mot ]
787
+ self .outer_in_mit_mot = list ( outer_inputs [p : p + n_mit_mot ])
779
788
p += n_mit_mot
780
- self .outer_in_mit_sot = outer_inputs [p : p + n_mit_sot ]
789
+ self .outer_in_mit_sot = list ( outer_inputs [p : p + n_mit_sot ])
781
790
p += n_mit_sot
782
791
783
792
n_sit_sot = info .n_sit_sot
784
- self .outer_in_sit_sot = outer_inputs [p : p + n_sit_sot ]
785
- self .inner_in_sit_sot = inner_inputs [q : q + n_sit_sot ]
793
+ self .outer_in_sit_sot = list ( outer_inputs [p : p + n_sit_sot ])
794
+ self .inner_in_sit_sot = list ( inner_inputs [q : q + n_sit_sot ])
786
795
p += n_sit_sot
787
796
q += n_sit_sot
788
797
789
798
n_shared_outs = info .n_shared_outs
790
- self .outer_in_shared = outer_inputs [p : p + n_shared_outs ]
791
- self .inner_in_shared = inner_inputs [q : q + n_shared_outs ]
799
+ self .outer_in_shared = list ( outer_inputs [p : p + n_shared_outs ])
800
+ self .inner_in_shared = list ( inner_inputs [q : q + n_shared_outs ])
792
801
p += n_shared_outs
793
802
q += n_shared_outs
794
803
795
804
n_nit_sot = info .n_nit_sot
796
- self .outer_in_nit_sot = outer_inputs [p : p + n_nit_sot ]
805
+ self .outer_in_nit_sot = list ( outer_inputs [p : p + n_nit_sot ])
797
806
p += n_nit_sot
798
807
799
- self .outer_in_non_seqs = outer_inputs [p :]
800
- self .inner_in_non_seqs = inner_inputs [q :]
808
+ self .outer_in_non_seqs = list ( outer_inputs [p :])
809
+ self .inner_in_non_seqs = list ( inner_inputs [q :])
801
810
802
811
# now for the outputs
803
812
p = 0
804
813
q = 0
805
814
806
815
self .mit_mot_out_slices = info .mit_mot_out_slices
807
816
n_mit_mot_outs = info .n_mit_mot_outs
808
- self .outer_out_mit_mot = outer_outputs [p : p + n_mit_mot ]
809
- iomm = inner_outputs [q : q + n_mit_mot_outs ]
810
- self .inner_out_mit_mot = ()
817
+ self .outer_out_mit_mot = list ( outer_outputs [p : p + n_mit_mot ])
818
+ iomm = list ( inner_outputs [q : q + n_mit_mot_outs ])
819
+ self .inner_out_mit_mot : Tuple [ List [ Variable ], ...] = ()
811
820
qq = 0
812
821
for sl in self .mit_mot_out_slices :
813
822
self .inner_out_mit_mot += (iomm [qq : qq + len (sl )],)
814
823
qq += len (sl )
815
824
p += n_mit_mot
816
825
q += n_mit_mot_outs
817
826
818
- self .outer_out_mit_sot = outer_outputs [p : p + n_mit_sot ]
819
- self .inner_out_mit_sot = inner_outputs [q : q + n_mit_sot ]
827
+ self .outer_out_mit_sot = list ( outer_outputs [p : p + n_mit_sot ])
828
+ self .inner_out_mit_sot = list ( inner_outputs [q : q + n_mit_sot ])
820
829
p += n_mit_sot
821
830
q += n_mit_sot
822
831
823
- self .outer_out_sit_sot = outer_outputs [p : p + n_sit_sot ]
824
- self .inner_out_sit_sot = inner_outputs [q : q + n_sit_sot ]
832
+ self .outer_out_sit_sot = list ( outer_outputs [p : p + n_sit_sot ])
833
+ self .inner_out_sit_sot = list ( inner_outputs [q : q + n_sit_sot ])
825
834
p += n_sit_sot
826
835
q += n_sit_sot
827
836
828
- self .outer_out_nit_sot = outer_outputs [p : p + n_nit_sot ]
829
- self .inner_out_nit_sot = inner_outputs [q : q + n_nit_sot ]
837
+ self .outer_out_nit_sot = list ( outer_outputs [p : p + n_nit_sot ])
838
+ self .inner_out_nit_sot = list ( inner_outputs [q : q + n_nit_sot ])
830
839
p += n_nit_sot
831
840
q += n_nit_sot
832
841
833
- self .outer_out_shared = outer_outputs [p : p + n_shared_outs ]
834
- self .inner_out_shared = inner_outputs [q : q + n_shared_outs ]
842
+ self .outer_out_shared = list ( outer_outputs [p : p + n_shared_outs ])
843
+ self .inner_out_shared = list ( inner_outputs [q : q + n_shared_outs ])
835
844
p += n_shared_outs
836
845
q += n_shared_outs
837
846
@@ -978,12 +987,18 @@ def get_alt_field(
978
987
-------
979
988
The alternate variable.
980
989
"""
990
+ _var_info : FieldInfo
981
991
if not isinstance (var_info , FieldInfo ):
982
- var_info = self .find_among_fields (var_info )
992
+ find_var_info = self .find_among_fields (var_info )
993
+ if find_var_info is None :
994
+ raise ValueError (f"Couldn't find { var_info } among fields" )
995
+ _var_info = find_var_info
996
+ else :
997
+ _var_info = var_info
983
998
984
- alt_type = var_info .name [(var_info .name .index ("_" , 6 ) + 1 ) :]
985
- alt_var = getattr (self , f"{ alt_prefix } _{ alt_type } " )[var_info .index ]
986
- return alt_var
999
+ alt_type = _var_info .name [(_var_info .name .index ("_" , 6 ) + 1 ) :]
1000
+ alt_var = getattr (self , f"{ alt_prefix } _{ alt_type } " )[_var_info .index ]
1001
+ return type_cast ( Variable , alt_var )
987
1002
988
1003
def find_among_fields (
989
1004
self , i : Variable , field_filter : Callable [[str ], bool ] = default_filter
@@ -1054,7 +1069,7 @@ def _remove_from_fields(
1054
1069
return field_info
1055
1070
1056
1071
def get_dependent_nodes (
1057
- self , i : Variable , seen : Optional [Set [int ]] = None
1072
+ self , i : Variable , seen : Optional [Set [Variable ]] = None
1058
1073
) -> Set [Variable ]:
1059
1074
if seen is None :
1060
1075
seen = {i }
0 commit comments