Skip to content

Commit 28a213f

Browse files
Remove presere ops
Differential Revision: D64151426 Pull Request resolved: #6360
1 parent dfbf6fd commit 28a213f

File tree

5 files changed

+19
-14
lines changed

5 files changed

+19
-14
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d1b87e26e5c4343f5b56bb1e6f89b479b389bfac
1+
export-D64151426

exir/program/_program.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -925,9 +925,12 @@ def _gen_edge_manager_for_partitioners(
925925
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
926926
all_ops_no_decomp |= set(curr_ops_no_decomp)
927927

928-
program = program.run_decompositions(
929-
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
930-
)
928+
table = _default_decomposition_table()
929+
930+
for op in all_ops_no_decomp:
931+
table.pop(op, None)
932+
933+
program = program.run_decompositions(table)
931934
# Among all the preserved aten ops, use the check_op_fn to do an additional
932935
# check on which ops need to be preserved and which ops need to be decomposed
933936
# Those which are truly preserved will be replaced with transformed ops
@@ -1097,9 +1100,10 @@ def to_edge_with_preserved_ops(
10971100

10981101
for name, program in aten_programs.items():
10991102
# Decompose to Core ATen
1100-
program = program.run_decompositions(
1101-
_default_decomposition_table(), _preserve_ops=preserve_ops
1102-
)
1103+
table = _default_decomposition_table()
1104+
for op in preserve_ops:
1105+
table.pop(op, None)
1106+
program = program.run_decompositions(table)
11031107
edge_programs[name] = _generate_edge_program(
11041108
name, config, program, list(preserve_ops)
11051109
)

exir/program/test/test_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,10 +573,10 @@ def get_num_nondecomposed_ops(self, ep, partitioner):
573573
# which pass the filter_ops fn given by the partitioner
574574
reference_ep = copy.deepcopy(ep)
575575
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
576-
reference_decomp_ep = reference_ep.run_decompositions(
577-
decomp_table=_default_decomposition_table(),
578-
_preserve_ops=tuple(aten_ops_not_decomposed),
579-
)
576+
table = _default_decomposition_table()
577+
for op in aten_ops_not_decomposed:
578+
table.pop(op, None)
579+
reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table)
580580
num_non_decomposed_aten_ops = 0
581581
for node in reference_decomp_ep.graph.nodes:
582582
if (

exir/tracer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@
4444
from executorch.exir.types import ValueSpec
4545

4646
from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual
47-
from torch._decomp import core_aten_decompositions, get_decompositions
47+
from torch._decomp import get_decompositions
4848
from torch._dynamo.guards import Guard
4949
from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
50+
from torch.export import default_decompositions
5051
from torch.func import functionalize
5152
from torch.fx.operator_schemas import normalize_function
5253
from torch.utils._pytree import TreeSpec
@@ -631,7 +632,7 @@ def _default_decomposition_table(
631632
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
632633
return get_decompositions(decomp_opset)
633634
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
634-
return core_aten_decompositions()
635+
return default_decompositions()
635636

636637

637638
def dynamo_trace(

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def python_is_compatible():
106106
# NOTE: If a newly-fetched version of the executorch repo changes the value of
107107
# NIGHTLY_VERSION, you should re-run this script to install the necessary
108108
# package versions.
109-
NIGHTLY_VERSION = "dev20241007"
109+
NIGHTLY_VERSION = "dev20241019"
110110

111111
# The pip repository that hosts nightly torch packages.
112112
TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu"

0 commit comments

Comments
 (0)