Skip to content

Support tensors in prim_getters #3203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 2 additions & 80 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import executorch.extension.pytree as ex_pytree
import torch
import torch.fx
from executorch.exir.emit._emitter import (
Expand All @@ -18,89 +17,12 @@
_TopLevelEmitter,
)
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.schema import (
Bool,
Chain,
ContainerMetadata,
Double,
EValue,
ExecutionPlan,
Int,
Program,
String,
SubsegmentOffsets,
)
from executorch.exir.tensor import layout_enum, scalar_type_enum
from executorch.exir.schema import Program, SubsegmentOffsets
from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION
from torch.export.exported_program import ExportedProgram, OutputKind
from torch.utils import _pytree as pytree


def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
"""
Given a mapping of function names to return values, emit simple execution
plans that just return these constant values.

Precondition: All the values are primitives (bool, float, int, str, enum)
or structures (list, dict) of them.
"""
plans = []
# flatten any structures
for method, vals in prim_getters.items():
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
flattened_output, spec = ex_pytree.tree_flatten(vals)
spec = spec.to_str()
chain = Chain(
inputs=[],
outputs=[],
instructions=[],
stacktrace=None,
)

# switch on type of prim
values = []
for val in flattened_output:
if isinstance(val, float):
values.append(EValue(Double(val)))

elif isinstance(val, bool):
values.append(EValue(Bool(val)))

elif isinstance(val, int):
values.append(EValue(Int(val)))

elif isinstance(val, str):
values.append(EValue(String(val)))

elif isinstance(val, torch.dtype):
values.append(EValue(Int(scalar_type_enum(val))))

elif isinstance(val, torch.layout):
values.append(EValue(Int(layout_enum(val))))

else:
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
)

# add to plans
plans.append(
ExecutionPlan(
name=method,
values=values,
inputs=[],
outputs=list(range(0, len(values))),
chains=[chain],
operators=[],
delegates=[],
non_const_buffer_sizes=[0, 0],
container_meta_type=ContainerMetadata("", spec),
)
)
return plans


@dataclass
class EmitterOutput:
"""
Expand Down Expand Up @@ -220,7 +142,7 @@ def emit_program(

# emit any primitive getters
if prim_getters is not None:
plans.extend(_emit_prim_getters(prim_getters))
plans.extend(emitter._emit_prim_getters(prim_getters))

return EmitterOutput(
debug_handle_map=debug_handle_map,
Expand Down
71 changes: 71 additions & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,77 @@ def _emit_free(self, spec: TensorSpec) -> _AbstractValue:
# The value is not used but the caller expects an AbstractValue returned.
return _AbstractValue(None, None) # pyre-ignore

def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
"""
Given a mapping of function names to return values, emit simple execution
plans that just return these constant values.

Precondition: All the values are primitives (bool, float, int, str, enum)
or structures (list, dict) of them.
"""
plans = []
# flatten any structures
for method, vals in prim_getters.items():
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
flattened_output, spec = ex_pytree.tree_flatten(vals)
spec = spec.to_str()
chain = Chain(
inputs=[],
outputs=[],
instructions=[],
stacktrace=None,
)

# switch on type of prim
values = []
for val in flattened_output:
if isinstance(val, float):
values.append(EValue(Double(val)))

elif isinstance(val, bool):
values.append(EValue(Bool(val)))

elif isinstance(val, int):
values.append(EValue(Int(val)))

elif isinstance(val, str):
values.append(EValue(String(val)))

elif isinstance(val, torch.dtype):
values.append(EValue(Int(scalar_type_enum(val))))

elif isinstance(val, torch.layout):
values.append(EValue(Int(layout_enum(val))))

elif isinstance(val, torch.Tensor):
values.append(
self._tensor_spec_to_evalue(
TensorSpec.from_tensor(val, const=True)
)
)

else:
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
)

# add to plans
plans.append(
ExecutionPlan(
name=method,
values=values,
inputs=[],
outputs=list(range(0, len(values))),
chains=[chain],
operators=[],
delegates=[],
non_const_buffer_sizes=[0],
container_meta_type=ContainerMetadata("", spec),
)
)
return plans

def fetch_attr(self, target: _Target) -> _AbstractValue:
"""Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
attr = super().fetch_attr(target)
Expand Down
36 changes: 34 additions & 2 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,9 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48)

def test_emit_prims(self) -> None:
tensor_output = torch.rand(1, 4)
tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)]

class Simple(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand All @@ -1078,6 +1081,12 @@ def get_ints(self) -> Tuple[int]:
def get_str(self) -> str:
return "foo"

def get_tensor(self) -> torch.Tensor:
return tensor_output

def get_tensor_list(self) -> List[torch.Tensor]:
return tensor_list_output

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.sigmoid(self.linear(x))

Expand All @@ -1090,9 +1099,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
getters = {}
getters["get_ints"] = model.get_ints()
getters["get_str"] = model.get_str()
print(getters["get_str"])
getters["get_tensor"] = model.get_tensor()
getters["get_tensor_list"] = model.get_tensor_list()

merged_program = emit_program(exir_input, False, getters).program
self.assertEqual(len(merged_program.execution_plan), 3)

self.assertEqual(len(merged_program.execution_plan), 5)

self.assertEqual(
merged_program.execution_plan[0].name,
Expand All @@ -1106,6 +1118,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
merged_program.execution_plan[2].name,
"get_str",
)
self.assertEqual(
merged_program.execution_plan[3].name,
"get_tensor",
)
self.assertEqual(
merged_program.execution_plan[4].name,
"get_tensor_list",
)

# no instructions in a getter
self.assertEqual(
len(merged_program.execution_plan[1].chains[0].instructions),
Expand Down Expand Up @@ -1141,6 +1162,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
merged_program.execution_plan[2].values[0].val.string_val,
"foo",
)
self.assertEqual(len(merged_program.execution_plan[3].outputs), 1)
self.assertEqual(len(merged_program.execution_plan[4].outputs), 2)

merged_program = to_edge(
export(model, inputs), constant_methods=getters
).to_executorch()
executorch_module = _load_for_executorch_from_buffer(merged_program.buffer)
torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output)
model_output = executorch_module.run_method("get_tensor_list", [])
for i in range(len(tensor_list_output)):
torch.allclose(model_output[i], tensor_list_output[i])

def test_emit_debug_handle_map(self) -> None:
mul_model = Mul()
Expand Down