Skip to content

Commit 75c7366

Browse files
committed
Support DeciLM (Nemotron)
1 parent 29060ad commit 75c7366

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

exllamav3/models/architectures.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .qwen2 import Qwen2Config, Qwen2Model
55
from .phi3 import Phi3Config, Phi3Model
66
from .gemma import Gemma2Config, Gemma2Model
7+
from .decilm import DeciLMConfig, DeciLMModel
78

89
ARCHITECTURES = {
910
"LlamaForCausalLM": {
@@ -31,6 +32,11 @@
3132
"config_class": Gemma2Config,
3233
"model_class": Gemma2Model,
3334
},
35+
"DeciLMForCausalLM": {
36+
"architecture": "DeciLMForCausalLM",
37+
"config_class": DeciLMConfig,
38+
"model_class": DeciLMModel,
39+
},
3440
}
3541

3642
def get_architectures():

exllamav3/models/decilm.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from __future__ import annotations
2+
from typing_extensions import override
3+
import torch
4+
from .config import Config, no_default
5+
from .model import Model
6+
from ..util.rope import RopeSettings, RopeStyle
7+
from ..modules import RMSNorm, Embedding, TransformerBlock, Attention, GatedMLP, Linear
8+
from ..modules.attn import prepare_for_attn
9+
10+
class DeciLMConfig(Config):
11+
12+
def __init__(
13+
self,
14+
directory: str,
15+
**kwargs,
16+
):
17+
super().__init__(
18+
directory,
19+
kwargs.get("arch_string", "DeciLMForCausalLM"),
20+
DeciLMModel,
21+
**kwargs
22+
)
23+
24+
# Global attention params
25+
self.head_dim = self.read_cfg(int, "head_dim", None)
26+
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
27+
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
28+
29+
if not self.head_dim:
30+
self.head_dim = self.hidden_size // self.num_q_heads
31+
32+
# MLP params
33+
self.assert_cfg(str, "hidden_act", "silu", True)
34+
35+
# Norms
36+
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
37+
38+
# Layers
39+
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
40+
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
41+
42+
# RoPE
43+
self.rope_settings = self.read_rope_settings_default(RopeStyle.NEOX)
44+
45+
# Block configs
46+
self.block_configs = self.read_cfg(list, "block_configs", no_default)
47+
assert len(self.block_configs) == self.num_hidden_layers, \
48+
"Number of hidden layers does not match length of block_configs list"
49+
50+
51+
class DeciLMModel(Model):
52+
53+
def __init__(
54+
self,
55+
config: DeciLMConfig,
56+
**kwargs
57+
):
58+
super().__init__(config, **kwargs)
59+
60+
self.modules += [
61+
Embedding(
62+
config = config,
63+
key = "model.embed_tokens",
64+
vocab_size = config.vocab_size,
65+
hidden_size = config.hidden_size,
66+
)
67+
]
68+
69+
self.first_block_idx = len(self.modules)
70+
self.last_kv_module_idx = 0
71+
cache_layer_idx = 0
72+
73+
for idx, cfg in enumerate(config.block_configs):
74+
cfg_attn = cfg["attention"]
75+
cfg_ffn = cfg["ffn"]
76+
77+
if cfg_attn.get("no_op"):
78+
attn_norm = None
79+
attn = None
80+
else:
81+
assert not cfg_attn.get("num_sink_tokens"), "DeciLM: num_sink_tokens not supported"
82+
assert not cfg_attn.get("replace_with_linear"), "DeciLM: replace_with_linear not supported"
83+
assert not cfg_attn.get("sparsify"), "DeciLM: sparsify not supported"
84+
assert not cfg_attn.get("unshifted_sink"), "DeciLM: unshifted_sink not supported"
85+
assert not cfg_attn.get("use_prefill_window_in_sink_attention"), \
86+
"DeciLM: use_prefill_window_in_sink_attention not supported"
87+
attn_norm = RMSNorm(
88+
config = config,
89+
key = f"model.layers.{idx}.input_layernorm",
90+
rms_norm_eps = config.rms_norm_eps,
91+
)
92+
attn = Attention(
93+
config = config,
94+
key = f"model.layers.{idx}.self_attn",
95+
layer_idx = cache_layer_idx,
96+
hidden_size = config.hidden_size,
97+
head_dim = config.head_dim,
98+
num_q_heads = config.num_q_heads,
99+
num_kv_heads = cfg_attn["n_heads_in_group"],
100+
rope_settings = config.rope_settings,
101+
sm_scale = None,
102+
key_q = "q_proj",
103+
key_k = "k_proj",
104+
key_v = "v_proj",
105+
key_o = "o_proj",
106+
qmap = "block.attn",
107+
)
108+
cache_layer_idx += 1
109+
self.last_kv_module_idx = len(self.modules)
110+
111+
if cfg_ffn.get("no_op"):
112+
mlp_norm = None
113+
mlp = None
114+
else:
115+
assert not cfg_ffn.get("replace_with_linear"), "DeciLM: replace_with_linear not supported"
116+
assert not cfg_ffn.get("sparsify"), "DeciLM: sparsify not supported"
117+
mlp_norm = RMSNorm(
118+
config = config,
119+
key = f"model.layers.{idx}.post_attention_layernorm",
120+
rms_norm_eps = config.rms_norm_eps,
121+
)
122+
interm_size = int(2 * cfg_ffn["ffn_mult"] * config.hidden_size / 3)
123+
interm_size = ((interm_size + 255) // 256) * 256
124+
mlp = GatedMLP(
125+
config = config,
126+
key = f"model.layers.{idx}.mlp",
127+
hidden_size = config.hidden_size,
128+
intermediate_size = interm_size,
129+
key_up = "up_proj",
130+
key_gate = "gate_proj",
131+
key_down = "down_proj",
132+
qmap = "block.mlp",
133+
)
134+
135+
self.modules += [
136+
TransformerBlock(
137+
config = config,
138+
key = f"model.layers.{idx}",
139+
attn_norm = attn_norm,
140+
attn = attn,
141+
mlp_norm = mlp_norm,
142+
mlp = mlp,
143+
)
144+
]
145+
146+
head_alt_key = None
147+
if config.tie_word_embeddings and not self.config.stc.has_tensor("lm_head"):
148+
head_alt_key = "model.embed_tokens"
149+
150+
self.modules += [
151+
RMSNorm(
152+
config = config,
153+
key = "model.norm",
154+
rms_norm_eps = config.rms_norm_eps,
155+
out_dtype = torch.half,
156+
),
157+
Linear(
158+
config = config,
159+
key = "lm_head",
160+
qbits_key = "head_bits",
161+
alt_key = head_alt_key,
162+
in_features = config.hidden_size,
163+
out_features = config.vocab_size,
164+
qmap = "block",
165+
caps = {"logits_output": True}
166+
)
167+
]
168+
169+
self.logit_layer_idx = len(self.modules) - 1
170+
171+
172+
@override
173+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
174+
params["input_ids"] = input_ids
175+
input_ids = prepare_for_attn(input_ids, params)
176+
return input_ids

0 commit comments

Comments
 (0)