Skip to content

Commit 54a8348

Browse files
Chengzhe Xucehongwang
authored andcommitted
init commit for flux torch.compile
1 parent 5a4dd33 commit 54a8348

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

examples/dynamo/torch_compile_flux.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# https://huggingface.co/black-forest-labs/FLUX.1-schnell
2+
import torch
3+
from diffusers import FluxPipeline
4+
import torch_tensorrt
5+
6+
device = "cuda:0"
7+
backend = "torch_tensorrt"
8+
9+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16, device_map="balanced", max_memory={0: "32GB"})
10+
11+
# pipe = pipe.to(device)
12+
pipe.reset_device_map()
13+
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
14+
15+
# Optimize the transformer portion with Torch-TensorRT
16+
pipe.transformer = torch.compile(
17+
pipe.transformer,
18+
backend=backend,
19+
options={
20+
"truncate_long_and_double": True,
21+
"enabled_precisions": {torch.float16},
22+
# "use_fp32_acc": True,
23+
},
24+
dynamic=False,
25+
)
26+
27+
# pipe.transformer.config['num_layers'] = 5
28+
# pipe.transformer.config.num_layers = 5
29+
30+
prompt = "A cat holding a sign that says hello world"
31+
32+
with torch_tensorrt.logging.debug():
33+
image = pipe(
34+
prompt,
35+
guidance_scale=0.0,
36+
num_inference_steps=4,
37+
max_sequence_length=128,
38+
generator=torch.Generator("cpu").manual_seed(0)
39+
).images[0]
40+
41+
image.save("images/flux-schnell.png")

0 commit comments

Comments
 (0)