Skip to content

feat: support prim::If in automatic fallback #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 135 additions & 42 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/ir_views.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/loop_unrolling.h"
#include "torch/csrc/jit/passes/lower_graph.h"
Expand Down Expand Up @@ -173,10 +174,131 @@ void AddSegmentedBlockToGraph(
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
}
size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
if (!old_to_new_g.count(seg.raw_inputs()[i])) {
old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
}
}

return;
}

typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
GraphAndMapping;

void AddIfBlockToGraph(
std::shared_ptr<torch::jit::Graph>& new_g,
torch::jit::Node* if_node,
const std::vector<GraphAndMapping>& graph_and_mappings,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
torch::jit::IfView if_view(if_node);

// create a new if node in new_g and add corresponding inputs
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));

// iterate over all blocks and add them to new created prim::If
for (auto graph_and_mapping : graph_and_mappings) {
auto new_if_block = new_if->addBlock();
auto cur_block_graph = graph_and_mapping.first;
auto cur_block_mapping = graph_and_mapping.second;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
for (auto& i : cur_block_mapping) {
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
// it's mini graph's input
if (old_to_new_g.count(i.first)) {
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
}
}

auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
new_if_block->cloneFrom(cur_block_graph->block(), env);
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = new_g->insertInput(0, "self_1");
self->setType(cur_block_graph->inputs()[0]->type());
}
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
}
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
new_if_block->eraseInput(i);
}
}
for (auto ov : if_view.outputs()) {
auto no = new_if->addOutput();
old_to_new_g[ov] = no;
no->copyMetadata(ov);
}
return;
}

GraphAndMapping ConstructFallbackGraph(
torch::jit::script::Module& new_mod,
torch::jit::Block* block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
CompileSpec cfg,
conversion::GraphParams named_params) {
auto convert_cfg = cfg.convert_info;
auto partition_info = cfg.partition_info;

auto new_g = std::make_shared<torch::jit::Graph>();

auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info);

// the mapping from lowering graph => fallback global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto input : block->inputs()) {
util::getOrAddInputForValue(input, new_g, old_to_new_g);
}

for (auto& seg_block : segmented_blocks) {
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<ir::Input> inputs;
for (auto& shape : seg_block.in_shape()) {
inputs.push_back(ir::Input(shape));
}
// update the input ranges for each segments
convert_cfg.inputs = inputs;
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
auto if_node = seg_block.raw_nodes()[0];

// convert the 2 blocks in prim::if and get the converted graph with mappings
std::vector<GraphAndMapping> graph_and_mappings;
for (auto cur_block : if_node->blocks()) {
graph_and_mappings.push_back(
ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, named_params));
}
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);

} else {
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
}
}
}

for (auto& output : block->outputs()) {
if (old_to_new_g.count(output)) {
new_g->registerOutput(old_to_new_g[output]);
}
}
return {new_g, old_to_new_g};
}

torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
Expand All @@ -192,53 +314,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
auto convert_cfg = std::move(cfg.convert_info);
LOG_INFO(*g << "(LoweringGraph)\n");
LOG_INFO("(LoweredGraph)\n" << *g);

// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::Partition(g, convert_cfg.inputs, cfg.partition_info);
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
std::unordered_map<torch::jit::Value*, ir::Input> inputs;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs.insert({g->inputs()[i], cfg.convert_info.inputs[i]});
}
auto input_ivalues_map = partitioning::generateRandomInputs(inputs);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
new_g = graph_and_mapping.first;
LOG_INFO("(FallbackGraph)\n" << *new_g);

// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
// module
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
return mod;
}

std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
// add global graph's input to old_to_new_g mapping
for (auto input : g->inputs()) {
util::getOrAddInputForValue(input, new_g, old_to_new_g);
}
for (auto& seg_block : segmented_blocks) {
std::string cur_block_target =
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<ir::Input> inputs;
for (auto& shape : seg_block.in_shape()) {
inputs.push_back(ir::Input(shape));
}
// update the input ranges for each segments
convert_cfg.inputs = inputs;
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
}
}

for (auto& output : g->outputs()) {
new_g->registerOutput(old_to_new_g[output]);
}

LOG_INFO(*new_g << "(FallbackGraph)\n");

auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
new_mod.type()->addMethod(new_method);
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/SegmentedBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace trtorch {
namespace core {
namespace partitioning {

SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
for (auto& node : nodes) {
nodes_.push_back(node);
Expand Down
2 changes: 1 addition & 1 deletion core/partitioning/SegmentedBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct SegmentedBlock {

SegmentedBlock() = default;
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes);
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}

torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
Expand Down
49 changes: 34 additions & 15 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <queue>
#include "core/conversion/conversion.h"
#include "core/partitioning/shape_analysis.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

namespace trtorch {
Expand Down Expand Up @@ -85,8 +86,14 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
// if current node is prim::If, just ensure that we have all required input in kTorch
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
new_seg_blocks.push_back(seg_block);
} else {
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
}
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
Expand Down Expand Up @@ -127,7 +134,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
return std::move(new_seg_blocks);
}

void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr<torch::jit::Graph> g
// create a list so we can insert SegmentedBlock without losing the iterators
std::list<SegmentedBlock> segmented_blocks_list(segmented_blocks.begin(), segmented_blocks.end());
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> idx_to_iter;
Expand Down Expand Up @@ -169,8 +176,10 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
if (!updated_segments.count(first_torch_id)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// TRTorch doesn't support non-tensor inputs for a module.
auto new_torch_block = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]).front();
*idx_to_iter[first_torch_id] = new_torch_block;
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
segmented_blocks.insert(
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(first_torch_id);
}
}
Expand All @@ -191,7 +200,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
return;
}

void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) {
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
std::set<torch::jit::Value*> input_values;
for (auto& seg_block : segmented_blocks) {
Expand All @@ -200,7 +209,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
}
}

for (auto& graph_output : g->outputs()) {
for (auto& graph_output : block->outputs()) {
input_values.insert(graph_output);
}

Expand Down Expand Up @@ -249,12 +258,12 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
return;
}

std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info) {
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
auto min_block_size = partition_info.min_block_size;
std::unordered_set<std::string> forced_fallback_operators(
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());

auto nodes = g->block()->nodes();
auto nodes = block->nodes();
std::vector<SegmentedBlock> segmented_blocks;

// segment the nodes
Expand All @@ -278,6 +287,16 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
}
tensorrt_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
// we shouldn't inject node for this block in dependency analysis process
if (n->kind() == torch::jit::prim::If) {
if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
pytorch_nodes.clear();
}
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
continue;
}
pytorch_nodes.push_back(n);
}
}
Expand All @@ -295,21 +314,21 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
}

std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<ir::Input>& inputs,
torch::jit::Block* block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
const PartitionInfo& partition_info) {
LOG_DEBUG(partition_info);
// segment lowering global graph into blocks
std::vector<SegmentedBlock> segmented_blocks = segment_graph(g, partition_info);
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);

// resolve nonTensor inputs/outputs
resolveNonTensorInputs(segmented_blocks, g);
resolveNonTensorInputs(segmented_blocks);

// register input/output torch::jit::Value for segmented graphs
registerSegmentsOutputs(segmented_blocks, g);
registerSegmentsOutputs(segmented_blocks, block);

// run shape analysis on each segmented block
runShapeAnalysis(segmented_blocks, inputs, g);
runShapeAnalysis(segmented_blocks, input_ivalues_map);

return segmented_blocks;
}
Expand Down
9 changes: 5 additions & 4 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/ir/ir.h"
#include "core/partitioning/PartitionInfo.h"
#include "core/partitioning/SegmentedBlock.h"
#include "core/partitioning/shape_analysis.h"
#include "core/util/prelude.h"
#include "torch/csrc/jit/ir/ir.h"

Expand All @@ -14,13 +15,13 @@ namespace partitioning {

typedef std::vector<SegmentedBlock> PartitionedGraph;

PartitionedGraph segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info);
PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info);

std::vector<SegmentedBlock> Partition(
std::shared_ptr<torch::jit::Graph> g,
std::vector<ir::Input>& inputs,
torch::jit::Block* block,
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
const PartitionInfo& partition_info);

} // namespace partitioning
} // namespace core
} // namespace trtorch
} // namespace trtorch
Loading