Skip to content

Commit 0969ca8

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Register upper bound inference for index.Tensor (#1018)
Summary: Pull Request resolved: #1018 Register upper bound inference for index.Tensor for HintBasedSymShapeEvalPass Reviewed By: guangy10 Differential Revision: D49394788 fbshipit-source-id: 3820054c10448a9a9099e49b4f830abd56d872aa
1 parent 8d52580 commit 0969ca8

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

exir/passes/sym_shape_eval_pass.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,137 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
3232
return [eval_expr(args[0].numel()), len(args[0].shape)]
3333

3434

35+
@register_upper_bound_inference(exir_ops.edge.aten.index.Tensor)
36+
@register_upper_bound_inference(torch.ops.aten.index.Tensor)
37+
def index_Tensor(args, kwargs) -> List[Optional[int]]:
38+
tensor = args[0]
39+
indices = args[1]
40+
41+
# Compute numbers of contiguous blocks of non-null indices.
42+
# For example, if A, B, C, D, E are non-null tensors, then
43+
# [None, None, A, B, None, C, D, E, None] has 2 blocks.
44+
index_blocks = 0
45+
in_block = False
46+
for index in indices:
47+
if index is not None:
48+
if not in_block:
49+
in_block = True
50+
index_blocks += 1
51+
else:
52+
in_block = False
53+
54+
if index_blocks == 0:
55+
# If no dimensions are actually being indexed, either because the indices list is empty
56+
# or all indices are null, then the result is just the same as the input tensor.
57+
return tensor.shape
58+
59+
adjacent = index_blocks == 1
60+
61+
# Number of leading null indices in the indices list.
62+
num_leading_null_indices = 0
63+
for index in indices:
64+
if index is None:
65+
num_leading_null_indices += 1
66+
else:
67+
break
68+
69+
# Number of null indices in total in the indices list.
70+
num_null_indices = sum([ix is None for ix in indices])
71+
72+
# Number of dimensions being indexed (bool/byte tensors are treated as masks, and index as
73+
# many input dimensions as their number of dimensions.
74+
num_indexed_dims = 0
75+
mask_indices = []
76+
int_indices = []
77+
for index in indices:
78+
if index is not None:
79+
if index.dtype in [torch.bool, torch.uint8]:
80+
num_indexed_dims += index.dim()
81+
mask_indices.append(index)
82+
else:
83+
num_indexed_dims += 1
84+
int_indices.append(index)
85+
86+
broadcast_sizes = []
87+
if len(int_indices) > 0:
88+
# All of the integer index tensors (non-mask & non-null index tensors) need to broadcast.
89+
# We need to compute the resulting shape.
90+
curr_ndim = 0
91+
rev_shape = []
92+
for index in int_indices:
93+
for j in range(index.dim()):
94+
rev_j_size = eval_expr(index.size(index.dim() - j - 1))
95+
if j >= curr_ndim:
96+
curr_ndim += 1
97+
rev_shape.append(rev_j_size)
98+
elif rev_shape[j] == 1:
99+
rev_shape[j] = rev_j_size
100+
broadcast_sizes = list(reversed(rev_shape))
101+
102+
# The number of True elements in the mask indices must broadcast (i.e some might be 1
103+
# but the others must all be equal). They also need to broadcast with broadcast_sizes[0]
104+
# Therefore, if broadcast_sizes[0] != 1, we don't need to worry about the mask indices,
105+
# since we are assuming that the inputs are valid. However, if broadcast_sizes[0] = 1,
106+
# we do need to consider them. We can't know how many True elements there are in each mask,
107+
# but we know that the broadcasted size, can't be larger than the minimum number of True
108+
# elements across all mask indices with a number of elements other than 1.
109+
if len(mask_indices) > 0 and (len(broadcast_sizes) == 0 or broadcast_sizes[0] == 1):
110+
upper_bound_broadcast_size = 1
111+
intialized = False
112+
for mask in mask_indices:
113+
mask_numel = eval_expr(mask.numel())
114+
if mask_numel != 1:
115+
if intialized:
116+
assert isinstance(
117+
mask_numel, int
118+
), "Expect mask_numel to be a concrete int"
119+
assert isinstance(
120+
upper_bound_broadcast_size, int
121+
), "Expect upper_bound_broadcast_size to be a concrete int"
122+
if upper_bound_broadcast_size > mask_numel:
123+
upper_bound_broadcast_size = mask_numel
124+
else:
125+
upper_bound_broadcast_size = mask_numel
126+
intialized = True
127+
if len(broadcast_sizes) == 0:
128+
broadcast_sizes.append(upper_bound_broadcast_size)
129+
else:
130+
broadcast_sizes[0] = upper_bound_broadcast_size
131+
132+
broadcast_ndim = len(broadcast_sizes)
133+
134+
out_ndim = tensor.dim() + broadcast_ndim - num_indexed_dims
135+
out_sizes: List[Optional[int]] = [0 for _ in range(out_ndim)]
136+
137+
if adjacent:
138+
for i in range(num_leading_null_indices):
139+
out_sizes[i] = eval_expr(tensor.size(i))
140+
for i in range(broadcast_ndim):
141+
out_sizes[i + num_leading_null_indices] = broadcast_sizes[i]
142+
for i in range(num_indexed_dims + num_leading_null_indices, tensor.dim()):
143+
out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))
144+
else:
145+
for i in range(broadcast_ndim):
146+
out_sizes[i] = broadcast_sizes[i]
147+
in_i = 0
148+
out_i = broadcast_ndim
149+
for index in indices:
150+
if index is None:
151+
out_sizes[out_i] = eval_expr(tensor.size(in_i))
152+
out_i += 1
153+
in_i += 1
154+
else:
155+
if index.dtype in [torch.bool, torch.uint8]:
156+
in_i += index.dim()
157+
else:
158+
in_i += 1
159+
160+
for i in range(num_indexed_dims + num_null_indices, tensor.dim()):
161+
out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))
162+
163+
return out_sizes
164+
165+
35166
class HintBasedSymShapeEvalPass(PassBase):
36167
"""
37168
If we enable dynamic shape tracing, a tensor's shape may become a symbolic

0 commit comments

Comments
 (0)