Skip to content

Commit e46f525

Browse files
committed
fix: Add test case, move config condition
- Add test case to elicit behavior where full compilation is requested but TRT engine size falls below default `min_block_size=3` - Move `min_block_size` condition to narrow scope
1 parent 17753fc commit e46f525

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

core/compiler.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,6 @@ partitioning::GraphAndMapping BuildHybridGraph(
143143
auto convert_info = cfg.convert_info;
144144
auto partitioning_info = cfg.partitioning_info;
145145

146-
// Any nonzero block size is valid if full compilation to TRT is desired
147-
if (expect_full_compilation) {
148-
partitioning_info.min_block_size = 1;
149-
}
150-
151146
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
152147
partitioning_ctx.input_types_map = first_use_types;
153148

@@ -197,9 +192,10 @@ partitioning::GraphAndMapping BuildHybridGraph(
197192
if (expect_full_compilation) {
198193
for (auto torch_node : seg_block.block()->nodes()) {
199194
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
200-
LOG_ERROR(
195+
TORCHTRT_THROW_ERROR(
201196
"Full compilation specified but node " << torch_node->kind().toQualString()
202-
<< " was executed in Torch.");
197+
<< " was executed in Torch."
198+
<< " Try recompiling with require_full_compilation=False.");
203199
}
204200
}
205201
}
@@ -209,10 +205,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
209205
// If full compilation is expected, cannot have more than 2 Torch segments
210206
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211207
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
212-
LOG_ERROR(
208+
TORCHTRT_THROW_ERROR(
213209
"Full compilation specified but number of torch segments was "
214210
<< num_torch_segments << " and number of trt segments was " << num_trt_segments
215-
<< ". Was expecting at most 2 Torch segments and 1 TRT segment.");
211+
<< ". Was expecting at most 2 Torch segments and 1 TRT segment."
212+
<< " Try recompiling with require_full_compilation=False.");
216213
}
217214
}
218215

@@ -384,6 +381,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
384381
// If the model is fully-compilable and the user has specified full compilation, run partitioning
385382
// to generate collection-processing code in Torch
386383
auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info.enabled);
384+
385+
// Any nonzero block size is valid if full compilation to TRT is desired
386+
// Override the default min_block_size to ensure all TRT-supported operations are
387+
// executed in TRT, regardless of the size of the graph
388+
if (expect_full_compilation) {
389+
cfg.partitioning_info.min_block_size = 1;
390+
}
391+
387392
auto graph_and_mapping =
388393
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
389394
new_g = graph_and_mapping.first;

tests/py/api/test_e2e_behavior.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,37 @@ def forward(self, x, y, z):
146146
trt_output, torch_output
147147
), "Found differing output formatting between Torch-TRT and Torch"
148148

149+
def test_tuple_output_with_full_compilation(self):
150+
class Sample(torch.nn.Module):
151+
def __init__(self):
152+
super(Sample, self).__init__()
153+
154+
def forward(self, x, y):
155+
a = x + y
156+
return (a,)
157+
158+
self.model = Sample().eval().to("cuda")
159+
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
160+
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
161+
scripted_mod = torch.jit.script(self.model)
162+
163+
inputs = [
164+
torchtrt.Input((5, 5), dtype=torch.float),
165+
torchtrt.Input((5, 5), dtype=torch.float),
166+
]
167+
168+
trt_mod = torchtrt.ts.compile(
169+
scripted_mod,
170+
inputs=inputs,
171+
require_full_compilation=True,
172+
enabled_precisions={torch.float, torch.half},
173+
)
174+
trt_output = trt_mod(self.input_1, self.input_2)
175+
torch_output = self.model(self.input_1, self.input_2)
176+
assert same_output_format(
177+
trt_output, torch_output
178+
), "Found differing output formatting between Torch-TRT and Torch"
179+
149180

150181
if __name__ == "__main__":
151182
unittest.main()

0 commit comments

Comments
 (0)