Skip to content

[FX] Sync enhancement done internally at Meta #1161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ bazel build //:libtorchtrt --compilation_mode opt
```

### FX path (Python only) installation
If the user plan to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below.
If the user plans to try FX path (Python only) and would like to avoid bazel build. Please follow the steps below.
``` shell
cd py && python3 setup.py install --fx-only
```
Expand Down
8 changes: 6 additions & 2 deletions examples/fx/quantized_resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
import torchvision.models as models
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference,
prepare_fx,
)
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.passes import shape_prop
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
Expand Down Expand Up @@ -48,7 +52,7 @@ def build_int8_trt(rn18):
prepared = prepare_fx(rn18, {"": qconfig})
for _ in range(10):
prepared(data)
quantized_rn18 = convert_fx(prepared, is_reference=True)
quantized_rn18 = convert_to_reference(prepared)
ref_res = quantized_rn18(data)
print("quantized model:", quantized_rn18)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,7 +2222,7 @@ def acc_ops_adaptive_avg_poolnd(
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
assert all(
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."

output_size = cast(
Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
Expand Down
25 changes: 15 additions & 10 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,8 @@ def add_binary_elementwise_layer(
This function adds a TensorRT elementwise layer. We allow both operands to be
constant (not a trt tensor) because in implicit batch dimension mode, we could
introduce constant via .size() op. Other scenario should be const folded first.
If any operand is not a trt tensor, we make it a trt constant layer which has
the same type as the other trt tensor. Then we broadcast these two inputs to
have the same number of dimensions.
If any operand is not a trt tensor, we make it a trt constant layer while preserve
its dtype. Then we broadcast these two inputs to have the same number of dimensions.

Limitation:
If we are using implicit batch dim mode, the operand that is not a trt
Expand All @@ -436,14 +435,16 @@ def add_binary_elementwise_layer(
Returns:
The output of TensorRT Elementwise layer.
"""
dtype = None
lhs_dtype = None
rhs_dtype = None
is_lhs_trt_tensor = False
is_rhs_trt_tensor = False

if isinstance(lhs_val, TRTTensor):
dtype = torch_dtype_from_trt(lhs_val.dtype)
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
dtype = torch_dtype_from_trt(rhs_val.dtype)
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
is_rhs_trt_tensor = True

if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
Expand All @@ -463,10 +464,14 @@ def add_binary_elementwise_layer(
# this way the shape will become [1], and then will be properly squeezed
# into [], meaning that the result will have shape [], which is what we
# expect.
#
# Note that the dtype here is supposed to be the same as the scalar
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = torch.tensor([rhs_val], dtype=dtype)
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
lhs_val = torch.tensor([lhs_val], dtype=dtype)
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)

# When lhs is scalar, and rhs has shape [1,], then currently the assert
# will fail because lhs shape has fewer dimensions than rhs shape. This
Expand All @@ -482,8 +487,8 @@ def add_binary_elementwise_layer(
if isinstance(rhs_val, torch.Tensor):
rhs_val = squeeze_left(rhs_val)

lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype)
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
] = dict()

def validate_input_specs(self):
for shape, dtpe, _, shape_ranges, has_batch_dim in self.input_specs:
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
has_batch_dim
Expand Down
74 changes: 74 additions & 0 deletions py/torch_tensorrt/fx/passes/graph_opts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from collections.abc import Sequence

import torch
import torch.fx


def common_subexpression_elimination(graph_module: torch.fx.GraphModule) -> bool:
"""
Optimize quantization by removing repeated subexpressions.

Args:
graph_module(torch.fx.GraphModule): target module to be optimized

Returns:
Graph changed or not.
"""

def seq_hashable(seq):
if seq is None:
return None

items = []
for old in seq:
if isinstance(old, Sequence) and not isinstance(old, str):
new = seq_hashable(old)
elif isinstance(old, dict):
new = dict_hashable(old)
elif isinstance(old, slice):
new = old.__reduce__()
else:
new = old

items.append(new)

return tuple(items)

def dict_hashable(d):
if d is None:
return None

items = []
for k, old_v in d.items():
if isinstance(old_v, Sequence):
new_v = seq_hashable(old_v)
elif isinstance(old_v, dict):
new_v = dict_hashable(old_v)
elif isinstance(old_v, slice):
new_v = old_v.__reduce__()
else:
new_v = old_v

items.append((k, new_v))
return tuple(sorted(items))

changed = False
env = {}
for n in graph_module.graph.nodes:
# do not CSE away impure ops
if n.op not in {"call_function", "call_method"} or n.is_impure():
continue

# hash target, args, kwargs
hash_val = (n.target, seq_hashable(n.args), dict_hashable(n.kwargs))

# check if a node has a substitute and can be eliminated
if hash_val in env:
n.replace_all_uses_with(env[hash_val])
graph_module.graph.erase_node(n)
changed = True
continue

env[hash_val] = n

return changed
5 changes: 4 additions & 1 deletion py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial, wraps
from typing import Any, Callable, NamedTuple, Sequence
from typing import Any, Callable, Sequence

import torch
from torch import nn
Expand All @@ -10,6 +10,7 @@
from ..lower_setting import LowerSetting
from ..observer import Observer
from ..passes.remove_duplicate_output_args import remove_duplicate_output_args
from .graph_opts import common_subexpression_elimination

from .lower_basic_pass import run_const_fold

Expand Down Expand Up @@ -94,6 +95,8 @@ def graph_optimization_pass(self) -> PassManager:
passes.append(wrapper(p, self._input))
for p in self.lower_setting.lower_basic_fuse_pass.passes:
passes.append(wrapper(p, self._input))

passes.append(inplace_wrapper(common_subexpression_elimination))
passes.append(
inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d}
)

# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."


if __name__ == "__main__":
run_tests()
22 changes: 22 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase

# from torch_tensorrt.fx.tools.common_fx2trt import InputTensorSpec


class TestAnyConverters(AccTestCase):
@parameterized.expand(
Expand Down Expand Up @@ -64,6 +66,26 @@ def forward(self, x):
test_implicit_batch_dim=False,
)

# Testing with shape (-1, -1, -1, -1) results into error: torch.zeros(tuple([*input_t.shape])). Trying to create tensor with negative dimension -1: [-1, -1, -1, -1]
"""
def test_ops_with_dynamic_shape_four_dimensions(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.any(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.any}
)
"""


if __name__ == "__main__":
run_tests()
22 changes: 21 additions & 1 deletion py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestConverter(AccTestCase):
Expand All @@ -30,6 +30,26 @@ def forward(self, x):
test_implicit_batch_dim=False,
)

# Testing with shape (-1, -1, -1, -1) results into error: RuntimeError: setStorage: sizes [2, 3], strides [1, 2], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 1
"""
def test_as_strided_with_dynamic_shape_four_dimensions(self):
class Stride(nn.Module):
def forward(self, x):
return torch.as_strided(torch.tensor([5, 5]), (2, 3), (1, 2), 0)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
),
]

self.run_test_with_dynamic_shape(
Stride(), input_specs, expected_ops={acc_ops.as_strided}
)
"""


if __name__ == "__main__":
run_tests()
Loading