|
1 | 1 | import torch
|
2 |
| -from torch.fx.node import _get_qualified_name |
3 | 2 | from typing import Any, Union, Sequence, Dict
|
4 | 3 | from torch_tensorrt import _Input, Device
|
5 | 4 |
|
@@ -66,149 +65,3 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
|
66 | 65 | )
|
67 | 66 |
|
68 | 67 | return device
|
69 |
| - |
70 |
| - |
71 |
| -def _extract_downstream_get_nodes( |
72 |
| - module_node: torch.fx.Node, output_indices: Sequence[int] |
73 |
| -) -> Sequence[torch.fx.Node]: |
74 |
| - """Extracts downstream users of a node which get the item at a particular index |
75 |
| -
|
76 |
| - Certain module-type nodes have multiple outputs (tuple of outputs). This function |
77 |
| - returns downstream nodes which call the _operator.getitem function, which extracts |
78 |
| - the element at a particular index in the tuple |
79 |
| -
|
80 |
| - Args: |
81 |
| - module_node: FX module-type node to analyze |
82 |
| - output_index: Indices in the module node output to search for |
83 |
| - Returns: |
84 |
| - List of nodes which get the item at the specified index in the module node output |
85 |
| - """ |
86 |
| - get_nodes = [] |
87 |
| - |
88 |
| - # Iterate over all downstream users of the node object |
89 |
| - for user in module_node.users: |
90 |
| - # If the user is a "get" node accessing the specified index, store it |
91 |
| - if _get_qualified_name(user.target) == "_operator.getitem" and ( |
92 |
| - user.args[1] in output_indices |
93 |
| - ): |
94 |
| - get_nodes.append(user) |
95 |
| - |
96 |
| - return get_nodes |
97 |
| - |
98 |
| - |
99 |
| -def repair_long_or_double_input( |
100 |
| - gm: torch.fx.GraphModule, |
101 |
| - position: int, |
102 |
| - submodule_name: str, |
103 |
| - submodule_outputs: Union[torch.Tensor, Sequence[torch.Tensor]], |
104 |
| - dtype: torch.dtype, |
105 |
| -): |
106 |
| - """Fixes Long/Double type inputs to TRT-accelerated subgraphs |
107 |
| -
|
108 |
| - In-Place modifies the provided graph |
109 |
| -
|
110 |
| - Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, |
111 |
| - inserts an upcast back to the 64-bit type for subsequent Torch operations |
112 |
| -
|
113 |
| - Args: |
114 |
| - gm: FX GraphModule enclosing the TRT subgraph |
115 |
| - position: Index in the submodule inputs at which the long or double input is found |
116 |
| - submodule_name: Name of TRT-accelerated subgraph module in FX graph |
117 |
| - submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) |
118 |
| - dtype: Data type of tensor at position in submodule (double/long) |
119 |
| - """ |
120 |
| - assert dtype in ( |
121 |
| - torch.int64, |
122 |
| - torch.float64, |
123 |
| - ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" |
124 |
| - |
125 |
| - # Determine target data type in 32 and 64 bit forms |
126 |
| - dtype_64bit = dtype |
127 |
| - dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 |
128 |
| - |
129 |
| - # Find the node representing the submodule in the graph |
130 |
| - module_node = None |
131 |
| - |
132 |
| - # Iterate over all nodes in the graph, seeking target module name match |
133 |
| - for n in gm.graph.nodes: |
134 |
| - if n.op == "call_module" and str(n.target) == submodule_name: |
135 |
| - module_node = n |
136 |
| - break |
137 |
| - |
138 |
| - if module_node is None: |
139 |
| - raise AssertionError( |
140 |
| - f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" |
141 |
| - ) |
142 |
| - |
143 |
| - # Extract the 64-bit node of the input |
144 |
| - node_64bit = module_node.all_input_nodes[position] |
145 |
| - |
146 |
| - # Prior to the module, insert a cast to the 32-bit equivalent node |
147 |
| - with gm.graph.inserting_before(module_node): |
148 |
| - node_32bit = gm.graph.call_function( |
149 |
| - torch.ops.aten._to_copy.default, |
150 |
| - args=(node_64bit,), |
151 |
| - kwargs={"dtype": dtype_32bit}, |
152 |
| - ) |
153 |
| - |
154 |
| - # Replace 64-bit input to TRT module with new 32-bit cast node |
155 |
| - module_node.replace_input_with(node_64bit, node_32bit) |
156 |
| - |
157 |
| - output_positions_64bit = set() |
158 |
| - outputs_list = ( |
159 |
| - [submodule_outputs] |
160 |
| - if isinstance(submodule_outputs, torch.Tensor) |
161 |
| - else submodule_outputs |
162 |
| - ) |
163 |
| - |
164 |
| - # Determine if any outputs of the model are 64-bit type and store their indices |
165 |
| - for output_position, output in enumerate(outputs_list): |
166 |
| - if output.dtype == dtype_64bit: |
167 |
| - output_positions_64bit.add(output_position) |
168 |
| - |
169 |
| - # Only enter this code block if there exists a 64-bit output |
170 |
| - # This implies a cast is needed, since TRT cannot output 64-bit tensors |
171 |
| - if output_positions_64bit: |
172 |
| - # Determine whther the outputs of the module are tuple-type or not |
173 |
| - is_tuple_output = False |
174 |
| - if isinstance(submodule_outputs, tuple): |
175 |
| - is_tuple_output = True |
176 |
| - |
177 |
| - if not is_tuple_output: |
178 |
| - # If the output is a single tensor, insert a cast back to int64 |
179 |
| - with gm.graph.inserting_after(module_node): |
180 |
| - cast_node_64bit = gm.graph.call_function( |
181 |
| - torch.ops.aten._to_copy.default, |
182 |
| - args=(module_node,), |
183 |
| - kwargs={"dtype": dtype_64bit}, |
184 |
| - ) |
185 |
| - |
186 |
| - # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent |
187 |
| - module_node.replace_all_uses_with( |
188 |
| - cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) |
189 |
| - ) |
190 |
| - |
191 |
| - else: |
192 |
| - # If the output is a tuple of tensors, extract downstream users for each 64-bit output |
193 |
| - get_nodes = _extract_downstream_get_nodes( |
194 |
| - module_node, output_positions_64bit |
195 |
| - ) |
196 |
| - |
197 |
| - # For each downstream user, append a cast node back to the 64-bit precision |
198 |
| - for get_node in get_nodes: |
199 |
| - with gm.graph.inserting_after(get_node): |
200 |
| - cast_node_64bit = gm.graph.call_function( |
201 |
| - torch.ops.aten._to_copy.default, |
202 |
| - args=(get_node,), |
203 |
| - kwargs={"dtype": torch.int64}, |
204 |
| - ) |
205 |
| - |
206 |
| - get_node.replace_all_uses_with( |
207 |
| - cast_node_64bit, |
208 |
| - delete_user_cb=lambda user: (user != cast_node_64bit), |
209 |
| - ) |
210 |
| - |
211 |
| - # Clean up graph and ensure invariants are preserved |
212 |
| - gm.graph.eliminate_dead_code() |
213 |
| - gm.graph.lint() |
214 |
| - gm.recompile() |
0 commit comments