|
4 | 4 | #include "tests/util/util.h"
|
5 | 5 | #include "core/compiler.h"
|
6 | 6 |
|
7 |
| -TEST(Converters, ATenLogConvertsCorrectly) { |
8 |
| - const auto graph = R"IR( |
| 7 | +namespace { |
| 8 | +std::string gen_test_graph(const std::string &unary) { |
| 9 | + return R"IR( |
9 | 10 | graph(%0 : Tensor):
|
10 |
| - %3 : Tensor = aten::log(%0) |
| 11 | + %3 : Tensor = aten::)IR" + |
| 12 | + unary + R"IR((%0) |
11 | 13 | return (%3))IR";
|
| 14 | +} |
| 15 | +} // namespace |
12 | 16 |
|
13 |
| - auto g = std::make_shared<torch::jit::Graph>(); |
14 |
| - torch::jit::script::parseIR(graph, &*g); |
15 |
| - |
16 |
| - auto in = at::randint(1, 5, {5}, {at::kCUDA}); |
17 |
| - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); |
18 |
| - auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); |
| 17 | +#define test_unary(unary, name) \ |
| 18 | + TEST(Converters, ATen##name##ConvertsCorrectly) { \ |
| 19 | + const auto graph = gen_test_graph(#unary); \ |
| 20 | + \ |
| 21 | + auto g = std::make_shared<torch::jit::Graph>(); \ |
| 22 | + torch::jit::script::parseIR(graph, &*g); \ |
| 23 | + \ |
| 24 | + auto in = at::empty({10}, {at::kCUDA}).uniform_(0, 0.5); \ |
| 25 | + auto params = \ |
| 26 | + trtorch::core::conversion::get_named_params(g->inputs(), {}); \ |
| 27 | + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); \ |
| 28 | + \ |
| 29 | + in = at::clone(in); \ |
| 30 | + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); \ |
| 31 | + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); \ |
| 32 | + \ |
| 33 | + ASSERT_TRUE( \ |
| 34 | + trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); \ |
| 35 | + } |
19 | 36 |
|
20 |
| - in = at::clone(in); |
21 |
| - params = trtorch::core::conversion::get_named_params(g->inputs(), {}); |
22 |
| - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); |
| 37 | +test_unary(cos, Cos); |
| 38 | +test_unary(acos, Acos); |
| 39 | +test_unary(cosh, Cosh); |
| 40 | +test_unary(sin, Sin); |
| 41 | +test_unary(asin, Asin); |
| 42 | +test_unary(sinh, Sinh); |
| 43 | +test_unary(tan, Tan); |
| 44 | +test_unary(atan, Atan); |
| 45 | +test_unary(abs, Abs); |
| 46 | +test_unary(floor, Floor); |
| 47 | +test_unary(reciprocal, Reciprocal); |
| 48 | +test_unary(log, Log); |
| 49 | +test_unary(ceil, Ceil); |
| 50 | +test_unary(sqrt, Sqrt); |
| 51 | +test_unary(exp, Exp); |
| 52 | +test_unary(neg, Neg); |
23 | 53 |
|
24 |
| - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0])); |
25 |
| -} |
| 54 | +#undef test_unary |
0 commit comments