Skip to content

Commit f72e842

Browse files
committed
fix: Handle dynamic shapes in where ops
1 parent c21866b commit f72e842

File tree

1 file changed

+17
-62
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/condition

1 file changed

+17
-62
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Union
22

33
import numpy as np
4-
import tensorrt as trt
54
import torch
65
from torch.fx.node import Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -10,8 +9,7 @@
109
broadcastable,
1110
get_trt_tensor,
1211
)
13-
from torch_tensorrt.dynamo.conversion.impl.slice import expand
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
12+
from torch_tensorrt.fx.converters.converter_utils import prepend_ones, set_layer_name
1513
from torch_tensorrt.fx.types import TRTTensor
1614

1715

@@ -30,73 +28,30 @@ def where(
3028
x_shape = list(input.shape)
3129
y_shape = list(other.shape)
3230
condition_shape = list(condition.shape)
31+
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))
3332

34-
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
35-
36-
# expand shape
3733
if not isinstance(condition, TRTTensor):
3834
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
39-
if condition_shape != output_shape:
40-
condition = (
41-
condition.expand(output_shape)
42-
if isinstance(condition, torch.Tensor)
43-
else np.broadcast_to(condition, output_shape)
44-
)
45-
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
46-
else:
47-
assert condition.dtype == trt.bool, "mask dtype is not bool!"
48-
if condition_shape != output_shape:
49-
condition_val = expand(
50-
ctx, target, source_ir, f"{name}_expand", condition, output_shape
51-
)
52-
else:
53-
condition_val = condition
35+
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
36+
diff = max_shape_len - len(condition_shape)
37+
if diff > 0:
38+
condition = prepend_ones(
39+
ctx.net, condition, f"{name}_condition_broadcast", diff
40+
)
5441

5542
if not isinstance(input, TRTTensor):
56-
if x_shape != output_shape:
57-
# special case where 1 element in input
58-
if len(input.shape) == 0:
59-
input = (
60-
input.unsqueeze(0)
61-
if isinstance(input, torch.Tensor)
62-
else np.expand_dims(input, axis=0)
63-
)
64-
input = (
65-
input.expand(output_shape)
66-
if isinstance(input, torch.Tensor)
67-
else np.broadcast_to(input, output_shape)
68-
)
69-
x_val = get_trt_tensor(ctx, input, f"{name}_x")
70-
else:
71-
x_val = input
72-
if x_shape != output_shape:
73-
x_val = expand(
74-
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
75-
)
43+
input = get_trt_tensor(ctx, input, f"{name}_x")
44+
diff = max_shape_len - len(x_shape)
45+
if diff > 0:
46+
input = prepend_ones(ctx.net, input, f"{name}_input_broadcast", diff)
7647

7748
if not isinstance(other, TRTTensor):
78-
if y_shape != output_shape:
79-
# special case where 1 element in other
80-
if len(other.shape) == 0:
81-
other = (
82-
other.unsqueeze(0)
83-
if isinstance(other, torch.Tensor)
84-
else np.expand_dims(other, axis=0)
85-
)
86-
other = (
87-
other.expand(output_shape)
88-
if isinstance(other, torch.Tensor)
89-
else np.broadcast_to(other, output_shape)
90-
)
91-
y_val = get_trt_tensor(ctx, other, f"{name}_y")
92-
else:
93-
y_val = other
94-
if y_shape != output_shape:
95-
y_val = expand(
96-
ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape
97-
)
49+
other = get_trt_tensor(ctx, other, f"{name}_y")
50+
diff = max_shape_len - len(y_shape)
51+
if diff > 0:
52+
other = prepend_ones(ctx.net, other, f"{name}_other_broadcast", diff)
9853

99-
return select(ctx, target, source_ir, name, x_val, y_val, condition_val)
54+
return select(ctx, target, source_ir, name, input, other, condition)
10055

10156

10257
def select(

0 commit comments

Comments
 (0)