Skip to content

Commit 8f05ee4

Browse files
committed
Add testing for NumToTensor, Full, and Masked Fill lowering passes
- Generalize a case for NumToTensor in lowering pass
1 parent eff9138 commit 8f05ee4

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

core/lowering/passes/device_casting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
3434

3535
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
3636
std::string num_to_tensor_cast_pattern = R"IR(
37-
graph(%1: int):
37+
graph(%1: Scalar):
3838
%2: Tensor = prim::NumToTensor(%1)
3939
return (%2))IR";
4040

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ lowering_test(
3131
name = "test_conv1d_pass",
3232
)
3333

34+
lowering_test(
35+
name = "test_device_casting",
36+
)
37+
3438
lowering_test(
3539
name = "test_exception_elimination_pass",
3640
)
@@ -91,6 +95,7 @@ test_suite(
9195
name = "lowering_tests",
9296
tests = [
9397
":test_conv1d_pass",
98+
":test_device_casting",
9499
":test_exception_elimination_pass",
95100
":test_linear_to_addmm",
96101
":test_module_fallback_passes",
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "core/util/prelude.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
9+
#include "torch/torch.h"
10+
11+
TEST(LoweringPasses, UnpackAndCastMaskedFillLowersCorrectly) {
12+
const auto graph = R"IR(
13+
graph(%x.1: Tensor, %x.2: Tensor, %x.3: float):
14+
%2 : Tensor = aten::masked_fill_(%x.1, %x.2, %x.3)
15+
return (%2))IR";
16+
17+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA});
18+
auto in2 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kBool);
19+
auto in3 = 7.3;
20+
21+
auto g = std::make_shared<torch::jit::Graph>();
22+
torch::jit::parseIR(graph, g.get());
23+
24+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
25+
torch_tensorrt::core::lowering::passes::UnpackAndCastMaskedFill(g);
26+
torch::jit::EliminateCommonSubexpression(g);
27+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in, in2, in3});
28+
29+
ASSERT_TRUE(
30+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
31+
}
32+
33+
TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) {
34+
const auto graph = R"IR(
35+
graph(%x.1: int):
36+
%2 : Tensor = prim::NumToTensor(%x.1)
37+
return (%2))IR";
38+
39+
// Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt
40+
auto in = 1;
41+
auto g = std::make_shared<torch::jit::Graph>();
42+
torch::jit::parseIR(graph, g.get());
43+
44+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
45+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
46+
torch::jit::EliminateCommonSubexpression(g);
47+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
48+
49+
ASSERT_TRUE(
50+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
51+
}
52+
53+
TEST(LoweringPasses, UnpackAndCastNumToTensorLowersFloatCorrectly) {
54+
const auto graph = R"IR(
55+
graph(%x.1: float):
56+
%2 : Tensor = prim::NumToTensor(%x.1)
57+
return (%2))IR";
58+
59+
auto in = 78.1;
60+
61+
auto g = std::make_shared<torch::jit::Graph>();
62+
torch::jit::parseIR(graph, g.get());
63+
64+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
65+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
66+
torch::jit::EliminateCommonSubexpression(g);
67+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
68+
69+
ASSERT_TRUE(
70+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
71+
}
72+
73+
TEST(LoweringPasses, UnpackAndCastFullIntLowersCorrectly) {
74+
const auto graph = R"IR(
75+
graph(%x.1: int):
76+
%5 : NoneType = prim::Constant()
77+
%2 : int = prim::Constant[value=3]()
78+
%10 : int[] = prim::ListConstruct(%2, %2)
79+
%out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5)
80+
return (%out))IR";
81+
82+
auto in = 4;
83+
84+
auto g = std::make_shared<torch::jit::Graph>();
85+
torch::jit::parseIR(graph, g.get());
86+
87+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
88+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
89+
torch::jit::EliminateCommonSubexpression(g);
90+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
91+
92+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
93+
jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6));
94+
}
95+
96+
TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) {
97+
const auto graph = R"IR(
98+
graph(%x.1: float):
99+
%5 : NoneType = prim::Constant()
100+
%2 : int = prim::Constant[value=5]()
101+
%3 : int = prim::Constant[value=4]()
102+
%10 : int[] = prim::ListConstruct(%2, %3)
103+
%out : Tensor = aten::full(%10, %x.1, %5, %5, %5, %5)
104+
return (%out))IR";
105+
106+
auto in = 54.1;
107+
108+
auto g = std::make_shared<torch::jit::Graph>();
109+
torch::jit::parseIR(graph, g.get());
110+
111+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
112+
torch_tensorrt::core::lowering::passes::UnpackAndCastFull(g);
113+
torch::jit::EliminateCommonSubexpression(g);
114+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
115+
116+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
117+
jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6));
118+
}

0 commit comments

Comments
 (0)