You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Summary
This PR makes some small updates to the docs in a few places. In the
Build from Source page, I've added an explicit command to generate a
test model (add.pte) to use with the executor_runner.
In the Exporting and Lowering page, I've added some details on advanced
export and lowering techniques, including state management, dynamic
control flow, and multi-method .ptes.
### Test plan
I've built and viewed the docs locally to validate the changes.
cc @mergennachin@byjlw
For more information, see [Runtime API Reference](executorch-runtime-api-reference.md).
171
171
172
+
## Advanced Topics
173
+
174
+
While many models will "just work" following the steps above, some more complex models may require additional work to export. These include models with state and models with complex control flow or auto-regressive generation.
175
+
See the [Llama model](https://github.com/pytorch/executorch/tree/main/examples/models/llama) for example use of these techniques.
176
+
177
+
### State Management
178
+
179
+
Some types of models maintain internal state, such as KV caches in transformers. There are two ways to manage state within ExecuTorch. The first is to bring the state out as model inputs and outputs, effectively making the core model stateless. This is sometimes referred to as managing the state as IO.
180
+
181
+
The second approach is to leverage mutable buffers within the model directly. A mutable buffer can be registered using the PyTorch [register_buffer](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer) API on `nn.Module`. Storage for the buffer is managed by the framework, and any mutations to the buffer within the model are written back at the end of method execution.
182
+
183
+
Mutable buffers have several limitations:
184
+
- Export of mutability can be fragile.
185
+
- Consider explicitly calling `detach()` on tensors before assigning to a buffer if you encounter export-time errors related to gradients.
186
+
- Ensure that any operations done on a mutable buffer are done with in-place operations (typipcally ending in `_`).
187
+
- Do not reassign the buffer variable. Instead, use `copy_` to update the entire buffer content.
188
+
- Mutable buffers are not shared between multiple methods within a .pte.
189
+
- In-place operations are replaced with non-in place variants, and the resulting tensor is written back at the end of the method execution. This can be a performance bottleneck when using `index_put_`.
190
+
- Buffer mutations are not supported on all backends and may cause graph breaks and memory transfers back to CPU.
191
+
192
+
Support for mutation is expiremental and may change in the future.
193
+
194
+
### Dynamic Control Flow
195
+
196
+
Control flow is considered dynamic if the path taken is not fixed at export-time. This is commonly the case when if or loop conditions depend on the value of a Tensor, such as a generator loop that terminates when an
197
+
end-of-sequence token is generated. Shape-dependent control flow can also be dynamic if the tensor shape depends on the input.
198
+
199
+
To make dynamic if statements exportable, they can be written using [torch.cond](https://docs.pytorch.org/docs/stable/generated/torch.cond.html). Dynamic loops are not currently supported on ExecuTorch. The general approach to
200
+
enable this type of model is to export the body of the loop as a method, and then handle loop logic from the application code. This is common for handling generator loops in auto-regressive models, such as transformer incremental
201
+
decoding.
202
+
203
+
### Multi-method Models
204
+
205
+
ExecuTorch allows for bundling of multiple methods with a single .pte file. This can be useful for more complex model architectures, such as encoder-decoder models.
206
+
207
+
The include multiple methods in a .pte, each method must be exported individually with `torch.export.export`, yielding one `ExportedProgram` per method. These can be passed as a dictionary into `to_edge_transform_and_lower`:
208
+
```python
209
+
encode_ep = torch.export.export(...)
210
+
decode_ep = torch.export.export(...)
211
+
lowered = to_edge_transform_and_lower({
212
+
"encode": encode_ep,
213
+
"decode": decode_ep,
214
+
}).to_executorch()
215
+
```
216
+
217
+
At runtime, the method name can be passed to `load_method` and `execute` on the `Module` class.
218
+
219
+
Multi-method .ptes have several caveats:
220
+
- Methods are individually memory-planned. Activation memory is not current re-used between methods. For advanced use cases, a [custom memory plan](compiler-memory-planning.md) or [custom memory allocators](https://docs.pytorch.org/executorch/stable/runtime-overview.html#operating-system-considerations) can be used to overlap the allocations.
221
+
- Mutable buffers are not shared between methods.
222
+
- PyTorch export does not currently allow for exporting methods on a module other than `forward`. To work around this, it is common to create wrapper `nn.Modules` for each method.
The PyTorch and ExecuTorch export and lowering APIs provide a high level of customizability to meet the needs of diverse hardware and models. See [torch.export](https://pytorch.org/docs/main/export.html) and [Export API Reference](export-to-executorch-api-reference.md) for more information.
0 commit comments