@@ -115,46 +115,6 @@ enum Config {
115115 XlmRoberta ( BertConfig ) ,
116116}
117117
118- #[ derive( Debug , Clone , Deserialize , PartialEq ) ]
119- enum ModuleType {
120- #[ serde( rename = "sentence_transformers.models.Dense" ) ]
121- Dense ,
122- #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
123- Normalize ,
124- #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
125- Pooling ,
126- #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
127- Transformer ,
128- }
129-
130- #[ derive( Debug , Clone , Deserialize ) ]
131- struct ModuleConfig {
132- #[ allow( dead_code) ]
133- idx : usize ,
134- #[ allow( dead_code) ]
135- name : String ,
136- path : String ,
137- #[ serde( rename = "type" ) ]
138- module_type : ModuleType ,
139- }
140-
141- fn parse_dense_paths_from_modules ( model_path : & Path ) -> Result < Vec < String > , std:: io:: Error > {
142- let modules_path = model_path. join ( "modules.json" ) ;
143- if !modules_path. exists ( ) {
144- return Ok ( vec ! [ ] ) ;
145- }
146-
147- let content = std:: fs:: read_to_string ( & modules_path) ?;
148- let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
149- . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
150-
151- Ok ( modules
152- . into_iter ( )
153- . filter ( |module| module. module_type == ModuleType :: Dense )
154- . map ( |module| module. path )
155- . collect :: < Vec < String > > ( ) )
156- }
157-
158118pub struct CandleBackend {
159119 device : Device ,
160120 model : Box < dyn Model + Send > ,
@@ -551,66 +511,54 @@ impl CandleBackend {
551511 }
552512 } ;
553513
554- // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555- // in such case simply use the `dense_paths`
556- // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557- // content to read the paths of the default Dense paths, useful when the model directory
558- // is provided as the `model-id` rather than the ID from the Hugging Face Hub
559- // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560- // read from modules.json, this allows users to explicitly disable dense layers
561514 let mut dense_layers = Vec :: new ( ) ;
562-
563- let paths_to_load = if let Some ( dense_paths) = & dense_paths {
564- // If dense_paths is explicitly provided (even if empty), respect that choice
565- dense_paths. clone ( )
566- } else {
567- // Try to parse modules.json only if dense_paths is None
568- parse_dense_paths_from_modules ( model_path) . unwrap_or_default ( )
569- } ;
570-
571- if !paths_to_load. is_empty ( ) {
572- tracing:: info!( "Loading Dense module/s from path/s: {paths_to_load:?}" ) ;
573-
574- for dense_path in paths_to_load. iter ( ) {
575- let dense_safetensors = model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
576- let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
577-
578- if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
579- let dense_config_path = model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
580-
581- let dense_config_str =
582- std:: fs:: read_to_string ( & dense_config_path) . map_err ( |err| {
583- BackendError :: Start ( format ! (
584- "Unable to read `{dense_path}/config.json` file: {err:?}" ,
585- ) )
586- } ) ?;
587- let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
588- . map_err ( |err| {
589- BackendError :: Start ( format ! (
590- "Unable to parse `{dense_path}/config.json`: {err:?}" ,
591- ) )
592- } ) ?;
593-
594- let dense_vb = if dense_safetensors. exists ( ) {
595- unsafe {
596- VarBuilder :: from_mmaped_safetensors (
597- & [ dense_safetensors] ,
598- dtype,
599- & device,
600- )
601- }
602- . s ( ) ?
515+ if let Some ( dense_paths) = dense_paths {
516+ if !dense_paths. is_empty ( ) {
517+ tracing:: info!( "Loading Dense module/s from path/s: {dense_paths:?}" ) ;
518+
519+ for dense_path in dense_paths. iter ( ) {
520+ let dense_safetensors =
521+ model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
522+ let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
523+
524+ if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
525+ let dense_config_path =
526+ model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
527+
528+ let dense_config_str = std:: fs:: read_to_string ( & dense_config_path)
529+ . map_err ( |err| {
530+ BackendError :: Start ( format ! (
531+ "Unable to read `{dense_path}/config.json` file: {err:?}" ,
532+ ) )
533+ } ) ?;
534+ let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
535+ . map_err ( |err| {
536+ BackendError :: Start ( format ! (
537+ "Unable to parse `{dense_path}/config.json`: {err:?}" ,
538+ ) )
539+ } ) ?;
540+
541+ let dense_vb = if dense_safetensors. exists ( ) {
542+ unsafe {
543+ VarBuilder :: from_mmaped_safetensors (
544+ & [ dense_safetensors] ,
545+ dtype,
546+ & device,
547+ )
548+ }
549+ . s ( ) ?
550+ } else {
551+ VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
552+ } ;
553+
554+ let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
555+ as Box < dyn DenseLayer + Send > ;
556+ dense_layers. push ( dense_layer) ;
557+
558+ tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
603559 } else {
604- VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
605- } ;
606-
607- let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
608- as Box < dyn DenseLayer + Send > ;
609- dense_layers. push ( dense_layer) ;
610-
611- tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
612- } else {
613- tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
560+ tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
561+ }
614562 }
615563 }
616564 }
0 commit comments