Skip to content

Commit 46cc402

Browse files
committed
fix: Remove references to implicit batch for TRT 10 (#2773)
1 parent 2d6ffb4 commit 46cc402

File tree

4 files changed

+9
-30
lines changed

4 files changed

+9
-30
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def batch_norm(
5858

5959
# For BatchNorm1d, reshape 1d to 2d
6060
output_shape = input.shape
61-
if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4:
61+
if len(input.shape) < 4:
6262
assert (
6363
len(get_dynamic_dims(input.shape)) <= 1
6464
), "BatchNorm1D with more than one dynamic dims is not currently supported."
@@ -75,7 +75,7 @@ def batch_norm(
7575
output = layer.get_output(0)
7676

7777
# For BatchNorm1d, reshape output back to 1d
78-
if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4:
78+
if len(output_shape) < 4:
7979
output = impl.shuffle.reshape(
8080
ctx,
8181
target,
@@ -411,7 +411,7 @@ def softmax(
411411
input: TRTTensor,
412412
dim: Optional[Any] = None,
413413
) -> Union[TRTTensor, Sequence[TRTTensor]]:
414-
input_ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0)
414+
input_ranks = len(input.shape)
415415

416416
if not isinstance(input, TRTTensor):
417417
raise RuntimeError(
@@ -433,9 +433,6 @@ def get_softmax_dim(ndim: int) -> int:
433433
dim = cast(int, dim)
434434

435435
dim = get_positive_dim(dim, input_ranks)
436-
if ctx.net.has_implicit_batch_dimension:
437-
assert dim != 0, "Can't apply softmax on batch dimension when it's implicit."
438-
dim -= 1
439436

440437
layer = ctx.net.add_softmax(input)
441438
layer.axes = 1 << dim

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,12 @@ def select(
4040
"of the TensorRT region!"
4141
)
4242

43-
ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0)
43+
ranks = len(input.shape)
4444
dim = get_positive_dim(cast(int, dim), ranks)
4545
dynamic_shape = has_dynamic_shape(input.shape)
46-
if ctx.net.has_implicit_batch_dimension:
47-
if dim == 0:
48-
raise RuntimeError(
49-
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
50-
)
51-
dim = dim - 1
52-
else:
53-
if dynamic_shape:
54-
# Check whether slice target dim is dynamic shape dim
55-
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
46+
if dynamic_shape:
47+
# Check whether slice target dim is dynamic shape dim
48+
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
5649
index = index
5750

5851
if index >= input.shape[dim]:

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ def squeeze(
3232
for dim in dims:
3333
dim = get_positive_dim(
3434
dim,
35-
len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0),
35+
len(input.shape),
3636
)
37-
if ctx.net.has_implicit_batch_dimension:
38-
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
39-
dim -= 1
4037

4138
assert input.shape[dim] != -1, "We don't support squeeze dynamic dim."
4239
assert (

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,9 @@ def unsqueeze(
2929

3030
dim = cast(int, dim)
3131

32-
input_shape_size = (
33-
len(input_val.shape) + 1
34-
if ctx.net.has_implicit_batch_dimension
35-
else len(input_val.shape)
36-
)
32+
input_shape_size = len(input_val.shape)
3733
dim = get_positive_dim(dim, input_shape_size + 1)
3834

39-
if ctx.net.has_implicit_batch_dimension:
40-
assert dim != 0
41-
dim -= 1
42-
4335
assert (
4436
len(get_dynamic_dims(input_val.shape)) <= 1
4537
), "Currently we don't support unsqueeze with more than one dynamic dims."

0 commit comments

Comments
 (0)