|
7 | 7 | This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
|
8 | 8 | Torch-TensorRT.
|
9 | 9 |
|
10 |
| -**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. |
11 |
| -
|
| 10 | +**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications |
12 | 11 | Install the following dependencies before compilation
|
13 | 12 |
|
14 | 13 | .. code-block:: python
|
15 | 14 |
|
16 |
| - pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" |
| 15 | + pip install -r requirements.txt |
17 | 16 |
|
18 |
| -There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, |
19 |
| -we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) |
| 17 | +There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example, |
| 18 | +we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency) |
20 | 19 | """
|
21 | 20 |
|
22 | 21 | # %%
|
|
31 | 30 | # Define the FLUX-1.dev model
|
32 | 31 | # -----------------------------
|
33 | 32 | # Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
|
34 |
| -# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary |
| 33 | +# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary |
35 | 34 | # to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
|
36 | 35 | DEVICE = "cuda:0"
|
37 | 36 | pipe = FluxPipeline.from_pretrained(
|
|
45 | 44 |
|
46 | 45 |
|
47 | 46 | # %%
|
48 |
| -# Export the backbone using torch.export |
| 47 | +# Export the backbone using ``torch.export`` |
49 | 48 | # --------------------------------------------------
|
50 |
| -# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` |
51 |
| -# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
| 49 | +# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2 |
| 50 | +# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
52 | 51 | batch_size = 2
|
53 | 52 | BATCH = torch.export.Dim("batch", min=1, max=2)
|
54 | 53 | SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
|
55 |
| -# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. |
| 54 | +# This particular min, max values are recommended by torch dynamo during the export of the model. |
56 | 55 | # To see this recommendation, you can try exporting using min=1, max=4096
|
57 | 56 | IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
|
58 | 57 | dynamic_shapes = {
|
|
93 | 92 | # %%
|
94 | 93 | # Torch-TensorRT compilation
|
95 | 94 | # ---------------------------
|
96 |
| -# .. note:: |
97 |
| -# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future. |
98 |
| -# |
99 |
| -# |
100 |
| -# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes. |
101 |
| -# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation. |
102 |
| -# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in |
103 |
| -# a single TensorRT engine. |
| 95 | +# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model. |
| 96 | +# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU |
104 | 97 | trt_gm = torch_tensorrt.dynamo.compile(
|
105 | 98 | ep,
|
106 | 99 | inputs=dummy_inputs,
|
107 | 100 | enabled_precisions={torch.float32},
|
108 | 101 | truncate_double=True,
|
109 | 102 | min_block_size=1,
|
| 103 | + debug=True, |
110 | 104 | use_fp32_acc=True,
|
111 | 105 | use_explicit_typing=True,
|
112 | 106 | )
|
113 |
| - |
114 | 107 | # %%
|
115 | 108 | # Post Processing
|
116 | 109 | # ---------------------------
|
@@ -145,6 +138,4 @@ def generate_image(pipe, prompt, image_name):
|
145 | 138 |
|
146 | 139 | # %%
|
147 | 140 | # The generated image is as shown below
|
148 |
| -# |
149 |
| -# .. image:: dog_code.png |
150 |
| -# |
| 141 | +# .. image:: dog_code.png |
0 commit comments