Skip to content

cherry-pick: Reorganize + Upgrade Dynamo (release/1.4) #1931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -740,33 +740,33 @@ commands:
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile-core:
description: "Test the Dynamo torch_compile path"
test-dynamo-compile-core:
description: "Test the Dynamo compile path"
steps:
- run:
name: Run Dynamo torch_compile core tests
name: Run Dynamo compile core tests
command: |
cd py/torch_tensorrt/dynamo/torch_compile
cd py/torch_tensorrt/dynamo/backend
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
test-dynamo-compile:
description: "Test the Dynamo compile path"
steps:
- run:
name: Run Dynamo torch_compile E2E tests
name: Run Dynamo compile E2E tests
command: |
cd py/torch_tensorrt/dynamo/
pushd test/
pip3 install timm
pip3 install transformers
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
popd

- store_test_results:
Expand Down Expand Up @@ -1000,8 +1000,8 @@ jobs:
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-torch_compile-core
- test-dynamo-compile
- test-dynamo-compile-core
- test-dynamo-fx_ts

package-x86_64-linux:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _find_lib(name, paths):

if version.parse(torch.__version__) >= version.parse("2.dev"):
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import torch_compile
from torch_tensorrt.dynamo import backend


def _register_with_torch():
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class _IRType(Enum):
ts = 0
fx = 1
fx_ts_compat = 2
torch_compile = 3
dynamo_compile = 3


class _ModuleType(Enum):
Expand Down Expand Up @@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:

ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
ir_targets_fx = ir == "fx"
ir_targets_torch_compile = ir == "torch_compile"
ir_targets_dynamo_compile = ir == "dynamo_compile"
ir_targets_fx_ts_compat = ir == "fx_ts_compat"

if module_is_tsable and ir_targets_torchscript:
Expand All @@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
return _IRType.fx
elif module_is_fxable and ir_targets_fx_ts_compat:
return _IRType.fx_ts_compat
elif module_is_fxable and ir_targets_torch_compile:
return _IRType.torch_compile
elif module_is_fxable and ir_targets_dynamo_compile:
return _IRType.dynamo_compile
else:
if ir == "default":
# Options are listed in order of preference
Expand Down Expand Up @@ -156,8 +156,8 @@ def compile(
dynamic_batch=False,
**kwargs,
)
elif target_ir == _IRType.torch_compile:
return torch_tensorrt.dynamo.torch_compile(
elif target_ir == _IRType.dynamo_compile:
return torch_tensorrt.dynamo.compile(
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
)
elif target_ir == _IRType.fx_ts_compat:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from torch_tensorrt.dynamo import fx_ts_compat
from .torch_compile import compile as torch_compile
from .backend import compile
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import torch_tensorrt
from functools import partial

from typing import Any
from typing import Any, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
from torch_tensorrt.dynamo.torch_compile._defaults import (
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
MAX_NUM_TRT_ENGINES,
MIN_BLOCK_SIZE,
)


Expand All @@ -41,7 +41,7 @@ def compile(
calibrator=None,
truncate_long_and_double=False,
require_full_compilation=False,
min_block_size=3,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
**kwargs,
Expand All @@ -50,7 +50,7 @@ def compile(
logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -80,6 +80,8 @@ def compile(
precision=lower_precision,
debug=debug,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
**kwargs,
)

Expand All @@ -100,7 +102,8 @@ def create_backend(
precision: LowerPrecision = PRECISION,
debug: bool = DEBUG,
workspace_size: int = MAX_WORKSPACE_SIZE,
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -117,10 +120,11 @@ def create_backend(
debug=debug,
precision=precision,
workspace_size=workspace_size,
max_num_trt_engines=max_num_trt_engines,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
)

return partial(
tensorrt_backend,
torch_tensorrt_backend,
settings=settings,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
PRECISION = LowerPrecision.FP32
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
MAX_NUM_TRT_ENGINES = 200
MIN_BLOCK_SIZE = 5
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.torch_compile._defaults import (
from torch_tensorrt.dynamo.backend._defaults import (
PRECISION,
DEBUG,
MAX_WORKSPACE_SIZE,
MAX_NUM_TRT_ENGINES,
MIN_BLOCK_SIZE,
)


Expand All @@ -14,4 +15,5 @@ class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
workspace_size: int = MAX_WORKSPACE_SIZE
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,42 @@
from functools import partial
import torch._dynamo as td

from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
)
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._dynamo.backends.common import fake_tensor_unsupported

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


@td.register_backend(name="tensorrt")
@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def tensorrt_backend(
gm: torch.nn.Module,
def torch_tensorrt_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)


@td.register_backend(name="aot_torch_tensorrt_aten")
@fake_tensor_unsupported
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
custom_backend = partial(
fx_dynamo_backend,
_pretraced_backend,
settings=settings,
)

Expand All @@ -40,14 +52,12 @@ def tensorrt_backend(
)


@td.register_backend(name="fx_tensorrt")
@fake_tensor_unsupported
def fx_dynamo_backend(
def _pretraced_backend(
gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
):
"""Helper function to manage translation of FX module to TRT engines
"""Helper function to manage translation of traced FX module to TRT engines

Args:
module: FX GraphModule to convert
Expand All @@ -57,9 +67,9 @@ def fx_dynamo_backend(
Compiled FX GraphModule
"""
try:
trt_compiled = compile_module(
trt_compiled = _compile_module(
gm,
example_inputs,
sample_inputs,
settings=settings,
)
return trt_compiled
Expand All @@ -72,12 +82,12 @@ def fx_dynamo_backend(
return gm.forward


def compile_module(
def _compile_module(
gm: torch.fx.GraphModule,
example_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile an FX module
"""Compile a traced FX module

Includes: Partitioning + Conversion Phases

Expand All @@ -90,7 +100,10 @@ def compile_module(
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Iterate over all components that can be accelerated
Expand All @@ -100,7 +113,7 @@ def compile_module(

# Get submodule inputs
submodule_inputs = get_submod_inputs(
partitioned_module, submodule, example_inputs
partitioned_module, submodule, sample_inputs
)

# Create TRT Module from submodule
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,20 @@ def inplace_op(*args, **kwargs):
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)


@register_decomposition(aten.std, registry=DECOMPOSITIONS)
def std_replacement(*args, **kwargs) -> torch.Tensor:
return torch.sqrt(torch.var(*args, **kwargs))


@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS)
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
return torch.reciprocal(torch.sqrt(*args, **kwargs))


@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


def get_decompositions():
return DECOMPOSITIONS
Loading