Skip to content

Commit 86e6e47

Browse files
author
Wei Wei
committed
[fx2trt] dispatch tracer improvement (#64)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/64 as titled 1. able to handle in-place operation 2. handle corner case where module=ReLU and only contains aten relu with in-place=True 3. introduce functionalize to deinplace those `inplace` and `view` operations 4. introduce two methods of setting leaf nodes. leaf(op registeration) and leaf(override call_module). Only former one works with functionalize 5. More notes are in `test_dispatch_tracer.py` Reviewed By: yinghai Differential Revision: D35684102 fbshipit-source-id: ff2439587d563c4eee3a07b1ae3c1082af2679c8
1 parent 5760875 commit 86e6e47

File tree

2 files changed

+206
-14
lines changed

2 files changed

+206
-14
lines changed

test/tracer/test_dispatch_tracer.py

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,140 @@
11
import unittest
22

33
import torch
4+
import torchdynamo
5+
import torchvision
6+
7+
from functorch import make_fx as make_fx_pk
8+
from functorch.experimental import functionalize
49
from fx2trt_oss.tracer.dispatch_tracer.tracer import make_fx
10+
from torch.library import Library
11+
from torchdynamo.optimizations.normalize import normalize_ir
12+
from torchdynamo.optimizations.python_key import fake_signature
513

614
torch.manual_seed(0)
715

16+
wrap_lib = Library("wrap", "DEF")
17+
"""
18+
There are two methods for setting leaf_module. leaf(op registeration) and leaf(override call_module)
19+
Only leaf(op registeration) can work together with functionalize.
20+
If you do not need funcitonalize, you can choose any of the leaf module methods.
21+
22+
Test coverage:
23+
PythonkeyTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration)
24+
25+
DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration)
26+
DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module)
27+
DispatchTracerTest.test_non_tensor_input: dispatch tracer
28+
DispatchTracerTest.test_resnet18: dispatch tracer
29+
DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize
30+
DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize
31+
"""
32+
33+
34+
class PythonkeyTracerTest(unittest.TestCase):
35+
def test_leaf_operator_reg(self):
36+
class Leaf(torch.nn.Module):
37+
def forward(self, x, y):
38+
return x + y + torch.nn.Parameter(torch.ones(5))
39+
40+
leaf = Leaf()
41+
wrap_lib.define("wrapped_foo(Tensor x, Tensor y) -> Tensor")
42+
wrap_lib.impl("wrapped_foo", leaf, "CPU")
43+
44+
class Bar(torch.nn.Module):
45+
def __init__(self):
46+
super(Bar, self).__init__()
47+
self.foo = torch.ops.wrap.wrapped_foo
48+
self.other = torch.nn.Parameter(torch.ones(5))
49+
50+
def forward(self, x, y):
51+
x = self.foo(x, y)
52+
x = x + self.other
53+
return x
54+
55+
mod = Bar()
56+
57+
def f(x, y):
58+
return mod(x, y)
59+
60+
gm = make_fx_pk(functionalize(f))(torch.ones(5), torch.ones(5))
61+
inputs = [torch.ones(5) + 5, torch.ones(5) + 8]
62+
output = gm(*inputs)
63+
ref_output = f(*inputs)
64+
torch.testing.assert_close(output, ref_output)
65+
866

967
class DispatchTracerTest(unittest.TestCase):
10-
def test_leaf_module_list(self):
11-
class TestModule(torch.nn.Module):
68+
def test_leaf_operator_reg(self):
69+
class Leaf(torch.nn.Module):
70+
def forward(self, x, y):
71+
return x + y + torch.nn.Parameter(torch.ones(5))
72+
73+
leaf = Leaf()
74+
wrap_lib.define("wrapped_leaf(Tensor x, Tensor y) -> Tensor")
75+
wrap_lib.impl("wrapped_leaf", leaf, "CPU")
76+
77+
class Bar(torch.nn.Module):
78+
def __init__(self):
79+
super(Bar, self).__init__()
80+
self.leaf = torch.ops.wrap.wrapped_leaf
81+
self.other = torch.nn.Parameter(torch.ones(5))
82+
83+
def forward(self, x, y):
84+
x = self.leaf(x, y)
85+
x = x + self.other
86+
return x
87+
88+
mod = Bar()
89+
90+
def f(x, y):
91+
return mod(x, y)
92+
93+
gm = make_fx(functionalize(f))(torch.ones(5), torch.ones(5))
94+
inputs = [torch.ones(5) + 5, torch.ones(5) + 8]
95+
output = gm(*inputs)
96+
ref_output = f(*inputs)
97+
torch.testing.assert_close(output, ref_output)
98+
# through the op registration method, the module is defined in a call_function
99+
call_function_node = None
100+
for node in gm.graph.nodes:
101+
if (
102+
node.op == "call_function"
103+
and node.target == torch.ops.wrap.wrapped_leaf
104+
):
105+
call_function_node = node
106+
self.assertIsNotNone(call_function_node)
107+
108+
def test_leaf(self):
109+
class TestModuleLeaf(torch.nn.Module):
12110
def __init__(self):
13111
super().__init__()
14112
self.conv = torch.nn.Conv2d(3, 10, 1)
15-
self.relu = torch.nn.ReLU()
113+
self.relu = torch.nn.ReLU(inplace=True)
16114

17115
def forward(self, x):
18116
x = self.conv(x)
19117
return self.relu(x)
20118

119+
class TestModule(torch.nn.Module):
120+
def __init__(self):
121+
super().__init__()
122+
123+
self.relu = torch.nn.ReLU(inplace=True)
124+
self.leaf = TestModuleLeaf()
125+
126+
def forward(self, x):
127+
x = self.leaf(x)
128+
return self.relu(x)
129+
21130
mod = TestModule()
22131

23132
def f(x):
24133
return mod(x)
25134

26135
a = torch.randn(1, 3, 1, 1)
27136
ref_output = f(a)
28-
func = make_fx(f, leaf_module_list={"torch.nn.modules.activation.ReLU"})
137+
func = make_fx(f, leaf_module_list={"test_dispatch_tracer.TestModuleLeaf"})
29138
gm = func(a)
30139
output = gm(a)
31140
torch.testing.assert_close(output, ref_output)
@@ -36,17 +145,90 @@ def f(x):
36145
if node.op == "call_module":
37146
call_module_node = node
38147
self.assertIsNotNone(call_module_node)
39-
self.assertEqual(call_module_node.target, "ReLU_0")
148+
self.assertEqual(call_module_node.target, "TestModuleLeaf_0")
40149

41150
def test_non_tensor_input(self):
42151
def foo(x):
43152
a = x["a"]
44153
b = x["b"]
45154
return a + b
46155

47-
x = {"a": torch.randn(1), "b": torch.randn(1)}
156+
x = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
48157
ref_output = foo(x)
49158
func = make_fx(foo)
50159
gm = func(x)
51160
output = gm(x)
52161
torch.testing.assert_close(output, ref_output)
162+
163+
def test_resnet18(self):
164+
mod = torchvision.models.resnet18(pretrained=False)
165+
166+
def f(x):
167+
return mod(x)
168+
169+
a = torch.randn(1, 3, 224, 224)
170+
ref_output = f(a)
171+
gm = make_fx(f)(a)
172+
output = gm(a)
173+
torch.testing.assert_close(output, ref_output)
174+
175+
def test_reference_copy(self):
176+
class TestModule(torch.nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
180+
def forward(self, x, y):
181+
y[:, 0] = x[:, 0]
182+
return y
183+
184+
mod = TestModule()
185+
186+
def f(x, y):
187+
return mod(x, y)
188+
189+
a = torch.ones(2, 2) + 2
190+
b = torch.ones(2, 2)
191+
b_copy = torch.ones(2, 2)
192+
ref_output = f(a, b)
193+
gm = make_fx(functionalize(f))(a, b)
194+
output = gm(a, b_copy)
195+
torch.testing.assert_close(output, ref_output)
196+
197+
def test_reference_copy_torchdynamo(self):
198+
class TestModule(torch.nn.Module):
199+
def __init__(self):
200+
super().__init__()
201+
self.relu = torch.nn.ReLU(inplace=True)
202+
203+
def forward(self, x, y):
204+
y = y + 3
205+
y = self.relu(y)
206+
y[:, 0] = x[:, 0]
207+
return y
208+
209+
mod = TestModule()
210+
211+
def f(x, y):
212+
return mod(x, y)
213+
214+
a = torch.ones(2, 2) + 2
215+
b = torch.ones(2, 2)
216+
inputs = [a, b]
217+
ref_output = f(*inputs)
218+
219+
def compile_dispatch(gm, example_inputs):
220+
# after normalization, relu in-place is removed
221+
gm = normalize_ir(gm, example_inputs)
222+
# dispatch tracer
223+
nargs = len(example_inputs)
224+
gm = make_fx(functionalize(fake_signature(gm, nargs)))(*example_inputs)
225+
return gm
226+
227+
optimize_ctx = torchdynamo.optimize(
228+
compile_dispatch,
229+
nopython=True,
230+
)
231+
232+
with optimize_ctx:
233+
output = mod(*inputs)
234+
torch.testing.assert_close(output, ref_output)

tracer/dispatch_tracer/tracer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,22 @@ def unwrap_proxy(e):
2626
return e.proxy if isinstance(e, DispatchTensor) else e
2727

2828

29-
def build_outputs(func, args, kwargs, proxy_out):
29+
def build_outputs(func, func_overload, args, kwargs, proxy_out, call_module=False):
3030
# Kind of a hacky way to test if an op is in-place or not
3131
if func.__name__[-1] == "_" and func.__name__[0] != "_":
3232
args[0].proxy = proxy_out
3333

3434
with no_dispatch():
35-
real_out = func(*args, **kwargs)
35+
real_out = func_overload(*args, **kwargs)
3636

3737
def wrap_with_proxy(e, proxy):
38-
if isinstance(e, torch.Tensor):
38+
if e is None:
39+
e = torch.empty(())
40+
if type(e) == torch.Tensor:
3941
return DispatchTensor(e, proxy)
42+
# if module output is dispatchTensor, then all op inside it are in-place
43+
elif type(e) == DispatchTensor and call_module:
44+
e.proxy = proxy_out
4045
else:
4146
return e
4247

@@ -46,7 +51,7 @@ def wrap_with_proxy(e, proxy):
4651
)
4752
elif isinstance(real_out, list):
4853
return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]
49-
elif isinstance(real_out, torch.Tensor):
54+
elif type(real_out) == torch.Tensor:
5055
return wrap_with_proxy(real_out, proxy_out)
5156
else:
5257
return real_out
@@ -78,11 +83,12 @@ def __repr__(self):
7883
__torch_function__ = _disabled_torch_function_impl
7984

8085
@classmethod
81-
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
86+
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
87+
func = func_overload.overloadpacket
8288
proxy_args = pytree.tree_map(unwrap_proxy, args)
8389
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
8490
proxy_out = func(*proxy_args, **proxy_kwargs)
85-
return build_outputs(func, args, kwargs, proxy_out)
91+
return build_outputs(func, func_overload, args, kwargs, proxy_out)
8692

8793

8894
class DispatchTracer(Tracer):
@@ -101,6 +107,7 @@ def __init__(self, leaf_module_list: Optional[Set[str]] = None):
101107
DEFAULT_LEAF_MODULE_LIST
102108
)
103109

110+
# User can use leaf_module_list but it won't work combine with functionalize
104111
def call_module(
105112
self,
106113
m: torch.nn.Module,
@@ -121,7 +128,10 @@ def call_module(
121128
proxy_out = self.create_proxy(
122129
"call_module", qualname, proxy_args, proxy_kwargs
123130
)
124-
return build_outputs(forward, args, kwargs, proxy_out)
131+
132+
return build_outputs(
133+
forward, forward, args, kwargs, proxy_out, call_module=True
134+
)
125135
return forward(*args, **kwargs)
126136

127137
def is_leaf_module(self, m) -> bool:
@@ -170,7 +180,7 @@ def dispatch_trace(
170180
gm = GraphModule(tracer.root, graph, name)
171181
gm.graph.eliminate_dead_code()
172182
gm.recompile()
173-
return NormalizeArgs(gm).transform()
183+
return gm
174184

175185

176186
def wrap_key(f, inps):

0 commit comments

Comments
 (0)