Skip to content

Commit cc0d8af

Browse files
feat: Support exporting Torch-TRT compiled Graphmodules (#3262)
Co-authored-by: lanluo-nvidia <[email protected]>
1 parent c24ef24 commit cc0d8af

File tree

12 files changed

+984
-43
lines changed

12 files changed

+984
-43
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ jobs:
196196
pushd .
197197
cd tests/py/dynamo
198198
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
199+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
199200
popd
200201
201202
tests-py-torch-compile-be:

.github/workflows/build-test-windows.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ jobs:
172172
pushd .
173173
cd tests/py/dynamo
174174
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
175+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/reexport_test_results.xml --ir dynamo models/test_reexport.py
175176
popd
176177
177178
tests-py-torch-compile-be:

core/runtime/TRTEngine.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "core/runtime/runtime.h"
1010
#include "core/util/prelude.h"
11+
#include "torch/torch.h"
1112

1213
namespace torch_tensorrt {
1314
namespace core {
@@ -253,6 +254,28 @@ std::string TRTEngine::get_engine_layer_info() {
253254
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
254255
}
255256

257+
std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
258+
std::vector<at::Tensor> outputs;
259+
TORCHTRT_CHECK(
260+
(in_binding_names.size() == input_shapes.size()),
261+
"The number of input shapes provided doesn't match with the number of input names registered.");
262+
// Set all input shapes
263+
for (size_t i = 0; i < input_shapes.size(); i++) {
264+
exec_ctx->setInputShape(in_binding_names[i].c_str(), core::util::toDims(input_shapes[i]));
265+
}
266+
for (size_t i = 0; i < out_binding_names.size(); i++) {
267+
auto output_shape = core::util::toVec(exec_ctx->getTensorShape(out_binding_names[i].c_str()));
268+
auto output_dtype =
269+
core::util::TRTDataTypeToScalarType(cuda_engine->getTensorDataType(out_binding_names[i].c_str()));
270+
auto output_tensor = torch::empty(output_shape, torch::dtype(output_dtype));
271+
outputs.push_back(output_tensor);
272+
}
273+
TORCHTRT_CHECK(
274+
(out_binding_names.size() == outputs.size()),
275+
"The number of output shapes inferred doesn't match with the number of output names registered.");
276+
return outputs;
277+
}
278+
256279
void TRTEngine::set_profiling_paths() {
257280
device_profile_path =
258281
std::filesystem::path{profile_path_prefix + "/" + name + "_device_config_profile.trace"}.string();
@@ -354,6 +377,45 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali
354377
<< ")");
355378
}
356379

380+
FlattenedState TRTEngine::__obj_flatten__() {
381+
// This method would be called by meta kernel of this custom class and it only needs to return a tuple
382+
std::vector<std::string> serialized_info = this->serialize();
383+
384+
return std::tuple(
385+
std::tuple("version", serialized_info[ABI_TARGET_IDX]),
386+
std::tuple("name", serialized_info[NAME_IDX]),
387+
std::tuple("device_info", serialized_info[DEVICE_IDX]),
388+
std::tuple("serialized_engine", serialized_info[ENGINE_IDX]),
389+
std::tuple("in_binding_names", serialized_info[INPUT_BINDING_NAMES_IDX]),
390+
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
391+
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
392+
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
393+
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
394+
}
395+
396+
std::vector<std::string> TRTEngine::serialize() {
397+
// Serialize TensorRT engine
398+
auto serialized_trt_engine = make_trt(this->cuda_engine->serialize());
399+
400+
// Adding device info related meta data to the serialized file
401+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
402+
403+
std::vector<std::string> serialized_info;
404+
serialized_info.resize(SERIALIZATION_LEN);
405+
406+
serialized_info[ABI_TARGET_IDX] = ABI_VERSION;
407+
serialized_info[NAME_IDX] = this->name;
408+
serialized_info[DEVICE_IDX] = this->device_info.serialize();
409+
serialized_info[ENGINE_IDX] = base64_encode(trt_engine);
410+
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
411+
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
412+
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
413+
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
414+
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
415+
416+
return serialized_info;
417+
}
418+
357419
} // namespace runtime
358420
} // namespace core
359421
} // namespace torch_tensorrt

core/runtime/TRTEngine.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ namespace torch_tensorrt {
1919
namespace core {
2020
namespace runtime {
2121

22+
using FlattenedState = std::tuple<
23+
std::tuple<std::string, std::string>, // ABI_VERSION
24+
std::tuple<std::string, std::string>, // name
25+
std::tuple<std::string, std::string>, // device
26+
std::tuple<std::string, std::string>, // engine
27+
std::tuple<std::string, std::string>, // input binding names
28+
std::tuple<std::string, std::string>, // output binding names
29+
std::tuple<std::string, std::string>, // HW compatibility
30+
std::tuple<std::string, std::string>, // serialized metadata
31+
std::tuple<std::string, std::string>>; // Platform
32+
2233
struct TRTEngine : torch::CustomClassHolder {
2334
// Each engine needs it's own runtime object
2435
std::shared_ptr<nvinfer1::IRuntime> rt;
@@ -69,15 +80,21 @@ struct TRTEngine : torch::CustomClassHolder {
6980
void enable_profiling();
7081
void disable_profiling();
7182
std::string get_engine_layer_info();
83+
7284
void dump_engine_layer_info_to_file(const std::string& path);
7385
void dump_engine_layer_info();
7486
int64_t get_device_memory_budget();
7587
bool set_device_memory_budget(int64_t budget);
7688
int64_t get_streamable_device_memory_budget();
7789
int64_t get_automatic_device_memory_budget();
90+
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
7891
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
7992
static const char BINDING_DELIM = '%';
8093

94+
// Serde re-export functionality
95+
FlattenedState __obj_flatten__();
96+
std::vector<std::string> serialize();
97+
8198
// CUDAGraph-Related Functionality
8299
at::cuda::CUDAGraph cudagraph = {};
83100
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();

core/runtime/register_jit_hooks.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
namespace torch_tensorrt {
88
namespace core {
99
namespace runtime {
10-
namespace {
1110

1211
std::string serialize_bindings(const std::vector<std::string>& bindings) {
1312
std::stringstream ss;
@@ -66,6 +65,7 @@ std::string base64_decode(const std::string& in) {
6665
return out;
6766
}
6867

68+
namespace {
6969
// TODO: Implement a call method
7070
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
7171
// auto input_vec = inputs.vec();
@@ -80,51 +80,30 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8080
// TODO: .def("run", &TRTEngine::Run)
8181
.def("__str__", &TRTEngine::to_str)
8282
.def("__repr__", &TRTEngine::to_str)
83+
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
8384
.def("enable_profiling", &TRTEngine::enable_profiling)
8485
.def("disable_profiling", &TRTEngine::disable_profiling)
8586
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
8687
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8788
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8889
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
90+
.def("infer_outputs", &TRTEngine::infer_outputs)
8991
.def_property(
9092
"device_memory_budget",
9193
&TRTEngine::get_device_memory_budget,
9294
&TRTEngine::set_device_memory_budget)
9395
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
9496
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
9597
.def_pickle(
96-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
97-
// Serialize TensorRT engine
98-
auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());
99-
100-
// Adding device info related meta data to the serialized file
101-
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
102-
103-
std::vector<std::string> serialize_info;
104-
serialize_info.resize(SERIALIZATION_LEN);
105-
106-
serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
107-
serialize_info[NAME_IDX] = self->name;
108-
serialize_info[DEVICE_IDX] = self->device_info.serialize();
109-
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
110-
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
111-
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
112-
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
113-
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
114-
serialize_info[TARGET_PLATFORM_IDX] = self->target_platform.serialize();
115-
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
116-
LOG_DEBUG("Serialized Target Platform: " << self->target_platform);
117-
118-
return serialize_info;
119-
},
98+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
12099
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
121100
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
122101
TRTEngine::verify_serialization_fmt(serialized_info);
123102
return c10::make_intrusive<TRTEngine>(serialized_info);
124103
});
125104

126105
TORCH_LIBRARY(tensorrt, m) {
127-
m.def("execute_engine", execute_engine);
106+
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
128107
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
129108
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
130109
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
@@ -171,6 +150,10 @@ TORCH_LIBRARY(tensorrt, m) {
171150
});
172151
}
173152

153+
TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
154+
m.impl("execute_engine", execute_engine);
155+
}
156+
174157
} // namespace
175158
} // namespace runtime
176159
} // namespace core

core/runtime/runtime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ typedef enum {
3333
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
3434
} SerializedInfoIndex;
3535

36+
std::string base64_encode(const std::string& in);
37+
std::string base64_decode(const std::string& in);
38+
std::string serialize_bindings(const std::vector<std::string>& bindings);
39+
3640
c10::optional<RTDevice> get_most_compatible_device(
3741
const RTDevice& target_device,
3842
const RTDevice& curr_device = RTDevice(),

py/torch_tensorrt/_compile.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -666,14 +666,15 @@ def save(
666666
exp_program = export(module)
667667
torch.export.save(exp_program, file_path)
668668
else:
669-
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
670669

671670
if arg_inputs is None:
672671
raise ValueError(
673672
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
674673
)
675-
with enable_torchbind_tracing():
676-
exp_program = torch.export.export(
677-
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
678-
)
679-
torch.export.save(exp_program, file_path)
674+
exp_program = torch.export.export(
675+
module,
676+
tuple(arg_inputs),
677+
kwargs=kwarg_inputs,
678+
strict=False,
679+
)
680+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def get_decompositions(
412412
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
413413
else:
414414
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
415-
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/140085
416415
decomp_table = default_decompositions()
417416
DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
418417
decomp: decomp_table[decomp]

py/torch_tensorrt/dynamo/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
55
TorchTensorRTModule,
66
)
7+
from torch_tensorrt.dynamo.runtime.register_fake_class import *
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import base64
2+
from collections import defaultdict
3+
from typing import Any, List
4+
5+
import torch
6+
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape
7+
8+
9+
@torch.library.register_fake("tensorrt::execute_engine") # type: ignore
10+
def fake_tensorrt_execute_engine(
11+
inputs: List[torch.Tensor], fake_trt_engine: Any
12+
) -> Any:
13+
"""
14+
We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel.
15+
"""
16+
# Here's what we are doing
17+
# 1) Check if inputs are dynamic (they have sym ints in their shapes)
18+
# 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs
19+
# 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism
20+
# 4) Create a new symbolic fake tensor using min and max output shape for each output and return them
21+
# 5) For static inputs, the output shape will be static and we won't need to create sym ints
22+
is_dynamic_execution = input_is_dynamic(inputs)
23+
if is_dynamic_execution:
24+
modes = ["min", "max", "opt"]
25+
else:
26+
modes = ["opt"]
27+
28+
# Get the TRTEngine class and infer output shapes based on input shapes
29+
trt_engine = fake_trt_engine.wrapped_obj.engine
30+
outputs_mode_dict = defaultdict(list)
31+
for mode in modes:
32+
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
33+
proxy_outputs = trt_engine.infer_outputs(input_shapes)
34+
outputs_mode_dict[mode].extend(proxy_outputs)
35+
36+
# Store the number of outputs
37+
if {"min", "max"}.issubset(outputs_mode_dict):
38+
assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"])
39+
num_outputs = len(outputs_mode_dict["min"])
40+
elif "opt" in outputs_mode_dict:
41+
num_outputs = len(outputs_mode_dict["opt"])
42+
43+
fake_outputs = []
44+
for out_idx in range(num_outputs):
45+
output_shape = []
46+
if is_dynamic_execution:
47+
# Create output symbolic shape using unbacked symint.
48+
# Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0)
49+
# and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem
50+
# to affect compilation results / serialization during our testing.
51+
output_min_shape = outputs_mode_dict["min"][out_idx].size()
52+
output_opt_shape = outputs_mode_dict["opt"][out_idx].size()
53+
output_max_shape = outputs_mode_dict["max"][out_idx].size()
54+
55+
ctx = torch._custom_ops.get_ctx()
56+
for min_val, opt_val, max_val in zip(
57+
output_min_shape, output_opt_shape, output_max_shape
58+
):
59+
if min_val != max_val:
60+
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val)
61+
# Update var to val (hint)
62+
output_sym_int_shape_env = output_sym_int.node.shape_env
63+
output_sym_int_shape_env.add_var_to_val(
64+
output_sym_int.node.expr, opt_val
65+
)
66+
output_shape.append(output_sym_int)
67+
else:
68+
output_shape.append(min_val)
69+
else:
70+
output_shape.extend(outputs_mode_dict["opt"][out_idx].size())
71+
72+
fake_outputs.append(
73+
torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype)
74+
)
75+
76+
return fake_outputs
77+
78+
79+
@torch._library.register_fake_class("tensorrt::Engine")
80+
class FakeTRTEngine:
81+
def __init__(self, engine_info: List[str]) -> None:
82+
self.engine = torch.classes.tensorrt.Engine(engine_info)
83+
84+
@classmethod
85+
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
86+
engine_idx = torch.ops.tensorrt.ENGINE_IDX()
87+
engine_info = [info[1] for info in flattened_tq]
88+
engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx])
89+
90+
return cls(engine_info)
91+
92+
def enable_profiling(self) -> Any:
93+
pass
94+
95+
def disable_profiling(self) -> Any:
96+
pass
97+
98+
def dump_engine_layer_info_to_file(self, path: str) -> Any:
99+
pass
100+
101+
def dump_engine_layer_info(self) -> Any:
102+
pass
103+
104+
def get_engine_layer_info(self) -> Any:
105+
pass
106+
107+
def profile_path_prefix_getter(self) -> Any:
108+
pass
109+
110+
def profile_path_prefix_setter(self) -> Any:
111+
pass
112+
113+
def device_memory_budget_getter(self) -> Any:
114+
pass
115+
116+
def device_memory_budget_setter(self) -> Any:
117+
pass
118+
119+
def streamable_device_memory_budget_getter(self) -> Any:
120+
pass
121+
122+
def automatic_device_memory_budget_getter(self) -> Any:
123+
pass
124+
125+
def infer_outputs(self, input_shapes: List[Any]) -> Any:
126+
pass
127+
128+
def __setstate__(self, serialized_state: List[str]) -> Any:
129+
pass

0 commit comments

Comments
 (0)