Skip to content

Commit 70e6a99

Browse files
author
yifan_shen3
committed
update doc with partitioner and quantizer
1 parent 36453fc commit 70e6a99

File tree

2 files changed

+85
-20
lines changed

2 files changed

+85
-20
lines changed

backends/apple/coreml/README.md

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Core ML is an optimized framework for running machine learning models on Apple d
66

77
## Layout
88
- `compiler/` : Lowers a module to Core ML backend.
9+
- `partition/`: Partitions a module fully or partially to Core ML backend.
10+
- `quantizer/`: Quantizes a module in Core ML favored scheme
911
- `scripts/` : Scripts for installing dependencies and running tests.
1012
- `runtime/`: Core ML delegate runtime implementation.
1113
- `inmemoryfs`: InMemory filesystem implementation used to serialize/de-serialize AOT blob.
@@ -20,41 +22,104 @@ Core ML is an optimized framework for running machine learning models on Apple d
2022
If you have problems or questions or have suggestions for ways to make
2123
implementation and testing better, please create an issue on [github](https://www.github.com/pytorch/executorch/issues).
2224

23-
## Delegation
25+
## Partition and Delegation
2426

25-
For delegating the Program to the **Core ML** backend, the client must be responsible for calling `to_backend` with the **CoreMLBackend** tag.
27+
To delegate a Program to the **Core ML** backend, the client must call `to_backend` with the **CoreMLPartitioner**.
2628

2729
```python
28-
import executorch.exir as exir
2930
import torch
30-
31-
from torch.export import export
32-
33-
from executorch.exir import to_edge
34-
35-
from executorch.exir.backend.backend_api import to_backend
31+
import executorch.exir
3632

3733
from executorch.backends.apple.coreml.compiler import CoreMLBackend
34+
from executorch.backends.apple.coreml.partition.coreml_partitioner import CoreMLPartitioner
3835

39-
class LowerableSubModel(torch.nn.Module):
36+
class Model(torch.nn.Module):
4037
def __init__(self):
4138
super().__init__()
4239

4340
def forward(self, x):
4441
return torch.sin(x)
4542

46-
# Convert the lowerable module to Edge IR Representation
47-
to_be_lowered = LowerableSubModel()
48-
example_input = (torch.ones(1), )
49-
to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input))
43+
source_model = Model()
44+
example_inputs = (torch.ones(1), )
45+
46+
# Export the source model to Edge IR representation
47+
aten_program = torch.export.export(source_model, example_inputs)
48+
edge_program_manager = executorch.exir.to_edge(aten_program)
5049

51-
# Lower to Core ML backend
52-
lowered_module = to_backend('CoreMLBackend', to_be_lowered_exir_submodule.exported_program, [])
50+
# Delegate to Core ML backend
51+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
52+
53+
# Serialize delegated program
54+
executorch_program = delegated_program_manager.to_executorch()
55+
with open("model.pte", "wb") as f:
56+
f.write(executorch_program.buffer)
5357
```
5458

55-
Currently, the **Core ML** backend delegates the whole module to **Core ML**. If a specific op is not supported by the **Core ML** backend then the `to_backend` call would throw an exception. We will be adding a **Core ML Partitioner** to resolve the issue.
59+
The module will be fully or partially delegated to **Core ML**, depending on whether all or part of ops are supported by the **Core ML** backend. User may force skip certain ops by `CoreMLPartitioner(skip_ops_for_coreml_delegation=...)`
60+
61+
The `to_backend` implementation is a thin wrapper over [coremltools](https://apple.github.io/coremltools/docs-guides/), `coremltools` is responsible for converting an **ExportedProgram** to a **MLModel**. The converted **MLModel** data is saved, flattened, and returned as bytes to **ExecuTorch**.
62+
63+
## Quantization
64+
65+
To quantize a Program in a Core ML favored way, the client may utilize **CoreMLQuantizer**.
66+
67+
```python
68+
import torch
69+
import executorch.exir
70+
71+
from torch._export import capture_pre_autograd_graph
72+
from torch.ao.quantization.quantize_pt2e import (
73+
convert_pt2e,
74+
prepare_pt2e,
75+
prepare_qat_pt2e,
76+
)
77+
78+
from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer
79+
from coremltools.optimize.torch.quantization.quantization_config import (
80+
LinearQuantizerConfig,
81+
QuantizationScheme,
82+
)
83+
84+
class Model(torch.nn.Module):
85+
def __init__(self) -> None:
86+
super().__init__()
87+
self.conv = torch.nn.Conv2d(
88+
in_channels=3, out_channels=16, kernel_size=3, padding=1
89+
)
90+
self.relu = torch.nn.ReLU()
91+
92+
def forward(self, x: torch.Tensor) -> torch.Tensor:
93+
a = self.conv(x)
94+
return self.relu(a)
95+
96+
source_model = Model()
97+
example_inputs = (torch.randn((1, 3, 256, 256)), )
98+
99+
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
100+
101+
quantization_config = LinearQuantizerConfig.from_dict(
102+
{
103+
"global_config": {
104+
"quantization_scheme": QuantizationScheme.symmetric,
105+
"milestones": [0, 0, 10, 10],
106+
"activation_dtype": torch.uint8,
107+
"weight_dtype": torch.int8,
108+
"weight_per_channel": True,
109+
}
110+
}
111+
)
112+
quantizer = CoreMLQuantizer(quantization_config)
113+
114+
# For post-training quantization, use `prepare_pt2e`
115+
# For quantization-aware trainin,g use `prepare_qat_pt2e`
116+
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
117+
118+
prepared_graph(*example_inputs)
119+
converted_graph = convert_pt2e(prepared_graph)
120+
```
56121

57-
The `to_backend` implementation is a thin wrapper over `coremltools`, `coremltools` is responsible for converting an **ExportedProgram** to a **MLModel**. The converted **MLModel** data is saved, flattened, and returned as bytes to **ExecuTorch**.
122+
The `converted_graph` is the quantized torch model, and can be delegated to **Core ML** similarly through **CoreMLPartitioner**
58123

59124
## Runtime
60125

backends/apple/coreml/setup.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ python3 -m examples.apple.coreml.scripts.export --model_name add
2929
4. You can now integrate the **Core ML** backend in code.
3030

3131
```python
32-
# Lower to Core ML backend
33-
lowered_module = to_backend('CoreMLBackend', to_be_lowered_exir_submodule, [])
32+
# Delegate to Core ML backend
33+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
3434
```
3535

3636

0 commit comments

Comments
 (0)