|
20 | 20 | from executorch.exir.backend.backend_api import to_backend
|
21 | 21 | from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
|
22 | 22 | from executorch.exir.emit import emit_program # noqa
|
| 23 | +from executorch.exir.error import InternalError |
23 | 24 | from executorch.exir.passes.constant_prop_pass import constant_prop_pass
|
24 | 25 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
|
25 | 26 | from executorch.exir.print_program import pretty_print, print_program # noqa
|
@@ -836,6 +837,47 @@ def _compare_execution_plans(
|
836 | 837 | else:
|
837 | 838 | self.assertEqual(single_val, merged_val)
|
838 | 839 |
|
| 840 | + def test_emit_memory_format(self) -> None: |
| 841 | + class SimpleLinear(torch.nn.Module): |
| 842 | + def __init__(self) -> None: |
| 843 | + super().__init__() |
| 844 | + |
| 845 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 846 | + contiguous = x.to( |
| 847 | + dtype=torch.float32, memory_format=torch.contiguous_format |
| 848 | + ) |
| 849 | + preserve = x.to( |
| 850 | + dtype=torch.float32, memory_format=torch.preserve_format |
| 851 | + ) |
| 852 | + return contiguous + preserve |
| 853 | + |
| 854 | + # Should succeed at exporting model with legal memory format (contiguous, preserve) |
| 855 | + model = SimpleLinear() |
| 856 | + inputs = (torch.ones(10, 5),) |
| 857 | + try: |
| 858 | + to_edge( |
| 859 | + export(model, inputs), |
| 860 | + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), |
| 861 | + ).to_executorch() |
| 862 | + except: |
| 863 | + self.fail("Failed to export model with legal memory format") |
| 864 | + |
| 865 | + class SimpleLinear2(torch.nn.Module): |
| 866 | + def __init__(self) -> None: |
| 867 | + super().__init__() |
| 868 | + |
| 869 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 870 | + return x.to(dtype=torch.float32, memory_format=torch.channels_last) |
| 871 | + |
| 872 | + # Failure Expected when exporting model with illegal memory format (channels_last) |
| 873 | + model = SimpleLinear2() |
| 874 | + inputs = (torch.ones(10, 5, 2, 1),) |
| 875 | + with self.assertRaises(InternalError): |
| 876 | + to_edge( |
| 877 | + export(model, inputs), |
| 878 | + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), |
| 879 | + ).to_executorch() |
| 880 | + |
839 | 881 | def test_emit_multiple_entry_points(self) -> None:
|
840 | 882 | class SimpleLinear(torch.nn.Module):
|
841 | 883 | def __init__(self) -> None:
|
|
0 commit comments