Skip to content

Commit b9433ec

Browse files
committed
Merge remote-tracking branch 'dynamo_changes' into sample_backend
2 parents 226cc79 + cf5bb20 commit b9433ec

File tree

152 files changed

+21204
-260
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

152 files changed

+21204
-260
lines changed

.circleci/config.yml

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ commands:
508508
- store_artifacts:
509509
path: /tmp/testlogs
510510

511+
# =================== FX tests start ======================== #
511512
test-fx_core:
512513
description: "Test the fx core"
513514
steps:
@@ -707,6 +708,167 @@ commands:
707708
- store_artifacts:
708709
path: /tmp/testlogs
709710

711+
# =================== FX tests end ======================== #
712+
713+
# =================== Dynamo tests start ======================== #
714+
test-dynamo-fx_ts_core:
715+
description: "Test the Dynamo core"
716+
steps:
717+
- run:
718+
name: Run Dynamo core tests
719+
command: |
720+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
721+
pushd core/
722+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/core/test_results.xml
723+
popd
724+
725+
- store_test_results:
726+
path: /tmp/artifacts
727+
- store_artifacts:
728+
path: /tmp/testlogs
729+
730+
test-dynamo-fx_ts_converters_acc:
731+
description: "Test the Dynamo acc converters"
732+
steps:
733+
- run:
734+
name: Run FX converter tests
735+
command: |
736+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
737+
pushd converters/acc_op/
738+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/acc_op/test_results.xml
739+
popd
740+
741+
- store_test_results:
742+
path: /tmp/artifacts
743+
- store_artifacts:
744+
path: /tmp/testlogs
745+
746+
test-dynamo-fx_ts_converters_aten:
747+
description: "Test the dynamo aten converters"
748+
steps:
749+
- run:
750+
name: Run dynamo converter tests
751+
command: |
752+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
753+
pushd converters/aten_op/
754+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/aten_op/test_results.xml
755+
popd
756+
757+
- store_test_results:
758+
path: /tmp/artifacts
759+
- store_artifacts:
760+
path: /tmp/testlogs
761+
762+
test-dynamo-fx_ts_converters_vanilla:
763+
description: "Test the dynamo vanilla converters"
764+
steps:
765+
- run:
766+
name: Run dynamo converter tests
767+
command: |
768+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
769+
pushd converters/vanilla/
770+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/vanilla/test_results.xml
771+
popd
772+
773+
- store_test_results:
774+
path: /tmp/artifacts
775+
- store_artifacts:
776+
path: /tmp/testlogs
777+
778+
test-dynamo-fx_ts_passes:
779+
description: "Test the dynamo passes"
780+
steps:
781+
- run:
782+
name: Run dynamo passes
783+
command: |
784+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
785+
pushd passes
786+
list_passes=$(ls | grep -v test_setitem*)
787+
pytest $list_passes --junitxml=/tmp/artifacts/test_results/dynamo/passes/test_results.xml
788+
popd
789+
- store_test_results:
790+
path: /tmp/artifacts
791+
- store_artifacts:
792+
path: /tmp/testlogs
793+
794+
test-dynamo-fx_ts_tools:
795+
description: "Test the dynamo tools"
796+
steps:
797+
- run:
798+
name: Run dynamo tools
799+
command: |
800+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
801+
pushd tools
802+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/tools/test_results.xml
803+
popd
804+
- store_test_results:
805+
path: /tmp/artifacts
806+
- store_artifacts:
807+
path: /tmp/testlogs
808+
809+
test-dynamo-fx_ts_trt_lower:
810+
description: "Test the dynamo TRT lowering"
811+
steps:
812+
- run:
813+
name: Run dynamo TRT lowering
814+
command: |
815+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
816+
pushd trt_lower
817+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/trt_lower/test_results.xml
818+
popd
819+
- store_test_results:
820+
path: /tmp/artifacts
821+
- store_artifacts:
822+
path: /tmp/testlogs
823+
824+
test-dynamo-fx_ts_tracer:
825+
description: "Test all dynamo tracers"
826+
steps:
827+
- run:
828+
name: Run dynamo tracer
829+
command: |
830+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
831+
pushd tracer
832+
list_tracer=$(ls | grep -v test_dispatch_*)
833+
pytest $list_tracer --junitxml=/tmp/artifacts/test_results/fx/tracer/test_results.xml
834+
popd
835+
- store_test_results:
836+
path: /tmp/artifacts
837+
- store_artifacts:
838+
path: /tmp/testlogs
839+
840+
test-dynamo-fx_ts_tracer_acc:
841+
description: "Test the dynamo acc tracer only"
842+
steps:
843+
- run:
844+
name: Run dynamo tracer
845+
command: |
846+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
847+
pushd tracer
848+
list_tracer=$(ls | grep test_acc)
849+
pytest $list_tracer --junitxml=/tmp/artifacts/test_results/dynamo/tracer/test_results.xml
850+
popd
851+
- store_test_results:
852+
path: /tmp/artifacts
853+
- store_artifacts:
854+
path: /tmp/testlogs
855+
856+
test-dynamo-fx_ts_quant:
857+
description: "Test the dynamo quant"
858+
steps:
859+
- run:
860+
name: Run dynamo quant tests
861+
command: |
862+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
863+
pushd quant/
864+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/quant/test_results.xml
865+
popd
866+
867+
- store_test_results:
868+
path: /tmp/artifacts
869+
- store_artifacts:
870+
path: /tmp/testlogs
871+
710872
test-dynamo-torch_compile:
711873
description: "Test the Dynamo torch_compile path"
712874
steps:
@@ -719,11 +881,55 @@ commands:
719881
pip3 install transformers
720882
pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile
721883
popd
884+
722885
- store_test_results:
723886
path: /tmp/artifacts
724887
- store_artifacts:
725888
path: /tmp/testlogs
726889

890+
test-dynamo-fx_ts:
891+
description: "Test the dynamo backend"
892+
steps:
893+
- run:
894+
name: Run dynamo tests
895+
command: |
896+
mkdir -p /tmp/artifacts/test_results
897+
- test-dynamo-fx_ts_converters_acc
898+
- test-dynamo-fx_ts_converters_aten
899+
- test-dynamo-fx_ts_converters_vanilla
900+
- test-dynamo-fx_ts_passes
901+
- test-dynamo-fx_ts_tools
902+
- test-dynamo-fx_ts_trt_lower
903+
- test-dynamo-fx_ts_tracer
904+
- test-dynamo-fx_ts_core
905+
- test-dynamo-fx_ts_quant
906+
- store_test_results:
907+
path: /tmp/artifacts
908+
- store_artifacts:
909+
path: /tmp/testlogs
910+
911+
test-dynamo-fx_ts-no-aten:
912+
description: "Test the dynamo backend without aten operators"
913+
steps:
914+
- run:
915+
name: Run dynamo tests without aten ops
916+
command: |
917+
mkdir -p /tmp/artifacts/test_results
918+
- test-dynamo-fx_ts_converters_acc
919+
- test-dynamo-fx_ts_converters_vanilla
920+
- test-dynamo-fx_ts_passes
921+
- test-dynamo-fx_ts_tools
922+
- test-dynamo-fx_ts_trt_lower
923+
- test-dynamo-fx_ts_tracer_acc
924+
- test-dynamo-fx_ts_core
925+
- test-dynamo-fx_ts_quant
926+
- store_test_results:
927+
path: /tmp/artifacts
928+
- store_artifacts:
929+
path: /tmp/testlogs
930+
931+
# =================== Dynamo tests end ======================== #
932+
727933
# Define a job to be invoked later in a workflow.
728934
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
729935
jobs:
@@ -930,6 +1136,7 @@ jobs:
9301136
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
9311137
- dump-test-env
9321138
- test-dynamo-torch_compile
1139+
- test-dynamo-fx_ts
9331140

9341141
test-py-dynamo-x86_64-linux-no-aten:
9351142
parameters:
@@ -960,7 +1167,7 @@ jobs:
9601167
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
9611168
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
9621169
- dump-test-env
963-
- test-dynamo-torch_compile
1170+
- test-dynamo-fx_ts-no-aten
9641171

9651172
package-x86_64-linux:
9661173
parameters:

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import torch_tensorrt
7373
...
7474
7575
trt_ts_module = torch_tensorrt.compile(torch_script_module,
76+
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
7677
inputs = [example_tensor, # Provide example tensor for input shape or...
7778
torch_tensorrt.Input( # Specify input object with shape and dtype
7879
min_shape=[1, 3, 224, 224],
@@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
8182
# For static size shape=[1, 3, 224, 224]
8283
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
8384
],
85+
86+
# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
87+
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
88+
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
89+
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),
90+
8491
enabled_precisions = {torch.half}, # Run with FP16
8592
)
8693
@@ -114,7 +121,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
114121
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
115122

116123
- Bazel 5.2.0
117-
- Libtorch 2.0.0.dev20230103 (built with CUDA 11.7)
124+
- Libtorch 2.1.0.dev20230314 (built with CUDA 11.7)
118125
- CUDA 11.7
119126
- cuDNN 8.5.0
120127
- TensorRT 8.5.1.7
@@ -124,7 +131,7 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
124131
Releases: https://github.com/pytorch/TensorRT/releases
125132

126133
```
127-
pip install torch-tensorrt==1.2.0 --find-links https://github.com/pytorch/TensorRT/releases/expanded_assets/v1.2.0
134+
pip install torch-tensorrt
128135
```
129136

130137
## Compiling Torch-TensorRT

WORKSPACE

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "8b3b48615169c83c1b643c0efade078ea080b1da598e15fcf01bc59421f3095e",
59+
sha256 = "7c4b8754830fef23ec19c5eaf414794cee9597b435df055f5c1d0471d3e81568",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.0.dev20230219%2Bcu117.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "aa7fd06079d260ff83c344d043fb84fbd9cf831cf375ed8b5a1b62416817af31",
67+
sha256 = "f1e64a75dd12d0ba4c8c1f61947299e0a9c50684dff64f0cfbf355aa7a13e8cf",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.0.0.dev20230219%2Bcu117.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu117/libtorch-shared-with-deps-2.1.0.dev20230314%2Bcu117.zip"],
7070
)
7171

7272
# Download these tarballs manually from the NVIDIA website

core/compiler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352352
// Determine if the block is convertible/has collection output, and based on the result,
353353
// whether full compilation can be expected
354354
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto inputIsCollection = conversion::InputIsCollection(g->block());
355356
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356-
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
357358

358359
// Determine whether user specifications necessitate partitioning
359360
auto isFallbackRequested = userRequestedFallback(cfg);

core/conversion/conversion.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
556556
return convertable_ops;
557557
}
558558

559+
bool InputIsCollection(const torch::jit::Block* b) {
560+
for (auto in : b->inputs()) {
561+
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
562+
return true;
563+
}
564+
}
565+
return false;
566+
}
567+
559568
bool OutputIsCollection(const torch::jit::Block* b) {
560569
for (auto out : b->outputs()) {
561570
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562-
out->type()->kind() == torch::jit::TypeKind::ListType) {
571+
out->type()->kind() == torch::jit::TypeKind::ListType ||
572+
out->type()->kind() == torch::jit::TypeKind::DictType) {
563573
return true;
564574
}
565575
}

core/conversion/conversion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(
2626

2727
bool OpSupported(const torch::jit::Node* n);
2828

29+
bool InputIsCollection(const torch::jit::Block* b);
30+
2931
bool OutputIsCollection(const torch::jit::Block* b);
3032

3133
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

0 commit comments

Comments
 (0)