Skip to content

Commit ac4bf90

Browse files
committed
chore: rebase and updates
1 parent e37c091 commit ac4bf90

File tree

3 files changed

+18
-27
lines changed

3 files changed

+18
-27
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
ConverterRegistry,
1616
DynamoConverterImplSignature,
1717
)
18-
from torch_tensorrt.fx.converters.converter_utils import (
18+
from torch_tensorrt.fx.converters.converter_utils import ( # noqa: F401
1919
broadcast,
2020
get_axes_for_reduce_op,
21+
prepend_ones,
22+
set_layer_name,
2123
)
2224
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
2325

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from torch_tensorrt.dynamo.conversion.converter_utils import (
99
broadcastable,
1010
get_trt_tensor,
11+
prepend_ones,
12+
set_layer_name,
1113
)
12-
from torch_tensorrt.fx.converters.converter_utils import prepend_ones, set_layer_name
1314
from torch_tensorrt.fx.types import TRTTensor
1415

1516

tests/py/dynamo/conversion/test_where_aten.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,59 +77,44 @@ def forward(self, condition):
7777
(condition,),
7878
)
7979

80-
# shape, min shape range, opt shape range, max shape range
80+
# min/opt/max shape for condition/x/y input
8181
@parameterized.expand(
8282
[
8383
(
8484
"3d_condition_3d_xshape_3d_yshape",
85-
(-1, -1, -1),
8685
(1, 1, 1),
8786
(1, 2, 3),
8887
(3, 3, 3),
89-
(-1, -1, -1),
9088
(1, 1, 1),
9189
(1, 2, 3),
9290
(3, 3, 3),
93-
(-1, -1, -1),
9491
(1, 1, 1),
9592
(1, 2, 3),
9693
(3, 3, 3),
9794
),
9895
(
9996
"1d_condition_3d_xshape_2d_yshape",
100-
(-1),
10197
(1,),
10298
(2,),
10399
(4,),
104-
(-1, -1, -1),
105100
(1, 1, 1),
106101
(3, 2, 2),
107102
(3, 2, 4),
108-
(
109-
-1,
110-
-1,
111-
),
112103
(1, 1),
113104
(2, 2),
114105
(2, 4),
115106
),
116107
(
117108
"2d_condition_3d_xshape_2d_yshape",
118-
(-1, -1),
119109
(4, 1),
120110
(4, 2),
121111
(5, 4),
122-
(-1, -1, -1),
123112
(1, 1, 1),
124113
(3, 1, 2),
125114
(3, 1, 4),
126-
(
127-
-1,
128-
-1,
129-
),
130-
(4, 1),
131-
(4, 2),
132-
(5, 4),
115+
(1, 1),
116+
(1, 2),
117+
(1, 4),
133118
),
134119
]
135120
)
@@ -140,19 +125,22 @@ def forward(self, condition, x, y):
140125

141126
input_specs = [
142127
Input(
143-
shape=args[1],
128+
min_shape=args[1],
129+
opt_shape=args[2],
130+
max_shape=args[3],
144131
dtype=torch.bool,
145-
shape_ranges=[(args[2], args[3], args[4])],
146132
),
147133
Input(
148-
shape=args[5],
134+
min_shape=args[4],
135+
opt_shape=args[5],
136+
max_shape=args[6],
149137
dtype=torch.float32,
150-
shape_ranges=[(args[6], args[7], args[8])],
151138
),
152139
Input(
153-
shape=args[9],
140+
min_shape=args[7],
141+
opt_shape=args[8],
142+
max_shape=args[9],
154143
dtype=torch.float32,
155-
shape_ranges=[(args[10], args[11], args[12])],
156144
),
157145
]
158146
self.run_test_with_dynamic_shape(Where(), input_specs)

0 commit comments

Comments
 (0)