Skip to content

Commit 09b099a

Browse files
committed
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into export_prototype
2 parents 980dc1c + dfc4899 commit 09b099a

File tree

4 files changed

+109
-29
lines changed

4 files changed

+109
-29
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,30 @@
33
import math
44
import operator
55
import warnings
6-
from typing import Dict, Optional, Sequence, Tuple, Union, cast
6+
from typing import cast, Dict, Optional, Sequence, Tuple, Union
77

88
import numpy as np
99

1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13+
14+
from ..converter_registry import tensorrt_converter
15+
16+
from ..tracer.acc_tracer import acc_ops
17+
from ..types import * # noqa: F403
1318
from torch.fx.immutable_collections import immutable_list
1419
from torch.fx.node import Argument, Target
15-
from torch_tensorrt.fx.converters.impl import activation, convolution
20+
21+
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
22+
23+
from .converter_utils import * # noqa: F403
1624
from torch_tensorrt.fx.passes.lower_basic_pass import (
1725
trt_transposed_linear,
1826
trt_transposed_matmul,
1927
)
2028
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
21-
22-
from ..converter_registry import tensorrt_converter
23-
from ..tracer.acc_tracer import acc_ops
24-
from ..types import * # noqa: F403
25-
from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter
26-
from .converter_utils import * # noqa: F403
29+
from torch_tensorrt.fx.converters.impl import activation, convolution
2730

2831
_LOGGER: logging.Logger = logging.getLogger(__name__)
2932

@@ -2711,14 +2714,8 @@ def acc_ops_linear(
27112714
"dim for linear and it can't be the last dim."
27122715
)
27132716

2714-
if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)):
2715-
weight = get_trt_tensor(
2716-
network,
2717-
kwargs["weight"].t()
2718-
if isinstance(kwargs["weight"], torch.Tensor)
2719-
else kwargs["weight"].T,
2720-
f"{name}_weight",
2721-
)
2717+
if isinstance(kwargs["weight"], torch.Tensor):
2718+
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
27222719
if target not in (acc_ops.linear, torch.ops.aten.linear):
27232720
weight_op = trt.MatrixOperation.TRANSPOSE
27242721
else:

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import operator
22
import warnings
3-
from enum import Enum, auto
43
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
54

5+
from enum import Enum, auto
66
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -20,7 +20,7 @@
2020
TRTPluginFieldCollection,
2121
TRTTensor,
2222
)
23-
from ..utils import Frameworks, unified_dtype_converter
23+
from ..utils import unified_dtype_converter, Frameworks
2424

2525

2626
class SourceIR(Enum):
@@ -271,7 +271,7 @@ def create_constant(
271271
"""
272272
constant = network.add_constant(
273273
(1,) if isinstance(value, (int, float)) else value.shape,
274-
to_numpy(value, dtype).copy(),
274+
to_numpy(value, dtype),
275275
)
276276
constant.name = name
277277
return constant.get_output(0)
@@ -311,7 +311,7 @@ def get_trt_tensor(
311311
elif isinstance(input_val, np.ndarray) and (
312312
input_val.dtype == np.bool_ or input_val.dtype == np.int64
313313
):
314-
input_val = input_val.astype(np.int32)
314+
input_val = input_val.to(np.int32)
315315

316316
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
317317
return create_constant(network, input_val, name, dtype)

py/torch_tensorrt/fx/converters/impl/convolution.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
from typing import Any, Optional, Sequence, Union
2-
31
import numpy as np
2+
from typing import Any, Optional, Sequence, Union
43

54
# @manual=//deeplearning/trt/python:py_tensorrt
65
import tensorrt as trt
76
import torch
87
from torch.fx.node import Target
9-
from torch_tensorrt.fx.converters import acc_ops_converters
8+
109
from torch_tensorrt.fx.converters.converter_utils import (
1110
SourceIR,
1211
extend_attr_to_tuple,
1312
get_dyn_range,
14-
get_trt_tensor,
15-
has_dynamic_shape,
1613
mark_as_int8_layer,
1714
set_layer_name,
15+
has_dynamic_shape,
1816
to_numpy,
17+
get_trt_tensor,
18+
)
19+
from torch_tensorrt.fx.converters import acc_ops_converters
20+
21+
from torch_tensorrt.fx.types import (
22+
TRTNetwork,
23+
TRTTensor,
1924
)
20-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2125

2226

2327
def convNd(
@@ -50,7 +54,7 @@ def convNd(
5054
)
5155

5256
# Process bias terms
53-
if isinstance(bias, (torch.Tensor, np.ndarray)):
57+
if isinstance(bias, torch.Tensor):
5458
# Transform the bias constant into a Numpy array
5559
bias = to_numpy(bias)
5660

@@ -75,7 +79,7 @@ def convNd(
7579
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
7680
)
7781

78-
elif isinstance(weight, (torch.Tensor, np.ndarray)):
82+
elif isinstance(weight, torch.Tensor):
7983
# Transform the weight constant into a Numpy array
8084
weight = to_numpy(weight)
8185

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt
33
from torch.testing._internal.common_utils import TestCase, run_tests
44

5-
from ..testing_utilities import lower_graph_testing
5+
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
66

77

88
class TestFakeTensors(TestCase):
@@ -157,5 +157,84 @@ def forward(self, x):
157157
torch._dynamo.reset()
158158

159159

160+
class TestTensorFreezing(TestCase):
161+
def test_tensor_freeze_attr(self):
162+
class TensorFreeze(torch.nn.Module):
163+
def __init__(self):
164+
super().__init__()
165+
self.const = torch.ones((8, 2), device="cuda")
166+
167+
def forward(self, x):
168+
return x @ self.const
169+
170+
inputs = [
171+
torch.ones(
172+
7,
173+
8,
174+
).cuda()
175+
]
176+
177+
fx_graph = torch.fx.symbolic_trace(TensorFreeze())
178+
179+
# Validate that the results between Torch and Torch-TRT are similar
180+
optimized_model = torch_tensorrt.compile(
181+
fx_graph,
182+
"torch_compile",
183+
inputs,
184+
min_block_size=1,
185+
pass_through_build_failures=True,
186+
)
187+
optimized_model_results = optimized_model(*inputs).detach().cpu()
188+
torch_model_results = fx_graph(*inputs).detach().cpu()
189+
190+
max_diff = float(
191+
torch.max(torch.abs(optimized_model_results - torch_model_results))
192+
)
193+
self.assertAlmostEqual(
194+
max_diff,
195+
0,
196+
DECIMALS_OF_AGREEMENT,
197+
msg=f"Frozen-Tensor TRT outputs don't match with the original model.",
198+
)
199+
torch._dynamo.reset()
200+
201+
def test_constant_fold(self):
202+
class Arange(torch.nn.Module):
203+
def forward(self, x):
204+
y = torch.arange(10, device="cuda")
205+
return x + y
206+
207+
inputs = [
208+
torch.rand(
209+
10,
210+
10,
211+
).cuda()
212+
]
213+
214+
fx_graph = torch.fx.symbolic_trace(Arange())
215+
216+
# Validate that the results between Torch and Torch-TRT are similar
217+
optimized_model = torch_tensorrt.compile(
218+
fx_graph,
219+
"torch_compile",
220+
inputs,
221+
min_block_size=1,
222+
pass_through_build_failures=True,
223+
)
224+
optimized_model_results = optimized_model(*inputs).detach().cpu()
225+
torch_model_results = fx_graph(*inputs).detach().cpu()
226+
227+
max_diff = float(
228+
torch.max(torch.abs(optimized_model_results - torch_model_results))
229+
)
230+
self.assertAlmostEqual(
231+
max_diff,
232+
0,
233+
DECIMALS_OF_AGREEMENT,
234+
msg=f"Constant Folded TRT outputs don't match with the original model.",
235+
)
236+
torch._dynamo.reset()
237+
238+
160239
if __name__ == "__main__":
161240
run_tests()

0 commit comments

Comments
 (0)