@@ -11,13 +11,13 @@ use crate::compute_cap::{
11
11
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
12
12
} ;
13
13
use crate :: models:: {
14
- BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15
- MistralConfig , Model , NomicBertModel , NomicConfig ,
14
+ BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , JinaBertModel ,
15
+ JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig ,
16
16
} ;
17
17
#[ cfg( feature = "cuda" ) ]
18
18
use crate :: models:: {
19
- FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20
- FlashMistralModel , FlashNomicBertModel ,
19
+ FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
20
+ FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel ,
21
21
} ;
22
22
use anyhow:: Context ;
23
23
use candle:: { DType , Device } ;
@@ -57,6 +57,8 @@ enum Config {
57
57
#[ serde( rename( deserialize = "nomic_bert" ) ) ]
58
58
NomicBert ( NomicConfig ) ,
59
59
Mistral ( MistralConfig ) ,
60
+ #[ serde( rename = "new" ) ]
61
+ Gte ( GTEConfig ) ,
60
62
}
61
63
62
64
pub struct CandleBackend {
@@ -215,6 +217,10 @@ impl CandleBackend {
215
217
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216
218
. to_string ( ) ,
217
219
) ) ,
220
+ ( Config :: Gte ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
221
+ "GTE is only supported on Cuda devices in fp16 with flash attention enabled"
222
+ . to_string ( ) ,
223
+ ) ) ,
218
224
#[ cfg( feature = "cuda" ) ]
219
225
( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
220
226
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -333,6 +339,17 @@ impl CandleBackend {
333
339
FlashMistralModel :: load ( vb, & config, model_type) . s ( ) ?,
334
340
) )
335
341
}
342
+ #[ cfg( feature = "cuda" ) ]
343
+ ( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
344
+ if dtype != DType :: F16
345
+ || !cfg ! ( feature = "flash-attn" )
346
+ || get_runtime_compute_cap ( ) . unwrap ( ) < 80
347
+ {
348
+ return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
349
+ }
350
+ tracing:: info!( "Starting FlashGTE model on {:?}" , device) ;
351
+ Ok ( Box :: new ( FlashGTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
352
+ }
336
353
} ;
337
354
338
355
Ok ( Self {
0 commit comments