Skip to content

Commit 4daac3b

Browse files
842974287Wei Wei
authored andcommitted
[fx2trt] dispatch tracer (#20)
Summary: Pull Request resolved: pytorch/fx2trt#20 Adding the initial version of dispatch tracer for fx2trt. Mostly copied from functorch. Added leaf module support Removed decomposition support for this version for simplicity. Will add it. This version doesn't support trace size() as a node. Reviewed By: wushirong Differential Revision: D34868356 fbshipit-source-id: 52b34d8e3a34fd9209cd713363016259a735b866
1 parent ac93bae commit 4daac3b

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed

test/tracer/test_dispatch_tracer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
3+
import torch
4+
from fx2trt_oss.tracer.dispatch_tracer.tracer import make_fx
5+
6+
torch.manual_seed(0)
7+
8+
9+
class DispatchTracerTest(unittest.TestCase):
10+
def test_leaf_module_list(self):
11+
class TestModule(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.conv = torch.nn.Conv2d(3, 10, 1)
15+
self.relu = torch.nn.ReLU()
16+
17+
def forward(self, x):
18+
x = self.conv(x)
19+
return self.relu(x)
20+
21+
mod = TestModule()
22+
23+
def f(x):
24+
return mod(x)
25+
26+
a = torch.randn(1, 3, 1, 1)
27+
ref_output = f(a)
28+
func = make_fx(f, leaf_module_list={"torch.nn.modules.activation.ReLU"})
29+
gm = func(a)
30+
output = gm(a)
31+
torch.testing.assert_close(output, ref_output)
32+
33+
# There should be a call module node in the graph.
34+
call_module_node = None
35+
for node in gm.graph.nodes:
36+
if node.op == "call_module":
37+
call_module_node = node
38+
self.assertIsNotNone(call_module_node)
39+
self.assertEqual(call_module_node.target, "ReLU_0")
40+
41+
def test_non_tensor_input(self):
42+
def foo(x):
43+
a = x["a"]
44+
b = x["b"]
45+
return a + b
46+
47+
x = {"a": torch.randn(1), "b": torch.randn(1)}
48+
ref_output = foo(x)
49+
func = make_fx(foo)
50+
gm = func(x)
51+
output = gm(x)
52+
torch.testing.assert_close(output, ref_output)

tracer/dispatch_tracer/__init__.py

Whitespace-only changes.

tracer/dispatch_tracer/tracer.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import functools
2+
from typing import Set, Any, Callable, Dict, Optional, Tuple
3+
4+
import torch
5+
import torch.fx as fx
6+
import torch.utils._pytree as pytree
7+
from torch._C import _disabled_torch_function_impl
8+
from torch.fx import GraphModule, Tracer
9+
from torch.fx.experimental.normalize import NormalizeArgs
10+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
11+
from contextlib import contextmanager
12+
13+
DEFAULT_LEAF_MODULE_LIST = {}
14+
15+
16+
@contextmanager
17+
def no_dispatch():
18+
guard = torch._C._DisableTorchDispatch()
19+
try:
20+
yield
21+
finally:
22+
del guard
23+
24+
25+
def unwrap_proxy(e):
26+
return e.proxy if isinstance(e, DispatchTensor) else e
27+
28+
29+
def build_outputs(func, args, kwargs, proxy_out):
30+
# Kind of a hacky way to test if an op is in-place or not
31+
if func.__name__[-1] == "_" and func.__name__[0] != "_":
32+
args[0].proxy = proxy_out
33+
34+
with no_dispatch():
35+
real_out = func(*args, **kwargs)
36+
37+
def wrap_with_proxy(e, proxy):
38+
if isinstance(e, torch.Tensor):
39+
return DispatchTensor(e, proxy)
40+
else:
41+
return e
42+
43+
if isinstance(real_out, tuple):
44+
return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
45+
elif isinstance(real_out, list):
46+
return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]
47+
elif isinstance(real_out, torch.Tensor):
48+
return wrap_with_proxy(real_out, proxy_out)
49+
else:
50+
return real_out
51+
52+
53+
class DispatchTensor(torch.Tensor):
54+
"""
55+
Copied from the python key tensor in functorch
56+
https://github.com/pytorch/functorch/blob/b83273b25213f556f05a065163163ba531e24750/functorch/_src/python_key.py.
57+
and tracer tensor in subclass_zoo
58+
https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py
59+
60+
The differences are
61+
1. when creating the tensor we always set require_grad to false as we are only using
62+
it here for inference purposes.
63+
"""
64+
65+
@staticmethod
66+
def __new__(cls, elem, proxy):
67+
return torch.Tensor._make_subclass(cls, elem, require_grad=False)
68+
69+
def __init__(self, elem, proxy):
70+
self.proxy = proxy
71+
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
72+
73+
def __repr__(self):
74+
return f"DispatchTensor({torch.Tensor._make_subclass(torch.Tensor, self)})"
75+
76+
__torch_function__ = _disabled_torch_function_impl
77+
78+
@classmethod
79+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
80+
proxy_args = pytree.tree_map(unwrap_proxy, args)
81+
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
82+
proxy_out = func(*proxy_args, **proxy_kwargs)
83+
return build_outputs(func, args, kwargs, proxy_out)
84+
85+
86+
class DispatchTracer(Tracer):
87+
"""
88+
Copied from the python key tracer in functorch
89+
https://github.com/pytorch/functorch/blob/b83273b25213f556f05a065163163ba531e24750/functorch/_src/python_key.py.
90+
91+
The differences are
92+
1. this tracer allows specifying leaf module and will preserve it as a call module node
93+
in the graph.
94+
"""
95+
def __init__(self, leaf_module_list: Optional[Set[str]] = None):
96+
super().__init__()
97+
self.leaf_module_list = (leaf_module_list or set()).union(DEFAULT_LEAF_MODULE_LIST)
98+
99+
def call_module(
100+
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
101+
) -> Any:
102+
if self.is_leaf_module(m):
103+
i = 0
104+
while True:
105+
qualname = f"{type(m).__name__}_{i}"
106+
if not hasattr(self.root, qualname):
107+
break
108+
i += 1
109+
setattr(self.root, qualname, m)
110+
proxy_args = pytree.tree_map(unwrap_proxy, args)
111+
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
112+
proxy_out = self.create_proxy("call_module", qualname, proxy_args, proxy_kwargs)
113+
return build_outputs(forward, args, kwargs, proxy_out)
114+
return forward(*args, **kwargs)
115+
116+
def is_leaf_module(self, m) -> bool:
117+
return torch.typename(m) in self.leaf_module_list
118+
119+
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
120+
if isinstance(attr_val, torch.nn.Parameter):
121+
for n, p in self.root.named_parameters():
122+
if attr_val is p:
123+
if n not in parameter_proxy_cache:
124+
proxy = self.create_proxy("get_attr", n, (), {})
125+
parameter_proxy_cache[n] = DispatchTensor(attr_val, proxy)
126+
return parameter_proxy_cache[n]
127+
return attr_val
128+
return attr_val
129+
130+
def create_arg(self, a: Any):
131+
if isinstance(a, torch.nn.Parameter):
132+
for n, p in self.root.named_parameters():
133+
if a is p:
134+
return self.create_node("get_attr", n, (), {})
135+
qualname: Optional[str] = None
136+
137+
i = 0
138+
while True:
139+
qualname = f"_param_constant{i}"
140+
if not hasattr(self.root, qualname):
141+
break
142+
i += 1
143+
setattr(self.root, qualname, a)
144+
145+
return self.create_node("get_attr", qualname, (), {})
146+
return super().create_arg(a)
147+
148+
149+
def dispatch_trace(
150+
root: torch.nn.Module,
151+
leaf_module_list: Optional[Set[str]] = None,
152+
concrete_args=None,
153+
) -> GraphModule:
154+
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
155+
tracer = DispatchTracer(leaf_module_list)
156+
graph = tracer.trace(root, concrete_args=concrete_args)
157+
gm = GraphModule(tracer.root, graph, name)
158+
gm.graph.eliminate_dead_code()
159+
gm.recompile()
160+
return NormalizeArgs(gm).transform()
161+
162+
163+
def wrap_key(f, inps):
164+
flat_inps, inp_spec = pytree.tree_flatten(inps)
165+
166+
@functools.wraps(f)
167+
def wrapped(*args):
168+
flat_args, args_spec = pytree.tree_flatten(args)
169+
assert len(flat_args) == len(flat_inps)
170+
for idx, arg in enumerate(flat_args):
171+
if isinstance(flat_inps[idx], torch.Tensor):
172+
flat_args[idx] = DispatchTensor(flat_inps[idx], arg)
173+
else:
174+
flat_args[idx] = flat_inps[idx]
175+
tree_args = pytree.tree_unflatten(flat_args, args_spec)
176+
out = f(*tree_args)
177+
flat_outs, out_spec = pytree.tree_flatten(out)
178+
for idx in range(len(flat_outs)):
179+
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], DispatchTensor):
180+
flat_outs[idx] = flat_outs[idx].proxy
181+
return pytree.tree_unflatten(flat_outs, out_spec)
182+
183+
return wrapped
184+
185+
186+
def make_fx(f, leaf_module_list: Optional[Set[str]] = None):
187+
@functools.wraps(f)
188+
def wrapped(*args):
189+
phs = pytree.tree_map(lambda x: fx.PH, args)
190+
t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs), leaf_module_list=leaf_module_list)
191+
return t
192+
193+
return wrapped

0 commit comments

Comments
 (0)