Skip to content

Commit fc4177e

Browse files
jfix71Wei Wei
authored andcommitted
[acc_tracer] Remove tensor_meta after NormalizeArgs (#16)
Summary: Pull Request resolved: pytorch/fx2trt#16 Reviewed By: yinghai, wushirong Differential Revision: D34772842 fbshipit-source-id: fd2577e54b1e5f563df2e1052773ef7b19069abe
1 parent 18f74b5 commit fc4177e

File tree

4 files changed

+52
-17
lines changed

4 files changed

+52
-17
lines changed

test/tracer/test_acc_tracer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,11 +1989,11 @@ def forward(self, a: List[torch.Tensor]) -> torch.Tensor:
19891989
else:
19901990
self.fail(f"Unexpected node: {node.format_node()}")
19911991

1992-
# Check the tensor metadatas are correct given the input is a list.
1993-
self.assertTrue(isinstance(ph.meta["tensor_meta"], list))
1994-
self.assertEqual(len(ph.meta["tensor_meta"]), 2)
1995-
self.assertEqual(getitem_0.meta["tensor_meta"], ph.meta["tensor_meta"][0])
1996-
self.assertEqual(getitem_1.meta["tensor_meta"], ph.meta["tensor_meta"][1])
1992+
# Check the tensor ranks are correct given the input is a list.
1993+
self.assertTrue(isinstance(ph.meta["tensor_rank"], list))
1994+
self.assertEqual(len(ph.meta["tensor_rank"]), 2)
1995+
self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"][0])
1996+
self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"][1])
19971997

19981998
self.assertTrue(torch.equal(m(input), traced(input)))
19991999

tracer/acc_tracer/acc_ops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,7 @@ def cat(*, tensors, dim):
320320
)
321321
def transpose_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
322322
# Get the dim-permutation/shuffle
323-
shape_as_list = node.meta["tensor_meta"].shape
324-
ranks = len(shape_as_list)
323+
ranks = node.meta["tensor_rank"]
325324
shuffle = list(range(ranks))
326325
dim0 = cast(int, node.kwargs["dim0"])
327326
dim1 = cast(int, node.kwargs["dim1"])
@@ -422,7 +421,7 @@ def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
422421
],
423422
)
424423
def t_mapper(node: torch.fx.Node, _: nn.Module):
425-
ranks = len(node.meta["tensor_meta"].shape)
424+
ranks = node.meta["tensor_rank"]
426425
shuffle = [1, 0] if (ranks > 1) else [0]
427426

428427
with node.graph.inserting_before(node):
@@ -1842,7 +1841,7 @@ def packed_quantized_linear_mapper(
18421841
}
18431842

18441843
new_node = node.graph.call_function(quantized_linear, kwargs=kwargs)
1845-
new_node.meta = node.meta
1844+
new_node.meta = node.meta.copy()
18461845
return new_node
18471846

18481847

tracer/acc_tracer/acc_tracer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import fx2trt_oss.tracer.acc_tracer.acc_normalizer as acc_normalizer
1212
import fx2trt_oss.tracer.acc_tracer.acc_ops # noqa: F401
13+
import fx2trt_oss.tracer.acc_tracer.acc_utils as acc_utils
1314
import torch
1415
import torch.jit as jit
1516
import torch.nn as nn
@@ -249,7 +250,11 @@ def trace(
249250
}
250251

251252

252-
def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None, leaf_module_list: Optional[Set] = None) -> nn.Module:
253+
def _rewrite(
254+
mod_to_rewrite: nn.Module,
255+
allow_list: Optional[Set] = None,
256+
leaf_module_list: Optional[Set] = None,
257+
) -> nn.Module:
253258
if allow_list is None:
254259
allow_list = DEFAULT_REWRITE_ALLOW_LIST
255260
else:
@@ -270,7 +275,7 @@ def rewrite_module(m: nn.Module):
270275
return m
271276

272277
# If m is an already-rewritten RewrittenModule, then use the original base class.
273-
base_class : Type[nn.Module] = getattr(m, "_base_class_origin", type(m))
278+
base_class: Type[nn.Module] = getattr(m, "_base_class_origin", type(m))
274279

275280
# Keep track of all the ConditionalExceptionWrappers that the
276281
# Acc_Rewriter calls into in this module so we can add them in init
@@ -290,7 +295,9 @@ class RewrittenModule(base_class): # type: ignore[valid-type, misc]
290295
for method_name in dir(base_class):
291296
method = getattr(base_class, method_name, None)
292297
if method is None and method_name not in {"__doc__"}:
293-
_LOGGER.warning(f"{__qualname__} does not have attribute {method_name}")
298+
_LOGGER.warning(
299+
f"{__qualname__} does not have attribute {method_name}"
300+
)
294301

295302
if builtins.type(method) is not FunctionType:
296303
continue
@@ -368,6 +375,15 @@ def _remove_exceptions(gm: torch.fx.GraphModule) -> bool:
368375
return changed
369376

370377

378+
def _replace_tensor_meta_with_rank(gm: torch.fx.GraphModule):
379+
for node in gm.graph.nodes:
380+
if node.op != "output" and "tensor_meta" in node.meta:
381+
node.meta["tensor_rank"] = acc_utils.map_tensor_metadata(
382+
node.meta["tensor_meta"], lambda x: len(x.shape)
383+
)
384+
del node.meta["tensor_meta"]
385+
386+
371387
def trace(
372388
mod: nn.Module,
373389
sample_inputs: Sequence[Any],
@@ -450,9 +466,13 @@ def trace(
450466
# nodes after removing assertions and exceptions.
451467
traced.graph.eliminate_dead_code()
452468

469+
# Run shape prop to add node.meta["type"] to nodes, needed for NormalizeArgs.
470+
shape_prop.ShapeProp(traced).propagate(*sample_inputs)
471+
# Swap out tensor_meta for tensor_rank, because we don't actually want to rely on
472+
# tensor_meta yet for normalization/lowering, though rank shouldn't change.
473+
_replace_tensor_meta_with_rank(traced)
453474
# Now normalize args/kwargs to make default values visible. Leave args/kwargs as
454475
# they were, since all-kwarg normalization is broken, and we don't need it anyway.
455-
shape_prop.ShapeProp(traced).propagate(*sample_inputs)
456476
traced = NormalizeArgs(traced, normalize_to_only_use_kwargs=False).transform()
457477

458478
# Normalize to acc-specialized wrappers for consistency across op naming and

tracer/acc_tracer/acc_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import inspect
22
import json
33
import os
4-
from typing import Any, Tuple, Callable, Union, Dict, List, Optional
54
import re
5+
from typing import Any, Tuple, Callable, Union, Dict, List, Optional
66

77
import torch
88
import torch.fx
9-
from torch.fx.passes.graph_manipulation import (
10-
serialize_module,
11-
)
129
from torch.fx.graph_module import GraphModule
10+
from torch.fx.immutable_collections import immutable_list
1311
from torch.fx.node import _get_qualified_name
1412
from torch.fx.passes import graph_drawer
13+
from torch.fx.passes.graph_manipulation import (
14+
serialize_module,
15+
)
1516
from torch.fx.passes.shape_prop import TensorMetadata
1617

1718

@@ -173,3 +174,18 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str)
173174
name = f"{base}_{int(num) + 1}"
174175

175176
return name
177+
178+
179+
def map_tensor_metadata(a: Any, fn: Callable):
180+
"""
181+
Map some `fn` to `a`, where `a` is either a TensorMetadata, or else a tuple/list
182+
recursively containing TensorMetadata.
183+
"""
184+
if isinstance(a, TensorMetadata):
185+
return fn(a)
186+
elif isinstance(a, tuple):
187+
return tuple(map_tensor_metadata(elem, fn) for elem in a)
188+
assert isinstance(
189+
a, list
190+
), f"Only supporting tuple/list/TensorMetadata, but found {type(a)}"
191+
return immutable_list(map_tensor_metadata(elem, fn) for elem in a)

0 commit comments

Comments
 (0)