Skip to content

Commit 125a0a6

Browse files
committed
chore: Fix missing headers in tests
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5eba455 commit 125a0a6

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

tests/core/conversion/converters/test_masked_fill.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/torch.h>
12
#include <string>
23
#include "core/compiler.h"
34
#include "core/lowering/passes/passes.h"

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) {
6262
return (%5))IR";
6363
}
6464

65-
void test_body(const std::string& graph, at::Tensor& in) {
65+
void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) {
6666
auto g = std::make_shared<torch::jit::Graph>();
6767
torch::jit::parseIR(graph, g.get());
6868

@@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) {
7171

7272
in = at::clone(in);
7373
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
74-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
74+
std::vector<at::Tensor> trt_results;
75+
if (dynamic) {
76+
trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in});
77+
} else {
78+
trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
79+
}
7580
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
7681
}
7782
} // namespace

tests/core/conversion/converters/test_unpack.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "gtest/gtest.h"
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
8+
#include "torch/torch.h"
79

810
TEST(Converters, UnpackVarLowersCorrectly) {
911
const auto graph = R"IR(

tests/core/conversion/converters/test_where.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/torch.h>
12
#include <string>
23
#include "core/compiler.h"
34
#include "core/lowering/passes/passes.h"

0 commit comments

Comments
 (0)