Skip to content

Commit 2ae5aaa

Browse files
committed
chore: address review comments
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 366cd31 commit 2ae5aaa

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/compile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch.fx.passes.pass_manager import PassManager
1010
from torch.fx.passes.shape_prop import ShapeProp
11+
from torch.fx.passes.splitter_base import SplitResult
1112
from torch_tensorrt.dynamo.aten_tracer import trace
1213
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
1314
from torch_tensorrt.dynamo.lowering import (
@@ -17,6 +18,7 @@
1718
from torch_tensorrt.dynamo import CompilationSettings
1819
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
1920
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
21+
from torch_tensorrt.dynamo.backend.backends import _compile_module
2022
from torch_tensorrt.dynamo.conversion import convert_module
2123

2224
from torch_tensorrt.dynamo._defaults import (
@@ -124,7 +126,7 @@ def compile(
124126

125127

126128
def _compile_graph(
127-
split_result: TRTSplitter,
129+
split_result: SplitResult,
128130
inputs: Any,
129131
settings: CompilationSettings = CompilationSettings(),
130132
**kwargs,

tests/py/ts/api/test_module_fallback.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_fallback_resnet18(self):
2323
},
2424
"enabled_precisions": {torch.float},
2525
"torch_executed_modules": ["torchvision.models.resnet.BasicBlock"],
26+
"ir": "ts",
2627
}
2728
trt_mod = torchtrt.compile(self.model, **compile_spec)
2829
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
@@ -49,6 +50,7 @@ def test_fallback_mobilenet_v2(self):
4950
"torchvision.models.mobilenetv2.ConvBNActivation"
5051
],
5152
"min_block_size": 5,
53+
"ir": "ts",
5254
}
5355
trt_mod = torchtrt.compile(self.model, **compile_spec)
5456
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))

tests/py/ts/api/test_operator_fallback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_fallback_resnet18(self):
2323
},
2424
"enabled_precisions": {torch.float},
2525
"torch_executed_ops": ["aten::add"],
26+
"ir": "ts",
2627
}
2728
trt_mod = torchtrt.compile(self.model, **compile_spec)
2829
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
@@ -49,6 +50,7 @@ def test_fallback_resnet18_with_tensor_domain(self):
4950
},
5051
"enabled_precisions": {torch.float},
5152
"torch_executed_ops": ["aten::add"],
53+
"ir": "ts",
5254
}
5355
trt_mod = torchtrt.compile(self.model, **compile_spec)
5456
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
@@ -72,6 +74,7 @@ def test_fallback_mobilenet_v2(self):
7274
},
7375
"enabled_precisions": {torch.float},
7476
"torch_executed_ops": ["aten::hardtanh"],
77+
"ir": "ts",
7578
}
7679
trt_mod = torchtrt.compile(self.model, **compile_spec)
7780
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))

0 commit comments

Comments
 (0)