Skip to content

Commit ac07351

Browse files
author
Wei Wei
committed
Changes done internally at Facebook
d269be2fc7d84738a642d1d53eb44e6886a28d0c Alex Beloi <[email protected]> [fx] add deferred weights (xl_weight) and tracing for xl_embedding_bag 6f233bc9c72d90a908db0548c9d2dbe853895137 Alex Beloi <[email protected]> [fx] fix out of bounds indices/offsets for embedding_bag ops with xl_weight 3ca3b21c6a85ab9a6e9de503d0f13ee713a7b67c Janet Yang <[email protected]> Support div, torch.norm 52955d93d25e857510ed1b765220e8e5b0b0bb08 Janet Yang <[email protected]> Pass to replace sum(elmtwise(X))/numel(X) w/ mean(elmtwise(X)) 89c56ef76a7a329f244a013ac5ccb099cb00c3c0 Janet Yang <[email protected]> Support scalar clamp, fixes for nan_to_num and benchmark afdc533da031a64e162bb08c8629ff38739e24f8 Wei Wei <[email protected]> [fx2trt] disable dispatch trace leaf node test d160a7a5e554d37c142e13f100bf4d8739ced232 Wei Wei <[email protected]> add option to remove passes c22f691e6eae1b06ecd301eb6285b32d5dc9717c Mike Iovine <[email protected]> [fx2trt] Support dict inputs in acc tracer 8c05a3c57b1f5c63108b979ef8c61411525d0b1f Mike Iovine <[email protected]> [fx2trt] Support namedtuple access in acc tracer getattr ff2000594e3f3ff75e0074edf9c38b5609128bbd Janet Yang <[email protected]> Generalize remove split ops more 1580805d827eb40c941e769b0b99e7c6a3ed6f89 Wei Wei <[email protected]> [fx2trt] add reshape unit test d6a975462071a3747d18edcbe87a3b143b3ece88 Archie Sravankumar <[email protected]> Added FX tracing for `log_softmax` 6943ac0e322077b36a03c50c4c9065de6cd32837 Sungmin Cho <[email protected]> Add replace_mutable_op lower pass baab27b81b1275de92fdaf760a158ce951564d33 Donglin Xia <[email protected]> Register avg_pool3d for acc_op in acc_op.py ae4c4e2c3c18d78542140fcc30e1c24f7c647ef3 Wei Wei <[email protected]> [aten2trt] init check-in 87ef03338c9a25c5a610a2eb590345e8935f8d75 Wei Wei <[email protected]> [aten2trt] add binary ops 2bb168517ace7e638cffc7a241b1cbf528790b92 Mike Iovine <[email protected]> [fx2trt] Add acc normalization blocklist 8c912e085cf8722d572698286020ae1ce055023d Zhijing Li (Accelerator Enablement) <[email protected]> Skip unstable test_conv_add_standalone_module b80dca9c9afa3b7d253e7806f48a890b9f83bf04 Jonathan Amazon <[email protected]> [PyTorch][FX][Compiler] Add acc_op tracing support for torch.baddbmm in FX 137a3977ffeb03d0387e8a95ff2f32f3d15b3de8 Wei Wei <[email protected]> [aten2trt] resnet support fef54c237589a70c007c861e2d59c4052e3de054 Kefei Lu <[email protected]> [easy] fx2xxx: fix fuse_parallel_linear which changes getitem slices from tuple to list 4b062ef361cd7797e72c51bb4dc41766aca7b6db Kefei Lu <[email protected]> fx2trt: fix bad reshape pattern x.reshape(y.size(0), ...) 49573920892bb2fe75fe011a8cad9887bdc8bd04 Alex Beloi <[email protected]> [FX] add tracing for torch.detach fe3cc75e775af53f603a83e8b4899b28f3cb6ddc Yinghai Lu <[email protected]> [fx2ait] add support to torch.clip 42c54d69c68dc58ac348647acada88b1e5634b40 Fei Kou <[email protected]> Fix clamping float32 boundary values e013621dedf5960f81b915cef8d2ce19ca349a7a Kefei Lu <[email protected]> trt lower: change preset application logic to in-place instead of immutable update adc9f8ff48c01a0ce70080c930221ac81f048563 Kefei Lu <[email protected]> [easy]: fix another instance of [slice(), ...] to (slice(), ...) a22e9ff2cc55eb8669690eedd6971be93a2a356b Rui Zhu <[email protected]> Support NoneType in acc_tracing by setting its meta shape to be 1 4f54ce9283f02fe416ff3f502ef1a4e4f80c0f37 Mike Iovine <[email protected]> [fx2ait] Avoid extra copies from view ops 0baf42ebf6ce4146df1bee2d2e62fa2b77dbd7fb Mor Tzur <[email protected]> add torch.concat to acc_ops 9cd933707772b0f05b8aca62bcc813929bd52868 Shirong Wu <[email protected]> replace assert_allclose with assert_close e418d0653752022ea4ee186036b79dc8ca0ae87b Valeriu Lacatusu <[email protected]> [PyTorch][FX][Compiler] Add acc_op tracing support for torch.nn.functional.softplus in FX afb2f560b3995ea3a1cd440df3cdd66d92472e46 Wei Wei <[email protected]> [fx2trt] test fix to adopt new interface of dynamo 8ca2307c744f13ef15bad49f5030dddd2b787b9d Huamin Li <[email protected]> rename test_setitem to test_setitem_trt e0b75bbfda8604d4b60599ddba4d4aa7023887a5 Valeriu Lacatusu <[email protected]> [FX] Replace deprecated torch.testing.assert_allclose with torch.testing.assert_close 4a233da979a755fa605e9750c6035ed885597afa Valeriu Lacatusu <[email protected]> [PyTorch][FX][Compiler] Add acc_op tracing support for torch.ops._caffe2.RoIAlign in FX
1 parent e3b9929 commit ac07351

File tree

18 files changed

+409
-162
lines changed

18 files changed

+409
-162
lines changed

docs/_sources/tutorials/ptq.rst.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f
136136
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
137137
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
138138
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on
139-
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq
139+
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq
140140

141141
.. _writing_ptq_python:
142142

@@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
194194
calibrator=calibrator)
195195
196196
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
197-
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
198-
and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py
197+
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py
198+
and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py
199199

200200
Citations
201201
^^^^^^^^^^^

examples/fx/hugging_face_torchdynamo_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616
from transformers import BertConfig, ReformerConfig, XLNetModel, XLNetConfig
1717

18-
import torchdynamo
19-
from torchdynamo.optimizations import backends
20-
from torchdynamo.optimizations.training import aot_autograd_debug_strategy1
21-
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
22-
from torchdynamo.testing import collect_results
23-
from torchdynamo.testing import same
18+
import torch._dynamo as torchdynamo
19+
from torch._dynamo.optimizations import backends
20+
from torch._dynamo.optimizations.training import aot_autograd_debug_strategy1
21+
from torch._dynamo.optimizations.training import aot_autograd_speedup_strategy
22+
from torch._dynamo.testing import collect_results
23+
from torch._dynamo.testing import same
2424

2525
torch.backends.cuda.matmul.allow_tf32 = True
2626

examples/fx/torchdynamo_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from dataclasses import dataclass, field, replace
44

55
import torch
6-
import torchdynamo
6+
import torch._dynamo as torchdynamo
77
import torchvision
88
from torch_tensorrt.fx.lower import compile
99
from torch_tensorrt.fx.utils import LowerPrecision
10-
from torchdynamo.optimizations import backends
10+
from torch._dynamo.optimizations import backends
1111

1212
"""
1313
The purpose of this example is to demostrate the lowering flow to TRT and Torchdynamo

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
trt_transposed_linear,
2626
trt_transposed_matmul,
2727
)
28+
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
2829

2930
_LOGGER: logging.Logger = logging.getLogger(__name__)
3031

@@ -3371,6 +3372,9 @@ def acc_ops_gelu(
33713372
name: str,
33723373
) -> Union[TRTTensor, Sequence[TRTTensor]]:
33733374
input_val = kwargs["input"]
3375+
approximate = kwargs["approximate"]
3376+
if approximate is not "none":
3377+
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")
33743378
if not isinstance(input_val, TRTTensor):
33753379
raise RuntimeError(
33763380
f"GELU received input {input_val} that is not part "

py/torch_tensorrt/fx/lower.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
def compile(
3030
module: nn.Module,
3131
input,
32+
min_acc_module_size: int = 10,
3233
max_batch_size: int = 2048,
3334
max_workspace_size=1 << 25,
3435
explicit_batch_dimension=False,
@@ -48,6 +49,7 @@ def compile(
4849
module: Original module for lowering.
4950
input: Input for module.
5051
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
52+
min_acc_module_size: Minimal number of nodes for an accelerated submodule
5153
max_workspace_size: Maximum size of workspace given to TensorRT.
5254
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
5355
lower_precision: lower_precision config given to TRTModule.
@@ -61,6 +63,7 @@ def compile(
6163
"""
6264
lower_setting = LowerSetting(
6365
max_batch_size=max_batch_size,
66+
min_acc_module_size=min_acc_module_size,
6467
max_workspace_size=max_workspace_size,
6568
explicit_batch_dimension=explicit_batch_dimension,
6669
lower_precision=lower_precision,
@@ -237,6 +240,7 @@ def __call__(
237240
module: nn.Module,
238241
inputs: Input,
239242
additional_inputs: Optional[Input] = None,
243+
fp16_conversion_fn: Optional[Callable[[Input], Input]] = None,
240244
) -> nn.Module:
241245
lower_setting = self.lower_pass_manager_builder.lower_setting
242246
atol = lower_setting.correctness_atol
@@ -253,9 +257,26 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
253257
== LowerPrecision.FP16
254258
):
255259
module.half()
256-
inputs = tuple(
257-
x.half() if x is not None and x.dtype == torch.float32 else x
258-
for x in inputs
260+
# A custom conversion function can be passed to the lowerer to
261+
# handle inputs with custom types. By default, just handle
262+
# tensors and NoneType.
263+
if fp16_conversion_fn is None:
264+
conversion_fn = (
265+
lambda x: x.half()
266+
if x is not None and x.dtype == torch.float32
267+
else x
268+
)
269+
else:
270+
conversion_fn = fp16_conversion_fn
271+
272+
inputs = tuple(conversion_fn(x) for x in inputs)
273+
if lower_setting.is_aten:
274+
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(
275+
inputs, additional_inputs
276+
)
277+
else:
278+
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
279+
inputs, additional_inputs
259280
)
260281
if lower_setting.is_aten:
261282
pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline(

py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,51 @@
11
import torch
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
4+
from parameterized import param, parameterized
45
from torch.testing._internal.common_utils import run_tests
56
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
67

78

89
class TestCatConverter(AccTestCase):
9-
def test_cat(self):
10+
@parameterized.expand(
11+
[
12+
param("cat", torch.cat),
13+
param("concat", torch.concat),
14+
]
15+
)
16+
def test_cat(self, _, op):
1017
class Cat(nn.Module):
1118
def forward(self, x, y, z):
12-
return torch.cat((x, y, z), 1)
19+
return op((x, y, z), 1)
1320

1421
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
1522
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
1623

17-
def test_cat_neg(self):
24+
@parameterized.expand(
25+
[
26+
param("cat", torch.cat),
27+
param("concat", torch.concat),
28+
]
29+
)
30+
def test_cat_neg(self, _, op):
1831
class Cat(nn.Module):
1932
def forward(self, x, y, z):
20-
return torch.cat((x, y, z), -1)
33+
return op((x, y, z), -1)
2134

2235
inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)]
2336
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
2437

25-
def test_cat_with_dynamic_shape(self):
38+
@parameterized.expand(
39+
[
40+
param("cat", torch.cat),
41+
param("concat", torch.concat),
42+
]
43+
)
44+
def test_cat_with_dynamic_shape(self, _, op):
2645
class Cat(nn.Module):
2746
def forward(self, x, y):
2847
x = x + y
29-
return torch.cat((x, y), 0)
48+
return op((x, y), 0)
3049

3150
input_specs = [
3251
InputTensorSpec(
@@ -42,11 +61,17 @@ def forward(self, x, y):
4261
]
4362
self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})
4463

45-
def test_cat_with_dynamic_shape_four_dimensions(self):
64+
@parameterized.expand(
65+
[
66+
param("cat", torch.cat),
67+
param("concat", torch.concat),
68+
]
69+
)
70+
def test_cat_with_dynamic_shape_four_dimensions(self, _, op):
4671
class Cat(nn.Module):
4772
def forward(self, x, y):
4873
x = x + y
49-
return torch.cat((x, y), 0)
74+
return op((x, y), 0)
5075

5176
input_specs = [
5277
InputTensorSpec(
@@ -63,6 +88,14 @@ def forward(self, x, y):
6388

6489
self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat})
6590

91+
def test_concat(self):
92+
class Cat(nn.Module):
93+
def forward(self, x, y, z):
94+
return torch.concat((x, y, z), 1)
95+
96+
inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)]
97+
self.run_test(Cat(), inputs, expected_ops={acc_ops.cat})
98+
6699

67100
if __name__ == "__main__":
68101
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,39 @@ def forward(self, x):
5757
TestModule(), input_specs, expected_ops={acc_ops.gelu}
5858
)
5959

60+
def test_gelu_module(self):
61+
class TestModule(nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
self.gelu = torch.nn.GELU()
65+
66+
def forward(self, x):
67+
return self.gelu(x)
68+
69+
inputs = [torch.randn(3, 10, 20)]
70+
self.run_test(
71+
TestModule(),
72+
inputs,
73+
expected_ops={acc_ops.gelu},
74+
test_implicit_batch_dim=False,
75+
)
76+
77+
def test_gelu_module_throw(self):
78+
class TestModule(nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
self.gelu = torch.nn.GELU(approximate="tanh")
82+
83+
def forward(self, x):
84+
return self.gelu(x)
85+
86+
inputs = [torch.randn(3, 10, 20)]
87+
self.run_test_with_assert_error(
88+
TestModule(),
89+
inputs,
90+
expect_error=RuntimeError,
91+
)
92+
6093

6194
if __name__ == "__main__":
6295
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@
66

77

88
class TestNewOnesConverter(AccTestCase):
9-
def test_newone(self):
10-
class TestModule(nn.Module):
11-
def forward(self, x):
12-
return x.new_ones((3, 5), dtype=torch.float16)
13-
14-
inputs = [torch.randn(1, 10)]
15-
self.run_test(
16-
TestModule(),
17-
inputs,
18-
expected_ops={acc_ops.new_ones},
19-
test_implicit_batch_dim=False,
20-
)
21-
229
def test_newone_no_dtype(self):
2310
class TestModule(nn.Module):
2411
def forward(self, x):
@@ -47,23 +34,6 @@ def forward(self, x):
4734

4835

4936
class TestNewOnesConverterWithDynamicShape(AccTestCase):
50-
def test_newone(self):
51-
class TestModule(nn.Module):
52-
def forward(self, x):
53-
return x.new_ones((3, 5), dtype=torch.float16)
54-
55-
input_specs = [
56-
InputTensorSpec(
57-
shape=(-1, -1, -1, -1),
58-
dtype=torch.float32,
59-
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
60-
),
61-
]
62-
63-
self.run_test_with_dynamic_shape(
64-
TestModule(), input_specs, expected_ops={acc_ops.new_ones}
65-
)
66-
6737
def test_newone_no_dtype(self):
6838
class TestModule(nn.Module):
6939
def forward(self, x):

py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,31 @@
99
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
1010

1111
unary_ops = [
12-
(torch.sin, acc_ops.sin),
13-
(torch.cos, acc_ops.cos),
14-
(torch.tan, acc_ops.tan),
15-
(torch.sinh, acc_ops.sinh),
16-
(torch.cosh, acc_ops.cosh),
17-
(torch.asin, acc_ops.asin),
18-
(torch.acos, acc_ops.acos),
19-
(torch.atan, acc_ops.atan),
20-
(torch.abs, acc_ops.abs),
21-
(torch.neg, acc_ops.neg),
22-
(torch.reciprocal, acc_ops.reciprocal),
23-
(torch.sqrt, acc_ops.sqrt),
24-
(torch.log, acc_ops.log),
25-
(torch.exp, acc_ops.exp),
26-
(torch.floor, acc_ops.floor),
27-
(torch.ceil, acc_ops.ceil),
28-
(torch.sign, acc_ops.sign),
12+
(torch.sin, acc_ops.sin, False),
13+
(torch.cos, acc_ops.cos, False),
14+
(torch.tan, acc_ops.tan, False),
15+
(torch.sinh, acc_ops.sinh, False),
16+
(torch.cosh, acc_ops.cosh, False),
17+
(torch.asin, acc_ops.asin, True),
18+
(torch.acos, acc_ops.acos, True),
19+
(torch.atan, acc_ops.atan, True),
20+
(torch.abs, acc_ops.abs, False),
21+
(torch.neg, acc_ops.neg, False),
22+
(torch.reciprocal, acc_ops.reciprocal, False),
23+
(torch.sqrt, acc_ops.sqrt, False),
24+
(torch.log, acc_ops.log, False),
25+
(torch.exp, acc_ops.exp, False),
26+
(torch.floor, acc_ops.floor, False),
27+
(torch.ceil, acc_ops.ceil, False),
28+
(torch.sign, acc_ops.sign, False),
2929
]
3030

3131

3232
class TestUnaryOpConverters(AccTestCase):
33-
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops])
34-
def test_unary_ops(self, name, orig_op: Callable, expected_op):
33+
@parameterized.expand([(op[1].__name__, op[0], op[1], op[2]) for op in unary_ops])
34+
def test_unary_ops(
35+
self, name, orig_op: Callable, expected_op: Callable, range_req: bool
36+
):
3537
class TestModule(nn.Module):
3638
def __init__(self, orig_op):
3739
super().__init__()
@@ -41,11 +43,15 @@ def forward(self, x):
4143
return self.orig_op(x)
4244

4345
m = TestModule(orig_op)
44-
inputs = [torch.randn(2, 2, 3)]
46+
inputs = (
47+
[torch.distributions.uniform.Uniform(-1, 1).sample([2, 2, 3])]
48+
if range_req
49+
else [torch.randn(2, 2, 3)]
50+
)
4551
self.run_test(m, inputs, expected_ops={expected_op})
4652

4753

48-
class TestUnaryOpConvertersWithDynamicShapeFourDimensions(AccTestCase):
54+
class TestUnaryVOpConvertersWithDynamicShapeFourDimensions(AccTestCase):
4955
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops])
5056
def test_unary_ops(self, name, orig_op: Callable, expected_op):
5157
class TestModule(nn.Module):

py/torch_tensorrt/fx/test/core/test_trt_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def forward(self, x):
2727
torch.save(trt_mod, "trt.pt")
2828
reload_trt_mod = torch.load("trt.pt")
2929

30-
torch.testing.assert_allclose(
30+
torch.testing.assert_close(
3131
reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04
3232
)
3333
os.remove(f"{os.getcwd()}/trt.pt")
@@ -49,7 +49,7 @@ def forward(self, x):
4949
new_trt_mod = TRTModule()
5050
new_trt_mod.load_state_dict(st)
5151

52-
torch.testing.assert_allclose(
52+
torch.testing.assert_close(
5353
new_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04
5454
)
5555

0 commit comments

Comments
 (0)