Skip to content

Commit af53282

Browse files
committed
fix: Address review comments
- Fix typing issues, add depedencies to `setup.py`, add qualified name checking for module registry
1 parent 60df50e commit af53282

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ commands:
258258
name: Set up python environment
259259
command: |
260260
pip3 install --upgrade pip
261-
pip3 install wheel setuptools pyyaml
261+
pip3 install wheel setuptools
262262
pip3 install nvidia-pyindex
263263
pip3 install tabulate
264264
pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >>

py/setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def run(self):
427427
ext_modules=ext_modules,
428428
install_requires=[
429429
"torch >=2.1.dev,<2.2" if not LEGACY else "torch >=1.13.0,<2.0",
430+
"pyyaml",
431+
"packaging",
430432
],
431433
setup_requires=[],
432434
cmdclass={

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
logger = logging.getLogger(__name__)
1717

1818
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19-
"torch.ops." + str(module.new_operator)
19+
_get_qualified_name(module.new_operator)
2020
for module in MODULE_SUBSTITUTION_REGISTRY.values()
2121
)
2222

py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Callable, Dict
2+
from typing import Any, Callable, Dict, Type
33
import torch
44
import logging
55

@@ -23,11 +23,11 @@ class ModuleReplacement:
2323

2424

2525
# Dictionary mapping module to ModuleReplacement instance
26-
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()
26+
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
2727

2828

2929
def module_substitution(
30-
module_to_replace: torch.nn.Module,
30+
module_to_replace: Type[torch.nn.Module],
3131
new_operator: torch._ops.OpOverload,
3232
enabled: bool = True,
3333
) -> Callable[[Any], Any]:
@@ -102,6 +102,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
102102
# Replace all original node uses and clean up graph
103103
n.replace_all_uses_with(new_node)
104104
gm.graph.eliminate_dead_code()
105+
gm.graph.lint()
105106
gm.recompile()
106107

107108
# A module replacement can fail in the event that the specific instance of the submodule cannot
@@ -115,5 +116,6 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
115116

116117
# Perform cleanup and recompilation before returning module
117118
gm.graph.eliminate_dead_code()
119+
gm.graph.lint()
118120
gm.recompile()
119121
return gm

0 commit comments

Comments
 (0)