|
| 1 | +import functools |
| 2 | +from typing import Set, Any, Callable, Dict, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.fx as fx |
| 6 | +import torch.utils._pytree as pytree |
| 7 | +from torch._C import _disabled_torch_function_impl |
| 8 | +from torch.fx import GraphModule, Tracer |
| 9 | +from torch.fx.experimental.normalize import NormalizeArgs |
| 10 | +from torch.fx.passes.shape_prop import _extract_tensor_metadata |
| 11 | +from contextlib import contextmanager |
| 12 | + |
| 13 | +DEFAULT_LEAF_MODULE_LIST = {} |
| 14 | + |
| 15 | + |
| 16 | +@contextmanager |
| 17 | +def no_dispatch(): |
| 18 | + guard = torch._C._DisableTorchDispatch() |
| 19 | + try: |
| 20 | + yield |
| 21 | + finally: |
| 22 | + del guard |
| 23 | + |
| 24 | + |
| 25 | +def unwrap_proxy(e): |
| 26 | + return e.proxy if isinstance(e, DispatchTensor) else e |
| 27 | + |
| 28 | + |
| 29 | +def build_outputs(func, args, kwargs, proxy_out): |
| 30 | + # Kind of a hacky way to test if an op is in-place or not |
| 31 | + if func.__name__[-1] == "_" and func.__name__[0] != "_": |
| 32 | + args[0].proxy = proxy_out |
| 33 | + |
| 34 | + with no_dispatch(): |
| 35 | + real_out = func(*args, **kwargs) |
| 36 | + |
| 37 | + def wrap_with_proxy(e, proxy): |
| 38 | + if isinstance(e, torch.Tensor): |
| 39 | + return DispatchTensor(e, proxy) |
| 40 | + else: |
| 41 | + return e |
| 42 | + |
| 43 | + if isinstance(real_out, tuple): |
| 44 | + return tuple([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]) |
| 45 | + elif isinstance(real_out, list): |
| 46 | + return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)] |
| 47 | + elif isinstance(real_out, torch.Tensor): |
| 48 | + return wrap_with_proxy(real_out, proxy_out) |
| 49 | + else: |
| 50 | + return real_out |
| 51 | + |
| 52 | + |
| 53 | +class DispatchTensor(torch.Tensor): |
| 54 | + """ |
| 55 | + Copied from the python key tensor in functorch |
| 56 | + https://github.com/pytorch/functorch/blob/b83273b25213f556f05a065163163ba531e24750/functorch/_src/python_key.py. |
| 57 | + and tracer tensor in subclass_zoo |
| 58 | + https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py |
| 59 | +
|
| 60 | + The differences are |
| 61 | + 1. when creating the tensor we always set require_grad to false as we are only using |
| 62 | + it here for inference purposes. |
| 63 | + """ |
| 64 | + |
| 65 | + @staticmethod |
| 66 | + def __new__(cls, elem, proxy): |
| 67 | + return torch.Tensor._make_subclass(cls, elem, require_grad=False) |
| 68 | + |
| 69 | + def __init__(self, elem, proxy): |
| 70 | + self.proxy = proxy |
| 71 | + proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self) |
| 72 | + |
| 73 | + def __repr__(self): |
| 74 | + return f"DispatchTensor({torch.Tensor._make_subclass(torch.Tensor, self)})" |
| 75 | + |
| 76 | + __torch_function__ = _disabled_torch_function_impl |
| 77 | + |
| 78 | + @classmethod |
| 79 | + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| 80 | + proxy_args = pytree.tree_map(unwrap_proxy, args) |
| 81 | + proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) |
| 82 | + proxy_out = func(*proxy_args, **proxy_kwargs) |
| 83 | + return build_outputs(func, args, kwargs, proxy_out) |
| 84 | + |
| 85 | + |
| 86 | +class DispatchTracer(Tracer): |
| 87 | + """ |
| 88 | + Copied from the python key tracer in functorch |
| 89 | + https://github.com/pytorch/functorch/blob/b83273b25213f556f05a065163163ba531e24750/functorch/_src/python_key.py. |
| 90 | +
|
| 91 | + The differences are |
| 92 | + 1. this tracer allows specifying leaf module and will preserve it as a call module node |
| 93 | + in the graph. |
| 94 | + """ |
| 95 | + def __init__(self, leaf_module_list: Optional[Set[str]] = None): |
| 96 | + super().__init__() |
| 97 | + self.leaf_module_list = (leaf_module_list or set()).union(DEFAULT_LEAF_MODULE_LIST) |
| 98 | + |
| 99 | + def call_module( |
| 100 | + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] |
| 101 | + ) -> Any: |
| 102 | + if self.is_leaf_module(m): |
| 103 | + i = 0 |
| 104 | + while True: |
| 105 | + qualname = f"{type(m).__name__}_{i}" |
| 106 | + if not hasattr(self.root, qualname): |
| 107 | + break |
| 108 | + i += 1 |
| 109 | + setattr(self.root, qualname, m) |
| 110 | + proxy_args = pytree.tree_map(unwrap_proxy, args) |
| 111 | + proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) |
| 112 | + proxy_out = self.create_proxy("call_module", qualname, proxy_args, proxy_kwargs) |
| 113 | + return build_outputs(forward, args, kwargs, proxy_out) |
| 114 | + return forward(*args, **kwargs) |
| 115 | + |
| 116 | + def is_leaf_module(self, m) -> bool: |
| 117 | + return torch.typename(m) in self.leaf_module_list |
| 118 | + |
| 119 | + def _module_getattr(self, attr, attr_val, parameter_proxy_cache): |
| 120 | + if isinstance(attr_val, torch.nn.Parameter): |
| 121 | + for n, p in self.root.named_parameters(): |
| 122 | + if attr_val is p: |
| 123 | + if n not in parameter_proxy_cache: |
| 124 | + proxy = self.create_proxy("get_attr", n, (), {}) |
| 125 | + parameter_proxy_cache[n] = DispatchTensor(attr_val, proxy) |
| 126 | + return parameter_proxy_cache[n] |
| 127 | + return attr_val |
| 128 | + return attr_val |
| 129 | + |
| 130 | + def create_arg(self, a: Any): |
| 131 | + if isinstance(a, torch.nn.Parameter): |
| 132 | + for n, p in self.root.named_parameters(): |
| 133 | + if a is p: |
| 134 | + return self.create_node("get_attr", n, (), {}) |
| 135 | + qualname: Optional[str] = None |
| 136 | + |
| 137 | + i = 0 |
| 138 | + while True: |
| 139 | + qualname = f"_param_constant{i}" |
| 140 | + if not hasattr(self.root, qualname): |
| 141 | + break |
| 142 | + i += 1 |
| 143 | + setattr(self.root, qualname, a) |
| 144 | + |
| 145 | + return self.create_node("get_attr", qualname, (), {}) |
| 146 | + return super().create_arg(a) |
| 147 | + |
| 148 | + |
| 149 | +def dispatch_trace( |
| 150 | + root: torch.nn.Module, |
| 151 | + leaf_module_list: Optional[Set[str]] = None, |
| 152 | + concrete_args=None, |
| 153 | +) -> GraphModule: |
| 154 | + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ |
| 155 | + tracer = DispatchTracer(leaf_module_list) |
| 156 | + graph = tracer.trace(root, concrete_args=concrete_args) |
| 157 | + gm = GraphModule(tracer.root, graph, name) |
| 158 | + gm.graph.eliminate_dead_code() |
| 159 | + gm.recompile() |
| 160 | + return NormalizeArgs(gm).transform() |
| 161 | + |
| 162 | + |
| 163 | +def wrap_key(f, inps): |
| 164 | + flat_inps, inp_spec = pytree.tree_flatten(inps) |
| 165 | + |
| 166 | + @functools.wraps(f) |
| 167 | + def wrapped(*args): |
| 168 | + flat_args, args_spec = pytree.tree_flatten(args) |
| 169 | + assert len(flat_args) == len(flat_inps) |
| 170 | + for idx, arg in enumerate(flat_args): |
| 171 | + if isinstance(flat_inps[idx], torch.Tensor): |
| 172 | + flat_args[idx] = DispatchTensor(flat_inps[idx], arg) |
| 173 | + else: |
| 174 | + flat_args[idx] = flat_inps[idx] |
| 175 | + tree_args = pytree.tree_unflatten(flat_args, args_spec) |
| 176 | + out = f(*tree_args) |
| 177 | + flat_outs, out_spec = pytree.tree_flatten(out) |
| 178 | + for idx in range(len(flat_outs)): |
| 179 | + if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], DispatchTensor): |
| 180 | + flat_outs[idx] = flat_outs[idx].proxy |
| 181 | + return pytree.tree_unflatten(flat_outs, out_spec) |
| 182 | + |
| 183 | + return wrapped |
| 184 | + |
| 185 | + |
| 186 | +def make_fx(f, leaf_module_list: Optional[Set[str]] = None): |
| 187 | + @functools.wraps(f) |
| 188 | + def wrapped(*args): |
| 189 | + phs = pytree.tree_map(lambda x: fx.PH, args) |
| 190 | + t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs), leaf_module_list=leaf_module_list) |
| 191 | + return t |
| 192 | + |
| 193 | + return wrapped |
0 commit comments