@@ -27,55 +27,34 @@ def generate_image(pipe, prompt, image_name):
27
27
28
28
29
29
device = "cuda"
30
- breakpoint ()
31
30
pipe = FluxPipeline .from_pretrained (
32
31
"black-forest-labs/FLUX.1-dev" ,
33
32
torch_dtype = torch .float16 ,
34
33
)
35
34
36
- breakpoint ()
37
- pipe .to (device )
38
- pipe .to (torch .float16 )
35
+ pipe .to (device ).to (torch .float16 )
36
+ config = pipe .transformer .config
37
+ # from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
38
+ # pipe.transformer = FluxTransformer2DModel(patch_size=1, in_channels=64, num_layers=1, num_single_layers=1, guidance_embeds=True).to("cuda:0").to(torch.float16)
39
39
backbone = pipe .transformer
40
-
40
+ # generate_image(pipe, ["A cat holding a sign that says hello world"], "flux-dev")
41
+ # breakpoint()
41
42
# mto.restore(backbone, "./schnell_fp8.pt")
42
43
43
- # dummy_inputs = generate_dummy_inputs("flux-dev", "cuda", True)
44
44
batch_size = 2
45
45
BATCH = torch .export .Dim ("batch" , min = 1 , max = 2 )
46
46
SEQ_LEN = torch .export .Dim ("seq_len" , min = 1 , max = 512 )
47
47
IMG_ID = torch .export .Dim ("img_id" , min = 3586 , max = 4096 )
48
- # dynamic_shapes = (
49
- # {0: BATCH},
50
- # {0: BATCH},
51
- # {0: BATCH},
52
- # {0: BATCH},
53
- # {0: BATCH, 1: SEQ_LEN},
54
- # {0: BATCH, 1: SEQ_LEN},
55
- # {0: BATCH},
56
- # {}
57
- # )
58
- #
59
- # dummy_inputs = (
60
- # torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
61
- # torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
62
- # torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
63
- # torch.randn((batch_size, 768), dtype=torch.float16).to(device),
64
- # torch.randn((batch_size, 512, 4096), dtype=torch.float16).to(device),
65
- # torch.randn((batch_size, 512, 3), dtype=torch.float16).to(device),
66
- # torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
67
- # )
48
+
68
49
69
50
dynamic_shapes = {
70
51
"hidden_states" : {0 : BATCH },
71
52
"encoder_hidden_states" : {0 : BATCH , 1 : SEQ_LEN },
72
53
"pooled_projections" : {0 : BATCH },
73
54
"timestep" : {0 : BATCH },
74
- "txt_ids" : {0 : BATCH , 1 : SEQ_LEN },
75
- "img_ids" : {0 : BATCH , 1 : IMG_ID },
55
+ "txt_ids" : {0 : SEQ_LEN },
56
+ "img_ids" : {0 : IMG_ID },
76
57
"guidance" : {0 : BATCH },
77
- # "joint_attention_kwargs": {},
78
- # "return_dict": {}
79
58
}
80
59
81
60
dummy_inputs = {
@@ -89,39 +68,56 @@ def generate_image(pipe, prompt, image_name):
89
68
device
90
69
),
91
70
"timestep" : torch .tensor ([1.0 , 1.0 ], dtype = torch .float16 ).to (device ),
92
- "txt_ids" : torch .randn ((batch_size , 512 , 3 ), dtype = torch .float16 ).to (device ),
93
- "img_ids" : torch .randn ((batch_size , 4096 , 3 ), dtype = torch .float16 ).to (device ),
94
- "guidance" : torch .tensor ([1.0 , 1.0 ], dtype = torch .float16 ).to (device ),
95
- # "joint_attention_kwargs": {},
96
- # "return_dict": torch.tensor(False)
71
+ "txt_ids" : torch .randn ((512 , 3 ), dtype = torch .float16 ).to (device ),
72
+ "img_ids" : torch .randn ((4096 , 3 ), dtype = torch .float16 ).to (device ),
73
+ "guidance" : torch .tensor ([1.0 , 1.0 ], dtype = torch .float32 ).to (device ),
97
74
}
98
- with export_torch_mode ():
99
- ep = _export (
100
- backbone ,
101
- args = (),
102
- kwargs = dummy_inputs ,
103
- dynamic_shapes = dynamic_shapes ,
104
- strict = False ,
105
- allow_complex_guards_as_runtime_asserts = True ,
106
- )
75
+ # with export_torch_mode():
76
+ ep = _export (
77
+ backbone ,
78
+ args = (),
79
+ kwargs = dummy_inputs ,
80
+ dynamic_shapes = dynamic_shapes ,
81
+ strict = False ,
82
+ allow_complex_guards_as_runtime_asserts = True ,
83
+ )
107
84
108
85
# breakpoint()
109
86
with torch_tensorrt .logging .debug ():
110
87
trt_gm = torch_tensorrt .dynamo .compile (
111
88
ep ,
112
89
inputs = dummy_inputs ,
113
- enabled_precisions = {torch .float16 },
90
+ enabled_precisions = {torch .float32 },
114
91
truncate_double = True ,
115
92
dryrun = False ,
116
93
min_block_size = 1 ,
94
+ # use_python_runtime=True,
117
95
debug = True ,
96
+ use_fp32_acc = True ,
97
+ use_explicit_typing = True ,
118
98
)
119
-
99
+ # breakpoint()
100
+ # out_pyt = backbone(**dummy_inputs)
101
+ # out_trt = trt_gm(**dummy_inputs)
120
102
breakpoint ()
103
+
104
+
105
+ class TRTModule (torch .nn .Module ):
106
+ def __init__ (self , trt_mod ):
107
+ super (TRTModule , self ).__init__ ()
108
+ self .trt_mod = trt_mod
109
+
110
+ def __call__ (self , * args , ** kwargs ):
111
+ # breakpoint()
112
+ kwargs .pop ("joint_attention_kwargs" )
113
+ kwargs .pop ("return_dict" )
114
+
115
+ return self .trt_mod (** kwargs )
116
+
117
+
121
118
backbone .to ("cpu" )
122
- config = pipe .transformer .config
123
- pipe .transformer = trt_gm
119
+ pipe .transformer = TRTModule (trt_gm )
124
120
pipe .transformer .config = config
125
121
126
122
# Generate an image
127
- generate_image (pipe , "A cat holding a sign that says hello world" , "flux-dev" )
123
+ generate_image (pipe , [ "A cat holding a sign that says hello world" ] , "flux-dev" )
0 commit comments