You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Exploit the fact that, we reduce the unsqueeze operation to permute.
```
torch.all(torch.permute(x.unsqueeze(0), [1, 0, 2, 3]) == x.unsqueeze(1))
torch.all(torch.permute(x.unsqueeze(0), [1, 2, 0, 3]) == x.unsqueeze(2))
torch.all(torch.permute(x.unsqueeze(0), [1, 2, 3, 0]) == x.unsqueeze(3))
```
This diff introduce a minor change to the Permute implementation that it no longer requires the input dimension length to match the length of the permute array. This allows the `unsqueeze` operation to achieve a no-op `unsqueeze(0)` and then apply a permute.
Differential Revision: [D56347734](https://our.internmc.facebook.com/intern/diff/D56347734/)
[ghstack-poisoned]
Currently, the **Core ML** backend delegates the whole module to **Core ML**. If a specific op is not supported by the **Core ML** backend then the `to_backend` call would throw an exception. We will be adding a **Core ML Partitioner** to resolve the issue.
56
+
The module will be fully or partially delegated to **Core ML**, depending on whether all or part of ops are supported by the **Core ML** backend. User may force skip certain ops by `CoreMLPartitioner(skip_ops_for_coreml_delegation=...)`
57
+
58
+
The `to_backend` implementation is a thin wrapper over [coremltools](https://apple.github.io/coremltools/docs-guides/), `coremltools` is responsible for converting an **ExportedProgram** to a **MLModel**. The converted **MLModel** data is saved, flattened, and returned as bytes to **ExecuTorch**.
59
+
60
+
## Quantization
56
61
57
-
The `to_backend` implementation is a thin wrapper over `coremltools`, `coremltools` is responsible for converting an **ExportedProgram** to a **MLModel**. The converted **MLModel** data is saved, flattened, and returned as bytes to **ExecuTorch**.
62
+
To quantize a Program in a Core ML favored way, the client may utilize **CoreMLQuantizer**.
63
+
64
+
```python
65
+
import torch
66
+
import executorch.exir
67
+
68
+
from torch._export import capture_pre_autograd_graph
69
+
from torch.ao.quantization.quantize_pt2e import (
70
+
convert_pt2e,
71
+
prepare_pt2e,
72
+
prepare_qat_pt2e,
73
+
)
74
+
75
+
from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer
76
+
from coremltools.optimize.torch.quantization.quantization_config import (
The `converted_graph` is the quantized torch model, and can be delegated to **Core ML** similarly through **CoreMLPartitioner**
58
119
59
120
## Runtime
60
121
61
-
To execute a **Core ML** delegated **Program**, the client must link to the `coremldelegate` library. Once linked there are no additional steps required, **ExecuTorch** when running the **Program** would call the **Core ML** runtime to execute the **Core ML** delegated part of the **Program**.
122
+
To execute a Core ML delegated program, the application must link to the `coremldelegate` library. Once linked there are no additional steps required, ExecuTorch when running the program would call the Core ML runtime to execute the Core ML delegated part of the program.
62
123
63
124
Please follow the instructions described in the [Core ML setup](/backends/apple/coreml/setup.md) to link the `coremldelegate` library.
125
+
126
+
## Help & Improvements
127
+
If you have problems or questions or have suggestions for ways to make
128
+
implementation and testing better, please create an issue on [github](https://www.github.com/pytorch/executorch/issues).
0 commit comments