|
7 | 7 | This example illustrates the state of the art model `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ optimized using
|
8 | 8 | Torch-TensorRT.
|
9 | 9 |
|
10 |
| -**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications |
| 10 | +**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. |
| 11 | +
|
11 | 12 | Install the following dependencies before compilation
|
12 | 13 |
|
13 | 14 | .. code-block:: python
|
14 | 15 |
|
15 | 16 | pip install -r requirements.txt
|
16 | 17 |
|
17 |
| -There are different components of the FLUX.1-dev pipeline such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler`. In this example, |
18 |
| -we demonstrate optimizing the `transformer` component of the model (which typically consumes >95% of the e2el diffusion latency) |
| 18 | +There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, |
| 19 | +we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) |
19 | 20 | """
|
20 | 21 |
|
21 | 22 | # %%
|
|
30 | 31 | # Define the FLUX-1.dev model
|
31 | 32 | # -----------------------------
|
32 | 33 | # Load the ``FLUX-1.dev`` pretrained pipeline using ``FluxPipeline`` class.
|
33 |
| -# ``FluxPipeline`` includes all the different components such as `transformer`, `vae`, `text_encoder`, `tokenizer` and `scheduler` necessary |
| 34 | +# ``FluxPipeline`` includes different components such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler`` necessary |
34 | 35 | # to generate an image. We load the weights in ``FP16`` precision using ``torch_dtype`` argument
|
35 | 36 | DEVICE = "cuda:0"
|
36 | 37 | pipe = FluxPipeline.from_pretrained(
|
|
44 | 45 |
|
45 | 46 |
|
46 | 47 | # %%
|
47 |
| -# Export the backbone using ``torch.export`` |
| 48 | +# Export the backbone using torch.export |
48 | 49 | # --------------------------------------------------
|
49 |
| -# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a batch size of 2 |
50 |
| -# due to 0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
| 50 | +# Define the dummy inputs and their respective dynamic shapes. We export the transformer backbone with dynamic shapes with a ``batch_size=2`` |
| 51 | +# due to `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_ |
51 | 52 | batch_size = 2
|
52 | 53 | BATCH = torch.export.Dim("batch", min=1, max=2)
|
53 | 54 | SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
|
54 |
| -# This particular min, max values are recommended by torch dynamo during the export of the model. |
| 55 | +# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. |
55 | 56 | # To see this recommendation, you can try exporting using min=1, max=4096
|
56 | 57 | IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
|
57 | 58 | dynamic_shapes = {
|
|
92 | 93 | # %%
|
93 | 94 | # Torch-TensorRT compilation
|
94 | 95 | # ---------------------------
|
95 |
| -# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model. |
96 |
| -# Since this is a 12 billion parameter model, it takes around 20-30 min on H100 GPU |
| 96 | +# .. note:: |
| 97 | +# The compilation requires a GPU with high memory (> 80GB) since TensorRT is storing the weights in FP32 precision. This is a known issue and will be resolved in the future. |
| 98 | +# |
| 99 | +# |
| 100 | +# We enable ``FP32`` matmul accumulation using ``use_fp32_acc=True`` to ensure accuracy is preserved by introducing cast to ``FP32`` nodes. |
| 101 | +# We also enable explicit typing to ensure TensorRT respects the datatypes set by the user which is a requirement for FP32 matmul accumulation. |
| 102 | +# Since this is a 12 billion parameter model, it takes around 20-30 min to compile on H100 GPU. The model is completely convertible and results in |
| 103 | +# a single TensorRT engine. |
97 | 104 | trt_gm = torch_tensorrt.dynamo.compile(
|
98 | 105 | ep,
|
99 | 106 | inputs=dummy_inputs,
|
100 | 107 | enabled_precisions={torch.float32},
|
101 | 108 | truncate_double=True,
|
102 | 109 | min_block_size=1,
|
103 |
| - debug=True, |
104 | 110 | use_fp32_acc=True,
|
105 | 111 | use_explicit_typing=True,
|
106 | 112 | )
|
| 113 | + |
107 | 114 | # %%
|
108 | 115 | # Post Processing
|
109 | 116 | # ---------------------------
|
@@ -138,4 +145,6 @@ def generate_image(pipe, prompt, image_name):
|
138 | 145 |
|
139 | 146 | # %%
|
140 | 147 | # The generated image is as shown below
|
141 |
| -# .. image:: dog_code.png |
| 148 | +# |
| 149 | +# .. image:: dog_code.png |
| 150 | +# |
0 commit comments