Skip to content

Commit f9295ce

Browse files
author
Anurag Dixit
committed
(//test): Added test case for set_device API
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 648a4f5 commit f9295ce

File tree

5 files changed

+52
-4
lines changed

5 files changed

+52
-4
lines changed

core/execution/TRTEngine.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, std::s
4343
util::logging::get_logger().get_is_colored_output_on()) {
4444

4545
// Deserialize device meta data if device_info is non-empty
46-
if (!serialized_device_info.empty())
47-
{
46+
if (!serialized_device_info.empty()) {
4847
auto cuda_device = deserialize_device(serialized_device_info);
4948
// Set CUDA device as configured in serialized meta data
5049
set_cuda_device(cuda_device);
@@ -124,7 +123,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("te
124123
serialize_info.push_back(trt_engine);
125124
return serialize_info;
126125
},
127-
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
126+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
128127
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
129128
}
130129
);

tests/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ test_suite(
22
name = "tests",
33
tests = [
44
"//tests/core/converters:test_converters",
5-
"//tests/modules:test_modules"
5+
"//tests/modules:test_modules",
6+
"//tests/api:test_apis"
67
],
78
)
89

tests/api/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("//tests/api:api_test.bzl", "api_test")
2+
3+
api_test(
4+
name = "test_device"
5+
)
6+
7+
test_suite(
8+
name = "test_apis",
9+
tests = [
10+
":test_device"
11+
]
12+
)

tests/api/api_test.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
def api_test(name, visibility=None):
3+
native.cc_test(
4+
name = name,
5+
srcs = [name + ".cpp"],
6+
visibility = visibility,
7+
deps = [
8+
"//tests/util",
9+
"//core",
10+
"@googletest//:gtest_main",
11+
],
12+
timeout="short"
13+
)
14+

tests/api/test_device.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <string>
2+
#include "cuda_runtime_api.h"
3+
#include "gtest/gtest.h"
4+
#include "torch/csrc/jit/ir/irparser.h"
5+
#include "tests/util/util.h"
6+
#include "core/compiler.h"
7+
8+
TEST(API, TRTorchSetDeviceTest) {
9+
// Check number of CUDA capable device on the target
10+
int device_count = -1;
11+
assert(cudaGetDeviceCount(&device_count) == cudaSuccess);
12+
assert(device_count != 0);
13+
14+
int gpu_id = device_count-1;
15+
trtorch::core::set_device(gpu_id);
16+
17+
// Verify if the device ID is set correctly
18+
int device = -1;
19+
assert(cudaGetDevice(&device) == cudaSuccess);
20+
21+
ASSERT_TRUE(device == gpu_id);
22+
}

0 commit comments

Comments
 (0)