Skip to content

perf: Add lowering passes to improve TRT runtime on SD #2351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Lowering Pass Requirements
------------

An ATen lowering pass function in Torch-TRT must satisfy two requirements:
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
- The function must take as input a `torch.fx.GraphModule` and a sequence of torch Tensors, `Sequence[torch.Tensor]`, and return the lowered `torch.fx.GraphModule`
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation

See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
Expand All @@ -22,7 +22,7 @@ Example Lowering Pass

.. code-block:: python
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
"""Repair scenarios where inputs are also outputs of the graph
TRT does not allow such cases, so we insert a clone (identity) layer
Expand Down Expand Up @@ -82,15 +82,15 @@ For instance, to insert the pass at the default location (end of the list), the
.. code-block:: python
@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
...
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:

.. code-block:: python
@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
...
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
Expand All @@ -101,7 +101,7 @@ There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for
print(dump_lowering_passes())
# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module)
apply_lowering_passes(graph_module, sample_inputs)
# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def trace(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(model, tuple(inputs)).module()
graph_module = apply_lowering_passes(graph_module)
graph_module = apply_lowering_passes(graph_module, inputs)
logger.debug("Post export graph: " + str(graph_module.graph))
return graph_module
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = apply_lowering_passes(gm)
gm = apply_lowering_passes(gm, sample_inputs)

trt_compiled = compile_module(
gm,
Expand Down
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,3 +1517,18 @@ def aten_ops_max_pool(
dilation=args_bounds_check(args, 4, replacement=1),
ceil_mode=args_bounds_check(args, 5, replacement=False),
)


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
) # type: ignore[misc]
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import (
activation,
attention,
cast,
condition,
conv,
Expand Down
50 changes: 50 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import math
from typing import Optional, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.fx.types import TRTTensor


def scaled_dot_product_attention(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_mm",
query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
div = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
)
softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", div, -1
)
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_out",
softmax,
value,
)

return out
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
aten.special_log_ndtr,
aten.special_xlog1py,
aten.stack,
aten.std,
aten.t,
aten.tanh_backward,
aten.threshold,
Expand All @@ -163,6 +164,8 @@
aten.upsample_bilinear2d,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.var,
aten.var_mean,
aten.xlogy,
aten.zero,
aten.zero_,
Expand Down
55 changes: 49 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from torch._decomp import register_decomposition
Expand Down Expand Up @@ -83,11 +83,6 @@ def inplace_op(*args, **kwargs): # type: ignore
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)


@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
return torch.sqrt(torch.var(*args, **kwargs))


@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
return torch.reciprocal(torch.sqrt(*args, **kwargs))
Expand Down Expand Up @@ -135,6 +130,54 @@ def reciprocal_replacement(
return torch.div(1, input_)


@register_torch_trt_decomposition(
torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def var_decomposition(
input_tensor: torch.Tensor,
dims: Optional[List[int]],
correction: int,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if dims is None:
dims = []

# If the dimensions are empty, variance is taken over all dimensions
if isinstance(dims, (tuple, list)) and len(dims) == 0:
N = input_tensor.numel()
# Otherwise, the number of samples is the product of the dimensions reduced over
else:
N = 1
for dim_i in dims:
N *= input_tensor.shape[dim_i]

# Compute the mean, difference, and correction term as per the formula:
# https://pytorch.org/docs/stable/generated/torch.var.html

# Additionally, prims does not support keepdim, and so we only keep dimensions
# on the first reduction, then remove it for the second
sample_mean = torch.mean(input_tensor, dims, keepdim=True)
diff = input_tensor - sample_mean
squared_diff = diff * diff
variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False)

if correction is None:
correction_term = float(N - 1)
elif isinstance(correction, int):
correction_term = float(N - correction)
elif isinstance(correction, float):
correction_term = float(N) - correction
else:
raise RuntimeError("correction must be int or float")

if correction_term <= 0:
raise RuntimeError(f"correction term was non-positive, got: {correction_term}")

variance = variance_unnormalized / correction_term

return variance


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
20 changes: 15 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Callable, Optional
from typing import Callable, Optional, Sequence, Union

import torch

from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -13,19 +15,25 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_efficient_attention,
fuse_prims_broadcast,
]
)

logger = logging.getLogger(__name__)


LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
LoweringPassSignature = Callable[
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
]


def _aten_lowering_pass(
*args: LoweringPassSignature,
index: Optional[int] = None,
) -> LoweringPassSignature:
) -> Union[
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
]:
"""Adds a lowering pass to the registry, at a specified index if desired
If no index is specified, the lowering pass is inserted at the end of the list
Expand Down Expand Up @@ -65,12 +73,14 @@ def _remove_lowering_pass(*, index: int) -> None:
return


def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def apply_lowering_passes(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
)
return ATEN_LOWERING_PASSES(gm)
return ATEN_LOWERING_PASSES(gm, sample_inputs)


def dump_lowering_passes() -> str:
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Sequence

import torch
from torch_tensorrt._utils import sanitized_torch_version
Expand All @@ -21,7 +22,9 @@


@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def constant_fold(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
Expand Down
82 changes: 82 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
from typing import Sequence

import torch
from torch.fx.passes.shape_prop import ShapeProp
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


# TODO: Add relevant prims to this fusion
def fuse_prims_broadcast(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
modified_graph = False

# Propagate shapes through the graph to determine if broadcast can be resolved
try:
ShapeProp(gm).propagate(*sample_inputs)
except (RuntimeError, AssertionError):
logger.warning(
"Shape Propagation Failed on Graph, skipping fuse_prims_broadcast lowering pass",
exc_info=True,
)
return gm

for node in gm.graph.nodes:
# If the node is a sum prims operator, with broadcast_in_dim being the only consumer
# it is a candidate for fusing
if (
node.target in (torch.ops.prims.sum.default,)
and len(node.users) == 1
and list(node.users)[0].target == torch.ops.prims.broadcast_in_dim.default
):
# Get broadcasted shape, reduced dimensions, and original tensor shape
broadcast_node = list(node.users)[0]
broadcasted_shape = broadcast_node.args[1]
reduced_dims = node.args[1]
original_shape = node.args[0].meta["tensor_meta"].shape

# If the rank of the broadcasted shape is the same as the original
# and the broadcasts are all singletons for the reduced dimensions
# and all of the non-reduced dimensions are identical to the originals

# Then the broadcast is effectively performing a "keep_dim=True" operation
if (
len(broadcasted_shape) == len(original_shape)
and all(broadcasted_shape[i] == 1 for i in reduced_dims)
and all(
broadcasted_shape[j] == original_shape[j]
for j in range(len(original_shape))
if j not in reduced_dims
)
):
# Fuse the operator to its convertible alternative
with gm.graph.inserting_after(broadcast_node):
modified_graph = True

if node.target == torch.ops.prims.sum.default:
fused_node = gm.graph.call_function(
torch.ops.aten.sum.dim_IntList,
args=(node.args[0], reduced_dims, True),
)

# Replace all uses of the placeholder except the cloned node
# with the cloned placeholder
broadcast_node.replace_all_uses_with(
fused_node,
)

# Erase uses of the broadcast node and original
gm.graph.erase_node(broadcast_node)
gm.graph.erase_node(node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after fusing prims-broadcast paradigm:\n{gm.graph}")

return gm
Loading