Skip to content

Replace view_copy with view (2/3) #2462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions exir/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,12 @@ def free(spec: TensorSpec) -> None:
E.g., it can be the target of call_function node.
"""
pass


def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
"""
This function mimics torch.ops.aten.view.default.
It is used to elide view_copy nodes.
"""
return base.view(size)
24 changes: 24 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,27 @@ python_library(
"//executorch/exir/dialects/edge:lib",
],
)

python_library(
name = "normalize_view_copy_base_pass",
srcs = [
"normalize_view_copy_base_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "replace_view_copy_with_view_pass",
srcs = [
"replace_view_copy_with_view_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir:tensor",
"//executorch/exir/dialects:lib",
],
)
64 changes: 64 additions & 0 deletions exir/passes/normalize_view_copy_base_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

import torch

from executorch.exir.dialects._ops import ops
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def _is_view_copy(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.view_copy.default,
ops.edge.aten.view_copy.default,
)


class NormalizeViewCopyBasePass(PassBase):
"""
Point each view_copy to the first upstream non-view.

After this pass, the base of each view_copy is not a view_copy.

When combined with dead-code elimination, this pass removes redundant
view_copy nodes.

TODO: replace RemoveRedundantViewCopyPass with NormalizeViewCopyBasePass + dead code elimination.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
n_updated = 0
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, size = node.args
if _is_view_copy(base):
# Point base to bases's base and update node's args
# Base's base will not be a view_copy because we iterate
# through the graph in topological order, replacing as we go.
base = base.args[0]
node.args = (base, size)
n_updated += 1

module.recompile()

logging.debug(f"Updated the base on {n_updated} view_copy nodes.")
return PassResult(graph_module, n_updated > 0)

def ensures(self, graph_module: torch.fx.GraphModule) -> None:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, size = node.args
assert not _is_view_copy(base)
189 changes: 189 additions & 0 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
import math
from typing import Any, Dict, List, Tuple

import torch
from executorch.exir import memory

from executorch.exir.dialects._ops import ops
from executorch.exir.tensor import (
contiguous_stride_from_shape,
determine_tensor_dynanism,
dim_order_from_stride,
TensorShapeDynamism,
TensorSpec,
)
from torch.fx.passes.infra.pass_base import PassBase, PassResult

logger: logging.Logger = logging.getLogger(__name__)


def _is_view_copy(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.view_copy.default,
ops.edge.aten.view_copy.default,
)


_VIEW_OP = memory.view


class _ViewSpec(TensorSpec):
def __init__(self, base: TensorSpec, shape: List[int]) -> None:
"""
A ViewSpec is an immutable TensorSpec that mirrors its base for non-size
related information.
"""

if math.prod(base.shape) != math.prod(shape):
raise Exception(
f"Cannot create a ViewSpec because the provided shape {shape} is not consistent with the number of elements in the provided base ({math.prod(base.shape)})."
)

self._init_setters = [
"_frozen",
"_base",
"_guards",
"shape",
"stride",
"dim_order",
"shape_dynamism",
]
self._frozen = False
self._base = base
self.shape: List[int] = shape
self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape))
self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(
torch.Size(self.shape)
)

# This spec gives a view into its base.
# The base can be modified (e.g., mem_id) and this spec will
# update accordingly, but certain fields we do not expect to change
# We create guards for these
self._guards: Dict[str, Any] = {
"shape_dynamism": base.shape_dynamism,
"scalar_type": base.scalar_type,
"layout": base.layout,
"is_sparse": base.is_sparse,
}
self._frozen = True

def _check_guards(self) -> None:
for name in self._guards:
if getattr(self._base, name) != self._guards[name]:
raise Exception(
f"The guarded attribute '{name}' has changed value. At creation of the ViewSpec, it was {self._guards[name]}, but it is now {getattr(self._base, name)}."
)

def __getattribute__(self, name): # pyre-ignore
if name in [
"_init_setters",
"_frozen",
"_base",
"_guards",
"_check_guards",
# Adding debug is needed so that view_spec.debug() shows the right id in
# its string (if debug is excluded, it shows the id(view_spec._base) instead
# of id(view_spec))
"debug",
]:
return object.__getattribute__(self, name)

# Guard check after freeze
if self._frozen:
self._check_guards()

# self._init_setters attributes come from self, others come from base
if name in self._init_setters:
return object.__getattribute__(self, name)
return getattr(self._base, name)

def __setattr__(self, name: str, val) -> None: # pyre-ignore
if name in ["_init_setters", "_frozen"]:
object.__setattr__(self, name, val)
return

# Allow setting during initialization
if name in self._init_setters and not self._frozen:
object.__setattr__(self, name, val)
return

if name in self._init_setters:
raise Exception(
f"ViewSpec is immutable. Cannot set the attribute '{name}' after creation."
)

raise Exception(
f"ViewSpec is immutable. To update the non-size related attribute '{name}', update the base."
)


class ReplaceViewCopyWithViewPass(PassBase):
def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
"""
This pass replaces view_copy nodes with view nodes.
This should be run after the NormalizeViewCopyBasePass.
During memory planning, view nodes share the same storage as their base.
"""

n_replaced = 0
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, _ = node.args
node.target = _VIEW_OP

# Create spec for the node.
# _ViewSpec is an immutable TensorSpec gives a view into
# its base spec for non-size related information.

# the shape is not the same as node.args[1] because node.args[1]
# can have an inferred sizes (-1).
shape = node.meta["val"].shape
node.meta["spec"] = _ViewSpec(base.meta["spec"], shape)

n_replaced += 1

module.recompile()

logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.")
return PassResult(graph_module, n_replaced > 0)

def ensures(self, graph_module: torch.fx.GraphModule) -> None:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
assert not _is_view_copy(node)
if node.op == "call_function" and node.target == _VIEW_OP:
assert isinstance(node.meta["spec"], _ViewSpec)

def requires(self, graph_module: torch.fx.GraphModule) -> None:
"""
This pass should be called after NormalizeViewCopyBasePass.
We check that all view_copy nodes have been normalized.
"""
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if _is_view_copy(node):
base, size = node.args
assert not _is_view_copy(base)
3 changes: 3 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ python_unittest(
"//executorch/exir:memory",
"//executorch/exir:memory_planning",
"//executorch/exir:pass_base",
"//executorch/exir:schema",
"//executorch/exir:tensor",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
Expand All @@ -216,9 +217,11 @@ python_unittest(
"//executorch/exir/passes:debug_handle_generator_pass",
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_edge_with_backend_pass",
"//executorch/exir/passes:replace_view_copy_with_view_pass",
"//executorch/exir/passes:scalar_to_tensor_pass",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:sym_to_tensor_pass",
Expand Down
Loading