Skip to content

Commit 2b35c09

Browse files
committed
fix: Move min_block_size update into partitioning
- Add argument to partitioning function indicating full compiliation is expected, which (currently) sets the minimum block size to 1 and informs the user of this setting if the user had also specified a minimum block size
1 parent 9aba865 commit 2b35c09

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

core/compiler.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
150150
// TODO: Combine this within partition call
151151
partitioning::populateInputIValues(&partitioning_ctx);
152152

153-
partitioning::partition(&partitioning_ctx);
153+
partitioning::partition(&partitioning_ctx, expect_full_compilation);
154154

155155
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
156156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
@@ -382,13 +382,6 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
382382
// to generate collection-processing code in Torch
383383
auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info.enabled);
384384

385-
// Any nonzero block size is valid if full compilation to TRT is desired
386-
// Override the default min_block_size to ensure all TRT-supported operations are
387-
// executed in TRT, regardless of the size of the graph
388-
if (expect_full_compilation) {
389-
cfg.partitioning_info.min_block_size = 1;
390-
}
391-
392385
auto graph_and_mapping =
393386
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
394387
new_g = graph_and_mapping.first;

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) {
564564
}
565565
}
566566

567-
void partition(PartitioningCtx* ctx) {
567+
void partition(PartitioningCtx* ctx, bool expect_full_compilation) {
568+
// If full compilation is expected, overwrite minimum block size
569+
// Any nonzero block size is valid if full compilation to TRT is desired
570+
// Override the default min_block_size to ensure all TRT-supported operations are
571+
// executed in TRT, regardless of the size of the graph
572+
if (expect_full_compilation) {
573+
// If minimum block size is different from the default, the user must have specified it
574+
if (ctx->settings.min_block_size != 3) {
575+
LOG_WARNING(
576+
"Detected user-specified min_block_size with require_full_compilation=True "
577+
<< "disregarding min_block_size.");
578+
}
579+
ctx->settings.min_block_size = 1;
580+
}
581+
568582
LOG_DEBUG(ctx->settings);
569583

570584
// Go through all the blocks to do the partitioning

core/partitioning/partitioning.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
4848

4949
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
5050

51-
void partition(PartitioningCtx* ctx);
51+
void partition(PartitioningCtx* ctx, bool expect_full_compilation = false);
5252

5353
} // namespace partitioning
5454
} // namespace core

0 commit comments

Comments
 (0)