Skip to content

Commit 0d5e887

Browse files
committed
[CoreML Backend] Backend fixes and propagate debug handles.
1 parent e283967 commit 0d5e887

22 files changed

+313
-166
lines changed

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 156 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
# CoreML backend for delegating a EdgeProgram to CoreML.
44

55
import json
6+
67
import shutil
78
import uuid
89
from dataclasses import asdict, dataclass
910
from enum import Enum
1011

1112
from pathlib import Path
1213

13-
from typing import Dict, final, List
14+
from typing import Any, Dict, final, List, Optional, Tuple
1415

1516
import coremltools as ct
1617
import executorchcoreml
@@ -30,6 +31,13 @@ class COMPILE_SPEC_KEYS(Enum):
3031
MODEL_COMPUTE_PRECISION = "model_compute_precision"
3132

3233

34+
class MODEL_PATHS(Enum):
35+
MODEL = "model.mlpackage"
36+
COMPILED_MODEL = "model.mlmodelc"
37+
METADATA = "metadata.json"
38+
DEBUG_INFO = "debug_info.json"
39+
40+
3341
@dataclass
3442
class ModelMetadata:
3543
# The model input names.
@@ -40,6 +48,16 @@ class ModelMetadata:
4048
identifier: str
4149

4250

51+
@dataclass
52+
class ModelDebugInfo:
53+
# Version info.
54+
versionInfo: Dict[str, str]
55+
# Mapping from debug symbol to operation path.
56+
debugSymbolToOperationPath: Dict[str, List[Dict[str, str]]]
57+
# Mapping from debug symbol to handle.
58+
debugSymbolToHandles: Dict[str, List[int]]
59+
60+
4361
@final
4462
class CoreMLBackend(BackendDetails):
4563
class MODEL_TYPE(Enum):
@@ -165,53 +183,163 @@ def generate_compile_specs(
165183
return compile_specs
166184

167185
@staticmethod
168-
def model_metadata_from_spec(model_spec: ct.proto.Model_pb2) -> Dict[str, str]:
186+
def model_metadata_from_spec(
187+
model_spec: ct.proto.Model_pb2, identifier: str
188+
) -> Dict[str, str]:
169189
input_names: List[str] = [input.name for input in model_spec.description.input]
170190
output_names = [output.name for output in model_spec.description.output]
171-
identifier = uuid.uuid4()
172191

173192
return ModelMetadata(
174-
inputNames=input_names, outputNames=output_names, identifier=str(identifier)
193+
inputNames=input_names, outputNames=output_names, identifier=identifier
194+
)
195+
196+
@staticmethod
197+
def get_debug_symbol(operation_path: List[Dict[str, str]]) -> Optional[str]:
198+
if len(operation_path) == 0:
199+
return None
200+
201+
operator_name: Optional[str] = operation_path[-1].get("Operator", None)
202+
output_name: Optional[str] = operation_path[-1].get("Output", None)
203+
if output_name is None or operator_name is None:
204+
return None
205+
206+
return output_name + ":" + operator_name
207+
208+
@staticmethod
209+
def get_model_debug_info(model_package_dir: Path) -> Optional[ModelDebugInfo]:
210+
delegate_info_file = model_package_dir / "executorch_debug_handle_mapping.json"
211+
212+
if not delegate_info_file.is_file():
213+
return None
214+
215+
delegate_info: Optional[Dict[str, Any]] = None
216+
217+
try:
218+
with open(delegate_info_file) as f:
219+
delegate_info = json.load(f)
220+
except ValueError:
221+
return None
222+
223+
if delegate_info is None:
224+
return None
225+
226+
debug_handle_to_operation_path_mapping: Optional[Dict[str, Any]] = (
227+
delegate_info.get("mapping", None)
228+
)
229+
230+
if debug_handle_to_operation_path_mapping is None:
231+
return None
232+
233+
debug_symbol_to_operation_path: Dict[str, List[Dict[str, str]]] = {}
234+
debug_symbol_to_handles: Dict[str, List[int]] = {}
235+
for (
236+
debug_handle,
237+
operation_paths,
238+
) in debug_handle_to_operation_path_mapping.items():
239+
debug_handle_value: Optional[int] = None
240+
try:
241+
debug_handle_value = int(debug_handle)
242+
except ValueError:
243+
debug_handle_value = None
244+
245+
if debug_handle_value is None:
246+
continue
247+
248+
for operation_path in operation_paths:
249+
debug_symbol: Optional[str] = CoreMLBackend.get_debug_symbol(
250+
operation_path=operation_path
251+
)
252+
253+
if debug_symbol is None:
254+
continue
255+
256+
debug_handle_values: List[int] = debug_symbol_to_handles.get(
257+
debug_symbol, []
258+
)
259+
debug_handle_values.append(debug_handle_value)
260+
debug_symbol_to_handles[debug_symbol] = debug_handle_values
261+
262+
debug_symbol_to_operation_path[debug_symbol] = operation_path
263+
264+
version_info: Dict[str, str] = delegate_info.get("version", {})
265+
266+
return ModelDebugInfo(
267+
versionInfo=version_info,
268+
debugSymbolToOperationPath=debug_symbol_to_operation_path,
269+
debugSymbolToHandles=debug_symbol_to_handles,
175270
)
176271

177272
@staticmethod
178-
def to_bytes(mlmodel: ct.models.MLModel, model_type: MODEL_TYPE) -> bytes:
179-
dir_path: Path = Path("tmp")
273+
def save_model_metadata(model_metadata: ModelMetadata, model_dir_path: Path):
274+
# Store model metadata.
275+
model_metadata_path = Path(model_dir_path) / MODEL_PATHS.METADATA.value
276+
model_metadata_json = json.dumps(asdict(model_metadata))
277+
with open(model_metadata_path, "w") as outfile:
278+
outfile.write(model_metadata_json)
279+
280+
@staticmethod
281+
def save_model_debug_info(model_debug_info: ModelDebugInfo, model_dir_path: Path):
282+
# Store model debug info.
283+
model_debug_info_path = Path(model_dir_path) / MODEL_PATHS.DEBUG_INFO.value
284+
model_debug_info_json = json.dumps(asdict(model_debug_info))
285+
with open(model_debug_info_path, "w") as outfile:
286+
outfile.write(model_debug_info_json)
287+
288+
@staticmethod
289+
def preprocess_model(
290+
mlmodel: ct.models.MLModel, model_type: MODEL_TYPE
291+
) -> PreprocessResult:
292+
identifier = str(uuid.uuid4())
293+
dir_path: Path = Path("tmp") / identifier
180294
model_dir_path: Path = dir_path / "lowered_module"
181295
model_spec: ct.proto.Model_pb2 = mlmodel.get_spec()
182296
model_metadata: ModelMetadata = CoreMLBackend.model_metadata_from_spec(
183-
model_spec
297+
model_spec=model_spec,
298+
identifier=identifier,
184299
)
185-
match model_type:
186-
case CoreMLBackend.MODEL_TYPE.MODEL:
187-
# Store model.
188-
model_path = model_dir_path / "model.mlpackage"
189-
mlmodel.save(model_path)
190300

301+
# Save model.
302+
model_path = model_dir_path / MODEL_PATHS.MODEL.value
303+
mlmodel.save(model_path)
304+
# Extract delegate mapping file.
305+
model_debug_info: Optional[ModelDebugInfo] = CoreMLBackend.get_model_debug_info(
306+
model_path
307+
)
308+
309+
match model_type:
191310
case CoreMLBackend.MODEL_TYPE.COMPILED_MODEL:
192-
# Store compiled model
193-
model_path = model_dir_path / "model.mlmodelc"
311+
shutil.rmtree(str(model_path.resolve()))
312+
model_path = model_dir_path / MODEL_PATHS.COMPILED_MODEL.value
194313
compiled_model_path = mlmodel.get_compiled_model_path()
195-
196-
shutil.copytree(
314+
shutil.move(
197315
compiled_model_path,
198316
str(model_path.resolve()),
199-
dirs_exist_ok=True,
200317
)
201318

202-
# Store model metadata.
203-
model_metadata_path = Path(model_dir_path) / "metadata.json"
204-
model_metadata_json = json.dumps(asdict(model_metadata))
205-
with open(model_metadata_path, "w") as outfile:
206-
outfile.write(model_metadata_json)
319+
case _:
320+
pass
207321

208-
# flatten directory contents and convert it to bytes
209-
flattened_bytes = executorchcoreml.flatten_directory_contents(
322+
CoreMLBackend.save_model_metadata(
323+
model_metadata=model_metadata, model_dir_path=model_dir_path
324+
)
325+
if model_debug_info is not None:
326+
CoreMLBackend.save_model_debug_info(
327+
model_debug_info=model_debug_info, model_dir_path=model_dir_path
328+
)
329+
330+
processed_bytes: bytes = executorchcoreml.flatten_directory_contents(
210331
str(model_dir_path.resolve())
211332
)
212333

213-
shutil.rmtree(str(model_dir_path.resolve()))
214-
return flattened_bytes
334+
debug_handle_map: Optional[Dict[str, Tuple[int]]] = None
335+
if model_debug_info is not None:
336+
debug_handle_map = model_debug_info.debugSymbolToHandles
337+
338+
shutil.rmtree(str(dir_path.resolve()))
339+
return PreprocessResult(
340+
processed_bytes=processed_bytes,
341+
debug_handle_map=debug_handle_map,
342+
)
215343

216344
@classmethod
217345
def preprocess(
@@ -235,25 +363,14 @@ def preprocess(
235363
CoreMLBackend.min_deployment_target_from_compile_specs(module_compile_specs)
236364
)
237365

238-
skip_model_load: bool = False
239-
match model_type:
240-
case CoreMLBackend.MODEL_TYPE.MODEL:
241-
skip_model_load = True
242-
243-
case CoreMLBackend.MODEL_TYPE.COMPILED_MODEL:
244-
skip_model_load = False
245-
246366
mlmodel = ct.convert(
247367
model=edge_program,
248368
source="pytorch",
249369
convert_to="mlprogram",
250370
pass_pipeline=ct.PassPipeline.DEFAULT,
251-
skip_model_load=skip_model_load,
371+
skip_model_load=False,
252372
compute_precision=model_compute_precision,
253373
minimum_deployment_target=minimum_deployment_target,
254374
)
255375

256-
processed_bytes = CoreMLBackend.to_bytes(mlmodel, model_type=model_type)
257-
return PreprocessResult(
258-
processed_bytes=processed_bytes,
259-
)
376+
return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type)

backends/apple/coreml/runtime/delegate/ETCoreMLLogging.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ typedef NS_ERROR_ENUM(ETCoreMLErrorDomain, ETCoreMLError) {
6262
/// `*errorOut` with NSError.
6363
#define ETCoreMLLogUnderlyingErrorAndSetNSError(errorOut, errorCode, underlyingNSError, formatString, ...) \
6464
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
65-
formatString " (Underlying error: %@)", \
65+
formatString ", with underlying error= %@.", \
6666
##__VA_ARGS__, \
6767
(underlyingNSError).localizedDescription); \
6868
if (errorOut) { \
@@ -71,10 +71,10 @@ typedef NS_ERROR_ENUM(ETCoreMLErrorDomain, ETCoreMLError) {
7171
format:@formatString, ##__VA_ARGS__]; \
7272
}
7373

74-
#define ETCoreMLLogError(error, formatString, ...) \
75-
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
76-
formatString " (Underlying error: %@)", \
77-
##__VA_ARGS__, \
74+
#define ETCoreMLLogError(error, formatString, ...) \
75+
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
76+
formatString ", with error= %@.", \
77+
##__VA_ARGS__, \
7878
(error).localizedDescription);
7979

8080

0 commit comments

Comments
 (0)