-
Notifications
You must be signed in to change notification settings - Fork 363
Implemented basic pipeline for Refitting #2886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
dc98f23
Implemented basic pipeline for Refitting
cehongwang 74d458e
Organized code for refitting
cehongwang c47bef3
Renamed function
cehongwang 869aaad
Supported multi-engine
cehongwang e4cb669
Support both TRTModules and return a new copy
cehongwang 388dadc
Enabled module saving with settings
cehongwang f822b28
Enabled three types of runtime. Build an interface for user to easy r…
cehongwang 56eb549
Reorganized the code
cehongwang 94483a8
Added weight type check and number check
cehongwang 4ba84b7
Deleted the save in compilation
cehongwang ee6f123
deleted more compilation save
cehongwang bc23ddb
Supported different dtypes. Support all possible layers. Support deep…
cehongwang 501e5d9
Delete the outdated file
cehongwang bd5fb55
Deleted setting loading
cehongwang 578927c
Fixed bugs when handling multiple engines. Tested with custom module …
cehongwang 82cd252
Fixed dtype bugs
cehongwang e3cf823
Made a note to add INormalization Layer
cehongwang c3b0862
Update the unsupported torch module weights
cehongwang 400bcac
Cleaned up the code. Added refitting outcome check
cehongwang 6f08664
Use enums to change dtype from np to trt
cehongwang 2250239
Moved check output to util and added a gated flag
cehongwang c906d0e
fixed a bug in check_output function. Changed to only check once afte…
cehongwang 51cba6f
reverse the main function
cehongwang e3576fa
Added support to inline module w/ or w/o graph break
cehongwang cde8fe9
Added an extra attribute to TRT engine in cpp
cehongwang b575105
Added an attribute in TorchTRTModule in python.
cehongwang 9923125
Fixed a type
cehongwang e25941e
Fixed a bug for inline_module refit
cehongwang 646da9e
Added refit example documentation
cehongwang 924f4a8
Added backward compatibility
cehongwang 3c25a3a
Rename the setting enum
cehongwang 0c9637d
Cleaned up cpp constructors
cehongwang bb5fdba
Fixed a type of setting storage checking
cehongwang e47bcb2
Renamed settings to metadata
cehongwang e6e71ca
Added refit to __init__ of dynamo
cehongwang cf43a79
Added docstring. Added support for dynamic shape
cehongwang cfeb6bf
Chagned the check_output function to return a boolean
cehongwang a092229
Chagned get_settings to a static method in TorchTensorRTModule
cehongwang 4819a6d
Simplified the code
cehongwang bd77f22
Added three testcases
cehongwang 1b3a769
Supported torch ops in settings
cehongwang 1456ad9
Updated the example
cehongwang 1acfe31
Wrote 6 test cases for refitting feature to cover different scenarios
cehongwang d38e422
Fixed a bug in tests
cehongwang 880afde
Delete settings check
cehongwang 2dc5bfa
Fixed a bug of modifing settings inplace
cehongwang 410689c
added it to //docsrc/py_api/dynamo.rst so that it gets rendered in th…
cehongwang 381f14a
Added reference to doc
cehongwang eebe883
Changed the default outputcheck to false
cehongwang 2a3d567
Chagned the assertion
cehongwang 003380a
Renamed the imported name
cehongwang 323db97
Renamed
cehongwang de0ab94
Fixed a bug of serialized info signature
cehongwang de3da26
Changed the refit condition check
cehongwang 91c6036
Changed the file path in test file
cehongwang bd43882
Fixed minor format
cehongwang 5ef9af7
Deleted setting repetitions
cehongwang 8882425
Changed min_block_size to 1
cehongwang 7f1f958
Added comments
cehongwang b8e023d
Merged two if statements
cehongwang df9cd39
Chagned the weight type
cehongwang b33fa0f
Fixed hardcoded index
cehongwang 911984d
Fixed a type causing extra overhead
cehongwang 0a1c8ca
Added comments and repaced the index to enum
cehongwang 257db26
Fixed inline module check
cehongwang fef6766
Added deprecate warning. Renamed refit flag to make_refitable
cehongwang d6dbdd4
Merge branch 'main' into refitter-support
cehongwang 7381221
Updated lowering process to conform with latest main branch
cehongwang e7768f7
Handled default setting usecases
cehongwang 51a03c9
Fixed circular import bugs
cehongwang 33bde0f
Changed deprecated behavior
cehongwang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
.. _refit_engine_example: | ||
|
||
Refit TenorRT Graph Module with Torch-TensorRT | ||
=================================================================== | ||
|
||
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights. | ||
|
||
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products. | ||
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient. | ||
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow. | ||
|
||
In this tutorial, we are going to walk through | ||
1. Compiling a PyTorch model to a TensorRT Graph Module | ||
2. Save and load a graph module | ||
3. Refit the graph module | ||
""" | ||
|
||
# %% | ||
# Standard Workflow | ||
# ----------------------------- | ||
|
||
# %% | ||
# Imports and model definition | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
import numpy as np | ||
import torch | ||
import torch_tensorrt as torch_trt | ||
import torchvision.models as models | ||
from torch_tensorrt.dynamo import refit_module_weights | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] | ||
|
||
|
||
# %% | ||
# Compile the module for the first time and save it. | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
model = models.resnet18(pretrained=False).eval().to("cuda") | ||
exp_program = torch.export.export(model, tuple(inputs)) | ||
enabled_precisions = {torch.float} | ||
debug = False | ||
workspace_size = 20 << 30 | ||
min_block_size = 0 | ||
use_python_runtime = False | ||
torch_executed_ops = {} | ||
trt_gm = torch_trt.dynamo.compile( | ||
exp_program, | ||
tuple(inputs), | ||
use_python_runtime=use_python_runtime, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
min_block_size=min_block_size, | ||
torch_executed_ops=torch_executed_ops, | ||
make_refitable=True, | ||
) # Output is a torch.fx.GraphModule | ||
|
||
# Save the graph module as an exported program | ||
# This is only supported when use_python_runtime = False | ||
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs) | ||
|
||
|
||
# %% | ||
# Refit the module with update model weights | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
# Create and compile the updated model | ||
model2 = models.resnet18(pretrained=True).eval().to("cuda") | ||
exp_program2 = torch.export.export(model2, tuple(inputs)) | ||
|
||
|
||
compiled_trt_ep = torch_trt.load("./compiled.ep") | ||
|
||
# This returns a new module with updated weights | ||
new_trt_gm = refit_module_weights( | ||
compiled_module=compiled_trt_ep, | ||
new_weight_module=exp_program2, | ||
inputs=inputs, | ||
) | ||
|
||
# Check the output | ||
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs) | ||
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): | ||
assert torch.allclose( | ||
expected_output, refitted_output, 1e-2, 1e-2 | ||
), "Refit Result is not correct. Refit failed" | ||
|
||
print("Refit successfully!") | ||
|
||
# %% | ||
# Alterative Workflow using Python Runtime | ||
# ----------------------------- | ||
|
||
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime. | ||
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -351,9 +351,9 @@ def convert_method_to_trt_engine( | |||
torchtrt_inputs = prepare_inputs(inputs) | ||||
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) | ||||
|
||||
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] | ||||
return dynamo_convert_module_to_trt_engine( | ||||
exp_program, | ||||
inputs=inputs, | ||||
inputs=tuple(inputs), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason this needs to be a tuple ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function signature defines it as an Tuple
If I do not change it I am not allowed to commit. |
||||
enabled_precisions=enabled_precisions_set, | ||||
**kwargs, | ||||
) | ||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.