Skip to content

Commit 1ac53ee

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Introduce dim_order edge dialect op for .to()
Summary: These ops takes dim_order instead of memory_format. These are new edge dialect only ops for now. In the future we are planning to introduce these into aten (and the aten cope opset). Reviewed By: larryliu0820 Differential Revision: D48195056 fbshipit-source-id: 2ba97024647553ae99517bfbc42f1aba78395c13
1 parent 7b330cf commit 1ac53ee

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

exir/passes/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,12 @@ python_library(
255255
"//executorch/exir/dialects:lib",
256256
],
257257
)
258+
259+
python_library(
260+
name = "dim_order_ops_registry",
261+
srcs = ["dim_order_ops_registry.py"],
262+
deps = [
263+
"//caffe2:torch",
264+
"//executorch/exir:dim_order_utils",
265+
],
266+
)

exir/passes/dim_order_ops_registry.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dim_order_utils import get_memory_format
10+
11+
from torch.library import impl, Library
12+
13+
lib = Library("dim_order_ops", "DEF")
14+
lib.define(
15+
"_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
16+
)
17+
18+
# Out variant drops TensorOptions
19+
lib.define(
20+
"_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
21+
)
22+
23+
24+
def _op_impl(target, *args, **kwargs):
25+
kwargs["memory_format"] = get_memory_format(kwargs["dim_order"])
26+
_ = kwargs.pop("dim_order")
27+
res = target(*args, **kwargs)
28+
# assert list(res.dim_order()) == dim_order
29+
return res
30+
31+
32+
@impl(lib, "_to_dim_order_copy", "CompositeImplicitAutograd")
33+
def _to_dim_order_copy_impl(*args, **kwargs):
34+
return _op_impl(torch.ops.aten._to_copy, *args, **kwargs)
35+
36+
37+
@impl(lib, "_to_dim_order_copy.out", "CompositeImplicitAutograd")
38+
def _to_dim_order_copy_out_impl(*args, **kwargs):
39+
return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)
40+
41+
42+
"""
43+
Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
44+
"""
45+
DimOrderOpsMap = {
46+
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
47+
}

0 commit comments

Comments
 (0)