Skip to content

Commit e884820

Browse files
authored
fix: Repair version checking system for Torch (#2118)
1 parent 230cf6a commit e884820

File tree

7 files changed

+27
-9
lines changed

7 files changed

+27
-9
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696

97-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
97+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9898
from torch_tensorrt import dynamo
9999
from torch_tensorrt.dynamo import backend
100100

py/torch_tensorrt/_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ def get_build_info() -> str:
3030

3131
def set_device(gpu_id):
3232
_C.set_device(gpu_id)
33+
34+
35+
def sanitized_torch_version() -> str:
36+
return (
37+
torch.__version__
38+
if ".nv" not in torch.__version__
39+
else torch.__version__.split(".nv")[0]
40+
)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import torch
21
from packaging import version
2+
from torch_tensorrt._util import sanitized_torch_version
33

4-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
4+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
55
from torch_tensorrt.dynamo import fx_ts_compat
66
from .backend import compile

py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.testing._internal.common_utils import run_tests, TestCase
1212
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
1313
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
14+
from torch_tensorrt._util import sanitized_torch_version
1415

1516
_LOGGER = logging.getLogger(__name__)
1617

@@ -43,7 +44,9 @@ def forward(self, x, y):
4344
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
4445
return reshape
4546
"""
46-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
47+
if version.parse(sanitized_torch_version()) < version.parse(
48+
"2.1.0.dev20230620"
49+
):
4750
expected_graph = expected_graph.replace("num_users", "#users")
4851

4952
assert (

py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.nn as nn
99

1010
import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup
11+
from torch_tensorrt._util import sanitized_torch_version
12+
1113
from torch.testing._internal.common_utils import run_tests, TestCase
1214

1315
_LOGGER = logging.getLogger(__name__)
@@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn):
5759
return add
5860
""".strip()
5961

60-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
62+
if version.parse(sanitized_torch_version()) < version.parse(
63+
"2.1.0.dev20230620"
64+
):
6165
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")
6266

6367
assert (
@@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn):
7175
return (x,)
7276
""".strip()
7377

74-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
78+
if version.parse(sanitized_torch_version()) < version.parse(
79+
"2.1.0.dev20230620"
80+
):
7581
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")
7682

7783
assert (

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from contextlib import contextmanager
44
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
55
from packaging import version
6+
from torch_tensorrt._util import sanitized_torch_version
67

78
import torch
89

9-
if version.parse(torch.__version__) >= version.parse("2.dev"):
10+
if version.parse(sanitized_torch_version()) >= version.parse("2.dev"):
1011
import torch._dynamo as torchdynamo
1112

1213
from torch.fx.passes.infra.pass_base import PassResult

py/torch_tensorrt/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
replace_op_with_indices,
1313
run_const_fold,
1414
)
15-
15+
from torch_tensorrt._util import sanitized_torch_version
1616
from .types import Shape, TRTDataType
1717

1818

@@ -160,7 +160,7 @@ def nested_decorator(f: Callable):
160160
def function_wrapper(*args, **kwargs):
161161
# Parse minimum and current Torch versions
162162
min_version = version.parse(min_torch_version)
163-
current_version = version.parse(torch.__version__)
163+
current_version = version.parse(sanitized_torch_version())
164164

165165
if current_version < min_version:
166166
raise AssertionError(

0 commit comments

Comments
 (0)