Skip to content

Commit aa01508

Browse files
tarun292facebook-github-bot
authored andcommitted
Support tensors in prim_getters
Summary: Adding support for tensors and tensor lists in prim getters Differential Revision: D56426044
1 parent 8dc54d5 commit aa01508

File tree

3 files changed

+106
-68
lines changed

3 files changed

+106
-68
lines changed

exir/emit/_emit_program.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -36,71 +36,6 @@
3636
from torch.utils import _pytree as pytree
3737

3838

39-
def _emit_prim_getters(prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
40-
"""
41-
Given a mapping of function names to return values, emit simple execution
42-
plans that just return these constant values.
43-
44-
Precondition: All the values are primitives (bool, float, int, str, enum)
45-
or structures (list, dict) of them.
46-
"""
47-
plans = []
48-
# flatten any structures
49-
for method, vals in prim_getters.items():
50-
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
51-
flattened_output, spec = ex_pytree.tree_flatten(vals)
52-
spec = spec.to_str()
53-
chain = Chain(
54-
inputs=[],
55-
outputs=[],
56-
instructions=[],
57-
stacktrace=None,
58-
)
59-
60-
# switch on type of prim
61-
values = []
62-
for val in flattened_output:
63-
if isinstance(val, float):
64-
values.append(EValue(Double(val)))
65-
66-
elif isinstance(val, bool):
67-
values.append(EValue(Bool(val)))
68-
69-
elif isinstance(val, int):
70-
values.append(EValue(Int(val)))
71-
72-
elif isinstance(val, str):
73-
values.append(EValue(String(val)))
74-
75-
elif isinstance(val, torch.dtype):
76-
values.append(EValue(Int(scalar_type_enum(val))))
77-
78-
elif isinstance(val, torch.layout):
79-
values.append(EValue(Int(layout_enum(val))))
80-
81-
else:
82-
raise ExportError(
83-
ExportErrorType.NOT_SUPPORTED,
84-
f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
85-
)
86-
87-
# add to plans
88-
plans.append(
89-
ExecutionPlan(
90-
name=method,
91-
values=values,
92-
inputs=[],
93-
outputs=list(range(0, len(values))),
94-
chains=[chain],
95-
operators=[],
96-
delegates=[],
97-
non_const_buffer_sizes=[0, 0],
98-
container_meta_type=ContainerMetadata("", spec),
99-
)
100-
)
101-
return plans
102-
103-
10439
@dataclass
10540
class EmitterOutput:
10641
"""
@@ -220,7 +155,7 @@ def emit_program(
220155

221156
# emit any primitive getters
222157
if prim_getters is not None:
223-
plans.extend(_emit_prim_getters(prim_getters))
158+
plans.extend(emitter._emit_prim_getters(prim_getters))
224159

225160
return EmitterOutput(
226161
debug_handle_map=debug_handle_map,

exir/emit/_emitter.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,77 @@ def _emit_free(self, spec: TensorSpec) -> _AbstractValue:
11731173
# The value is not used but the caller expects an AbstractValue returned.
11741174
return _AbstractValue(None, None) # pyre-ignore
11751175

1176+
def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan]:
1177+
"""
1178+
Given a mapping of function names to return values, emit simple execution
1179+
plans that just return these constant values.
1180+
1181+
Precondition: All the values are primitives (bool, float, int, str, enum)
1182+
or structures (list, dict) of them.
1183+
"""
1184+
plans = []
1185+
# flatten any structures
1186+
for method, vals in prim_getters.items():
1187+
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
1188+
flattened_output, spec = ex_pytree.tree_flatten(vals)
1189+
spec = spec.to_str()
1190+
chain = Chain(
1191+
inputs=[],
1192+
outputs=[],
1193+
instructions=[],
1194+
stacktrace=None,
1195+
)
1196+
1197+
# switch on type of prim
1198+
values = []
1199+
for val in flattened_output:
1200+
if isinstance(val, float):
1201+
values.append(EValue(Double(val)))
1202+
1203+
elif isinstance(val, bool):
1204+
values.append(EValue(Bool(val)))
1205+
1206+
elif isinstance(val, int):
1207+
values.append(EValue(Int(val)))
1208+
1209+
elif isinstance(val, str):
1210+
values.append(EValue(String(val)))
1211+
1212+
elif isinstance(val, torch.dtype):
1213+
values.append(EValue(Int(scalar_type_enum(val))))
1214+
1215+
elif isinstance(val, torch.layout):
1216+
values.append(EValue(Int(layout_enum(val))))
1217+
1218+
elif isinstance(val, torch.Tensor):
1219+
values.append(
1220+
self._tensor_spec_to_evalue(
1221+
TensorSpec.from_tensor(val, const=True)
1222+
)
1223+
)
1224+
1225+
else:
1226+
raise ExportError(
1227+
ExportErrorType.NOT_SUPPORTED,
1228+
f"Error emitting {method} which returns a value of type {type(val)}. which is not a supported primitive",
1229+
)
1230+
1231+
# add to plans
1232+
plans.append(
1233+
ExecutionPlan(
1234+
name=method,
1235+
values=values,
1236+
inputs=[],
1237+
outputs=list(range(0, len(values))),
1238+
chains=[chain],
1239+
operators=[],
1240+
delegates=[],
1241+
non_const_buffer_sizes=[0, 0],
1242+
container_meta_type=ContainerMetadata("", spec),
1243+
)
1244+
)
1245+
return plans
1246+
11761247
def fetch_attr(self, target: _Target) -> _AbstractValue:
11771248
"""Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
11781249
attr = super().fetch_attr(target)

exir/emit/test/test_emit.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,9 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
10651065
self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48)
10661066

10671067
def test_emit_prims(self) -> None:
1068+
tensor_output = torch.rand(1, 4)
1069+
tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)]
1070+
10681071
class Simple(torch.nn.Module):
10691072
def __init__(self) -> None:
10701073
super().__init__()
@@ -1078,6 +1081,12 @@ def get_ints(self) -> Tuple[int]:
10781081
def get_str(self) -> str:
10791082
return "foo"
10801083

1084+
def get_tensor(self) -> torch.Tensor:
1085+
return tensor_output
1086+
1087+
def get_tensor_list(self) -> List[torch.Tensor]:
1088+
return tensor_list_output
1089+
10811090
def forward(self, x: torch.Tensor) -> torch.Tensor:
10821091
return torch.nn.functional.sigmoid(self.linear(x))
10831092

@@ -1090,9 +1099,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10901099
getters = {}
10911100
getters["get_ints"] = model.get_ints()
10921101
getters["get_str"] = model.get_str()
1093-
print(getters["get_str"])
1102+
getters["get_tensor"] = model.get_tensor()
1103+
getters["get_tensor_list"] = model.get_tensor_list()
1104+
10941105
merged_program = emit_program(exir_input, False, getters).program
1095-
self.assertEqual(len(merged_program.execution_plan), 3)
1106+
1107+
self.assertEqual(len(merged_program.execution_plan), 5)
10961108

10971109
self.assertEqual(
10981110
merged_program.execution_plan[0].name,
@@ -1106,6 +1118,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11061118
merged_program.execution_plan[2].name,
11071119
"get_str",
11081120
)
1121+
self.assertEqual(
1122+
merged_program.execution_plan[3].name,
1123+
"get_tensor",
1124+
)
1125+
self.assertEqual(
1126+
merged_program.execution_plan[4].name,
1127+
"get_tensor_list",
1128+
)
1129+
11091130
# no instructions in a getter
11101131
self.assertEqual(
11111132
len(merged_program.execution_plan[1].chains[0].instructions),
@@ -1141,6 +1162,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11411162
merged_program.execution_plan[2].values[0].val.string_val,
11421163
"foo",
11431164
)
1165+
self.assertEqual(len(merged_program.execution_plan[3].outputs), 1)
1166+
self.assertEqual(len(merged_program.execution_plan[4].outputs), 2)
1167+
1168+
merged_program = to_edge(
1169+
export(model, inputs), constant_methods=getters
1170+
).to_executorch()
1171+
executorch_module = _load_for_executorch_from_buffer(merged_program.buffer)
1172+
torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output)
1173+
model_output = executorch_module.run_method("get_tensor_list", [])
1174+
for i in range(len(tensor_list_output)):
1175+
torch.allclose(model_output[i], tensor_list_output[i])
11441176

11451177
def test_emit_debug_handle_map(self) -> None:
11461178
mul_model = Mul()

0 commit comments

Comments
 (0)