Skip to content

Commit a015538

Browse files
peri044cehongwang
authored andcommitted
chore: updates
1 parent f9fc37a commit a015538

File tree

4 files changed

+94
-5
lines changed

4 files changed

+94
-5
lines changed

examples/dynamo/flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def forward(
180180
use_fp32_acc=use_fp32_acc,
181181
)
182182
trt_end = time.time()
183+
config = pipe.transformer.config
183184
pipe.transformer = trt_model
185+
pipe.transformer.config = config
184186

185187
free, total = torch.cuda.mem_get_info(cuda_device)
186188
print(
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import modelopt.torch.opt as mto
2+
import modelopt.torch.quantization as mtq
3+
import torch
4+
import torch_tensorrt
5+
from diffusers import FluxPipeline
6+
from modelopt.torch.quantization.utils import export_torch_mode
7+
8+
# from onnx_utils.export import generate_dummy_inputs
9+
from torch.export._trace import _export
10+
11+
12+
def generate_image(pipe, prompt, image_name):
13+
seed = 42
14+
image = pipe(
15+
prompt,
16+
output_type="pil",
17+
num_inference_steps=20,
18+
generator=torch.Generator("cuda").manual_seed(seed),
19+
).images[0]
20+
image.save(f"{image_name}.png")
21+
print(f"Image generated using {image_name} model saved as {image_name}.png")
22+
23+
24+
device = "cuda"
25+
pipe = FluxPipeline.from_pretrained(
26+
"black-forest-labs/FLUX.1-dev",
27+
torch_dtype=torch.float16,
28+
)
29+
30+
pipe.to(device)
31+
backbone = pipe.transformer
32+
33+
# Restore FP8 weights
34+
mto.restore(backbone, "./schnell_fp8.pt")
35+
36+
# dummy_inputs = generate_dummy_inputs("flux-dev", "cuda", True)
37+
batch_size = 1
38+
BATCH = torch.export.Dim("batch", min=1, max=2)
39+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=256)
40+
dynamic_shapes = (
41+
{0: BATCH},
42+
{0: BATCH, 1: SEQ_LEN},
43+
{0: BATCH},
44+
{0: BATCH},
45+
{0: BATCH},
46+
{0: BATCH, 1: SEQ_LEN},
47+
)
48+
49+
dummy_inputs = (
50+
torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
51+
torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
52+
torch.randn((batch_size, 768), dtype=torch.float16).to(device),
53+
torch.tensor([1.0, 1.0], dtype=torch.float16).to(device),
54+
torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
55+
torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
56+
)
57+
with export_torch_mode():
58+
ep = _export(
59+
backbone,
60+
dummy_inputs,
61+
dynamic_shapes=dynamic_shapes,
62+
strict=False,
63+
allow_complex_guards_as_runtime_asserts=True,
64+
)
65+
66+
with torch_tensorrt.logging.debug():
67+
trt_gm = torch_tensorrt.dynamo.compile(
68+
ep,
69+
inputs=dummy_inputs,
70+
enabled_precisions={torch.float8_e4m3fn, torch.float16},
71+
truncate_double=True,
72+
dryrun=True,
73+
debug=True,
74+
)
75+
76+
77+
backbone.to("cpu")
78+
config = pipe.transformer.config
79+
pipe.transformer = trt_gm
80+
pipe.transformer.config = config
81+
82+
# Generate an image
83+
generate_image(pipe, "A cat holding a sign that says hello world", "flux-dev")

py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def remove_assert_scalar(
1515
"""Remove assert_scalar ops in the graph"""
1616
count = 0
1717
for node in gm.graph.nodes:
18-
if node.target == torch.ops.aten._assert_scalar.default:
18+
if (
19+
node.target == torch.ops.aten._assert_scalar.default
20+
or node == torch.ops.aten._assert_tensor_metadata.default
21+
):
1922
gm.graph.erase_node(node)
2023
count += 1
2124

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ def prepare_inputs(
246246
if isinstance(inputs, Input):
247247
return inputs
248248

249-
elif isinstance(inputs, torch.Tensor):
249+
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
250250
return Input.from_tensor(
251-
inputs, disable_memory_format_check=disable_memory_format_check
251+
torch.tensor(inputs),
252+
disable_memory_format_check=disable_memory_format_check,
252253
)
253254

254255
elif isinstance(inputs, (list, tuple)):
@@ -395,8 +396,8 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
395396
"""
396397
Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64
397398
"""
398-
if isinstance(tensor, (torch.Tensor, FakeTensor)):
399-
return tensor.dtype
399+
if isinstance(tensor, (torch.Tensor, FakeTensor, int, float, bool)):
400+
return torch.tensor(tensor).dtype
400401
elif isinstance(tensor, torch.SymInt):
401402
return torch.int64
402403
else:

0 commit comments

Comments
 (0)