Skip to content

Commit 8af0422

Browse files
authored
Merge pull request #1004 from NVIDIA/support_multiple_delimiters
fix(//core/runtime): Support more delimiter variants
2 parents e44493c + 65af9d1 commit 8af0422

File tree

6 files changed

+60
-5
lines changed

6 files changed

+60
-5
lines changed

core/ir/ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ InputSpecMap pair_input_vals_with_specs(std::vector<const torch::jit::Value*> va
2121

2222
std::unordered_map<const torch::jit::Value*, core::ir::Input> a;
2323
for (size_t i = 0; i < vals.size(); i++) {
24-
LOG_DEBUG("Paring " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
24+
LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
2525
a.insert({vals[i], specs[i]});
2626
}
2727
return a;

core/runtime/TRTEngine.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,25 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5353
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");
5454

5555
exec_ctx = make_trt(cuda_engine->createExecutionContext());
56+
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
5657

5758
uint64_t inputs = 0;
5859
uint64_t outputs = 0;
5960

6061
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
6162
std::string bind_name = cuda_engine->getBindingName(x);
62-
std::string idx_s = bind_name.substr(bind_name.find("_") + 1);
63+
LOG_DEBUG("Binding name: " << bind_name);
64+
auto delim = bind_name.find(".");
65+
if (delim == std::string::npos) {
66+
delim = bind_name.find("_");
67+
TORCHTRT_CHECK(
68+
delim != std::string::npos,
69+
"Unable to determine binding index for input "
70+
<< bind_name
71+
<< "\nEnsure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions");
72+
}
73+
74+
std::string idx_s = bind_name.substr(delim + 1);
6375
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
6476

6577
if (cuda_engine->bindingIsInput(x)) {
@@ -71,6 +83,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7183
}
7284
}
7385
num_io = std::make_pair(inputs, outputs);
86+
87+
LOG_DEBUG(*this);
7488
}
7589

7690
TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
@@ -82,6 +96,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
8296
return (*this);
8397
}
8498

99+
std::string TRTEngine::to_str() const {
100+
std::stringstream ss;
101+
ss << "Torch-TensorRT TensorRT Engine:" << std::endl;
102+
ss << " Name: " << name << std::endl;
103+
ss << " Inputs: [" << std::endl;
104+
for (uint64_t i = 0; i < num_io.first; i++) {
105+
ss << " id: " << i << std::endl;
106+
ss << " shape: " << exec_ctx->getBindingDimensions(i) << std::endl;
107+
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(i)) << std::endl;
108+
}
109+
ss << " ]" << std::endl;
110+
ss << " Outputs: [" << std::endl;
111+
for (uint64_t o = 0; o < num_io.second; o++) {
112+
ss << " id: " << o << std::endl;
113+
ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl;
114+
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl;
115+
}
116+
ss << " ]" << std::endl;
117+
ss << " Device: " << device_info << std::endl;
118+
119+
return ss.str();
120+
}
121+
122+
std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
123+
os << engine.to_str();
124+
return os;
125+
}
126+
85127
// TODO: Implement a call method
86128
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
87129
// auto input_vec = inputs.vec();
@@ -96,6 +138,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
96138
.def(torch::init<std::vector<std::string>>())
97139
// TODO: .def("__call__", &TRTEngine::Run)
98140
// TODO: .def("run", &TRTEngine::Run)
141+
.def("__str__", &TRTEngine::to_str)
99142
.def_pickle(
100143
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
101144
// Serialize TensorRT engine

core/runtime/runtime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ struct TRTEngine : torch::CustomClassHolder {
5959
TRTEngine(std::vector<std::string> serialized_info);
6060
TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);
6161
TRTEngine& operator=(const TRTEngine& other);
62+
std::string to_str() const;
63+
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
6264
// TODO: Implement a call method
6365
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
6466
};

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,12 @@ TORCHTRT_API std::string convert_method_to_trt_engine(
739739
* module. Registers execution of the engine as the forward method of the module
740740
* Forward is defined as: forward(Tensor[]) -> Tensor[]
741741
*
742-
* @return: A new module trageting a TensorRT engine
742+
* TensorRT bindings must have names with the following format:
743+
* - [symbol].[index in input / output array]
744+
* ex.
745+
* - [x.0, x.1, x.2] -> [y.0]
746+
*
747+
* @return: A new module targeting a TensorRT engine
743748
*/
744749
TORCHTRT_API torch::jit::Module embed_engine_in_new_module(const std::string& engine, Device device);
745750
} // namespace torchscript

py/torch_tensorrt/ts/_compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ def embed_engine_in_new_module(serialized_engine: bytes, device=Device._current_
207207
208208
forward(Tensor[]) -> Tensor[]
209209
210+
TensorRT bindings must have names with the following format:
211+
- [symbol].[index in input / output array]
212+
ex.
213+
- [x.0, x.1, x.2] -> [y.0]
214+
210215
Module can be save with engine embedded with torch.jit.save and moved / loaded according to torch_tensorrt portability rules
211216
212217
Arguments:

tests/util/run_graph_engine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::vector<core::ir::Input> toInputs(std::vector<at::Tensor> ten) {
2121
for (auto i : ten) {
2222
a.push_back(core::ir::Input(core::util::toVec(i.sizes())));
2323
}
24-
return std::move(a);
24+
return a;
2525
}
2626

2727
std::vector<core::ir::Input> toInputsDynamic(std::vector<at::Tensor> ten, bool dynamic_batch) {
@@ -49,7 +49,7 @@ std::vector<core::ir::Input> toInputsDynamic(std::vector<at::Tensor> ten, bool d
4949
}
5050
}
5151

52-
return std::move(a);
52+
return a;
5353
}
5454

5555
std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs) {

0 commit comments

Comments
 (0)