Skip to content

Commit ee65be4

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent 7bae16e commit ee65be4

File tree

3 files changed

+14
-146
lines changed

3 files changed

+14
-146
lines changed

examples/dynamo/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Model Zoo
2121
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2222
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
2323
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
24-
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
24+
* :ref:`_torch_export_sam2`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)

examples/dynamo/torch_export_flux_dev.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@
77
This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
88
Torch-TensorRT.
99
10-
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
11-
10+
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications
1211
Install the following dependencies before compilation
1312
1413
.. code-block:: python
1514
16-
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
15+
pip install -r requirements.txt
1716
18-
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
19-
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
17+
There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example,
18+
we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency)
2019
"""
2120

2221
# %%
@@ -31,7 +30,7 @@
3130
# Define the FLUX-1.dev model
3231
# -----------------------------
3332
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
34-
# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary
33+
# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary
3534
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
3635
DEVICE = "cuda:0"
3736
pipe = FluxPipeline.from_pretrained(
@@ -45,14 +44,14 @@
4544

4645

4746
# %%
48-
# Export the backbone using torch.export
47+
# Export the backbone using ``torch.export``
4948
# --------------------------------------------------
50-
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2``
51-
# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
49+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2
50+
# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
5251
batch_size = 2
5352
BATCH = torch.export.Dim("batch", min=1, max=2)
5453
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
55-
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
54+
# This particular min, max values are recommended by torch dynamo during the export of the model.
5655
# To see this recommendation, you can try exporting using min=1, max=4096
5756
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
5857
dynamic_shapes = {
@@ -93,24 +92,18 @@
9392
# %%
9493
# Torch-TensorRT compilation
9594
# ---------------------------
96-
# .. note::
97-
# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future.
98-
#
99-
#
100-
# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes.
101-
# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation.
102-
# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in
103-
# a single TensorRT engine.
95+
# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model.
96+
# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU
10497
trt_gm = torch_tensorrt.dynamo.compile(
10598
ep,
10699
inputs=dummy_inputs,
107100
enabled_precisions={torch.float32},
108101
truncate_double=True,
109102
min_block_size=1,
103+
debug=True,
110104
use_fp32_acc=True,
111105
use_explicit_typing=True,
112106
)
113-
114107
# %%
115108
# Post Processing
116109
# ---------------------------
@@ -145,6 +138,4 @@ def generate_image(pipe, prompt, image_name):
145138

146139
# %%
147140
# The generated image is as shown below
148-
#
149-
# .. image:: dog_code.png
150-
#
141+
# .. image:: dog_code.png

examples/dynamo/torch_export_flux_fp8.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)