-
Notifications
You must be signed in to change notification settings - Fork 608
Add div decomposition in ArmQuantizer #5267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2023-2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
import serializer.tosa_serializer as ts | ||
import torch | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
from executorch.backends.arm.tosa_quant_utils import ( | ||
dequantize_value, | ||
get_quant_node_args, | ||
QuantArgs, | ||
quantize_value, | ||
) | ||
from serializer.tosa_serializer import TosaOp | ||
|
||
|
||
@register_node_visitor | ||
class DivVisitor(NodeVisitor): | ||
target = "aten.reciprocal.default" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
# 1/X | ||
|
||
if is_quant_node: | ||
input = inputs[0] | ||
input_qargs = get_quant_node_args(node.all_input_nodes[0]) | ||
output_qargs = get_quant_node_args(list(node.users)[0]) | ||
|
||
div_table = div_table_8bit(input_qargs, output_qargs) | ||
|
||
table_attr = ts.TosaSerializerAttribute() | ||
table_attr.TableAttribute(div_table) | ||
tosa_graph.addOperator( | ||
TosaOp.Op().TABLE, [input.name], [output.name], table_attr | ||
) | ||
|
||
else: | ||
tosa_graph.addOperator( | ||
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] | ||
) | ||
|
||
|
||
def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): | ||
""" | ||
Returns a table mapping 256 entries to div([qmin,qmax]) | ||
""" | ||
|
||
def div(x): | ||
# Convert quantized input to floating point div input space. | ||
v1 = dequantize_value(x, in_quantargs) | ||
# Compute div. | ||
v2 = 1.0 / v1 | ||
# Convert div output back to quantized space. | ||
v3 = quantize_value(v2, out_quantargs) | ||
|
||
return v3 | ||
|
||
return [ | ||
div(x) | ||
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass | ||
|
||
|
||
def get_div_decomposition(op) -> tuple: | ||
""" | ||
Returns the the (reciprocal_op, mul_op), where the ops depends on if | ||
the div op is in exir_ops torch.ops.aten. | ||
""" | ||
if op == exir_ops.edge.aten.div.Tensor: | ||
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) | ||
if op == torch.ops.aten.div.Tensor: | ||
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) | ||
raise RuntimeError(f"Can't get div decomposition for op {op}") | ||
|
||
|
||
class DecomposeDivPass(ExportPass): | ||
""" | ||
This pass decomposes div into a mul and a reciprocal node. | ||
|
||
Example: | ||
y = div(a,b) | ||
Becomes: | ||
x = reciprocal(b) | ||
y = mul(a,x) | ||
""" | ||
|
||
def call_operator(self, op, args, kwargs, meta): | ||
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor): | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
reciprocal_op, mul_op = get_div_decomposition(op) | ||
|
||
numerator = args[0] | ||
denominator = args[1] | ||
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) | ||
|
||
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast, Union | ||
|
||
import torch | ||
from executorch.backends.arm.tosa_mapping import extract_tensor_meta | ||
|
||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix | ||
from torch.fx import GraphModule, Node | ||
|
||
|
||
class ScalarsToAttributePass(ExportPass): | ||
""" | ||
For ops in 'targeted_ops', convert inputs that are scalar values | ||
to attribute Nodes that output the same value. | ||
""" | ||
|
||
targeted_ops = [ | ||
torch.ops.aten.add.Tensor, | ||
torch.ops.aten.sub.Tensor, | ||
torch.ops.aten.sub_.Tensor, | ||
torch.ops.aten.mul.Tensor, | ||
torch.ops.aten.div.Tensor, | ||
] | ||
|
||
def call(self, graph_module: GraphModule) -> PassResult: | ||
for n in graph_module.graph.nodes: | ||
n = cast(Node, n) | ||
if n.op != "call_function" or n.target not in self.targeted_ops: | ||
continue | ||
|
||
biggest_rank = 1 | ||
for arg in n.args: | ||
if isinstance(arg, Node): | ||
_, shape, _ = extract_tensor_meta(arg.meta) | ||
biggest_rank = max(biggest_rank, len(shape)) | ||
|
||
new_args = [] | ||
for arg in n.args: | ||
if isinstance(arg, Node): | ||
new_args.append(arg) | ||
continue | ||
|
||
prefix = "_tensor_constant_" | ||
get_new_attr_name = get_new_attr_name_with_prefix(prefix) | ||
tensor_constant_name = get_new_attr_name(graph_module) | ||
float_tensor = torch.tensor( | ||
float(cast(Union[int, float], arg)) | ||
).reshape((1,) * biggest_rank) | ||
graph_module.register_buffer(tensor_constant_name, float_tensor) | ||
fake_mode = n.meta["val"].fake_mode | ||
|
||
with graph_module.graph.inserting_before(n): | ||
get_attr_node = graph_module.graph.create_node( | ||
"get_attr", tensor_constant_name, (), {} | ||
) | ||
get_attr_node.meta["val"] = fake_mode.from_tensor( | ||
float_tensor, static_shapes=True | ||
) | ||
new_args.append(get_attr_node) | ||
n.args = tuple(new_args) | ||
|
||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the motivation for this API is clear and this is allowed, however I do want to make sure modifying graph in the Quantizer is something we are OK with from general ET lowering point of view. I also see SDPA as a valid use case.
cc @kimishpatel, @mergennachin - any comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - I'll resolve the merge conflict and mark this as ready for review tomorrow unless I get new comments.