|
1 |
| -# EXIR |
| 1 | +# Export IR Specification |
2 | 2 |
|
3 |
| -Title TBA |
| 3 | +Export IR is an intermediate representation (IR) for the result of |
| 4 | +`torch.export`. To read more on the details of Export IR, please read this |
| 5 | +[document](https://pytorch.org/docs/main/export.ir_spec.html). |
| 6 | + |
| 7 | +The Exported IR is a specification that consists of the following parts: |
| 8 | + |
| 9 | +1. A definition of computation graph model. |
| 10 | +2. Set of operators allowed in the graph. |
| 11 | + |
| 12 | +A **dialect** is an Exported IR graph composed with the operations defined |
| 13 | +below, but with additional properties (such as restrictions on operator set or |
| 14 | +metadata) that are meant for a specific purpose. |
| 15 | + |
| 16 | +The EXIR dialects that currently exist are: |
| 17 | + |
| 18 | +* [ATen Dialect](./ir-exir-aten-dialect.md) |
| 19 | +* [Edge Dialect](./ir-exir-edge-dialect.md) |
| 20 | +* [Backend Dialect](./ir-exir-backend-dialect.md) |
| 21 | + |
| 22 | +These dialects represent stages that a captured program goes through from |
| 23 | +program capture to conversion into an executable format. For example, the |
| 24 | +Executorch compilation process starts from a Python program capture into ATen |
| 25 | +Dialect, then ATen Dialect is converted to Edge Dialect, Edge to Backend, and |
| 26 | +finally to a binary format for execution. |
| 27 | + |
| 28 | +## ATen Dialect |
| 29 | + |
| 30 | +ATen dialect will be used as the entry point of the ExecuTorch compilation |
| 31 | +pipeline, it is the first time an eager mode Pytorch program becomes an Exported |
| 32 | +IR graph. At this stage, functionalization is performed, so all the tensor |
| 33 | +aliases are made a copy of. Therefore, all tensors are converted to continuous |
| 34 | +format. |
| 35 | + |
| 36 | +The goal of this dialect is to capture users' programs as faithfully as possible |
| 37 | +(while remaining valid Exported IR). Registered custom operators that user has called |
| 38 | +in eager mode will preserve as-is in ATen dialect. However, we should refrain |
| 39 | +from adding custom ops in the graph via passes. |
| 40 | + |
| 41 | +For now, the function of ATen dialect is to further lower to Edge dialect. |
| 42 | +However, in the future we can see this one as the common integration point for |
| 43 | +other export use cases. |
| 44 | + |
| 45 | +### ATen Dialect Properties |
| 46 | + |
| 47 | +An ATen dialect graph is a valid Export IR graph with the following additional |
| 48 | +properties: |
| 49 | + |
| 50 | +1. All operators in `call_function` nodes are either ATen operators (in the |
| 51 | + `torch.ops.aten` namespace, higher order operators (like control flow |
| 52 | + operators), or a registered custom operator. A registered custom operator is |
| 53 | + an operator registered into the current Pytorch eager mode runtime, usually |
| 54 | + with `TORCH_LIBRARY` call (implies schema). Details for how to register a |
| 55 | + custom operator can be found |
| 56 | + [here](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.3rgxk3v387wl). |
| 57 | +2. Every operator must also have a meta kernel. A meta kernel is a |
| 58 | + function that, given the shapes of the input tensors, can return the shape of |
| 59 | + output tensor. Details on how to write a meta kernel can be found |
| 60 | + [here](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0). |
| 61 | +3. Input value type must be “Pytree-able”. As a consequence, the output |
| 62 | + types are also Pytree-able because all operators output are pytree-able. |
| 63 | +4. Ops of ATen dialect can choose to work Dynamic dtypes, implicit type |
| 64 | + promotions and implicit broadcasting of tensors. |
| 65 | +5. All tensors memory formats are in `torch.contiguous_format`. |
| 66 | + |
| 67 | +### ATen Operator Definition |
| 68 | + |
| 69 | +The operator set definition can be found [here](./ir-ops-set-definition.md). |
| 70 | + |
| 71 | +## Edge Dialect |
| 72 | + |
| 73 | +This dialect is meant to introduce specializations that are useful for Edge |
| 74 | +devices but not necessarily for general (server) export. However, we still |
| 75 | +withhold specializing further to each different hardware. In other words, we |
| 76 | +don’t want to introduce any new hardware dependent concepts or data; besides |
| 77 | +those already present in users’ original python program. |
| 78 | + |
| 79 | +### Edge Dialect Properties |
| 80 | + |
| 81 | +An Edge dialect graph is a valid Export IR graph with the following additional |
| 82 | +properties: |
| 83 | + |
| 84 | +1. All operators in OpCall nodes are either from a predefined operator set, |
| 85 | + called **“Edge Operators”**, or a registered custom operator. An Edge operator is a |
| 86 | + ATen operator with dtype specialization. |
| 87 | +2. Input and output of the graph, and as well as to every node, cannot be Scalar. I.e. |
| 88 | + All scalar types (such as float, int) are converted to Tensor. |
| 89 | + |
| 90 | +## Using the Edge Dialect |
| 91 | + |
| 92 | +A GraphModule in Edge dialect is represented with `torch.fx.GraphModule` Python class |
| 93 | +in memory. To obtain such a class, one start with a `torch.nn.Module`: |
| 94 | + |
| 95 | +```python |
| 96 | +import torch |
| 97 | +from executorch import exir |
| 98 | + |
| 99 | +class MyModule(torch.nn.Module): |
| 100 | + ... |
| 101 | +a = MyModule() |
| 102 | +tracing_inputs = (torch.rand(2, 2),) |
| 103 | +aten_dialect_program = torch.export(a, tracing_inputs) |
| 104 | +edge_dialect_program = exir.to_edge(aten_dialect) |
| 105 | +``` |
| 106 | + |
| 107 | +At this point, user defined graph transformation can be run through |
| 108 | +`edge_dialect_program.transform(pass)`. Order matters. Note: If the custom pass |
| 109 | +is touching `node.target`, be aware that all of the `node.target` at this stage |
| 110 | +are "Edge ops" (more details below) and not torch ops like in the ATen dialect. |
| 111 | +A tutorial on pass writing can be found |
| 112 | +[here](./compiler-custom-compiler-passes.md). After all these passes are |
| 113 | +executed, `to_edge()` will make sure the graph is still valid. |
| 114 | + |
| 115 | +### Edge Operators |
| 116 | + |
| 117 | +As mentioned before, an edge operator is an ATen core operator with type |
| 118 | +specialization. This means an instance of the edge operator contains a set of |
| 119 | +dtype constraints, that describe all the tensor dtypes supported by both the |
| 120 | +ExecuTorch runtime and their ATen kernels. These dtype constraints are expressed |
| 121 | +in a DSL defined in |
| 122 | +[edge.yaml](https://github.com/pytorch/executorch/blob/main/exir/dialects/edge/edge.yaml). |
| 123 | +Here's an example of the dtype constraints: |
| 124 | + |
| 125 | +``` |
| 126 | +- func: sigmoid |
| 127 | + namespace: edge |
| 128 | + inherits: aten::sigmoid |
| 129 | + type_alias: |
| 130 | + T0: [Bool, Byte, Char, Int, Long, Short] |
| 131 | + T1: [Double, Float] |
| 132 | + T2: [Float] |
| 133 | + type_constraint: |
| 134 | + - self: T0 |
| 135 | + __ret_0: T2 |
| 136 | + - self: T1 |
| 137 | + __ret_0: T1 |
| 138 | +``` |
| 139 | +This is saying if `self` tensor is one of the type `Bool, Byte, Char, Int, Long, Short`, then the return tensor would be `Float`. If `self` is one of `Double, Float`, the return tensor will be the same dtype. |
| 140 | + |
| 141 | +After these dtype constraints are collected and documented in edge.yaml, EXIR |
| 142 | +consumes the file, and loads the constraints into EXIR Edge operators. This |
| 143 | +makes it convenient for developers to learn the supported dtypes of any argument |
| 144 | +in the Edge op schema. For example we can do: |
| 145 | + |
| 146 | + |
| 147 | +```python |
| 148 | +from executorch.exir.dialects._ops import ops as exir_ops # import dialects ops |
| 149 | +sigmoid = exir_ops.edge.aten.sigmoid.default |
| 150 | +print(sigmoid._schema) |
| 151 | +# aten::sigmoid(Tensor self) -> Tensor |
| 152 | +self_arg = sigmoid._schema.arguments[0] |
| 153 | +_return = sigmoid._schema.returns[0] |
| 154 | + |
| 155 | +print(self_arg.allowed_types) |
| 156 | +# {torch.float32, torch.int8, torch.float64, torch.int16, torch.int32, torch.int64, torch.uint8, torch.bool} |
| 157 | + |
| 158 | +print(_return.allowed_types) |
| 159 | +# {torch.float32, torch.float64} |
| 160 | +``` |
| 161 | + |
| 162 | +These constraints are helpful for someone who wants to write a custom kernel for this operator. Also inside EXIR, we offer a validator to check if the graph is still complying with these dtype constraints, after custom transformations. |
| 163 | + |
| 164 | +### Op Set (WIP) |
| 165 | + |
| 166 | +Check out |
| 167 | +[edge.yaml](https://github.com/pytorch/executorch/blob/main/exir/dialects/edge/edge.yaml) |
| 168 | +for the complete list of operators having dtype constraints specified. We are |
| 169 | +gradually expanding this operator set and targeting to provide dtype constraints |
| 170 | +for all core ATen ops. |
| 171 | + |
| 172 | +## Backend Dialect |
| 173 | + |
| 174 | +Backend dialect is the name we gave to the `ExportedProgram` in Edge dialect, |
| 175 | +after optional **target specific** passes. The difference between backend |
| 176 | +dialect and edge dialect is that backend dialect is target-aware and may contain |
| 177 | +operators or submodules that are only meaningful to the target backend. Backend |
| 178 | +specific operators are new components we may see in a backend dialect, comparing |
| 179 | +with Edge dialect. They are a set of operators for the target backend. |
| 180 | + |
| 181 | +Another property to notice is that the memory formats of the tensor can be any |
| 182 | +format (this is subject to change in the near future when we introduce dim order |
| 183 | +to backend dialect). |
| 184 | + |
| 185 | +This dialect allows introduction of operators that do not conform to the schema |
| 186 | +defined in the canonical ATen operator set, and are not showing up in any of the |
| 187 | +dialects above (ATen dialect and edge dialect). Consider to use backend |
| 188 | +operators if your use case satisfies one or more of the following criteria: |
| 189 | + |
| 190 | +1. Your backend provides a library that optimizes a certain operator that is |
| 191 | + equivalent to a subgraph. E.g., linear_relu (equivalent to linear + relu) that |
| 192 | + can be executed faster on a certain backend. |
| 193 | +2. There's a need to retrace the graph module after it is already lowered to a |
| 194 | + backend. When we retrace, backend operators can transform back to the original |
| 195 | + subgraph (in ATen dialect) where normal custom op doesn't take care of that. |
| 196 | +3. Your backend specific operator doesn't have a generic CPU kernel but only a |
| 197 | + kernel for a certain backend. Using backend operator can workaround this issue |
| 198 | + by using the original subgraph as default kernel and keep the graph module |
| 199 | + runnable. |
| 200 | + |
| 201 | +### Running Backend Passes |
| 202 | + |
| 203 | +To lower edge ops to backend ops, a pass will perform pattern matching to |
| 204 | +identify the edge ops of interest in the graph, and then replace them with |
| 205 | +equivalent backend operators. There are two APIs to register such passes: |
| 206 | + |
| 207 | +* `transform()`. An API on `ExportProgram` that allows users to provide custom |
| 208 | + passes. Note that this is not guarded by any validator so the soundness of the |
| 209 | + program is not guaranteed. |
| 210 | +* [`ExecutorchBackendConfig.passes`](https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L40). |
| 211 | + If added here, the pass will be part of the lowering process from backend |
| 212 | + dialect to `ExecutorchProgram`. |
| 213 | + |
| 214 | +Example: One such pass is `QuantFusion`. This pass takes a "canonical |
| 215 | +quantization pattern", that is, "dequant - some_op - quant", and fusees this |
| 216 | +pattern into a single operator that is backend specific, that is, |
| 217 | +`quantized_decomposed::some_op`. You can find more details |
| 218 | +[here](./quantization-custom-quantization.md). Another simpler example is |
| 219 | +[here](https://github.com/pytorch/executorch/blob/main/exir/passes/replace_edge_with_backend_pass.py#L20) |
| 220 | +where we replace sym_size operators with ones that are understood by ExecuTorch. |
| 221 | + |
| 222 | +### Backend Dialect Operators |
| 223 | + |
| 224 | +We provide a decorator `bind_pattern_to_op` to help users easily register their |
| 225 | +backend operators into Export IR. This decorator takes: |
| 226 | +their backend operators into Export IR. This decorator takes: |
| 227 | +* a `torch.Library` object, that indicates which library or namespace this backend |
| 228 | + operator belongs to. |
| 229 | +* a name or schema. If we already defined the schema of the backend operator in |
| 230 | + the `torch.Library` object, only a name is needed. Otherwise we can register |
| 231 | + the schema if a schema string is being passed in. |
| 232 | + |
| 233 | +This decorator should be added to the pattern we are trying to match (and then |
| 234 | +lower to this backend op) on the edge dialect. This way we are registering this |
| 235 | +pattern as a `CompositeImplicitAutograd` kernel for this backend operator. |
| 236 | + |
| 237 | +Then the operator can be accessed/used from the passes. The `CompositeImplicitAutograd` kernel makes sure: |
| 238 | +1. No need for the user to write a (CPU) runnable kernel |
| 239 | +2. Ensures the retracability of `ExportProgram`. Once retraced, the backend |
| 240 | + operator will be decomposed into the ATen ops used in the pattern. |
| 241 | + |
| 242 | +Unlike edge dialect where we have a well defined op set, for backend dialect, |
| 243 | +since it is target-aware we will be allowing user to use our API to register |
| 244 | +target-aware ops and they will be grouped by namespaces. Here are some examples: |
| 245 | +`executorch_prims` are ops that are used by ExecuTorch runtime to perform |
| 246 | +operation on `SymInt`s. `quantized_decomposed` are ops that fuses edge operators |
| 247 | +for quantization purpose and are meaningful to targets that support |
| 248 | +quantization. |
| 249 | + |
| 250 | +* `executorch_prims::add.int(SymInt a, SymInt b) -> SymInt` |
| 251 | + * pattern: builtin.add |
| 252 | + * backend: executor |
| 253 | +* `executorch_prims::mul.int(SymInt a, SymInt b) -> SymInt` |
| 254 | + * pattern: builtin.mul |
| 255 | + * backend: executor |
| 256 | +* `executorch_prims::sub.int(SymInt a, SymInt b) -> SymInt` |
| 257 | + * pattern: builtin.sub |
| 258 | + * backend: executor |
| 259 | +* `executorch_prims::floordiv.int(SymInt a, SymInt b) -> SymInt` |
| 260 | + * pattern: builtin.floordiv |
| 261 | + * backend: executor |
| 262 | +* `executorch_prims::gt.int(SymInt a, SymInt b) -> bool` |
| 263 | + * pattern: builtin.gt |
| 264 | + * backend: executor |
| 265 | +* `executorch_prims::lt.int(SymInt a, SymInt b) -> bool` |
| 266 | + * pattern: builtin.lt |
| 267 | + * backend: executor |
| 268 | +* `executorch_prims::ge.int(SymInt a, SymInt b) -> bool` |
| 269 | + * pattern: builtin.ge |
| 270 | + * backend: executor |
| 271 | +* `executorch_prims::le.int(SymInt a, SymInt b) -> bool` |
| 272 | + * pattern: builtin.le |
| 273 | + * backend: executor |
| 274 | +* `executorch_prims::eq.int(SymInt a, SymInt b) -> bool` |
| 275 | + * pattern: builtin.eq |
| 276 | + * backend: executor |
| 277 | +* `quantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor` |
| 278 | + * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py) |
| 279 | + * backend: quantization |
| 280 | +* `quantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc` |
| 281 | + * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py) |
| 282 | + * backend: quantization |
| 283 | +* `quantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensor` |
| 284 | + * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py) |
| 285 | + * backend: quantization |
| 286 | +* `quantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc` |
| 287 | + * pattern: [source](https://github.com/pytorch/executorch/blob/main/exir/passes/_quant_patterns_and_replacements.py) |
| 288 | + * backend: quantization |
0 commit comments