Skip to content

[FX] Changes done internally at Facebook #1456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/_sources/tutorials/ptq.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq

.. _writing_ptq_python:

Expand Down Expand Up @@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
calibrator=calibrator)

If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py
and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py

Citations
^^^^^^^^^^^
Expand Down
12 changes: 6 additions & 6 deletions examples/fx/hugging_face_torchdynamo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
)
from transformers import BertConfig, ReformerConfig, XLNetModel, XLNetConfig

import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
from torchdynamo.testing import collect_results
from torchdynamo.testing import same
import torch._dynamo as torchdynamo
from torch._dynamo.optimizations import backends
from torch._dynamo.optimizations.training import aot_autograd_debug_strategy1
from torch._dynamo.optimizations.training import aot_autograd_speedup_strategy
from torch._dynamo.testing import collect_results
from torch._dynamo.testing import same

torch.backends.cuda.matmul.allow_tf32 = True

Expand Down
4 changes: 2 additions & 2 deletions examples/fx/torchdynamo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import dataclass, field, replace

import torch
import torchdynamo
import torch._dynamo as torchdynamo
import torchvision
from torch_tensorrt.fx.lower import compile
from torch_tensorrt.fx.utils import LowerPrecision
from torchdynamo.optimizations import backends
from torch._dynamo.optimizations import backends

"""
The purpose of this example is to demostrate the lowering flow to TRT and Torchdynamo
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
trt_transposed_linear,
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -3371,6 +3372,9 @@ def acc_ops_gelu(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
approximate = kwargs["approximate"]
if approximate is not "none":
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"GELU received input {input_val} that is not part "
Expand Down
27 changes: 24 additions & 3 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
def compile(
module: nn.Module,
input,
min_acc_module_size: int = 10,
max_batch_size: int = 2048,
max_workspace_size=1 << 25,
explicit_batch_dimension=False,
Expand All @@ -51,6 +52,7 @@ def compile(
module: Original module for lowering.
input: Input for module.
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
min_acc_module_size: Minimal number of nodes for an accelerated submodule
max_workspace_size: Maximum size of workspace given to TensorRT.
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
lower_precision: lower_precision config given to TRTModule.
Expand All @@ -70,6 +72,7 @@ def compile(

lower_setting = LowerSetting(
max_batch_size=max_batch_size,
min_acc_module_size=min_acc_module_size,
max_workspace_size=max_workspace_size,
explicit_batch_dimension=explicit_batch_dimension,
lower_precision=lower_precision,
Expand Down Expand Up @@ -268,6 +271,7 @@ def __call__(
module: nn.Module,
inputs: Input,
additional_inputs: Optional[Input] = None,
fp16_conversion_fn: Optional[Callable[[Input], Input]] = None,
) -> nn.Module:
lower_setting = self.lower_pass_manager_builder.lower_setting
atol = lower_setting.correctness_atol
Expand All @@ -284,9 +288,26 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
== LowerPrecision.FP16
):
module.half()
inputs = tuple(
x.half() if x is not None and x.dtype == torch.float32 else x
for x in inputs
# A custom conversion function can be passed to the lowerer to
# handle inputs with custom types. By default, just handle
# tensors and NoneType.
if fp16_conversion_fn is None:
conversion_fn = (
lambda x: x.half()
if x is not None and x.dtype == torch.float32
else x
)
else:
conversion_fn = fp16_conversion_fn

inputs = tuple(conversion_fn(x) for x in inputs)
if lower_setting.is_aten:
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
inputs, additional_inputs
)
else:
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
inputs, additional_inputs
)
if lower_setting.is_aten:
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
Expand Down
49 changes: 41 additions & 8 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,51 @@
import torch
import torch.nn as nn
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec


class TestCatConverter(AccTestCase):
def test_cat(self):
@parameterized.expand(
[
param("cat", torch.cat),
param("concat", torch.concat),
]
)
def test_cat(self, _, op):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y, z), 1)
return op((x, y, z), 1)

inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})

def test_cat_neg(self):
@parameterized.expand(
[
param("cat", torch.cat),
param("concat", torch.concat),
]
)
def test_cat_neg(self, _, op):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y, z), -1)
return op((x, y, z), -1)

inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)]
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})

def test_cat_with_dynamic_shape(self):
@parameterized.expand(
[
param("cat", torch.cat),
param("concat", torch.concat),
]
)
def test_cat_with_dynamic_shape(self, _, op):
class Cat(nn.Module):
def forward(self, x, y):
x = x + y
return torch.cat((x, y), 0)
return op((x, y), 0)

input_specs = [
InputTensorSpec(
Expand All @@ -42,11 +61,17 @@ def forward(self, x, y):
]
self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})

def test_cat_with_dynamic_shape_four_dimensions(self):
@parameterized.expand(
[
param("cat", torch.cat),
param("concat", torch.concat),
]
)
def test_cat_with_dynamic_shape_four_dimensions(self, _, op):
class Cat(nn.Module):
def forward(self, x, y):
x = x + y
return torch.cat((x, y), 0)
return op((x, y), 0)

input_specs = [
InputTensorSpec(
Expand All @@ -63,6 +88,14 @@ def forward(self, x, y):

self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})

def test_concat(self):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.concat((x, y, z), 1)

inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})


if __name__ == "__main__":
run_tests()
33 changes: 33 additions & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def forward(self, x):
TestModule(), input_specs, expected_ops={acc_ops.gelu}
)

def test_gelu_module(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU()

def forward(self, x):
return self.gelu(x)

inputs = [torch.randn(3, 10, 20)]
self.run_test(
TestModule(),
inputs,
expected_ops={acc_ops.gelu},
test_implicit_batch_dim=False,
)

def test_gelu_module_throw(self):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate="tanh")

def forward(self, x):
return self.gelu(x)

inputs = [torch.randn(3, 10, 20)]
self.run_test_with_assert_error(
TestModule(),
inputs,
expect_error=RuntimeError,
)


if __name__ == "__main__":
run_tests()
30 changes: 0 additions & 30 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,6 @@


class TestNewOnesConverter(AccTestCase):
def test_newone(self):
class TestModule(nn.Module):
def forward(self, x):
return x.new_ones((3, 5), dtype=torch.float16)

inputs = [torch.randn(1, 10)]
self.run_test(
TestModule(),
inputs,
expected_ops={acc_ops.new_ones},
test_implicit_batch_dim=False,
)

def test_newone_no_dtype(self):
class TestModule(nn.Module):
def forward(self, x):
Expand Down Expand Up @@ -47,23 +34,6 @@ def forward(self, x):


class TestNewOnesConverterWithDynamicShape(AccTestCase):
def test_newone(self):
class TestModule(nn.Module):
def forward(self, x):
return x.new_ones((3, 5), dtype=torch.float16)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
),
]

self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={acc_ops.new_ones}
)

def test_newone_no_dtype(self):
class TestModule(nn.Module):
def forward(self, x):
Expand Down
83 changes: 42 additions & 41 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,47 +271,48 @@ def forward(self, x):
precision=LowerPrecision.FP16,
)

# tensor.int()
def test_int(self):
class To(torch.nn.Module):
def forward(self, x):
x = x.int()
# we do not expect int to be output type, so add an extra layer
x = x.float()
return x

input = torch.randn(2, 2)
inputs = [
input,
]
self.run_test(
To(),
inputs,
expected_ops={acc_ops.to_dtype},
test_implicit_batch_dim=False,
precision=LowerPrecision.FP32,
)

# tensor.int()
def test_int_with_dynamic_shape_four_dimensions(self):
class To(torch.nn.Module):
def forward(self, x):
x = x.int()
# we do not expect int to be output type, so add an extra layer
x = x.float()
return x

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1, -1),
dtype=torch.int,
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
),
]

self.run_test_with_dynamic_shape(
To(), input_specs, expected_ops={acc_ops.to_dtype}
)
# TODO Open in future. TRT 8.5 does not work for this test
# The test is a rare case. We need to remove it in graph maybe.
# def test_int(self):
# class To(torch.nn.Module):
# def forward(self, x):
# x = x.int()
# # we do not expect int to be output type, so add an extra layer
# x = x.float()
# return x

# input = torch.randn(2, 2)
# inputs = [
# input,
# ]
# self.run_test(
# To(),
# inputs,
# expected_ops={acc_ops.to_dtype},
# test_implicit_batch_dim=False,
# precision=LowerPrecision.FP32,
# )

# # tensor.int()
# def test_int_with_dynamic_shape_four_dimensions(self):
# class To(torch.nn.Module):
# def forward(self, x):
# x = x.int()
# # we do not expect int to be output type, so add an extra layer
# x = x.float()
# return x

# input_specs = [
# InputTensorSpec(
# shape=(-1, -1, -1, -1),
# dtype=torch.int,
# shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))],
# ),
# ]

# self.run_test_with_dynamic_shape(
# To(), input_specs, expected_ops={acc_ops.to_dtype}
# )


if __name__ == "__main__":
Expand Down
Loading