Skip to content

Commit c5692e6

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent a015538 commit c5692e6

File tree

1 file changed

+72
-28
lines changed

1 file changed

+72
-28
lines changed

examples/dynamo/torch_export_flux_fp8.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
import modelopt.torch.opt as mto
2-
import modelopt.torch.quantization as mtq
1+
# import modelopt.torch.opt as mto
2+
# import modelopt.torch.quantization as mtq
3+
# from modelopt.torch.quantization.utils import export_torch_mode
34
import torch
45
import torch_tensorrt
5-
from diffusers import FluxPipeline
6-
from modelopt.torch.quantization.utils import export_torch_mode
6+
from diffusers import (
7+
DiffusionPipeline,
8+
FluxPipeline,
9+
StableDiffusion3Pipeline,
10+
StableDiffusionPipeline,
11+
)
712

813
# from onnx_utils.export import generate_dummy_inputs
914
from torch.export._trace import _export
@@ -22,58 +27,97 @@ def generate_image(pipe, prompt, image_name):
2227

2328

2429
device = "cuda"
30+
breakpoint()
2531
pipe = FluxPipeline.from_pretrained(
2632
"black-forest-labs/FLUX.1-dev",
2733
torch_dtype=torch.float16,
2834
)
2935

36+
breakpoint()
3037
pipe.to(device)
38+
pipe.to(torch.float16)
3139
backbone = pipe.transformer
3240

33-
# Restore FP8 weights
34-
mto.restore(backbone, "./schnell_fp8.pt")
41+
# mto.restore(backbone, "./schnell_fp8.pt")
3542

3643
# dummy_inputs = generate_dummy_inputs("flux-dev", "cuda", True)
37-
batch_size = 1
44+
batch_size = 2
3845
BATCH = torch.export.Dim("batch", min=1, max=2)
39-
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=256)
40-
dynamic_shapes = (
41-
{0: BATCH},
42-
{0: BATCH, 1: SEQ_LEN},
43-
{0: BATCH},
44-
{0: BATCH},
45-
{0: BATCH},
46-
{0: BATCH, 1: SEQ_LEN},
47-
)
46+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
47+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
48+
# dynamic_shapes = (
49+
# {0: BATCH},
50+
# {0: BATCH},
51+
# {0: BATCH},
52+
# {0: BATCH},
53+
# {0: BATCH, 1: SEQ_LEN},
54+
# {0: BATCH, 1: SEQ_LEN},
55+
# {0: BATCH},
56+
# {}
57+
# )
58+
#
59+
# dummy_inputs = (
60+
# torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
61+
# torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
62+
# torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
63+
# torch.randn((batch_size, 768), dtype=torch.float16).to(device),
64+
# torch.randn((batch_size, 512, 4096), dtype=torch.float16).to(device),
65+
# torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
66+
# torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
67+
# )
4868

49-
dummy_inputs = (
50-
torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
51-
torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
52-
torch.randn((batch_size, 768), dtype=torch.float16).to(device),
53-
torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
54-
torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
55-
torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
56-
)
69+
dynamic_shapes = {
70+
"hidden_states": {0: BATCH},
71+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
72+
"pooled_projections": {0: BATCH},
73+
"timestep": {0: BATCH},
74+
"txt_ids": {0: BATCH, 1: SEQ_LEN},
75+
"img_ids": {0: BATCH, 1: IMG_ID},
76+
"guidance": {0: BATCH},
77+
# "joint_attention_kwargs": {},
78+
# "return_dict": {}
79+
}
80+
81+
dummy_inputs = {
82+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
83+
device
84+
),
85+
"encoder_hidden_states": torch.randn(
86+
(batch_size, 512, 4096), dtype=torch.float16
87+
).to(device),
88+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
89+
device
90+
),
91+
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
92+
"txt_ids": torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
93+
"img_ids": torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
94+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
95+
# "joint_attention_kwargs": {},
96+
# "return_dict": torch.tensor(False)
97+
}
5798
with export_torch_mode():
5899
ep = _export(
59100
backbone,
60-
dummy_inputs,
101+
args=(),
102+
kwargs=dummy_inputs,
61103
dynamic_shapes=dynamic_shapes,
62104
strict=False,
63105
allow_complex_guards_as_runtime_asserts=True,
64106
)
65107

108+
# breakpoint()
66109
with torch_tensorrt.logging.debug():
67110
trt_gm = torch_tensorrt.dynamo.compile(
68111
ep,
69112
inputs=dummy_inputs,
70-
enabled_precisions={torch.float8_e4m3fn, torch.float16},
113+
enabled_precisions={torch.float16},
71114
truncate_double=True,
72-
dryrun=True,
115+
dryrun=False,
116+
min_block_size=1,
73117
debug=True,
74118
)
75119

76-
120+
breakpoint()
77121
backbone.to("cpu")
78122
config = pipe.transformer.config
79123
pipe.transformer = trt_gm

0 commit comments

Comments
 (0)