Skip to content

Commit f205377

Browse files
committed
test: add test suite for conditionals
Signed-off-by: Bo Wang <[email protected]>
1 parent 2af7935 commit f205377

File tree

3 files changed

+98
-3
lines changed

3 files changed

+98
-3
lines changed

tests/core/partitioning/BUILD

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ config_setting(
77
}
88
)
99

10+
filegroup(
11+
name = "jit_models",
12+
srcs = ["//tests/modules:resnet50_traced.jit.pt",
13+
"//tests/modules:mobilenet_v2_traced.jit.pt",
14+
"//tests/modules:conditional_scripted.jit.pt"]
15+
)
16+
1017
partitioning_test(
1118
name = "test_segmentation",
1219
)
@@ -35,17 +42,34 @@ cc_test(
3542
"//conditions:default": ["@libtorch//:libtorch"],
3643
}),
3744
data = [
38-
"//tests/modules:jit_models"
45+
":jit_models"
46+
]
47+
)
48+
49+
cc_test(
50+
name = "test_conditionals",
51+
srcs = ["test_conditionals.cpp"],
52+
deps = [
53+
"//tests/util",
54+
"//core",
55+
"@googletest//:gtest_main",
56+
] + select({
57+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
58+
"//conditions:default": ["@libtorch//:libtorch"],
59+
}),
60+
data = [
61+
":jit_models"
3962
]
4063
)
4164

4265
test_suite(
43-
name = "partitioning_tests",
66+
name = "partitioning_test",
4467
tests = [
4568
":test_segmentation",
4669
":test_shape_analysis",
4770
":test_tensorrt_conversion",
4871
":test_stitched_graph",
49-
":test_fallback_graph_output"
72+
":test_fallback_graph_output",
73+
":test_conditionals"
5074
]
5175
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include <string>
2+
#include <unordered_set>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/script.h"
7+
8+
size_t count_trt_engines_in_conditionals(std::shared_ptr<torch::jit::Graph> g) {
9+
size_t count = 0;
10+
for (auto n : g->nodes()) {
11+
if (n->kind() == torch::jit::prim::If) {
12+
std::vector<torch::jit::Block*> blocks{n->blocks()[0], n->blocks()[1]};
13+
for (auto cur_block : blocks) {
14+
for (auto n : cur_block->nodes()) {
15+
if (n->kind().toQualString() == std::string("tensorrt::execute_engine")) {
16+
++count;
17+
}
18+
}
19+
}
20+
}
21+
}
22+
return count;
23+
}
24+
25+
TEST(Partitioning, FallbackOnConditionalsCorrectly) {
26+
torch::jit::script::Module mod;
27+
try {
28+
mod = torch::jit::load("tests/modules/conditional_scripted.jit.pt");
29+
} catch (const c10::Error& e) {
30+
std::cerr << "error loading the model\n";
31+
return;
32+
}
33+
34+
std::vector<trtorch::core::ir::InputRange> input_ranges{trtorch::core::ir::InputRange({3, 3, 16, 16})};
35+
trtorch::core::CompileSpec cfg(input_ranges);
36+
cfg.partition_info.enabled = true;
37+
torch::jit::script::Module new_mod = trtorch::core::CompileGraph(mod, cfg);
38+
auto g = new_mod.get_method("forward").graph();
39+
40+
auto conditional_engines_count = count_trt_engines_in_conditionals(g);
41+
42+
ASSERT_TRUE(conditional_engines_count == 2);
43+
}

tests/modules/hub.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,31 @@ def forward(self, x):
9595

9696
trace_model = torch.jit.trace(model, x)
9797
torch.jit.save(trace_model, "pooling_traced.jit.pt")
98+
99+
100+
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
101+
class FallbackIf(torch.nn.Module):
102+
def __init__(self):
103+
super(FallbackIf, self).__init__()
104+
self.relu1 = torch.nn.ReLU()
105+
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
106+
self.log_sig = torch.nn.LogSigmoid()
107+
self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
108+
self.conv3 = torch.nn.Conv2d(32, 3, 3, 1, 1)
109+
110+
def forward(self, x):
111+
x = self.relu1(x)
112+
x_first = x[0][0][0][0].item()
113+
if x_first > 0:
114+
x = self.conv1(x)
115+
x1 = self.log_sig(x)
116+
x2 = self.conv2(x)
117+
x = self.conv3(x1 + x2)
118+
else:
119+
x = self.log_sig(x)
120+
x = self.conv1(x)
121+
return x
122+
123+
conditional_model = FallbackIf().eval().cuda()
124+
conditional_script_model = torch.jit.script(conditional_model)
125+
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")

0 commit comments

Comments
 (0)