Skip to content

Arm backend: Add initial Llama model test case #8679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 19, 2025
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
pool_2d_support,
reduce_sum_support,
right_shift_support,
slice_copy_support,
to_copy_support,
tosa_supported_operators,
)
39 changes: 39 additions & 0 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


@register_tosa_support_check
class SliceCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.slice_copy.Tensor]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
if tosa_spec not in self.tosa_specs:
return False

inputs = getNodeArgs(node)
if len(inputs) == 5 and (step := inputs[4].number) != 1:
logging.warning(f"{node.target} with step size of {step} not supported.")
return False
return True
2 changes: 0 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
def get_registered_tosa_support_checks(
tosa_spec: TosaSpecification,
) -> list[Type[SupportedTOSAOperatorCheck]]:

if tosa_spec not in _tosa_spec_support:
raise RuntimeError(
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
Expand Down Expand Up @@ -165,7 +164,6 @@ def is_node_supported(
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
Expand Down
19 changes: 14 additions & 5 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def define_node(
# Handle int8 (quantized) and int32
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]

dim_order = (
inputs[0].dim_order
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
Expand All @@ -61,13 +67,14 @@ def define_node(
# output.dtype == ts.DType.INT32
add_output = output

input1, input2 = tutils.reshape_for_broadcast(
tosa_graph, rescaled_inputs, dim_order
)

# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[input1.name, input2.name],
[add_output.name],
None,
)
Expand Down Expand Up @@ -108,10 +115,12 @@ def define_node(
assert inputs[0].dtype == ts.DType.FP32
assert output.dtype == ts.DType.FP32

input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ADD,
[inputs[0].name, inputs[1].name],
[input1.name, input2.name],
[output.name],
None,
)
26 changes: 21 additions & 5 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import reshape_for_broadcast
from serializer.tosa_serializer import TosaOp


Expand All @@ -43,6 +44,12 @@ def define_node(
output: TosaArg,
) -> None:
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8

dim_order = (
inputs[0].dim_order
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)
input_A = inputs[0]
input_B = inputs[1]
input_qparams = get_input_qparams(node) # pyre-ignore[16]
Expand All @@ -68,15 +75,21 @@ def define_node(
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)

input1, input2 = tutils.reshape_for_broadcast(
tosa_graph,
[
input_A_rescaled,
input_B_rescaled,
],
dim_order,
)

# Do the INT32 Mul
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL,
[
input_A_rescaled.name,
input_B_rescaled.name,
],
[input1.name, input2.name],
[mul_output.name],
attr,
)
Expand All @@ -101,8 +114,11 @@ def define_node(
) -> None:
if inputs[0].dtype == ts.DType.INT8:
return super().define_node(node, tosa_graph, inputs, output)

input1, input2 = reshape_for_broadcast(tosa_graph, inputs)

attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr
)
7 changes: 5 additions & 2 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def define_node(
output: TosaArg,
) -> None:

# See slice_copy_support.py
if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)):
raise ValueError("Unsupported combination of inputs")

# aten.slice_copy supports slicing in 1d at a time.
# The arguments are dimension of slicing, start index and end index.
assert len(inputs) == 4
# The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride.
input_node, dim, start, end = inputs

# Translate and check parameters in Pytorch dim order.
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pytest_configure(config):
)
# Only enable if we also have the TOSA reference model available.
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]

pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined]
pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined]
if getattr(config.option, "fast_fvp", False):
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
Expand All @@ -70,6 +70,11 @@ def try_addoption(*args, **kwargs):
try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.")
try_addoption("--fast_fvp", action="store_true")
try_addoption(
"--llama_inputs",
nargs="+",
help="List of two files. Firstly .pt file. Secondly .json",
)


def pytest_sessionstart(session):
Expand Down
120 changes: 120 additions & 0 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import os
import sys
import unittest

import torch

from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
)


# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
this_files_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../.."))
sys.path.append(project_dir)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class TestLlama(unittest.TestCase):
"""
Test class of Llama models. Type of Llama model depends on command line parameters:
--llama_inputs <path to .pt file> <path to json file>
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
"""

def prepare_model(self):

checkpoint = None
params_file = None
if conftest.is_option_enabled("llama_inputs"):
param_list = conftest.get_option("llama_inputs")
assert (
isinstance(param_list, list) and len(param_list) == 2
), "invalid number of inputs for --llama_inputs"
checkpoint = param_list[0]
params_file = param_list[1]
assert isinstance(checkpoint, str) and isinstance(
params_file, str
), "invalid input for --llama_inputs"
else:
logging.warning(
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
)
return None, None, None

assert os.path.isfile(checkpoint) and os.path.isfile(
params_file
), "Invalid file paths"

# TODO: Enable key value cache
args = [
"--disable_dynamic_shape",
"-c",
checkpoint,
"-p",
params_file,
"--model",
"stories110m",
]
parser = build_args_parser()
args = parser.parse_args(args)

llama_model, llama_inputs, llama_meta = get_llama_model(args)

# TODO: Remove workaround since attention mask should not be persistent,
# it only works if input shape is always the same
freqs_c = "freqs_cos"
freqs_s = "freqs_sin"
for i in range(llama_model.n_layers):
val = llama_model.layers[i].attention.get_buffer("mask")
llama_model.layers[i].attention.register_buffer(
"mask", val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
llama_model.layers[i].attention.rope.register_buffer(
freqs_c, val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
llama_model.layers[i].attention.rope.register_buffer(
freqs_s, val, persistent=True
)

return llama_model, llama_inputs, llama_meta

def test_llama_tosa_MI(self):
llama_model, llama_inputs, llama_meta = self.prepare_model()

if llama_model is None and llama_inputs is None and llama_meta is None:
return

with torch.no_grad():
(
ArmTester(
llama_model,
example_inputs=llama_inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
constant_methods=llama_meta,
)
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
.to_executorch()
.run_method_and_compare_outputs(
inputs=llama_inputs, atol=1.8, rtol=0.01 # TODO: decrease tolerance
)
)
24 changes: 23 additions & 1 deletion backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
Expand Down Expand Up @@ -61,6 +60,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
}


class Add3(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y

test_data: list[input_t2] = {
"3d_randn_diff_rank": (torch.randn(1, 4, 5), torch.randn(4, 1)),
"4d_randn_diff_rank": (torch.randn(1, 1, 4, 4), torch.randn(4, 1)),
"4d_randn_diff_rank_2": (torch.randn(4, 1), torch.randn(1, 1, 4, 5)),
}


@common.parametrize("test_data", Add.test_data)
def test_add_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
Expand Down Expand Up @@ -129,6 +139,18 @@ def test_add_2_tosa_MI(test_data: input_t2):
pipeline.run()


@common.parametrize("test_data", Add3.test_data)
def test_add3_tosa_MI(test_data: input_t2):
pipeline = TosaPipelineMI[input_t2](Add3(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Add3.test_data)
def test_add3_tosa_BI(test_data: input_t2):
pipeline = TosaPipelineBI[input_t2](Add3(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Add2.test_data)
def test_add_2_tosa_BI(test_data: input_t2):
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)
Expand Down
Loading
Loading