Skip to content

Commit 63005e0

Browse files
authored
feat: Improve layer naming (#2162)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6814350 commit 63005e0

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy
7+
8+
# @manual=//deeplearning/trt/python:py_tensorrt
9+
import tensorrt as trt
710
import torch
811
import torch.fx
912
from torch.fx.node import _get_qualified_name
1013
from torch.fx.passes.shape_prop import TensorMetadata
1114
from torch_tensorrt._Input import Input
15+
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
1216
from torch_tensorrt.fx.observer import Observer
1317
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1418

15-
# @manual=//deeplearning/trt/python:py_tensorrt
16-
import tensorrt as trt
1719
from packaging import version
1820

1921
from .converter_registry import DYNAMO_CONVERTERS as CONVERTERS
@@ -232,7 +234,7 @@ def run(
232234
)
233235

234236
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
235-
self._cur_node_name = str(n)
237+
self._cur_node_name = get_node_name(n)
236238
self._cur_node = n
237239
# add "_itensor_to_tensor_meta"
238240
kwargs = dict(n.kwargs)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
1+
import logging
2+
import re
13
from typing import List
24

5+
import tensorrt as trt
36
import torch
47
from torch_tensorrt.fx.converters.converter_utils import (
58
Frameworks,
69
unified_dtype_converter,
710
)
811
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
912

10-
import tensorrt as trt
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
def get_node_name(node: torch.fx.Node) -> str:
17+
# nn_module_stack preserves the call stack of pytorch nn.modules
18+
# The call stack contains a detailed name of the module
19+
# which shows exactly where the module is located in the
20+
# network architecture.
21+
stack_item = node.meta.get("nn_module_stack", None)
22+
# The current node is the last item in the stack
23+
mod_stack = stack_item.popitem() if stack_item else ""
24+
node_name = str(node)
25+
if mod_stack:
26+
mod_name = str(mod_stack[0]).replace("___", "/")
27+
# Clean up the module name
28+
mod_name = re.sub("^.*__self", "", mod_name)
29+
mod_name = re.sub(r"_(\d+)$", r"/\g<1>", mod_name)
30+
node_name = mod_name + "/" + node_name
31+
else:
32+
# Try an alternative way to get the module info
33+
# like the node.meta['source_fn'] attr
34+
pass
35+
36+
_LOGGER.debug(f"Node meta name {node_name}")
37+
return node_name
1138

1239

1340
def dynamic_unsupported(node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)