Skip to content

Commit 2844630

Browse files
authored
Merge pull request #1979 from pytorch/dynamo_module_level_acceleration
feat: Module-Acceleration in Dynamo [5 / x]
2 parents b38fa5b + c9f06fc commit 2844630

File tree

10 files changed

+481
-7
lines changed

10 files changed

+481
-7
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _find_lib(name, paths):
9595

9696
from torch_tensorrt import fx
9797

98-
if version.parse(torch.__version__) >= version.parse("2.dev"):
98+
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
9999
from torch_tensorrt import dynamo
100100
from torch_tensorrt.dynamo import backend
101101

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11+
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
12+
pre_aot_substitutions,
13+
)
1114
from torch_tensorrt.dynamo.backend.lowering._partition import (
1215
partition,
1316
get_submod_inputs,
@@ -41,6 +44,9 @@ def aot_torch_tensorrt_aten_backend(
4144
settings=settings,
4245
)
4346

47+
# Perform Pre-AOT Lowering for Module-Level Replacement
48+
gm = pre_aot_substitutions(gm)
49+
4450
# Invoke AOTAutograd to translate operators to aten
4551
return aot_module_simplified(
4652
gm,
@@ -65,6 +71,8 @@ def _pretraced_backend(
6571
Compiled FX GraphModule
6672
"""
6773
try:
74+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
75+
6876
trt_compiled = _compile_module(
6977
gm,
7078
sample_inputs,
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
1+
from ._decompositions import (
22
get_decompositions,
33
)
4-
from torch_tensorrt.dynamo.backend.lowering._partition import (
5-
partition,
6-
get_submod_inputs,
4+
from ._pre_aot_lowering import (
5+
SUBSTITUTION_REGISTRY,
6+
register_substitution,
77
)
8+
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
9+
from .substitutions import *

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence
2+
from typing import Dict, List, Optional, Sequence, Set
33

44
import torch
55

66
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
7+
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
78
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
89
from torch.fx.graph_module import GraphModule
910
from torch.fx.node import _get_qualified_name
@@ -14,6 +15,11 @@
1415

1516
logger = logging.getLogger(__name__)
1617

18+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19+
_get_qualified_name(to_replace.new_operator)
20+
for to_replace in SUBSTITUTION_REGISTRY.values()
21+
)
22+
1723

1824
class TRTPartitioner(CapabilityBasedPartitioner):
1925
"""Partitioner to split an FX graph into subgraphs based on operator support
@@ -35,7 +41,9 @@ def __init__(
3541
operator_support: OperatorSupport,
3642
*,
3743
non_compute_ops: Optional[Sequence[str]] = None,
38-
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
44+
allowed_single_node_partition_ops: Optional[
45+
Sequence[str]
46+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
3947
min_block_size=MIN_BLOCK_SIZE,
4048
) -> None:
4149
super().__init__(
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Callable, Dict, Optional, Type, Union
3+
import torch
4+
import logging
5+
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
@dataclass(frozen=True)
11+
class Substitution:
12+
"""Class to store key functionality for module replacement"""
13+
14+
# torch.ops.___ name for replacement function for module
15+
new_operator: torch._ops.OpOverload
16+
17+
# Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18+
# and returning a replacement node, with type 'call_function', or raising an Error if
19+
# incompatibility is detected
20+
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
21+
subgraph_insertion_fn: Callable[
22+
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
23+
]
24+
25+
26+
# Dictionary mapping module to Substitution instance
27+
SUBSTITUTION_REGISTRY: Dict[
28+
Union[Type[torch.nn.Module], Callable], Substitution
29+
] = dict()
30+
31+
32+
def register_substitution(
33+
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
34+
new_operator: torch._ops.OpOverload,
35+
enabled: bool = True,
36+
) -> Callable[[Any], Any]:
37+
"""Decorator to register subgraph insertion functions
38+
39+
Args:
40+
module_or_function_to_replace: nn.Module or node target Callable to replace
41+
new_operator: Custom torch operator to replace with
42+
enabled: Whether the substitution is enabled or disabled
43+
Returns:
44+
torch.fx.GraphModule
45+
"""
46+
47+
def enable_substitution(subgraph_insertion_fn):
48+
"""Function for use if substitution is enabled"""
49+
replacement = Substitution(
50+
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
51+
)
52+
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
53+
return subgraph_insertion_fn
54+
55+
def disable_substitution(subgraph_insertion_fn):
56+
"""Function for use if substitution is disabled"""
57+
return subgraph_insertion_fn
58+
59+
return enable_substitution if enabled else disable_substitution
60+
61+
62+
def pre_aot_substitutions(gm: torch.fx.GraphModule):
63+
"""Perform graph substitutions prior to AOT tracing
64+
65+
Args:
66+
gm: FX GraphModule to perform substitution on
67+
Returns:
68+
torch.fx.GraphModule
69+
70+
"""
71+
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
72+
73+
# Ensure all parameters are in inference mode
74+
for param in gm.parameters():
75+
param.requires_grad = False
76+
77+
# Iterate over graph nodes, extracting module calls, to check for interceptions
78+
for n in gm.graph.nodes:
79+
exists_in_registry = False
80+
to_replace = None
81+
82+
if n.op == "call_module":
83+
# Extract submodule from graph, validate in registry
84+
submodule = gm.get_submodule(n.target)
85+
to_replace = type(submodule)
86+
exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
87+
elif n.op == "call_function":
88+
# Extract function from graph, validate in registry
89+
to_replace = n.target
90+
exists_in_registry = n.target in SUBSTITUTION_REGISTRY
91+
92+
# If submodule/function is a member of the substitution registry, replace it
93+
if exists_in_registry:
94+
try:
95+
replacement = SUBSTITUTION_REGISTRY[to_replace]
96+
op, insertion_fn = (
97+
replacement.new_operator,
98+
replacement.subgraph_insertion_fn,
99+
)
100+
logger.debug(f"Replacing node of type {to_replace} with {op}")
101+
102+
# Insert new node prior to older node
103+
with gm.graph.inserting_before(n):
104+
new_node = insertion_fn(
105+
gm, n, submodule if n.op == "call_module" else None
106+
)
107+
108+
# If submodule is not a native torch.nn module, it must be manually excluded
109+
# from Dynamo tracing
110+
if n.op == "call_module" and not type(submodule).__module__.startswith(
111+
"torch.nn"
112+
):
113+
torch._dynamo.allowed_functions._allowed_function_ids.add(
114+
id(to_replace)
115+
)
116+
117+
# Replace all original node uses and clean up graph
118+
n.replace_all_uses_with(new_node)
119+
gm.graph.eliminate_dead_code()
120+
gm.graph.lint()
121+
gm.recompile()
122+
123+
# A replacement can fail in the event that the specific instance of the submodule/function
124+
# cannot be replaced
125+
except Exception:
126+
logger.debug(
127+
f"Encountered error while replacing {to_replace}",
128+
exc_info=True,
129+
)
130+
continue
131+
132+
# Perform cleanup and recompilation before returning module
133+
gm.graph.eliminate_dead_code()
134+
gm.graph.lint()
135+
gm.recompile()
136+
137+
logger.debug("Post-module replacement graph:\n" + str(gm.graph))
138+
139+
return gm
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .maxpool1d import *
2+
from .einsum import *
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Dict, Tuple
2+
import torch
3+
from torch._custom_op.impl import custom_op
4+
from torch.fx.node import Argument, Target
5+
6+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
from torch_tensorrt.dynamo.backend.lowering import register_substitution
11+
12+
13+
@custom_op(
14+
qualname="tensorrt::einsum",
15+
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
16+
)
17+
def einsum(equation, tensors):
18+
# Defines operator schema, name, namespace, and function header
19+
...
20+
21+
22+
@einsum.impl("cpu")
23+
@einsum.impl("cuda")
24+
@einsum.impl_abstract()
25+
def einsum_generic(
26+
*args,
27+
**kwargs,
28+
):
29+
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
30+
return torch.einsum(
31+
*args,
32+
**kwargs,
33+
)
34+
35+
36+
@tensorrt_converter(torch.ops.tensorrt.einsum.default)
37+
def aten_ops_einsum(
38+
network: TRTNetwork,
39+
target: Target,
40+
args: Tuple[Argument, ...],
41+
kwargs: Dict[str, Argument],
42+
name: str,
43+
) -> TRTTensor:
44+
# Defines converter replacing the default operator for this function
45+
for input_trt in args[1]:
46+
if not isinstance(input_trt, TRTTensor):
47+
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")
48+
49+
einsum_layer = network.add_einsum(inputs=args[1], equation=args[0])
50+
51+
set_layer_name(einsum_layer, target, name)
52+
return einsum_layer.get_output(0)
53+
54+
55+
@register_substitution(torch.einsum, torch.ops.tensorrt.einsum)
56+
def einsum_insertion_fn(
57+
gm: torch.fx.GraphModule,
58+
node: torch.fx.Node,
59+
_unused: None = None,
60+
) -> torch.fx.Node:
61+
equation = node.args[0]
62+
63+
# Ensure inputs is a list of (Tensor) arguments
64+
if isinstance(node.args[1], (tuple, list)):
65+
inputs = node.args[1]
66+
else:
67+
inputs = node.args[1:]
68+
69+
assert (
70+
1 <= len(inputs) <= 2
71+
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"
72+
73+
# Ensure the input is formatted as an equation and
74+
new_node = gm.graph.call_function(
75+
torch.ops.tensorrt.einsum,
76+
args=(equation, inputs),
77+
kwargs=node.kwargs,
78+
)
79+
80+
return new_node

0 commit comments

Comments
 (0)