Skip to content

Commit a613584

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent be3f910 commit a613584

File tree

2 files changed

+21
-246
lines changed

2 files changed

+21
-246
lines changed

examples/dynamo/flux.py

Lines changed: 0 additions & 234 deletions
This file was deleted.

examples/dynamo/torch_export_flux_dev.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77
This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
88
Torch-TensorRT.
99
10-
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications
10+
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
11+
1112
Install the following dependencies before compilation
1213
1314
.. code-block:: python
1415
1516
pip install -r requirements.txt
1617
17-
There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example,
18-
we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency)
18+
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
19+
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
1920
"""
2021

2122
# %%
@@ -30,7 +31,7 @@
3031
# Define the FLUX-1.dev model
3132
# -----------------------------
3233
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
33-
# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary
34+
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
3435
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
3536
DEVICE = "cuda:0"
3637
pipe = FluxPipeline.from_pretrained(
@@ -44,14 +45,14 @@
4445

4546

4647
# %%
47-
# Export the backbone using ``torch.export``
48+
# Export the backbone using torch.export
4849
# --------------------------------------------------
49-
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2
50-
# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
50+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
51+
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
5152
batch_size = 2
5253
BATCH = torch.export.Dim("batch", min=1, max=2)
5354
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
54-
# This particular min, max values are recommended by torch dynamo during the export of the model.
55+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
5556
# To see this recommendation, you can try exporting using min=1, max=4096
5657
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
5758
dynamic_shapes = {
@@ -92,18 +93,24 @@
9293
# %%
9394
# Torch-TensorRT compilation
9495
# ---------------------------
95-
# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model.
96-
# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU
96+
# .. note::
97+
# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future.
98+
#
99+
#
100+
# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes.
101+
# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation.
102+
# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in
103+
# a single TensorRT engine.
97104
trt_gm = torch_tensorrt.dynamo.compile(
98105
ep,
99106
inputs=dummy_inputs,
100107
enabled_precisions={torch.float32},
101108
truncate_double=True,
102109
min_block_size=1,
103-
debug=True,
104110
use_fp32_acc=True,
105111
use_explicit_typing=True,
106112
)
113+
107114
# %%
108115
# Post Processing
109116
# ---------------------------
@@ -138,4 +145,6 @@ def generate_image(pipe, prompt, image_name):
138145

139146
# %%
140147
# The generated image is as shown below
141-
# .. image:: dog_code.png
148+
#
149+
# .. image:: dog_code.png
150+
#

0 commit comments

Comments
 (0)