@@ -12,12 +12,12 @@ use crate::compute_cap::{
12
12
} ;
13
13
use crate :: models:: {
14
14
BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15
- Model , NomicBertModel , NomicConfig ,
15
+ MistralConfig , Model , NomicBertModel , NomicConfig ,
16
16
} ;
17
17
#[ cfg( feature = "cuda" ) ]
18
18
use crate :: models:: {
19
19
FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20
- FlashNomicBertModel ,
20
+ FlashMistralModel , FlashNomicBertModel ,
21
21
} ;
22
22
use anyhow:: Context ;
23
23
use candle:: { DType , Device } ;
@@ -56,6 +56,7 @@ enum Config {
56
56
DistilBert ( DistilBertConfig ) ,
57
57
#[ serde( rename( deserialize = "nomic_bert" ) ) ]
58
58
NomicBert ( NomicConfig ) ,
59
+ Mistral ( MistralConfig ) ,
59
60
}
60
61
61
62
pub struct CandleBackend {
@@ -69,6 +70,54 @@ impl CandleBackend {
69
70
dtype : String ,
70
71
model_type : ModelType ,
71
72
) -> Result < Self , BackendError > {
73
+ // Default files
74
+ let default_safetensors = model_path. join ( "model.safetensors" ) ;
75
+ let default_pytorch = model_path. join ( "pytorch_model.bin" ) ;
76
+
77
+ // Single Files
78
+ let model_files = if default_safetensors. exists ( ) {
79
+ vec ! [ default_safetensors]
80
+ } else if default_pytorch. exists ( ) {
81
+ vec ! [ default_pytorch]
82
+ }
83
+ // Sharded weights
84
+ else {
85
+ // Get index file
86
+ let index_file = model_path. join ( "model.safetensors.index.json" ) ;
87
+
88
+ // Parse file
89
+ let index_file_string: String = std:: fs:: read_to_string ( & index_file)
90
+ . map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
91
+ let json: serde_json:: Value = serde_json:: from_str ( & index_file_string)
92
+ . map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
93
+
94
+ let weight_map = match json. get ( "weight_map" ) {
95
+ None => {
96
+ return Err ( BackendError :: Start ( format ! (
97
+ "no weight map in {index_file:?}"
98
+ ) ) ) ;
99
+ }
100
+ Some ( serde_json:: Value :: Object ( map) ) => map,
101
+ Some ( _) => {
102
+ return Err ( BackendError :: Start ( format ! (
103
+ "weight map in {index_file:?} is not a map"
104
+ ) ) ) ;
105
+ }
106
+ } ;
107
+ let mut safetensors_files = std:: collections:: HashSet :: new ( ) ;
108
+ for value in weight_map. values ( ) {
109
+ if let Some ( file) = value. as_str ( ) {
110
+ safetensors_files. insert ( file. to_string ( ) ) ;
111
+ }
112
+ }
113
+
114
+ // Collect paths
115
+ safetensors_files
116
+ . iter ( )
117
+ . map ( |n| model_path. join ( n) )
118
+ . collect ( )
119
+ } ;
120
+
72
121
// Load config
73
122
let config: String = std:: fs:: read_to_string ( model_path. join ( "config.json" ) )
74
123
. context ( "Unable to read config file" )
@@ -115,17 +164,10 @@ impl CandleBackend {
115
164
) ) )
116
165
} ?;
117
166
118
- let safetensors_path = model_path. join ( "model.safetensors" ) ;
119
- let vb = if safetensors_path. exists ( ) {
120
- unsafe {
121
- VarBuilder :: from_mmaped_safetensors (
122
- & [ model_path. join ( "model.safetensors" ) ] ,
123
- dtype,
124
- & device,
125
- )
126
- }
167
+ let vb = if model_files. len ( ) == 1 && model_files[ 0 ] . extension ( ) . unwrap ( ) == "bin" {
168
+ VarBuilder :: from_pth ( & model_files[ 0 ] , dtype, & device)
127
169
} else {
128
- VarBuilder :: from_pth ( model_path . join ( "pytorch_model.bin" ) , dtype, & device)
170
+ unsafe { VarBuilder :: from_mmaped_safetensors ( & model_files , dtype, & device) }
129
171
}
130
172
. s ( ) ?;
131
173
@@ -136,7 +178,7 @@ impl CandleBackend {
136
178
) ) ,
137
179
( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138
180
BertConfigWrapper :: JinaBert ( config) => {
139
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
181
+ tracing:: info!( "Starting JinaBert model on {:?}" , device) ;
140
182
Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141
183
}
142
184
BertConfigWrapper :: JinaCodeBert ( config) => {
@@ -160,15 +202,19 @@ impl CandleBackend {
160
202
) )
161
203
}
162
204
( Config :: DistilBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
163
- tracing:: info!( "Starting DistilBertModel model on {:?}" , device) ;
205
+ tracing:: info!( "Starting DistilBert model on {:?}" , device) ;
164
206
Ok ( Box :: new (
165
207
DistilBertModel :: load ( vb, & config, model_type) . s ( ) ?,
166
208
) )
167
209
}
168
210
( Config :: NomicBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
169
- tracing:: info!( "Starting NomicBertModel model on {:?}" , device) ;
211
+ tracing:: info!( "Starting NomicBert model on {:?}" , device) ;
170
212
Ok ( Box :: new ( NomicBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
171
213
}
214
+ ( Config :: Mistral ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
215
+ "Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
216
+ . to_string ( ) ,
217
+ ) ) ,
172
218
#[ cfg( feature = "cuda" ) ]
173
219
( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
174
220
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -198,7 +244,7 @@ impl CandleBackend {
198
244
} else {
199
245
match config {
200
246
BertConfigWrapper :: JinaBert ( config) => {
201
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
247
+ tracing:: info!( "Starting JinaBert model on {:?}" , device) ;
202
248
Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203
249
}
204
250
BertConfigWrapper :: JinaCodeBert ( config) => {
@@ -245,7 +291,7 @@ impl CandleBackend {
245
291
. to_lowercase ( )
246
292
== "true"
247
293
{
248
- tracing:: info!( "Starting FlashDistilBertModel model on {:?}" , device) ;
294
+ tracing:: info!( "Starting FlashDistilBert model on {:?}" , device) ;
249
295
Ok ( Box :: new (
250
296
FlashDistilBertModel :: load ( vb, & config, model_type) . s ( ) ?,
251
297
) )
@@ -265,15 +311,28 @@ impl CandleBackend {
265
311
. to_lowercase ( )
266
312
== "true"
267
313
{
268
- tracing:: info!( "Starting FlashNomicBertModel model on {:?}" , device) ;
314
+ tracing:: info!( "Starting FlashNomicBert model on {:?}" , device) ;
269
315
Ok ( Box :: new (
270
316
FlashNomicBertModel :: load ( vb, & config, model_type) . s ( ) ?,
271
317
) )
272
318
} else {
273
- tracing:: info!( "Starting NomicBertModel model on {:?}" , device) ;
319
+ tracing:: info!( "Starting NomicBert model on {:?}" , device) ;
274
320
Ok ( Box :: new ( NomicBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
275
321
}
276
322
}
323
+ #[ cfg( feature = "cuda" ) ]
324
+ ( Config :: Mistral ( config) , Device :: Cuda ( _) ) => {
325
+ if dtype != DType :: F16
326
+ || !cfg ! ( feature = "flash-attn" )
327
+ || get_runtime_compute_cap ( ) . unwrap ( ) < 80
328
+ {
329
+ return Err ( BackendError :: Start ( "Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
330
+ }
331
+ tracing:: info!( "Starting FlashMistral model on {:?}" , device) ;
332
+ Ok ( Box :: new (
333
+ FlashMistralModel :: load ( vb, & config, model_type) . s ( ) ?,
334
+ ) )
335
+ }
277
336
} ;
278
337
279
338
Ok ( Self {
0 commit comments