Skip to content

Commit 7ed0904

Browse files
committed
feat: Refactor dynamo export
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent dea70b8 commit 7ed0904

36 files changed

+366
-1979
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class _IRType(Enum):
1515

1616
ts = 0
1717
fx = 1
18-
fx_ts_compat = 2
19-
dynamo_compile = 3
18+
dynamo = 2
19+
torch_compile = 3
2020

2121

2222
class _ModuleType(Enum):
@@ -47,31 +47,29 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4747

4848
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4949
ir_targets_fx = ir == "fx"
50-
ir_targets_dynamo_compile = ir == "dynamo_compile"
51-
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
50+
ir_targets_dynamo = ir == "dynamo"
51+
ir_targets_torch_compile = ir == "torch_compile"
5252

5353
if module_is_tsable and ir_targets_torchscript:
5454
return _IRType.ts
5555
elif module_is_fxable and ir_targets_fx:
5656
return _IRType.fx
57-
elif module_is_fxable and ir_targets_fx_ts_compat:
58-
return _IRType.fx_ts_compat
59-
elif module_is_fxable and ir_targets_dynamo_compile:
60-
return _IRType.dynamo_compile
57+
elif module_is_fxable and ir_targets_dynamo:
58+
return _IRType.dynamo
59+
elif module_is_fxable and ir_targets_torch_compile:
60+
return _IRType.torch_compile
6161
else:
6262
if ir == "default":
6363
# Options are listed in order of preference
64-
if module_is_tsable:
64+
if module_is_fxable:
6565
logging.log(
66-
logging.Level.Info, "ir was set to default, using TorchScript as ir"
66+
logging.Level.Info, "ir was set to default, using dynamo as ir"
6767
)
68-
return _IRType.ts
69-
elif module_is_fxable:
68+
return _IRType.dynamo
69+
elif module_is_tsable:
7070
raise ValueError(
71-
"Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT"
71+
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to compile."
7272
)
73-
# logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
74-
# return _IRType.fx
7573
else:
7674
raise ValueError("Module was provided with in an unsupported format")
7775
else:
@@ -156,12 +154,12 @@ def compile(
156154
dynamic_batch=False,
157155
**kwargs,
158156
)
159-
elif target_ir == _IRType.dynamo_compile:
157+
elif target_ir == _IRType.dynamo:
160158
return torch_tensorrt.dynamo.compile(
161159
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162160
)
163-
elif target_ir == _IRType.fx_ts_compat:
164-
return torch_tensorrt.dynamo.fx_ts_compat.compile(
161+
elif target_ir == _IRType.torch_compile:
162+
return torch_tensorrt.dynamo.backend.compile(
165163
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
166164
)
167165
else:

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
1-
import torch
2-
from packaging import version
3-
4-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
5-
from torch_tensorrt.dynamo import fx_ts_compat
6-
from .backend import compile
1+
from ._settings import *
2+
from .compile import compile

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
11+
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
1212
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
1313
from torch_tensorrt.dynamo._defaults import (
1414
PRECISION,

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo.common import CompilationSettings
7+
from torch_tensorrt.dynamo import CompilationSettings
88
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
@@ -15,8 +15,8 @@
1515
partition,
1616
get_submod_inputs,
1717
)
18-
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
19-
from torch_tensorrt.dynamo.backend.conversion import convert_module
18+
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
19+
from torch_tensorrt.dynamo.conversion import convert_module
2020

2121
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2222

py/torch_tensorrt/dynamo/backend/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
)
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
99
from .substitutions import *
10+
from ._fusers import *
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from torch_tensorrt.fx.tracer.acc_tracer import acc_ops
3+
4+
5+
def check_permute(node: torch.fx.Node):
6+
ranks = len(node.meta["tensor_meta"].shape)
7+
permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr]
8+
allowed_permutation = list(i for i in range(ranks))
9+
allowed_permutation[-1] = ranks - 2
10+
allowed_permutation[-2] = ranks - 1
11+
return permutation == allowed_permutation
12+
13+
14+
def fuse_permute_matmul(gm: torch.fx.GraphModule):
15+
"""
16+
Fuse pattern like permute + matmul if permute is transposing the last two dimension.
17+
"""
18+
for node in gm.graph.nodes:
19+
if node.target == acc_ops.matmul:
20+
lhs, rhs = node.kwargs["input"], node.kwargs["other"]
21+
lhs_transposed = rhs_tranposed = False
22+
skip = False
23+
24+
if lhs.target == acc_ops.permute and check_permute(lhs):
25+
lhs_transposed = True
26+
lhs = lhs.kwargs["input"]
27+
28+
if rhs.target == acc_ops.permute and check_permute(rhs):
29+
rhs_tranposed = True
30+
rhs = rhs.kwargs["input"]
31+
32+
if (not skip) and (lhs_transposed or rhs_tranposed):
33+
with gm.graph.inserting_before(node):
34+
fused_node = gm.graph.call_function(
35+
trt_transposed_matmul,
36+
args=(lhs, rhs, lhs_transposed, rhs_tranposed),
37+
)
38+
node.replace_all_uses_with(fused_node)
39+
40+
gm.graph.eliminate_dead_code()
41+
gm.graph.lint()
42+
gm.recompile()
43+
return gm
44+
45+
46+
def trt_transposed_linear(
47+
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
48+
):
49+
return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
50+
51+
52+
def fuse_permute_linear(gm: torch.fx.GraphModule):
53+
"""
54+
Fuse pattern like permute + linear if permute is transposing the last two dimension.
55+
"""
56+
for node in gm.graph.nodes:
57+
if node.target == acc_ops.linear:
58+
inp = node.kwargs["input"]
59+
if inp.target == acc_ops.permute and check_permute(inp):
60+
inp = inp.kwargs["input"]
61+
weight = node.kwargs["weight"]
62+
bias = node.kwargs["bias"]
63+
with gm.graph.inserting_before(node):
64+
fused_node = gm.graph.call_function(
65+
trt_transposed_linear, args=(inp, weight, bias)
66+
)
67+
node.replace_all_uses_with(fused_node)
68+
69+
gm.graph.eliminate_dead_code()
70+
gm.graph.lint()
71+
gm.recompile()
72+
return gm

py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
import torch
44
from copy import deepcopy
5-
from torch_tensorrt.dynamo import compile
6-
from utils import lower_graph_testing
7-
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT
5+
from torch_tensorrt.dynamo.backend import compile
6+
from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT
87

98

109
class TestTRTModuleNextCompilation(TestCase):

py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs
1+
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
22
from utils import same_output_format
33
import torch_tensorrt
44
import unittest

py/torch_tensorrt/dynamo/backend/test/test_decompositions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from functools import partial
2-
from utils import lower_graph_testing
2+
from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT
33
from torch.testing._internal.common_utils import run_tests, TestCase
44
import torch
5-
from torch_tensorrt.dynamo import compile
6-
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT
5+
from torch_tensorrt.dynamo.backend import compile
76

87

98
class TestLowering(TestCase):

py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from utils import lower_graph_testing
33
from torch.testing._internal.common_utils import run_tests, TestCase
4-
from torch_tensorrt.dynamo import compile
4+
from torch_tensorrt.dynamo.backend import compile
55

66

77
class TestMaxPool1D(TestCase):

py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from utils import lower_graph_testing
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
import torch
4-
from torch_tensorrt.dynamo import compile
4+
from torch_tensorrt.dynamo.backend import compile
55

66

77
class TestFakeTensors(TestCase):

py/torch_tensorrt/dynamo/backend/test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
1818

19+
DECIMALS_OF_AGREEMENT = 4
20+
1921

2022
@fake_tensor_unsupported
2123
def fx_dynamo_testing_backend(

py/torch_tensorrt/dynamo/common/__init__.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

py/torch_tensorrt/dynamo/common/test_utils.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)