Skip to content

Commit d6c8d31

Browse files
authored
Merge pull request #88 from NVIDIA/execution_improvements
Adds a destructor for the new class that was previously in the old execution manager
2 parents 1168092 + 9eac5c9 commit d6c8d31

File tree

5 files changed

+14
-9
lines changed

5 files changed

+14
-9
lines changed

core/compiler.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4242

4343

4444
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
45-
auto engine = execution::TRTEngine(mod._ivalue()->name(), serialized_engine);
45+
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
4646
// Get required metadata about the engine out
47-
auto num_io = engine.num_io;
48-
auto name = engine.name;
47+
auto num_io = engine_ptr->num_io;
48+
auto name = engine_ptr->name;
4949

5050
// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
51-
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(engine);
5251
mod.register_attribute(
5352
name,
5453
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ using ConverterLUT = std::unordered_map<c10::OperatorName, OpConverter>;
4646
class NodeConverterRegistry {
4747
public:
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
49-
LOG_DEBUG("Registering Converter for " << canonical_schema_string(*signature));
49+
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
5050
auto name = signature->operator_name();
5151
converter_lut_[name] = std::move(converter);
5252
return true;

core/execution/TRTEngine.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6060
return (*this);
6161
}
6262

63+
TRTEngine::~TRTEngine() {
64+
exec_ctx->destroy();
65+
cuda_engine->destroy();
66+
rt->destroy();
67+
}
68+
6369
// TODO: Implement a call method
6470
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
6571
// auto input_vec = inputs.vec();

core/execution/execution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct TRTEngine : torch::CustomClassHolder {
2222
std::string name;
2323
util::logging::TRTorchLogger logger;
2424

25-
TRTEngine() = default;
25+
~TRTEngine();
2626
TRTEngine(std::string serialized_engine);
2727
TRTEngine(std::string mod_name, std::string serialized_engine);
2828
TRTEngine& operator=(const TRTEngine& other);

core/util/macros.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@
1111
l.log(sev, ss.str()); \
1212
} while (0)
1313

14-
#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
14+
#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
1515
#define LOG_DEBUG_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kDEBUG, s)
1616
#define LOG_INFO_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINFO, s)
1717
#define LOG_WARNING_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kWARNING, s)
1818
#define LOG_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kERROR, s)
1919
#define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s)
2020

21-
#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
21+
#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
2222
#define LOG_DEBUG_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kDEBUG, s)
2323
#define LOG_INFO_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINFO, s)
2424
#define LOG_WARNING_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kWARNING, s)
2525
#define LOG_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kERROR, s)
2626
#define LOG_INTERNAL_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINTERNAL_ERROR, s)
2727

28-
#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__)
28+
#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__)
2929
#define LOG_DEBUG(...) GET_MACRO(__VA_ARGS__, LOG_DEBUG_OWN, LOG_DEBUG_GLOBAL)(__VA_ARGS__)
3030
#define LOG_INFO(...) GET_MACRO(__VA_ARGS__, LOG_INFO_OWN, LOG_INFO_GLOBAL)(__VA_ARGS__)
3131
#define LOG_WARNING(...) GET_MACRO(__VA_ARGS__, LOG_WARNING_OWN, LOG_WARNING_GLOBAL)(__VA_ARGS__)

0 commit comments

Comments
 (0)