1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
1
+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
6
2
7
3
import copy
8
4
import unittest
18
14
ExirExportedProgram ,
19
15
)
20
16
from executorch .exir .lowered_backend_module import LoweredBackendModule
21
- from executorch .sdk .edir .base_schema import OperatorNode
22
- from executorch .sdk .edir .et_schema import (
23
- ExportedETOperatorGraph ,
24
- FXOperatorGraph ,
25
- InferenceRun ,
26
- )
17
+ from executorch .sdk .edir .et_schema import FXOperatorGraph
27
18
from executorch .sdk .etdump .schema import ETDump , ProfileBlock , ProfileEvent , RunData
19
+ from executorch .sdk .fb .et_schema import ExportedETOperatorGraph , InferenceRun
28
20
29
21
from parameterized import parameterized
30
22
from torch import Tensor
@@ -697,37 +689,36 @@ def test_gen_from_fx_graph(self, model_name: str, model: torch.nn.Module) -> Non
697
689
op_graph = gen_fx_graph_file_contents (et_program .dump_graph_module ())
698
690
self .check_graph_equal (op_graph , model_name , "et_dialect" )
699
691
700
- # pyre-ignore
701
- @parameterized .expand (MODELS )
702
- def test_metadata_attaching (self , model_name : str , model : torch .nn .Module ) -> None :
703
- _ , _ , et_program = gen_graphs_from_model (model )
704
- op_graph = FXOperatorGraph .gen_operator_graph (et_program .dump_graph_module ())
705
- inference_run = model .gen_inference_run ()
706
- op_graph .attach_metadata (inference_run )
707
-
708
- def verify_metadata_containment (
709
- graph : FXOperatorGraph , inference_run : InferenceRun
710
- ) -> None :
711
- validation_map = inference_run .node_metadata
712
-
713
- for node in graph .elements :
714
- # Recursively check subgraph nodes
715
- if isinstance (node , FXOperatorGraph ):
716
- verify_metadata_containment (node , inference_run )
717
- # Check that each node contains the corresponding metadata fields
718
- if isinstance (node , OperatorNode ) and node .metadata is not None :
719
- metadata = node .metadata
720
- debug_handle = metadata .get ("debug_handle" )
721
- if debug_handle in validation_map :
722
- self .assertDictContainsSubset (
723
- validation_map [debug_handle ], metadata
724
- )
725
-
726
- # Check for run level metadata
727
- if op_graph .metadata is not None :
728
- self .assertDictContainsSubset (inference_run .run_metadata , op_graph .metadata )
729
-
730
- verify_metadata_containment (op_graph , inference_run )
692
+ # @parameterized.expand(MODELS)
693
+ # def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
694
+ # _, _, et_program = gen_graphs_from_model(model)
695
+ # op_graph = FXOperatorGraph.gen_operator_graph(et_program.dump_graph_module())
696
+ # inference_run = model.gen_inference_run()
697
+ # op_graph.attach_metadata(inference_run)
698
+
699
+ # def verify_metadata_containment(
700
+ # graph: FXOperatorGraph, inference_run: InferenceRun
701
+ # ) -> None:
702
+ # validation_map = inference_run.node_metadata
703
+
704
+ # for node in graph.elements:
705
+ # # Recursively check subgraph nodes
706
+ # if isinstance(node, FXOperatorGraph):
707
+ # verify_metadata_containment(node, inference_run)
708
+ # # Check that each node contains the corresponding metadata fields
709
+ # if isinstance(node, OperatorNode) and node.metadata is not None:
710
+ # metadata = node.metadata
711
+ # debug_handle = metadata.get("debug_handle")
712
+ # if debug_handle in validation_map:
713
+ # self.assertDictContainsSubset(
714
+ # validation_map[debug_handle], metadata
715
+ # )
716
+
717
+ # # Check for run level metadata
718
+ # if op_graph.metadata is not None:
719
+ # self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
720
+
721
+ # verify_metadata_containment(op_graph, inference_run)
731
722
732
723
733
724
class ExportedOpGraphTest (unittest .TestCase ):
@@ -747,36 +738,35 @@ def test_gen_from_emitted_program(
747
738
"Please run `//executorch/sdk/edir/tests:generate_fixtures` to regenerate the fixtures." ,
748
739
)
749
740
750
- # pyre-ignore
751
- @parameterized .expand (MODELS )
752
- def test_metadata_attaching (self , model_name : str , model : torch .nn .Module ) -> None :
753
- op_graph = generate_op_graph (model , model .get_random_inputs ())
754
- inference_run = model .gen_inference_run ()
755
- op_graph .attach_metadata (inference_run )
756
-
757
- def verify_metadata_containment (
758
- graph : ExportedETOperatorGraph , inference_run : InferenceRun
759
- ) -> None :
760
- validation_map = inference_run .node_metadata
761
-
762
- for node in graph .elements :
763
- # Recursively check subgraph nodes
764
- if isinstance (node , ExportedETOperatorGraph ):
765
- verify_metadata_containment (node , inference_run )
766
- # Check that each node contains the corresponding metadata fields
767
- if isinstance (node , OperatorNode ) and node .metadata is not None :
768
- metadata = node .metadata
769
- debug_handle = metadata .get ("debug_handle" )
770
- if debug_handle in validation_map :
771
- self .assertDictContainsSubset (
772
- validation_map [debug_handle ], metadata
773
- )
774
-
775
- # Check for run level metadata
776
- if op_graph .metadata is not None :
777
- self .assertDictContainsSubset (inference_run .run_metadata , op_graph .metadata )
778
-
779
- verify_metadata_containment (op_graph , inference_run )
741
+ # @parameterized.expand(MODELS)
742
+ # def test_metadata_attaching(self, model_name: str, model: torch.nn.Module) -> None:
743
+ # op_graph = generate_op_graph(model, model.get_random_inputs())
744
+ # inference_run = model.gen_inference_run()
745
+ # op_graph.attach_metadata(inference_run)
746
+
747
+ # def verify_metadata_containment(
748
+ # graph: ExportedETOperatorGraph, inference_run: InferenceRun
749
+ # ) -> None:
750
+ # validation_map = inference_run.node_metadata
751
+
752
+ # for node in graph.elements:
753
+ # # Recursively check subgraph nodes
754
+ # if isinstance(node, ExportedETOperatorGraph):
755
+ # verify_metadata_containment(node, inference_run)
756
+ # # Check that each node contains the corresponding metadata fields
757
+ # if isinstance(node, OperatorNode) and node.metadata is not None:
758
+ # metadata = node.metadata
759
+ # debug_handle = metadata.get("debug_handle")
760
+ # if debug_handle in validation_map:
761
+ # self.assertDictContainsSubset(
762
+ # validation_map[debug_handle], metadata
763
+ # )
764
+
765
+ # # Check for run level metadata
766
+ # if op_graph.metadata is not None:
767
+ # self.assertDictContainsSubset(inference_run.run_metadata, op_graph.metadata)
768
+
769
+ # verify_metadata_containment(op_graph, inference_run)
780
770
781
771
782
772
class InferenceRunTest (unittest .TestCase ):
0 commit comments