Skip to content

Commit cef1620

Browse files
committed
Added dynamic shape support to SDXL example
1 parent a8e0b48 commit cef1620

File tree

1 file changed

+96
-83
lines changed

1 file changed

+96
-83
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -22,94 +22,107 @@
2222
import torch_tensorrt as torch_trt
2323
import torchvision.models as models
2424

25-
np.random.seed(5)
26-
torch.manual_seed(5)
27-
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
28-
29-
# %%
30-
# Initialize the Mutable Torch TensorRT Module with settings.
31-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32-
settings = {
33-
"use_python": False,
34-
"enabled_precisions": {torch.float32},
35-
"immutable_weights": False,
36-
}
37-
38-
model = models.resnet18(pretrained=True).eval().to("cuda")
39-
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
40-
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
41-
mutable_module(*inputs)
42-
43-
# %%
44-
# Make modifications to the mutable module.
45-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
46-
47-
# %%
48-
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
49-
model2 = models.resnet18(pretrained=False).eval().to("cuda")
50-
mutable_module.load_state_dict(model2.state_dict())
51-
52-
53-
# Check the output
54-
# The refit happens while you call the mutable module again.
55-
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
56-
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
57-
assert torch.allclose(
58-
expected_output, refitted_output, 1e-2, 1e-2
59-
), "Refit Result is not correct. Refit failed"
60-
61-
print("Refit successfully!")
62-
63-
# %%
64-
# Saving Mutable Torch TensorRT Module
65-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66-
67-
# Currently, saving is only when "use_python" = False in settings
68-
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
69-
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
25+
# np.random.seed(5)
26+
# torch.manual_seed(5)
27+
# inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
28+
29+
# # %%
30+
# # Initialize the Mutable Torch TensorRT Module with settings.
31+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
# settings = {
33+
# "use_python": False,
34+
# "enabled_precisions": {torch.float32},
35+
# "immutable_weights": False,
36+
# }
37+
38+
# model = models.resnet18(pretrained=True).eval().to("cuda")
39+
# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
40+
# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
41+
# mutable_module(*inputs)
42+
43+
# # %%
44+
# # Make modifications to the mutable module.
45+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
46+
47+
# # %%
48+
# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
49+
# model2 = models.resnet18(pretrained=False).eval().to("cuda")
50+
# mutable_module.load_state_dict(model2.state_dict())
51+
52+
53+
# # Check the output
54+
# # The refit happens while you call the mutable module again.
55+
# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
56+
# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
57+
# assert torch.allclose(
58+
# expected_output, refitted_output, 1e-2, 1e-2
59+
# ), "Refit Result is not correct. Refit failed"
60+
61+
# print("Refit successfully!")
62+
63+
# # %%
64+
# # Saving Mutable Torch TensorRT Module
65+
# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66+
67+
# # Currently, saving is only when "use_python" = False in settings
68+
# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
69+
# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
7070

7171
# %%
7272
# Stable Diffusion with Huggingface
7373
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7474

75-
from diffusers import DiffusionPipeline
76-
77-
with torch.no_grad():
78-
settings = {
79-
"use_python_runtime": True,
80-
"enabled_precisions": {torch.float16},
81-
"debug": True,
82-
"immutable_weights": False,
83-
}
84-
85-
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
86-
device = "cuda:0"
87-
88-
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
89-
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
90-
91-
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
92-
pipe.to(device)
93-
94-
# The only extra line you need
95-
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
96-
97-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
98-
image.save("./without_LoRA_mutable.jpg")
99-
100-
# Standard Huggingface LoRA loading procedure
101-
pipe.load_lora_weights(
102-
"stablediffusionapi/load_lora_embeddings",
103-
weight_name="all-disney-princess-xl-lo.safetensors",
104-
adapter_name="lora1",
105-
)
106-
pipe.set_adapters(["lora1"], adapter_weights=[1])
107-
pipe.fuse_lora()
108-
pipe.unload_lora_weights()
109-
110-
# Refit triggered
111-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
112-
image.save("./with_LoRA_mutable.jpg")
75+
# from diffusers import DiffusionPipeline
76+
77+
# with torch.no_grad():
78+
# settings = {
79+
# "use_python_runtime": True,
80+
# "enabled_precisions": {torch.float16},
81+
# "debug": True,
82+
# "immutable_weights": False,
83+
# }
84+
85+
# model_id = "stabilityai/stable-diffusion-xl-base-1.0"
86+
# device = "cuda:0"
87+
88+
# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
89+
# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"
90+
91+
# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
92+
# pipe.to(device)
93+
94+
# # The only extra line you need
95+
# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
96+
# BATCH = torch.export.Dim("BATCH", min=1 * 2, max=12 * 2)
97+
# _HEIGHT = torch.export.Dim('_HEIGHT', min=16, max=32)
98+
# _WIDTH = torch.export.Dim('_WIDTH', min=16, max=32)
99+
# HEIGHT = 4*_HEIGHT
100+
# WIDTH = 4*_WIDTH
101+
# args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
102+
# kwargs_dynamic_shapes = {
103+
# 'encoder_hidden_states': {0: BATCH},
104+
# 'added_cond_kwargs':{
105+
# 'text_embeds': {0: BATCH},
106+
# 'time_ids': {0: BATCH},
107+
# }
108+
# }
109+
# pipe.unet.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shapes)
110+
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30, height=1024, width=768, num_images_per_prompt=2).images[0]
111+
# image.save("./without_LoRA_mutable.jpg")
112+
113+
# # Standard Huggingface LoRA loading procedure
114+
# pipe.load_lora_weights(
115+
# "stablediffusionapi/load_lora_embeddings",
116+
# weight_name="all-disney-princess-xl-lo.safetensors",
117+
# adapter_name="lora1",
118+
# )
119+
# pipe.set_adapters(["lora1"], adapter_weights=[1])
120+
# pipe.fuse_lora()
121+
# pipe.unload_lora_weights()
122+
123+
# # Refit triggered
124+
# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30, height=1024, width=1024, num_images_per_prompt=1).images[0]
125+
# image.save("./with_LoRA_mutable.jpg")
113126

114127

115128
# %%

0 commit comments

Comments
 (0)