Skip to content

Commit 5dadfd6

Browse files
committed
changes to make Llama example work
1 parent bba4153 commit 5dadfd6

File tree

7 files changed

+290
-9
lines changed

7 files changed

+290
-9
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch_tensorrt.dynamo._compiler import compile_module
1515
from torch_tensorrt.dynamo.lowering import (
1616
get_decompositions,
17+
modify_reshape_complex_nodes,
1718
post_lowering,
1819
remove_detach,
1920
remove_sym_nodes,
@@ -61,9 +62,15 @@ def aot_torch_tensorrt_aten_backend(
6162
settings_aot_autograd["decompostions"] = get_decompositions(
6263
settings.enable_experimental_decompositions
6364
)
64-
return aot_autograd(fw_compiler=_pretraced_backend_autograd)(
65-
gm, sample_inputs, **settings_aot_autograd
66-
)
65+
# This is added since detach lowering leads to alias nodes
66+
# Error - View operation returned a tensor that is the same as the input base tensor
67+
# torch nop_decompositions in torch/_decomp/decompositions.py
68+
if aten.detach in settings_aot_autograd["decompositions"]:
69+
del settings_aot_autograd["decompositions"][aten.detach]
70+
return aot_autograd(
71+
fw_compiler=_pretraced_backend_autograd,
72+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
73+
)(gm, sample_inputs)
6774

6875

6976
def _pretraced_backend(
@@ -103,6 +110,16 @@ def _pretraced_backend(
103110
# Remove detach nodes
104111
remove_detach(gm, settings)
105112

113+
complexInputIndices = []
114+
for i, torch_input in enumerate(torch_inputs):
115+
if torch_inputs[i].dtype == torch.complex64:
116+
complexInputIndices.append(i)
117+
torch_input_real = torch_inputs[i].real
118+
torch_input_imaginary = torch_inputs[i].imag
119+
torch_inputs[i] = torch.stack(
120+
(torch_input_real, torch_input_imaginary), dim=-1
121+
)
122+
106123
# Invoke AOTAutograd to translate operators to aten
107124
if settings.use_aot_joint_export:
108125
gm = aot_export_joint_simple(
@@ -120,6 +137,12 @@ def _pretraced_backend(
120137

121138
logger.debug("Lowered Input graph:\n " + str(gm.graph))
122139

140+
if complexInputIndices:
141+
modify_reshape_complex_nodes(gm, complexInputIndices)
142+
logger.debug(
143+
"Input graph after modifying complex nodes:\n " + str(gm.graph)
144+
)
145+
123146
torchtrt_inputs = prepare_inputs(
124147
torch_inputs, disable_memory_format_check=True
125148
)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
has_static_shapes_in_args,
1717
)
1818
from torch_tensorrt.dynamo.conversion.converter_utils import (
19+
args_bounds_check,
1920
enforce_tensor_types,
2021
get_positive_dim,
2122
is_only_operator_on_placeholder,
@@ -25,12 +26,6 @@
2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

2728

28-
def args_bounds_check(
29-
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
30-
) -> Any:
31-
return args[i] if len(args) > i and args[i] is not None else replacement
32-
33-
3429
def get_ir(target: Target) -> SourceIR:
3530
target_module = getattr(target, "__module__", "None")
3631
if any(

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,9 @@ def set_layer_name(
913913
else f"{source_ir}_ops.{target.__name__}"
914914
)
915915
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
916+
917+
918+
def args_bounds_check(
919+
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
920+
) -> Any:
921+
return args[i] if len(args) > i and args[i] is not None else replacement
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ._aten_lowering_pass import *
2+
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes
23
from .remove_sym_nodes import remove_sym_nodes
34
from .repair_input_aliasing import repair_input_aliasing
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
find_complex_nodes,
10+
)
11+
12+
from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple
13+
14+
15+
def tensorrt_complex_mul(args0, args1):
16+
args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1)
17+
args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1)
18+
19+
args0_real = torch.ops.aten.squeeze.dim(args0_real, -1)
20+
args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1)
21+
args1_real = torch.ops.aten.squeeze.dim(args1_real, -1)
22+
args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1)
23+
24+
complex_mul_real = torch.ops.aten.sub(
25+
torch.ops.aten.mul(args0_real, args1_real),
26+
torch.ops.aten.mul(args0_imag, args1_imag),
27+
)
28+
complex_mul_imag = torch.ops.aten.add(
29+
torch.ops.aten.mul(args0_real, args1_imag),
30+
torch.ops.aten.mul(args0_imag, args1_real),
31+
)
32+
33+
return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1)
34+
35+
36+
def remove_complex_real_view_nodes(gm: torch.fx.GraphModule):
37+
modified_graph = False
38+
nodes_to_remove = []
39+
for node in gm.graph.nodes:
40+
if "view_as_complex" in node.name or "view_as_real" in node.name:
41+
nodes_to_remove.append(node)
42+
43+
for node in nodes_to_remove:
44+
input_node = node.args[0] if node.args else None
45+
46+
for other_node in gm.graph.nodes:
47+
new_args = tuple(
48+
input_node if arg is node else arg for arg in other_node.args
49+
)
50+
other_node.args = new_args
51+
52+
gm.graph.erase_node(node)
53+
modified_graph = True
54+
55+
if modified_graph:
56+
gm = clean_up_graph_after_modifications(gm)
57+
logger.debug(
58+
f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}"
59+
)
60+
61+
62+
def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes):
63+
for node in gm.graph.nodes:
64+
if node in complex_nodes:
65+
# slice and transpose will remain same
66+
if "reshape" in node.name:
67+
new_shape = list(node.args[1]) + [2]
68+
node.args = (node.args[0], tuple(new_shape))
69+
70+
71+
def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes):
72+
modified_graph = False
73+
for node in gm.graph.nodes:
74+
if node in complex_nodes:
75+
if "mul" in node.name:
76+
complex_mul_args = (node.args[0], node.args[1])
77+
with gm.graph.inserting_after(node):
78+
replacement_node = gm.graph.create_node(
79+
op="call_function",
80+
target=tensorrt_complex_mul,
81+
args=complex_mul_args,
82+
)
83+
node.replace_all_uses_with(replacement_node)
84+
replacement_node.meta.update(node.meta)
85+
modified_graph = True
86+
gm.graph.erase_node(node)
87+
88+
if modified_graph:
89+
gm = clean_up_graph_after_modifications(gm)
90+
logger.debug(
91+
f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}"
92+
)
93+
94+
95+
def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes):
96+
modify_reshape_nodes(gm, complex_nodes)
97+
remove_complex_real_view_nodes(gm)
98+
modify_mul_nodes(gm, complex_nodes)
99+
100+
101+
def modify_reshape_complex_nodes(gm: torch.fx.GraphModule, complexInputIndices):
102+
complex_nodes = find_complex_nodes(gm)
103+
if complex_nodes:
104+
replace_complex_placeholder_to_tuple(gm, complexInputIndices)
105+
modify_complex_nodes(gm, complex_nodes)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
3+
import torch
4+
from torch.fx.node import _get_qualified_name
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check
7+
8+
# dead-code elimination, linting, and recompilation for graph, in-place
9+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
10+
clean_up_graph_after_modifications,
11+
)
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# for now creating this node, but mostly will want to modify this in input
16+
17+
18+
def replace_complex_placeholder_to_tuple(
19+
gm: torch.fx.GraphModule, inputListindices
20+
) -> torch.fx.GraphModule:
21+
modified_graph = False
22+
input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices]
23+
for node in gm.graph.nodes:
24+
if node.op == "placeholder" and node.target in input_arg_list:
25+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
26+
27+
node_shape = node.meta["val"].size()
28+
new_node_shape = node_shape + (2,)
29+
new_node_dtype = None
30+
if node.meta["val"].dtype == torch.complex64:
31+
new_node_dtype = torch.float32
32+
else:
33+
new_node_dtype = torch.float64
34+
fake_mode = FakeTensorMode()
35+
36+
real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype)
37+
with FakeTensorMode() as fake_mode:
38+
new_placeholder_tuple = fake_mode.from_tensor(real_tensor)
39+
node.meta["val"] = new_placeholder_tuple
40+
modified_graph = True
41+
# propagate the meta data change for the downstream ops
42+
# TODO:to check if this is required in all cases
43+
propogate_complex_num_shape_change_till_complex_mul(gm, node, fake_mode)
44+
45+
# If graph was modified, clean it up
46+
if modified_graph:
47+
gm = clean_up_graph_after_modifications(gm)
48+
logger.debug(
49+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
50+
)
51+
52+
return gm
53+
54+
55+
def infer_slice_shape(node):
56+
input_shape = node.args[0].meta["val"].shape
57+
slice_args = node.args
58+
dim = slice_args[1]
59+
start = slice_args[2]
60+
end = slice_args[3]
61+
step = args_bounds_check(slice_args, 4, replacement=1)
62+
new_shape = list(input_shape)
63+
new_shape[dim] = (end - start + step - 1) // step
64+
return tuple(new_shape)
65+
66+
67+
def infer_reshape_shape(node):
68+
return node.args[1]
69+
70+
71+
shape_inference_funcs = {
72+
"torch.ops.aten.slice.Tensor": infer_slice_shape,
73+
"torch.ops.aten.reshape.default": infer_reshape_shape,
74+
}
75+
76+
77+
# Please note this function is for the use case of Llama model
78+
# with complex placeholder->reshape->slice->complex mul
79+
# Hence mul is the terminating op
80+
def propogate_complex_num_shape_change_till_complex_mul(
81+
node: torch.fx.Node, start_node: torch.fx.Node, fake_mode: FakeTensorMode
82+
) -> None:
83+
visited_nodes = set()
84+
stack = [start_node]
85+
while stack:
86+
node = stack.pop()
87+
if node in visited_nodes:
88+
continue
89+
visited_nodes.add(node)
90+
update_node_meta(node, fake_mode)
91+
for user in node.users:
92+
if (
93+
user.op == "call_function"
94+
and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor"
95+
):
96+
continue
97+
stack.append(user)
98+
99+
100+
def update_node_meta(node, fake_mode):
101+
op_name = node.name
102+
op_target = node.target
103+
104+
if node.op == "call_function":
105+
op_target = _get_qualified_name(node.target)
106+
107+
if op_target in shape_inference_funcs:
108+
new_shape = shape_inference_funcs[op_target](node)
109+
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
110+
node.meta["val"] = fake_mode.from_tensor(real_tensor)
111+
else:
112+
print("No shape for the inference function", {op_name})

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,42 @@ def get_tensor_placeholders(
2929
]
3030

3131
return placeholders
32+
33+
34+
def find_complex_nodes(gm: torch.fx.GraphModule):
35+
complex_nodes = []
36+
complexNodes = {}
37+
for node in gm.graph.nodes:
38+
if is_node_complex(node, complexNodes):
39+
complex_nodes.append(node)
40+
return complex_nodes
41+
42+
43+
def is_node_complex(node: torch.fx.Node, complexNodes):
44+
if not isinstance(node, torch.fx.Node):
45+
return False
46+
if node.name in complexNodes:
47+
return True
48+
if node.op == "call_function" and node.args is not None:
49+
for arg in node.args:
50+
if isinstance(arg, int):
51+
continue
52+
elif isinstance(arg, (list, tuple)):
53+
for eachNode in arg:
54+
if is_node_complex(eachNode, complexNodes):
55+
complexNodes[node.name] = True
56+
return True
57+
58+
elif hasattr(arg, "meta") and "val" in arg.meta:
59+
if isinstance(arg.meta["val"], (list, tuple)):
60+
for eachFakeTensorMeta in arg.meta["val"]:
61+
if eachFakeTensorMeta.dtype in (
62+
torch.complex64,
63+
torch.complex128,
64+
):
65+
complexNodes[node.name] = True
66+
return True
67+
elif arg.meta["val"].dtype in (torch.complex64, torch.complex128):
68+
complexNodes[node.name] = True
69+
return True
70+
return False

0 commit comments

Comments
 (0)