Skip to content

Commit 75715dc

Browse files
peri044gs-olive
andcommitted
feat: Refactor FX APIs under dynamo namespace for parity with TS APIs (#1807)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent df294de commit 75715dc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3834
-159
lines changed

.circleci/config.yml

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ commands:
508508
- store_artifacts:
509509
path: /tmp/testlogs
510510

511+
# =================== FX tests start ======================== #
511512
test-fx_core:
512513
description: "Test the fx core"
513514
steps:
@@ -707,6 +708,61 @@ commands:
707708
- store_artifacts:
708709
path: /tmp/testlogs
709710

711+
# =================== FX tests end ======================== #
712+
713+
# =================== Dynamo tests start ======================== #
714+
test-dynamo-fx_ts:
715+
description: "Test the Dynamo fx_ts_compat path"
716+
steps:
717+
- run:
718+
name: Run Dynamo fx_ts_compat core tests
719+
command: |
720+
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
721+
pushd core/
722+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml
723+
popd
724+
725+
- store_test_results:
726+
path: /tmp/artifacts
727+
- store_artifacts:
728+
path: /tmp/testlogs
729+
730+
test-dynamo-torch_compile-core:
731+
description: "Test the Dynamo torch_compile path"
732+
steps:
733+
- run:
734+
name: Run Dynamo torch_compile core tests
735+
command: |
736+
cd py/torch_tensorrt/dynamo/torch_compile
737+
pushd test/
738+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
739+
popd
740+
741+
- store_test_results:
742+
path: /tmp/artifacts
743+
- store_artifacts:
744+
path: /tmp/testlogs
745+
746+
test-dynamo-torch_compile:
747+
description: "Test the Dynamo torch_compile path"
748+
steps:
749+
- run:
750+
name: Run Dynamo torch_compile E2E tests
751+
command: |
752+
cd py/torch_tensorrt/dynamo/
753+
pushd test/
754+
pip3 install timm
755+
pip3 install transformers
756+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
757+
popd
758+
759+
- store_test_results:
760+
path: /tmp/artifacts
761+
- store_artifacts:
762+
path: /tmp/testlogs
763+
764+
# =================== Dynamo tests end ======================== #
765+
710766
# Define a job to be invoked later in a workflow.
711767
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
712768
jobs:
@@ -883,6 +939,39 @@ jobs:
883939
- dump-test-env
884940
- test-fx-no-aten
885941

942+
test-py-dynamo-x86_64-linux:
943+
parameters:
944+
torch-build:
945+
type: string
946+
torch-build-index:
947+
type: string
948+
trt-version-long:
949+
type: string
950+
machine:
951+
image: ubuntu-2004-cuda-11.4:202110-01
952+
resource_class: gpu.nvidia.large
953+
steps:
954+
- checkout
955+
- attach_workspace:
956+
at: /tmp/dist/
957+
- install-torch-from-index:
958+
torch-build: << parameters.torch-build >>
959+
torch-build-index: << parameters.torch-build-index >>
960+
- create-py-env:
961+
trt-version-long: << parameters.trt-version-long >>
962+
- install-cudnn
963+
# - run:
964+
# name: "Set LD_LIBRARY_PATH path to include the installed CUDNN"
965+
# command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
966+
- run:
967+
name: "Install torch-tensorrt"
968+
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
969+
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
970+
- dump-test-env
971+
- test-dynamo-torch_compile
972+
- test-dynamo-torch_compile-core
973+
- test-dynamo-fx_ts
974+
886975
package-x86_64-linux:
887976
parameters:
888977
enabled:
@@ -1261,6 +1350,13 @@ workflows:
12611350
requires:
12621351
- build-x86_64-linux
12631352

1353+
- test-py-dynamo-x86_64-linux:
1354+
torch-build: << pipeline.parameters.torch-build >>
1355+
torch-build-index: << pipeline.parameters.torch-build-index >>
1356+
trt-version-long: << pipeline.parameters.trt-version-long >>
1357+
requires:
1358+
- build-x86_64-linux
1359+
12641360
- build-x86_64-linux:
12651361
name: build-x86_64-linux-legacy
12661362
torch-build: << pipeline.parameters.torch-build-legacy >>
@@ -1328,6 +1424,13 @@ workflows:
13281424
requires:
13291425
- package-x86_64-linux
13301426

1427+
- test-py-dynamo-x86_64-linux:
1428+
torch-build: << pipeline.parameters.torch-build >>
1429+
torch-build-index: << pipeline.parameters.torch-build-index >>
1430+
trt-version-long: << pipeline.parameters.trt-version-long >>
1431+
requires:
1432+
- package-x86_64-linux
1433+
13311434
on-push:
13321435
jobs:
13331436
- build-x86_64-linux:
@@ -1357,6 +1460,13 @@ workflows:
13571460
requires:
13581461
- build-x86_64-linux
13591462

1463+
- test-py-dynamo-x86_64-linux:
1464+
torch-build: << pipeline.parameters.torch-build >>
1465+
torch-build-index: << pipeline.parameters.torch-build-index >>
1466+
trt-version-long: << pipeline.parameters.trt-version-long >>
1467+
requires:
1468+
- build-x86_64-linux
1469+
13601470
- build-x86_64-linux-cmake:
13611471
torch-build: << pipeline.parameters.torch-build >>
13621472
torch-build-index: << pipeline.parameters.torch-build-index >>

py/setup.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,10 @@ def run(self):
356356
"torch_tensorrt.fx.tools",
357357
"torch_tensorrt.fx.tracer.acc_tracer",
358358
"torch_tensorrt.fx.tracer.dispatch_tracer",
359+
"torch_tensorrt.dynamo",
360+
"torch_tensorrt.dynamo.fx_ts_compat",
361+
"torch_tensorrt.dynamo.fx_ts_compat.passes",
362+
"torch_tensorrt.dynamo.fx_ts_compat.tools",
359363
]
360364
package_dir = {
361365
"torch_tensorrt.fx": "torch_tensorrt/fx",
@@ -364,11 +368,47 @@ def run(self):
364368
"torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools",
365369
"torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer",
366370
"torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer",
371+
"torch_tensorrt.dynamo": "torch_tensorrt/dynamo",
372+
"torch_tensorrt.dynamo.fx_ts_compat": "torch_tensorrt/dynamo/fx_ts_compat",
373+
"torch_tensorrt.dynamo.fx_ts_compat.passes": "torch_tensorrt/dynamo/fx_ts_compat/passes",
374+
"torch_tensorrt.dynamo.fx_ts_compat.tools": "torch_tensorrt/dynamo/fx_ts_compat/tools",
367375
}
368376

369377
with open("README.md", "r", encoding="utf-8") as fh:
370378
long_description = fh.read()
371379

380+
if FX_ONLY:
381+
package_data_list = [
382+
"_Input.py",
383+
]
384+
else:
385+
package_data_list = [
386+
"lib/*",
387+
"include/torch_tensorrt/*.h",
388+
"include/torch_tensorrt/core/*.h",
389+
"include/torch_tensorrt/core/conversion/*.h",
390+
"include/torch_tensorrt/core/conversion/conversionctx/*.h",
391+
"include/torch_tensorrt/core/conversion/converters/*.h",
392+
"include/torch_tensorrt/core/conversion/evaluators/*.h",
393+
"include/torch_tensorrt/core/conversion/tensorcontainer/*.h",
394+
"include/torch_tensorrt/core/conversion/var/*.h",
395+
"include/torch_tensorrt/core/ir/*.h",
396+
"include/torch_tensorrt/core/lowering/*.h",
397+
"include/torch_tensorrt/core/lowering/passes/*.h",
398+
"include/torch_tensorrt/core/partitioning/*.h",
399+
"include/torch_tensorrt/core/partitioning/segmentedblock/*.h",
400+
"include/torch_tensorrt/core/partitioning/partitioninginfo/*.h",
401+
"include/torch_tensorrt/core/partitioning/partitioningctx/*.h",
402+
"include/torch_tensorrt/core/plugins/*.h",
403+
"include/torch_tensorrt/core/plugins/impl/*.h",
404+
"include/torch_tensorrt/core/runtime/*.h",
405+
"include/torch_tensorrt/core/util/*.h",
406+
"include/torch_tensorrt/core/util/logging/*.h",
407+
"bin/*",
408+
"BUILD",
409+
"WORKSPACE",
410+
]
411+
372412
setup(
373413
name="torch_tensorrt",
374414
version=__version__,
@@ -412,32 +452,7 @@ def run(self):
412452
python_requires=">=3.7",
413453
include_package_data=True,
414454
package_data={
415-
"torch_tensorrt": [
416-
"lib/*",
417-
"include/torch_tensorrt/*.h",
418-
"include/torch_tensorrt/core/*.h",
419-
"include/torch_tensorrt/core/conversion/*.h",
420-
"include/torch_tensorrt/core/conversion/conversionctx/*.h",
421-
"include/torch_tensorrt/core/conversion/converters/*.h",
422-
"include/torch_tensorrt/core/conversion/evaluators/*.h",
423-
"include/torch_tensorrt/core/conversion/tensorcontainer/*.h",
424-
"include/torch_tensorrt/core/conversion/var/*.h",
425-
"include/torch_tensorrt/core/ir/*.h",
426-
"include/torch_tensorrt/core/lowering/*.h",
427-
"include/torch_tensorrt/core/lowering/passes/*.h",
428-
"include/torch_tensorrt/core/partitioning/*.h",
429-
"include/torch_tensorrt/core/partitioning/segmentedblock/*.h",
430-
"include/torch_tensorrt/core/partitioning/partitioninginfo/*.h",
431-
"include/torch_tensorrt/core/partitioning/partitioningctx/*.h",
432-
"include/torch_tensorrt/core/plugins/*.h",
433-
"include/torch_tensorrt/core/plugins/impl/*.h",
434-
"include/torch_tensorrt/core/runtime/*.h",
435-
"include/torch_tensorrt/core/util/*.h",
436-
"include/torch_tensorrt/core/util/logging/*.h",
437-
"bin/*",
438-
"BUILD",
439-
"WORKSPACE",
440-
],
455+
"torch_tensorrt": package_data_list,
441456
},
442457
exclude_package_data={
443458
"": ["*.cpp"],

py/torch_tensorrt/_Device.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import torch
22

3-
from torch_tensorrt import _enums
3+
# from torch_tensorrt import _enums
4+
import tensorrt as trt
45
from torch_tensorrt import logging
5-
from torch_tensorrt import _C
6-
76
import warnings
87

8+
try:
9+
from torch_tensorrt import _C
10+
except:
11+
warnings.warn(
12+
"Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable."
13+
)
14+
915

1016
class Device(object):
1117
"""
@@ -51,7 +57,7 @@ def __init__(self, *args, **kwargs):
5157
)
5258
else:
5359
(self.device_type, id) = Device._parse_device_str(args[0])
54-
if self.device_type == _enums.DeviceType.GPU:
60+
if self.device_type == trt.DeviceType.GPU:
5561
self.gpu_id = id
5662
else:
5763
self.dla_core = id
@@ -64,7 +70,7 @@ def __init__(self, *args, **kwargs):
6470
elif len(args) == 0:
6571
if "gpu_id" in kwargs or "dla_core" in kwargs:
6672
if "dla_core" in kwargs:
67-
self.device_type = _enums.DeviceType.DLA
73+
self.device_type = trt.DeviceType.DLA
6874
self.dla_core = kwargs["dla_core"]
6975
if "gpu_id" in kwargs:
7076
self.gpu_id = kwargs["gpu_id"]
@@ -76,7 +82,7 @@ def __init__(self, *args, **kwargs):
7682
)
7783
else:
7884
self.gpu_id = kwargs["gpu_id"]
79-
self.device_type = _enums.DeviceType.GPU
85+
self.device_type = trt.DeviceType.GPU
8086
else:
8187
raise ValueError(
8288
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
@@ -97,15 +103,23 @@ def __init__(self, *args, **kwargs):
97103
def __str__(self) -> str:
98104
return (
99105
"Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")"
100-
if self.device_type == _enums.DeviceType.GPU
106+
if self.device_type == trt.DeviceType.GPU
101107
else ", dla_core={}, allow_gpu_fallback={}".format(
102108
self.dla_core, self.allow_gpu_fallback
103109
)
104110
)
105111

106112
def _to_internal(self) -> _C.Device:
107113
internal_dev = _C.Device()
108-
internal_dev.device_type = self.device_type
114+
if self.device_type == trt.DeviceType.GPU:
115+
internal_dev.device_type = _C.DeviceType.GPU
116+
elif self.device_type == trt.DeviceType.DLA:
117+
internal_dev.device_type = _C.DeviceType.DLA
118+
else:
119+
raise ValueError(
120+
"Invalid DeviceType detected while parsing the Device class"
121+
)
122+
109123
internal_dev.gpu_id = self.gpu_id
110124
internal_dev.dla_core = self.dla_core
111125
internal_dev.allow_gpu_fallback = self.allow_gpu_fallback
@@ -136,6 +150,6 @@ def _parse_device_str(s):
136150
s = s.lower()
137151
spec = s.split(":")
138152
if spec[0] == "gpu" or spec[0] == "cuda":
139-
return (_enums.DeviceType.GPU, int(spec[1]))
153+
return (trt.DeviceType.GPU, int(spec[1]))
140154
elif spec[0] == "dla":
141-
return (_enums.DeviceType.DLA, int(spec[1]))
155+
return (trt.DeviceType.DLA, int(spec[1]))

0 commit comments

Comments
 (0)