Skip to content

Commit a8df33a

Browse files
committed
fix: Repair version checking system for Torch
- Address version parsing issue for NV versions of Torch - Add specialized check for NV Torch versions such as `2.0.0.nv23.05`
1 parent d4c9c06 commit a8df33a

File tree

6 files changed

+20
-9
lines changed

6 files changed

+20
-9
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ def _find_lib(name, paths):
9393
from torch_tensorrt._Device import Device
9494

9595
from torch_tensorrt import fx
96+
from torch_tensorrt._version import sanitized_torch_version
9697

97-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
98+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9899
from torch_tensorrt import dynamo
99100
from torch_tensorrt.dynamo import backend
100101

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import torch
21
from packaging import version
2+
from torch_tensorrt._version import sanitized_torch_version
33

4-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
4+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
55
from torch_tensorrt.dynamo import fx_ts_compat
66
from .backend import compile

py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.testing._internal.common_utils import run_tests, TestCase
1212
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
1313
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
14+
from torch_tensorrt._version import sanitized_torch_version
1415

1516
_LOGGER = logging.getLogger(__name__)
1617

@@ -43,7 +44,9 @@ def forward(self, x, y):
4344
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
4445
return reshape
4546
"""
46-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
47+
if version.parse(sanitized_torch_version()) < version.parse(
48+
"2.1.0.dev20230620"
49+
):
4750
expected_graph = expected_graph.replace("num_users", "#users")
4851

4952
assert (

py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.nn as nn
99

1010
import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup
11+
from torch_tensorrt._version import sanitized_torch_version
12+
1113
from torch.testing._internal.common_utils import run_tests, TestCase
1214

1315
_LOGGER = logging.getLogger(__name__)
@@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn):
5759
return add
5860
""".strip()
5961

60-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
62+
if version.parse(sanitized_torch_version()) < version.parse(
63+
"2.1.0.dev20230620"
64+
):
6165
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")
6266

6367
assert (
@@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn):
7175
return (x,)
7276
""".strip()
7377

74-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
78+
if version.parse(sanitized_torch_version()) < version.parse(
79+
"2.1.0.dev20230620"
80+
):
7581
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")
7682

7783
assert (

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from contextlib import contextmanager
44
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
55
from packaging import version
6+
from torch_tensorrt._version import sanitized_torch_version
67

78
import torch
89

9-
if version.parse(torch.__version__) >= version.parse("2.dev"):
10+
if version.parse(sanitized_torch_version()) >= version.parse("2.dev"):
1011
import torch._dynamo as torchdynamo
1112

1213
from torch.fx.passes.infra.pass_base import PassResult

py/torch_tensorrt/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
replace_op_with_indices,
1313
run_const_fold,
1414
)
15-
15+
from torch_tensorrt._version import sanitized_torch_version
1616
from .types import Shape, TRTDataType
1717

1818

@@ -160,7 +160,7 @@ def nested_decorator(f: Callable):
160160
def function_wrapper(*args, **kwargs):
161161
# Parse minimum and current Torch versions
162162
min_version = version.parse(min_torch_version)
163-
current_version = version.parse(torch.__version__)
163+
current_version = version.parse(sanitized_torch_version())
164164

165165
if current_version < min_version:
166166
raise AssertionError(

0 commit comments

Comments
 (0)