Skip to content

Commit 7a8a2c6

Browse files
committed
Add support for Int4 groupwise quantization
1 parent 85b7869 commit 7a8a2c6

23 files changed

+1124
-177
lines changed

backends/apple/mps/mps_preprocess.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
2626
convert_to_flatbuffer,
2727
)
28-
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
2928
from executorch.exir._serialize._program import Cord
3029

3130
from executorch.exir.backend.backend_details import (
@@ -127,6 +126,8 @@ def preprocess(
127126
op_handler[node.op](edge_program, node_visitors, node, mps_graph)
128127

129128
segment_data, mps_graph = _extract_constant_segment(mps_graph)
129+
if logging.DEBUG >= logging.root.level:
130+
pretty_print(mps_graph)
130131

131132
# Add to aggregate segments cord with padding.
132133
padding_length = _padding_required(len(segment_data), 16)
@@ -160,9 +161,6 @@ def preprocess(
160161
# Append the segment data to the end of the mps graph
161162
combined.append(segment_data)
162163

163-
if logging.DEBUG >= logging.root.level:
164-
pretty_print(mps_graph)
165-
166164
return PreprocessResult(processed_bytes=bytes(combined))
167165

168166
@staticmethod
@@ -198,10 +196,8 @@ def handle_placeholder(
198196
node: torch.fx.Node,
199197
mps_graph: MPSGraph,
200198
) -> None:
201-
# Handle only constants. Placeholders have already
202-
# been visited in `process_input_placeholders`
203-
if is_parameter(edge_program, node):
204-
node_visitors[node.op].define_tensor(node, mps_graph)
199+
# Constants are handled directly when visiting the nodes.
200+
pass
205201

206202
@staticmethod
207203
def handle_output(
@@ -257,7 +253,6 @@ def tensor_to_str(mps_tensor: MPSTensor):
257253
tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
258254
tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
259255
tensor_str += "dims=" + str(mps_tensor.dims) + ", "
260-
tensor_str += "constant_buffer=" + str(mps_tensor.constant_buffer) + ", "
261256
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
262257
tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
263258
tensor_str += ")"

backends/apple/mps/operators/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
#
55

66
from . import ( # noqa
7-
# Activation ops
87
activation_ops,
9-
# binary ops
8+
# Binary ops
109
binary_ops,
1110
# Clamp ops
1211
clamp_ops,
@@ -22,6 +21,10 @@
2221
normalization_ops,
2322
op_clone,
2423
op_getitem,
24+
# Quant-Dequant ops
25+
op_quant_dequant,
26+
# Skip ops
27+
op_skip_ops,
2528
# Pad ops
2629
pad_ops,
2730
# Pooling ops
@@ -32,7 +35,7 @@
3235
reduce_ops,
3336
# Shape ops
3437
shape_ops,
35-
# unary ops
38+
# Unary ops
3639
unary_ops,
3740
)
3841

@@ -41,8 +44,6 @@
4144
op_clone,
4245
# Binary ops
4346
binary_ops,
44-
# Unary ops
45-
unary_ops,
4647
# Activation ops
4748
activation_ops,
4849
# Linear algebra ops
@@ -67,4 +68,10 @@
6768
pad_ops,
6869
# Range ops
6970
range_ops,
71+
# Unary ops
72+
unary_ops,
73+
# Quant-Dequant ops
74+
op_quant_dequant,
75+
# Skip ops
76+
op_skip_ops,
7077
]

backends/apple/mps/operators/node_visitor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def define_tensor(
7272
self,
7373
node: torch.fx.Node,
7474
mps_graph: MPSGraph,
75+
mps_data_type: MPSDataType = None,
7576
) -> int:
7677
"""Defines a tensor value into the MPSGraph serialization schema
7778
@@ -89,7 +90,7 @@ def define_tensor(
8990
# Get a unique id for the node.
9091
id = self.get_serialized_id(node, mps_graph)
9192
cb_size, constant_buffer, mps_data_type = self.get_serialized_buffer(
92-
node, mps_graph, id
93+
node, mps_graph, id, mps_data_type
9394
)
9495
dims = get_shape(node)
9596

@@ -143,6 +144,9 @@ def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[i
143144
mps_graph.mps_values.append(mps_tensor)
144145
return self.tensor_to_id[node]
145146

147+
def hash_tensor(self, tensor):
148+
return hash(tuple(tensor.reshape(-1).tolist()))
149+
146150
def define_constant(
147151
self,
148152
constant_tensor: torch.tensor,
@@ -155,9 +159,12 @@ def define_constant(
155159
mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
156160
"""
157161
constant_tensor = constant_tensor.contiguous()
158-
# MPS TODO: cache these values
159-
id = len(mps_graph.mps_values)
160-
self.tensor_to_id[constant_tensor] = id
162+
hash = self.hash_tensor(constant_tensor)
163+
if hash in self.tensor_to_id:
164+
return self.tensor_to_id[hash]
165+
166+
id = self.get_serialized_id(constant_tensor, mps_graph, hash)
167+
161168
mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype)
162169
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
163170
constant_tensor, mps_graph, mps_data_type, id
@@ -189,9 +196,10 @@ def define_scalar(
189196
"""
190197
assert isinstance(val, int) or isinstance(val, float)
191198

192-
# MPS TODO: cache these values
193-
id = len(mps_graph.mps_values)
194-
self.tensor_to_id[val] = id
199+
if val in self.tensor_to_id:
200+
return self.tensor_to_id[val]
201+
202+
id = self.get_serialized_id(val, mps_graph, val)
195203

196204
tensor = torch.tensor(val)
197205
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
@@ -214,6 +222,7 @@ def get_serialized_buffer(
214222
node: torch.fx.Node,
215223
mps_graph: MPSGraph,
216224
node_id: int,
225+
mps_data_type: MPSDataType = None,
217226
) -> Tuple[int, Buffer, MPSDataType]:
218227
"""
219228
If tensor holds some constant data, serialize it and return the
@@ -226,7 +235,9 @@ def get_serialized_buffer(
226235
Returns:
227236
_type_: _description_
228237
"""
229-
mps_data_type = self.get_serialized_dtype(node)
238+
mps_data_type = (
239+
self.get_serialized_dtype(node) if mps_data_type is None else mps_data_type
240+
)
230241

231242
# Check if this node is a lifted parameter
232243
if not is_parameter(self.exported_program, node):
@@ -255,6 +266,22 @@ def get_serialized_data(
255266
if id not in mps_graph.constant_ids:
256267
mps_graph.constant_ids.append(id)
257268

269+
if (
270+
mps_data_type is MPSDataType.mps_data_type_int4
271+
and tensor.dtype is torch.int8
272+
):
273+
if tensor.dim() != 2:
274+
raise RuntimeError(f"Unexpected tensor shape {tensor.shape}")
275+
276+
tensor = tensor.to(dtype=torch.int32)
277+
tensor = (((tensor[::, ::2] & 0x0F) << 4) | (tensor[::, 1::2] & 0x0F)).to(
278+
torch.uint8
279+
)
280+
tensor = (
281+
torch._convert_weight_to_int4pack(tensor.to("mps"), 2)
282+
.cpu()
283+
.view(dtype=torch.uint8)
284+
)
258285
array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
259286
array = ctypes.cast(
260287
tensor.untyped_storage().data_ptr(),
@@ -265,7 +292,7 @@ def get_serialized_data(
265292
return tensor.untyped_storage().nbytes(), buffer, mps_data_type
266293

267294
def get_serialized_id(
268-
self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph
295+
self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph, hash=None
269296
) -> int:
270297
"""
271298
Map a tensor to a unique id. If the tensor was already mapped, return
@@ -278,19 +305,27 @@ def get_serialized_id(
278305
Returns:
279306
int: _description_
280307
"""
281-
if node in self.tensor_to_id:
308+
if hash is not None and hash in self.tensor_to_id:
309+
return self.tensor_to_id[hash]
310+
elif node in self.tensor_to_id:
282311
return self.tensor_to_id[node]
283312

284313
id = len(mps_graph.mps_values)
285-
self.tensor_to_id[node] = id
314+
if hash is not None:
315+
self.tensor_to_id[hash] = id
316+
else:
317+
self.tensor_to_id[node] = id
286318

287319
return id
288320

321+
def torch_dtype_to_mps_dtype(self, torch_dtype: torch.dtype) -> MPSDataType:
322+
return edge_dtype_to_mps_dtype(torch_dtype)
323+
289324
def get_serialized_dtype(
290325
self,
291326
node: torch.fx.Node,
292327
) -> MPSDataType:
293-
return edge_dtype_to_mps_dtype(node.meta["val"].dtype)
328+
return self.torch_dtype_to_mps_dtype(node.meta["val"].dtype)
294329

295330
def create_tertiary_node(
296331
self, node: torch.fx.Node, mps_graph: MPSGraph, tertiary_op: MPSNodeUnion
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#
2+
# Copyright (c) 2024 Apple Inc. All rights reserved.
3+
# Provided subject to the LICENSE file in the top level directory.
4+
#
5+
6+
import logging
7+
from typing import cast
8+
9+
import torch
10+
from executorch.backends.apple.mps.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
15+
MPSDataType,
16+
MPSDequantizePerChannelGroup,
17+
MPSGraph,
18+
MPSNode,
19+
)
20+
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
21+
22+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
23+
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
24+
25+
26+
@register_node_visitor
27+
class OpDequantizePerChannelGroupDefault(NodeVisitor):
28+
target = "quantized_decomposed.dequantize_per_channel_group.default"
29+
30+
def __init__(self, *args) -> None:
31+
super().__init__(*args)
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
mps_graph: MPSGraph,
37+
) -> None:
38+
# Weights placeholders shouldn't have been defined until this point
39+
if get_input_node(node, 0) in self.tensor_to_id:
40+
raise RuntimeError(
41+
f"Placeholder for {node.target.__name__} already visited"
42+
)
43+
output_id = self.define_tensor(node, mps_graph)
44+
input_id = self.define_tensor(
45+
get_input_node(node, 0), mps_graph, MPSDataType.mps_data_type_int4
46+
)
47+
scales_id = self.define_tensor(get_input_node(node, 1), mps_graph)
48+
49+
# there are no zero points in this quantization method (node.args[2] is all zeros)
50+
zero_points_id = -1
51+
quant_min = cast(int, node.args[3])
52+
quant_max = cast(int, node.args[4])
53+
dtype = self.torch_dtype_to_mps_dtype(node.args[5])
54+
group_size = cast(int, node.args[6])
55+
output_dtype = self.torch_dtype_to_mps_dtype(node.args[7])
56+
57+
dequant_node = MPSNode(
58+
mpsnode_union=MPSDequantizePerChannelGroup(
59+
input1_id=input_id,
60+
output_id=output_id,
61+
scales_id=scales_id,
62+
zero_points_id=zero_points_id,
63+
quant_min=quant_min,
64+
quant_max=quant_max,
65+
dtype=dtype,
66+
group_size=group_size,
67+
output_dtype=output_dtype,
68+
)
69+
)
70+
mps_graph.mps_nodes.append(dequant_node)
71+
72+
73+
@register_node_visitor
74+
class OpQuantizePerToken(NodeVisitor):
75+
"""
76+
Dynamic Quantize Per Token Node visitor
77+
"""
78+
79+
target = "quantized_decomposed.quantize_per_token.default"
80+
81+
def __init__(self, *args) -> None:
82+
super().__init__(*args)
83+
84+
def define_node(
85+
self,
86+
node: torch.fx.Node,
87+
mps_graph: MPSGraph,
88+
) -> None:
89+
"""
90+
Skip activation dynamic quantization for now.
91+
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
92+
Issue: #133407308
93+
"""
94+
dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
95+
self.tensor_to_id[node] = dq_input
96+
97+
98+
@register_node_visitor
99+
class OpDequantizePerToken(NodeVisitor):
100+
"""
101+
Dequantize Per Token Node visitor
102+
"""
103+
104+
target = "quantized_decomposed.dequantize_per_token.default"
105+
106+
def __init__(self, *args) -> None:
107+
super().__init__(*args)
108+
109+
def define_node(
110+
self,
111+
node: torch.fx.Node,
112+
mps_graph: MPSGraph,
113+
) -> None:
114+
"""
115+
Skip activation dynamic quantization for now.
116+
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
117+
Issue: #133407308
118+
"""
119+
dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
120+
self.tensor_to_id[node] = dq_input
121+
122+
123+
@register_node_visitor
124+
class OpChooseQparamsToken(NodeVisitor):
125+
"""
126+
do nothing if node is choose_qparams_per_token_asymmetric.tensor
127+
"""
128+
129+
target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default"
130+
131+
def define_node(
132+
self,
133+
node: torch.fx.Node,
134+
mps_graph: MPSGraph,
135+
) -> None:
136+
"""
137+
Skip activation dynamic quantization for now.
138+
Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
139+
Issue: #133407308
140+
"""
141+
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
142+
self.tensor_to_id[node] = [input_id, input_id]

0 commit comments

Comments
 (0)