Skip to content

Commit bfb549e

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent c5692e6 commit bfb549e

File tree

1 file changed

+45
-49
lines changed

1 file changed

+45
-49
lines changed

examples/dynamo/torch_export_flux_fp8.py

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,55 +27,34 @@ def generate_image(pipe, prompt, image_name):
2727

2828

2929
device = "cuda"
30-
breakpoint()
3130
pipe = FluxPipeline.from_pretrained(
3231
"black-forest-labs/FLUX.1-dev",
3332
torch_dtype=torch.float16,
3433
)
3534

36-
breakpoint()
37-
pipe.to(device)
38-
pipe.to(torch.float16)
35+
pipe.to(device).to(torch.float16)
36+
config = pipe.transformer.config
37+
# from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
38+
# pipe.transformer = FluxTransformer2DModel(patch_size=1, in_channels=64, num_layers=1, num_single_layers=1, guidance_embeds=True).to("cuda:0").to(torch.float16)
3939
backbone = pipe.transformer
40-
40+
# generate_image(pipe, ["A cat holding a sign that says hello world"], "flux-dev")
41+
# breakpoint()
4142
# mto.restore(backbone, "./schnell_fp8.pt")
4243

43-
# dummy_inputs = generate_dummy_inputs("flux-dev", "cuda", True)
4444
batch_size = 2
4545
BATCH = torch.export.Dim("batch", min=1, max=2)
4646
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
4747
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
48-
# dynamic_shapes = (
49-
# {0: BATCH},
50-
# {0: BATCH},
51-
# {0: BATCH},
52-
# {0: BATCH},
53-
# {0: BATCH, 1: SEQ_LEN},
54-
# {0: BATCH, 1: SEQ_LEN},
55-
# {0: BATCH},
56-
# {}
57-
# )
58-
#
59-
# dummy_inputs = (
60-
# torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
61-
# torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
62-
# torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
63-
# torch.randn((batch_size, 768), dtype=torch.float16).to(device),
64-
# torch.randn((batch_size, 512, 4096), dtype=torch.float16).to(device),
65-
# torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
66-
# torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
67-
# )
48+
6849

6950
dynamic_shapes = {
7051
"hidden_states": {0: BATCH},
7152
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
7253
"pooled_projections": {0: BATCH},
7354
"timestep": {0: BATCH},
74-
"txt_ids": {0: BATCH, 1: SEQ_LEN},
75-
"img_ids": {0: BATCH, 1: IMG_ID},
55+
"txt_ids": {0: SEQ_LEN},
56+
"img_ids": {0: IMG_ID},
7657
"guidance": {0: BATCH},
77-
# "joint_attention_kwargs": {},
78-
# "return_dict": {}
7958
}
8059

8160
dummy_inputs = {
@@ -89,39 +68,56 @@ def generate_image(pipe, prompt, image_name):
8968
device
9069
),
9170
"timestep": torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
92-
"txt_ids": torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
93-
"img_ids": torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
94-
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
95-
# "joint_attention_kwargs": {},
96-
# "return_dict": torch.tensor(False)
71+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(device),
72+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(device),
73+
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(device),
9774
}
98-
with export_torch_mode():
99-
ep = _export(
100-
backbone,
101-
args=(),
102-
kwargs=dummy_inputs,
103-
dynamic_shapes=dynamic_shapes,
104-
strict=False,
105-
allow_complex_guards_as_runtime_asserts=True,
106-
)
75+
# with export_torch_mode():
76+
ep = _export(
77+
backbone,
78+
args=(),
79+
kwargs=dummy_inputs,
80+
dynamic_shapes=dynamic_shapes,
81+
strict=False,
82+
allow_complex_guards_as_runtime_asserts=True,
83+
)
10784

10885
# breakpoint()
10986
with torch_tensorrt.logging.debug():
11087
trt_gm = torch_tensorrt.dynamo.compile(
11188
ep,
11289
inputs=dummy_inputs,
113-
enabled_precisions={torch.float16},
90+
enabled_precisions={torch.float32},
11491
truncate_double=True,
11592
dryrun=False,
11693
min_block_size=1,
94+
# use_python_runtime=True,
11795
debug=True,
96+
use_fp32_acc=True,
97+
use_explicit_typing=True,
11898
)
119-
99+
# breakpoint()
100+
# out_pyt = backbone(**dummy_inputs)
101+
# out_trt = trt_gm(**dummy_inputs)
120102
breakpoint()
103+
104+
105+
class TRTModule(torch.nn.Module):
106+
def __init__(self, trt_mod):
107+
super(TRTModule, self).__init__()
108+
self.trt_mod = trt_mod
109+
110+
def __call__(self, *args, **kwargs):
111+
# breakpoint()
112+
kwargs.pop("joint_attention_kwargs")
113+
kwargs.pop("return_dict")
114+
115+
return self.trt_mod(**kwargs)
116+
117+
121118
backbone.to("cpu")
122-
config = pipe.transformer.config
123-
pipe.transformer = trt_gm
119+
pipe.transformer = TRTModule(trt_gm)
124120
pipe.transformer.config = config
125121

126122
# Generate an image
127-
generate_image(pipe, "A cat holding a sign that says hello world", "flux-dev")
123+
generate_image(pipe, ["A cat holding a sign that says hello world"], "flux-dev")

0 commit comments

Comments
 (0)