Skip to content

Commit c5a1e46

Browse files
committed
Deleted sample_inputs
1 parent c6d3c1a commit c5a1e46

13 files changed

+26
-59
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,14 @@ def compile(
182182
raise AssertionError(
183183
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
184184
)
185-
exported_program = pre_export_lowering(exported_program, None)
185+
exported_program = pre_export_lowering(exported_program)
186186
exported_program = exported_program.run_decompositions(
187187
get_decompositions(enable_experimental_decompositions)
188188
)
189189
gm = exported_program.module()
190190
logger.debug("Input graph: " + str(gm.graph))
191191
# Apply lowering on the graph module
192-
gm = post_lowering(gm, None)
192+
gm = post_lowering(gm)
193193
logger.debug("Lowered Input graph: " + str(gm.graph))
194194

195195
compilation_options = {
@@ -602,7 +602,7 @@ def convert_module_to_trt_engine(
602602
"timing_cache_path": timing_cache_path,
603603
}
604604

605-
exported_program = pre_export_lowering(exported_program, torch_inputs)
605+
exported_program = pre_export_lowering(exported_program)
606606
# Decompose the exported program
607607
exported_program = exported_program.run_decompositions(
608608
get_decompositions(enable_experimental_decompositions)
@@ -611,7 +611,7 @@ def convert_module_to_trt_engine(
611611
logger.debug("Input graph: " + str(gm.graph))
612612

613613
# Apply lowering on the graph module
614-
gm = post_lowering(gm, torch_inputs)
614+
gm = post_lowering(gm)
615615
logger.debug("Lowered Input graph: " + str(gm.graph))
616616

617617
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,25 +86,21 @@ def _remove_lowering_pass(*, index: int) -> None:
8686
return
8787

8888

89-
def post_lowering(
90-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
91-
) -> torch.fx.GraphModule:
89+
def post_lowering(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9290
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
9391
logging.debug(
9492
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
9593
)
96-
return ATEN_POST_LOWERING_PASSES(gm, sample_inputs)
94+
return ATEN_POST_LOWERING_PASSES(gm)
9795

9896

99-
def pre_export_lowering(
100-
ep: torch.export.ExportedProgram, sample_inputs: Sequence[torch.Tensor]
101-
) -> torch.fx.GraphModule:
97+
def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule:
10298
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
10399
logging.debug(
104100
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
105101
)
106102
gm = ep.graph_module
107-
gm = ATEN_PRE_LOWERING_PASSES(gm, sample_inputs)
103+
gm = ATEN_PRE_LOWERING_PASSES(gm)
108104
return ep
109105

110106

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Sequence
2+
from typing import Any
33

44
import torch
55
from torch_tensorrt._utils import sanitized_torch_version
@@ -19,9 +19,7 @@
1919

2020

2121
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
22-
def constant_fold(
23-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
24-
) -> torch.fx.GraphModule:
22+
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
2523
"""Adapted from:
2624
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
2725

py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from typing import Sequence
32

43
import torch
54
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -10,9 +9,7 @@
109

1110

1211
# TODO: Add relevant prims to this fusion
13-
def fuse_prims_broadcast(
14-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15-
) -> torch.fx.GraphModule:
12+
def fuse_prims_broadcast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1613
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
1714
modified_graph = False
1815

py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Callable, Sequence, Tuple
2+
from typing import Callable, Tuple
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -9,9 +9,7 @@
99
logger = logging.getLogger(__name__)
1010

1111

12-
def lower_linear(
13-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14-
) -> torch.fx.GraphModule:
12+
def lower_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1513
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
1614
orig, replacement = linear_replacement()
1715

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def lower_scaled_dot_product_attention(
19-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
19+
gm: torch.fx.GraphModule,
2020
) -> torch.fx.GraphModule:
2121
"""Replace specific versions of scaled_dot_product_attention with an equivalent
2222
implementation which can be easily converted to TRT

py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,15 @@ class DynamoPassManager(PassManager): # type: ignore[misc]
88
def __init__(
99
self,
1010
passes: Optional[
11-
List[
12-
Callable[
13-
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
14-
]
15-
]
11+
List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]
1612
] = None,
1713
):
1814
super().__init__(passes)
1915

2016
@classmethod
2117
def build_from_passlist(
2218
cls,
23-
passes: Optional[
24-
List[
25-
Callable[
26-
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
27-
]
28-
]
29-
],
19+
passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]],
3020
) -> Any:
3121
pm = DynamoPassManager(passes)
3222
return pm
@@ -47,11 +37,11 @@ def add_pass_with_index(
4737
def remove_pass_with_index(self, index: int) -> None:
4838
del self.passes[index]
4939

50-
def __call__(self, gm: Any, sample_inputs: Any) -> Any:
40+
def __call__(self, gm: Any) -> Any:
5141
self.validate()
52-
out, example_inputs = gm, sample_inputs
42+
out = gm
5343
for _pass in self.passes:
54-
out = _pass(out, example_inputs)
44+
out = _pass(out)
5545
return out
5646

5747
def __str__(self) -> str:

py/torch_tensorrt/dynamo/lowering/passes/remove_detach.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import logging
2-
from typing import Sequence
32

43
import torch
54

65
logger = logging.getLogger(__name__)
76

87

9-
def remove_detach(
10-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
11-
) -> torch.fx.GraphModule:
8+
def remove_detach(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
129
"""Remove detach ops in the graph"""
1310
count = 0
1411
for node in gm.graph.nodes:

py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from typing import Sequence
32

43
import torch
54
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -10,9 +9,7 @@
109

1110

1211
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
13-
def remove_input_alias_fixing_clones(
14-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15-
) -> torch.fx.GraphModule:
12+
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1613
"""Remove the auxiliary clone nodes inserted to fix input aliasing
1714
1815
See: https://github.com/pytorch/pytorch/issues/108079

py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from typing import Sequence
32

43
import torch
54
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -10,9 +9,7 @@
109
logger = logging.getLogger(__name__)
1110

1211

13-
def repair_input_as_output(
14-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15-
) -> torch.fx.GraphModule:
12+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1613
"""Repair scenarios where inputs are also outputs of the graph
1714
1815
TRT does not allow such cases, so we insert a clone (identity) layer

py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import operator
3-
from typing import Sequence
43

54
import torch
65
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -11,7 +10,7 @@
1110

1211

1312
def replace_max_pool_with_indices(
14-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
13+
gm: torch.fx.GraphModule,
1514
) -> torch.fx.GraphModule:
1615
"""Replace MaxPool nodes which return unused indices"""
1716
replacement_dict = {

py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import List, Sequence
2+
from typing import List
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -11,9 +11,7 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14-
def view_to_reshape(
15-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
16-
) -> torch.fx.GraphModule:
14+
def view_to_reshape(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1715
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
1816
orig_op = torch.ops.aten.view.default
1917
replacement_op = torch.ops.aten.reshape.default

tests/py/dynamo/models/test_models_export_kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def forward(self, x, b=5, c=None, d=None):
6666
cos_sim = cosine_similarity(model(*args, **kwargs), trt_mod(*args, **kwargs)[0])
6767
assertions.assertTrue(
6868
cos_sim > COSINE_THRESHOLD,
69-
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
69+
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
7070
)
7171

7272
# Clean up model env

0 commit comments

Comments
 (0)