1
+ import copy
1
2
import operator
2
3
from typing import Any , Dict , Sequence , Tuple , cast
3
4
6
7
from torch ._subclasses .fake_tensor import FakeTensor
7
8
from torch .export import ExportedProgram , ExportGraphSignature
8
9
from torch .export .exported_program import (
10
+ CustomObjArgument ,
9
11
InputKind ,
10
12
InputSpec ,
13
+ ModuleCallEntry ,
14
+ ModuleCallSignature ,
11
15
OutputKind ,
12
16
OutputSpec ,
13
17
TensorArgument ,
@@ -44,24 +48,27 @@ def transform(
44
48
45
49
Returns an inlined torch.fx.GraphModule
46
50
"""
51
+ gm_export = copy .deepcopy (gm )
47
52
# Run shape analysis
48
- _ , outputs_map = partitioning .run_shape_analysis (gm , inputs )
53
+ _ , outputs_map = partitioning .run_shape_analysis (gm_export , inputs )
49
54
50
55
# Inline TensorRT submodules
51
- inline_trt_modules (gm , outputs_map )
56
+ inline_trt_modules (gm_export , outputs_map )
52
57
53
58
# Inline pytorch submodules
54
- inline_torch_modules (gm )
59
+ inline_torch_modules (gm_export )
55
60
56
61
# Clean the graph
57
- gm .delete_all_unused_submodules ()
58
- gm .graph .eliminate_dead_code ()
59
- gm .graph .lint ()
62
+ gm_export .delete_all_unused_submodules ()
63
+ gm_export .graph .eliminate_dead_code ()
64
+ gm_export .graph .lint ()
60
65
61
- return gm
66
+ return gm_export
62
67
63
68
64
- def lift (gm : torch .fx .GraphModule , graph_signature : Any ) -> torch .fx .GraphModule :
69
+ def lift (
70
+ gm : torch .fx .GraphModule , graph_signature : Any
71
+ ) -> Tuple [torch .fx .GraphModule , ExportGraphSignature , Dict [str , Any ], Dict [str , Any ]]:
65
72
"""
66
73
Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
67
74
Arguments:
@@ -75,6 +82,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
75
82
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
76
83
# has all parameters registered as torch.tensors.
77
84
state_dict = gm .state_dict ()
85
+ constants = {}
78
86
79
87
fake_mode = detect_fake_mode (
80
88
tuple (node .meta ["val" ] for node in gm .graph .nodes if node .op == "placeholder" )
@@ -89,52 +97,68 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
89
97
break
90
98
91
99
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
92
- # The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
100
+ # The input_specs should be of the form [params, buffers, constant_tensors, custom_obj, user_inputs]
93
101
non_user_input_idx = 0
94
102
for node in gm .graph .nodes :
95
103
if node .op == "get_attr" :
96
- if node .target not in state_dict :
97
- raise ValueError (
98
- f"The get_attr node : { node .name } with target: { node .target } value could not be found in state_dict. Please check the input exported_program's graphmodule parameters."
99
- )
100
104
101
- constant_tensor = state_dict [ node . target ]
102
- input_kind = InputKind . CONSTANT_TENSOR
105
+ lift_val = None
106
+ input_kind = None
103
107
104
- # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
105
- for name , _ in gm .named_parameters ():
106
- if node .target == name :
107
- input_kind = InputKind .PARAMETER
108
- state_dict [name ] = torch .nn .Parameter (state_dict [name ])
109
- break
110
- for name , _ in gm .named_buffers ():
111
- if node .target == name :
112
- input_kind = InputKind .BUFFER
113
- break
108
+ if node .target not in state_dict :
109
+ constants [node .target ] = getattr (gm , node .target )
110
+ input_kind = InputKind .CUSTOM_OBJ
111
+ lift_val = constants [node .target ]
112
+ else :
113
+ lift_val = state_dict [node .target ]
114
+
115
+ input_kind = InputKind .CONSTANT_TENSOR
116
+
117
+ # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
118
+ for name , _ in gm .named_parameters ():
119
+ if node .target == name :
120
+ input_kind = InputKind .PARAMETER
121
+ state_dict [name ] = torch .nn .Parameter (state_dict [name ])
122
+ break
123
+ for name , _ in gm .named_buffers ():
124
+ if node .target == name :
125
+ input_kind = InputKind .BUFFER
126
+ break
127
+
128
+ assert lift_val is not None and input_kind is not None
114
129
115
130
# Replace get_attr nodes with placeholder nodes and copy metadata.
116
131
with gm .graph .inserting_before (first_user_input ):
117
- const_placeholder_node = gm .graph .placeholder (node .target )
132
+ const_placeholder_node = gm .graph .placeholder (
133
+ node .target .replace ("." , "_" )
134
+ )
118
135
# Copy the node meta into this new placeholder node
119
136
const_placeholder_node .meta = node .meta
120
- const_placeholder_node .meta ["val" ] = cast (
121
- FakeTensor ,
122
- torch .empty_strided (
123
- tuple (constant_tensor .shape ),
124
- tuple ([1 ] * len (constant_tensor .shape )),
125
- ),
126
- )
137
+
138
+ if isinstance (lift_val , torch .Tensor ):
139
+ const_placeholder_node .meta ["val" ] = cast (
140
+ FakeTensor ,
141
+ torch .empty_strided (
142
+ tuple (lift_val .shape ),
143
+ tuple ([1 ] * len (lift_val .shape )),
144
+ ),
145
+ )
127
146
128
147
node .replace_all_uses_with (const_placeholder_node )
129
148
gm .graph .erase_node (node )
130
149
131
150
# Add these parameters/buffers/constants to the existing graph signature
132
151
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
152
+ input_spec_arg = TensorArgument (name = const_placeholder_node .name )
153
+ if input_kind == InputKind .CUSTOM_OBJ :
154
+ input_spec_arg = CustomObjArgument (
155
+ name = const_placeholder_node .name , class_fqn = ""
156
+ )
133
157
graph_signature .input_specs .insert (
134
158
non_user_input_idx ,
135
159
InputSpec (
136
160
kind = input_kind ,
137
- arg = TensorArgument ( name = const_placeholder_node . name ) ,
161
+ arg = input_spec_arg ,
138
162
target = node .target ,
139
163
),
140
164
)
@@ -143,7 +167,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule
143
167
gm .graph .eliminate_dead_code ()
144
168
gm .graph .lint ()
145
169
146
- return gm , graph_signature , state_dict
170
+ return gm , graph_signature , state_dict , constants
147
171
148
172
149
173
def get_duplicate_nodes (
@@ -281,18 +305,30 @@ def create_trt_exp_program(
281
305
input_specs = input_specs , output_specs = output_specs
282
306
)
283
307
308
+ module_call_graph = [
309
+ ModuleCallEntry (
310
+ "" ,
311
+ ModuleCallSignature (
312
+ inputs = [],
313
+ outputs = [],
314
+ in_spec = gm .graph ._codegen .pytree_info .in_spec ,
315
+ out_spec = gm .graph ._codegen .pytree_info .out_spec ,
316
+ ),
317
+ )
318
+ ]
319
+
284
320
# Lift parameters/buffers/constants in the graph
285
321
# torch.export serialization expects them to be lifted
286
- gm , trt_graph_signature , state_dict = lift (gm , trt_graph_signature )
322
+ gm , trt_graph_signature , state_dict , constants = lift (gm , trt_graph_signature )
287
323
288
324
trt_exp_program = ExportedProgram (
289
- gm ,
290
- gm .graph ,
291
- trt_graph_signature ,
292
- state_dict ,
293
- {},
294
- [] ,
295
- [] ,
325
+ root = gm ,
326
+ graph = gm .graph ,
327
+ graph_signature = trt_graph_signature ,
328
+ state_dict = state_dict ,
329
+ range_constraints = {},
330
+ module_call_graph = module_call_graph ,
331
+ constants = constants ,
296
332
)
297
333
298
334
return trt_exp_program
@@ -319,9 +355,13 @@ def inline_trt_modules(
319
355
num_outputs = len (outputs_map [trt_module_node .name ])
320
356
# Insert a call_function node to perform inference on TRT engine
321
357
with gm .graph .inserting_before (trt_module_node ):
358
+ engine_name = f"{ name } _engine"
359
+ setattr (gm , engine_name , trt_module .engine )
360
+ engine_node = gm .graph .get_attr (engine_name )
361
+
322
362
trt_node = gm .graph .call_function (
323
363
torch .ops .tensorrt .execute_engine .default ,
324
- (trt_module_node .args , trt_module . engine ),
364
+ (trt_module_node .args , engine_node ),
325
365
)
326
366
trt_node .meta ["val" ] = []
327
367
assert num_outputs > 0
@@ -337,6 +377,13 @@ def inline_trt_modules(
337
377
)
338
378
)
339
379
380
+ # meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties)
381
+ # Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but
382
+ # for custom object nodes, it should be CustomObjArgument
383
+ engine_node .meta ["val" ] = CustomObjArgument (
384
+ name = engine_node .name , class_fqn = ""
385
+ )
386
+
340
387
if num_outputs == 1 :
341
388
# Insert getitem nodes as outputs (for export serialization to work)
342
389
with gm .graph .inserting_after (trt_node ):
0 commit comments