Skip to content

Commit 10a8e24

Browse files
kirklandsignSS-JIA
andauthored
[ET-VK] Enable custom rotary embedding module replacement (#6424)
Pull Request resolved: #6393 ## Context Implements a module replacement source transform to use the rotary embedding custom op introduced in the previous diff in the Llama model. ghstack-source-id: 249175727 @exported-using-ghexport Differential Revision: [D64697590](https://our.internmc.facebook.com/intern/diff/D64697590/) Co-authored-by: Stephen Jia <[email protected]>
1 parent 10f51b9 commit 10a8e24

File tree

5 files changed

+96
-1
lines changed

5 files changed

+96
-1
lines changed

backends/vulkan/_passes/custom_ops_defs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,53 @@ def linear_weight_int4_impl(
183183
)
184184
lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
185185
linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)
186+
187+
######################
188+
## apply_rotary_emb ##
189+
######################
190+
191+
192+
# Note that this implementation is copied from executorch.examples.models.llama.rope
193+
# but it is copied here to avoid introducing a dependency on the llama code.
194+
def apply_rotary_emb_impl(
195+
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
196+
):
197+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
198+
ndim = x.ndim
199+
freqs_cis_ndim = freqs_cis.ndim
200+
if freqs_cis_ndim == 3:
201+
# freqs_cis: (seq_len, n_heads, head_dim // 2)
202+
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
203+
shape = [
204+
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
205+
for i, d in enumerate(x.shape)
206+
]
207+
else:
208+
# freqs_cis: (seq_len, head_dim // 2)
209+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
210+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
211+
return freqs_cis.view(shape)
212+
213+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
214+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
215+
216+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
217+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
218+
219+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
220+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
221+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
222+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
223+
224+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
225+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
226+
227+
return xq_out.type_as(xq), xk_out.type_as(xk)
228+
229+
230+
name = "apply_rotary_emb"
231+
lib.define(
232+
f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)"
233+
)
234+
lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
235+
apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/partitioner/supported_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ def __contains__(self, op):
9494
# Convolution
9595
exir_ops.edge.aten.convolution.default,
9696
exir_ops.edge.et_vk.conv_with_clamp.default,
97-
# Custom ops
97+
# Llama ops
9898
"llama::sdpa_with_kv_cache",
99+
exir_ops.edge.et_vk.apply_rotary_emb.default,
99100
]
100101

101102
NO_DYNAMIC_SHAPE = [

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ runtime.python_library(
9191
"source_transformation/rope.py",
9292
"source_transformation/sdpa.py",
9393
"source_transformation/spin_quant.py",
94+
"source_transformation/vulkan_rope.py",
9495
],
9596
_is_external_target = True,
9697
base_module = "executorch.examples.models.llama",

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
replace_sdpa_with_flex_sdpa,
7070
replace_sdpa_with_simple_sdpa,
7171
)
72+
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
7273

7374
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
7475
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -943,4 +944,7 @@ def _get_source_transforms( # noqa
943944
transforms.append(replace_sdpa_with_simple_sdpa)
944945
transforms.append(replace_kv_cache_with_coreml_kv_cache)
945946

947+
if args.vulkan:
948+
transforms.append(replace_with_vulkan_rotary_emb)
949+
946950
return transforms
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
10+
apply_rotary_emb_op,
11+
)
12+
13+
from executorch.examples.models.llama.rope import RotaryEmbedding
14+
15+
16+
class VkRotaryEmbedding(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(
21+
self,
22+
xq: torch.Tensor,
23+
xk: torch.Tensor,
24+
freqs_cos: torch.Tensor,
25+
freqs_sin: torch.Tensor,
26+
):
27+
xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
28+
return xq_out, xk_out
29+
30+
31+
def replace_with_vulkan_rotary_emb(module: torch.nn.Module):
32+
for name, child in module.named_children():
33+
if isinstance(child, RotaryEmbedding):
34+
new_module = VkRotaryEmbedding()
35+
setattr(module, name, new_module)
36+
else:
37+
replace_with_vulkan_rotary_emb(child)
38+
39+
return module

0 commit comments

Comments
 (0)