Skip to content

Commit 1a309b8

Browse files
committed
Added dynamic shape support to SDXL example
1 parent a8e0b48 commit 1a309b8

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,30 @@
9393

9494
# The only extra line you need
9595
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
96-
97-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
96+
BATCH = torch.export.Dim("BATCH", min=1 * 2, max=12 * 2)
97+
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
98+
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
99+
HEIGHT = 4 * _HEIGHT
100+
WIDTH = 4 * _WIDTH
101+
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
102+
kwargs_dynamic_shapes = {
103+
"encoder_hidden_states": {0: BATCH},
104+
"added_cond_kwargs": {
105+
"text_embeds": {0: BATCH},
106+
"time_ids": {0: BATCH},
107+
},
108+
}
109+
pipe.unet.set_expected_dynamic_shape_range(
110+
args_dynamic_shapes, kwargs_dynamic_shapes
111+
)
112+
image = pipe(
113+
prompt,
114+
negative_prompt=negative,
115+
num_inference_steps=30,
116+
height=1024,
117+
width=768,
118+
num_images_per_prompt=2,
119+
).images[0]
98120
image.save("./without_LoRA_mutable.jpg")
99121

100122
# Standard Huggingface LoRA loading procedure
@@ -108,7 +130,14 @@
108130
pipe.unload_lora_weights()
109131

110132
# Refit triggered
111-
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
133+
image = pipe(
134+
prompt,
135+
negative_prompt=negative,
136+
num_inference_steps=30,
137+
height=1024,
138+
width=1024,
139+
num_images_per_prompt=1,
140+
).images[0]
112141
image.save("./with_LoRA_mutable.jpg")
113142

114143

0 commit comments

Comments
 (0)