Skip to content

Commit ac93bae

Browse files
author
Wei Wei
committed
[fx2trt] improve to_dtype (#48)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/48 Currently, to_dtype can only support 1) to(dtype) This diff makes this op more capable of handling more cases: 2) to(torch.device) #gpu 3) to(torch.device, dtype) #gpu (Note: this ignores all push blocking failures!) Reviewed By: 842974287 Differential Revision: D35331003 fbshipit-source-id: 4dee2b3c7899805fa4f3c91d0a16207241396647
1 parent 5f83e2c commit ac93bae

File tree

4 files changed

+161
-9
lines changed

4 files changed

+161
-9
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from fx2trt_oss.fx.types import * # noqa: F403
1515
from fx2trt_oss.fx.utils import (
1616
torch_dtype_from_trt,
17+
torch_dtype_to_trt,
1718
get_dynamic_dims,
1819
)
1920
from torch.fx.immutable_collections import immutable_list
@@ -1237,6 +1238,34 @@ def acc_ops_minimum(
12371238
)
12381239

12391240

1241+
@tensorrt_converter(acc_ops.device)
1242+
def acc_ops_device(
1243+
network: TRTNetwork,
1244+
target: Target,
1245+
args: Tuple[Argument, ...],
1246+
kwargs: Dict[str, Argument],
1247+
name: str,
1248+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1249+
# TRT always assume the device is cuda not cpu
1250+
return torch.device("cuda")
1251+
1252+
@tensorrt_converter(acc_ops.to_dtype)
1253+
def acc_ops_to_dtype(
1254+
network: TRTNetwork,
1255+
target: Target,
1256+
args: Tuple[Argument, ...],
1257+
kwargs: Dict[str, Argument],
1258+
name: str,
1259+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1260+
input_val = kwargs["input"]
1261+
input_dtype = kwargs["acc_out_ty"].dtype
1262+
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
1263+
1264+
if input_dtype:
1265+
input_dtype = torch_dtype_to_trt(input_dtype)
1266+
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
1267+
return input_t
1268+
12401269

12411270
@tensorrt_converter(acc_ops.logical_not)
12421271
def acc_ops_logical_not(
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
3+
from torch.testing._internal.common_fx2trt import AccTestCase
4+
from torch.testing._internal.common_utils import run_tests
5+
from fx2trt_oss.fx.utils import LowerPrecision
6+
7+
class TestNeFunctionConverter(AccTestCase):
8+
def test_fp16(self):
9+
class To(torch.nn.Module):
10+
def forward(self, x):
11+
return x.to(torch.float16)
12+
13+
input = torch.randn(2,2)
14+
inputs = [
15+
input,
16+
]
17+
self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim = False, precision=LowerPrecision.FP16)
18+
19+
def test_fp32(self):
20+
class To(torch.nn.Module):
21+
def forward(self, x):
22+
return x.to(torch.float32)
23+
24+
input = torch.randn(2,2).to(torch.float16)
25+
inputs = [
26+
input,
27+
]
28+
self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim = False)
29+
30+
def test_cuda_fp16(self):
31+
class To(torch.nn.Module):
32+
def forward(self, x):
33+
return x.to(torch.device('cuda:0'), torch.float16)
34+
35+
input = torch.randn(2,2)
36+
inputs = [
37+
input,
38+
]
39+
self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype}, test_implicit_batch_dim = False, precision=LowerPrecision.FP16)
40+
41+
def test_cuda(self):
42+
class To(torch.nn.Module):
43+
def forward(self, x):
44+
x = x.to(torch.device('cuda'))
45+
# append extra layer since to(device) is skipped in TRT
46+
return x + torch.randn(2,2).cuda()
47+
48+
input = torch.randn(2,2)
49+
inputs = [
50+
input,
51+
]
52+
self.run_test(To(), inputs, expected_ops={acc_ops.to_dtype, acc_ops.add}, test_implicit_batch_dim = False, precision=LowerPrecision.FP32)
53+
54+
55+
def test_device(self):
56+
class To(torch.nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
self.a = torch.randn(2,2)
60+
def forward(self, x):
61+
idevice = x.device
62+
a = self.a.to(idevice)
63+
return x + a
64+
65+
input = torch.randn(2,2).cuda()
66+
inputs = [
67+
input,
68+
]
69+
self.run_test(To(), inputs, expected_ops={}, test_implicit_batch_dim = False, precision=LowerPrecision.FP32)
70+
71+
def test_device_fp16(self):
72+
class To(torch.nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
self.a = torch.randn(2,2)
76+
def forward(self, x):
77+
idevice = x.device
78+
a = self.a.to(idevice)
79+
# fx tracer could not handle "to(idevice, torch.float16)"
80+
# TypeError: to() received an invalid combination of arguments - got (Attribute, torch.dtype)
81+
return a.to(torch.float16)
82+
83+
input = torch.randn(2,2).half().cuda()
84+
inputs = [
85+
input,
86+
]
87+
self.run_test(To(), inputs, expected_ops={}, test_implicit_batch_dim = False, precision=LowerPrecision.FP16)
88+
89+
if __name__ == '__main__':
90+
run_tests()

test/tracer/test_acc_tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2371,6 +2371,7 @@ def test_all_acc_ops_registered(self):
23712371
acc_ops.interpolate,
23722372
acc_ops.logical_and,
23732373
acc_ops.logical_not,
2374-
acc_ops.ne
2374+
acc_ops.ne,
2375+
acc_ops.device,
23752376
},
23762377
)

tracer/acc_tracer/acc_ops.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ def sign(*, input):
184184
def size(*, input):
185185
return input.size()
186186

187+
@register_acc_op_properties(AccOpProperty.unary)
188+
@register_acc_op
189+
def device(*, input):
190+
return input.device
187191

188192
@register_custom_acc_mapper_fn(
189193
op_and_target=("call_function", getattr),
@@ -203,10 +207,15 @@ def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
203207
input_obj.meta["type"] == torch.Tensor
204208
), f"Expected torch.Tensor type for {input_obj.meta['type']}"
205209
assert (
206-
attr_name == "shape"
210+
attr_name == "shape" or attr_name == "device"
207211
), f"Only supporting shape getattr for now, not {attr_name}"
212+
if attr_name == "shape":
213+
func = size
214+
elif attr_name == "device":
215+
func = device
216+
208217
with node.graph.inserting_before(node):
209-
size_node = node.graph.call_function(size, kwargs={"input": input_obj})
218+
size_node = node.graph.call_function(func, kwargs={"input": input_obj})
210219
size_node.meta = node.meta.copy()
211220
return size_node
212221

@@ -1993,29 +2002,52 @@ def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.
19932002

19942003
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
19952004
@register_acc_op
1996-
def to_dtype(input, acc_out_ty=None):
2005+
def to_dtype(input, acc_out_ty=None, device=None):
19972006
assert acc_out_ty is not None
1998-
return input.to(dtype=acc_out_ty.dtype)
2007+
return input.to(dtype=acc_out_ty.dtype, device=device)
19992008

20002009

20012010
@register_custom_acc_mapper_fn(
20022011
op_and_target=("call_method", "to"),
20032012
arg_replacement_tuples=[
20042013
("input", "input"),
20052014
("dtype", "dtype"),
2015+
("device", "device", this_arg_is_optional),
2016+
20062017
],
20072018
)
20082019
def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module):
2009-
dest_dtype = node.kwargs["dtype"]
2020+
dest = node.kwargs["dtype"]
20102021
mem_format = node.kwargs.get("memory_format")
2011-
device = node.kwargs.get("device")
2012-
assert dest_dtype is not None
2022+
dest_other = node.kwargs.get("device")
2023+
assert dest is not None
20132024
assert mem_format is None or mem_format == torch.preserve_format
2014-
assert device is None
2025+
2026+
dest_dtype = dest_device=None
2027+
if isinstance(dest, torch.fx.node.Node):
2028+
meta_type = dest.meta["type"]
2029+
#consider the device is gpu only, meta info is limited to give clear device type
2030+
if dest.meta["type"] == torch.device:
2031+
dest_device = dest
2032+
else:
2033+
# Due to the limitation of FX, we can not support to(torch.Tensor) since meta only contains 'type': <class 'torch.Tensor'>
2034+
raise RuntimeError(f"We currently do not support to({meta_type})")
2035+
elif isinstance(dest, torch.device):
2036+
# only device is set, dtype=None
2037+
if dest_other is None:
2038+
dest_device = dest
2039+
# device and dtype are both set
2040+
else:
2041+
dest_dtype = dest_other
2042+
dest_device = dest
2043+
# only dtype is set
2044+
else:
2045+
dest_dtype = dest
20152046

20162047
new_kwargs = {
20172048
"input": node.kwargs["input"],
20182049
"acc_out_ty": acc_utils.build_raw_tensor_meta(dtype=dest_dtype),
2050+
"device": dest_device,
20192051
}
20202052

20212053
with node.graph.inserting_before(node):

0 commit comments

Comments
 (0)