Skip to content

Commit 21b5f51

Browse files
committed
feat: Add lowering pass for rsqrt operator
- Add unpack rsqrt lowering pass - Add test cases for positive inputs, int and float - Add references to new function in headers and BUILD files
1 parent 1011ac1 commit 21b5f51

File tree

5 files changed

+75
-0
lines changed

5 files changed

+75
-0
lines changed

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ cc_library(
3333
"unpack_hardsigmoid.cpp",
3434
"unpack_hardswish.cpp",
3535
"unpack_log_softmax.cpp",
36+
"unpack_rsqrt.cpp",
3637
"unpack_std.cpp",
3738
"unpack_var.cpp",
3839
"view_to_reshape.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ target_sources(${lib_name}
2020
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
2121
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
2222
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
23+
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp"
2324
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2425
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2526
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3333
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3434
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3535
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
36+
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph);
3637
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
3738
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3839
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/unpack_rsqrt.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string rsqrt_pattern = R"IR(
12+
graph(%1):
13+
%out: Tensor = aten::rsqrt(%1)
14+
return (%out))IR";
15+
std::string unpacked_pattern = R"IR(
16+
graph(%1):
17+
%intermediate: Tensor = aten::sqrt(%1)
18+
%out: Tensor = aten::reciprocal(%intermediate)
19+
return (%out))IR";
20+
21+
torch::jit::SubgraphRewriter rsqrt_rewriter;
22+
rsqrt_rewriter.RegisterRewritePattern(rsqrt_pattern, unpacked_pattern);
23+
rsqrt_rewriter.runOnGraph(graph);
24+
LOG_GRAPH("Post unpack rsqrt: " << *graph);
25+
}
26+
27+
} // namespace passes
28+
} // namespace lowering
29+
} // namespace core
30+
} // namespace torch_tensorrt

tests/core/lowering/test_unpack_reduce_ops.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,45 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
202202
ASSERT_TRUE(
203203
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
204204
}
205+
206+
TEST(LoweringPasses, UnpackRsqrtLowersCorrectly) {
207+
const auto graph = R"IR(
208+
graph(%x.1 : Tensor):
209+
%2 : Tensor = aten::rsqrt(%x.1)
210+
return (%2))IR";
211+
212+
// Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt
213+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}) + 0.01;
214+
215+
auto g = std::make_shared<torch::jit::Graph>();
216+
torch::jit::parseIR(graph, g.get());
217+
218+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
219+
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
220+
torch::jit::EliminateCommonSubexpression(g);
221+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
222+
223+
ASSERT_TRUE(
224+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
225+
}
226+
227+
TEST(LoweringPasses, UnpackRsqrtIntLowersCorrectly) {
228+
const auto graph = R"IR(
229+
graph(%x.1 : Tensor):
230+
%2 : Tensor = aten::rsqrt(%x.1)
231+
return (%2))IR";
232+
233+
// Make range of ints [1, 10]
234+
auto in = at::randint(1, 11, {2, 3, 5, 7}, {at::kCUDA});
235+
236+
auto g = std::make_shared<torch::jit::Graph>();
237+
torch::jit::parseIR(graph, g.get());
238+
239+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
240+
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
241+
torch::jit::EliminateCommonSubexpression(g);
242+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
243+
244+
ASSERT_TRUE(
245+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
246+
}

0 commit comments

Comments
 (0)