Skip to content

Migrate the quantizer to use aten ops directly #4195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceGenericQuantizer,
CadenceAtenQuantizer,
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
Expand Down Expand Up @@ -64,7 +64,7 @@ def quantize_pt2(

# Get patterns and apply fusion of dq -> op -> q to qop
patterns = [
assert_is_instance(q, CadenceGenericQuantizer).pattern
assert_is_instance(q, CadenceAtenQuantizer).pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a duplicate and not properly stacked with
#4047?

You can use "gh-stack" to help with this in the future

for q in quantizer.quantizers
]
QuantFusion(patterns)(converted_model)
Expand Down
14 changes: 4 additions & 10 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
create_zero_bias_int32,
find_sequential_partitions_aten,
get_conv_args,
quantize_tensor_multiplier,
)
from executorch.exir.pass_base import ExportPass
from torch import fx
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.utils.fuser_utils import legalize_graph
Expand Down Expand Up @@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:

def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
for pattern in self.patterns:
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
graph_module,
pattern.partition_types(),
)
Expand Down Expand Up @@ -373,9 +371,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
op_node,
)
elif isinstance(pattern, LinearPattern) or isinstance(
pattern, LinearFunctionalPattern
):
elif isinstance(pattern, LinearPattern):
args, kwargs = get_args_and_kwargs_linear(
graph_module,
inputs_inputs,
Expand All @@ -385,9 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
bias_inputs,
quant_node,
)
elif isinstance(pattern, LayerNormPattern) or isinstance(
pattern, LayerNormFunctionalPattern
):
elif isinstance(pattern, LayerNormPattern):
args, kwargs = get_args_and_kwargs_layer_norm(
graph_module,
inputs_inputs,
Expand Down
104 changes: 20 additions & 84 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
Expand Down Expand Up @@ -47,17 +47,15 @@ class PartitionAnchors:

class QuantizationPattern(ABC):
@abstractmethod
def partition_types(
self,
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
def partition_types(self) -> list[OpOverload]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to support Python 3.8 here, which doesn't support list[x], and you need to use typing.List[x]

"""
List of types to be passed to find_sequential_partitions.
List of types to be passed to find_sequential_partitions_aten.
"""
pass

@abstractmethod
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass

Expand All @@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload:


class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.addmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload:


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.bmm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload:


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv1d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload:


class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Conv2d]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload:


class LayerNormPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.LayerNorm]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
layer_norm_node = fused_partition[0].nodes[-1]

# Weights and biases are used as fp32 by our kernel, so they are
# passed in as others here along with the normalized shape.
return PartitionAnchors(
inputs=[(layer_norm_node, 0)],
weights=[],
biases=[],
# Ordering: normalized_shape, weights, bias
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
output=[(layer_norm_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default


class LayerNormFunctionalPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.layer_norm]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.layer_norm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload:


class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.Linear]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down Expand Up @@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class LinearFunctionalPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.nn.functional.linear]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
bias = [(linear_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default


class MatmulPattern(QuantizationPattern):
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
return [torch.matmul]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.matmul.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand All @@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload:


class ReluPattern(QuantizationPattern):
def partition_types(self) -> List[Type[torch.nn.Module]]:
return [torch.nn.ReLU]
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.relu.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
Expand Down
26 changes: 11 additions & 15 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
BmmPattern,
Conv1dPattern,
Conv2dPattern,
LayerNormFunctionalPattern,
LayerNormPattern,
LinearFunctionalPattern,
LinearPattern,
MatmulPattern,
QuantizationPattern,
ReluPattern,
)
from executorch.backends.cadence.aot.quantizer.utils import (
find_sequential_partitions_aten,
is_annotated,
no_outside_users,
)
Expand All @@ -31,7 +30,6 @@
from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
Expand Down Expand Up @@ -63,7 +61,7 @@
bias_qspec: Optional[QuantizationSpec] = None


class CadenceGenericQuantizer(Quantizer):
class CadenceAtenQuantizer(Quantizer):
def __init__(
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
) -> None:
Expand All @@ -72,7 +70,7 @@ def __init__(
self.quantization_config = quantization_config

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fused_partitions = find_sequential_partitions(
fused_partitions = find_sequential_partitions_aten(
model,
self.pattern.partition_types(),
)
Expand Down Expand Up @@ -154,15 +152,13 @@ def __init__(self) -> None:
)
super().__init__(
[
CadenceGenericQuantizer(AddmmPattern(), static_qconfig),
CadenceGenericQuantizer(BmmPattern(), static_qconfig),
CadenceGenericQuantizer(Conv1dPattern(), static_qconfig),
CadenceGenericQuantizer(Conv2dPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),
CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(LinearPattern(), static_qconfig),
CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig),
CadenceGenericQuantizer(MatmulPattern(), static_qconfig),
CadenceGenericQuantizer(ReluPattern(), static_qconfig),
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
]
)
Loading
Loading