Skip to content

Commit 23c8172

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Adding Support for preserve_format when converting Evalues (#2242)
Summary: Pull Request resolved: #2242 This diff adds support for `torch.preserve_format` when casting constants to Evalues. --- Reviewed By: dbort Differential Revision: D54288165 fbshipit-source-id: 3fc2772ecadb9d65e16d3b23ee7ae02fb7a9259c
1 parent f6d0d49 commit 23c8172

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed

exir/emit/_emitter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,15 @@ def _constant_to_evalue( # noqa: C901
463463
return EValue(Int(layout_enum(val)))
464464

465465
if isinstance(val, torch.memory_format):
466-
if val != torch.contiguous_format:
466+
try:
467+
return EValue(Int(memory_format_enum(val)))
468+
except KeyError:
467469
raise InternalError(
468470
self._emit_node_specific_error(
469471
self.node,
470-
"Non contiguous tensors are not supported in ExecuTorch",
472+
f"Tensor has a memory_format that is unsupported in ExecuTorch: {val}",
471473
)
472474
)
473-
return EValue(Int(memory_format_enum(val)))
474475

475476
if isinstance(val, torch.Tensor):
476477
raise ExportError(

exir/emit/test/test_emit.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.emit import emit_program # noqa
24+
from executorch.exir.error import InternalError
2425
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
2526
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2627
from executorch.exir.print_program import pretty_print, print_program # noqa
@@ -838,6 +839,48 @@ def _compare_execution_plans(
838839
else:
839840
self.assertEqual(single_val, merged_val)
840841

842+
def test_emit_memory_format_valid(self) -> None:
843+
class SimpleLinear(torch.nn.Module):
844+
def __init__(self) -> None:
845+
super().__init__()
846+
847+
def forward(self, x: torch.Tensor) -> torch.Tensor:
848+
contiguous = x.to(
849+
dtype=torch.float32, memory_format=torch.contiguous_format
850+
)
851+
preserve = x.to(
852+
dtype=torch.float32, memory_format=torch.preserve_format
853+
)
854+
return contiguous + preserve
855+
856+
# Should succeed at exporting model with legal memory format (contiguous, preserve)
857+
model = SimpleLinear()
858+
inputs = (torch.ones(10, 5),)
859+
try:
860+
to_edge(
861+
export(model, inputs),
862+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
863+
).to_executorch()
864+
except:
865+
self.fail("Failed to export model with legal memory format")
866+
867+
def test_emit_memory_format_invalid(self) -> None:
868+
class SimpleLinear(torch.nn.Module):
869+
def __init__(self) -> None:
870+
super().__init__()
871+
872+
def forward(self, x: torch.Tensor) -> torch.Tensor:
873+
return x.to(dtype=torch.float32, memory_format=torch.channels_last)
874+
875+
# Failure expected when exporting model with illegal memory format (channels_last)
876+
model = SimpleLinear()
877+
inputs = (torch.ones(10, 5, 2, 1),)
878+
with self.assertRaises(InternalError):
879+
to_edge(
880+
export(model, inputs),
881+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
882+
).to_executorch()
883+
841884
def test_emit_multiple_entry_points(self) -> None:
842885
class SimpleLinear(torch.nn.Module):
843886
def __init__(self) -> None:

exir/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
231231
)
232232
table = {
233233
torch.contiguous_format: 0,
234+
torch.preserve_format: 1,
234235
}
235236
return table[memory_format]
236237

runtime/core/portable_type/tensor_options.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,21 @@ namespace torch {
1414
namespace executor {
1515

1616
/**
17-
* Tensor data memory format. This concept only exists for compatibility
18-
* with ATen.
17+
* Tensor data memory formats supported by ExecuTorch. This concept only exists
18+
* for compatibility with ATen; use dim_order to describe non-contiguous
19+
* layouts.
1920
*/
2021
enum class MemoryFormat : int8_t {
2122
/**
22-
* Row-major contiguous data format.
23-
*
24-
* This is the only format supported by ExecuTorch. Use dim orders to
25-
* describe other layouts.
23+
* Row-major contiguous data.
24+
*/
25+
Contiguous = 0,
26+
/**
27+
* Output tensor format should remain the same as the input tensor format.
28+
* E.g. if the input tensor is in channels_last format, operator output
29+
* should be in channels_last format.
2630
*/
27-
Contiguous,
31+
Preserve = 1,
2832
};
2933

3034
/**
@@ -39,7 +43,7 @@ enum class Layout : int8_t {
3943
*
4044
* This is the only layout supported by ExecuTorch.
4145
*/
42-
Strided,
46+
Strided = 0,
4347
};
4448
} // namespace executor
4549
} // namespace torch

0 commit comments

Comments
 (0)