Skip to content

Commit be3f910

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent bfb549e commit be3f910

File tree

6 files changed

+152
-125
lines changed

6 files changed

+152
-125
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Model Zoo
141141
* :ref:`torch_export_gpt2`
142142
* :ref:`torch_export_llama2`
143143
* :ref:`torch_export_sam2`
144+
* :ref:`torch_export_flux_dev`
144145
* :ref:`notebooks`
145146

146147
.. toctree::
@@ -157,6 +158,7 @@ Model Zoo
157158
tutorials/_rendered_examples/dynamo/torch_export_gpt2
158159
tutorials/_rendered_examples/dynamo/torch_export_llama2
159160
tutorials/_rendered_examples/dynamo/torch_export_sam2
161+
tutorials/_rendered_examples/dynamo/torch_export_flux_dev
160162
tutorials/notebooks
161163

162164
Python API Documentation
Loading

examples/dynamo/README.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ Model Zoo
2020
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
2121
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
2222
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
23-
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
23+
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
24+
* :ref:`_torch_export_sam2`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
.. _torch_export_flux_dev:
3+
4+
Compiling FLUX.1-dev model using the Torch-TensorRT dynamo backend
5+
===================================================================
6+
7+
This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
8+
Torch-TensorRT.
9+
10+
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications
11+
Install the following dependencies before compilation
12+
13+
.. code-block:: python
14+
15+
pip install -r requirements.txt
16+
17+
There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example,
18+
we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency)
19+
"""
20+
21+
# %%
22+
# Import the following libraries
23+
# -----------------------------
24+
import torch
25+
import torch_tensorrt
26+
from diffusers import FluxPipeline
27+
from torch.export._trace import _export
28+
29+
# %%
30+
# Define the FLUX-1.dev model
31+
# -----------------------------
32+
# Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
33+
# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary
34+
# to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
35+
DEVICE = "cuda:0"
36+
pipe = FluxPipeline.from_pretrained(
37+
"black-forest-labs/FLUX.1-dev",
38+
torch_dtype=torch.float16,
39+
)
40+
pipe.to(DEVICE).to(torch.float16)
41+
# Store the config and transformer backbone
42+
config = pipe.transformer.config
43+
backbone = pipe.transformer
44+
45+
46+
# %%
47+
# Export the backbone using ``torch.export``
48+
# --------------------------------------------------
49+
# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2
50+
# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
51+
batch_size = 2
52+
BATCH = torch.export.Dim("batch", min=1, max=2)
53+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
54+
# This particular min, max values are recommended by torch dynamo during the export of the model.
55+
# To see this recommendation, you can try exporting using min=1, max=4096
56+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
57+
dynamic_shapes = {
58+
"hidden_states": {0: BATCH},
59+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
60+
"pooled_projections": {0: BATCH},
61+
"timestep": {0: BATCH},
62+
"txt_ids": {0: SEQ_LEN},
63+
"img_ids": {0: IMG_ID},
64+
"guidance": {0: BATCH},
65+
}
66+
# The guidance factor is of type torch.float32
67+
dummy_inputs = {
68+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
69+
DEVICE
70+
),
71+
"encoder_hidden_states": torch.randn(
72+
(batch_size, 512, 4096), dtype=torch.float16
73+
).to(DEVICE),
74+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
75+
DEVICE
76+
),
77+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(DEVICE),
78+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
79+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
80+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
81+
}
82+
# This will create an exported program which is going to be compiled with Torch-TensorRT
83+
ep = _export(
84+
backbone,
85+
args=(),
86+
kwargs=dummy_inputs,
87+
dynamic_shapes=dynamic_shapes,
88+
strict=False,
89+
allow_complex_guards_as_runtime_asserts=True,
90+
)
91+
92+
# %%
93+
# Torch-TensorRT compilation
94+
# ---------------------------
95+
# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model.
96+
# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU
97+
trt_gm = torch_tensorrt.dynamo.compile(
98+
ep,
99+
inputs=dummy_inputs,
100+
enabled_precisions={torch.float32},
101+
truncate_double=True,
102+
min_block_size=1,
103+
debug=True,
104+
use_fp32_acc=True,
105+
use_explicit_typing=True,
106+
)
107+
# %%
108+
# Post Processing
109+
# ---------------------------
110+
# Release the GPU memory occupied by the exported program and the pipe.transformer
111+
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
112+
backbone.to("cpu")
113+
del ep
114+
pipe.transformer = trt_gm
115+
pipe.transformer.config = config
116+
117+
# %%
118+
# Image generation using prompt
119+
# ---------------------------
120+
# Provide a prompt and the file name of the image to be generated. Here we use the
121+
# prompt ``A golden retriever holding a sign to code``.
122+
123+
124+
# Function which generates images from the flux pipeline
125+
def generate_image(pipe, prompt, image_name):
126+
seed = 42
127+
image = pipe(
128+
prompt,
129+
output_type="pil",
130+
num_inference_steps=20,
131+
generator=torch.Generator("cuda").manual_seed(seed),
132+
).images[0]
133+
image.save(f"{image_name}.png")
134+
print(f"Image generated using {image_name} model saved as {image_name}.png")
135+
136+
137+
generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code")
138+
139+
# %%
140+
# The generated image is as shown below
141+
# .. image:: dog_code.png

examples/dynamo/torch_export_flux_fp8.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

py/torch_tensorrt/dynamo/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,10 @@ def prepare_inputs(
243243
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
244244
disable_memory_format_check: bool = False,
245245
) -> Any:
246-
if isinstance(inputs, Input):
246+
if inputs is None:
247+
return None
248+
249+
elif isinstance(inputs, Input):
247250
return inputs
248251

249252
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
@@ -400,6 +403,9 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
400403
return torch.tensor(tensor).dtype
401404
elif isinstance(tensor, torch.SymInt):
402405
return torch.int64
406+
elif tensor is None:
407+
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
408+
return None
403409
else:
404410
raise ValueError(f"Found invalid tensor type {type(tensor)}")
405411

0 commit comments

Comments
 (0)