Skip to content

Commit 7c9b7cb

Browse files
feat(candle): add FlashMistral (#308)
1 parent 5c6151c commit 7c9b7cb

39 files changed

+17578
-176
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/candle/src/flash_attn.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@ pub(crate) fn flash_attn_varlen(
3131
max_seqlen_k: usize,
3232
softmax_scale: f32,
3333
causal: bool,
34+
window_size_left: Option<usize>,
3435
) -> Result<Tensor, candle::Error> {
3536
let runtime_compute_cap = get_runtime_compute_cap();
3637

3738
if runtime_compute_cap == 75 {
3839
if alibi_slopes.is_some() {
3940
candle::bail!("Flash attention v1 does not support alibi");
4041
}
42+
if window_size_left.is_some() {
43+
candle::bail!("Flash attention v1 does not support attention windowing");
44+
}
4145

4246
#[cfg(feature = "flash-attn-v1")]
4347
{
@@ -59,10 +63,12 @@ pub(crate) fn flash_attn_varlen(
5963
} else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 {
6064
#[cfg(feature = "flash-attn")]
6165
{
62-
use candle_flash_attn::{flash_attn_varlen, flash_attn_varlen_alibi};
66+
use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed};
67+
68+
let window_size_right = if causal { Some(0) } else { None };
6369

6470
let attention = if let Some(alibi_slopes) = alibi_slopes {
65-
flash_attn_varlen_alibi(
71+
flash_attn_varlen_alibi_windowed(
6672
q,
6773
k,
6874
v,
@@ -72,10 +78,11 @@ pub(crate) fn flash_attn_varlen(
7278
max_seqlen_q,
7379
max_seqlen_k,
7480
softmax_scale,
75-
causal,
81+
window_size_left,
82+
window_size_right,
7683
)
7784
} else {
78-
flash_attn_varlen(
85+
flash_attn_varlen_windowed(
7986
q,
8087
k,
8188
v,
@@ -84,7 +91,8 @@ pub(crate) fn flash_attn_varlen(
8491
max_seqlen_q,
8592
max_seqlen_k,
8693
softmax_scale,
87-
causal,
94+
window_size_left,
95+
window_size_right,
8896
)
8997
};
9098

backends/candle/src/layers/linear.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use serde::Deserialize;
77
pub enum HiddenAct {
88
Gelu,
99
Relu,
10+
#[serde(alias = "silu")]
1011
Swiglu,
1112
}
1213

backends/candle/src/layers/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
mod cublaslt;
33
mod layer_norm;
44
mod linear;
5+
#[allow(dead_code, unused)]
6+
mod rms_norm;
57

68
pub use cublaslt::get_cublas_lt_wrapper;
79
pub use layer_norm::LayerNorm;
810
pub use linear::{HiddenAct, Linear};
11+
#[allow(unused_imports)]
12+
pub use rms_norm::RMSNorm;
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use candle::{DType, Device, Result, Tensor, D};
2+
use candle_nn::VarBuilder;
3+
4+
#[derive(Debug)]
5+
pub struct RMSNorm {
6+
weight: Tensor,
7+
epsilon: f32,
8+
span: tracing::Span,
9+
}
10+
11+
impl RMSNorm {
12+
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
13+
Ok(Self {
14+
weight: vb
15+
.get(hidden_size, "weight")
16+
.or_else(|_| vb.get(hidden_size, "gamma"))?,
17+
epsilon,
18+
span: tracing::span!(tracing::Level::TRACE, "rms-norm"),
19+
})
20+
}
21+
22+
pub fn forward(
23+
&self,
24+
hidden_states: &Tensor,
25+
residual: Option<&Tensor>,
26+
) -> Result<(Tensor, Tensor)> {
27+
let _enter = self.span.enter();
28+
29+
match hidden_states.device() {
30+
Device::Cpu | Device::Metal(_) => {
31+
let mut hidden_states = hidden_states.clone();
32+
let residual_add = if let Some(residual) = residual {
33+
let residual_add = hidden_states.add(residual)?;
34+
hidden_states = residual_add.clone();
35+
residual_add
36+
} else {
37+
hidden_states.clone()
38+
};
39+
40+
let hidden_states_dtype = hidden_states.dtype();
41+
let internal_dtype = match hidden_states_dtype {
42+
DType::F16 | DType::BF16 => DType::F32,
43+
d => d,
44+
};
45+
let hidden_size = hidden_states.dim(D::Minus1)?;
46+
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
47+
let norm_hidden_states =
48+
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
49+
let hidden_states_normed = hidden_states
50+
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
51+
Ok((
52+
hidden_states_normed
53+
.to_dtype(hidden_states_dtype)?
54+
.broadcast_mul(&self.weight)?,
55+
residual_add,
56+
))
57+
}
58+
Device::Cuda(_) => {
59+
#[cfg(feature = "cuda")]
60+
{
61+
use candle_layer_norm::{fused_add_rms_norm, rms_norm};
62+
63+
let original_shape = hidden_states.shape();
64+
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
65+
66+
if let Some(residual) = residual {
67+
let residual = residual.flatten_to(D::Minus2)?;
68+
69+
let (result, residual_add) = fused_add_rms_norm(
70+
&hidden_states,
71+
&residual,
72+
&self.weight,
73+
None,
74+
self.epsilon,
75+
)?;
76+
Ok((
77+
result.reshape(original_shape)?,
78+
residual_add.reshape(original_shape)?,
79+
))
80+
} else {
81+
let residual_add = hidden_states.clone();
82+
83+
let result = rms_norm(&hidden_states, &self.weight, None, self.epsilon)?;
84+
85+
Ok((
86+
result.reshape(original_shape)?,
87+
residual_add.reshape(original_shape)?,
88+
))
89+
}
90+
}
91+
#[cfg(not(feature = "cuda"))]
92+
candle::bail!("`cuda` feature is not enabled")
93+
}
94+
}
95+
}
96+
}

backends/candle/src/lib.rs

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ use crate::compute_cap::{
1212
};
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
15-
Model, NomicBertModel, NomicConfig,
15+
MistralConfig, Model, NomicBertModel, NomicConfig,
1616
};
1717
#[cfg(feature = "cuda")]
1818
use crate::models::{
1919
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
20-
FlashNomicBertModel,
20+
FlashMistralModel, FlashNomicBertModel,
2121
};
2222
use anyhow::Context;
2323
use candle::{DType, Device};
@@ -56,6 +56,7 @@ enum Config {
5656
DistilBert(DistilBertConfig),
5757
#[serde(rename(deserialize = "nomic_bert"))]
5858
NomicBert(NomicConfig),
59+
Mistral(MistralConfig),
5960
}
6061

6162
pub struct CandleBackend {
@@ -69,6 +70,54 @@ impl CandleBackend {
6970
dtype: String,
7071
model_type: ModelType,
7172
) -> Result<Self, BackendError> {
73+
// Default files
74+
let default_safetensors = model_path.join("model.safetensors");
75+
let default_pytorch = model_path.join("pytorch_model.bin");
76+
77+
// Single Files
78+
let model_files = if default_safetensors.exists() {
79+
vec![default_safetensors]
80+
} else if default_pytorch.exists() {
81+
vec![default_pytorch]
82+
}
83+
// Sharded weights
84+
else {
85+
// Get index file
86+
let index_file = model_path.join("model.safetensors.index.json");
87+
88+
// Parse file
89+
let index_file_string: String = std::fs::read_to_string(&index_file)
90+
.map_err(|err| BackendError::Start(err.to_string()))?;
91+
let json: serde_json::Value = serde_json::from_str(&index_file_string)
92+
.map_err(|err| BackendError::Start(err.to_string()))?;
93+
94+
let weight_map = match json.get("weight_map") {
95+
None => {
96+
return Err(BackendError::Start(format!(
97+
"no weight map in {index_file:?}"
98+
)));
99+
}
100+
Some(serde_json::Value::Object(map)) => map,
101+
Some(_) => {
102+
return Err(BackendError::Start(format!(
103+
"weight map in {index_file:?} is not a map"
104+
)));
105+
}
106+
};
107+
let mut safetensors_files = std::collections::HashSet::new();
108+
for value in weight_map.values() {
109+
if let Some(file) = value.as_str() {
110+
safetensors_files.insert(file.to_string());
111+
}
112+
}
113+
114+
// Collect paths
115+
safetensors_files
116+
.iter()
117+
.map(|n| model_path.join(n))
118+
.collect()
119+
};
120+
72121
// Load config
73122
let config: String = std::fs::read_to_string(model_path.join("config.json"))
74123
.context("Unable to read config file")
@@ -115,17 +164,10 @@ impl CandleBackend {
115164
)))
116165
}?;
117166

118-
let safetensors_path = model_path.join("model.safetensors");
119-
let vb = if safetensors_path.exists() {
120-
unsafe {
121-
VarBuilder::from_mmaped_safetensors(
122-
&[model_path.join("model.safetensors")],
123-
dtype,
124-
&device,
125-
)
126-
}
167+
let vb = if model_files.len() == 1 && model_files[0].extension().unwrap() == "bin" {
168+
VarBuilder::from_pth(&model_files[0], dtype, &device)
127169
} else {
128-
VarBuilder::from_pth(model_path.join("pytorch_model.bin"), dtype, &device)
170+
unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device) }
129171
}
130172
.s()?;
131173

@@ -136,7 +178,7 @@ impl CandleBackend {
136178
)),
137179
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
138180
BertConfigWrapper::JinaBert(config) => {
139-
tracing::info!("Starting JinaBertModel model on {:?}", device);
181+
tracing::info!("Starting JinaBert model on {:?}", device);
140182
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
141183
}
142184
BertConfigWrapper::JinaCodeBert(config) => {
@@ -160,15 +202,19 @@ impl CandleBackend {
160202
))
161203
}
162204
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
163-
tracing::info!("Starting DistilBertModel model on {:?}", device);
205+
tracing::info!("Starting DistilBert model on {:?}", device);
164206
Ok(Box::new(
165207
DistilBertModel::load(vb, &config, model_type).s()?,
166208
))
167209
}
168210
(Config::NomicBert(config), Device::Cpu | Device::Metal(_)) => {
169-
tracing::info!("Starting NomicBertModel model on {:?}", device);
211+
tracing::info!("Starting NomicBert model on {:?}", device);
170212
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
171213
}
214+
(Config::Mistral(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
215+
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216+
.to_string(),
217+
)),
172218
#[cfg(feature = "cuda")]
173219
(Config::Bert(config), Device::Cuda(_)) => {
174220
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
@@ -198,7 +244,7 @@ impl CandleBackend {
198244
} else {
199245
match config {
200246
BertConfigWrapper::JinaBert(config) => {
201-
tracing::info!("Starting JinaBertModel model on {:?}", device);
247+
tracing::info!("Starting JinaBert model on {:?}", device);
202248
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
203249
}
204250
BertConfigWrapper::JinaCodeBert(config) => {
@@ -245,7 +291,7 @@ impl CandleBackend {
245291
.to_lowercase()
246292
== "true"
247293
{
248-
tracing::info!("Starting FlashDistilBertModel model on {:?}", device);
294+
tracing::info!("Starting FlashDistilBert model on {:?}", device);
249295
Ok(Box::new(
250296
FlashDistilBertModel::load(vb, &config, model_type).s()?,
251297
))
@@ -265,15 +311,28 @@ impl CandleBackend {
265311
.to_lowercase()
266312
== "true"
267313
{
268-
tracing::info!("Starting FlashNomicBertModel model on {:?}", device);
314+
tracing::info!("Starting FlashNomicBert model on {:?}", device);
269315
Ok(Box::new(
270316
FlashNomicBertModel::load(vb, &config, model_type).s()?,
271317
))
272318
} else {
273-
tracing::info!("Starting NomicBertModel model on {:?}", device);
319+
tracing::info!("Starting NomicBert model on {:?}", device);
274320
Ok(Box::new(NomicBertModel::load(vb, &config, model_type).s()?))
275321
}
276322
}
323+
#[cfg(feature = "cuda")]
324+
(Config::Mistral(config), Device::Cuda(_)) => {
325+
if dtype != DType::F16
326+
|| !cfg!(feature = "flash-attn")
327+
|| get_runtime_compute_cap().unwrap() < 80
328+
{
329+
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
330+
}
331+
tracing::info!("Starting FlashMistral model on {:?}", device);
332+
Ok(Box::new(
333+
FlashMistralModel::load(vb, &config, model_type).s()?,
334+
))
335+
}
277336
};
278337

279338
Ok(Self {

backends/candle/src/models/bert.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,10 @@ impl BertModel {
638638
(pool, Some(classifier), None)
639639
}
640640
ModelType::Embedding(pool) => {
641+
if pool == Pool::LastToken {
642+
candle::bail!("`last_token` is not supported for Bert");
643+
}
644+
641645
let splade = if pool == Pool::Splade {
642646
Some(BertSpladeHead::load_roberta(vb.clone(), config)?)
643647
} else {
@@ -832,6 +836,8 @@ impl BertModel {
832836
let pooled_embeddings = match self.pool {
833837
// CLS pooling
834838
Pool::Cls => outputs.i((.., 0))?,
839+
// Last token pooling is not supported for this model
840+
Pool::LastToken => unreachable!(),
835841
// Mean pooling
836842
Pool::Mean => {
837843
if let Some(ref attention_mask) = attention_mask {

0 commit comments

Comments
 (0)