|
1 | 1 | import torch
|
2 |
| - |
| 2 | +from torch.fx.node import _get_qualified_name |
3 | 3 | from typing import Any, Union, Sequence, Dict
|
4 | 4 | from torch_tensorrt import _Input, Device
|
5 | 5 |
|
@@ -66,3 +66,149 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
|
66 | 66 | )
|
67 | 67 |
|
68 | 68 | return device
|
| 69 | + |
| 70 | + |
| 71 | +def _extract_downstream_get_nodes( |
| 72 | + module_node: torch.fx.Node, output_indices: Sequence[int] |
| 73 | +) -> Sequence[torch.fx.Node]: |
| 74 | + """Extracts downstream users of a node which get the item at a particular index |
| 75 | +
|
| 76 | + Certain module-type nodes have multiple outputs (tuple of outputs). This function |
| 77 | + returns downstream nodes which call the _operator.getitem function, which extracts |
| 78 | + the element at a particular index in the tuple |
| 79 | +
|
| 80 | + Args: |
| 81 | + module_node: FX module-type node to analyze |
| 82 | + output_index: Indices in the module node output to search for |
| 83 | + Returns: |
| 84 | + List of nodes which get the item at the specified index in the module node output |
| 85 | + """ |
| 86 | + get_nodes = [] |
| 87 | + |
| 88 | + # Iterate over all downstream users of the node object |
| 89 | + for user in module_node.users: |
| 90 | + # If the user is a "get" node accessing the specified index, store it |
| 91 | + if _get_qualified_name(user.target) == "_operator.getitem" and ( |
| 92 | + user.args[1] in output_indices |
| 93 | + ): |
| 94 | + get_nodes.append(user) |
| 95 | + |
| 96 | + return get_nodes |
| 97 | + |
| 98 | + |
| 99 | +def repair_long_or_double_input( |
| 100 | + gm: torch.fx.GraphModule, |
| 101 | + position: int, |
| 102 | + submodule_name: str, |
| 103 | + submodule_outputs: Union[torch.Tensor, Sequence[torch.Tensor]], |
| 104 | + dtype: torch.dtype, |
| 105 | +): |
| 106 | + """Fixes Long/Double type inputs to TRT-accelerated subgraphs |
| 107 | +
|
| 108 | + In-Place modifies the provided graph |
| 109 | +
|
| 110 | + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, |
| 111 | + inserts an upcast back to the 64-bit type for subsequent Torch operations |
| 112 | +
|
| 113 | + Args: |
| 114 | + gm: FX GraphModule enclosing the TRT subgraph |
| 115 | + position: Index in the submodule inputs at which the long or double input is found |
| 116 | + submodule_name: Name of TRT-accelerated subgraph module in FX graph |
| 117 | + submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) |
| 118 | + dtype: Data type of tensor at position in submodule (double/long) |
| 119 | + """ |
| 120 | + assert dtype in ( |
| 121 | + torch.int64, |
| 122 | + torch.float64, |
| 123 | + ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" |
| 124 | + |
| 125 | + # Determine target data type in 32 and 64 bit forms |
| 126 | + dtype_64bit = dtype |
| 127 | + dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 |
| 128 | + |
| 129 | + # Find the node representing the submodule in the graph |
| 130 | + module_node = None |
| 131 | + |
| 132 | + # Iterate over all nodes in the graph, seeking target module name match |
| 133 | + for n in gm.graph.nodes: |
| 134 | + if n.op == "call_module" and str(n.target) == submodule_name: |
| 135 | + module_node = n |
| 136 | + break |
| 137 | + |
| 138 | + if module_node is None: |
| 139 | + raise AssertionError( |
| 140 | + f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" |
| 141 | + ) |
| 142 | + |
| 143 | + # Extract the 64-bit node of the input |
| 144 | + node_64bit = module_node.all_input_nodes[position] |
| 145 | + |
| 146 | + # Prior to the module, insert a cast to the 32-bit equivalent node |
| 147 | + with gm.graph.inserting_before(module_node): |
| 148 | + node_32bit = gm.graph.call_function( |
| 149 | + torch.ops.aten._to_copy.default, |
| 150 | + args=(node_64bit,), |
| 151 | + kwargs={"dtype": dtype_32bit}, |
| 152 | + ) |
| 153 | + |
| 154 | + # Replace 64-bit input to TRT module with new 32-bit cast node |
| 155 | + module_node.replace_input_with(node_64bit, node_32bit) |
| 156 | + |
| 157 | + output_positions_64bit = set() |
| 158 | + outputs_list = ( |
| 159 | + [submodule_outputs] |
| 160 | + if isinstance(submodule_outputs, torch.Tensor) |
| 161 | + else submodule_outputs |
| 162 | + ) |
| 163 | + |
| 164 | + # Determine if any outputs of the model are 64-bit type and store their indices |
| 165 | + for output_position, output in enumerate(outputs_list): |
| 166 | + if output.dtype == dtype_64bit: |
| 167 | + output_positions_64bit.add(output_position) |
| 168 | + |
| 169 | + # Only enter this code block if there exists a 64-bit output |
| 170 | + # This implies a cast is needed, since TRT cannot output 64-bit tensors |
| 171 | + if output_positions_64bit: |
| 172 | + # Determine whther the outputs of the module are tuple-type or not |
| 173 | + is_tuple_output = False |
| 174 | + if isinstance(submodule_outputs, tuple): |
| 175 | + is_tuple_output = True |
| 176 | + |
| 177 | + if not is_tuple_output: |
| 178 | + # If the output is a single tensor, insert a cast back to int64 |
| 179 | + with gm.graph.inserting_after(module_node): |
| 180 | + cast_node_64bit = gm.graph.call_function( |
| 181 | + torch.ops.aten._to_copy.default, |
| 182 | + args=(module_node,), |
| 183 | + kwargs={"dtype": dtype_64bit}, |
| 184 | + ) |
| 185 | + |
| 186 | + # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent |
| 187 | + module_node.replace_all_uses_with( |
| 188 | + cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) |
| 189 | + ) |
| 190 | + |
| 191 | + else: |
| 192 | + # If the output is a tuple of tensors, extract downstream users for each 64-bit output |
| 193 | + get_nodes = _extract_downstream_get_nodes( |
| 194 | + module_node, output_positions_64bit |
| 195 | + ) |
| 196 | + |
| 197 | + # For each downstream user, append a cast node back to the 64-bit precision |
| 198 | + for get_node in get_nodes: |
| 199 | + with gm.graph.inserting_after(get_node): |
| 200 | + cast_node_64bit = gm.graph.call_function( |
| 201 | + torch.ops.aten._to_copy.default, |
| 202 | + args=(get_node,), |
| 203 | + kwargs={"dtype": torch.int64}, |
| 204 | + ) |
| 205 | + |
| 206 | + get_node.replace_all_uses_with( |
| 207 | + cast_node_64bit, |
| 208 | + delete_user_cb=lambda user: (user != cast_node_64bit), |
| 209 | + ) |
| 210 | + |
| 211 | + # Clean up graph and ensure invariants are preserved |
| 212 | + gm.graph.eliminate_dead_code() |
| 213 | + gm.graph.lint() |
| 214 | + gm.recompile() |
0 commit comments