-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Support exporting Torch-TRT compiled Graphmodules #3262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia 2f408f9
test
lanluo-nvidia 1c5e86c
test
lanluo-nvidia ba487dc
test
lanluo-nvidia 99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia 2b43480
test
lanluo-nvidia 17b57a6
feat: Add re-export functionality for Torch-TRT modules
peri044 b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia 3d94f8b
test
lanluo-nvidia cb03ca1
feat: add support for re-exporting graph modules
peri044 28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia b89cbe0
resolve comments
lanluo-nvidia 2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia 3eb48d7
test
lanluo-nvidia 839c72e
chore: updates
peri044 50eb0d8
replace dummy inference
lanluo-nvidia 95ed602
test
lanluo-nvidia 120f30d
test
lanluo-nvidia 424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia 2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia 14f5d61
test
lanluo-nvidia 7563959
test
lanluo-nvidia 77355f0
test
lanluo-nvidia 13361fd
add linear lowering meta val
lanluo-nvidia fca16a5
chore: updates
peri044 f0a9fef
add linear_lowering change
lanluo-nvidia cff64a4
test
lanluo-nvidia 933abac
test
lanluo-nvidia 8417684
resolve comments
lanluo-nvidia 8676f88
test
lanluo-nvidia df13856
chore: updates
peri044 d406366
chore: updates
peri044 595ea6e
chore: updates
peri044 076f47a
resolve comments
lanluo-nvidia 8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia 96e93e4
resolve comments
lanluo-nvidia 675667b
chore: updates
peri044 4e1a538
chore: updates
peri044 fb12021
chore: updates
peri044 6b3f94c
chore: updates
peri044 1983c60
chore: add tests
peri044 dd94194
chore: updates
peri044 ea226d6
chore: address comments
peri044 0d04111
chore: rebase with main
peri044 772e5d1
chore: updates
peri044 f739f57
chore: fix tests
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
py/torch_tensorrt/dynamo/runtime/register_fake_class.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import base64 | ||
from collections import defaultdict | ||
from typing import Any, List | ||
|
||
import torch | ||
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape | ||
|
||
|
||
@torch.library.register_fake("tensorrt::execute_engine") # type: ignore | ||
def fake_tensorrt_execute_engine( | ||
inputs: List[torch.Tensor], fake_trt_engine: Any | ||
) -> Any: | ||
""" | ||
We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel. | ||
""" | ||
# Here's what we are doing | ||
# 1) Check if inputs are dynamic (they have sym ints in their shapes) | ||
# 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs | ||
# 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism | ||
# 4) Create a new symbolic fake tensor using min and max output shape for each output and return them | ||
# 5) For static inputs, the output shape will be static and we won't need to create sym ints | ||
is_dynamic_execution = input_is_dynamic(inputs) | ||
if is_dynamic_execution: | ||
modes = ["min", "max", "opt"] | ||
else: | ||
modes = ["opt"] | ||
|
||
# Get the TRTEngine class and infer output shapes based on input shapes | ||
trt_engine = fake_trt_engine.wrapped_obj.engine | ||
outputs_mode_dict = defaultdict(list) | ||
for mode in modes: | ||
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs] | ||
proxy_outputs = trt_engine.infer_outputs(input_shapes) | ||
outputs_mode_dict[mode].extend(proxy_outputs) | ||
|
||
# Store the number of outputs | ||
if {"min", "max"}.issubset(outputs_mode_dict): | ||
assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"]) | ||
num_outputs = len(outputs_mode_dict["min"]) | ||
elif "opt" in outputs_mode_dict: | ||
num_outputs = len(outputs_mode_dict["opt"]) | ||
|
||
fake_outputs = [] | ||
for out_idx in range(num_outputs): | ||
output_shape = [] | ||
if is_dynamic_execution: | ||
# Create output symbolic shape using unbacked symint. | ||
# Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0) | ||
# and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem | ||
# to affect compilation results / serialization during our testing. | ||
output_min_shape = outputs_mode_dict["min"][out_idx].size() | ||
output_opt_shape = outputs_mode_dict["opt"][out_idx].size() | ||
output_max_shape = outputs_mode_dict["max"][out_idx].size() | ||
|
||
ctx = torch._custom_ops.get_ctx() | ||
for min_val, opt_val, max_val in zip( | ||
output_min_shape, output_opt_shape, output_max_shape | ||
): | ||
if min_val != max_val: | ||
output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val) | ||
# Update var to val (hint) | ||
output_sym_int_shape_env = output_sym_int.node.shape_env | ||
output_sym_int_shape_env.add_var_to_val( | ||
output_sym_int.node.expr, opt_val | ||
) | ||
output_shape.append(output_sym_int) | ||
else: | ||
output_shape.append(min_val) | ||
else: | ||
output_shape.extend(outputs_mode_dict["opt"][out_idx].size()) | ||
|
||
fake_outputs.append( | ||
torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype) | ||
) | ||
|
||
return fake_outputs | ||
|
||
|
||
@torch._library.register_fake_class("tensorrt::Engine") | ||
class FakeTRTEngine: | ||
peri044 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, engine_info: List[str]) -> None: | ||
self.engine = torch.classes.tensorrt.Engine(engine_info) | ||
|
||
@classmethod | ||
def __obj_unflatten__(cls, flattened_tq: Any) -> Any: | ||
engine_idx = torch.ops.tensorrt.ENGINE_IDX() | ||
engine_info = [info[1] for info in flattened_tq] | ||
engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx]) | ||
|
||
return cls(engine_info) | ||
|
||
def enable_profiling(self) -> Any: | ||
pass | ||
|
||
def disable_profiling(self) -> Any: | ||
pass | ||
|
||
def dump_engine_layer_info_to_file(self, path: str) -> Any: | ||
pass | ||
|
||
def dump_engine_layer_info(self) -> Any: | ||
pass | ||
|
||
def get_engine_layer_info(self) -> Any: | ||
pass | ||
|
||
def profile_path_prefix_getter(self) -> Any: | ||
pass | ||
|
||
def profile_path_prefix_setter(self) -> Any: | ||
pass | ||
|
||
def device_memory_budget_getter(self) -> Any: | ||
pass | ||
|
||
def device_memory_budget_setter(self) -> Any: | ||
pass | ||
|
||
def streamable_device_memory_budget_getter(self) -> Any: | ||
pass | ||
|
||
def automatic_device_memory_budget_getter(self) -> Any: | ||
pass | ||
|
||
def infer_outputs(self, input_shapes: List[Any]) -> Any: | ||
pass | ||
|
||
def __setstate__(self, serialized_state: List[str]) -> Any: | ||
pass |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.