Skip to content

Commit ebb752f

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
fix nonzero upper bound (#393)
Summary: Pull Request resolved: #393 If nonzero's input has `n` dimensions, then the resulting output tensor is of size `(z, n)`, where `z` is the total number of non-zero elements in the input tensor. Therefore, an upper bound for out's shape is [input.numel(), input.dim()] Reviewed By: ydwu4 Differential Revision: D49342476 fbshipit-source-id: ac70fd83090cc80c93315b9fc5be830afb6207df
1 parent 60ea5c6 commit ebb752f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

exir/passes/sym_shape_eval_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def inference_deco(f: Callable):
2929
@register_upper_bound_inference(exir_ops.edge.aten.nonzero.default)
3030
@register_upper_bound_inference(torch.ops.aten.nonzero.default)
3131
def nonzero(args, kwargs) -> List[Optional[int]]:
32-
return [eval_expr(args[0].shape[0]), len(args[0].shape)]
32+
return [eval_expr(args[0].numel()), len(args[0].shape)]
3333

3434

3535
class HintBasedSymShapeEvalPass(PassBase):

0 commit comments

Comments
 (0)