Skip to content

ModAI changes to export xnnpack delegated non_lowered_server_model #10989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 93 additions & 36 deletions exir/passes/quantize_io_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from executorch.exir import EdgeProgramManager
from executorch.exir import EdgeProgramManager, ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -39,11 +39,33 @@ def quantize_input(
if len(target_placeholder.users) != 1:
raise ValueError(f"Input {input_index} has more than one users")
quantize = next(iter(target_placeholder.users))
if quantize.target not in [
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]:
raise ValueError(
f"Input {input_index} is not used by a quantize op. It's used by {quantize.target}"
)

if (
quantize.target
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
raise ValueError(f"Input {input_index} is not used by a quantize op")
replacement_op_dequant = (
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
)
replacement_op_quant = (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
)
elif quantize.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
replacement_op_dequant = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
replacement_op_quant = (
torch.ops.quantized_decomposed.quantize_per_tensor.default
)
else:
raise ValueError(f"Invalid quantize op: {quantize.target}")

# If user specified qparams are different from args of quantize op, we do requantization instead of eliminating quantize op
need_requant = False
Expand Down Expand Up @@ -83,7 +105,7 @@ def quantize_input(

with exported_program.graph_module.graph.inserting_before(quantize):
input_dequant = exported_program.graph_module.graph.call_function(
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
replacement_op_dequant,
args=(
target_placeholder,
*quant_args,
Expand All @@ -106,10 +128,8 @@ def quantize_input(
logger.info(f"Modifying program to take quantized input at index {input_index}")
logger.info(f"Quantization parameters: {quant_args}")

target_placeholder.meta["val"] = (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
target_placeholder.meta["val"], *quant_args
)
target_placeholder.meta["val"] = replacement_op_quant(
target_placeholder.meta["val"], *quant_args
)
quantize.replace_all_uses_with(quantize.args[0])

Expand Down Expand Up @@ -138,10 +158,10 @@ def quantize_output(exported_program, output_index):
)

target_output = output_list[output_index]
if (
target_output.target
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
):
if target_output.target not in [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]:
raise ValueError("Output {output_index} is not a dequantize op")

dequant = target_output
Expand Down Expand Up @@ -185,6 +205,7 @@ def __init__(
edge_program_manager: EdgeProgramManager,
quantized_inputs_idx: Union[Dict[int, Dict[str, Any]], List[int]],
method_name: Optional[str] = None,
exported_program: Optional[ExportedProgram] = None,
):
super().__init__()
self.edge_program_manager = edge_program_manager
Expand All @@ -196,31 +217,49 @@ def __init__(
for idx in quantized_inputs_idx:
self.quantized_inputs_idx_dict[idx] = None
self.param_prefix_name = method_name
self.exported_program = exported_program
self.quant_args = {}

def call(self, graph_module: torch.fx.GraphModule):
for i, qparams in self.quantized_inputs_idx_dict.items():
quant_args = quantize_input(
self.edge_program_manager.exported_program(), i, qparams
)

def edge_manager_update_quant_config_method(self, idx, quant_args):
if self.edge_program_manager is not None:
if not self.edge_program_manager._config_methods:
self.edge_program_manager._config_methods = {}

self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "input", i, "scale")
get_config_method_name(self.param_prefix_name, "input", idx, "scale")
] = quant_args[0]
self.edge_program_manager._config_methods[ # pyre-ignore
get_config_method_name(self.param_prefix_name, "input", i, "zp")
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "input", idx, "zp")
] = quant_args[1]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "input", i, "quant_min")
get_config_method_name(
self.param_prefix_name, "input", idx, "quant_min"
)
] = quant_args[2]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "input", i, "quant_max")
get_config_method_name(
self.param_prefix_name, "input", idx, "quant_max"
)
] = quant_args[3]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "input", i, "dtype")
get_config_method_name(self.param_prefix_name, "input", idx, "dtype")
] = scalar_type_enum(quant_args[4])

def edge_manager_update_quant_config_methods_all(self):
if self.edge_program_manager is not None:
for idx, val in self.quant_args.items():
self.edge_manager_update_quant_config_method(idx, val)

def call(self, graph_module: torch.fx.GraphModule):
for i, qparams in self.quantized_inputs_idx_dict.items():
exported_program = (
self.edge_program_manager.exported_program()
if self.edge_program_manager is not None
else self.exported_program
)
self.quant_args[i] = quantize_input(exported_program, i, qparams)
self.edge_manager_update_quant_config_method(i, self.quant_args[i])

return PassResult(graph_module, True)


Expand All @@ -230,35 +269,53 @@ def __init__(
edge_program_manager: EdgeProgramManager,
quantized_outputs_idx_list: List[int],
method_name: Optional[str] = None,
exported_program: Optional[ExportedProgram] = None,
):
super().__init__()
self.edge_program_manager = edge_program_manager
self.quantized_outputs_idx_list = quantized_outputs_idx_list
self.param_prefix_name = method_name
self.exported_program = exported_program
self.dequant_args = {}

def call(self, graph_module: torch.fx.GraphModule):
for i in self.quantized_outputs_idx_list:
dequant_args = quantize_output(
self.edge_program_manager.exported_program(), i
) # noqa F841

def edge_manager_update_quant_config_method(self, idx, dequant_args):
if self.edge_program_manager is not None:
if not self.edge_program_manager._config_methods:
self.edge_program_manager._config_methods = {}

self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "output", i, "scale")
get_config_method_name(self.param_prefix_name, "output", idx, "scale")
] = dequant_args[0]
self.edge_program_manager._config_methods[ # pyre-ignore
get_config_method_name(self.param_prefix_name, "output", i, "zp")
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "output", idx, "zp")
] = dequant_args[1]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "output", i, "quant_min")
get_config_method_name(
self.param_prefix_name, "output", idx, "quant_min"
)
] = dequant_args[2]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "output", i, "quant_max")
get_config_method_name(
self.param_prefix_name, "output", idx, "quant_max"
)
] = dequant_args[3]
self.edge_program_manager._config_methods[
get_config_method_name(self.param_prefix_name, "output", i, "dtype")
get_config_method_name(self.param_prefix_name, "output", idx, "dtype")
] = scalar_type_enum(dequant_args[4])

def edge_manager_update_quant_config_methods_all(self):
if self.edge_program_manager is not None:
for idx, val in self.dequant_args.items():
self.edge_manager_update_quant_config_method(idx, val)

def call(self, graph_module: torch.fx.GraphModule):
for i in self.quantized_outputs_idx_list:
exported_program = (
self.edge_program_manager.exported_program()
if self.edge_program_manager is not None
else self.exported_program
)
self.dequant_args[i] = quantize_output(exported_program, i) # noqa F841
self.edge_manager_update_quant_config_method(i, self.dequant_args[i])

return PassResult(graph_module, True)
Loading