Skip to content

[MPS] Add support for Int4 groupwise quantization #4623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
convert_to_flatbuffer,
)
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
from executorch.exir._serialize._program import Cord

from executorch.exir.backend.backend_details import (
Expand Down Expand Up @@ -127,6 +126,8 @@ def preprocess(
op_handler[node.op](edge_program, node_visitors, node, mps_graph)

segment_data, mps_graph = _extract_constant_segment(mps_graph)
if logging.DEBUG >= logging.root.level:
pretty_print(mps_graph)

# Add to aggregate segments cord with padding.
padding_length = _padding_required(len(segment_data), 16)
Expand Down Expand Up @@ -160,9 +161,6 @@ def preprocess(
# Append the segment data to the end of the mps graph
combined.append(segment_data)

if logging.DEBUG >= logging.root.level:
pretty_print(mps_graph)

return PreprocessResult(processed_bytes=bytes(combined))

@staticmethod
Expand Down Expand Up @@ -198,10 +196,8 @@ def handle_placeholder(
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# Handle only constants. Placeholders have already
# been visited in `process_input_placeholders`
if is_parameter(edge_program, node):
node_visitors[node.op].define_tensor(node, mps_graph)
# Constants are handled directly when visiting the nodes.
pass

@staticmethod
def handle_output(
Expand Down Expand Up @@ -257,7 +253,6 @@ def tensor_to_str(mps_tensor: MPSTensor):
tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
tensor_str += "dims=" + str(mps_tensor.dims) + ", "
tensor_str += "constant_buffer=" + str(mps_tensor.constant_buffer) + ", "
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
tensor_str += ")"
Expand Down
17 changes: 12 additions & 5 deletions backends/apple/mps/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
#

from . import ( # noqa
# Activation ops
activation_ops,
# binary ops
# Binary ops
binary_ops,
# Clamp ops
clamp_ops,
Expand All @@ -22,6 +21,10 @@
normalization_ops,
op_clone,
op_getitem,
# Quant-Dequant ops
op_quant_dequant,
# Skip ops
op_skip_ops,
# Pad ops
pad_ops,
# Pooling ops
Expand All @@ -32,7 +35,7 @@
reduce_ops,
# Shape ops
shape_ops,
# unary ops
# Unary ops
unary_ops,
)

Expand All @@ -41,8 +44,6 @@
op_clone,
# Binary ops
binary_ops,
# Unary ops
unary_ops,
# Activation ops
activation_ops,
# Linear algebra ops
Expand All @@ -67,4 +68,10 @@
pad_ops,
# Range ops
range_ops,
# Unary ops
unary_ops,
# Quant-Dequant ops
op_quant_dequant,
# Skip ops
op_skip_ops,
]
59 changes: 47 additions & 12 deletions backends/apple/mps/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def define_tensor(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
mps_data_type: MPSDataType = None,
) -> int:
"""Defines a tensor value into the MPSGraph serialization schema
Expand All @@ -89,7 +90,7 @@ def define_tensor(
# Get a unique id for the node.
id = self.get_serialized_id(node, mps_graph)
cb_size, constant_buffer, mps_data_type = self.get_serialized_buffer(
node, mps_graph, id
node, mps_graph, id, mps_data_type
)
dims = get_shape(node)

Expand Down Expand Up @@ -143,6 +144,9 @@ def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[i
mps_graph.mps_values.append(mps_tensor)
return self.tensor_to_id[node]

def hash_tensor(self, tensor):
return hash(tuple(tensor.reshape(-1).tolist()))

def define_constant(
self,
constant_tensor: torch.tensor,
Expand All @@ -155,9 +159,12 @@ def define_constant(
mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
"""
constant_tensor = constant_tensor.contiguous()
# MPS TODO: cache these values
id = len(mps_graph.mps_values)
self.tensor_to_id[constant_tensor] = id
hash = self.hash_tensor(constant_tensor)
if hash in self.tensor_to_id:
return self.tensor_to_id[hash]

id = self.get_serialized_id(constant_tensor, mps_graph, hash)

mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype)
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
constant_tensor, mps_graph, mps_data_type, id
Expand Down Expand Up @@ -189,9 +196,10 @@ def define_scalar(
"""
assert isinstance(val, int) or isinstance(val, float)

# MPS TODO: cache these values
id = len(mps_graph.mps_values)
self.tensor_to_id[val] = id
if val in self.tensor_to_id:
return self.tensor_to_id[val]

id = self.get_serialized_id(val, mps_graph, val)

tensor = torch.tensor(val)
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
Expand All @@ -214,6 +222,7 @@ def get_serialized_buffer(
node: torch.fx.Node,
mps_graph: MPSGraph,
node_id: int,
mps_data_type: MPSDataType = None,
) -> Tuple[int, Buffer, MPSDataType]:
"""
If tensor holds some constant data, serialize it and return the
Expand All @@ -226,7 +235,9 @@ def get_serialized_buffer(
Returns:
_type_: _description_
"""
mps_data_type = self.get_serialized_dtype(node)
mps_data_type = (
self.get_serialized_dtype(node) if mps_data_type is None else mps_data_type
)

# Check if this node is a lifted parameter
if not is_parameter(self.exported_program, node):
Expand Down Expand Up @@ -255,6 +266,22 @@ def get_serialized_data(
if id not in mps_graph.constant_ids:
mps_graph.constant_ids.append(id)

if (
mps_data_type is MPSDataType.mps_data_type_int4
and tensor.dtype is torch.int8
):
if tensor.dim() != 2:
raise RuntimeError(f"Unexpected tensor shape {tensor.shape}")

tensor = tensor.to(dtype=torch.int32)
tensor = (((tensor[::, ::2] & 0x0F) << 4) | (tensor[::, 1::2] & 0x0F)).to(
torch.uint8
)
tensor = (
torch._convert_weight_to_int4pack(tensor.to("mps"), 2)
.cpu()
.view(dtype=torch.uint8)
)
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
array = ctypes.cast(
tensor.untyped_storage().data_ptr(),
Expand All @@ -265,7 +292,7 @@ def get_serialized_data(
return tensor.untyped_storage().nbytes(), buffer, mps_data_type

def get_serialized_id(
self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph
self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph, hash=None
) -> int:
"""
Map a tensor to a unique id. If the tensor was already mapped, return
Expand All @@ -278,19 +305,27 @@ def get_serialized_id(
Returns:
int: _description_
"""
if node in self.tensor_to_id:
if hash is not None and hash in self.tensor_to_id:
return self.tensor_to_id[hash]
elif node in self.tensor_to_id:
return self.tensor_to_id[node]

id = len(mps_graph.mps_values)
self.tensor_to_id[node] = id
if hash is not None:
self.tensor_to_id[hash] = id
else:
self.tensor_to_id[node] = id

return id

def torch_dtype_to_mps_dtype(self, torch_dtype: torch.dtype) -> MPSDataType:
return edge_dtype_to_mps_dtype(torch_dtype)

def get_serialized_dtype(
self,
node: torch.fx.Node,
) -> MPSDataType:
return edge_dtype_to_mps_dtype(node.meta["val"].dtype)
return self.torch_dtype_to_mps_dtype(node.meta["val"].dtype)

def create_tertiary_node(
self, node: torch.fx.Node, mps_graph: MPSGraph, tertiary_op: MPSNodeUnion
Expand Down
142 changes: 142 additions & 0 deletions backends/apple/mps/operators/op_quant_dequant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#
# Copyright (c) 2024 Apple Inc. All rights reserved.
# Provided subject to the LICENSE file in the top level directory.
#

import logging
from typing import cast

import torch
from executorch.backends.apple.mps.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
MPSDataType,
MPSDequantizePerChannelGroup,
MPSGraph,
MPSNode,
)
from executorch.backends.apple.mps.utils.mps_utils import get_input_node

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)


@register_node_visitor
class OpDequantizePerChannelGroupDefault(NodeVisitor):
target = "quantized_decomposed.dequantize_per_channel_group.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# Weights placeholders shouldn't have been defined until this point
if get_input_node(node, 0) in self.tensor_to_id:
raise RuntimeError(
f"Placeholder for {node.target.__name__} already visited"
)
output_id = self.define_tensor(node, mps_graph)
input_id = self.define_tensor(
get_input_node(node, 0), mps_graph, MPSDataType.mps_data_type_int4
)
scales_id = self.define_tensor(get_input_node(node, 1), mps_graph)

# there are no zero points in this quantization method (node.args[2] is all zeros)
zero_points_id = -1
quant_min = cast(int, node.args[3])
quant_max = cast(int, node.args[4])
dtype = self.torch_dtype_to_mps_dtype(node.args[5])
group_size = cast(int, node.args[6])
output_dtype = self.torch_dtype_to_mps_dtype(node.args[7])

dequant_node = MPSNode(
mpsnode_union=MPSDequantizePerChannelGroup(
input1_id=input_id,
output_id=output_id,
scales_id=scales_id,
zero_points_id=zero_points_id,
quant_min=quant_min,
quant_max=quant_max,
dtype=dtype,
group_size=group_size,
output_dtype=output_dtype,
)
)
mps_graph.mps_nodes.append(dequant_node)


@register_node_visitor
class OpQuantizePerToken(NodeVisitor):
"""
Dynamic Quantize Per Token Node visitor
"""

target = "quantized_decomposed.quantize_per_token.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
"""
Skip activation dynamic quantization for now.
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
Issue: #133407308
"""
dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
self.tensor_to_id[node] = dq_input


@register_node_visitor
class OpDequantizePerToken(NodeVisitor):
"""
Dequantize Per Token Node visitor
"""

target = "quantized_decomposed.dequantize_per_token.default"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
"""
Skip activation dynamic quantization for now.
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
Issue: #133407308
"""
dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
self.tensor_to_id[node] = dq_input


@register_node_visitor
class OpChooseQparamsToken(NodeVisitor):
"""
do nothing if node is choose_qparams_per_token_asymmetric.tensor
"""

target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default"

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
"""
Skip activation dynamic quantization for now.
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
Issue: #133407308
"""
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
self.tensor_to_id[node] = [input_id, input_id]
Loading
Loading