Skip to content

Commit 6a7f4db

Browse files
Shirong WuWei Wei
authored andcommitted
First step of refactor lower passes (#74219)
Summary: X-link: pytorch/pytorch#74219 Pull Request resolved: https://github.com/pytorch/fx2trt/pull/18 This is beginning diff for refactor and clean up stuff in lowering process and pass management. Reviewed By: yinghai Differential Revision: D34764123 fbshipit-source-id: 57fa930abefae56654219225167b52d7dd79e03a
1 parent 6110041 commit 6a7f4db

File tree

7 files changed

+132
-33
lines changed

7 files changed

+132
-33
lines changed

fx/lower.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
from .input_tensor_spec import (
2222
InputTensorSpec,
2323
)
24-
from .passes.fuse_pass import (
25-
fuse_permute_linear,
26-
fuse_permute_matmul,
27-
)
24+
from .passes.pass_utils import chain_passes, PassFunc
25+
from .passes.lower_basic_pass import fuse_permute_matmul,fuse_permute_linear
2826
from .passes.remove_duplicate_output_args import (
2927
remove_duplicate_output_args,
3028
)
@@ -74,9 +72,6 @@ class PassContext(NamedTuple):
7472
lower_setting: "LowerSetting"
7573
module_name: str = ""
7674

77-
# Function signature for a graph module pass
78-
PassFunc = Callable[[nn.Module, PassContext], Tuple[nn.Module, PassContext]]
79-
8075

8176
def lower_to_trt(
8277
module: nn.Module,
@@ -85,7 +80,6 @@ def lower_to_trt(
8580
max_workspace_size=1 << 25,
8681
explicit_batch_dimension=False,
8782
fp16_mode=True,
88-
enable_fuse=True,
8983
verbose_log=False,
9084
timing_cache_prefix="",
9185
save_timing_cache=False,
@@ -102,8 +96,6 @@ def lower_to_trt(
10296
max_workspace_size: Maximum size of workspace given to TensorRT.
10397
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
10498
fp16_mode: fp16 config given to TRTModule.
105-
enable_fuse: Enable pass fusion during lowering if set to true. l=Lowering will try to find pattern defined
106-
in fx2trt_oss.fx.passes from original module, and replace with optimized pass before apply lowering.
10799
verbose_log: Enable verbose log for TensorRT if set True.
108100
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
109101
save_timing_cache: Update timing cache with current timing cache data if set to True.
@@ -117,7 +109,6 @@ def lower_to_trt(
117109
max_workspace_size=max_workspace_size,
118110
explicit_batch_dimension=explicit_batch_dimension,
119111
fp16_mode=fp16_mode,
120-
enable_fuse=enable_fuse,
121112
verbose_log=verbose_log,
122113
timing_cache_prefix=timing_cache_prefix,
123114
save_timing_cache=save_timing_cache,
@@ -153,14 +144,12 @@ class LowerSetting:
153144
strict_type_constraints: Require TensorRT engine to strictly follow data type
154145
setting at execution time.
155146
156-
enable_fuse: Enable pass fuse duirng lowering, i.e. fuse multiple operations
157-
as (a->b->c->d)=>(e). Current available fuse source patterns are:
158-
sparse->matmul->add
147+
customized_fuse_pass: List of custmozied pass to apply during lowering process.
148+
149+
lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations
150+
as (a->b->c->d)=>(e). Current basic fuse patterns are:
159151
permute->linear
160152
permute->matmul
161-
unsqueeze->cat->sum
162-
163-
enable_fuse_for_sparsity: Enable pass fuse for sparsity.
164153
165154
verbose_log: Enable TensorRT engine verbose log mode.
166155
@@ -191,8 +180,8 @@ class LowerSetting:
191180
int8_mode: bool = False
192181
max_workspace_size: int = 1 << 30
193182
strict_type_constraints: bool = False
194-
enable_fuse: bool = True
195-
enable_fuse_for_sparsity = False
183+
customized_fuse_pass: Sequence = ()
184+
lower_basic_fuse_pass: Sequence = (fuse_permute_matmul,fuse_permute_linear)
196185
verbose_log: bool = False
197186
algo_selector = None
198187
timing_cache_prefix: str = ""
@@ -249,10 +238,10 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
249238
if self.lower_setting.input_specs
250239
else InputTensorSpec.from_tensors(input)
251240
)
252-
if self.lower_setting.enable_fuse:
253-
mod = fuse_permute_matmul(mod)
254-
mod = fuse_permute_linear(mod)
255-
FUSE_PASSES_POST_OBSERVER.observe(mod, input)
241+
242+
if self.lower_setting.lower_basic_fuse_pass:
243+
lower_pass = chain_passes(*self.lower_setting.lower_basic_fuse_pass)
244+
lower_pass(mod, input)
256245

257246
# Prepare algorithm selector and timing_cache for TRTInterpreter
258247
algo_selector = None
@@ -363,14 +352,16 @@ def __call__(
363352
inputs = tuple(x.half() if x.dtype == torch.float32 else x for x in inputs)
364353

365354
# Ensure ast_rewrite is done for input module before const_fold.
366-
traced_mod = self.trace_func(module, inputs) # type: ignore[misc]
355+
tracer = chain_passes(self.trace_func, *self.lower_setting.customized_fuse_pass)
356+
traced_mod = tracer(module, inputs) # type: ignore[misc]
367357

368358
# Run const folding.
369359
traced_mod = run_const_fold(traced_mod)
370360

371361
# Retrace here to eliminate no-op introduced by const folding and map new introduced
372362
# nodes to acc op nodes.
373-
traced_mod = self.trace_func(traced_mod, inputs) # type: ignore[misc]
363+
traced_mod = tracer(traced_mod, inputs) # type: ignore[misc]
364+
FUSE_PASSES_POST_OBSERVER.observe(traced_mod, inputs)
374365

375366
# Run split.
376367
split_result = self.split_func(traced_mod, inputs, self.lower_setting) # type: ignore[misc,operator]

fx/passes/fuse_pass.py renamed to fx/passes/lower_basic_pass.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@
77
get_attr,
88
)
99
from fx2trt_oss.fx.observer import observable
10+
from fx2trt_oss.fx.passes.pass_utils import log_before_after, validate_inference
11+
from typing import Any
1012

13+
# Create an alias for module input type to avoid littering pyre-ignore for Any
14+
# throughout the file.
15+
Input = Any
1116

12-
def fuse_sparse_matmul_add(gm: torch.fx.GraphModule):
17+
@log_before_after
18+
@validate_inference(atol=1e-3, rtol=1e-2)
19+
def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input):
1320
"""
1421
Replace acc_ops.matmul + acc_ops.add with acc_ops.linear
1522
TRT8.2 can take advantage of structured sparsity (2:4), but the graph needs contain a single FC layer.
@@ -100,7 +107,9 @@ def check_permute(node: torch.fx.Node):
100107

101108

102109
@observable()
103-
def fuse_permute_linear(gm: torch.fx.GraphModule):
110+
@log_before_after
111+
@validate_inference(atol=1e-3, rtol=1e-2)
112+
def fuse_permute_linear(gm: torch.fx.GraphModule, input: Input):
104113
"""
105114
Fuse pattern like permute + linear if permute is transposing the last two dimension.
106115
"""
@@ -122,7 +131,9 @@ def fuse_permute_linear(gm: torch.fx.GraphModule):
122131

123132

124133
@observable()
125-
def fuse_permute_matmul(gm: torch.fx.GraphModule):
134+
@log_before_after
135+
@validate_inference(atol=1e-3, rtol=1e-2)
136+
def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input):
126137
"""
127138
Fuse pattern like permute + matmul if permute is transposing the last two dimension.
128139
"""
@@ -150,7 +161,6 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule):
150161
gm.recompile()
151162
return gm
152163

153-
154164
try:
155165
# @manual=//deeplearning/trt/python:py_tensorrt
156166
import tensorrt as trt

fx/passes/pass_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import List, Any, Callable
2+
from torch import fx
3+
import logging
4+
import torch
5+
import tempfile
6+
from functools import wraps
7+
from torch.fx.passes.shape_prop import ShapeProp
8+
9+
# Create an alias for module input type to avoid littering pyre-ignore for Any
10+
# throughout the file.
11+
Input = Any
12+
_LOGGER: logging.Logger = logging.getLogger(__name__)
13+
14+
PassFunc = Callable[[fx.GraphModule, Input], fx.GraphModule]
15+
16+
def chain_passes(*passes: PassFunc) -> PassFunc:
17+
"""
18+
Chains a sequence of pass functions to form a single pass function
19+
"""
20+
21+
def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
22+
for pass_ in passes:
23+
if isinstance(module, torch.fx.GraphModule):
24+
ShapeProp(module).propagate(*input)
25+
module = pass_(module, input)
26+
return module
27+
28+
return parent_pass
29+
30+
31+
def validate_inference(rtol=None, atol=None):
32+
def _validate_inference(pass_: PassFunc) -> PassFunc:
33+
"""
34+
Wraps a pass function to validate that its inference results before and
35+
after the pass run should be `allclose`.
36+
"""
37+
38+
@wraps(pass_)
39+
def pass_with_validation(
40+
module: fx.GraphModule, input: Input
41+
) -> fx.GraphModule:
42+
res0 = module(*input)
43+
module = pass_(module, input)
44+
res1 = module(*input)
45+
46+
tensor_res_0 = _collect_tensors(res0)
47+
tensor_res_1 = _collect_tensors(res1)
48+
49+
for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
50+
kwargs = {}
51+
if rtol:
52+
kwargs["rtol"] = rtol
53+
if atol:
54+
kwargs["atol"] = atol
55+
assert torch.allclose(
56+
x, y, **kwargs
57+
), f"pass {pass_} failed correctness check due to output {kk}"
58+
return module
59+
60+
return pass_with_validation
61+
62+
return _validate_inference
63+
64+
65+
def log_before_after(pass_: PassFunc) -> PassFunc:
66+
"""
67+
Wraps a pass function to log the module graph before and after the pass
68+
"""
69+
70+
@wraps(pass_)
71+
def pass_with_before_after_log(
72+
module: fx.GraphModule, input: Input
73+
) -> fx.GraphModule:
74+
with tempfile.NamedTemporaryFile(
75+
mode="w",
76+
encoding="utf-8",
77+
delete=False,
78+
) as f:
79+
print(f"== Log pass {pass_} before/after graph to {f.name}")
80+
print(f"[{pass_}] Before:\n{module.graph}", file=f)
81+
module = pass_(module, input)
82+
print(f"[{pass_}] After:\n{module.graph}", file=f)
83+
return module
84+
85+
return pass_with_before_after_log
86+
87+
88+
def _collect_tensors(arg: fx.node.Argument) -> List[torch.Tensor]:
89+
"""Collects all the tensors found in a nested container object"""
90+
res: List[torch.Tensor] = []
91+
92+
def collect(x: fx.node.Argument) -> fx.node.Argument:
93+
if isinstance(x, torch.Tensor):
94+
res.append(x)
95+
return x
96+
97+
fx.node.map_aggregate(arg, collect)
98+
return res

test/passes/test_fuse_permute_linear_trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
55
from torch.testing._internal.common_fx2trt import AccTestCase
6-
from fx2trt_oss.fx.passes.fuse_pass import (
6+
from fx2trt_oss.fx.passes.lower_basic_pass import (
77
fuse_permute_linear,
88
trt_transposed_linear,
99
)

test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
55
from torch.testing._internal.common_fx2trt import AccTestCase
66
from parameterized import parameterized, param
7-
from fx2trt_oss.fx.passes.fuse_pass import (
7+
from fx2trt_oss.fx.passes.lower_basic_pass import (
88
fuse_permute_matmul,
99
trt_transposed_matmul,
1010
)

test/passes/test_multi_fuse_trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import fx2trt_oss.tracer.acc_tracer.acc_ops as acc_ops
55
from torch.testing._internal.common_fx2trt import AccTestCase
66
from parameterized import parameterized
7-
from fx2trt_oss.fx.passes.fuse_pass import (
7+
from fx2trt_oss.fx.passes.lower_basic_pass import (
88
fuse_permute_linear,
99
trt_transposed_linear,
1010
fuse_permute_matmul,

test/trt_lower/test_observer_gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def forward(self, x, y):
2929

3030
with execution_verifier() as verify_execution:
3131

32-
lowerer = lower.Lowerer.create(lower_setting=lower.LowerSetting(enable_fuse=True))
32+
lowerer = lower.Lowerer.create(lower_setting=lower.LowerSetting())
3333
# Update `lowerer.split_func` to make sure the test model is split
3434
# onto the trt partition:
3535
lowerer = replace(

0 commit comments

Comments
 (0)