Skip to content

Commit d07a815

Browse files
committed
fix: Repair version checking system for Torch
- Address version parsing issue for NV versions of Torch - Add specialized check for NV Torch versions such as `2.0.0.nv23.05`
1 parent d4c9c06 commit d07a815

File tree

6 files changed

+35
-7
lines changed

6 files changed

+35
-7
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696

97-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
97+
if version.parse(
98+
torch.__version__
99+
if ".nv" not in torch.__version__
100+
else torch.__version__.split(".nv")[0]
101+
) >= version.parse("2.1.dev"):
98102
from torch_tensorrt import dynamo
99103
from torch_tensorrt.dynamo import backend
100104

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
from packaging import version
33

4-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
4+
if version.parse(
5+
torch.__version__
6+
if ".nv" not in torch.__version__
7+
else torch.__version__.split(".nv")[0]
8+
) >= version.parse("2.1.dev"):
59
from torch_tensorrt.dynamo import fx_ts_compat
610
from .backend import compile

py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def forward(self, x, y):
4343
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
4444
return reshape
4545
"""
46-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
46+
if version.parse(
47+
torch.__version__
48+
if ".nv" not in torch.__version__
49+
else torch.__version__.split(".nv")[0]
50+
) < version.parse("2.1.0.dev20230620"):
4751
expected_graph = expected_graph.replace("num_users", "#users")
4852

4953
assert (

py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def is_leaf_module(self, m, qn):
5757
return add
5858
""".strip()
5959

60-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
60+
if version.parse(
61+
torch.__version__
62+
if ".nv" not in torch.__version__
63+
else torch.__version__.split(".nv")[0]
64+
) < version.parse("2.1.0.dev20230620"):
6165
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")
6266

6367
assert (
@@ -71,7 +75,11 @@ def is_leaf_module(self, m, qn):
7175
return (x,)
7276
""".strip()
7377

74-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
78+
if version.parse(
79+
torch.__version__
80+
if ".nv" not in torch.__version__
81+
else torch.__version__.split(".nv")[0]
82+
) < version.parse("2.1.0.dev20230620"):
7583
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")
7684

7785
assert (

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import torch
88

9-
if version.parse(torch.__version__) >= version.parse("2.dev"):
9+
if version.parse(
10+
torch.__version__
11+
if ".nv" not in torch.__version__
12+
else torch.__version__.split(".nv")[0]
13+
) >= version.parse("2.dev"):
1014
import torch._dynamo as torchdynamo
1115

1216
from torch.fx.passes.infra.pass_base import PassResult

py/torch_tensorrt/fx/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ def nested_decorator(f: Callable):
160160
def function_wrapper(*args, **kwargs):
161161
# Parse minimum and current Torch versions
162162
min_version = version.parse(min_torch_version)
163-
current_version = version.parse(torch.__version__)
163+
current_version = version.parse(
164+
torch.__version__
165+
if ".nv" not in torch.__version__
166+
else torch.__version__.split(".nv")[0]
167+
)
164168

165169
if current_version < min_version:
166170
raise AssertionError(

0 commit comments

Comments
 (0)