@@ -195,6 +195,136 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
195
195
return module
196
196
197
197
198
+ @torch .library .custom_op ("coreml::sdpa" , mutates_args = ())
199
+ def sdpa (
200
+ q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , attn_mask : torch .Tensor
201
+ ) -> torch .Tensor :
202
+ """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
203
+ return torch .ops .aten .scaled_dot_product_attention .default (
204
+ q , k , v , attn_mask = attn_mask
205
+ )
206
+
207
+
208
+ @torch .library .register_fake ("coreml::sdpa" )
209
+ def _ (
210
+ q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , attn_mask : torch .Tensor
211
+ ) -> torch .Tensor :
212
+ """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
213
+ expected_shape = list (q .shape )
214
+ expected_shape [- 1 ] = v .shape [- 1 ]
215
+ return q .new_empty (expected_shape )
216
+
217
+
218
+ class SDPACoreML (torch .nn .Module ):
219
+ """Similar to SDPASimple, but with coreml custom op to do SDPA calculation."""
220
+
221
+ def __init__ (
222
+ self ,
223
+ kv_cache : KVCache ,
224
+ dim : int ,
225
+ head_dim : int ,
226
+ n_rep : int ,
227
+ ):
228
+ super ().__init__ ()
229
+ self .kv_cache = kv_cache
230
+ self .dim = dim
231
+ self .head_dim = head_dim
232
+ self .n_rep = n_rep
233
+
234
+ def forward (
235
+ self ,
236
+ input_pos : torch .Tensor ,
237
+ q : torch .Tensor ,
238
+ k : torch .Tensor ,
239
+ v : torch .Tensor ,
240
+ bsz ,
241
+ seqlen ,
242
+ mask ,
243
+ ):
244
+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
245
+ k = k .transpose (1 , 2 )
246
+ v = v .transpose (1 , 2 )
247
+
248
+ k , v = self .kv_cache .update (input_pos , k , v )
249
+ attn_mask = mask [None , None , input_pos ]
250
+
251
+ if self .n_rep > 1 :
252
+ k = k .repeat_interleave (self .n_rep , dim = 1 )
253
+ v = v .repeat_interleave (self .n_rep , dim = 1 )
254
+
255
+ y = torch .ops .coreml .sdpa (q , k , v , attn_mask )
256
+
257
+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
258
+
259
+
260
+ def replace_sdpa_with_coreml_sdpa (module : torch .nn .Module ):
261
+ for name , child in module .named_children ():
262
+ if isinstance (child , SDPA ):
263
+ setattr (
264
+ module ,
265
+ name ,
266
+ SDPACoreML (child .kv_cache , child .dim , child .head_dim , child .n_rep ),
267
+ )
268
+ else :
269
+ replace_sdpa_with_coreml_sdpa (child )
270
+ return module
271
+
272
+
273
+ class KVCacheCoreML (torch .nn .Module ):
274
+ """
275
+ Rather than k_out[:, :, input_pos] = k_val, use torch.ops.aten.index_put_,
276
+ which can directly translate to CoreML iOS18.silce_update
277
+ """
278
+
279
+ def __init__ (
280
+ self ,
281
+ max_batch_size : int ,
282
+ max_seq_length : int ,
283
+ n_heads : int ,
284
+ head_dim : int ,
285
+ dtype = torch .float32 ,
286
+ ):
287
+ super ().__init__ ()
288
+ self .max_seq_length = max_seq_length
289
+ cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
290
+
291
+ self .max_batch_size = max_batch_size
292
+ self .n_heads = n_heads
293
+ self .head_dim = head_dim
294
+ self .register_buffer (
295
+ "k_cache" , torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
296
+ )
297
+ self .register_buffer (
298
+ "v_cache" , torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
299
+ )
300
+
301
+ def update (
302
+ self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
303
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
304
+ k_out = torch .ops .aten .index_put_ (self .k_cache , [None , None , input_pos ], k_val )
305
+ v_out = torch .ops .aten .index_put_ (self .v_cache , [None , None , input_pos ], v_val )
306
+ return k_out , v_out
307
+
308
+
309
+ def replace_kv_cache_with_coreml_kv_cache (module : torch .nn .Module ):
310
+ for name , child in module .named_children ():
311
+ if isinstance (child , KVCache ):
312
+ setattr (
313
+ module ,
314
+ name ,
315
+ KVCacheCoreML (
316
+ child .max_batch_size ,
317
+ child .max_seq_length ,
318
+ child .n_heads ,
319
+ child .head_dim ,
320
+ child .k_cache .dtype ,
321
+ ),
322
+ )
323
+ else :
324
+ replace_kv_cache_with_coreml_kv_cache (child )
325
+ return module
326
+
327
+
198
328
class KVCacheSimple (torch .nn .Module ):
199
329
def __init__ (
200
330
self ,
0 commit comments