Skip to content

Commit 11ffc60

Browse files
alvarobarttkozistr
andauthored
Add Qwen3Model (#627)
Co-authored-by: Hyeongchan Kim <[email protected]>
1 parent d51a8b9 commit 11ffc60

File tree

7 files changed

+4717
-2
lines changed

7 files changed

+4717
-2
lines changed

backends/candle/src/lib.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ use crate::compute_cap::{
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
1515
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
16-
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
16+
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config,
1717
};
1818
#[cfg(feature = "cuda")]
1919
use crate::models::{
2020
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
2121
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
22-
FlashQwen2Model,
22+
FlashQwen2Model, FlashQwen3Model,
2323
};
2424
use anyhow::Context;
2525
use candle::{DType, Device};
@@ -103,6 +103,8 @@ enum Config {
103103
Gte(GTEConfig),
104104
#[allow(dead_code)]
105105
Qwen2(Qwen2Config),
106+
#[allow(dead_code)]
107+
Qwen3(Qwen3Config),
106108
#[serde(rename = "mpnet")]
107109
MPNet(MPNetConfig),
108110
#[serde(rename(deserialize = "modernbert"))]
@@ -273,6 +275,10 @@ impl CandleBackend {
273275
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
274276
.to_string(),
275277
)),
278+
(Config::Qwen3(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
279+
"Qwen3 is only supported on Cuda devices in fp16 with flash attention enabled"
280+
.to_string(),
281+
)),
276282
(Config::MPNet(config), _) => {
277283
tracing::info!("Starting MPNet model on {:?}", device);
278284
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
@@ -446,6 +452,18 @@ impl CandleBackend {
446452
FlashQwen2Model::load(vb, &config, model_type).s()?,
447453
))
448454
}
455+
#[cfg(feature = "cuda")]
456+
(Config::Qwen3(config), Device::Cuda(_)) => {
457+
if dtype != DType::F16
458+
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
459+
{
460+
return Err(BackendError::Start("Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
461+
}
462+
tracing::info!("Starting FlashQwen3 model on {:?}", device);
463+
Ok(Box::new(
464+
FlashQwen3Model::load(vb, &config, model_type).s()?,
465+
))
466+
}
449467
};
450468

451469
Ok(Self {

0 commit comments

Comments
 (0)