@@ -725,17 +725,17 @@ def count_nodes(graph_module, target):
725
725
)
726
726
727
727
def test_edge_dialect_non_core_aten_ops (self ):
728
- class LinalgNorm (torch .nn .Module ):
728
+ class LinalgRank (torch .nn .Module ):
729
729
def __init__ (self ):
730
730
super ().__init__ ()
731
731
732
732
def forward (self , x : torch .Tensor ) -> torch .Tensor :
733
- return torch .linalg .norm (x )
733
+ return torch .linalg .matrix_rank (x )
734
734
735
735
from torch ._export .verifier import SpecViolationError
736
736
737
- input = torch .arange ( 9 , dtype = torch .float ) - 4
738
- ep = torch .export .export (LinalgNorm (), (input ,), strict = True )
737
+ input = torch .ones (( 9 , 9 , 9 ), dtype = torch .float )
738
+ ep = torch .export .export (LinalgRank (), (input ,), strict = True )
739
739
740
740
# aten::linalg_norm is not a core op, so it should error out
741
741
with self .assertRaises (SpecViolationError ):
@@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
748
748
ep ,
749
749
compile_config = EdgeCompileConfig (
750
750
_check_ir_validity = True ,
751
- _core_aten_ops_exception_list = [
752
- torch .ops .aten .linalg_vector_norm .default
753
- ],
751
+ _core_aten_ops_exception_list = [torch .ops .aten ._linalg_svd .default ],
754
752
),
755
753
)
756
754
except SpecViolationError :
0 commit comments