-
Notifications
You must be signed in to change notification settings - Fork 364
[FX] refactor the fx path in compile function #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
b547a33
compile interface
0f5ef06
add compile method
5f99c11
update
4c670de
Merge branch 'master' into fx2trt_wei_2
7be21ca
update
dc7e1a5
Merge remote-tracking branch 'origin/master' into fx2trt_wei_2
f1dfc92
Merge branch 'master' into fx2trt_wei_2
96f9aa3
Update lower_setting.py
596ac14
update fx2trt_example
17e8f94
Merge branch 'fx2trt_wei_2' of https://github.com/pytorch/TensorRT in…
e367e11
add docstring
834a4b0
update dynamic_batch default to False
09babb5
add docstring
9eb349d
add save/load module
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
import copy | ||
import torchvision | ||
import torch_tensorrt | ||
from torch_tensorrt.fx import InputTensorSpec | ||
|
||
|
||
def test_torch_tensorrt(model, inputs): | ||
# torchscript path | ||
model_ts = copy.deepcopy(model) | ||
inputs_ts = copy.deepcopy(inputs) | ||
# fp32 test | ||
with torch.inference_mode(): | ||
ref_fp32 = model_ts(*inputs_ts) | ||
trt_ts_module = torch_tensorrt.compile( | ||
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32} | ||
) | ||
result_fp32 = trt_ts_module(*inputs_ts) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) | ||
# fp16 test | ||
model_ts = model_ts.half() | ||
inputs_ts = [i.cuda().half() for i in inputs_ts] | ||
with torch.inference_mode(): | ||
ref_fp16 = model_ts(*inputs_ts) | ||
trt_ts_module = torch_tensorrt.compile( | ||
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16} | ||
) | ||
result_fp16 = trt_ts_module(*inputs_ts) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99) | ||
|
||
# FX path | ||
model_fx = copy.deepcopy(model) | ||
inputs_fx = copy.deepcopy(inputs) | ||
# fp32 test | ||
with torch.inference_mode(): | ||
ref_fp32 = model_fx(*inputs_fx) | ||
trt_fx_module = torch_tensorrt.compile( | ||
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32} | ||
) | ||
result_fp32 = trt_fx_module(*inputs_fx) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) | ||
# fp16 test | ||
model_fx = model_fx.cuda().half() | ||
inputs_fx = [i.cuda().half() for i in inputs_fx] | ||
with torch.inference_mode(): | ||
ref_fp16 = model_fx(*inputs_fx) | ||
trt_fx_module = torch_tensorrt.compile( | ||
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16} | ||
) | ||
result_fp16 = trt_fx_module(*inputs_fx) | ||
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 ) | ||
|
||
|
||
if __name__ == "__main__": | ||
model = torchvision.models.resnet18(pretrained=True).cuda().eval() | ||
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined] | ||
test_torch_tensorrt(model, inputs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.