Skip to content

Commit f9fc37a

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent 54a8348 commit f9fc37a

File tree

2 files changed

+232
-41
lines changed

2 files changed

+232
-41
lines changed

examples/dynamo/flux.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# %%
2+
# Imports and Model Definition
3+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
import argparse
5+
import logging
6+
from typing import Any, Dict, Optional
7+
8+
import torch
9+
import torch_tensorrt
10+
from diffusers import FluxPipeline, FluxTransformer2DModel
11+
from torch.export import Dim
12+
from transformers import AutoModelForCausalLM, AutoTokenizer
13+
from utils import export_llm, generate
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.DEBUG)
17+
handler = logging.StreamHandler()
18+
handler.setLevel(logging.DEBUG)
19+
logger.addHandler(handler)
20+
21+
import time
22+
from contextlib import contextmanager
23+
24+
25+
@contextmanager
26+
def timer(logger, name: str):
27+
logger.info(f"{name} section Start...")
28+
start = time.time()
29+
yield
30+
end = time.time()
31+
logger.info(f"{name} section End...")
32+
logger.info(f"{name} section elapsed time: {end - start} seconds")
33+
34+
35+
class MyModule(torch.nn.Module):
36+
def __init__(self, module):
37+
super().__init__()
38+
self.module = module
39+
40+
def forward(
41+
self,
42+
hidden_states: torch.Tensor,
43+
encoder_hidden_states: torch.Tensor = None,
44+
pooled_projections: torch.Tensor = None,
45+
timestep: torch.LongTensor = None,
46+
img_ids: torch.Tensor = None,
47+
txt_ids: torch.Tensor = None,
48+
guidance: torch.Tensor = None,
49+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
50+
return_dict: bool = False,
51+
**kwargs,
52+
):
53+
54+
return self.module.forward(
55+
hidden_states,
56+
encoder_hidden_states,
57+
pooled_projections,
58+
timestep,
59+
img_ids,
60+
txt_ids,
61+
)
62+
63+
64+
if __name__ == "__main__":
65+
arg_parser = argparse.ArgumentParser(
66+
description="Run inference on a model with random input values"
67+
)
68+
# The following options are manual user provided settings
69+
arg_parser.add_argument(
70+
"--use_fp32_acc",
71+
action="store_true",
72+
help="Use FP32 acc",
73+
)
74+
arg_parser.add_argument(
75+
"--save_engine",
76+
action="store_true",
77+
help="Just save the TRT engine and stop the program",
78+
)
79+
arg_parser.add_argument(
80+
"--export",
81+
action="store_true",
82+
help="Re-export the TRT module",
83+
)
84+
args = arg_parser.parse_args()
85+
86+
# parameter setting
87+
batch_size = 2
88+
max_seq_len = 256
89+
prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)]
90+
cuda_device = "cuda:0"
91+
device = cuda_device
92+
93+
with torch.no_grad():
94+
# Define the model
95+
pipe = FluxPipeline.from_pretrained(
96+
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16
97+
)
98+
pipe.to(device)
99+
100+
example_inputs = (
101+
torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
102+
torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
103+
torch.randn((batch_size, 768), dtype=torch.float16).to(device),
104+
torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
105+
torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
106+
torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
107+
)
108+
BATCH = Dim("batch", min=1, max=batch_size)
109+
SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len)
110+
dynamic_shapes = (
111+
{0: BATCH},
112+
{0: BATCH, 1: SEQ_LEN},
113+
{0: BATCH},
114+
{0: BATCH},
115+
{0: BATCH},
116+
{0: BATCH, 1: SEQ_LEN},
117+
)
118+
free, total = torch.cuda.mem_get_info(cuda_device)
119+
print(f"== After model declaration == Free mem: {free}, Total mem: {total}")
120+
121+
# Export the transformer
122+
with timer(logger=logger, name="ep_gen"):
123+
model = MyModule(pipe.transformer).eval().half().to(device)
124+
logger.info("Directly use _export because torch.export.export doesn't work")
125+
# This API is used to express the constraint violation guards as asserts in the graph.
126+
from torch.export._trace import _export
127+
128+
ep = _export(
129+
model,
130+
args=example_inputs,
131+
dynamic_shapes=dynamic_shapes,
132+
strict=False,
133+
allow_complex_guards_as_runtime_asserts=True,
134+
)
135+
free, total = torch.cuda.mem_get_info(cuda_device)
136+
print(f"== After model export == Free mem: {free}, Total mem: {total}")
137+
138+
# Torch-TensorRT compilation
139+
logger.info(f"Generating TRT engine now.")
140+
use_explicit_typing, use_fp32_acc = False, False
141+
enabled_precisions = {torch.float16}
142+
if args.use_fp32_acc:
143+
use_explicit_typing = True
144+
use_fp32_acc = True
145+
enabled_precisions = {torch.float32}
146+
147+
if args.save_engine:
148+
with torch_tensorrt.logging.debug():
149+
serialized_engine = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
150+
ep,
151+
inputs=list(example_inputs),
152+
enabled_precisions=enabled_precisions,
153+
truncate_double=True,
154+
device=torch.device(cuda_device),
155+
disable_tf32=True,
156+
use_explicit_typing=use_explicit_typing,
157+
debug=True,
158+
use_fp32_acc=use_fp32_acc,
159+
)
160+
with open("flux_trt.engine", "wb") as file:
161+
file.write(serialized_engine)
162+
163+
free, total = torch.cuda.mem_get_info(cuda_device)
164+
print(
165+
f"== After saving TRT engine == Free mem: {free}, Total mem: {total}"
166+
)
167+
else:
168+
with timer(logger, "trt_gen"):
169+
with torch_tensorrt.logging.debug():
170+
trt_start = time.time()
171+
trt_model = torch_tensorrt.dynamo.compile(
172+
ep,
173+
inputs=list(example_inputs),
174+
enabled_precisions=enabled_precisions,
175+
truncate_double=True,
176+
device=torch.device(cuda_device),
177+
disable_tf32=True,
178+
use_explicit_typing=use_explicit_typing,
179+
debug=True,
180+
use_fp32_acc=use_fp32_acc,
181+
)
182+
trt_end = time.time()
183+
pipe.transformer = trt_model
184+
185+
free, total = torch.cuda.mem_get_info(cuda_device)
186+
print(
187+
f"== After compiling TRT model and before image gen == Free mem: {free}, Total mem: {total}"
188+
)
189+
190+
del ep
191+
del model
192+
print("=== FINISHED TRT COMPILATION. GENERATING IMAGE NOW ...")
193+
prompt = "A cat holding a sign that says hello world"
194+
image = pipe(
195+
prompt,
196+
guidance_scale=0.0,
197+
num_inference_steps=4,
198+
max_sequence_length=128,
199+
generator=torch.Generator("cpu").manual_seed(0),
200+
).images[0]
201+
image.save("./flux-schnell.png")
202+
203+
free, total = torch.cuda.mem_get_info(cuda_device)
204+
print(f"== After image gen == Free mem: {free}, Total mem: {total}")
205+
206+
if args.export:
207+
with timer(logger, "trt_save"):
208+
try:
209+
trt_ep = torch.export.export(
210+
trt_model,
211+
args=example_inputs,
212+
dynamic_shapes=dynamic_shapes,
213+
strict=False,
214+
)
215+
torch.export.save(trt_ep, "trt.ep")
216+
free, total = torch.cuda.mem_get_info(cuda_device)
217+
print(
218+
f"== After TRT model re-export == Free mem: {free}, Total mem: {total}"
219+
)
220+
except Exception as e:
221+
import traceback
222+
223+
# Capture the full traceback
224+
tb = traceback.format_exc()
225+
logger.warning("An error occurred. Here's the traceback:")
226+
# print(tb)
227+
logger.warning(tb)
228+
torch_tensorrt.save(trt_model, "trt.ep")
229+
free, total = torch.cuda.mem_get_info(cuda_device)
230+
print(
231+
f"== After saving TRT module via torch_tensorrt.save == Free mem: {free}, Total mem: {total}"
232+
)

examples/dynamo/torch_compile_flux.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)