Skip to content

Commit 4f7a549

Browse files
authored
Merge pull request #816 from cyfwry/810
Fix the bug of incorrect model type identification
2 parents b652045 + c055c3c commit 4f7a549

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,36 @@ class _IRType(Enum):
1414
fx = 1
1515

1616

17-
def _module_ir(module: Any, ir: str) -> _IRType.ts:
18-
# Possible module types
19-
module_is_tsable = any(
20-
isinstance(module, t) for t in [torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction])
21-
module_is_fxable = any(isinstance(module, t) for t in [torch.nn.Module, torch.fx.GraphModule])
17+
class _ModuleType(Enum):
18+
"""Enum to set the minimum required logging level to print a message to stdout
19+
"""
20+
nn = 0
21+
ts = 1
22+
fx = 2
23+
24+
25+
def _parse_module_type(module: Any) -> _ModuleType:
26+
if any(isinstance(module, t) for t in [torch.jit.ScriptModule, torch.jit.ScriptFunction]):
27+
return _ModuleType.ts
28+
elif isinstance(module, torch.fx.GraphModule):
29+
return _ModuleType.fx
30+
elif isinstance(module, torch.nn.Module):
31+
return _ModuleType.nn
32+
else:
33+
raise RuntimeError("Module is an unknown format")
34+
35+
36+
def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
37+
module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]])
38+
module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]])
2239

2340
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
2441
ir_targets_fx = ir == "fx"
2542

2643
if module_is_tsable and ir_targets_torchscript:
2744
return _IRType.ts
2845
elif module_is_fxable and ir_targets_fx:
29-
if isinstance(module, torch.fx.GraphModule):
46+
if module_type == _ModuleType.fx:
3047
raise ValueError("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT")
3148
elif ir_targets_fx:
3249
raise ValueError("Preferred ir was set to \"fx\" which is currently not supported by Torch-TensorRT")
@@ -85,10 +102,11 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
85102
Returns:
86103
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
87104
"""
88-
target_ir = _module_ir(module, ir)
105+
module_type = _parse_module_type(module)
106+
target_ir = _get_target_ir(module_type, ir)
89107
if target_ir == _IRType.ts:
90108
ts_mod = module
91-
if isinstance(module, torch.nn.Module):
109+
if module_type == _ModuleType.nn:
92110
logging.log(
93111
logging.Level.Info,
94112
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
@@ -134,14 +152,14 @@ def convert_method_to_trt_engine(module: Any,
134152
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
135153
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
136154
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
137-
138155
Returns:
139156
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
140157
"""
141-
target_ir = _module_ir(module, ir)
158+
module_type = _parse_module_type(module)
159+
target_ir = _get_target_ir(module_type, ir)
142160
if target_ir == _IRType.ts:
143161
ts_mod = module
144-
if isinstance(module, torch.nn.Module):
162+
if module_type == _ModuleType.nn:
145163
logging.log(
146164
logging.Level.Info,
147165
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
@@ -155,4 +173,4 @@ def convert_method_to_trt_engine(module: Any,
155173
elif target_ir == _IRType.fx:
156174
raise RuntimeError("fx is currently not supported")
157175
else:
158-
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
176+
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

tests/py/test_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,18 @@ def test_dynamic_shape(self):
506506
self.assertTrue(self._verify_correctness(i, target))
507507

508508

509+
class TestModule(unittest.TestCase):
510+
511+
def test_module_type(self):
512+
nn_module = models.alexnet(pretrained=True).eval().to("cuda")
513+
ts_module = torch.jit.trace(nn_module, torch.ones([1, 3, 224, 224]).to("cuda"))
514+
fx_module = torch.fx.symbolic_trace(nn_module)
515+
516+
self.assertEqual(torchtrt._compile._parse_module_type(nn_module), torchtrt._compile._ModuleType.nn)
517+
self.assertEqual(torchtrt._compile._parse_module_type(ts_module), torchtrt._compile._ModuleType.ts)
518+
self.assertEqual(torchtrt._compile._parse_module_type(fx_module), torchtrt._compile._ModuleType.fx)
519+
520+
509521
def test_suite():
510522
suite = unittest.TestSuite()
511523
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
@@ -527,6 +539,7 @@ def test_suite():
527539
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
528540
suite.addTest(unittest.makeSuite(TestDevice))
529541
suite.addTest(unittest.makeSuite(TestInput))
542+
suite.addTest(unittest.makeSuite(TestModule))
530543

531544
return suite
532545

0 commit comments

Comments
 (0)