Skip to content

Commit d7f0a75

Browse files
committed
test: update the test for aten::to after fixing
Signed-off-by: Bo Wang <[email protected]>
1 parent 871e02e commit d7f0a75

File tree

1 file changed

+4
-26
lines changed

1 file changed

+4
-26
lines changed

tests/core/lowering/test_reduce_to_pass.cpp

100644100755
Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,14 @@
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
88

9-
TEST(LoweringPasses, ReduceToCorrectly) {
10-
std::string source_graph = R"IR(
11-
graph(%x, %device, %dtype, %nb, %copy, %format):
12-
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
13-
return (%out))IR";
14-
std::string target_graph = R"IR(
15-
graph(%x, %device, %dtype, %nb, %copy, %format):
16-
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
17-
return (%out))IR";
18-
19-
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
20-
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
21-
auto sg = std::make_shared<torch::jit::Graph>();
22-
torch::jit::parseIR(source_graph, &*sg);
23-
torch_tensorrt::core::lowering::passes::ReduceToOperation(sg);
24-
25-
auto tg = std::make_shared<torch::jit::Graph>();
26-
torch::jit::parseIR(target_graph, &*tg);
27-
28-
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
29-
}
30-
319
TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
3210
std::string source_graph = R"IR(
33-
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
34-
%out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format)
11+
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
12+
%out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format)
3513
return (%out))IR";
3614
std::string target_graph = R"IR(
37-
graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format):
38-
%out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format)
15+
graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format):
16+
%out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format)
3917
return (%out))IR";
4018

4119
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(

0 commit comments

Comments
 (0)