@@ -32,6 +32,137 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
32
32
return [eval_expr (args [0 ].numel ()), len (args [0 ].shape )]
33
33
34
34
35
+ @register_upper_bound_inference (exir_ops .edge .aten .index .Tensor )
36
+ @register_upper_bound_inference (torch .ops .aten .index .Tensor )
37
+ def index_Tensor (args , kwargs ) -> List [Optional [int ]]:
38
+ tensor = args [0 ]
39
+ indices = args [1 ]
40
+
41
+ # Compute numbers of contiguous blocks of non-null indices.
42
+ # For example, if A, B, C, D, E are non-null tensors, then
43
+ # [None, None, A, B, None, C, D, E, None] has 2 blocks.
44
+ index_blocks = 0
45
+ in_block = False
46
+ for index in indices :
47
+ if index is not None :
48
+ if not in_block :
49
+ in_block = True
50
+ index_blocks += 1
51
+ else :
52
+ in_block = False
53
+
54
+ if index_blocks == 0 :
55
+ # If no dimensions are actually being indexed, either because the indices list is empty
56
+ # or all indices are null, then the result is just the same as the input tensor.
57
+ return tensor .shape
58
+
59
+ adjacent = index_blocks == 1
60
+
61
+ # Number of leading null indices in the indices list.
62
+ num_leading_null_indices = 0
63
+ for index in indices :
64
+ if index is None :
65
+ num_leading_null_indices += 1
66
+ else :
67
+ break
68
+
69
+ # Number of null indices in total in the indices list.
70
+ num_null_indices = sum ([ix is None for ix in indices ])
71
+
72
+ # Number of dimensions being indexed (bool/byte tensors are treated as masks, and index as
73
+ # many input dimensions as their number of dimensions.
74
+ num_indexed_dims = 0
75
+ mask_indices = []
76
+ int_indices = []
77
+ for index in indices :
78
+ if index is not None :
79
+ if index .dtype in [torch .bool , torch .uint8 ]:
80
+ num_indexed_dims += index .dim ()
81
+ mask_indices .append (index )
82
+ else :
83
+ num_indexed_dims += 1
84
+ int_indices .append (index )
85
+
86
+ broadcast_sizes = []
87
+ if len (int_indices ) > 0 :
88
+ # All of the integer index tensors (non-mask & non-null index tensors) need to broadcast.
89
+ # We need to compute the resulting shape.
90
+ curr_ndim = 0
91
+ rev_shape = []
92
+ for index in int_indices :
93
+ for j in range (index .dim ()):
94
+ rev_j_size = eval_expr (index .size (index .dim () - j - 1 ))
95
+ if j >= curr_ndim :
96
+ curr_ndim += 1
97
+ rev_shape .append (rev_j_size )
98
+ elif rev_shape [j ] == 1 :
99
+ rev_shape [j ] = rev_j_size
100
+ broadcast_sizes = list (reversed (rev_shape ))
101
+
102
+ # The number of True elements in the mask indices must broadcast (i.e some might be 1
103
+ # but the others must all be equal). They also need to broadcast with broadcast_sizes[0]
104
+ # Therefore, if broadcast_sizes[0] != 1, we don't need to worry about the mask indices,
105
+ # since we are assuming that the inputs are valid. However, if broadcast_sizes[0] = 1,
106
+ # we do need to consider them. We can't know how many True elements there are in each mask,
107
+ # but we know that the broadcasted size, can't be larger than the minimum number of True
108
+ # elements across all mask indices with a number of elements other than 1.
109
+ if len (mask_indices ) > 0 and (len (broadcast_sizes ) == 0 or broadcast_sizes [0 ] == 1 ):
110
+ upper_bound_broadcast_size = 1
111
+ intialized = False
112
+ for mask in mask_indices :
113
+ mask_numel = eval_expr (mask .numel ())
114
+ if mask_numel != 1 :
115
+ if intialized :
116
+ assert isinstance (
117
+ mask_numel , int
118
+ ), "Expect mask_numel to be a concrete int"
119
+ assert isinstance (
120
+ upper_bound_broadcast_size , int
121
+ ), "Expect upper_bound_broadcast_size to be a concrete int"
122
+ if upper_bound_broadcast_size > mask_numel :
123
+ upper_bound_broadcast_size = mask_numel
124
+ else :
125
+ upper_bound_broadcast_size = mask_numel
126
+ intialized = True
127
+ if len (broadcast_sizes ) == 0 :
128
+ broadcast_sizes .append (upper_bound_broadcast_size )
129
+ else :
130
+ broadcast_sizes [0 ] = upper_bound_broadcast_size
131
+
132
+ broadcast_ndim = len (broadcast_sizes )
133
+
134
+ out_ndim = tensor .dim () + broadcast_ndim - num_indexed_dims
135
+ out_sizes : List [Optional [int ]] = [0 for _ in range (out_ndim )]
136
+
137
+ if adjacent :
138
+ for i in range (num_leading_null_indices ):
139
+ out_sizes [i ] = eval_expr (tensor .size (i ))
140
+ for i in range (broadcast_ndim ):
141
+ out_sizes [i + num_leading_null_indices ] = broadcast_sizes [i ]
142
+ for i in range (num_indexed_dims + num_leading_null_indices , tensor .dim ()):
143
+ out_sizes [i + broadcast_ndim - num_indexed_dims ] = eval_expr (tensor .size (i ))
144
+ else :
145
+ for i in range (broadcast_ndim ):
146
+ out_sizes [i ] = broadcast_sizes [i ]
147
+ in_i = 0
148
+ out_i = broadcast_ndim
149
+ for index in indices :
150
+ if index is None :
151
+ out_sizes [out_i ] = eval_expr (tensor .size (in_i ))
152
+ out_i += 1
153
+ in_i += 1
154
+ else :
155
+ if index .dtype in [torch .bool , torch .uint8 ]:
156
+ in_i += index .dim ()
157
+ else :
158
+ in_i += 1
159
+
160
+ for i in range (num_indexed_dims + num_null_indices , tensor .dim ()):
161
+ out_sizes [i + broadcast_ndim - num_indexed_dims ] = eval_expr (tensor .size (i ))
162
+
163
+ return out_sizes
164
+
165
+
35
166
class HintBasedSymShapeEvalPass (PassBase ):
36
167
"""
37
168
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
0 commit comments