Skip to content

Commit 2e21ce6

Browse files
frank-weiWei Wei
and
Wei Wei
authored
[FX] Changes done internally at Facebook (#1603)
Co-authored-by: Wei Wei <[email protected]>
1 parent df65620 commit 2e21ce6

31 files changed

+1277
-207
lines changed

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ commands:
263263
parameters:
264264
torch-build:
265265
type: string
266-
default: "2.0.0.dev20230103+cu117"
266+
default: "2.0.0.dev20230120+cu117"
267267
torch-build-index:
268268
type: string
269269
default: "https://download.pytorch.org/whl/nightly/cu117"
@@ -992,7 +992,7 @@ parameters:
992992
# Nightly platform config
993993
torch-build:
994994
type: string
995-
default: "2.0.0.dev20230103+cu117"
995+
default: "2.0.0.dev20230120+cu117"
996996
torch-build-index:
997997
type: string
998998
default: "https://download.pytorch.org/whl/nightly/cu117"

py/torch_tensorrt/fx/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ FX2TRT is merged as FX module in Torch-TensorRT
55

66
* Method 1. Follow the instrucions for Torch-TensorRT
77
* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
8-
```
8+
`
99
$ conda create --name python_env python=3.8
1010
$ conda activate python_env
1111
# Recommend to install PyTorch 1.12 and later
@@ -18,4 +18,4 @@ FX2TRT is merged as FX module in Torch-TensorRT
1818
$ pyton -c "import torch_tensorrt.fx"
1919
# Test an example by
2020
$ python py/torch_tensorrt/fx/example/lower_example.py
21-
```
21+
`

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2802,7 +2802,7 @@ def acc_ops_linear(
28022802

28032803
if isinstance(kwargs["weight"], torch.Tensor):
28042804
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
2805-
if target is not acc_ops.linear:
2805+
if target not in (acc_ops.linear, torch.ops.aten.linear):
28062806
weight_op = trt.MatrixOperation.TRANSPOSE
28072807
else:
28082808
weight_op = trt.MatrixOperation.NONE

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,20 @@ def aten_ops_fmod(
187187
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
188188

189189

190-
@tensorrt_converter(torch.ops.aten.mm.default)
191-
@tensorrt_converter(torch.ops.aten.addmm.default)
190+
@tensorrt_converter(torch.ops.aten.linear)
192191
def aten_ops_linear(
193192
network: TRTNetwork,
194193
target: Target,
195194
args: Tuple[Argument, ...],
196195
kwargs: Dict[str, Argument],
197196
name: str,
198197
) -> Union[TRTTensor, Sequence[TRTTensor]]:
199-
if target == torch.ops.aten.addmm.default:
200-
kwargs_new = {
201-
"bias": args[0],
202-
"input": args[1],
203-
"weight": args[2],
204-
}
205-
elif target == torch.ops.aten.mm.default:
206-
kwargs_new = {
207-
"bias": None,
208-
"input": args[0],
209-
"weight": args[1],
210-
}
198+
kwargs_new = {
199+
"input": args[0],
200+
"weight": args[1],
201+
"bias": args[2],
202+
}
203+
211204
return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name)
212205

213206

@@ -320,3 +313,35 @@ def aten_ops_reshape(
320313
"acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]),
321314
}
322315
return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name)
316+
317+
318+
@tensorrt_converter(torch.ops.aten.cat.default)
319+
def aten_ops_cat(
320+
network: TRTNetwork,
321+
target: Target,
322+
args: Tuple[Argument, ...],
323+
kwargs: Dict[str, Argument],
324+
name: str,
325+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
326+
kwargs_new = {
327+
"tensors": args[0],
328+
"dim": args[1],
329+
}
330+
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
331+
332+
333+
@tensorrt_converter(torch.ops.aten.expand.default)
334+
def aten_ops_expand(
335+
network: TRTNetwork,
336+
target: Target,
337+
args: Tuple[Argument, ...],
338+
kwargs: Dict[str, Argument],
339+
name: str,
340+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
341+
kwargs_new = {
342+
"input": args[0],
343+
"sizes": args[1],
344+
}
345+
return acc_ops_converters.acc_ops_expand_tensor(
346+
network, target, None, kwargs_new, name
347+
)

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,12 @@ def call_method(self, target, args, kwargs):
341341

342342
def output(self, target, args, kwargs):
343343
assert len(args) == 1
344-
outputs = args[0] if isinstance(args[0], tuple) else (args[0],)
344+
if isinstance(args[0], tuple):
345+
outputs = args[0]
346+
elif isinstance(args[0], list):
347+
outputs = tuple(args[0])
348+
else:
349+
outputs = (args[0],)
345350

346351
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
347352
raise RuntimeError("TensorRT requires all outputs to be Tensor!")

py/torch_tensorrt/fx/lower.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import dataclasses as dc
22
import logging
3-
import dataclasses as dc
4-
import logging
53
from typing import Any, Callable, Optional, Sequence
64

75
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -180,8 +178,9 @@ def lower_pass(
180178
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
181179
if lower_setting.use_experimental_rt:
182180
import io
183-
from torch_tensorrt._TRTModuleNext import TRTModuleNext
181+
184182
from torch_tensorrt._Device import Device
183+
from torch_tensorrt._TRTModuleNext import TRTModuleNext
185184

186185
with io.BytesIO() as engine_bytes:
187186
engine_bytes.write(interp_res.engine.serialize())

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class LowerSetting(LowerSettingBasic):
6666
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
6767
preset_lowerer (str): when specified, use a preset logic to build the
6868
instance of Lowerer.
69-
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
70-
only used by explicit batch dim with dynamic shape mode.
69+
only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with
70+
2 stream on each. Set total number to 8 as a safe default value.
7171
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
7272
tactic_sources: tactic sources for TensorRT kernel selection. Default to None,
7373
meaning all possible tactic sources.
@@ -81,17 +81,21 @@ class LowerSetting(LowerSettingBasic):
8181
explicit_precision: bool = False
8282
max_workspace_size: int = 1 << 30
8383
strict_type_constraints: bool = False
84-
customized_fuse_pass: PassManager = PassManager.build_from_passlist([])
85-
lower_basic_fuse_pass: PassManager = PassManager.build_from_passlist(
86-
[fuse_permute_matmul, fuse_permute_linear]
84+
customized_fuse_pass: PassManager = dc.field(
85+
default_factory=lambda: PassManager.build_from_passlist([])
86+
)
87+
lower_basic_fuse_pass: PassManager = dc.field(
88+
default_factory=lambda: PassManager.build_from_passlist(
89+
[fuse_permute_matmul, fuse_permute_linear]
90+
)
8791
)
8892
verbose_log: bool = False
8993
algo_selector = None
9094
timing_cache_prefix: str = ""
9195
save_timing_cache: bool = False
9296
cuda_graph_batch_size: int = -1
9397
preset_lowerer: str = ""
94-
opt_profile_replica: int = 1
98+
opt_profile_replica: int = 8
9599
dynamic_batch: bool = True
96100
tactic_sources: Optional[int] = None
97101
correctness_atol: float = 0.1

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,25 @@ def _get_shape(node: fx.Node) -> Optional[torch.Size]:
608608
# shape info not available
609609
return None
610610
return node.meta["tensor_meta"].shape
611+
612+
613+
@log_before_after
614+
@validate_inference(atol=1e-3, rtol=1e-2)
615+
def fix_clamp_numerical_limits_to_fp16(
616+
mod: torch.fx.GraphModule, input: Input
617+
) -> torch.fx.GraphModule:
618+
MIN_FP16 = -65504.0
619+
MAX_FP16 = 65504.0
620+
for node in mod.graph.nodes:
621+
if node.op == "call_function" and "clamp" in str(node.target):
622+
input_kwargs = node.kwargs
623+
if input_kwargs["min"] < MIN_FP16 and input_kwargs["max"] > MAX_FP16:
624+
new_kwargs = {
625+
"input": input_kwargs["input"],
626+
"min": MIN_FP16,
627+
"max": MAX_FP16,
628+
}
629+
node.kwargs = new_kwargs
630+
631+
mod.recompile()
632+
return mod

0 commit comments

Comments
 (0)