Skip to content

fix: Avoid resolving non-tensor inputs to torch segment_blocks when unnecessary #1020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 38 additions & 30 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
return false;
}

std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
std::vector<torch::jit::Node*> getDependencyNodes(const std::vector<torch::jit::Value*>& vals) {
// use bfs to get the DAG dependency nodes for input value
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
Expand Down Expand Up @@ -137,17 +137,10 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> nontensor_inputs;
// Gather all non-tensor inputs for this seg_block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
nontensor_inputs.push_back(input);
}
}

std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
PartitionedGraph segmentBlocksWithSpecifiedInputs(
SegmentedBlock& seg_block,
const std::vector<torch::jit::Value*>& inputs_to_resolve) {
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
PartitionedGraph new_seg_blocks;
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
Expand All @@ -162,15 +155,15 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
}
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
std::unordered_set<torch::jit::Value*> inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end());
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());

bool prev_non_tensor_outputs = false;
for (auto n : seg_block.raw_nodes()) {
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
// SegmentedBlock.
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
if (containTargetInputs(n, inputs_to_resolve_set) || prev_non_tensor_outputs) {
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
if (!tensorrt_nodes.empty()) {
Expand Down Expand Up @@ -201,6 +194,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
return new_seg_blocks;
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> inputs_to_resolve;
// Gather all non-tensor inputs for this block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
inputs_to_resolve.push_back(input);
}
}
return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve);
}

std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
const PartitionedGraph& segmented_blocks,
const std::function<bool(torch::jit::Value*)>& condition) {
Expand Down Expand Up @@ -248,6 +253,10 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);

std::map<int, std::vector<torch::jit::Value*>>
torch_values_to_fix; // Only need to resolve values generated by tensorrt
std::set<int> tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs

// update blocks_list
std::unordered_set<int> updated_segments;
for (auto& use : usage_counts) {
Expand All @@ -256,27 +265,26 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
// kTorch segment.
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
auto first_torch_id = use_info.torch_use_id.back();
if (!updated_segments.count(first_torch_id)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(first_torch_id);
}
torch_values_to_fix[first_torch_id].push_back(use.first);
}
// kTensorRT segments always need to inject nodes for the nonTensor inputs
for (auto i : use_info.tensorrt_use_id) {
if (!updated_segments.count(i)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(i);
}
tensorrt_blocks_to_fix.insert(i);
}
}
for (auto torch_block_pair : torch_values_to_fix) {
auto to_inject_blocks =
segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

for (auto i : tensorrt_blocks_to_fix) {
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

segmented_blocks.clear();
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
return;
Expand Down
144 changes: 144 additions & 0 deletions tests/core/partitioning/test_resolve_nontensor_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
int count = count_trt_engines(fallback_g);
ASSERT_TRUE(count == 2);
}

TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
const auto graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%2 : str = prim::Constant[value="INS"]()
%3 : str = prim::Constant[value="OUTS"]()
%4 : bool = prim::Constant[value=0]()
%5 : int = prim::Constant[value=-1]()
%6 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%6, %2, %x)
%7 : Tensor = aten::__getitem__(%6, %2)
%8 : Tensor = aten::lt(%7, %y)
%9 : Tensor?[] = prim::ListConstruct(%8)
%10 : int = prim::dtype(%7)
%11 : Device = prim::device(%7)
%12 : Tensor = aten::tensor(%5, %10, %11, %4)
%13 : Tensor = aten::index_put_(%7, %9, %12, %4)
= aten::_set_item(%6, %3, %7)
%14 : Tensor = aten::__getitem__(%6, %2)
%15 : Tensor = aten::__getitem__(%6, %3)
return (%14, %15))IR";
*/
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue ins_key("INS");
auto ins_key_val = g->insertConstant(ins_key);
torch::jit::IValue outs_key("OUTS");
auto outs_key_val = g->insertConstant(outs_key);
torch::jit::IValue zero(0);
auto false_const_val = g->insertConstant(zero);
false_const_val->setType(c10::BoolType::get());
torch::jit::IValue neg_one(-1);
auto neg_one_const_val = g->insertConstant(neg_one);
auto dict_node = g->createDict(
ins_key_val->type(),
x->type(),
torch::jit::ArrayRef<torch::jit::Value*>(),
torch::jit::ArrayRef<torch::jit::Value*>());
g->insertNode(dict_node);
auto set_node = g->create(
torch::jit::Symbol::fromQualString("aten::_set_item"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x},
0);
g->insertNode(set_node);
auto get_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
1);
g->insertNode(get_node);
auto lt_node = g->create(
torch::jit::Symbol::fromQualString("aten::lt"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y},
1);
g->insertNode(lt_node);
auto list_node = g->createList(
at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
g->insertNode(list_node);
auto dtype_node = g->create(
torch::jit::Symbol::fromQualString("prim::dtype"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
1);
dtype_node->output()->setType(neg_one_const_val->type());
g->insertNode(dtype_node);
auto device_node = g->create(
torch::jit::Symbol::fromQualString("prim::device"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
1);
device_node->output()->setType(c10::DeviceObjType::get());
g->insertNode(device_node);
auto tensor_node = g->create(
torch::jit::Symbol::fromQualString("aten::tensor"),
torch::jit::ArrayRef<torch::jit::Value*>{
neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val},
1);
g->insertNode(tensor_node);
auto index_put_node = g->create(
torch::jit::Symbol::fromQualString("aten::index_put_"),
torch::jit::ArrayRef<torch::jit::Value*>{
get_node->output(), list_node->output(), tensor_node->output(), false_const_val},
1);
g->insertNode(index_put_node);
auto out_set_node = g->create(
torch::jit::Symbol::fromQualString("aten::_set_item"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()},
0);
g->insertNode(out_set_node);
auto get_ins_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
1);
g->insertNode(get_ins_node);
auto get_outs_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val},
1);
g->insertNode(get_outs_node);
g->registerOutput(get_ins_node->output());
g->registerOutput(get_outs_node->output());

torch_tensorrt::core::partitioning::PartitionInfo partition_info;
partition_info.enabled = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));

std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs_map.insert({g->inputs()[i], inputs[i]});
input_types.insert({g->inputs()[i], {at::kFloat}});
}
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);

int torch_block_cnt = 0, trt_block_cnt = 0;
for (const auto& segmented_block : segmented_blocks) {
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
++trt_block_cnt;
ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) {
return type_ptr->isSubtypeOf(torch::jit::TensorType::get());
}));
} else {
++torch_block_cnt;
bool output_dict = false;
bool input_dict = false;
auto dict_type = dict_node->output()->type();
for (auto in : segmented_block.raw_inputs()) {
if (in->type()->isSubtypeOf(dict_type)) {
input_dict = true;
}
}
for (auto out : segmented_block.raw_outputs()) {
if (out->type()->isSubtypeOf(dict_type)) {
output_dict = true;
}
}
EXPECT_TRUE(output_dict ^ input_dict);
}
}
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2);
}