Skip to content

Commit 8c0deaa

Browse files
authored
Merge pull request #1837 from gs-olive/_native_batch_norm_legit_no_training_fix
fix: Remove references to `_native_batch_norm_legit_no_training` for PyTorch 2.0 Stable
2 parents 5fdb1c9 + 7457a30 commit 8c0deaa

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
165165
torch.ops.aten.max_pool3d_with_indices.default,
166166
torch.ops.aten.native_batch_norm.default,
167167
torch.ops.aten._native_batch_norm_legit.default,
168-
torch.ops.aten._native_batch_norm_legit_no_training.default,
169168
):
170169
modified = True
171170
if len(n.users) != 1:
@@ -186,16 +185,6 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
186185
new_args = list(n.args)
187186
new_args.append(False)
188187
new_args = tuple(new_args)
189-
elif (
190-
n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
191-
):
192-
new_op = torch.ops.aten.batch_norm
193-
new_args = list(n.args)
194-
new_args.append(False)
195-
# _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false)
196-
# but batchnorm takes in a training arg at position 5.
197-
new_args.insert(5, False)
198-
new_args = tuple(new_args)
199188

200189
getitem_node = next(iter(n.users))
201190
with module.graph.inserting_after(getitem_node):

0 commit comments

Comments
 (0)