Skip to content

Register upper bound inference for index.Tensor #1018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions exir/passes/sym_shape_eval_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,137 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
return [eval_expr(args[0].numel()), len(args[0].shape)]


@register_upper_bound_inference(exir_ops.edge.aten.index.Tensor)
@register_upper_bound_inference(torch.ops.aten.index.Tensor)
def index_Tensor(args, kwargs) -> List[Optional[int]]:
tensor = args[0]
indices = args[1]

# Compute numbers of contiguous blocks of non-null indices.
# For example, if A, B, C, D, E are non-null tensors, then
# [None, None, A, B, None, C, D, E, None] has 2 blocks.
index_blocks = 0
in_block = False
for index in indices:
if index is not None:
if not in_block:
in_block = True
index_blocks += 1
else:
in_block = False

if index_blocks == 0:
# If no dimensions are actually being indexed, either because the indices list is empty
# or all indices are null, then the result is just the same as the input tensor.
return tensor.shape

adjacent = index_blocks == 1

# Number of leading null indices in the indices list.
num_leading_null_indices = 0
for index in indices:
if index is None:
num_leading_null_indices += 1
else:
break

# Number of null indices in total in the indices list.
num_null_indices = sum([ix is None for ix in indices])

# Number of dimensions being indexed (bool/byte tensors are treated as masks, and index as
# many input dimensions as their number of dimensions.
num_indexed_dims = 0
mask_indices = []
int_indices = []
for index in indices:
if index is not None:
if index.dtype in [torch.bool, torch.uint8]:
num_indexed_dims += index.dim()
mask_indices.append(index)
else:
num_indexed_dims += 1
int_indices.append(index)

broadcast_sizes = []
if len(int_indices) > 0:
# All of the integer index tensors (non-mask & non-null index tensors) need to broadcast.
# We need to compute the resulting shape.
curr_ndim = 0
rev_shape = []
for index in int_indices:
for j in range(index.dim()):
rev_j_size = eval_expr(index.size(index.dim() - j - 1))
if j >= curr_ndim:
curr_ndim += 1
rev_shape.append(rev_j_size)
elif rev_shape[j] == 1:
rev_shape[j] = rev_j_size
broadcast_sizes = list(reversed(rev_shape))

# The number of True elements in the mask indices must broadcast (i.e some might be 1
# but the others must all be equal). They also need to broadcast with broadcast_sizes[0]
# Therefore, if broadcast_sizes[0] != 1, we don't need to worry about the mask indices,
# since we are assuming that the inputs are valid. However, if broadcast_sizes[0] = 1,
# we do need to consider them. We can't know how many True elements there are in each mask,
# but we know that the broadcasted size, can't be larger than the minimum number of True
# elements across all mask indices with a number of elements other than 1.
if len(mask_indices) > 0 and (len(broadcast_sizes) == 0 or broadcast_sizes[0] == 1):
upper_bound_broadcast_size = 1
intialized = False
for mask in mask_indices:
mask_numel = eval_expr(mask.numel())
if mask_numel != 1:
if intialized:
assert isinstance(
mask_numel, int
), "Expect mask_numel to be a concrete int"
assert isinstance(
upper_bound_broadcast_size, int
), "Expect upper_bound_broadcast_size to be a concrete int"
if upper_bound_broadcast_size > mask_numel:
upper_bound_broadcast_size = mask_numel
else:
upper_bound_broadcast_size = mask_numel
intialized = True
if len(broadcast_sizes) == 0:
broadcast_sizes.append(upper_bound_broadcast_size)
else:
broadcast_sizes[0] = upper_bound_broadcast_size

broadcast_ndim = len(broadcast_sizes)

out_ndim = tensor.dim() + broadcast_ndim - num_indexed_dims
out_sizes: List[Optional[int]] = [0 for _ in range(out_ndim)]

if adjacent:
for i in range(num_leading_null_indices):
out_sizes[i] = eval_expr(tensor.size(i))
for i in range(broadcast_ndim):
out_sizes[i + num_leading_null_indices] = broadcast_sizes[i]
for i in range(num_indexed_dims + num_leading_null_indices, tensor.dim()):
out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))
else:
for i in range(broadcast_ndim):
out_sizes[i] = broadcast_sizes[i]
in_i = 0
out_i = broadcast_ndim
for index in indices:
if index is None:
out_sizes[out_i] = eval_expr(tensor.size(in_i))
out_i += 1
in_i += 1
else:
if index.dtype in [torch.bool, torch.uint8]:
in_i += index.dim()
else:
in_i += 1

for i in range(num_indexed_dims + num_null_indices, tensor.dim()):
out_sizes[i + broadcast_ndim - num_indexed_dims] = eval_expr(tensor.size(i))

return out_sizes


class HintBasedSymShapeEvalPass(PassBase):
"""
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
Expand Down