Skip to content

fix: fix the fallback related issue after merging collection #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 12 additions & 51 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,13 @@ struct usage_info {
std::vector<size_t> tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
};

inline bool isTensorOrTensorList(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
}

inline bool isTensorList(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
}

inline bool isTensor(torch::jit::Value* val) {
return val->type()->isSubtypeOf(torch::jit::TensorType::get());
}

bool containNonTensorOutputs(torch::jit::Node* n) {
for (auto output : n->outputs()) {
if (!isTensorOrTensorList(output)) {
if (!isTensor(output)) {
return true;
}
}
Expand Down Expand Up @@ -68,6 +59,7 @@ std::vector<torch::jit::Node*> findModifyingNodes(
return modifying_nodes;
}

// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
std::vector<torch::jit::Node*> getDependencyNodes(
const std::vector<torch::jit::Value*>& vals,
const SegmentedBlock& seg_block) {
Expand All @@ -88,7 +80,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
stk.push_back(node);
for (auto input : node->inputs()) {
if (!isTensorOrTensorList(input)) {
if (!isTensor(input)) {
q.push(input);
}
}
Expand Down Expand Up @@ -124,7 +116,8 @@ void find_all_fallback_nodes(
if (!isTensor(output)) {
for (auto use : output->uses()) {
auto node = use.user;
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
if (node->kind() != torch::jit::prim::Constant &&
global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
q.push(node);
}
}
Expand Down Expand Up @@ -176,7 +169,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
seg_block.raw_inputs().end() &&
seg_block.contain_raw_value(mini_graph_input)) {
if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
if (!isTensor(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
continue;
seg_block.registerOutput(mini_graph_input);
}
Expand Down Expand Up @@ -242,36 +235,6 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
<< util::node_info(n));
}
}
return false;
}

bool is_collection(torch::jit::Node* n) {
for (auto out: n->outputs()) {
if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
// If the op is not supported by the conversion phase it should run in PyTorch
if (!conversion::OpSupported(n)) {
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
return false;
}

// If the user specifies the op to run in Torch it should run in PyTorch
if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) {
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
return false;
}

// If the user specifies the module containing this op to run in torch it should run in PyTorch
const auto to_compile_sym = c10::Symbol::attr("to_compile");
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
return false;
}

Expand Down Expand Up @@ -390,19 +353,18 @@ PartitionedGraph segment_graph(
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);

auto nodes = block->nodes();
auto reverse_nodes = nodes.reverse(); // merge from output side to input side
PartitionedGraph segmented_blocks;

// segment the nodes
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
for (const auto n : reverse_nodes) {
for (const auto n : nodes) {
// Skip constant nodes as they are resources for both kinds of modules
if (n->kind() == torch::jit::prim::Constant) {
continue;
}
// the outputs of trt subgraph shouldn't be collections
if (should_run_in_trt(n, forced_fallback_ops) && !(in_prog_trt_blk_nodes.size() == 0 && is_collection(n))) {
in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n);
if (check_node_fallback(n, global_fallback_nodes)) {
in_prog_trt_blk_nodes.push_back(n);

// If there is an active PyTorch block and we have passed the threshold for a valid TRT
// block then segment and reset the active PyTorch block
Expand All @@ -418,7 +380,7 @@ PartitionedGraph segment_graph(
LOG_DEBUG(
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
in_prog_pyt_blk_nodes.insert(
in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
}
in_prog_trt_blk_nodes.clear();
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
Expand All @@ -437,14 +399,14 @@ PartitionedGraph segment_graph(
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}
if (checkLoopEvaluatable(n)) {
in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n);
in_prog_trt_blk_nodes.push_back(n);
} else {
auto loop_node = std::vector<torch::jit::Node*>{n};
finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
}
continue;
}
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.begin(), n);
in_prog_pyt_blk_nodes.push_back(n);
}
}

Expand All @@ -459,7 +421,6 @@ PartitionedGraph segment_graph(
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
}
std::reverse(segmented_blocks.begin(), segmented_blocks.end());
return segmented_blocks;
}

Expand Down