Skip to content

Commit 35aefeb

Browse files
feat(candle): add flash gte (#310)
1 parent 7c9b7cb commit 35aefeb

File tree

12 files changed

+3664
-4
lines changed

12 files changed

+3664
-4
lines changed

backends/candle/src/lib.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ use crate::compute_cap::{
1111
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
14-
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
15-
MistralConfig, Model, NomicBertModel, NomicConfig,
14+
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
15+
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig,
1616
};
1717
#[cfg(feature = "cuda")]
1818
use crate::models::{
19-
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel,
20-
FlashMistralModel, FlashNomicBertModel,
19+
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
20+
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel,
2121
};
2222
use anyhow::Context;
2323
use candle::{DType, Device};
@@ -57,6 +57,8 @@ enum Config {
5757
#[serde(rename(deserialize = "nomic_bert"))]
5858
NomicBert(NomicConfig),
5959
Mistral(MistralConfig),
60+
#[serde(rename = "new")]
61+
Gte(GTEConfig),
6062
}
6163

6264
pub struct CandleBackend {
@@ -215,6 +217,10 @@ impl CandleBackend {
215217
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216218
.to_string(),
217219
)),
220+
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
221+
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
222+
.to_string(),
223+
)),
218224
#[cfg(feature = "cuda")]
219225
(Config::Bert(config), Device::Cuda(_)) => {
220226
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
@@ -333,6 +339,17 @@ impl CandleBackend {
333339
FlashMistralModel::load(vb, &config, model_type).s()?,
334340
))
335341
}
342+
#[cfg(feature = "cuda")]
343+
(Config::Gte(config), Device::Cuda(_)) => {
344+
if dtype != DType::F16
345+
|| !cfg!(feature = "flash-attn")
346+
|| get_runtime_compute_cap().unwrap() < 80
347+
{
348+
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
349+
}
350+
tracing::info!("Starting FlashGTE model on {:?}", device);
351+
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
352+
}
336353
};
337354

338355
Ok(Self {

backends/candle/src/models/bert.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pub enum PositionEmbeddingType {
3535
#[default]
3636
Absolute,
3737
Alibi,
38+
Rope,
3839
}
3940

4041
#[derive(Debug)]

0 commit comments

Comments
 (0)