Skip to content

Commit b35f69a

Browse files
authored
Merge pull request #870 from NVIDIA/aten_int_num_tensor
Adding limited support for aten::Int
2 parents c0734dc + 83ae991 commit b35f69a

17 files changed

+469
-45
lines changed

core/lowering/lowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
22
#include "torch/csrc/jit/passes/create_functional_graphs.h"
33
#include "torch/csrc/jit/passes/dead_code_elimination.h"
4+
#include "torch/csrc/jit/passes/erase_number_types.h"
45
#include "torch/csrc/jit/passes/freeze_module.h"
56
#include "torch/csrc/jit/passes/fuse_linear.h"
67
#include "torch/csrc/jit/passes/guard_elimination.h"
@@ -64,6 +65,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6465
passes::RemoveNOPs(g);
6566
passes::AliasOperators(g);
6667
passes::SiluToSigmoidMultipication(g);
68+
passes::RemoveSingleUse0DTensors(g);
69+
passes::RemoveUnnecessaryCasts(g);
6770
LOG_GRAPH(*g);
6871
}
6972

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"view_to_reshape.cpp",
2525
"remove_dropout.cpp",
2626
"remove_nops.cpp",
27+
"remove_unnecessary_casts.cpp",
2728
"silu_to_sigmoid_multiplication.cpp",
2829
"unpack_addmm.cpp",
2930
"unpack_batch_norm.cpp",

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
2828
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2929
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
3030
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
31+
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
32+
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3133
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3234
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3335
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <stack>
2+
#include <unordered_set>
3+
4+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
5+
6+
#include "core/lowering/passes/passes.h"
7+
#include "core/util/prelude.h"
8+
9+
namespace torch_tensorrt {
10+
namespace core {
11+
namespace lowering {
12+
namespace passes {
13+
14+
void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) {
15+
auto g = mod.get_method(method_name).graph();
16+
17+
std::string set_attr_pattern = R"IR(
18+
graph(%self, %0):
19+
None = prim::SetAttr[name="_has_warned"](%self, %0)
20+
return ())IR";
21+
std::string no_set_attr_pattern = R"IR(
22+
graph(%self, %0):
23+
return ())IR";
24+
25+
// remove contiguous
26+
torch::jit::SubgraphRewriter remove_set_attr;
27+
remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern);
28+
remove_set_attr.runOnGraph(g);
29+
LOG_GRAPH("Post remove contiguous: " << *g);
30+
}
31+
32+
} // namespace passes
33+
} // namespace lowering
34+
} // namespace core
35+
} // namespace torch_tensorrt
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
#include <vector>
7+
8+
namespace torch_tensorrt {
9+
namespace core {
10+
namespace lowering {
11+
namespace passes {
12+
13+
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
14+
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
15+
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
16+
std::string int_cast_pattern = R"IR(
17+
graph(%1: int):
18+
%2: Tensor = aten::NumToTensor(%1)
19+
%3: int = aten::Int(%2)
20+
return (%3))IR";
21+
std::string int_clean_pattern = R"IR(
22+
graph(%1: int):
23+
return (%1))IR";
24+
25+
std::string float_cast_pattern = R"IR(
26+
graph(%1: float):
27+
%2: Tensor = aten::NumToTensor(%1)
28+
%3: float = aten::Float(%2)
29+
return (%3))IR";
30+
std::string float_clean_pattern = R"IR(
31+
graph(%1: float):
32+
return (%1))IR";
33+
34+
std::string bool_cast_pattern = R"IR(
35+
graph(%1: bool):
36+
%2: Tensor = aten::NumToTensor(%1)
37+
%3: bool = aten::Bool(%2)
38+
return (%3))IR";
39+
std::string bool_clean_pattern = R"IR(
40+
graph(%1: bool):
41+
return (%1))IR";
42+
43+
torch::jit::SubgraphRewriter int_cast_rewriter;
44+
int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern);
45+
int_cast_rewriter.runOnGraph(graph);
46+
47+
torch::jit::SubgraphRewriter float_cast_rewriter;
48+
float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern);
49+
float_cast_rewriter.runOnGraph(graph);
50+
51+
torch::jit::SubgraphRewriter bool_cast_rewriter;
52+
bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern);
53+
bool_cast_rewriter.runOnGraph(graph);
54+
55+
LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph);
56+
}
57+
58+
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
59+
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
60+
if (it->kind() == torch::jit::prim::Constant) {
61+
// Going from a constant and is single use means we can fuse
62+
if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) {
63+
// Get the tensor stored in constant
64+
at::Tensor t = *torch::jit::constant_as<at::Tensor>(it->output());
65+
// If shape is 0D
66+
if (t.sizes() == std::vector<int64_t>({})) {
67+
LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName());
68+
LOG_GRAPH("Number of uses: " << it->output()->uses().size());
69+
// If the tensor is only used once
70+
if (it->output()->uses().size() == 1) {
71+
auto use = it->output()->uses()[0];
72+
auto user = use.user;
73+
74+
// Is a NumToTensor / aten::[Int/Float] case
75+
if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) {
76+
if (user->output()->uses().size() == 1) {
77+
auto potential_cast = user->output()->uses()[0].user;
78+
// The downstream user is aten::Int
79+
if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") ||
80+
potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) {
81+
LOG_GRAPH("Downstream user is aten::Int/aten::Float");
82+
auto arg = use.offset;
83+
84+
for (size_t k = 0; k < user->inputs().size(); ++k) {
85+
if (k != arg) {
86+
if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) {
87+
LOG_GRAPH("Input " << k << " is a Tensor");
88+
if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) {
89+
auto num_to_tensor = user->inputs()[k]->node();
90+
91+
LOG_GRAPH(
92+
"Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n "
93+
<< *(*it) << *num_to_tensor << *user << *potential_cast);
94+
95+
// Replace the Tensor Constant with a scalar constant
96+
LOG_GRAPH("Deleting 0-dim Tensor: " << **it);
97+
torch::jit::WithInsertPoint gaurd(*it);
98+
99+
auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope());
100+
new_const_val->copyMetadata(it->output());
101+
// How to determine the internal scalar type instead of assuming?
102+
if (potential_cast->kind() == c10::aten::Int) {
103+
new_const_val->setType(c10::IntType::get());
104+
} else if (potential_cast->kind() == c10::aten::Float) {
105+
new_const_val->setType(c10::FloatType::get());
106+
}
107+
it->output()->replaceAllUsesWith(new_const_val);
108+
it.destroyCurrent();
109+
110+
LOG_GRAPH("New constant: " << *new_const_val->node());
111+
112+
// Delete NumToTensor
113+
LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor);
114+
num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]);
115+
num_to_tensor->destroy();
116+
117+
// Change intermediate op output type
118+
LOG_GRAPH(user->schema());
119+
120+
torch::jit::Node* new_node;
121+
switch (user->kind()) {
122+
// Use this to handle special cases where the scalar version of the intermediate operator
123+
// has a different schema than the original
124+
case c10::aten::add:
125+
new_node = g->create(
126+
user->kind(),
127+
torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}),
128+
1);
129+
new_node->insertAfter(user);
130+
new_node->outputs()[0]->setType(c10::IntType::get());
131+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
132+
user->destroy();
133+
break;
134+
default:
135+
new_node = g->create(user->kind(), user->inputs(), 1);
136+
new_node->insertAfter(user);
137+
new_node->outputs()[0]->setType(c10::IntType::get());
138+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
139+
user->destroy();
140+
break;
141+
}
142+
143+
LOG_GRAPH("New intermediate operation: " << *new_node);
144+
LOG_GRAPH(new_node->schema());
145+
146+
// Delete aten::Int
147+
LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast);
148+
potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]);
149+
potential_cast->destroy();
150+
}
151+
}
152+
}
153+
}
154+
}
155+
}
156+
}
157+
}
158+
}
159+
}
160+
}
161+
}
162+
LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g);
163+
}
164+
165+
} // namespace passes
166+
} // namespace lowering
167+
} // namespace core
168+
} // namespace torch_tensorrt

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def install_deps(session):
1818

1919
def download_models(session, use_host_env=False):
2020
print("Downloading test models")
21-
session.install('timm')
21+
session.install("-r", os.path.join(TOP_DIR, "tests", "modules", "requirements.txt"))
2222
print(TOP_DIR)
2323
session.chdir(os.path.join(TOP_DIR, "tests", "modules"))
2424
if use_host_env:

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ lowering_test(
5454
name = "test_remove_detach_pass",
5555
)
5656

57+
lowering_test(
58+
name = "test_remove_unnecessary_casts",
59+
)
60+
5761
lowering_test(
5862
name = "test_view_to_reshape_pass",
5963
)
@@ -85,6 +89,7 @@ test_suite(
8589
":test_remove_detach_pass",
8690
":test_view_to_reshape_pass",
8791
":test_remove_dropout_pass",
92+
":test_remove_unnecessary_casts",
8893
":test_reduce_to_pass",
8994
":test_reduce_remainder",
9095
":test_reduce_gelu",

0 commit comments

Comments
 (0)