Skip to content

Commit d2d30d4

Browse files
committed
examples: Stable Diffusion torch.compile
- Tutorial for using torch.compile with Stable Diffusion and Torch-TensorRT - Demonstration of output images without need to rerun the whole script upon initial compilation of the docs
1 parent 4e5b0f6 commit d2d30d4

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ Tutorials
7979
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
8080
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
8181
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
82+
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
83+
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
8284

8385
Python API Documenation
8486
------------------------
494 KB
Loading

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ a number of ways you can leverage this backend to accelerate inference.
99
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
1010
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12+
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
.. _torch_compile_stable_diffusion:
3+
4+
Torch Compile Stable Diffusion
5+
======================================================
6+
7+
This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a Stable Diffusion model. A sample output is featured below:
8+
9+
.. image:: /tutorials/images/majestic_castle.png
10+
:width: 512px
11+
:height: 512px
12+
:scale: 50 %
13+
:align: right
14+
"""
15+
16+
# %%
17+
# Imports and Model Definition
18+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
19+
20+
import torch
21+
from diffusers import DiffusionPipeline
22+
23+
import torch_tensorrt
24+
25+
model_id = "CompVis/stable-diffusion-v1-4"
26+
device = "cuda:0"
27+
28+
# Instantiate Stable Diffusion Pipeline with FP16 weights
29+
pipe = DiffusionPipeline.from_pretrained(
30+
model_id, revision="fp16", torch_dtype=torch.float16
31+
)
32+
pipe = pipe.to(device)
33+
34+
backend = "torch_tensorrt"
35+
36+
# Optimize the UNet portion with Torch-TensorRT
37+
pipe.unet = torch.compile(
38+
pipe.unet,
39+
backend=backend,
40+
options={
41+
"truncate_long_and_double": True,
42+
"precision": torch.float16,
43+
},
44+
dynamic=False,
45+
)
46+
47+
# %%
48+
# Inference
49+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
50+
51+
prompt = "a majestic castle in the clouds"
52+
image = pipe(prompt).images[0]
53+
54+
image.save("images/majestic_castle.png")
55+
image.show()

0 commit comments

Comments
 (0)