Skip to content

Commit ef1caa2

Browse files
narendasanabhi-iyer
authored andcommitted
refactor(//cpp/trtorchexec): Demonstrate serialization in trtorchexec
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 493e465 commit ef1caa2

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

cpp/trtorchexec/main.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ int main(int argc, const char* argv[]) {
3838
}
3939

4040
mod.to(at::kCUDA);
41+
mod.eval();
4142

4243
std::vector<std::vector<int64_t>> dims;
4344
for (int i = 2; i < argc; i++) {
@@ -92,7 +93,7 @@ int main(int argc, const char* argv[]) {
9293
std::cout << "Running TRT module" << std::endl;
9394
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
9495
std::vector<at::Tensor> trt_results;
95-
if (trt_results_ivalues.isTensor()) {
96+
if (trt_results_ivalues.isTensor()) {
9697
trt_results.push_back(trt_results_ivalues.toTensor());
9798
} else {
9899
auto results = trt_results_ivalues.toTuple()->elements();
@@ -106,5 +107,8 @@ int main(int argc, const char* argv[]) {
106107
}
107108

108109
std::cout << "Converted Engine saved to /tmp/engine_converted_from_jit.trt" << std::endl;
110+
111+
trt_mod.save("/tmp/ts_trt.ts");
112+
std::cout << "Compiled TorchScript program saved to /tmp/ts_trt.ts" << std::endl;
109113
std::cout << "ok\n";
110114
}

0 commit comments

Comments
 (0)