Skip to content

Commit 6814350

Browse files
authored
feat: Add ExportedProgram as an IR (#2191)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 08a2ee4 commit 6814350

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.fx
99
import torch_tensorrt.ts
10+
from torch._export import ExportedProgram
1011
from torch_tensorrt._enums import dtype
1112
from torch_tensorrt._Input import Input
1213
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
@@ -43,6 +44,7 @@ class _IRType(Enum):
4344
fx = 1
4445
dynamo = 2
4546
torch_compile = 3
47+
exported_program = 4
4648

4749

4850
class _ModuleType(Enum):
@@ -51,6 +53,7 @@ class _ModuleType(Enum):
5153
nn = 0
5254
ts = 1
5355
fx = 2
56+
ep = 3
5457

5558

5659
def _parse_module_type(module: Any) -> _ModuleType:
@@ -61,6 +64,8 @@ def _parse_module_type(module: Any) -> _ModuleType:
6164
return _ModuleType.ts
6265
elif isinstance(module, torch.fx.GraphModule):
6366
return _ModuleType.fx
67+
elif isinstance(module, ExportedProgram):
68+
return _ModuleType.ep
6469
elif isinstance(module, torch.nn.Module):
6570
return _ModuleType.nn
6671
else:
@@ -70,6 +75,7 @@ def _parse_module_type(module: Any) -> _ModuleType:
7075
def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
7176
module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts])
7277
module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx])
78+
module_is_exportable = module_type == _ModuleType.ep
7379

7480
ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"])
7581
ir_targets_fx = ir == "fx"
@@ -95,8 +101,16 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
95101
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
96102
)
97103
return _IRType.ts
104+
elif module_is_exportable:
105+
raise ValueError(
106+
"Input graph is an ExportedProgram which is not currently supported. Please provide torch.nn.Module or torch.fx.GraphModule as input."
107+
)
98108
else:
99109
raise ValueError("Module was provided in an unsupported format")
110+
elif ir == "exported_program":
111+
raise ValueError(
112+
"ir=exported_program is not currently supported. Supported ir options : ts|fx|dynamo"
113+
)
100114
else:
101115
raise ValueError("Unknown ir was requested")
102116

0 commit comments

Comments
 (0)