Skip to content

Commit ae7e6c8

Browse files
authored
fix: get_padded_shape_tensors can now handle dynamic pads (#3123)
1 parent d75f588 commit ae7e6c8

File tree

1 file changed

+22
-11
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+22
-11
lines changed

py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_padded_shape_tensors(
2626
source_ir: Optional[SourceIR],
2727
name: str,
2828
input: TRTTensor,
29-
pad: Sequence[int],
29+
pad: Sequence[Union[int, TRTTensor]],
3030
) -> TRTTensor:
3131
rank = len(input.shape)
3232
if len(pad) // 2 > rank:
@@ -47,11 +47,11 @@ def get_padded_shape_tensors(
4747
start_list = [0] * rank
4848
for i in range(len(pad) // 2):
4949
dim_index = rank - (i + 1)
50-
pad_before = pad[i * 2]
51-
pad_after = pad[i * 2 + 1]
50+
pad_before = get_trt_tensor(ctx, pad[i * 2], f"{name}_pad_before_{i}")
51+
pad_after = get_trt_tensor(ctx, pad[i * 2 + 1], f"{name}_pad_after_{i}")
5252

53-
pad_sum = get_trt_tensor(
54-
ctx, pad_before + pad_after, f"{name}_pad_sum_{i}", dtype=np.int32
53+
pad_sum = impl.elementwise.add(
54+
ctx, target, source_ir, f"{name}_pad_sum_{i}", pad_before, pad_after
5555
)
5656
dim_shape = ctx.net.add_slice(
5757
input_shape_tensor,
@@ -63,7 +63,9 @@ def get_padded_shape_tensors(
6363
new_dim_shape = impl.elementwise.add(
6464
ctx, target, source_ir, f"{name}_shape_dim_{i}", dim_shape, pad_sum
6565
)
66-
start_list[dim_index] = -pad_before
66+
start_list[dim_index] = impl.elementwise.sub(
67+
ctx, target, source_ir, f"{name}_pad_before_neg_{i}", 0, pad_before
68+
)
6769

6870
slices = []
6971
for j in range(rank):
@@ -79,14 +81,23 @@ def get_padded_shape_tensors(
7981
).get_output(0)
8082
)
8183
padded_shape_tensor = impl.cat.cat(
82-
ctx, target, source_ir, f"{name}_cat_dim_{i}", slices, 0
84+
ctx,
85+
target,
86+
source_ir,
87+
f"{name}_cat_dim_{i}",
88+
slices,
89+
0,
90+
cast_dtype=padded_shape_tensor.dtype,
8391
)
8492

85-
start_indices_tensor = get_trt_tensor(
93+
start_indices_tensor = impl.cat.cat(
8694
ctx,
87-
np.array(start_list, dtype=np.int32),
95+
target,
96+
source_ir,
8897
f"{name}_start_indices_tensor",
89-
dtype=np.int32,
98+
start_list,
99+
0,
100+
cast_dtype=padded_shape_tensor.dtype,
90101
)
91102

92103
return start_indices_tensor, padded_shape_tensor
@@ -98,7 +109,7 @@ def constant_padNd(
98109
source_ir: Optional[SourceIR],
99110
name: str,
100111
input: TRTTensor,
101-
pad: Sequence[int],
112+
pad: Sequence[Union[int, TRTTensor]],
102113
value: Union[int, float] = 0,
103114
) -> TRTTensor:
104115

0 commit comments

Comments
 (0)