Skip to content

Commit 70a7bb3

Browse files
committed
feat(//cpp): Adding example tensors as a way to set input spec
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 01d525d commit 70a7bb3

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

cpp/include/trtorch/trtorch.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ struct TRTORCH_API CompileSpec {
427427
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
428428

429429
/**
430-
* @brief Construct a new Input Range object dynamic input size from
430+
* @brief Construct a new Input spec object dynamic input size from
431431
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
432432
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
433433
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
@@ -462,7 +462,7 @@ struct TRTORCH_API CompileSpec {
462462
TensorFormat format = TensorFormat::kContiguous);
463463

464464
/**
465-
* @brief Construct a new Input Range object dynamic input size from
465+
* @brief Construct a new Input spec object dynamic input size from
466466
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
467467
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
468468
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
@@ -479,7 +479,7 @@ struct TRTORCH_API CompileSpec {
479479
TensorFormat format = TensorFormat::kContiguous);
480480

481481
/**
482-
* @brief Construct a new Input Range object dynamic input size from
482+
* @brief Construct a new Input spec object dynamic input size from
483483
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
484484
* supported sizes
485485
*
@@ -496,6 +496,16 @@ struct TRTORCH_API CompileSpec {
496496
DataType dtype,
497497
TensorFormat format = TensorFormat::kContiguous);
498498

499+
/**
500+
* @brief Construct a new Input spec object using a torch tensor as an example
501+
* The tensor's shape, type and layout inform the spec's values
502+
*
503+
* Note: You cannot set dynamic shape through this method, you must use an alternative constructor
504+
*
505+
* @param tensor Reference tensor to set shape, type and layout
506+
*/
507+
Input(at::Tensor tensor);
508+
499509
bool get_explicit_set_dtype() {
500510
return explicit_set_dtype;
501511
}

cpp/src/compile_spec.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,24 @@ CompileSpec::Input::Input(
287287
this->input_is_dynamic = true;
288288
}
289289

290+
CompileSpec::Input::Input(at::Tensor tensor) {
291+
this->opt_shape = tensor.sizes().vec();
292+
this->min_shape = tensor.sizes().vec();
293+
this->max_shape = tensor.sizes().vec();
294+
this->shape = tensor.sizes().vec();
295+
this->dtype = tensor.scalar_type();
296+
this->explicit_set_dtype = true;
297+
TRTORCH_ASSERT(tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
298+
at::MemoryFormat frmt;
299+
if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
300+
frmt = at::MemoryFormat::Contiguous;
301+
} else {
302+
frmt = at::MemoryFormat::ChannelsLast;
303+
}
304+
this->format = frmt;
305+
this->input_is_dynamic = false;
306+
}
307+
290308
/* ==========================================*/
291309

292310
core::ir::Input to_internal_input(CompileSpec::InputRange& i) {

tests/cpp/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ test_suite(
1616
":test_modules_as_engines",
1717
":test_multiple_registered_engines",
1818
":test_serialization",
19-
":test_module_fallback"
19+
":test_module_fallback",
20+
":test_example_tensors"
2021
],
2122
)
2223

@@ -28,7 +29,8 @@ test_suite(
2829
":test_modules_as_engines",
2930
":test_multiple_registered_engines",
3031
":test_serialization",
31-
":test_module_fallback"
32+
":test_module_fallback",
33+
":test_example_tensors"
3234
],
3335
)
3436

@@ -43,6 +45,17 @@ cc_test(
4345
],
4446
)
4547

48+
cc_test(
49+
name = "test_example_tensors",
50+
srcs = ["test_example_tensors.cpp"],
51+
data = [
52+
"//tests/modules:jit_models",
53+
],
54+
deps = [
55+
":cpp_api_test",
56+
],
57+
)
58+
4659
cc_test(
4760
name = "test_serialization",
4861
srcs = ["test_serialization.cpp"],

tests/cpp/test_example_tensors.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "cpp_api_test.h"
2+
3+
TEST_P(CppAPITests, InputsFromTensors) {
4+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
5+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6+
for (auto in_shape : input_shapes) {
7+
auto in = at::randn(in_shape, {at::kCUDA});
8+
jit_inputs_ivalues.push_back(in.clone());
9+
trt_inputs_ivalues.push_back(in.clone());
10+
}
11+
12+
auto spec = trtorch::CompileSpec({trt_inputs_ivalues[0].toTensor()});
13+
14+
auto trt_mod = trtorch::CompileGraph(mod, spec);
15+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
16+
std::vector<at::Tensor> trt_results;
17+
trt_results.push_back(trt_results_ivalues.toTensor());
18+
}
19+
20+
INSTANTIATE_TEST_SUITE_P(
21+
CompiledModuleForwardIsCloseSuite,
22+
CppAPITests,
23+
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
24+

0 commit comments

Comments
 (0)