Skip to content

Commit eab0dba

Browse files
committed
chore: Fix save failures
1 parent 3ece71b commit eab0dba

File tree

4 files changed

+100
-48
lines changed

4 files changed

+100
-48
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
241241
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
242242
<< std::endl;
243243
}
244-
ss << " }" << std::endl;
244+
ss << " ]" << std::endl;
245245
ss << " Device: " << device_info << std::endl;
246246
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
247247
// clang-format on

py/torch_tensorrt/_compile.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,5 +397,10 @@ def save(
397397
exp_program = export(module, inputs)
398398
torch.export.save(exp_program, file_path)
399399
else:
400-
exp_program = torch.export.export(module, tuple(inputs), strict=False)
401-
torch.export.save(exp_program, file_path)
400+
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
401+
402+
with enable_torchbind_tracing():
403+
exp_program = torch.export.export(
404+
module, tuple(inputs), strict=False
405+
)
406+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import operator
23
from typing import Any, Dict, Sequence, Tuple, cast
34

@@ -6,8 +7,11 @@
67
from torch._subclasses.fake_tensor import FakeTensor
78
from torch.export import ExportedProgram, ExportGraphSignature
89
from torch.export.exported_program import (
10+
CustomObjArgument,
911
InputKind,
1012
InputSpec,
13+
ModuleCallEntry,
14+
ModuleCallSignature,
1115
OutputKind,
1216
OutputSpec,
1317
TensorArgument,
@@ -44,24 +48,27 @@ def transform(
4448
4549
Returns an inlined torch.fx.GraphModule
4650
"""
51+
gm_export = copy.deepcopy(gm)
4752
# Run shape analysis
48-
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
53+
_, outputs_map = partitioning.run_shape_analysis(gm_export, inputs)
4954

5055
# Inline TensorRT submodules
51-
inline_trt_modules(gm, outputs_map)
56+
inline_trt_modules(gm_export, outputs_map)
5257

5358
# Inline pytorch submodules
54-
inline_torch_modules(gm)
59+
inline_torch_modules(gm_export)
5560

5661
# Clean the graph
57-
gm.delete_all_unused_submodules()
58-
gm.graph.eliminate_dead_code()
59-
gm.graph.lint()
62+
gm_export.delete_all_unused_submodules()
63+
gm_export.graph.eliminate_dead_code()
64+
gm_export.graph.lint()
6065

61-
return gm
66+
return gm_export
6267

6368

64-
def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule:
69+
def lift(
70+
gm: torch.fx.GraphModule, graph_signature: Any
71+
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature, Dict[str, Any], Dict[str, Any]]:
6572
"""
6673
Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
6774
Arguments:
@@ -75,6 +82,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
7582
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
7683
# has all parameters registered as torch.tensors.
7784
state_dict = gm.state_dict()
85+
constants = {}
7886

7987
fake_mode = detect_fake_mode(
8088
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
@@ -89,52 +97,68 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
8997
break
9098

9199
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
92-
# The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
100+
# The input_specs should be of the form [params, buffers, constant_tensors, custom_obj, user_inputs]
93101
non_user_input_idx = 0
94102
for node in gm.graph.nodes:
95103
if node.op == "get_attr":
96-
if node.target not in state_dict:
97-
raise ValueError(
98-
f"The get_attr node : {node.name} with target: {node.target} value could not be found in state_dict. Please check the input exported_program's graphmodule parameters."
99-
)
100104

101-
constant_tensor = state_dict[node.target]
102-
input_kind = InputKind.CONSTANT_TENSOR
105+
lift_val = None
106+
input_kind = None
103107

104-
# state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
105-
for name, _ in gm.named_parameters():
106-
if node.target == name:
107-
input_kind = InputKind.PARAMETER
108-
state_dict[name] = torch.nn.Parameter(state_dict[name])
109-
break
110-
for name, _ in gm.named_buffers():
111-
if node.target == name:
112-
input_kind = InputKind.BUFFER
113-
break
108+
if node.target not in state_dict:
109+
constants[node.target] = getattr(gm, node.target)
110+
input_kind = InputKind.CUSTOM_OBJ
111+
lift_val = constants[node.target]
112+
else:
113+
lift_val = state_dict[node.target]
114+
115+
input_kind = InputKind.CONSTANT_TENSOR
116+
117+
# state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
118+
for name, _ in gm.named_parameters():
119+
if node.target == name:
120+
input_kind = InputKind.PARAMETER
121+
state_dict[name] = torch.nn.Parameter(state_dict[name])
122+
break
123+
for name, _ in gm.named_buffers():
124+
if node.target == name:
125+
input_kind = InputKind.BUFFER
126+
break
127+
128+
assert lift_val is not None and input_kind is not None
114129

115130
# Replace get_attr nodes with placeholder nodes and copy metadata.
116131
with gm.graph.inserting_before(first_user_input):
117-
const_placeholder_node = gm.graph.placeholder(node.target)
132+
const_placeholder_node = gm.graph.placeholder(
133+
node.target.replace(".", "_")
134+
)
118135
# Copy the node meta into this new placeholder node
119136
const_placeholder_node.meta = node.meta
120-
const_placeholder_node.meta["val"] = cast(
121-
FakeTensor,
122-
torch.empty_strided(
123-
tuple(constant_tensor.shape),
124-
tuple([1] * len(constant_tensor.shape)),
125-
),
126-
)
137+
138+
if isinstance(lift_val, torch.Tensor):
139+
const_placeholder_node.meta["val"] = cast(
140+
FakeTensor,
141+
torch.empty_strided(
142+
tuple(lift_val.shape),
143+
tuple([1] * len(lift_val.shape)),
144+
),
145+
)
127146

128147
node.replace_all_uses_with(const_placeholder_node)
129148
gm.graph.erase_node(node)
130149

131150
# Add these parameters/buffers/constants to the existing graph signature
132151
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
152+
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
153+
if input_kind == InputKind.CUSTOM_OBJ:
154+
input_spec_arg = CustomObjArgument(
155+
name=const_placeholder_node.name, class_fqn=""
156+
)
133157
graph_signature.input_specs.insert(
134158
non_user_input_idx,
135159
InputSpec(
136160
kind=input_kind,
137-
arg=TensorArgument(name=const_placeholder_node.name),
161+
arg=input_spec_arg,
138162
target=node.target,
139163
),
140164
)
@@ -143,7 +167,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
143167
gm.graph.eliminate_dead_code()
144168
gm.graph.lint()
145169

146-
return gm, graph_signature, state_dict
170+
return gm, graph_signature, state_dict, constants
147171

148172

149173
def get_duplicate_nodes(
@@ -281,18 +305,30 @@ def create_trt_exp_program(
281305
input_specs=input_specs, output_specs=output_specs
282306
)
283307

308+
module_call_graph = [
309+
ModuleCallEntry(
310+
"",
311+
ModuleCallSignature(
312+
inputs=[],
313+
outputs=[],
314+
in_spec=gm.graph._codegen.pytree_info.in_spec,
315+
out_spec=gm.graph._codegen.pytree_info.out_spec,
316+
),
317+
)
318+
]
319+
284320
# Lift parameters/buffers/constants in the graph
285321
# torch.export serialization expects them to be lifted
286-
gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature)
322+
gm, trt_graph_signature, state_dict, constants = lift(gm, trt_graph_signature)
287323

288324
trt_exp_program = ExportedProgram(
289-
gm,
290-
gm.graph,
291-
trt_graph_signature,
292-
state_dict,
293-
{},
294-
[],
295-
[],
325+
root=gm,
326+
graph=gm.graph,
327+
graph_signature=trt_graph_signature,
328+
state_dict=state_dict,
329+
range_constraints={},
330+
module_call_graph=module_call_graph,
331+
constants=constants,
296332
)
297333

298334
return trt_exp_program
@@ -319,9 +355,13 @@ def inline_trt_modules(
319355
num_outputs = len(outputs_map[trt_module_node.name])
320356
# Insert a call_function node to perform inference on TRT engine
321357
with gm.graph.inserting_before(trt_module_node):
358+
engine_name = f"{name}_engine"
359+
setattr(gm, engine_name, trt_module.engine)
360+
engine_node = gm.graph.get_attr(engine_name)
361+
322362
trt_node = gm.graph.call_function(
323363
torch.ops.tensorrt.execute_engine.default,
324-
(trt_module_node.args, trt_module.engine),
364+
(trt_module_node.args, engine_node),
325365
)
326366
trt_node.meta["val"] = []
327367
assert num_outputs > 0
@@ -337,6 +377,13 @@ def inline_trt_modules(
337377
)
338378
)
339379

380+
# meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
381+
# Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
382+
# for custom object nodes, it should be CustomObjArgument
383+
engine_node.meta["val"] = CustomObjArgument(
384+
name=engine_node.name, class_fqn=""
385+
)
386+
340387
if num_outputs == 1:
341388
# Insert getitem nodes as outputs (for export serialization to work)
342389
with gm.graph.inserting_after(trt_node):

tests/py/dynamo/models/test_export_serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def forward(self, x):
146146
)
147147
],
148148
"ir": ir,
149-
"debug": True,
150149
}
151150

152151
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
@@ -306,6 +305,7 @@ def forward(self, x):
306305

307306
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
308307
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
308+
309309
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
310310
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
311311
deser_trt_module = deser_trt_exp_program.module()

0 commit comments

Comments
 (0)