Skip to content

fix: Split addmm nodes to not cast bias for FP32 accumulation and flux example fixes. #3395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.

Install the following dependencies before compilation
To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ page) and install the following dependencies

.. code-block:: python

pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"

There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
Expand All @@ -38,11 +38,10 @@
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)

# Store the config and transformer backbone
config = pipe.transformer.config
backbone = pipe.transformer

backbone = pipe.transformer.to(DEVICE)

# %%
# Export the backbone using torch.export
Expand All @@ -63,6 +62,8 @@
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
Expand All @@ -79,6 +80,8 @@
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False,
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
Expand Down Expand Up @@ -116,8 +119,11 @@
# ---------------------------
# Release the GPU memory occupied by the exported program and the pipe.transformer
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
backbone.to("cpu")

del ep
backbone.to("cpu")
pipe.to(DEVICE)
torch.cuda.empty_cache()
pipe.transformer = trt_gm
pipe.transformer.config = config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -27,7 +27,7 @@
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
remove_assert_nodes,
accumulate_fp32_matmul,
]
)
Expand Down
41 changes: 39 additions & 2 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,54 @@
logger = logging.getLogger(__name__)


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont we have a decomp for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ should this not just be a decomp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified it to use torch decomposition now

target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
bias, mat1, mat2 = addmm_node.all_input_nodes
beta = addmm_node.kwargs.get("beta")
alpha = addmm_node.kwargs.get("alpha")

with gm.graph.inserting_before(addmm_node):
mm_node = gm.graph.call_function(
torch.ops.aten.mm.default,
args=(mat1, mat2),
)
if alpha:
mm_node = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(mm_node, alpha),
)

if beta:
bias = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(bias, beta),
)
add_node = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(bias, mm_node),
)

addmm_node.replace_all_uses_with(add_node, propagate_meta=True)
gm.graph.erase_node(addmm_node)

return gm


def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace a matmul layer with fp32 accumulation nodes"""
"""Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution."""
if settings.use_fp32_acc:
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
split_addmm_nodes(gm)

matmul_nodes = [
node for node in gm.graph.nodes if node.target in matmul_targets
]
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def constant_fold(
gm.graph.erase_node(node)

gm = clean_up_graph_after_modifications(gm)
# Delete the constant folder instance which holds GPU memory
del cf

logger.debug(f"Graph after constant folding:\n{gm.graph}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
logger = logging.getLogger(__name__)


def remove_assert_scalar(
def remove_assert_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
or node == torch.ops.aten._assert_tensor_metadata.default
or node.target == torch.ops.aten._assert_tensor_metadata.default
):
gm.graph.erase_node(node)
count += 1
Expand Down
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import gc
import logging
import warnings
from dataclasses import fields, replace
Expand Down Expand Up @@ -30,6 +31,7 @@
DYNAMIC_DIM = -1
RTOL = 5e-3
ATOL = 5e-3
CPU_DEVICE = "cpu"


class Frameworks(Enum):
Expand Down Expand Up @@ -81,6 +83,17 @@ class Frameworks(Enum):
}


def delete_module(module: torch.fx.GraphModule) -> None:
"""
This is a helper function to delete the instance of module. We first move it to CPU and then
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
"""
module.to(CPU_DEVICE)
del module
torch.cuda.empty_cache()
gc.collect()


def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
"""Parses a user-provided input argument regarding Python runtime

Expand Down
42 changes: 42 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,48 @@ def forward(self, input, weight):
)
torch._dynamo.reset()

def test_fp32_acc_for_addmm(self):
class FP32Acc(torch.nn.Module):
def forward(self, input, mat1, mat2):
out = torch.ops.aten.addmm.default(input, mat1, mat2, beta=20, alpha=2)
return out

inputs = [
torch.rand((3, 5)).cuda(),
torch.rand((3, 4)).cuda(),
torch.rand((4, 5)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(FP32Acc())
expected_ops = {
torch.ops.aten._to_copy.default,
torch.ops.aten.mm.default,
torch.ops.aten.add.Tensor,
}
unexpected_ops = {}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
use_fp32_acc=True,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()


class TestLowerEfficientAttention(TestCase):
def test_lower_efficient_attention(self):
Expand Down
Loading