Skip to content

Commit f519935

Browse files
committed
feat(collections): Enable grouped inputs via partial compilation
HACK: This PR enables grouped input features by leveraging partial compilation and disabling tuple and list evaluators in the case where grouped inputs are used. The intention is that this WAR is removed in the next release Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8b891fb commit f519935

File tree

5 files changed

+49
-26
lines changed

5 files changed

+49
-26
lines changed

cpp/src/compile_spec.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
6363
}
6464
}
6565

66-
torchtrt::core::CompileSpec init_compile_spec(CompileSpec external) {
66+
torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
6767
if (external.graph_inputs.inputs.size() > 0) {
6868
torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs));
6969
return internal;
@@ -72,6 +72,25 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec external) {
7272
LOG_WARNING( "Input signature parsing is an experimental feature, behavior and APIs may change");
7373
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
7474
torchtrt::core::CompileSpec internal(converted_input_signature);
75+
76+
TORCHTRT_CHECK(!external.require_full_compilation, \
77+
"Grouped inputs currently requires partial compilation to be enabled, \
78+
this restriction will be relaxed in a future release");
79+
80+
LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
81+
LOG_DEBUG("Adding the following ops to torch_executed_ops:" \
82+
<< std::endl << " - aten::__getitem__" \
83+
<< std::endl << " - prim::ListConstruct" \
84+
<< std::endl << " - prim::ListUnpack" \
85+
<< std::endl << " - prim::TupleIndex" \
86+
<< std::endl << " - prim::TupleConstruct" \
87+
<< std::endl << " - prim::TupleUnpack");
88+
external.torch_executed_ops.push_back("aten::__getitem__");
89+
external.torch_executed_ops.push_back("prim::ListConstruct");
90+
external.torch_executed_ops.push_back("prim::ListUnpack");
91+
external.torch_executed_ops.push_back("prim::TupleIndex");
92+
external.torch_executed_ops.push_back("prim::TupleConstruct");
93+
external.torch_executed_ops.push_back("prim::TupleUnpack");
7594
return internal;
7695
}
7796
}

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt.logging import Level, log
99
from typing import Tuple, List, Dict
1010
import warnings
11+
from copy import deepcopy
1112

1213

1314
def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input:
@@ -188,7 +189,9 @@ def _parse_input_signature(input_signature: Any):
188189
else:
189190
raise KeyError("Input signature contains an unsupported type {}".format(type(input_signature)))
190191

191-
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
192+
def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
193+
# TODO: Remove deep copy once collections does not need partial compilation
194+
compile_spec = deepcopy(compile_spec_)
192195
info = _ts_C.CompileSpec()
193196

194197
if len(compile_spec["inputs"]) > 0:
@@ -204,6 +207,25 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec:
204207
signature = _parse_input_signature(compile_spec["input_signature"])
205208
info.input_signature = _C.InputSignature(signature) # py_object
206209

210+
if not compile_spec["torch_fallback"]["enabled"]:
211+
raise ValueError("Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release")
212+
213+
log(Level.Debug, "Grouped inputs currently requires additional settings to enable the feature")
214+
log(Level.Debug, """Adding the following ops to torch_executed_ops:
215+
- aten::__getitem__
216+
- prim::ListConstruct
217+
- prim::ListUnpack
218+
- prim::TupleIndex
219+
- prim::TupleConstruct
220+
- prim::TupleUnpack
221+
""")
222+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("aten::__getitem__")
223+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListConstruct")
224+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
225+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
226+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleConstruct")
227+
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleUnpack")
228+
207229
else:
208230
raise KeyError(
209231
"Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec"

py/torch_tensorrt/ts/_compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def compile(module: torch.jit.ScriptModule,
103103

104104
if require_full_compilation and (len(torch_executed_modules) > 0 or len(torch_executed_ops) > 0):
105105
raise ValueError(
106-
"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: "
107-
+ torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules)
106+
f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}")
108107

109108
spec = {
110109
"inputs": inputs,

tests/cpp/test_collections.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
3434
input_range.push_back({in0.sizes(), torch::kF16});
3535
input_range.push_back({in0.sizes(), torch::kF16});
3636
torch_tensorrt::ts::CompileSpec compile_settings(input_range);
37-
compile_settings.require_full_compilation = true;
3837
compile_settings.min_block_size = 1;
3938

4039
// // FP16 execution
@@ -78,7 +77,6 @@ TEST(CppAPITests, TestCollectionTupleInput) {
7877
torch::jit::IValue complex_input_shape2(input_tuple2);
7978

8079
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
81-
compile_settings.require_full_compilation = true;
8280
compile_settings.min_block_size = 1;
8381

8482
// // FP16 execution
@@ -136,7 +134,6 @@ TEST(CppAPITests, TestCollectionListInput) {
136134
torch::jit::IValue complex_input_shape2(input_tuple2);
137135

138136
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
139-
compile_settings.require_full_compilation = true;
140137
compile_settings.min_block_size = 1;
141138
//compile_settings.torch_executed_ops.push_back("aten::__getitem__");
142139

@@ -184,7 +181,6 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
184181
// torch::jit::IValue complex_input_shape(list);
185182

186183
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
187-
compile_settings.require_full_compilation = true;
188184
compile_settings.min_block_size = 1;
189185

190186
// compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
@@ -248,12 +244,8 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
248244
torch::jit::IValue complex_input_shape2(input_tuple2);
249245

250246
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
251-
compile_settings.require_full_compilation = true;
252247
compile_settings.min_block_size = 1;
253248

254-
// Need to skip the conversion of __getitem__ and ListConstruct
255-
//compile_settings.torch_executed_ops.push_back("aten::__getitem__");
256-
257249
// // FP16 execution
258250
compile_settings.enabled_precisions = {torch::kHalf};
259251
// // Compile module
@@ -313,12 +305,8 @@ TEST(CppAPITests, TestCollectionComplexModel) {
313305
torch::jit::IValue complex_input_shape2(input_tuple2);
314306

315307
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
316-
compile_settings.require_full_compilation = true;
317308
compile_settings.min_block_size = 1;
318309

319-
// Need to skip the conversion of __getitem__ and ListConstruct
320-
//compile_settings.torch_executed_ops.push_back("aten::__getitem__");
321-
322310
// // FP16 execution
323311
compile_settings.enabled_precisions = {torch::kHalf};
324312
// // Compile module

tests/py/api/test_collections.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def test_compile(self):
4848
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
4949
"device": torchtrt.Device("gpu:0"),
5050
"enabled_precisions": {torch.float},
51-
"require_full_compilation": False,
52-
"min_block_size": 3
51+
"min_block_size": 1
5352
}
5453

5554
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -69,8 +68,7 @@ def test_compile(self):
6968
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
7069
"device": torchtrt.Device("gpu:0"),
7170
"enabled_precisions": {torch.float},
72-
"require_full_compilation": False,
73-
"min_block_size": 3
71+
"min_block_size": 1
7472
}
7573

7674
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -89,8 +87,7 @@ def test_compile(self):
8987
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
9088
"device": torchtrt.Device("gpu:0"),
9189
"enabled_precisions": {torch.float},
92-
"require_full_compilation": False,
93-
"min_block_size": 3
90+
"min_block_size": 1
9491
}
9592

9693
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -111,8 +108,7 @@ def test_compile(self):
111108
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
112109
"device": torchtrt.Device("gpu:0"),
113110
"enabled_precisions": {torch.float},
114-
"require_full_compilation": False,
115-
"min_block_size": 3
111+
"min_block_size": 1
116112
}
117113

118114
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
@@ -134,8 +130,7 @@ def test_compile(self):
134130
"input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],),
135131
"device": torchtrt.Device("gpu:0"),
136132
"enabled_precisions": {torch.float},
137-
"require_full_compilation": False,
138-
"min_block_size": 3
133+
"min_block_size": 1
139134
}
140135

141136
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)

0 commit comments

Comments
 (0)