Skip to content

Commit 5110480

Browse files
committed
chore: refactor code structures according to PR
Signed-off-by: Bo Wang <[email protected]>
1 parent 80b1038 commit 5110480

File tree

4 files changed

+38
-24
lines changed

4 files changed

+38
-24
lines changed

core/compiler.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
1414
#include "torch/csrc/jit/passes/graph_fuser.h"
15+
#include "torch/csrc/jit/passes/loop_unrolling.h"
1516
#include "torch/csrc/jit/passes/lower_graph.h"
1617
#include "torch/csrc/jit/passes/pass_manager.h"
1718
#include "torch/custom_class.h"
@@ -30,10 +31,9 @@ void AddEngineToGraph(
3031
torch::jit::script::Module mod,
3132
std::shared_ptr<torch::jit::Graph>& g,
3233
const std::string& serialized_engine,
33-
int engine_id = 0,
34+
std::string engine_id = "",
3435
bool fallback = false) {
35-
auto engine_ptr =
36-
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(engine_id), serialized_engine);
36+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine);
3737
// Get required metadata about the engine out
3838
auto num_io = engine_ptr->num_io;
3939
auto name = engine_ptr->name;
@@ -200,14 +200,17 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
200200
return mod;
201201
}
202202

203-
int trt_engine_id = 0;
204203
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
205204
// add global graph's input to old_to_new_g mapping
206205
for (auto input : g->inputs()) {
207206
util::getOrAddInputForValue(input, new_g, old_to_new_g);
208207
}
209208
for (auto& seg_block : segmented_blocks) {
210-
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n");
209+
std::string cur_block_target =
210+
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
211+
LOG_INFO(*g << "(MiniGraphIn" << cur_block_target << "Block\n");
212+
std::ostringstream trt_engine_id;
213+
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
211214
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
212215
std::vector<ir::InputRange> input_ranges;
213216
for (auto& shape : seg_block.in_shape()) {
@@ -217,7 +220,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
217220
convert_cfg.input_ranges = input_ranges;
218221
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
219222
auto temp_g = std::make_shared<torch::jit::Graph>();
220-
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++, true);
223+
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);
221224

222225
seg_block.update_graph(temp_g);
223226
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);

core/partitioning/partitioning.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <queue>
44
#include "core/conversion/conversion.h"
55
#include "core/partitioning/shape_analysis.h"
6-
#include "torch/csrc/jit/passes/constant_pooling.h"
76
#include "torch/csrc/jit/passes/dead_code_elimination.h"
87

98
namespace trtorch {
@@ -30,7 +29,7 @@ bool isAllNodesSupported(const std::vector<torch::jit::Node*>& nodes) {
3029
return true;
3130
}
3231

33-
bool containNonTensorInputs(torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
32+
bool containTargetInputs(torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
3433
for (auto input : n->inputs()) {
3534
if (!isTensorOrTensorList(input) && target_inputs.count(input)) {
3635
return true;
@@ -94,7 +93,7 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
9493
bool prev_non_tensor_outputs = false;
9594
for (auto n : seg_block.raw_nodes()) {
9695
// it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97-
if (containNonTensorInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
96+
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
9897
if (!tensorrt_nodes.empty()) {
9998
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
10099
tensorrt_nodes.clear();
@@ -278,18 +277,8 @@ std::vector<SegmentedBlock> Partition(
278277
// register input/output torch::jit::Value for segmented graphs
279278
registerSegmentsOutputs(segmented_blocks, g);
280279

281-
// store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
282-
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
283-
std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
284-
for (size_t i = 0; i < g->inputs().size(); ++i) {
285-
ivalues_maps[g->inputs()[i]] = random_inputs[i];
286-
}
287-
288-
// register every segment's input shape, and it's running output IValues
289-
for (auto& seg_block : segmented_blocks) {
290-
torch::jit::ConstantPooling(seg_block.g());
291-
getSegmentsOutputByRunning(seg_block, ivalues_maps);
292-
}
280+
// run shape analysis on each segmented block
281+
runShapeAnalysis(segmented_blocks, input_ranges, g);
293282

294283
return segmented_blocks;
295284
}

core/partitioning/shape_analysis.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "core/partitioning/shape_analysis.h"
22
#include "core/util/prelude.h"
33
#include "torch/csrc/jit/api/module.h"
4+
#include "torch/csrc/jit/passes/constant_pooling.h"
45

56
namespace trtorch {
67
namespace core {
@@ -97,6 +98,25 @@ void getSegmentsOutputByRunning(
9798
seg_block.register_inshape(input_shape);
9899
}
99100

101+
void runShapeAnalysis(
102+
std::vector<SegmentedBlock>& segmented_blocks,
103+
std::vector<ir::InputRange>& input_ranges,
104+
std::shared_ptr<torch::jit::Graph> g) {
105+
// store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
106+
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
107+
std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
108+
for (size_t i = 0; i < g->inputs().size(); ++i) {
109+
ivalues_maps[g->inputs()[i]] = random_inputs[i];
110+
}
111+
112+
// register every segment's input shape, and it's running output IValues
113+
for (auto& seg_block : segmented_blocks) {
114+
torch::jit::ConstantPooling(seg_block.g());
115+
getSegmentsOutputByRunning(seg_block, ivalues_maps);
116+
}
117+
return;
118+
}
119+
100120
} // namespace partitioning
101121
} // namespace core
102122
} // namespace trtorch

core/partitioning/shape_analysis.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "core/ir/ir.h"
22
#include "core/partitioning/SegmentedBlock.h"
3+
#include "torch/csrc/jit/ir/ir.h"
34

45
namespace trtorch {
56
namespace core {
67
namespace partitioning {
78

89
std::vector<torch::jit::IValue> generateRandomInputs(std::vector<ir::InputRange>& input_ranges);
910

10-
void getSegmentsOutputByRunning(
11-
SegmentedBlock& seg_block,
12-
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps);
11+
void runShapeAnalysis(
12+
std::vector<SegmentedBlock>& segmented_blocks,
13+
std::vector<ir::InputRange>& input_ranges,
14+
std::shared_ptr<torch::jit::Graph> g);
1315

1416
} // namespace partitioning
1517
} // namespace core

0 commit comments

Comments
 (0)