Skip to content

Commit 150c8f7

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Adding Support for preserve_format when converting Evalues (#2242)
Summary: This diff adds support for `torch.preserve_format` when casting constants to Evalues. --- Differential Revision: D54288165
1 parent e7197a1 commit 150c8f7

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

exir/emit/_emitter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,11 @@ def _constant_to_evalue( # noqa: C901
462462
return EValue(Int(layout_enum(val)))
463463

464464
if isinstance(val, torch.memory_format):
465-
if val != torch.contiguous_format:
465+
if val != torch.contiguous_format and val != torch.preserve_format:
466466
raise InternalError(
467467
self._emit_node_specific_error(
468468
self.node,
469-
"Non contiguous tensors are not supported in ExecuTorch",
469+
"Only Tensors that have a contiguous or preserve_format memory_format are supported in ExecuTorch",
470470
)
471471
)
472472
return EValue(Int(memory_format_enum(val)))

exir/emit/test/test_emit.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.exir.backend.backend_api import to_backend
2121
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2222
from executorch.exir.emit import emit_program # noqa
23+
from executorch.exir.error import InternalError
2324
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
2425
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2526
from executorch.exir.print_program import pretty_print, print_program # noqa
@@ -836,6 +837,47 @@ def _compare_execution_plans(
836837
else:
837838
self.assertEqual(single_val, merged_val)
838839

840+
def test_emit_memory_format(self) -> None:
841+
class SimpleLinear(torch.nn.Module):
842+
def __init__(self) -> None:
843+
super().__init__()
844+
845+
def forward(self, x: torch.Tensor) -> torch.Tensor:
846+
contiguous = x.to(
847+
dtype=torch.float32, memory_format=torch.contiguous_format
848+
)
849+
preserve = x.to(
850+
dtype=torch.float32, memory_format=torch.preserve_format
851+
)
852+
return contiguous + preserve
853+
854+
# Should succeed at exporting model with legal memory format (contiguous, preserve)
855+
model = SimpleLinear()
856+
inputs = (torch.ones(10, 5),)
857+
try:
858+
to_edge(
859+
export(model, inputs),
860+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
861+
).to_executorch()
862+
except:
863+
self.fail("Failed to export model with legal memory format")
864+
865+
class SimpleLinear2(torch.nn.Module):
866+
def __init__(self) -> None:
867+
super().__init__()
868+
869+
def forward(self, x: torch.Tensor) -> torch.Tensor:
870+
return x.to(dtype=torch.float32, memory_format=torch.channels_last)
871+
872+
# Failure Expected when exporting model with illegal memory format (channels_last)
873+
model = SimpleLinear2()
874+
inputs = (torch.ones(10, 5, 2, 1),)
875+
with self.assertRaises(InternalError):
876+
to_edge(
877+
export(model, inputs),
878+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
879+
).to_executorch()
880+
839881
def test_emit_multiple_entry_points(self) -> None:
840882
class SimpleLinear(torch.nn.Module):
841883
def __init__(self) -> None:

exir/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
231231
)
232232
table = {
233233
torch.contiguous_format: 0,
234+
torch.preserve_format: 1,
234235
}
235236
return table[memory_format]
236237

0 commit comments

Comments
 (0)