1
- import modelopt .torch .opt as mto
2
- import modelopt .torch .quantization as mtq
1
+ # import modelopt.torch.opt as mto
2
+ # import modelopt.torch.quantization as mtq
3
+ # from modelopt.torch.quantization.utils import export_torch_mode
3
4
import torch
4
5
import torch_tensorrt
5
- from diffusers import FluxPipeline
6
- from modelopt .torch .quantization .utils import export_torch_mode
6
+ from diffusers import (
7
+ DiffusionPipeline ,
8
+ FluxPipeline ,
9
+ StableDiffusion3Pipeline ,
10
+ StableDiffusionPipeline ,
11
+ )
7
12
8
13
# from onnx_utils.export import generate_dummy_inputs
9
14
from torch .export ._trace import _export
@@ -22,58 +27,97 @@ def generate_image(pipe, prompt, image_name):
22
27
23
28
24
29
device = "cuda"
30
+ breakpoint ()
25
31
pipe = FluxPipeline .from_pretrained (
26
32
"black-forest-labs/FLUX.1-dev" ,
27
33
torch_dtype = torch .float16 ,
28
34
)
29
35
36
+ breakpoint ()
30
37
pipe .to (device )
38
+ pipe .to (torch .float16 )
31
39
backbone = pipe .transformer
32
40
33
- # Restore FP8 weights
34
- mto .restore (backbone , "./schnell_fp8.pt" )
41
+ # mto.restore(backbone, "./schnell_fp8.pt")
35
42
36
43
# dummy_inputs = generate_dummy_inputs("flux-dev", "cuda", True)
37
- batch_size = 1
44
+ batch_size = 2
38
45
BATCH = torch .export .Dim ("batch" , min = 1 , max = 2 )
39
- SEQ_LEN = torch .export .Dim ("seq_len" , min = 1 , max = 256 )
40
- dynamic_shapes = (
41
- {0 : BATCH },
42
- {0 : BATCH , 1 : SEQ_LEN },
43
- {0 : BATCH },
44
- {0 : BATCH },
45
- {0 : BATCH },
46
- {0 : BATCH , 1 : SEQ_LEN },
47
- )
46
+ SEQ_LEN = torch .export .Dim ("seq_len" , min = 1 , max = 512 )
47
+ IMG_ID = torch .export .Dim ("img_id" , min = 3586 , max = 4096 )
48
+ # dynamic_shapes = (
49
+ # {0: BATCH},
50
+ # {0: BATCH},
51
+ # {0: BATCH},
52
+ # {0: BATCH},
53
+ # {0: BATCH, 1: SEQ_LEN},
54
+ # {0: BATCH, 1: SEQ_LEN},
55
+ # {0: BATCH},
56
+ # {}
57
+ # )
58
+ #
59
+ # dummy_inputs = (
60
+ # torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
61
+ # torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
62
+ # torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
63
+ # torch.randn((batch_size, 768), dtype=torch.float16).to(device),
64
+ # torch.randn((batch_size, 512, 4096), dtype=torch.float16).to(device),
65
+ # torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
66
+ # torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
67
+ # )
48
68
49
- dummy_inputs = (
50
- torch .randn ((batch_size , 4096 , 64 ), dtype = torch .float16 ).to (device ),
51
- torch .randn ((batch_size , 256 , 4096 ), dtype = torch .float16 ).to (device ),
52
- torch .randn ((batch_size , 768 ), dtype = torch .float16 ).to (device ),
53
- torch .tensor ([1.0 , 1.0 ], dtype = torch .float16 ).to (device ),
54
- torch .randn ((batch_size , 4096 , 3 ), dtype = torch .float16 ).to (device ),
55
- torch .randn ((batch_size , 256 , 3 ), dtype = torch .float16 ).to (device ),
56
- )
69
+ dynamic_shapes = {
70
+ "hidden_states" : {0 : BATCH },
71
+ "encoder_hidden_states" : {0 : BATCH , 1 : SEQ_LEN },
72
+ "pooled_projections" : {0 : BATCH },
73
+ "timestep" : {0 : BATCH },
74
+ "txt_ids" : {0 : BATCH , 1 : SEQ_LEN },
75
+ "img_ids" : {0 : BATCH , 1 : IMG_ID },
76
+ "guidance" : {0 : BATCH },
77
+ # "joint_attention_kwargs": {},
78
+ # "return_dict": {}
79
+ }
80
+
81
+ dummy_inputs = {
82
+ "hidden_states" : torch .randn ((batch_size , 4096 , 64 ), dtype = torch .float16 ).to (
83
+ device
84
+ ),
85
+ "encoder_hidden_states" : torch .randn (
86
+ (batch_size , 512 , 4096 ), dtype = torch .float16
87
+ ).to (device ),
88
+ "pooled_projections" : torch .randn ((batch_size , 768 ), dtype = torch .float16 ).to (
89
+ device
90
+ ),
91
+ "timestep" : torch .tensor ([1.0 , 1.0 ], dtype = torch .float16 ).to (device ),
92
+ "txt_ids" : torch .randn ((batch_size , 512 , 3 ), dtype = torch .float16 ).to (device ),
93
+ "img_ids" : torch .randn ((batch_size , 4096 , 3 ), dtype = torch .float16 ).to (device ),
94
+ "guidance" : torch .tensor ([1.0 , 1.0 ], dtype = torch .float16 ).to (device ),
95
+ # "joint_attention_kwargs": {},
96
+ # "return_dict": torch.tensor(False)
97
+ }
57
98
with export_torch_mode ():
58
99
ep = _export (
59
100
backbone ,
60
- dummy_inputs ,
101
+ args = (),
102
+ kwargs = dummy_inputs ,
61
103
dynamic_shapes = dynamic_shapes ,
62
104
strict = False ,
63
105
allow_complex_guards_as_runtime_asserts = True ,
64
106
)
65
107
108
+ # breakpoint()
66
109
with torch_tensorrt .logging .debug ():
67
110
trt_gm = torch_tensorrt .dynamo .compile (
68
111
ep ,
69
112
inputs = dummy_inputs ,
70
- enabled_precisions = {torch .float8_e4m3fn , torch . float16 },
113
+ enabled_precisions = {torch .float16 },
71
114
truncate_double = True ,
72
- dryrun = True ,
115
+ dryrun = False ,
116
+ min_block_size = 1 ,
73
117
debug = True ,
74
118
)
75
119
76
-
120
+ breakpoint ()
77
121
backbone .to ("cpu" )
78
122
config = pipe .transformer .config
79
123
pipe .transformer = trt_gm
0 commit comments