Skip to content

[DevX] Support filtering benchmark configs in on-demand workflow #7639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .ci/scripts/__init__.py
Empty file.
145 changes: 104 additions & 41 deletions .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import logging
import os
import re
from typing import Any, Dict
import sys
from typing import Any, Dict, List

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from examples.models import MODEL_NAME_TO_MODEL


Expand Down Expand Up @@ -45,6 +47,79 @@
}


def extract_all_configs(data, target_os=None):
if isinstance(data, dict):
# If target_os is specified, include "xplat" and the specified branch
include_branches = {"xplat", target_os} if target_os else data.keys()
return [
v
for key, value in data.items()
if key in include_branches
for v in extract_all_configs(value, target_os)
]
elif isinstance(data, list):
return [v for item in data for v in extract_all_configs(item, target_os)]
else:
return [data]


def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
"""
Generate a list of compatible benchmark configurations for a given model name and target OS.
Args:
model_name (str): The name of the model to generate configurations for.
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
Returns:
List[str]: A list of compatible benchmark configurations.
Raises:
None
Example:
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
"""
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
if target_os != "xplat":
# Add OS-specific configs
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")

return configs


def parse_args() -> Any:
"""
Parse command-line arguments.
Expand Down Expand Up @@ -82,6 +157,11 @@ def comma_separated(value: str):
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
)
parser.add_argument(
"--configs",
type=comma_separated, # Use the custom parser for comma-separated values
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
)

return parser.parse_args()

Expand All @@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
set_output("benchmark_configs", {"include": [...]})
"""

if os.getenv("GITHUB_OUTPUT"):
print(f"Setting {val} to GitHub output")
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
print(f"{name}={val}", file=env)
else:
github_output = os.getenv("GITHUB_OUTPUT")
if not github_output:
print(f"::set-output name={name}::{val}")
return

try:
with open(github_output, "a") as env:
env.write(f"{name}={val}\n")
except PermissionError:
# Fall back to printing in case of permission error in unit tests
print(f"::set-output name={name}::{val}")


Expand All @@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
return bool(re.match(pattern, model_name))


def get_benchmark_configs() -> Dict[str, Dict]:
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
"""
Gather benchmark configurations for a given set of models on the target operating system and devices.
Expand Down Expand Up @@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
}
"""
args = parse_args()
target_os = args.os
devices = args.devices
models = args.models
target_os = args.os
target_configs = args.configs

benchmark_configs = {"include": []}

for model_name in models:
configs = []
if is_valid_huggingface_model_id(model_name):
if model_name.startswith("meta-llama/"):
# LLaMA models
repo_name = model_name.split("meta-llama/")[1]
if "qlora" in repo_name.lower():
configs.append("llama3_qlora")
elif "spinquant" in repo_name.lower():
configs.append("llama3_spinquant")
else:
configs.append("llama3_fb16")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if config.startswith("llama")
]
configs.extend(generate_compatible_configs(model_name, target_os))
print(f"Discovered all supported configs for model '{model_name}': {configs}")
if target_configs is not None:
for config in target_configs:
if config not in configs:
raise Exception(
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
f"Supported configs are: {configs}"
)
else:
# Non-LLaMA models
configs.append("hf_xnnpack_fp32")
elif model_name in MODEL_NAME_TO_MODEL:
# ExecuTorch in-tree non-GenAI models
configs.append("xnnpack_q8")
configs.extend(
[
config
for config in BENCHMARK_CONFIGS.get(target_os, [])
if not config.startswith("llama")
]
)
else:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
continue
configs = target_configs
print(f"Using provided configs {configs} for model '{model_name}'")

# Add configurations for each valid device
for device in devices:
Expand Down
189 changes: 189 additions & 0 deletions .ci/scripts/tests/test_gather_benchmark_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import importlib.util
import os
import subprocess
import sys
import unittest
from unittest.mock import mock_open, patch

import pytest

# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gather_benchmark_configs)


@pytest.mark.skipif(
sys.platform != "linux", reason="The script under test runs on Linux runners only"
)
class TestGatehrBenchmarkConfigs(unittest.TestCase):

def test_extract_all_configs_android(self):
android_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
)
self.assertIn("xnnpack_q8", android_configs)
self.assertIn("qnn_q8", android_configs)
self.assertIn("llama3_spinquant", android_configs)
self.assertIn("llama3_qlora", android_configs)

def test_extract_all_configs_ios(self):
ios_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
)

self.assertIn("xnnpack_q8", ios_configs)
self.assertIn("coreml_fp16", ios_configs)
self.assertIn("mps", ios_configs)
self.assertIn("llama3_coreml_ane", ios_configs)
self.assertIn("llama3_spinquant", ios_configs)
self.assertIn("llama3_qlora", ios_configs)

def test_generate_compatible_configs_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16", "llama3_coreml_ane"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_quantized_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
expected = ["llama3_spinquant"]
self.assertEqual(result, expected)

model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
expected = ["llama3_qlora"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_non_genai_model(self):
model_name = "mv2"
target_os = "xplat"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "qnn_q8"]
self.assertEqual(result, expected)

target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_unknown_model(self):
model_name = "unknown_model"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
self.assertEqual(result, [])

def test_is_valid_huggingface_model_id_valid(self):
valid_model = "meta-llama/Llama-3.2-1B"
self.assertTrue(
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
)

@patch("builtins.open", new_callable=mock_open)
@patch("os.getenv", return_value=None)
def test_set_output_no_github_env(self, mock_getenv, mock_file):
with patch("builtins.print") as mock_print:
gather_benchmark_configs.set_output("test_name", "test_value")
mock_print.assert_called_with("::set-output name=test_name::test_value")

def test_device_pools_contains_all_devices(self):
expected_devices = [
"apple_iphone_15",
"apple_iphone_15+ios_18",
"samsung_galaxy_s22",
"samsung_galaxy_s24",
"google_pixel_8_pro",
]
for device in expected_devices:
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)

def test_gather_benchmark_configs_cli(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": None,
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
self.assertIn('"model": "mv2"', result.stdout)
self.assertIn('"model": "dl3"', result.stdout)
self.assertIn('"config": "coreml_fp16"', result.stdout)
self.assertIn('"config": "xnnpack_q8"', result.stdout)
self.assertIn('"config": "mps"', result.stdout)

def test_gather_benchmark_configs_cli_specified_configs(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": "coreml_fp16,xnnpack_q8",
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
self.assertIn('"model": "mv2"', result.stdout)
self.assertIn('"model": "dl3"', result.stdout)
self.assertIn('"config": "coreml_fp16"', result.stdout)
self.assertIn('"config": "xnnpack_q8"', result.stdout)
self.assertNotIn('"config": "mps"', result.stdout)

def test_gather_benchmark_configs_cli_specified_configs_raise(self):
args = {
"models": "mv2,dl3",
"os": "ios",
"devices": "apple_iphone_15",
"configs": "qnn_q8",
}

cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
for key, value in args.items():
if value is not None:
cmd.append(f"--{key}")
cmd.append(value)

result = subprocess.run(cmd, capture_output=True, text=True)
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)


if __name__ == "__main__":
unittest.main()
Loading
Loading