File tree Expand file tree Collapse file tree 1 file changed +1
-7
lines changed
py/torch_tensorrt/dynamo/conversion Expand file tree Collapse file tree 1 file changed +1
-7
lines changed Original file line number Diff line number Diff line change @@ -726,7 +726,6 @@ def broadcast(
726
726
727
727
def get_axes_for_reduce_op (
728
728
dim : Union [int , Sequence [int ]],
729
- has_implicit_batch_dimension : bool = False ,
730
729
) -> int :
731
730
"""
732
731
TensorRT reduce layer relies on the binary representation of axes to
@@ -736,8 +735,6 @@ def get_axes_for_reduce_op(
736
735
Args:
737
736
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
738
737
that will be used to generate axes for TensorRT.
739
- has_implicit_batch_dimension (bool): Whether the TensorRT network is
740
- using implicit batch dimension.
741
738
742
739
Returns:
743
740
An integer which binary form can be used as axes for TensorRT reduce
@@ -746,12 +743,9 @@ def get_axes_for_reduce_op(
746
743
if isinstance (dim , int ):
747
744
dim = (dim ,)
748
745
749
- if has_implicit_batch_dimension :
750
- assert 0 not in dim , "Can't reduce over batch dimension when it's implicit."
751
-
752
746
axes = 0
753
747
for d in dim :
754
- axes |= 1 << ( d - ( 1 if has_implicit_batch_dimension else 0 ))
748
+ axes |= 1 << d
755
749
756
750
return axes
757
751
You can’t perform that action at this time.
0 commit comments