Skip to content

Add fuse_dq_q_pass in exir/passes and also add it to HTP backend #2295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

Expand Down
103 changes: 76 additions & 27 deletions exir/passes/remove_noop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,87 @@

# pyre-strict

from typing import List, Tuple

import torch
from executorch.exir.pass_base import ExportPass, ProxyValue
from torch.utils import _pytree as pytree
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule

_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
)
_QUANT_OPS: Tuple[torch._ops.OpOverload] = (
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
)


def eliminate_dq_q(
graph_module: GraphModule,
dequant_nodes: List[torch.fx.Node],
) -> None:
for node in dequant_nodes:
assert node.target in _DEQUANT_OPS
for user in list(node.users):
if user.target in _QUANT_OPS:
# Drop the input arg and check that the qparams are the same.
qparams_dq = list(node.args)[1:]
qparams_q = list(user.args)[1:]
if qparams_dq != qparams_q:
continue
user.replace_all_uses_with(node.args[0])


class RemoveNoopPass(ExportPass):
"""
Removes noops that pass through arguments.
"""

# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
if op not in (
torch.ops.aten.to.dtype,
torch.ops.aten.dropout.default,
torch.ops.aten.slice_copy.Tensor,
):
return super().call_operator(op, args, kwargs, meta)

args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
orig_tensor = (
args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
)

if orig_tensor is op(*args_data, **kwargs_data):
return args[0]

if op == torch.ops.aten.slice_copy.Tensor:
result = op(*args_data, **kwargs_data)
if orig_tensor.size() == result.size():
return args[0]

return super().call_operator(op, args, kwargs, meta)
def call(self, graph_module: GraphModule) -> PassResult:

# In this list we'll collect all the dequant nodes that are inputs to ops that
# are removed in this pass and later check for redundant dq->q patterns and
# remove them.
dequant_nodes = []

for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target not in (
torch.ops.aten.to.dtype,
torch.ops.aten.dropout.default,
torch.ops.aten.slice_copy.Tensor,
):
continue

orig_tensor = node.args[0].meta["val"]

if orig_tensor is node.meta["val"]:
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
# Otherwise, removing only the op will suffice.
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])
continue

if node.target == torch.ops.aten.slice_copy.Tensor:
if orig_tensor.size() == node.meta["val"].size():
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
# Otherwise, removing only the op will suffice.
if node.args[0].target in _DEQUANT_OPS:
dequant_nodes += [node.args[0]]
node.replace_all_uses_with(node.args[0])

graph_module.graph.eliminate_dead_code()
eliminate_dq_q(graph_module, dequant_nodes)
graph_module.graph.lint()
graph_module.graph.eliminate_dead_code()

return PassResult(graph_module, True)
178 changes: 177 additions & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Import passes
import executorch.exir.memory_planning # noqa
import torch
from executorch.exir import EdgeCompileConfig, memory, to_edge
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.emit import emit_program
Expand Down Expand Up @@ -50,6 +50,12 @@
from functorch.experimental import control_flow

from torch import nn

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export
from torch.fx import GraphModule, subgraph_rewriter
from torch.fx.experimental.proxy_tensor import make_fx
Expand Down Expand Up @@ -1244,3 +1250,173 @@ def forward(self, x):
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
# return (copy__default, aten_add_tensor)
self.assertEqual(count_copies(gm), 1)

def test_remove_quantized_op_noop_pass(self) -> None:
class TestAddSliceNoop(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = x + x
x = x + x[:]
return x

class TestAddSliceNotNoop(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = x + x
x = x + x[:1]
return x

def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
return sum(
(
node.target
in (
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
)
)
for node in gm.graph.nodes
)

def count_q_nodes(gm: torch.fx.GraphModule) -> int:
return sum(
(
node.target
in (
torch.ops.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
)
)
for node in gm.graph.nodes
)

def quantize_model(
m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
) -> Tuple[EdgeProgramManager, int, int]:
# program capture
m = torch._export.capture_pre_autograd_graph(
m_eager,
example_inputs,
)

quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config()
quantizer.set_global(quantization_config)
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m, fold_quantize=True)
ep = torch.export.export(m, example_inputs)
dq_nodes_pre = count_dq_nodes(ep.graph_module)
q_nodes_pre = count_q_nodes(ep.graph_module)
edge = to_edge(
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
return edge, dq_nodes_pre, q_nodes_pre

example_inputs = (torch.randn(9, 8),)
model = TestAddSliceNoop()
m_eager = model.eval()
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)

dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
# One dq and one q node around the slice copy should have been removed.
self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
self.assertEqual(q_nodes_pre - q_nodes_post, 1)

# Check that the slice_copy is removed by the RemoveNoopPass.
for node in edge.exported_program().graph_module.graph.nodes:
self.assertFalse("slice" in str(node.target))

model = TestAddSliceNotNoop()
m_eager = model.eval()
edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)

dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
# One dq and one q node around the slice copy should have been removed.
self.assertEqual(dq_nodes_pre, dq_nodes_post)
self.assertEqual(q_nodes_pre, q_nodes_post)

# Check that the slice_copy is not removed by the RemoveNoopPass.
self.assertTrue(
any(
"slice" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)

def test_dq_q_no_op_pass(self) -> None:
class TestDqQ(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 1.0, 0, -128, 127, torch.int8
)
q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
dq, 1.0, 0, -128, 127, torch.int8
)
return q

model = TestDqQ()
m_eager = model.eval()
ep = torch.export.export(m_eager, (torch.randn(9, 8),))
edge = to_edge(ep)
# Check that the dq and q nodes are not touched by the RemoveNoopPass.
self.assertTrue(
any(
"dequantize" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)
self.assertTrue(
any(
"quantize" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)

def test_dq_q_different_qparams(self) -> None:
class TestDqQDifferentQParam(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 1.0, 0, -128, 127, torch.int8
)
slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0)
q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
slice_copy_output, 1.0, 0, -127, 127, torch.int8
)
return q

model = TestDqQDifferentQParam()
m_eager = model.eval()
ep = torch.export.export(m_eager, (torch.randn(9, 8),))
edge = to_edge(ep)
print(edge.exported_program().graph_module.graph)
# Check that the dq and q nodes are not touched by the RemoveNoopPass.
self.assertTrue(
any(
"dequantize" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)
self.assertTrue(
any(
"quantize" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)
self.assertFalse(
any(
"slice" in str(node.target)
for node in edge.exported_program().graph_module.graph.nodes
)
)