Skip to content

Commit 7dc55e6

Browse files
Merge branch 'main' into tosa_1.0-misc-tests
2 parents 11cae8a + c913634 commit 7dc55e6

File tree

23 files changed

+1164
-382
lines changed

23 files changed

+1164
-382
lines changed

.buckconfig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
[buck2]
4141
restarter=true
42+
file_watcher=notify
4243

4344
[oss]
4445
folly_cxx_tests = False

backends/cadence/aot/fuse_ops.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -814,11 +814,61 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
814814

815815

816816
@register_cadence_pass(CadencePassAttribute(opt_level=1))
817-
class FuseMulIntoDequantPass(ExportPass):
817+
class FuseMulScalarIntoDequantPass(ExportPass):
818818
"""
819-
Looks for the pattern where atem.mul is multiplying the outputs of dequantize
820-
and aten.full. If found, updates the dequant scale to reflect the multiplication
821-
and removes the full and mul nodes.
819+
Looks for the pattern where aten.mul.Scalar is multiplying the
820+
outputs of dequantize. If found, updates the dequant scale
821+
to reflect the multiplication and removes the mul node.
822+
"""
823+
824+
def attempt_fusion(
825+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
826+
) -> None:
827+
if node.target not in {
828+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
829+
exir_ops.edge.cadence.dequantize_per_tensor.default,
830+
}:
831+
return
832+
833+
# ensure that the single user of dequant is aten.mul.Scalar
834+
user = list(node.users.keys())[0]
835+
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
836+
return
837+
838+
# ensure that the other arg to mul is a node (i.e. not a constant)
839+
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
840+
return
841+
842+
new_deq_args = list(node.args)
843+
assert isinstance(node.args[1], Number)
844+
assert isinstance(user.args[1], Number)
845+
# pyre-ignore[58]: Unsupported operand *
846+
new_deq_args[1] = node.args[1] * user.args[1]
847+
848+
logging.debug(
849+
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
850+
)
851+
852+
user.replace_all_uses_with(node)
853+
node.args = tuple(new_deq_args)
854+
855+
graph_module.graph.erase_node(user)
856+
857+
graph_module.recompile()
858+
859+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
860+
for node in graph_module.graph.nodes:
861+
self.attempt_fusion(graph_module, node)
862+
result = super().call(graph_module)
863+
return result
864+
865+
866+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867+
class FuseMulTensorIntoDequantPass(ExportPass):
868+
"""
869+
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
870+
and aten.full, or vice versa. If found, updates the dequant scale to reflect
871+
the multiplication and removes the full and mul nodes.
822872
"""
823873

824874
def attempt_fusion(
@@ -1017,7 +1067,8 @@ class CadenceFuseOpsInGraph:
10171067
FuseCascadedTransposeOrPermuteOps,
10181068
FuseCascadedViewOps,
10191069
FuseQuantDequantToRequantizePass,
1020-
FuseMulIntoDequantPass,
1070+
FuseMulTensorIntoDequantPass,
1071+
FuseMulScalarIntoDequantPass,
10211072
FuseFullThenReshapePass,
10221073
FuseTransposeOrPermuteOpPairsPass,
10231074
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
)
2020
from executorch.backends.cadence.aot.fuse_ops import (
2121
FuseFullThenReshapePass,
22-
FuseMulIntoDequantPass,
22+
FuseMulScalarIntoDequantPass,
23+
FuseMulTensorIntoDequantPass,
2324
FuseQuantDequantToRequantizePass,
2425
FuseTransposeOrPermuteOpPairsPass,
2526
)
@@ -446,7 +447,7 @@ def forward(self, x):
446447

447448
inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
448449
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
449-
graph_module = FuseMulIntoDequantPass()(graph_module).graph_module
450+
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module
450451

451452
# verify that the mul and full ops were removed
452453
self.check_op_counts(
@@ -467,6 +468,47 @@ def forward(self, x):
467468
deq_scale = node.args[1]
468469
self.assertEqual(deq_scale, 4.5)
469470

471+
def test_fuse_mul_scalar_into_dequant(self):
472+
dequant_scale = 0.006
473+
mul_value = 0.3
474+
475+
builder = GraphBuilder()
476+
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
477+
quant = builder.call_operator(
478+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
479+
args=(x, 1, 0, -128, 127, torch.int8),
480+
)
481+
dequant = builder.call_operator(
482+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
483+
args=(quant, dequant_scale, 5, -128, 127, torch.int8),
484+
)
485+
mul_scalar = builder.call_operator(
486+
op=exir_ops.edge.aten.mul.Scalar,
487+
args=(dequant, mul_value),
488+
)
489+
builder.output(mul_scalar)
490+
graph_module = builder.get_graph_module()
491+
492+
graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module
493+
494+
# verify that the mul and full ops were removed
495+
self.check_op_counts(
496+
graph_module,
497+
expected_op_counts={
498+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
499+
exir_ops.edge.aten.mul.Scalar: 0,
500+
},
501+
)
502+
503+
# verify that the dequant scale value was updated correctly
504+
for node in graph_module.graph.nodes:
505+
if (
506+
node.target
507+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
508+
):
509+
deq_scale = node.args[1]
510+
self.assertEqual(deq_scale, dequant_scale * mul_value)
511+
470512
def test_fuse_then_transpose_pass(self):
471513
# Create a graph with full -> transpose.
472514
builder = GraphBuilder()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Example of an external model for the ARM AOT Compiler
2+
Example of an external Python file to be used as a module by the `run.sh` (and the `aot_arm_compiler.py`) scripts in `examples/arm` directory.
3+
Just pass the path of the `add.py` file as `--model_name`:
4+
5+
`ModelUnderTest` should be a `torch.nn.module` instance.
6+
7+
`ModelInputs` should be a tuple of inputs to the forward function.

examples/arm/example_modules/add.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
4+
class myModelAdd(torch.nn.Module):
5+
def __init__(self):
6+
super().__init__()
7+
8+
def forward(self, x):
9+
return x + x
10+
11+
12+
ModelUnderTest = myModelAdd()
13+
ModelInputs = (torch.ones(5),)

examples/models/llama/export_llama_lib.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def _get_source_transforms( # noqa
12731273
preq_mode: Optional[str] = None,
12741274
preq_group_size: Optional[int] = None,
12751275
preq_embedding_quantize: Optional[str] = None,
1276+
local_global_attention: Optional[List[int]] = None,
12761277
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
12771278
"""
12781279
Return a list of functions that transform a graph.
@@ -1340,23 +1341,11 @@ def _get_source_transforms( # noqa
13401341
transformations based on the given checkpoint first. In those cases,
13411342
this wil be a no-op.
13421343
"""
1343-
1344-
# Create a mock args object with the necessary attributes
1345-
class Args:
1346-
pass
1347-
1348-
args = Args()
1349-
args.checkpoint = checkpoint
1350-
args.tokenizer_path = tokenizer_path
1351-
args.embedding_quantize = embedding_quantize
1352-
args.use_shared_embedding = use_shared_embedding
1353-
args.use_qat = use_qat
1354-
args.use_lora = use_lora
1355-
args.preq_mode = preq_mode
1356-
args.preq_group_size = preq_group_size
1357-
args.preq_embedding_quantize = preq_embedding_quantize
1358-
1359-
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
1344+
transforms.append(
1345+
get_quant_embedding_transform(
1346+
embedding_quantize, use_shared_embedding, checkpoint_dtype
1347+
)
1348+
)
13601349

13611350
# quantization_mode should be applied after embedding_quantize
13621351
# to support shared_embedding
@@ -1374,30 +1363,17 @@ class Args:
13741363
There are cases where this may be a no-op, namely, if all linears are
13751364
quantized in the checkpoint.
13761365
"""
1377-
1378-
# Create a mock args object with the necessary attributes
1379-
class Args:
1380-
pass
1381-
1382-
args = Args()
1383-
args.checkpoint = checkpoint
1384-
args.tokenizer_path = tokenizer_path
1385-
args.quantization_mode = quantization_mode
1386-
args.group_size = group_size
1387-
args.use_shared_embedding = use_shared_embedding
1388-
args.calibration_tasks = calibration_tasks
1389-
args.calibration_limit = calibration_limit
1390-
args.calibration_seq_length = calibration_seq_length
1391-
args.use_shared_embedding = use_shared_embedding
1392-
args.use_qat = use_qat
1393-
args.use_lora = use_lora
1394-
args.preq_mode = preq_mode
1395-
13961366
transforms.append(
13971367
get_quant_weight_transform(
1398-
args=args,
1368+
quantization_mode=quantization_mode,
1369+
group_size=group_size,
13991370
computation_dtype=dtype_override,
14001371
checkpoint_dtype=checkpoint_dtype,
1372+
checkpoint_path=checkpoint,
1373+
tokenizer_path=tokenizer_path,
1374+
calibration_tasks=calibration_tasks,
1375+
calibration_limit=calibration_limit,
1376+
calibration_seq_length=calibration_seq_length,
14011377
)
14021378
)
14031379

@@ -1467,7 +1443,7 @@ class Args:
14671443
if vulkan:
14681444
transforms.append(replace_with_vulkan_rotary_emb)
14691445

1470-
if getattr(args, "local_global_attention", None) is not None:
1446+
if local_global_attention:
14711447
transforms.append(
14721448
partial(
14731449
replace_kv_cache_with_ring_kv_cache,

examples/models/llama/source_transformation/quantize.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def quantize( # noqa C901
4141
checkpoint_dtype: Optional[DType] = None,
4242
checkpoint_path: Optional[Path] = None,
4343
# following arguments only available when setting int4 or gptq quantization.
44-
group_size: Optional[int] = 128,
44+
group_size: Optional[int] = None,
4545
# following arguments are only used for GPTQ
4646
calibration_tasks: Optional[list] = None,
4747
calibration_limit: Optional[int] = None,
@@ -146,9 +146,9 @@ def quantize( # noqa C901
146146
print("quantized model:", model)
147147
return model
148148
elif qmode == "8da4w":
149-
# Check for required args
150149
if group_size is None:
151-
raise Exception("For 8da4w quantization, group size must be specified.")
150+
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
151+
group_size = 128
152152

153153
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
154154
from torchao.utils import unwrap_tensor_subclass
@@ -784,16 +784,20 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
784784
############################ Source Transform Start #######################
785785

786786

787-
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
788-
if args.embedding_quantize.startswith("torchao:"):
787+
def get_quant_embedding_transform(
788+
embedding_quantize: str,
789+
use_shared_embedding: bool = False,
790+
dtype_override: Optional[DType] = None,
791+
):
792+
if embedding_quantize.startswith("torchao:"):
789793
from torchao.experimental.quant_api import (
790794
EmbeddingQuantizer,
791795
SharedEmbeddingQuantizer,
792796
)
793797
from torchao.quantization.granularity import PerAxis, PerGroup
794798
from torchao.quantization.quant_api import MappingType
795799

796-
quant_args = args.embedding_quantize.split(":")[1].split(",")
800+
quant_args = embedding_quantize.split(":")[1].split(",")
797801
if len(quant_args) == 2:
798802
bitwidth, group_size = quant_args
799803
is_asymmetric = True
@@ -814,7 +818,7 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
814818

815819
def _torchao_embedding_quantizer(model):
816820
with torch.no_grad():
817-
if not args.use_shared_embedding:
821+
if not use_shared_embedding:
818822
EmbeddingQuantizer(
819823
weight_dtype=weight_dtype,
820824
granularity=granularity,
@@ -831,7 +835,7 @@ def _torchao_embedding_quantizer(model):
831835

832836
return _torchao_embedding_quantizer
833837

834-
bitwidth, group_size = args.embedding_quantize.split(",")
838+
bitwidth, group_size = embedding_quantize.split(",")
835839
if group_size == "none" or group_size == "None" or group_size == "0":
836840
group_size = None
837841
else:
@@ -848,34 +852,27 @@ def _torchao_embedding_quantizer(model):
848852

849853

850854
def get_quant_weight_transform(
851-
args,
855+
quantization_mode: str,
856+
group_size: Optional[int] = None,
852857
computation_dtype: Optional[DType] = None,
853858
checkpoint_dtype: Optional[DType] = None,
859+
checkpoint_path: Optional[Path] = None,
860+
tokenizer_path: Optional[Path] = None,
861+
calibration_tasks: Optional[list] = None,
862+
calibration_limit: Optional[int] = None,
863+
calibration_seq_length: Optional[int] = None,
854864
):
855-
# If these optional args are None, don't provide them to quantize().
856-
quant_args_str = [
857-
"group_size",
858-
"calibration_tasks",
859-
"calibration_limit",
860-
"calibration_seq_length",
861-
]
862-
arg_dict = vars(args)
863-
quant_args = {
864-
param: val
865-
for param in quant_args_str
866-
if (val := arg_dict.get(param)) is not None
867-
}
868-
869865
return partial(
870866
quantize,
871-
**quant_args,
872-
qmode=args.quantization_mode,
867+
qmode=quantization_mode,
873868
computation_dtype=computation_dtype,
874869
checkpoint_dtype=checkpoint_dtype,
875-
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
876-
tokenizer_path=(
877-
Path(path) if (path := args.tokenizer_path) is not None else None
878-
),
870+
checkpoint_path=(Path(path) if (path := checkpoint_path) is not None else None),
871+
group_size=group_size,
872+
calibration_tasks=calibration_tasks,
873+
calibration_limit=calibration_limit,
874+
calibration_seq_length=calibration_seq_length,
875+
tokenizer_path=(Path(path) if (path := tokenizer_path) is not None else None),
879876
)
880877

881878

examples/models/llava/export_llava.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,16 @@ def forward(self, input_pos, embeddings):
107107
"4,32",
108108
]
109109
)
110-
quant_transform = get_quant_weight_transform(args, dtype_override)
110+
quant_transform = get_quant_weight_transform(
111+
quantization_mode=args.quantization_mode,
112+
group_size=args.group_size,
113+
computation_dtype=dtype_override,
114+
checkpoint_path=args.checkpoint_path,
115+
tokenizer_path=args.tokenizer_path,
116+
calibration_tasks=args.calibration_tasks,
117+
calibration_limit=args.calibration_limit,
118+
calibration_seq_length=args.calibration_seq_length,
119+
)
111120
_, quantizers, _ = get_quantizer_and_quant_params(args)
112121
source_transforms = []
113122
if llava.use_sdpa_with_kv_cache_op:

0 commit comments

Comments
 (0)