Skip to content

Commit 2326fff

Browse files
authored
Add scuba logging to edge API's
Differential Revision: D66385141 Pull Request resolved: #7103
1 parent c9d7b6e commit 2326fff

File tree

5 files changed

+54
-87
lines changed

5 files changed

+54
-87
lines changed

exir/emit/test/test_emit.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
340340
exir.print_program.pretty_print(program)
341341

342342
deboxed_int_list = []
343-
for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16]
344-
deboxed_int_list.append(
345-
program.execution_plan[0].values[item].val.int_val # pyre-ignore[16]
346-
)
343+
for item in program.execution_plan[0].values[5].val.items:
344+
deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val)
347345

348346
self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1]))
349347

@@ -459,11 +457,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
459457
# Check the mul operator's stack trace contains f -> g -> h
460458
self.assertTrue(
461459
"return torch.mul(x, torch.randn(3, 2))"
462-
in program.execution_plan[0] # pyre-ignore[16]
463-
.chains[0]
464-
.stacktrace[1]
465-
.items[-1]
466-
.context
460+
in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
467461
)
468462
self.assertEqual(
469463
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
@@ -616,11 +610,7 @@ def false_fn(y: torch.Tensor) -> torch.Tensor:
616610
if not isinstance(inst.instr_args, KernelCall):
617611
continue
618612

619-
op = (
620-
program.execution_plan[0]
621-
.operators[inst.instr_args.op_index] # pyre-ignore[16]
622-
.name
623-
)
613+
op = program.execution_plan[0].operators[inst.instr_args.op_index].name
624614

625615
if "mm" in op:
626616
num_mm += 1
@@ -657,19 +647,13 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
657647
# generate the tensor on which this iteration will operate on.
658648
self.assertEqual(
659649
op_table[
660-
program.execution_plan[0] # pyre-ignore[16]
661-
.chains[0]
662-
.instructions[0]
663-
.instr_args.op_index
650+
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index
664651
].name,
665652
"aten::sym_size",
666653
)
667654
self.assertEqual(
668655
op_table[
669-
program.execution_plan[0] # pyre-ignore[16]
670-
.chains[0]
671-
.instructions[1]
672-
.instr_args.op_index
656+
program.execution_plan[0].chains[0].instructions[1].instr_args.op_index
673657
].name,
674658
"aten::select_copy",
675659
)
@@ -681,28 +665,19 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
681665
# We check here that both of these have been generated.
682666
self.assertEqual(
683667
op_table[
684-
program.execution_plan[0] # pyre-ignore[16]
685-
.chains[0]
686-
.instructions[-5]
687-
.instr_args.op_index
668+
program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index
688669
].name,
689670
"executorch_prim::et_copy_index",
690671
)
691672
self.assertEqual(
692673
op_table[
693-
program.execution_plan[0] # pyre-ignore[16]
694-
.chains[0]
695-
.instructions[-4]
696-
.instr_args.op_index
674+
program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index
697675
].name,
698676
"executorch_prim::add",
699677
)
700678
self.assertEqual(
701679
op_table[
702-
program.execution_plan[0] # pyre-ignore[16]
703-
.chains[0]
704-
.instructions[-3]
705-
.instr_args.op_index
680+
program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index
706681
].name,
707682
"executorch_prim::eq",
708683
)
@@ -716,10 +691,7 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
716691
)
717692
self.assertEqual(
718693
op_table[
719-
program.execution_plan[0] # pyre-ignore[16]
720-
.chains[0]
721-
.instructions[-1]
722-
.instr_args.op_index
694+
program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index
723695
].name,
724696
"executorch_prim::sub",
725697
)
@@ -1300,9 +1272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13001272
# this triggers the actual emission of the graph
13011273
program = program_mul._emitter_output.program
13021274
node = None
1303-
program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16]
1304-
0
1305-
].instr_args.op_index
1275+
program.execution_plan[0].chains[0].instructions[0].instr_args.op_index
13061276

13071277
# Find the multiplication node in the graph that was emitted.
13081278
for node in program_mul.exported_program().graph.nodes:
@@ -1314,7 +1284,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13141284
# Find the multiplication instruction in the program that was emitted.
13151285
for idx in range(len(program.execution_plan[0].chains[0].instructions)):
13161286
instruction = program.execution_plan[0].chains[0].instructions[idx]
1317-
op_index = instruction.instr_args.op_index # pyre-ignore[16]
1287+
op_index = instruction.instr_args.op_index
13181288
if "mul" in program.execution_plan[0].operators[op_index].name:
13191289
break
13201290

@@ -1453,9 +1423,7 @@ def forward(self, x, y):
14531423
exec_prog._emitter_output.program
14541424
self.assertIsNotNone(exec_prog.delegate_map)
14551425
self.assertIsNotNone(exec_prog.delegate_map.get("forward"))
1456-
self.assertIsNotNone(
1457-
exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16]
1458-
)
1426+
self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0))
14591427
self.assertEqual(
14601428
exec_prog.delegate_map.get("forward").get(0).get("name"),
14611429
"BackendWithCompilerExample",
@@ -1568,9 +1536,7 @@ def forward(self, x):
15681536
model = model.to_executorch()
15691537
model.dump_executorch_program(True)
15701538
self.assertTrue(
1571-
model.executorch_program.execution_plan[0] # pyre-ignore[16]
1572-
.values[0]
1573-
.val.allocation_info
1539+
model.executorch_program.execution_plan[0].values[0].val.allocation_info
15741540
is not None
15751541
)
15761542
executorch_module = _load_for_executorch_from_buffer(model.buffer)
@@ -1611,9 +1577,7 @@ def forward(self, x):
16111577
)
16121578
model.dump_executorch_program(True)
16131579
self.assertTrue(
1614-
model.executorch_program.execution_plan[0] # pyre-ignore[16]
1615-
.values[0]
1616-
.val.allocation_info
1580+
model.executorch_program.execution_plan[0].values[0].val.allocation_info
16171581
is not None
16181582
)
16191583
executorch_module = _load_for_executorch_from_buffer(model.buffer)

exir/program/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
23

34
oncall("executorch")
45

@@ -43,7 +44,7 @@ python_library(
4344
"//executorch/exir/passes:spec_prop_pass",
4445
"//executorch/exir/passes:weights_to_outputs_pass",
4546
"//executorch/exir/verification:verifier",
46-
],
47+
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
4748
)
4849

4950
python_library(

exir/program/_program.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,24 @@
7575

7676
Val = Any
7777

78+
from typing import Any, Callable
79+
7880
from torch.library import Library
7981

82+
try:
83+
from executorch.exir.program.fb.logger import et_logger
84+
except ImportError:
85+
# Define a stub decorator that does nothing
86+
def et_logger(api_name: str) -> Callable[[Any], Any]:
87+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
88+
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
89+
return func(self, *args, **kwargs)
90+
91+
return wrapper
92+
93+
return decorator
94+
95+
8096
# This is the reserved namespace that is used to register ops to that will
8197
# be prevented from being decomposed during to_edge_transform_and_lower.
8298
edge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP"
@@ -957,6 +973,7 @@ def _gen_edge_manager_for_partitioners(
957973
return edge_manager
958974

959975

976+
@et_logger("to_edge_transform_and_lower")
960977
def to_edge_transform_and_lower(
961978
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
962979
transform_passes: Optional[
@@ -1110,6 +1127,7 @@ def to_edge_with_preserved_ops(
11101127
)
11111128

11121129

1130+
@et_logger("to_edge")
11131131
def to_edge(
11141132
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
11151133
constant_methods: Optional[Dict[str, Any]] = None,
@@ -1204,8 +1222,10 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
12041222
"""
12051223
Returns the ExportedProgram specified by 'method_name'.
12061224
"""
1225+
12071226
return self._edge_programs[method_name]
12081227

1228+
@et_logger("transform")
12091229
def transform(
12101230
self,
12111231
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
@@ -1253,6 +1273,7 @@ def transform(
12531273
new_programs, copy.deepcopy(self._config_methods), compile_config
12541274
)
12551275

1276+
@et_logger("to_backend")
12561277
def to_backend(
12571278
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
12581279
) -> "EdgeProgramManager":
@@ -1296,6 +1317,7 @@ def to_backend(
12961317
new_edge_programs, copy.deepcopy(self._config_methods), config
12971318
)
12981319

1320+
@et_logger("to_executorch")
12991321
def to_executorch(
13001322
self,
13011323
config: Optional[ExecutorchBackendConfig] = None,

exir/tests/test_joint_graph.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,21 @@ def forward(self, x, y):
7373

7474
# assert that the weight and bias have proper data_buffer_idx and allocation_info
7575
self.assertEqual(
76-
et.executorch_program.execution_plan[0] # pyre-ignore
77-
.values[0]
78-
.val.data_buffer_idx,
76+
et.executorch_program.execution_plan[0].values[0].val.data_buffer_idx,
7977
1,
8078
)
8179
self.assertEqual(
82-
et.executorch_program.execution_plan[0] # pyre-ignore
83-
.values[1]
84-
.val.data_buffer_idx,
80+
et.executorch_program.execution_plan[0].values[1].val.data_buffer_idx,
8581
2,
8682
)
8783
self.assertEqual(
88-
et.executorch_program.execution_plan[0] # pyre-ignore
84+
et.executorch_program.execution_plan[0]
8985
.values[0]
9086
.val.allocation_info.memory_offset_low,
9187
0,
9288
)
9389
self.assertEqual(
94-
et.executorch_program.execution_plan[0] # pyre-ignore
90+
et.executorch_program.execution_plan[0]
9591
.values[1]
9692
.val.allocation_info.memory_offset_low,
9793
48,
@@ -106,7 +102,7 @@ def forward(self, x, y):
106102

107103
self.assertTrue(torch.allclose(loss, et_outputs[0]))
108104
self.assertTrue(
109-
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore[6]
105+
torch.allclose(m.linear.weight.grad, et_outputs[1]) # pyre-ignore
110106
)
111107
self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
112108
self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
@@ -118,23 +114,17 @@ def forward(self, x, y):
118114

119115
# gradient outputs start at index 1
120116
self.assertEqual(
121-
et.executorch_program.execution_plan[1] # pyre-ignore
122-
.values[0]
123-
.val.int_val,
117+
et.executorch_program.execution_plan[1].values[0].val.int_val,
124118
1,
125119
)
126120

127121
self.assertEqual(
128-
et.executorch_program.execution_plan[2] # pyre-ignore
129-
.values[0]
130-
.val.string_val,
122+
et.executorch_program.execution_plan[2].values[0].val.string_val,
131123
"linear.weight",
132124
)
133125

134126
# parameter outputs start at index 3
135127
self.assertEqual(
136-
et.executorch_program.execution_plan[3] # pyre-ignore
137-
.values[0]
138-
.val.int_val,
128+
et.executorch_program.execution_plan[3].values[0].val.int_val,
139129
3,
140130
)

exir/tests/test_remove_view_copy.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -196,24 +196,14 @@ def test_spec(self) -> None:
196196
instructions = plan.chains[0].instructions
197197
self.assertEqual(len(instructions), 7)
198198

199+
self.assertEqual(instructions[0].instr_args.op_index, 0) # view @ idx2
200+
self.assertEqual(instructions[1].instr_args.op_index, 0) # view @ idx3
201+
self.assertEqual(instructions[2].instr_args.op_index, 1) # aten:mul @ idx6
202+
self.assertEqual(instructions[3].instr_args.op_index, 0) # view @ idx7
203+
self.assertEqual(instructions[4].instr_args.op_index, 1) # aten:mul @ idx9
199204
self.assertEqual(
200-
instructions[0].instr_args.op_index, 0 # pyre-ignore
201-
) # view @ idx2
202-
self.assertEqual(
203-
instructions[1].instr_args.op_index, 0 # pyre-ignore
204-
) # view @ idx3
205-
self.assertEqual(
206-
instructions[2].instr_args.op_index, 1 # pyre-ignore
207-
) # aten:mul @ idx6
208-
self.assertEqual(
209-
instructions[3].instr_args.op_index, 0 # pyre-ignore
210-
) # view @ idx7
211-
self.assertEqual(
212-
instructions[4].instr_args.op_index, 1 # pyre-ignore
213-
) # aten:mul @ idx9
214-
self.assertEqual(
215-
instructions[5].instr_args.op_index, 2 # pyre-ignore
205+
instructions[5].instr_args.op_index, 2
216206
) # aten:view_copy @ idx11
217207
self.assertEqual(
218-
instructions[6].instr_args.op_index, 2 # pyre-ignore
208+
instructions[6].instr_args.op_index, 2
219209
) # aten:view_copy @ idx11

0 commit comments

Comments
 (0)