Skip to content

Commit 1410ca3

Browse files
committed
Added test cases for aten::ScalarImplicit
- Resolved issue with Tensor casting issue - Added tests to cover int, float, and complex cases
1 parent 8f05ee4 commit 1410ca3

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

tests/core/lowering/test_device_casting.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "gtest/gtest.h"
66
#include "tests/util/util.h"
77
#include "torch/csrc/jit/ir/irparser.h"
8+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
89
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
910
#include "torch/torch.h"
1011

@@ -36,8 +37,8 @@ TEST(LoweringPasses, UnpackAndCastNumToTensorLowersIntCorrectly) {
3637
%2 : Tensor = prim::NumToTensor(%x.1)
3738
return (%2))IR";
3839

39-
// Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt
4040
auto in = 1;
41+
4142
auto g = std::make_shared<torch::jit::Graph>();
4243
torch::jit::parseIR(graph, g.get());
4344

@@ -116,3 +117,78 @@ TEST(LoweringPasses, UnpackAndCastFullFloatLowersCorrectly) {
116117
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
117118
jit_pre_results[0].toTensor(), jit_post_results[0].toTensor().cpu(), 2e-6));
118119
}
120+
121+
TEST(LoweringPasses, ReplaceScalarImplicitLowersCorrectly) {
122+
const auto graph = R"IR(
123+
graph(%x.1: Tensor):
124+
%5 : int = prim::Constant[value=0]()
125+
%false : bool = prim::Constant[value=0]()
126+
%none : NoneType = prim::Constant()
127+
%cuda : Device = prim::Constant[value="cuda"]()
128+
%3 : int = aten::size(%x.1, %5)
129+
%y.2 : Tensor = prim::NumToTensor(%3)
130+
%y.1 : Tensor = aten::to(%y.2, %cuda, %none, %false, %false)
131+
%19 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
132+
%21 : Tensor, %22 : Tensor = prim::ListUnpack(%19)
133+
%2 : Scalar = aten::ScalarImplicit(%22)
134+
%out : Tensor = prim::NumToTensor(%2)
135+
return (%out))IR";
136+
137+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA});
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, g.get());
141+
142+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
143+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
144+
torch::jit::EliminateCommonSubexpression(g);
145+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
146+
147+
ASSERT_TRUE(
148+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
149+
}
150+
151+
TEST(LoweringPasses, ReplaceScalarImplicitIntNumToTensorLowersCorrectly) {
152+
const auto graph = R"IR(
153+
graph(%x.1: int):
154+
%1 : Tensor = prim::NumToTensor(%x.1)
155+
%2 : Scalar = aten::ScalarImplicit(%1)
156+
%3 : Tensor = prim::NumToTensor(%2)
157+
return (%3))IR";
158+
159+
auto in = 25;
160+
161+
auto g = std::make_shared<torch::jit::Graph>();
162+
torch::jit::parseIR(graph, g.get());
163+
164+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
165+
torch_tensorrt::core::lowering::passes::UnpackAndCastNumToTensor(g);
166+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
167+
torch::jit::EliminateCommonSubexpression(g);
168+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
169+
170+
ASSERT_TRUE(
171+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
172+
}
173+
174+
TEST(LoweringPasses, ReplaceScalarImplicitFloatLowersCorrectly) {
175+
const auto graph = R"IR(
176+
graph(%x.1: float):
177+
%1 : Tensor = prim::NumToTensor(%x.1)
178+
%2 : Scalar = aten::ScalarImplicit(%1)
179+
%3 : Tensor = prim::NumToTensor(%2)
180+
return (%3))IR";
181+
182+
auto in = 2.5;
183+
184+
auto g = std::make_shared<torch::jit::Graph>();
185+
torch::jit::parseIR(graph, g.get());
186+
187+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
188+
torch_tensorrt::core::lowering::passes::ReplaceScalarImplicit(g);
189+
torch::jit::EliminateCommonSubexpression(g);
190+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
191+
192+
ASSERT_TRUE(
193+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
194+
}

0 commit comments

Comments
 (0)