Skip to content

Commit 3b4ec5e

Browse files
committed
Add support for ARWKV7 Hybrid models
Signed-off-by: Molly Sophia <[email protected]>
1 parent b4e6cf4 commit 3b4ec5e

File tree

7 files changed

+420
-72
lines changed

7 files changed

+420
-72
lines changed

convert_hf_to_gguf.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3555,6 +3555,84 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35553555
yield (new_name, data_torch)
35563556

35573557

3558+
@Model.register("RwkvHybridForCausalLM")
3559+
class ARwkv7Model(Model):
3560+
model_arch = gguf.MODEL_ARCH.ARWKV7
3561+
3562+
def set_vocab(self):
3563+
try:
3564+
self._set_vocab_sentencepiece()
3565+
except FileNotFoundError:
3566+
self._set_vocab_gpt2()
3567+
3568+
def set_gguf_parameters(self):
3569+
block_count = self.hparams["num_hidden_layers"]
3570+
hidden_size = self.hparams["hidden_size"]
3571+
head_size = self.hparams["head_size"]
3572+
rms_norm_eps = self.hparams["rms_norm_eps"]
3573+
intermediate_size = self.hparams["intermediate_size"]
3574+
wkv_has_gate = self.hparams["wkv_has_gate"]
3575+
assert self.hparams["wkv_version"] == 7
3576+
3577+
# ICLR: In-Context-Learning-Rate
3578+
lora_rank_decay = 64
3579+
lora_rank_iclr = 64
3580+
lora_rank_value_residual_mix = 32
3581+
lora_rank_gate = 128 if wkv_has_gate else 0
3582+
3583+
# RWKV isn't context limited
3584+
self.gguf_writer.add_context_length(1048576)
3585+
self.gguf_writer.add_embedding_length(hidden_size)
3586+
self.gguf_writer.add_block_count(block_count)
3587+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
3588+
self.gguf_writer.add_wkv_head_size(head_size)
3589+
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
3590+
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
3591+
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
3592+
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
3593+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3594+
self.gguf_writer.add_file_type(self.ftype)
3595+
self.gguf_writer.add_token_shift_count(1)
3596+
3597+
# required by llama.cpp, unused
3598+
self.gguf_writer.add_head_count(0)
3599+
3600+
lerp_weights: dict[int, dict[str, Tensor]] = {}
3601+
3602+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3603+
if bid is not None and "self_attn.time_mixer.x_" in name:
3604+
try:
3605+
self.lerp_weights[bid][name] = data_torch
3606+
except KeyError:
3607+
self.lerp_weights[bid] = {name: data_torch}
3608+
if all(f"model.layers.{bid}.self_attn.time_mixer.x_{i}" in self.lerp_weights[bid].keys() for i in ["r", "w", "k", "v", "a", "g"]):
3609+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3610+
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.self_attn.time_mixer.x_{i}"].squeeze(0) for i in ["r", "w", "k", "v", "a", "g"]], dim=0)
3611+
yield (new_name, data)
3612+
return
3613+
else:
3614+
data_torch = data_torch.squeeze()
3615+
new_name = self.map_tensor_name(name)
3616+
3617+
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
3618+
new_name += ".weight"
3619+
3620+
if any(
3621+
new_name.endswith(t) for t in [
3622+
"time_mix_w1.weight", "time_mix_w2.weight",
3623+
"time_mix_a1.weight", "time_mix_a2.weight",
3624+
"time_mix_v1.weight", "time_mix_v2.weight",
3625+
"time_mix_g1.weight", "time_mix_g2.weight",
3626+
]
3627+
):
3628+
data_torch = data_torch.transpose(0, 1)
3629+
3630+
if 'r_k' in new_name:
3631+
data_torch = data_torch.flatten()
3632+
3633+
yield (new_name, data_torch)
3634+
3635+
35583636
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
35593637
class MambaModel(Model):
35603638
model_arch = gguf.MODEL_ARCH.MAMBA

gguf-py/gguf/constants.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ class MODEL_ARCH(IntEnum):
261261
RWKV6 = auto()
262262
RWKV6QWEN2 = auto()
263263
RWKV7 = auto()
264+
ARWKV7 = auto()
264265
MAMBA = auto()
265266
XVERSE = auto()
266267
COMMAND_R = auto()
@@ -461,6 +462,7 @@ class MODEL_TENSOR(IntEnum):
461462
MODEL_ARCH.RWKV6: "rwkv6",
462463
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
463464
MODEL_ARCH.RWKV7: "rwkv7",
465+
MODEL_ARCH.ARWKV7: "arwkv7",
464466
MODEL_ARCH.MAMBA: "mamba",
465467
MODEL_ARCH.XVERSE: "xverse",
466468
MODEL_ARCH.COMMAND_R: "command-r",
@@ -1214,6 +1216,37 @@ class MODEL_TENSOR(IntEnum):
12141216
MODEL_TENSOR.CHANNEL_MIX_KEY,
12151217
MODEL_TENSOR.CHANNEL_MIX_VALUE,
12161218
],
1219+
MODEL_ARCH.ARWKV7: [
1220+
MODEL_TENSOR.TOKEN_EMBD,
1221+
MODEL_TENSOR.TOKEN_EMBD_NORM,
1222+
MODEL_TENSOR.OUTPUT_NORM,
1223+
MODEL_TENSOR.OUTPUT,
1224+
MODEL_TENSOR.ATTN_NORM,
1225+
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
1226+
MODEL_TENSOR.TIME_MIX_W0,
1227+
MODEL_TENSOR.TIME_MIX_W1,
1228+
MODEL_TENSOR.TIME_MIX_W2,
1229+
MODEL_TENSOR.TIME_MIX_A0,
1230+
MODEL_TENSOR.TIME_MIX_A1,
1231+
MODEL_TENSOR.TIME_MIX_A2,
1232+
MODEL_TENSOR.TIME_MIX_V0,
1233+
MODEL_TENSOR.TIME_MIX_V1,
1234+
MODEL_TENSOR.TIME_MIX_V2,
1235+
MODEL_TENSOR.TIME_MIX_G1,
1236+
MODEL_TENSOR.TIME_MIX_G2,
1237+
MODEL_TENSOR.TIME_MIX_K_K,
1238+
MODEL_TENSOR.TIME_MIX_K_A,
1239+
MODEL_TENSOR.TIME_MIX_R_K,
1240+
MODEL_TENSOR.TIME_MIX_KEY,
1241+
MODEL_TENSOR.TIME_MIX_VALUE,
1242+
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
1243+
MODEL_TENSOR.TIME_MIX_LN,
1244+
MODEL_TENSOR.TIME_MIX_OUTPUT,
1245+
MODEL_TENSOR.FFN_NORM,
1246+
MODEL_TENSOR.FFN_GATE,
1247+
MODEL_TENSOR.FFN_DOWN,
1248+
MODEL_TENSOR.FFN_UP,
1249+
],
12171250
MODEL_ARCH.MAMBA: [
12181251
MODEL_TENSOR.TOKEN_EMBD,
12191252
MODEL_TENSOR.OUTPUT_NORM,

0 commit comments

Comments
 (0)