Skip to content

move schema files #436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ python_unittest(
"//executorch/backends/xnnpack:xnnpack_preprocess",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/bundled_program:config",
"//executorch/bundled_program:core",
"//executorch/bundled_program/aot:config",
"//executorch/bundled_program/aot:core",
"//executorch/bundled_program/serialize:lib",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
Expand All @@ -49,8 +49,8 @@ python_unittest(
"//executorch/backends/xnnpack:xnnpack_preprocess",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/bundled_program:config",
"//executorch/bundled_program:core",
"//executorch/bundled_program/aot:config",
"//executorch/bundled_program/aot:core",
"//executorch/bundled_program/serialize:lib",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
Expand Down Expand Up @@ -78,8 +78,8 @@ python_unittest(
"//executorch/backends/xnnpack:xnnpack_preprocess",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/bundled_program:config",
"//executorch/bundled_program:core",
"//executorch/bundled_program/aot:config",
"//executorch/bundled_program/aot:core",
"//executorch/bundled_program/serialize:lib",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
Expand Down Expand Up @@ -127,9 +127,7 @@ python_unittest(
]),
deps = [
"//caffe2:torch",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/exir:lib",
"//pytorch/vision:torchvision",
],
)
24 changes: 16 additions & 8 deletions backends/xnnpack/test/test_xnnpack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import unittest
from random import randint
from typing import Any, Tuple
from typing import Any, List, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -26,8 +26,8 @@
# import the xnnpack backend implementation
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend

from executorch.bundled_program.config import BundledConfig
from executorch.bundled_program.core import create_bundled_program
from executorch.bundled_program.aot.config import MethodTestCase, MethodTestSuite
from executorch.bundled_program.aot.core import create_bundled_program
from executorch.bundled_program.serialize import (
serialize_from_bundled_program_to_flatbuffer,
)
Expand Down Expand Up @@ -101,14 +101,22 @@ def save_bundled_program(representative_inputs, program, ref_output, output_path
niter = 1

print("generating bundled program inputs / outputs")
inputs_list = [list(representative_inputs) for _ in range(niter)]
expected_outputs_list = [
[[ref_output] for x in inputs_list],

method_test_cases: List[MethodTestCase] = []
for _ in range(niter):
method_test_cases.append(
MethodTestCase(
inputs=list(representative_inputs),
expected_outputs=[ref_output],
)
)

method_test_suites = [
MethodTestSuite(method_name="forward", method_test_cases=method_test_cases)
]
bundled_config = BundledConfig([inputs_list], expected_outputs_list)

print("creating bundled program...")
bundled_program = create_bundled_program(program, bundled_config)
bundled_program = create_bundled_program(program, method_test_suites)

print("serializing bundled program...")
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
Expand Down
File renamed without changes.
89 changes: 89 additions & 0 deletions bundled_program/aot/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass
from typing import get_args, List, Optional, Sequence, Union

import torch
from executorch.extension.pytree import tree_flatten

from typing_extensions import TypeAlias

"""
The data types currently supported for element to be bundled. It should be
consistent with the types in bundled_program.schema.Value.
"""
ConfigValue: TypeAlias = Union[
torch.Tensor,
int,
bool,
float,
]

"""
All supported types for input/expected output of MethodTestCase.

Namedtuple is also supported and listed implicity since it is a subclass of tuple.
"""

# pyre-ignore
DataContainer: TypeAlias = Union[list, tuple, dict]


class MethodTestCase:
"""Test case with inputs and expected outputs
The expected_outputs could be None if user only want to user the test case for profiling."""

def __init__(
self, inputs: DataContainer, expected_outputs: Optional[DataContainer] = None
) -> None:
self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
self.expected_outputs: List[ConfigValue] = []
if expected_outputs:
self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)

def _flatten_and_sanity_check(
self, unflatten_data: DataContainer
) -> List[ConfigValue]:
"""Flat the given data and check its legality

Args:
unflatten_data: Data needs to be flatten.

Returns:
flatten_data: Flatten data with legal type.
"""

assert isinstance(
unflatten_data, get_args(DataContainer)
), f"The input or expected output of MethodTestCase should be in list, tuple or dict, but got {type(unflatten_data)}."

# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
flatten_data, _ = tree_flatten(unflatten_data)

for data in flatten_data:
assert isinstance(
data,
get_args(ConfigValue),
), "The type of input {} with type {} is not supported.\n".format(
data, type(data)
)
assert not isinstance(
data,
type(None),
), "The input {} should not be in null type.\n".format(data)

return flatten_data


@dataclass
class MethodTestSuite:
"""All info related to verify method"""

method_name: str
test_cases: Sequence[MethodTestCase]
Loading