Skip to content

Commit e2d0e82

Browse files
committed
feat(aten::Int): Lowers out aten::Int
This commit adds a pass to lower out aten::[Int/Float/Bool], aten::NumToTensor pairs w.o. exception. We are assumming this is safe as there are similar passes in PyTorch for ONNX lowering however the scope of this rule is intentionally limited to avoid possible cases where it is not safe. Therefore it should not be expected that all aten::Int issues will be solved with this change and the operator itself remains a limitation of TorchTRT Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 11c4608 commit e2d0e82

File tree

6 files changed

+149
-0
lines changed

6 files changed

+149
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
22
#include "torch/csrc/jit/passes/create_functional_graphs.h"
33
#include "torch/csrc/jit/passes/dead_code_elimination.h"
4+
#include "torch/csrc/jit/passes/erase_number_types.h"
45
#include "torch/csrc/jit/passes/freeze_module.h"
56
#include "torch/csrc/jit/passes/fuse_linear.h"
67
#include "torch/csrc/jit/passes/guard_elimination.h"
@@ -63,6 +64,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6364
passes::RemoveNOPs(g);
6465
passes::AliasOperators(g);
6566
passes::SiluToSigmoidMultipication(g);
67+
passes::RemoveUnnecessaryCasts(g);
6668
LOG_GRAPH(*g);
6769
}
6870

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
"view_to_reshape.cpp",
2424
"remove_dropout.cpp",
2525
"remove_nops.cpp",
26+
"remove_unnecessary_casts.cpp",
2627
"silu_to_sigmoid_multiplication.cpp",
2728
"unpack_addmm.cpp",
2829
"unpack_batch_norm.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
2727
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2828
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
2929
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
30+
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3031
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3132
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3233
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
#include <vector>
6+
7+
namespace torch_tensorrt {
8+
namespace core {
9+
namespace lowering {
10+
namespace passes {
11+
12+
13+
// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just
14+
// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright
15+
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) {
16+
std::string int_cast_pattern = R"IR(
17+
graph(%1: int):
18+
%2: Tensor = aten::NumToTensor(%1)
19+
%3: int = aten::Int(%2)
20+
return (%3))IR";
21+
std::string int_clean_pattern = R"IR(
22+
graph(%1: int):
23+
return (%1))IR";
24+
25+
std::string float_cast_pattern = R"IR(
26+
graph(%1: float):
27+
%2: Tensor = aten::NumToTensor(%1)
28+
%3: float = aten::Float(%2)
29+
return (%3))IR";
30+
std::string float_clean_pattern = R"IR(
31+
graph(%1: float):
32+
return (%1))IR";
33+
34+
std::string bool_cast_pattern = R"IR(
35+
graph(%1: bool):
36+
%2: Tensor = aten::NumToTensor(%1)
37+
%3: bool = aten::Bool(%2)
38+
return (%3))IR";
39+
std::string bool_clean_pattern = R"IR(
40+
graph(%1: bool):
41+
return (%1))IR";
42+
43+
torch::jit::SubgraphRewriter int_cast_rewriter;
44+
int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern);
45+
int_cast_rewriter.runOnGraph(graph);
46+
47+
torch::jit::SubgraphRewriter float_cast_rewriter;
48+
float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern);
49+
float_cast_rewriter.runOnGraph(graph);
50+
51+
torch::jit::SubgraphRewriter bool_cast_rewriter;
52+
bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern);
53+
bool_cast_rewriter.runOnGraph(graph);
54+
55+
LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph);
56+
}
57+
58+
} // namespace passes
59+
} // namespace lowering
60+
} // namespace core
61+
} // namespace torch_tensorrt

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ lowering_test(
5050
name = "test_remove_detach_pass",
5151
)
5252

53+
lowering_test(
54+
name = "test_remove_unnecessary_casts",
55+
)
56+
5357
lowering_test(
5458
name = "test_view_to_reshape_pass",
5559
)
@@ -81,6 +85,7 @@ test_suite(
8185
":test_remove_detach_pass",
8286
":test_view_to_reshape_pass",
8387
":test_remove_dropout_pass",
88+
":test_remove_unnecessary_casts",
8489
":test_reduce_to_pass",
8590
":test_reduce_gelu",
8691
":test_unpack_hardswish",
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%1: int):
12+
%2: Tensor = aten::NumToTensor(%1)
13+
%3: int = aten::Int(%2)
14+
%4: int = aten::add(%3, %3, %3)
15+
return (%4))IR";
16+
std::string target_graph = R"IR(
17+
graph(%1: int):
18+
%4: int = aten::add(%1, %1, %1)
19+
return (%4))IR";
20+
21+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
22+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
23+
auto sg = std::make_shared<torch::jit::Graph>();
24+
torch::jit::parseIR(source_graph, sg.get());
25+
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
26+
27+
auto tg = std::make_shared<torch::jit::Graph>();
28+
torch::jit::parseIR(target_graph, tg.get());
29+
30+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
31+
}
32+
33+
TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) {
34+
std::string source_graph = R"IR(
35+
graph(%1: float):
36+
%2: Tensor = aten::NumToTensor(%1)
37+
%3: float = aten::Float(%2)
38+
%4: float = aten::add(%3, %3, %3)
39+
return (%3))IR";
40+
std::string target_graph = R"IR(
41+
graph(%1: float):
42+
%4: float = aten::add(%1, %1, %1)
43+
return (%4))IR";
44+
45+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
46+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
47+
auto sg = std::make_shared<torch::jit::Graph>();
48+
torch::jit::parseIR(source_graph, sg.get());
49+
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
50+
51+
auto tg = std::make_shared<torch::jit::Graph>();
52+
torch::jit::parseIR(target_graph, tg.get());
53+
54+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
55+
}
56+
57+
TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) {
58+
std::string source_graph = R"IR(
59+
graph(%1: bool):
60+
%2: Tensor = aten::NumToTensor(%1)
61+
%3: bool = aten::Bool(%2)
62+
%4: bool = aten::__and__(%3, %3)
63+
return (%3))IR";
64+
std::string target_graph = R"IR(
65+
graph(%1: bool):
66+
%4: bool = aten::__and__(%1, %1)
67+
return (%4))IR";
68+
69+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
70+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
71+
auto sg = std::make_shared<torch::jit::Graph>();
72+
torch::jit::parseIR(source_graph, sg.get());
73+
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
74+
75+
auto tg = std::make_shared<torch::jit::Graph>();
76+
torch::jit::parseIR(target_graph, tg.get());
77+
78+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
79+
}

0 commit comments

Comments
 (0)