|
22 | 22 | import torch_tensorrt as torch_trt
|
23 | 23 | import torchvision.models as models
|
24 | 24 |
|
25 |
| -np.random.seed(5) |
26 |
| -torch.manual_seed(5) |
27 |
| -inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] |
28 |
| - |
29 |
| -# %% |
30 |
| -# Initialize the Mutable Torch TensorRT Module with settings. |
31 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
32 |
| -settings = { |
33 |
| - "use_python": False, |
34 |
| - "enabled_precisions": {torch.float32}, |
35 |
| - "immutable_weights": False, |
36 |
| -} |
37 |
| - |
38 |
| -model = models.resnet18(pretrained=True).eval().to("cuda") |
39 |
| -mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
40 |
| -# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
41 |
| -mutable_module(*inputs) |
42 |
| - |
43 |
| -# %% |
44 |
| -# Make modifications to the mutable module. |
45 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
46 |
| - |
47 |
| -# %% |
48 |
| -# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
49 |
| -model2 = models.resnet18(pretrained=False).eval().to("cuda") |
50 |
| -mutable_module.load_state_dict(model2.state_dict()) |
51 |
| - |
52 |
| - |
53 |
| -# Check the output |
54 |
| -# The refit happens while you call the mutable module again. |
55 |
| -expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
56 |
| -for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
57 |
| - assert torch.allclose( |
58 |
| - expected_output, refitted_output, 1e-2, 1e-2 |
59 |
| - ), "Refit Result is not correct. Refit failed" |
60 |
| - |
61 |
| -print("Refit successfully!") |
62 |
| - |
63 |
| -# %% |
64 |
| -# Saving Mutable Torch TensorRT Module |
65 |
| -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
66 |
| - |
67 |
| -# Currently, saving is only when "use_python" = False in settings |
68 |
| -torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
69 |
| -reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
| 25 | +# np.random.seed(5) |
| 26 | +# torch.manual_seed(5) |
| 27 | +# inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] |
| 28 | + |
| 29 | +# # %% |
| 30 | +# # Initialize the Mutable Torch TensorRT Module with settings. |
| 31 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 32 | +# settings = { |
| 33 | +# "use_python": False, |
| 34 | +# "enabled_precisions": {torch.float32}, |
| 35 | +# "immutable_weights": False, |
| 36 | +# } |
| 37 | + |
| 38 | +# model = models.resnet18(pretrained=True).eval().to("cuda") |
| 39 | +# mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings) |
| 40 | +# # You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module. |
| 41 | +# mutable_module(*inputs) |
| 42 | + |
| 43 | +# # %% |
| 44 | +# # Make modifications to the mutable module. |
| 45 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 46 | + |
| 47 | +# # %% |
| 48 | +# # Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation. |
| 49 | +# model2 = models.resnet18(pretrained=False).eval().to("cuda") |
| 50 | +# mutable_module.load_state_dict(model2.state_dict()) |
| 51 | + |
| 52 | + |
| 53 | +# # Check the output |
| 54 | +# # The refit happens while you call the mutable module again. |
| 55 | +# expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs) |
| 56 | +# for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): |
| 57 | +# assert torch.allclose( |
| 58 | +# expected_output, refitted_output, 1e-2, 1e-2 |
| 59 | +# ), "Refit Result is not correct. Refit failed" |
| 60 | + |
| 61 | +# print("Refit successfully!") |
| 62 | + |
| 63 | +# # %% |
| 64 | +# # Saving Mutable Torch TensorRT Module |
| 65 | +# # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 66 | + |
| 67 | +# # Currently, saving is only when "use_python" = False in settings |
| 68 | +# torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") |
| 69 | +# reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") |
70 | 70 |
|
71 | 71 | # %%
|
72 | 72 | # Stable Diffusion with Huggingface
|
73 | 73 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
74 | 74 |
|
75 |
| -from diffusers import DiffusionPipeline |
76 |
| - |
77 |
| -with torch.no_grad(): |
78 |
| - settings = { |
79 |
| - "use_python_runtime": True, |
80 |
| - "enabled_precisions": {torch.float16}, |
81 |
| - "debug": True, |
82 |
| - "immutable_weights": False, |
83 |
| - } |
84 |
| - |
85 |
| - model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
86 |
| - device = "cuda:0" |
87 |
| - |
88 |
| - prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
89 |
| - negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
90 |
| - |
91 |
| - pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
92 |
| - pipe.to(device) |
93 |
| - |
94 |
| - # The only extra line you need |
95 |
| - pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
96 |
| - |
97 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
98 |
| - image.save("./without_LoRA_mutable.jpg") |
99 |
| - |
100 |
| - # Standard Huggingface LoRA loading procedure |
101 |
| - pipe.load_lora_weights( |
102 |
| - "stablediffusionapi/load_lora_embeddings", |
103 |
| - weight_name="all-disney-princess-xl-lo.safetensors", |
104 |
| - adapter_name="lora1", |
105 |
| - ) |
106 |
| - pipe.set_adapters(["lora1"], adapter_weights=[1]) |
107 |
| - pipe.fuse_lora() |
108 |
| - pipe.unload_lora_weights() |
109 |
| - |
110 |
| - # Refit triggered |
111 |
| - image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] |
112 |
| - image.save("./with_LoRA_mutable.jpg") |
| 75 | +# from diffusers import DiffusionPipeline |
| 76 | + |
| 77 | +# with torch.no_grad(): |
| 78 | +# settings = { |
| 79 | +# "use_python_runtime": True, |
| 80 | +# "enabled_precisions": {torch.float16}, |
| 81 | +# "debug": True, |
| 82 | +# "immutable_weights": False, |
| 83 | +# } |
| 84 | + |
| 85 | +# model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
| 86 | +# device = "cuda:0" |
| 87 | + |
| 88 | +# prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed" |
| 89 | +# negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" |
| 90 | + |
| 91 | +# pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
| 92 | +# pipe.to(device) |
| 93 | + |
| 94 | +# # The only extra line you need |
| 95 | +# pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings) |
| 96 | +# BATCH = torch.export.Dim("BATCH", min=1 * 2, max=12 * 2) |
| 97 | +# _HEIGHT = torch.export.Dim('_HEIGHT', min=16, max=32) |
| 98 | +# _WIDTH = torch.export.Dim('_WIDTH', min=16, max=32) |
| 99 | +# HEIGHT = 4*_HEIGHT |
| 100 | +# WIDTH = 4*_WIDTH |
| 101 | +# args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {}) |
| 102 | +# kwargs_dynamic_shapes = { |
| 103 | +# 'encoder_hidden_states': {0: BATCH}, |
| 104 | +# 'added_cond_kwargs':{ |
| 105 | +# 'text_embeds': {0: BATCH}, |
| 106 | +# 'time_ids': {0: BATCH}, |
| 107 | +# } |
| 108 | +# } |
| 109 | +# pipe.unet.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shapes) |
| 110 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30, height=1024, width=768, num_images_per_prompt=2).images[0] |
| 111 | +# image.save("./without_LoRA_mutable.jpg") |
| 112 | + |
| 113 | +# # Standard Huggingface LoRA loading procedure |
| 114 | +# pipe.load_lora_weights( |
| 115 | +# "stablediffusionapi/load_lora_embeddings", |
| 116 | +# weight_name="all-disney-princess-xl-lo.safetensors", |
| 117 | +# adapter_name="lora1", |
| 118 | +# ) |
| 119 | +# pipe.set_adapters(["lora1"], adapter_weights=[1]) |
| 120 | +# pipe.fuse_lora() |
| 121 | +# pipe.unload_lora_weights() |
| 122 | + |
| 123 | +# # Refit triggered |
| 124 | +# image = pipe(prompt, negative_prompt=negative, num_inference_steps=30, height=1024, width=1024, num_images_per_prompt=1).images[0] |
| 125 | +# image.save("./with_LoRA_mutable.jpg") |
113 | 126 |
|
114 | 127 |
|
115 | 128 | # %%
|
|
0 commit comments