Skip to content

Commit d5d0ad2

Browse files
committed
Update base for Update on "Use external_deps for sentencepiece"
as title Differential Revision: [D59770172](https://our.internmc.facebook.com/intern/diff/D59770172/) [ghstack-poisoned]
2 parents 36684e2 + 0cde6b8 commit d5d0ad2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+131818
-993
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
c017c97333dfb9d17f2e5357980241827e50e8d5
1+
4e39cdceb1414b2d416339866a5bb044fbed4977

.ci/scripts/gather_test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"w2l": "linux.12xlarge",
2424
"ic4": "linux.12xlarge",
2525
"resnet50": "linux.12xlarge",
26-
"llava_encoder": "linux.4xlarge",
26+
"llava": "linux.4xlarge",
2727
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2828
"dl3": "linux.12xlarge",
2929
"emformer_join": "linux.12xlarge",
@@ -83,7 +83,7 @@ def model_should_run_on_event(model: str, event: str) -> bool:
8383
We put higher priority and fast models to pull request and rest to push.
8484
"""
8585
if event == "pull_request":
86-
return model in ["add", "ic3", "mv2", "mv3", "resnet18", "vit", "llava_encoder"]
86+
return model in ["add", "ic3", "mv2", "mv3", "resnet18", "vit", "llava"]
8787
return True
8888

8989

@@ -93,7 +93,7 @@ def model_should_run_on_target_os(model: str, target_os: str) -> bool:
9393
For example, a big model can be disabled in macos due to the limited macos resources.
9494
"""
9595
if target_os == "macos":
96-
return model not in ["llava_encoder"]
96+
return model not in ["llava"]
9797
return True
9898

9999

.ci/scripts/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ test_model() {
6767
run_portable_executor_runner
6868
rm "./${MODEL_NAME}.pte"
6969
fi
70-
if [[ "${MODEL_NAME}" == "llava_encoder" ]]; then
70+
if [[ "${MODEL_NAME}" == "llava" ]]; then
7171
# Install requirements for llava
72-
bash examples/models/llava_encoder/install_requirements.sh
72+
bash examples/models/llava/install_requirements.sh
7373
fi
7474
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too
7575
"${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}"

.gitmodules

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@
5555
[submodule "examples/third-party/LLaVA"]
5656
path = examples/third-party/LLaVA
5757
url = https://github.com/haotian-liu/LLaVA.git
58-
[submodule "examples/models/llama2/third-party/re2"]
59-
path = examples/models/llama2/third-party/re2
60-
url = https://github.com/google/re2.git
61-
[submodule "examples/models/llama2/third-party/abseil-cpp"]
62-
path = examples/models/llama2/third-party/abseil-cpp
63-
url = https://github.com/abseil/abseil-cpp.git
6458
[submodule "third-party/ios-cmake"]
6559
path = third-party/ios-cmake
6660
url = https://github.com/leetal/ios-cmake
6761
[submodule "examples/models/phi-3-mini/third-party/sentencepiece"]
6862
path = examples/models/phi-3-mini/third-party/sentencepiece
6963
url = https://github.com/google/sentencepiece.git
64+
[submodule "extension/llm/third-party/re2"]
65+
path = extension/llm/third-party/re2
66+
url = https://github.com/google/re2.git
67+
[submodule "extension/llm/third-party/abseil-cpp"]
68+
path = extension/llm/third-party/abseil-cpp
69+
url = https://github.com/abseil/abseil-cpp.git

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ exclude_patterns = [
77
'third-party/**',
88
'**/third-party/**',
99
'.github/scripts/**',
10+
'exir/serde/**',
1011
]
1112
command = [
1213
'python',
@@ -37,6 +38,7 @@ include_patterns = [
3738
exclude_patterns = [
3839
'third-party/**',
3940
'**/third-party/**',
41+
'exir/serde/**',
4042
]
4143
command = [
4244
'python',

backends/cadence/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ include(${EXECUTORCH_ROOT}/build/Utils.cmake)
2525
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
2626

2727

28-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/operators)
29-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/kernels)
28+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/operators)
29+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/reference/kernels)

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2121
from executorch.backends.cadence.aot.quantizer.quantizer import (
22-
CadenceGenericQuantizer,
22+
CadenceAtenQuantizer,
2323
CadenceQuantizer,
2424
)
2525
from executorch.backends.cadence.aot.utils import model_is_quantized
@@ -64,7 +64,7 @@ def quantize_pt2(
6464

6565
# Get patterns and apply fusion of dq -> op -> q to qop
6666
patterns = [
67-
assert_is_instance(q, CadenceGenericQuantizer).pattern
67+
assert_is_instance(q, CadenceAtenQuantizer).pattern
6868
for q in quantizer.quantizers
6969
]
7070
QuantFusion(patterns)(converted_model)

backends/cadence/aot/functions.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,30 +107,30 @@
107107
variants: function
108108
kernels:
109109
- arg_meta: null
110-
kernel_name: impl::HiFi::quantize_per_tensor_out
110+
kernel_name: impl::reference::quantize_per_tensor_out
111111

112112
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
113113
variants: function
114114
kernels:
115115
- arg_meta: null
116-
kernel_name: impl::HiFi::dequantize_per_tensor_out
116+
kernel_name: impl::reference::dequantize_per_tensor_out
117117

118118
- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
119119
kernels:
120120
- arg_meta: null
121-
kernel_name: impl::HiFi::quantized_conv_out
121+
kernel_name: impl::reference::quantized_conv_out
122122

123123
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
124124
kernels:
125125
- arg_meta: null
126-
kernel_name: impl::HiFi::quantized_layer_norm_out
126+
kernel_name: impl::reference::quantized_layer_norm_out
127127

128128
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
129129
kernels:
130130
- arg_meta: null
131-
kernel_name: impl::HiFi::quantized_linear_out
131+
kernel_name: impl::reference::quantized_linear_out
132132

133133
- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!)
134134
kernels:
135135
- arg_meta: null
136-
kernel_name: impl::HiFi::quantized_relu_out
136+
kernel_name: impl::reference::quantized_relu_out

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
BmmPattern,
1515
Conv1dPattern,
1616
Conv2dPattern,
17-
LayerNormFunctionalPattern,
1817
LayerNormPattern,
19-
LinearFunctionalPattern,
2018
LinearPattern,
2119
MatmulPattern,
2220
ReluPattern,
2321
)
2422
from executorch.backends.cadence.aot.quantizer.utils import (
2523
create_zero_bias_int32,
24+
find_sequential_partitions_aten,
2625
get_conv_args,
2726
quantize_tensor_multiplier,
2827
)
2928
from executorch.exir.pass_base import ExportPass
3029
from torch import fx
31-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
3230
from torch.fx import GraphModule
3331
from torch.fx.passes.infra.pass_base import PassResult
3432
from torch.fx.passes.utils.fuser_utils import legalize_graph
@@ -310,7 +308,7 @@ def __init__(self, patterns) -> None:
310308

311309
def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
312310
for pattern in self.patterns:
313-
fused_partitions = find_sequential_partitions(
311+
fused_partitions = find_sequential_partitions_aten(
314312
graph_module,
315313
pattern.partition_types(),
316314
)
@@ -373,9 +371,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
373371
quant_node,
374372
op_node,
375373
)
376-
elif isinstance(pattern, LinearPattern) or isinstance(
377-
pattern, LinearFunctionalPattern
378-
):
374+
elif isinstance(pattern, LinearPattern):
379375
args, kwargs = get_args_and_kwargs_linear(
380376
graph_module,
381377
inputs_inputs,
@@ -385,9 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
385381
bias_inputs,
386382
quant_node,
387383
)
388-
elif isinstance(pattern, LayerNormPattern) or isinstance(
389-
pattern, LayerNormFunctionalPattern
390-
):
384+
elif isinstance(pattern, LayerNormPattern):
391385
args, kwargs = get_args_and_kwargs_layer_norm(
392386
graph_module,
393387
inputs_inputs,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 20 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from abc import ABC, abstractmethod
1010
from dataclasses import dataclass, field
11-
from typing import Callable, List, Optional, Tuple, Type, Union
11+
from typing import List, Optional, Tuple, Union
1212

1313
import torch
1414
from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
@@ -47,17 +47,15 @@ class PartitionAnchors:
4747

4848
class QuantizationPattern(ABC):
4949
@abstractmethod
50-
def partition_types(
51-
self,
52-
) -> Union[List[Type[torch.nn.Module]], List[Callable[..., torch.Tensor]]]:
50+
def partition_types(self) -> list[OpOverload]:
5351
"""
54-
List of types to be passed to find_sequential_partitions.
52+
List of types to be passed to find_sequential_partitions_aten.
5553
"""
5654
pass
5755

5856
@abstractmethod
5957
def get_anchors(
60-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
58+
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
6159
) -> Optional[PartitionAnchors]:
6260
pass
6361

@@ -71,8 +69,8 @@ def replacement_op(self) -> OpOverload:
7169

7270

7371
class AddmmPattern(QuantizationPattern):
74-
def partition_types(self) -> List[Type[torch.nn.Module]]:
75-
return [torch.addmm]
72+
def partition_types(self) -> List[OpOverload]:
73+
return [torch.ops.aten.addmm.default]
7674

7775
def get_anchors(
7876
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -103,8 +101,8 @@ def replacement_op(self) -> OpOverload:
103101

104102

105103
class BmmPattern(QuantizationPattern):
106-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
107-
return [torch.bmm]
104+
def partition_types(self) -> List[OpOverload]:
105+
return [torch.ops.aten.bmm.default]
108106

109107
def get_anchors(
110108
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -123,8 +121,8 @@ def replacement_op(self) -> OpOverload:
123121

124122

125123
class Conv1dPattern(QuantizationPattern):
126-
def partition_types(self) -> List[Type[torch.nn.Module]]:
127-
return [torch.nn.Conv1d]
124+
def partition_types(self) -> List[OpOverload]:
125+
return [torch.ops.aten.conv1d.default]
128126

129127
def get_anchors(
130128
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -161,8 +159,8 @@ def replacement_op(self) -> OpOverload:
161159

162160

163161
class Conv2dPattern(QuantizationPattern):
164-
def partition_types(self) -> List[Type[torch.nn.Module]]:
165-
return [torch.nn.Conv2d]
162+
def partition_types(self) -> List[OpOverload]:
163+
return [torch.ops.aten.conv2d.default]
166164

167165
def get_anchors(
168166
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -199,32 +197,8 @@ def replacement_op(self) -> OpOverload:
199197

200198

201199
class LayerNormPattern(QuantizationPattern):
202-
def partition_types(self) -> List[Type[torch.nn.Module]]:
203-
return [torch.nn.LayerNorm]
204-
205-
def get_anchors(
206-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
207-
) -> PartitionAnchors:
208-
layer_norm_node = fused_partition[0].nodes[-1]
209-
210-
# Weights and biases are used as fp32 by our kernel, so they are
211-
# passed in as others here along with the normalized shape.
212-
return PartitionAnchors(
213-
inputs=[(layer_norm_node, 0)],
214-
weights=[],
215-
biases=[],
216-
# Ordering: normalized_shape, weights, bias
217-
others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)],
218-
output=[(layer_norm_node,)],
219-
)
220-
221-
def replacement_op(self) -> OpOverload:
222-
return torch.ops.cadence.quantized_layer_norm.default
223-
224-
225-
class LayerNormFunctionalPattern(QuantizationPattern):
226-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
227-
return [torch.nn.functional.layer_norm]
200+
def partition_types(self) -> List[OpOverload]:
201+
return [torch.ops.aten.layer_norm.default]
228202

229203
def get_anchors(
230204
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -257,8 +231,8 @@ def replacement_op(self) -> OpOverload:
257231

258232

259233
class LinearPattern(QuantizationPattern):
260-
def partition_types(self) -> List[Type[torch.nn.Module]]:
261-
return [torch.nn.Linear]
234+
def partition_types(self) -> List[OpOverload]:
235+
return [torch.ops.aten.linear.default]
262236

263237
def get_anchors(
264238
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -294,47 +268,9 @@ def replacement_op(self) -> OpOverload:
294268
return torch.ops.cadence.quantized_linear.default
295269

296270

297-
class LinearFunctionalPattern(QuantizationPattern):
298-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
299-
return [torch.nn.functional.linear]
300-
301-
def get_anchors(
302-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
303-
) -> PartitionAnchors:
304-
linear_node = fused_partition[0].nodes[-1]
305-
306-
bias_qspec = DerivedQuantizationSpec(
307-
derived_from=[
308-
(linear_node.args[0], linear_node),
309-
(linear_node.args[1], linear_node),
310-
],
311-
derive_qparams_fn=get_bias_qparams,
312-
dtype=torch.int32,
313-
quant_min=-(2**31),
314-
quant_max=2**31 - 1,
315-
qscheme=torch.per_tensor_affine,
316-
)
317-
318-
# Keep bias empty if not supplied
319-
bias = []
320-
if len(linear_node.args) > 2 and linear_node.args[2] is not None:
321-
bias = [(linear_node, 2, bias_qspec)]
322-
323-
return PartitionAnchors(
324-
inputs=[(linear_node, 0)],
325-
weights=[(linear_node, 1)],
326-
# pyre-fixme[6]: Incompatible parameter type
327-
biases=bias,
328-
output=[(linear_node,)],
329-
)
330-
331-
def replacement_op(self) -> OpOverload:
332-
return torch.ops.cadence.quantized_linear.default
333-
334-
335271
class MatmulPattern(QuantizationPattern):
336-
def partition_types(self) -> List[Callable[..., torch.Tensor]]:
337-
return [torch.matmul]
272+
def partition_types(self) -> List[OpOverload]:
273+
return [torch.ops.aten.matmul.default]
338274

339275
def get_anchors(
340276
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
@@ -353,8 +289,8 @@ def replacement_op(self) -> OpOverload:
353289

354290

355291
class ReluPattern(QuantizationPattern):
356-
def partition_types(self) -> List[Type[torch.nn.Module]]:
357-
return [torch.nn.ReLU]
292+
def partition_types(self) -> List[OpOverload]:
293+
return [torch.ops.aten.relu.default]
358294

359295
def get_anchors(
360296
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]

0 commit comments

Comments
 (0)