Skip to content

Commit 7cc5eb7

Browse files
committed
chore: Address test failures
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2ae5aaa commit 7cc5eb7

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
from ._settings import *
2-
from .compile import compile
3-
from .aten_tracer import trace
1+
from packaging import version
2+
from torch_tensorrt._util import sanitized_torch_version
3+
4+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
5+
from ._settings import *
6+
from .compile import compile
7+
from .aten_tracer import trace

tests/py/ts/models/test_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_resnet18(self):
2525
"gpu_id": 0,
2626
},
2727
"enabled_precisions": {torch.float},
28+
"ir": "ts",
2829
}
2930

3031
trt_mod = torchtrt.compile(self.model, **compile_spec)
@@ -49,6 +50,7 @@ def test_mobilenet_v2(self):
4950
"gpu_id": 0,
5051
},
5152
"enabled_precisions": {torch.float},
53+
"ir": "ts",
5254
}
5355

5456
trt_mod = torchtrt.compile(self.model, **compile_spec)
@@ -75,6 +77,7 @@ def test_efficientnet_b0(self):
7577
"gpu_id": 0,
7678
},
7779
"enabled_precisions": {torch.float},
80+
"ir": "ts",
7881
}
7982

8083
trt_mod = torchtrt.compile(self.model, **compile_spec)
@@ -107,6 +110,7 @@ def test_bert_base_uncased(self):
107110
},
108111
"enabled_precisions": {torch.float},
109112
"truncate_long_and_double": True,
113+
"ir": "ts",
110114
}
111115
with torchtrt.logging.errors():
112116
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -137,6 +141,7 @@ def test_resnet18_half(self):
137141
"gpu_id": 0,
138142
},
139143
"enabled_precisions": {torch.half},
144+
"ir": "ts",
140145
}
141146

142147
trt_mod = torchtrt.compile(self.scripted_model, **compile_spec)

tests/py/ts/models/test_multiple_registered_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_multiple_engines(self):
2727
"gpu_id": 0,
2828
},
2929
"enabled_precisions": {torch.float},
30+
"ir": "ts",
3031
}
3132
rn18_trt_mod = torchtrt.compile(self.resnet18, **compile_spec)
3233
rn50_trt_mod = torchtrt.compile(self.resnet50, **compile_spec)

0 commit comments

Comments
 (0)