@@ -13,13 +13,13 @@ use crate::compute_cap::{
13
13
use crate :: models:: {
14
14
BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , GTEModel , JinaBertModel ,
15
15
JinaCodeBertModel , MPNetConfig , MPNetModel , MistralConfig , Model , ModernBertConfig ,
16
- ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config ,
16
+ ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config , Qwen3Config ,
17
17
} ;
18
18
#[ cfg( feature = "cuda" ) ]
19
19
use crate :: models:: {
20
20
FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
21
21
FlashJinaCodeBertModel , FlashMistralModel , FlashModernBertModel , FlashNomicBertModel ,
22
- FlashQwen2Model ,
22
+ FlashQwen2Model , FlashQwen3Model ,
23
23
} ;
24
24
use anyhow:: Context ;
25
25
use candle:: { DType , Device } ;
@@ -103,6 +103,8 @@ enum Config {
103
103
Gte ( GTEConfig ) ,
104
104
#[ allow( dead_code) ]
105
105
Qwen2 ( Qwen2Config ) ,
106
+ #[ allow( dead_code) ]
107
+ Qwen3 ( Qwen3Config ) ,
106
108
#[ serde( rename = "mpnet" ) ]
107
109
MPNet ( MPNetConfig ) ,
108
110
#[ serde( rename( deserialize = "modernbert" ) ) ]
@@ -273,6 +275,10 @@ impl CandleBackend {
273
275
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
274
276
. to_string ( ) ,
275
277
) ) ,
278
+ ( Config :: Qwen3 ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
279
+ "Qwen3 is only supported on Cuda devices in fp16 with flash attention enabled"
280
+ . to_string ( ) ,
281
+ ) ) ,
276
282
( Config :: MPNet ( config) , _) => {
277
283
tracing:: info!( "Starting MPNet model on {:?}" , device) ;
278
284
Ok ( Box :: new ( MPNetModel :: load ( vb, & config, model_type) . s ( ) ?) )
@@ -446,6 +452,18 @@ impl CandleBackend {
446
452
FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
447
453
) )
448
454
}
455
+ #[ cfg( feature = "cuda" ) ]
456
+ ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
457
+ if dtype != DType :: F16
458
+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
459
+ {
460
+ return Err ( BackendError :: Start ( "Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
461
+ }
462
+ tracing:: info!( "Starting FlashQwen3 model on {:?}" , device) ;
463
+ Ok ( Box :: new (
464
+ FlashQwen3Model :: load ( vb, & config, model_type) . s ( ) ?,
465
+ ) )
466
+ }
449
467
} ;
450
468
451
469
Ok ( Self {
0 commit comments