|
6 | 6 | #include "torch/csrc/jit/ir/irparser.h"
|
7 | 7 | #include "torch/csrc/jit/ir/subgraph_matcher.h"
|
8 | 8 |
|
9 |
| -TEST(LoweringPasses, ReduceToCorrectly) { |
10 |
| - std::string source_graph = R"IR( |
11 |
| - graph(%x, %device, %dtype, %nb, %copy, %format): |
12 |
| - %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) |
13 |
| - return (%out))IR"; |
14 |
| - std::string target_graph = R"IR( |
15 |
| - graph(%x, %device, %dtype, %nb, %copy, %format): |
16 |
| - %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format) |
17 |
| - return (%out))IR"; |
18 |
| - |
19 |
| - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( |
20 |
| - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); |
21 |
| - auto sg = std::make_shared<torch::jit::Graph>(); |
22 |
| - torch::jit::parseIR(source_graph, &*sg); |
23 |
| - torch_tensorrt::core::lowering::passes::ReduceToOperation(sg); |
24 |
| - |
25 |
| - auto tg = std::make_shared<torch::jit::Graph>(); |
26 |
| - torch::jit::parseIR(target_graph, &*tg); |
27 |
| - |
28 |
| - ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); |
29 |
| -} |
30 |
| - |
31 | 9 | TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) {
|
32 | 10 | std::string source_graph = R"IR(
|
33 |
| - graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format): |
34 |
| - %out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format) |
| 11 | + graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): |
| 12 | + %out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format) |
35 | 13 | return (%out))IR";
|
36 | 14 | std::string target_graph = R"IR(
|
37 |
| - graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format): |
38 |
| - %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format) |
| 15 | + graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): |
| 16 | + %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) |
39 | 17 | return (%out))IR";
|
40 | 18 |
|
41 | 19 | torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
|
|
0 commit comments