|
| 1 | +import logging |
| 2 | + |
1 | 3 | import torch
|
2 | 4 |
|
3 | 5 | from executorch.exir.pass_base import ExportPass
|
@@ -95,3 +97,106 @@ def call(self, graph_module: torch.fx.GraphModule):
|
95 | 97 | graph_module.recompile()
|
96 | 98 |
|
97 | 99 | return PassResult(graph_module, graph_changed)
|
| 100 | + |
| 101 | + |
| 102 | +class ReplaceSDPAWithCustomSDPAPass(ExportPass): |
| 103 | + """ |
| 104 | + This pass replaces aten.scaled_dot_product_attention.default with llama.custom_sdpa.default. |
| 105 | + If assume_causal_mask is set to True, this pass will ignore any explicit masks and simply set |
| 106 | + is_causal to True in custoom_spda. |
| 107 | + """ |
| 108 | + |
| 109 | + def __init__(self, assume_causal_mask=False): |
| 110 | + super().__init__() |
| 111 | + self.assume_causal_mask = assume_causal_mask |
| 112 | + |
| 113 | + def call_operator(self, op, args, kwargs, meta): |
| 114 | + from executorch.extension.llm.custom_ops import custom_ops # noqa |
| 115 | + |
| 116 | + if op != torch.ops.aten.scaled_dot_product_attention.default: |
| 117 | + return super().call_operator(op, args, kwargs, meta) |
| 118 | + |
| 119 | + q, k, v, mask, dropout, is_causal, scale = self._extract_args(args, kwargs) |
| 120 | + |
| 121 | + qT = self._transpose(q, meta) |
| 122 | + kT = self._transpose(k, meta) |
| 123 | + vT = self._transpose(v, meta) |
| 124 | + |
| 125 | + if not ( |
| 126 | + q.node.meta["val"].dim() |
| 127 | + == k.node.meta["val"].dim() |
| 128 | + == v.node.meta["val"].dim() |
| 129 | + == 4 |
| 130 | + ): |
| 131 | + logging.info("ReplaceSDPAWithCustomSDPAPass only supports 4D QKV inputs.") |
| 132 | + return super().call_operator(op, args, kwargs, meta) |
| 133 | + |
| 134 | + if self.assume_causal_mask: |
| 135 | + # Ignore specified mask simply set the is_causal flag. |
| 136 | + mask = None |
| 137 | + is_causal = True |
| 138 | + |
| 139 | + if mask is not None: |
| 140 | + mask_fake_tensor = mask.node.meta["val"] |
| 141 | + if mask_fake_tensor.dim() > 2: |
| 142 | + if all(d == 1 for d in mask_fake_tensor.size()[:-2]): |
| 143 | + mask = super().call_operator( |
| 144 | + torch.ops.aten.squeeze.dims, |
| 145 | + (mask, tuple(i for i in range(mask_fake_tensor.dim() - 2))), |
| 146 | + {}, |
| 147 | + meta, |
| 148 | + ) |
| 149 | + else: |
| 150 | + logging.info( |
| 151 | + "ReplaceSDPAWithCustomSDPAPass only supports 2D attention mask." |
| 152 | + ) |
| 153 | + return super().call_operator(op, args, kwargs, meta) |
| 154 | + |
| 155 | + # TODO(kimishpatel): Remove once custom SDPA supports boolean mask. |
| 156 | + if mask_fake_tensor.dtype == torch.bool: |
| 157 | + mask = super().call_operator( |
| 158 | + torch.ops.aten.where.Scalar, |
| 159 | + (mask, 0.0, float("-inf")), |
| 160 | + {}, |
| 161 | + meta, |
| 162 | + ) |
| 163 | + |
| 164 | + custom_sdpa = super().call_operator( |
| 165 | + torch.ops.llama.custom_sdpa.default, |
| 166 | + (qT, kT, vT, 0, mask, dropout, is_causal, scale), |
| 167 | + {}, |
| 168 | + meta, |
| 169 | + ) |
| 170 | + return self._transpose(custom_sdpa, meta) |
| 171 | + |
| 172 | + def _extract_args(self, args, kwargs): |
| 173 | + q, k, v, *rest = args |
| 174 | + mask = None |
| 175 | + dropout = 0.0 |
| 176 | + is_causal = False |
| 177 | + scale = None |
| 178 | + if len(rest) > 0: |
| 179 | + mask = rest[0] |
| 180 | + if len(rest) > 1: |
| 181 | + dropout = rest[1] |
| 182 | + if len(rest) > 2: |
| 183 | + is_causal = rest[2] |
| 184 | + if "scale" in kwargs: |
| 185 | + scale = kwargs["scale"] |
| 186 | + |
| 187 | + return q, k, v, mask, dropout, is_causal, scale |
| 188 | + |
| 189 | + def _transpose(self, x, meta): |
| 190 | + transpose = super().call_operator( |
| 191 | + torch.ops.aten.transpose.int, |
| 192 | + (x, 1, 2), |
| 193 | + {}, |
| 194 | + meta, |
| 195 | + ) |
| 196 | + contiguous = super().call_operator( |
| 197 | + torch.ops.aten.contiguous.default, |
| 198 | + (transpose,), |
| 199 | + {}, |
| 200 | + meta, |
| 201 | + ) |
| 202 | + return contiguous |
0 commit comments