Skip to content

Commit 2a74aec

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Cleaning et_schema of InferenceRun for OSS (#565)
Summary: These are concepts that aren't useful in the OSS world. Associated tests were also moved - Moved out InferenceRun, OperatorGraphWithStats, ExportedETOperatorGraph for internal - Changed FXOperatorGraph to extending OperatorGraph instead of OperatorGraphWithStats - Tests for EDIR (to be renamed) moved to internal since they are heavily coupled with InferenceRun Reviewed By: tarun292 Differential Revision: D49839120
1 parent 80ee5f6 commit 2a74aec

25 files changed

+86
-691
lines changed

sdk/edir/TARGETS

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@ python_library(
1717
"et_schema.py",
1818
],
1919
deps = [
20-
"fbsource//third-party/pypi/numpy:numpy",
2120
":base_schema",
2221
"//caffe2:torch",
2322
"//executorch/exir:lib",
24-
"//executorch/exir:schema",
25-
"//executorch/exir/_serialize:lib",
26-
"//executorch/sdk/etdump:schema",
27-
"//executorch/sdk/etdump:serialize",
2823
],
2924
)

sdk/edir/et_schema.py

Lines changed: 2 additions & 589 deletions
Large diffs are not rendered by default.

sdk/edir/tests/TARGETS renamed to sdk/fb/tests/TARGETS

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ python_unittest(
1515
"//caffe2:torch",
1616
"//executorch/exir:lib",
1717
"//executorch/exir:lowered_backend_module",
18-
"//executorch/sdk/edir:base_schema",
1918
"//executorch/sdk/edir:et_schema",
2019
"//executorch/sdk/etdump:schema",
20+
"//executorch/sdk/fb:et_schema",
2121
],
2222
)
2323

@@ -31,9 +31,9 @@ python_library(
3131
"//caffe2:torch",
3232
"//executorch/exir:lib",
3333
"//executorch/exir:lowered_backend_module",
34-
"//executorch/sdk/edir:base_schema",
3534
"//executorch/sdk/edir:et_schema",
3635
"//executorch/sdk/etdump:schema",
36+
"//executorch/sdk/fb:et_schema",
3737
],
3838
)
3939

@@ -42,14 +42,14 @@ python_binary(
4242
srcs = [
4343
"exported_op_graph_test.py",
4444
],
45-
main_module = "executorch.sdk.edir.tests.exported_op_graph_test",
45+
main_module = "executorch.sdk.fb.tests.exported_op_graph_test",
4646
deps = [
4747
"fbsource//third-party/pypi/parameterized:parameterized",
4848
"//caffe2:torch",
4949
"//executorch/exir:lib",
5050
"//executorch/exir:lowered_backend_module",
51-
"//executorch/sdk/edir:base_schema",
5251
"//executorch/sdk/edir:et_schema",
5352
"//executorch/sdk/etdump:schema",
53+
"//executorch/sdk/fb:et_schema",
5454
],
5555
)

sdk/edir/tests/exported_op_graph_test.py renamed to sdk/fb/tests/exported_op_graph_test.py

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
#
4-
# This source code is licensed under the BSD-style license found in the
5-
# LICENSE file in the root directory of this source tree.
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
62

73
import copy
84
import unittest
@@ -18,13 +14,9 @@
1814
ExirExportedProgram,
1915
)
2016
from executorch.exir.lowered_backend_module import LoweredBackendModule
21-
from executorch.sdk.edir.base_schema import OperatorNode
22-
from executorch.sdk.edir.et_schema import (
23-
ExportedETOperatorGraph,
24-
FXOperatorGraph,
25-
InferenceRun,
26-
)
17+
from executorch.sdk.edir.et_schema import FXOperatorGraph
2718
from executorch.sdk.etdump.schema import ETDump, ProfileBlock, ProfileEvent, RunData
19+
from executorch.sdk.fb.et_schema import ExportedETOperatorGraph, InferenceRun
2820

2921
from parameterized import parameterized
3022
from torch import Tensor
@@ -697,37 +689,36 @@ def test_gen_from_fx_graph(self, model_name: str, model: torch.nn.Module) -> Non
697689
op_graph = gen_fx_graph_file_contents(et_program.dump_graph_module())
698690
self.check_graph_equal(op_graph, model_name, "et_dialect")
699691

700-
# pyre-ignore
701-
@parameterized.expand(MODELS)
702-
def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
703-
_, _, et_program = gen_graphs_from_model(model)
704-
op_graph = FXOperatorGraph.gen_operator_graph(et_program.dump_graph_module())
705-
inference_run = model.gen_inference_run()
706-
op_graph.attach_metadata(inference_run)
707-
708-
def verify_metadata_containment(
709-
graph: FXOperatorGraph, inference_run: InferenceRun
710-
) -> None:
711-
validation_map = inference_run.node_metadata
712-
713-
for node in graph.elements:
714-
# Recursively check subgraph nodes
715-
if isinstance(node, FXOperatorGraph):
716-
verify_metadata_containment(node, inference_run)
717-
# Check that each node contains the corresponding metadata fields
718-
if isinstance(node, OperatorNode) and node.metadata is not None:
719-
metadata = node.metadata
720-
debug_handle = metadata.get("debug_handle")
721-
if debug_handle in validation_map:
722-
self.assertDictContainsSubset(
723-
validation_map[debug_handle], metadata
724-
)
725-
726-
# Check for run level metadata
727-
if op_graph.metadata is not None:
728-
self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
729-
730-
verify_metadata_containment(op_graph, inference_run)
692+
# @parameterized.expand(MODELS)
693+
# def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
694+
# _, _, et_program = gen_graphs_from_model(model)
695+
# op_graph = FXOperatorGraph.gen_operator_graph(et_program.dump_graph_module())
696+
# inference_run = model.gen_inference_run()
697+
# op_graph.attach_metadata(inference_run)
698+
699+
# def verify_metadata_containment(
700+
# graph: FXOperatorGraph, inference_run: InferenceRun
701+
# ) -> None:
702+
# validation_map = inference_run.node_metadata
703+
704+
# for node in graph.elements:
705+
# # Recursively check subgraph nodes
706+
# if isinstance(node, FXOperatorGraph):
707+
# verify_metadata_containment(node, inference_run)
708+
# # Check that each node contains the corresponding metadata fields
709+
# if isinstance(node, OperatorNode) and node.metadata is not None:
710+
# metadata = node.metadata
711+
# debug_handle = metadata.get("debug_handle")
712+
# if debug_handle in validation_map:
713+
# self.assertDictContainsSubset(
714+
# validation_map[debug_handle], metadata
715+
# )
716+
717+
# # Check for run level metadata
718+
# if op_graph.metadata is not None:
719+
# self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
720+
721+
# verify_metadata_containment(op_graph, inference_run)
731722

732723

733724
class ExportedOpGraphTest(unittest.TestCase):
@@ -747,36 +738,35 @@ def test_gen_from_emitted_program(
747738
"Please run `//executorch/sdk/edir/tests:generate_fixtures` to regenerate the fixtures.",
748739
)
749740

750-
# pyre-ignore
751-
@parameterized.expand(MODELS)
752-
def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
753-
op_graph = generate_op_graph(model, model.get_random_inputs())
754-
inference_run = model.gen_inference_run()
755-
op_graph.attach_metadata(inference_run)
756-
757-
def verify_metadata_containment(
758-
graph: ExportedETOperatorGraph, inference_run: InferenceRun
759-
) -> None:
760-
validation_map = inference_run.node_metadata
761-
762-
for node in graph.elements:
763-
# Recursively check subgraph nodes
764-
if isinstance(node, ExportedETOperatorGraph):
765-
verify_metadata_containment(node, inference_run)
766-
# Check that each node contains the corresponding metadata fields
767-
if isinstance(node, OperatorNode) and node.metadata is not None:
768-
metadata = node.metadata
769-
debug_handle = metadata.get("debug_handle")
770-
if debug_handle in validation_map:
771-
self.assertDictContainsSubset(
772-
validation_map[debug_handle], metadata
773-
)
774-
775-
# Check for run level metadata
776-
if op_graph.metadata is not None:
777-
self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
778-
779-
verify_metadata_containment(op_graph, inference_run)
741+
# @parameterized.expand(MODELS)
742+
# def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
743+
# op_graph = generate_op_graph(model, model.get_random_inputs())
744+
# inference_run = model.gen_inference_run()
745+
# op_graph.attach_metadata(inference_run)
746+
747+
# def verify_metadata_containment(
748+
# graph: ExportedETOperatorGraph, inference_run: InferenceRun
749+
# ) -> None:
750+
# validation_map = inference_run.node_metadata
751+
752+
# for node in graph.elements:
753+
# # Recursively check subgraph nodes
754+
# if isinstance(node, ExportedETOperatorGraph):
755+
# verify_metadata_containment(node, inference_run)
756+
# # Check that each node contains the corresponding metadata fields
757+
# if isinstance(node, OperatorNode) and node.metadata is not None:
758+
# metadata = node.metadata
759+
# debug_handle = metadata.get("debug_handle")
760+
# if debug_handle in validation_map:
761+
# self.assertDictContainsSubset(
762+
# validation_map[debug_handle], metadata
763+
# )
764+
765+
# # Check for run level metadata
766+
# if op_graph.metadata is not None:
767+
# self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
768+
769+
# verify_metadata_containment(op_graph, inference_run)
780770

781771

782772
class InferenceRunTest(unittest.TestCase):

sdk/inspector/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ python_library(
1515
":inspector_utils",
1616
"//caffe2:torch",
1717
"//executorch/exir:lib",
18-
"//executorch/sdk/edir:et_schema",
18+
"//executorch/sdk/edir:base_schema",
1919
"//executorch/sdk/etdump:schema_flatcc",
2020
],
2121
)
@@ -26,6 +26,7 @@ python_library(
2626
"_inspector_utils.py",
2727
],
2828
deps = [
29+
"//executorch/sdk/edir:base_schema",
2930
"//executorch/sdk/edir:et_schema",
3031
"//executorch/sdk/etdump:schema_flatcc",
3132
"//executorch/sdk/etdump:serialize",

sdk/inspector/_inspector_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66

77
from typing import Dict, Mapping, Optional
88

9-
from executorch.sdk.edir.et_schema import (
10-
FXOperatorGraph,
11-
OperatorGraphWithStats,
12-
OperatorNode,
13-
)
9+
from executorch.sdk.edir.base_schema import OperatorNode
10+
11+
from executorch.sdk.edir.et_schema import FXOperatorGraph, OperatorGraph
1412
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC
1513

1614
from executorch.sdk.etdump.serialize import deserialize_from_etdump_flatcc
@@ -21,7 +19,7 @@
2119

2220
def gen_graphs_from_etrecord(
2321
etrecord: ETRecord,
24-
) -> Mapping[str, OperatorGraphWithStats]:
22+
) -> Mapping[str, OperatorGraph]:
2523
op_graph_map = {}
2624
if etrecord.graph_map is not None:
2725
op_graph_map = {
@@ -39,7 +37,7 @@ def gen_graphs_from_etrecord(
3937
# TODO: use anonymous function to avoid passing the dict around
4038
# and move this inside of the OperatorGraphWithStats class
4139
def create_debug_handle_to_op_node_mapping(
42-
op_graph: OperatorGraphWithStats,
40+
op_graph: OperatorGraph,
4341
debug_handle_to_op_node_map: Dict[int, OperatorNode],
4442
) -> None:
4543
"""
@@ -48,7 +46,7 @@ def create_debug_handle_to_op_node_mapping(
4846
"""
4947
# Recursively searches through the metadata of nodes
5048
for element in op_graph.elements:
51-
if isinstance(element, OperatorGraphWithStats):
49+
if isinstance(element, OperatorGraph):
5250
create_debug_handle_to_op_node_mapping(element, debug_handle_to_op_node_map)
5351
if isinstance(element, OperatorNode) and element.metadata is not None:
5452
metadata = element.metadata

sdk/inspector/inspector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import torch
2626
from executorch.exir import ExportedProgram
2727

28-
from executorch.sdk.edir.et_schema import OperatorGraphWithStats, OperatorNode
28+
from executorch.sdk.edir.base_schema import OperatorGraph, OperatorNode
2929
from executorch.sdk.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
3030
from executorch.sdk.inspector._inspector_utils import (
3131
create_debug_handle_to_op_node_mapping,
@@ -392,9 +392,9 @@ def __init__(
392392
etdump = gen_etdump_object(etdump_path=etdump_path)
393393
self.event_blocks = EventBlock._gen_from_etdump(etdump, etdump_scale)
394394

395-
self._op_graph_dict: Mapping[
396-
str, OperatorGraphWithStats
397-
] = gen_graphs_from_etrecord(etrecord=self._etrecord)
395+
self._op_graph_dict: Mapping[str, OperatorGraph] = gen_graphs_from_etrecord(
396+
etrecord=self._etrecord
397+
)
398398

399399
# Use the delegate map from etrecord, associate debug handles with each event
400400
for event_block in self.event_blocks:

sdk/inspector/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ python_unittest(
2727
name = "inspector_utils_test",
2828
srcs = ["inspector_utils_test.py"],
2929
deps = [
30+
"//executorch/sdk/edir:base_schema",
3031
"//executorch/sdk/edir:et_schema",
3132
"//executorch/sdk/etrecord:etrecord",
3233
"//executorch/sdk/etrecord/tests:etrecord_test_library",

sdk/inspector/tests/inspector_utils_test.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88
import unittest
99
from typing import Dict, Tuple
1010

11-
from executorch.sdk.edir.et_schema import (
12-
FXOperatorGraph,
13-
OperatorGraphWithStats,
14-
OperatorNode,
15-
ValueNode,
16-
)
11+
from executorch.sdk.edir.base_schema import OperatorGraph, OperatorNode, ValueNode
12+
13+
from executorch.sdk.edir.et_schema import FXOperatorGraph
1714
from executorch.sdk.etrecord import generate_etrecord, parse_etrecord
1815

1916
from executorch.sdk.etrecord.tests.etrecord_test import TestETRecord
@@ -62,7 +59,7 @@ def test_create_debug_handle_to_op_node_mapping(self):
6259

6360

6461
def gen_mock_operator_graph_with_expected_map() -> Tuple[
65-
OperatorGraphWithStats, Dict[int, OperatorNode]
62+
OperatorGraph, Dict[int, OperatorNode]
6663
]:
6764
# Make a mock OperatorGraph instance for testing
6865
node_input = ValueNode("input")
@@ -113,7 +110,7 @@ def gen_mock_operator_graph_with_expected_map() -> Tuple[
113110
mapping[444] = node_div
114111
node_output = ValueNode("output", [node_div])
115112
return (
116-
OperatorGraphWithStats(
113+
OperatorGraph(
117114
"mock_et_model",
118115
[
119116
node_input,

0 commit comments

Comments
 (0)