Skip to content

Add scuba logging to edge API's #7103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 15 additions & 51 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
exir.print_program.pretty_print(program)

deboxed_int_list = []
for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16]
deboxed_int_list.append(
program.execution_plan[0].values[item].val.int_val # pyre-ignore[16]
)
for item in program.execution_plan[0].values[5].val.items:
deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val)

self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1]))

Expand Down Expand Up @@ -459,11 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Check the mul operator's stack trace contains f -> g -> h
self.assertTrue(
"return torch.mul(x, torch.randn(3, 2))"
in program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.stacktrace[1]
.items[-1]
.context
in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
)
self.assertEqual(
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
Expand Down Expand Up @@ -616,11 +610,7 @@ def false_fn(y: torch.Tensor) -> torch.Tensor:
if not isinstance(inst.instr_args, KernelCall):
continue

op = (
program.execution_plan[0]
.operators[inst.instr_args.op_index] # pyre-ignore[16]
.name
)
op = program.execution_plan[0].operators[inst.instr_args.op_index].name

if "mm" in op:
num_mm += 1
Expand Down Expand Up @@ -657,19 +647,13 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# generate the tensor on which this iteration will operate on.
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[0]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index
].name,
"aten::sym_size",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[1]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[1].instr_args.op_index
].name,
"aten::select_copy",
)
Expand All @@ -681,28 +665,19 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# We check here that both of these have been generated.
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-5]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index
].name,
"executorch_prim::et_copy_index",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-4]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index
].name,
"executorch_prim::add",
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-3]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index
].name,
"executorch_prim::eq",
)
Expand All @@ -716,10 +691,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
)
self.assertEqual(
op_table[
program.execution_plan[0] # pyre-ignore[16]
.chains[0]
.instructions[-1]
.instr_args.op_index
program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index
].name,
"executorch_prim::sub",
)
Expand Down Expand Up @@ -1300,9 +1272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# this triggers the actual emission of the graph
program = program_mul._emitter_output.program
node = None
program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16]
0
].instr_args.op_index
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index

# Find the multiplication node in the graph that was emitted.
for node in program_mul.exported_program().graph.nodes:
Expand All @@ -1314,7 +1284,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Find the multiplication instruction in the program that was emitted.
for idx in range(len(program.execution_plan[0].chains[0].instructions)):
instruction = program.execution_plan[0].chains[0].instructions[idx]
op_index = instruction.instr_args.op_index # pyre-ignore[16]
op_index = instruction.instr_args.op_index
if "mul" in program.execution_plan[0].operators[op_index].name:
break

Expand Down Expand Up @@ -1453,9 +1423,7 @@ def forward(self, x, y):
exec_prog._emitter_output.program
self.assertIsNotNone(exec_prog.delegate_map)
self.assertIsNotNone(exec_prog.delegate_map.get("forward"))
self.assertIsNotNone(
exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16]
)
self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0))
self.assertEqual(
exec_prog.delegate_map.get("forward").get(0).get("name"),
"BackendWithCompilerExample",
Expand Down Expand Up @@ -1568,9 +1536,7 @@ def forward(self, x):
model = model.to_executorch()
model.dump_executorch_program(True)
self.assertTrue(
model.executorch_program.execution_plan[0] # pyre-ignore[16]
.values[0]
.val.allocation_info
model.executorch_program.execution_plan[0].values[0].val.allocation_info
is not None
)
executorch_module = _load_for_executorch_from_buffer(model.buffer)
Expand Down Expand Up @@ -1611,9 +1577,7 @@ def forward(self, x):
)
model.dump_executorch_program(True)
self.assertTrue(
model.executorch_program.execution_plan[0] # pyre-ignore[16]
.values[0]
.val.allocation_info
model.executorch_program.execution_plan[0].values[0].val.allocation_info
is not None
)
executorch_module = _load_for_executorch_from_buffer(model.buffer)
Expand Down
3 changes: 2 additions & 1 deletion exir/program/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

Expand Down Expand Up @@ -43,7 +44,7 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:weights_to_outputs_pass",
"//executorch/exir/verification:verifier",
],
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
)

python_library(
Expand Down
22 changes: 22 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,24 @@

Val = Any

from typing import Any, Callable

from torch.library import Library

try:
from executorch.exir.program.fb.logger import et_logger
except ImportError:
# Define a stub decorator that does nothing
def et_logger(api_name: str) -> Callable[[Any], Any]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
return func(self, *args, **kwargs)

return wrapper

return decorator


# This is the reserved namespace that is used to register ops to that will
# be prevented from being decomposed during to_edge_transform_and_lower.
edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
Expand Down Expand Up @@ -957,6 +973,7 @@ def _gen_edge_manager_for_partitioners(
return edge_manager


@et_logger("to_edge_transform_and_lower")
def to_edge_transform_and_lower(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
transform_passes: Optional[
Expand Down Expand Up @@ -1110,6 +1127,7 @@ def to_edge_with_preserved_ops(
)


@et_logger("to_edge")
def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -1204,8 +1222,10 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
"""
Returns the ExportedProgram specified by 'method_name'.
"""

return self._edge_programs[method_name]

@et_logger("transform")
def transform(
self,
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
Expand Down Expand Up @@ -1253,6 +1273,7 @@ def transform(
new_programs, copy.deepcopy(self._config_methods), compile_config
)

@et_logger("to_backend")
def to_backend(
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
) -> "EdgeProgramManager":
Expand Down Expand Up @@ -1296,6 +1317,7 @@ def to_backend(
new_edge_programs, copy.deepcopy(self._config_methods), config
)

@et_logger("to_executorch")
def to_executorch(
self,
config: Optional[ExecutorchBackendConfig] = None,
Expand Down
26 changes: 8 additions & 18 deletions exir/tests/test_joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,21 @@ def forward(self, x, y):

# assert that the weight and bias have proper data_buffer_idx and allocation_info
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
.values[0]
.val.data_buffer_idx,
et.executorch_program.execution_plan[0].values[0].val.data_buffer_idx,
1,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
.values[1]
.val.data_buffer_idx,
et.executorch_program.execution_plan[0].values[1].val.data_buffer_idx,
2,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
et.executorch_program.execution_plan[0]
.values[0]
.val.allocation_info.memory_offset_low,
0,
)
self.assertEqual(
et.executorch_program.execution_plan[0] # pyre-ignore
et.executorch_program.execution_plan[0]
.values[1]
.val.allocation_info.memory_offset_low,
48,
Expand All @@ -106,7 +102,7 @@ def forward(self, x, y):

self.assertTrue(torch.allclose(loss, et_outputs[0]))
self.assertTrue(
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6]
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore
)
self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
Expand All @@ -118,23 +114,17 @@ def forward(self, x, y):

# gradient outputs start at index 1
self.assertEqual(
et.executorch_program.execution_plan[1] # pyre-ignore
.values[0]
.val.int_val,
et.executorch_program.execution_plan[1].values[0].val.int_val,
1,
)

self.assertEqual(
et.executorch_program.execution_plan[2] # pyre-ignore
.values[0]
.val.string_val,
et.executorch_program.execution_plan[2].values[0].val.string_val,
"linear.weight",
)

# parameter outputs start at index 3
self.assertEqual(
et.executorch_program.execution_plan[3] # pyre-ignore
.values[0]
.val.int_val,
et.executorch_program.execution_plan[3].values[0].val.int_val,
3,
)
24 changes: 7 additions & 17 deletions exir/tests/test_remove_view_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,14 @@ def test_spec(self) -> None:
instructions = plan.chains[0].instructions
self.assertEqual(len(instructions), 7)

self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2
self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6
self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7
self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9
self.assertEqual(
instructions[0].instr_args.op_index, 0 # pyre-ignore
) # view @ idx2
self.assertEqual(
instructions[1].instr_args.op_index, 0 # pyre-ignore
) # view @ idx3
self.assertEqual(
instructions[2].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx6
self.assertEqual(
instructions[3].instr_args.op_index, 0 # pyre-ignore
) # view @ idx7
self.assertEqual(
instructions[4].instr_args.op_index, 1 # pyre-ignore
) # aten:mul @ idx9
self.assertEqual(
instructions[5].instr_args.op_index, 2 # pyre-ignore
instructions[5].instr_args.op_index, 2
) # aten:view_copy @ idx11
self.assertEqual(
instructions[6].instr_args.op_index, 2 # pyre-ignore
instructions[6].instr_args.op_index, 2
) # aten:view_copy @ idx11