Skip to content

Commit 070a03a

Browse files
committed
chore: Remove has_implicit_batch_dimension flag in get_axes_for_reduce_op()
1 parent 225d069 commit 070a03a

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,6 @@ def broadcast(
726726

727727
def get_axes_for_reduce_op(
728728
dim: Union[int, Sequence[int]],
729-
has_implicit_batch_dimension: bool = False,
730729
) -> int:
731730
"""
732731
TensorRT reduce layer relies on the binary representation of axes to
@@ -736,8 +735,6 @@ def get_axes_for_reduce_op(
736735
Args:
737736
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
738737
that will be used to generate axes for TensorRT.
739-
has_implicit_batch_dimension (bool): Whether the TensorRT network is
740-
using implicit batch dimension.
741738
742739
Returns:
743740
An integer which binary form can be used as axes for TensorRT reduce
@@ -746,12 +743,9 @@ def get_axes_for_reduce_op(
746743
if isinstance(dim, int):
747744
dim = (dim,)
748745

749-
if has_implicit_batch_dimension:
750-
assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."
751-
752746
axes = 0
753747
for d in dim:
754-
axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
748+
axes |= 1 << d
755749

756750
return axes
757751

0 commit comments

Comments
 (0)