@@ -115,6 +115,46 @@ enum Config {
115115 XlmRoberta ( BertConfig ) ,
116116}
117117
118+ #[ derive( Debug , Clone , Deserialize , PartialEq ) ]
119+ enum ModuleType {
120+ #[ serde( rename = "sentence_transformers.models.Dense" ) ]
121+ Dense ,
122+ #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
123+ Normalize ,
124+ #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
125+ Pooling ,
126+ #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
127+ Transformer ,
128+ }
129+
130+ #[ derive( Debug , Clone , Deserialize ) ]
131+ struct ModuleConfig {
132+ #[ allow( dead_code) ]
133+ idx : usize ,
134+ #[ allow( dead_code) ]
135+ name : String ,
136+ path : String ,
137+ #[ serde( rename = "type" ) ]
138+ module_type : ModuleType ,
139+ }
140+
141+ fn parse_dense_paths_from_modules ( model_path : & Path ) -> Result < Vec < String > , std:: io:: Error > {
142+ let modules_path = model_path. join ( "modules.json" ) ;
143+ if !modules_path. exists ( ) {
144+ return Ok ( vec ! [ ] ) ;
145+ }
146+
147+ let content = std:: fs:: read_to_string ( & modules_path) ?;
148+ let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
149+ . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
150+
151+ Ok ( modules
152+ . into_iter ( )
153+ . filter ( |module| module. module_type == ModuleType :: Dense )
154+ . map ( |module| module. path )
155+ . collect :: < Vec < String > > ( ) )
156+ }
157+
118158pub struct CandleBackend {
119159 device : Device ,
120160 model : Box < dyn Model + Send > ,
@@ -511,55 +551,66 @@ impl CandleBackend {
511551 }
512552 } ;
513553
514- // Load Dense layers from the provided Dense paths
554+ // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555+ // in such case simply use the `dense_paths`
556+ // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557+ // content to read the paths of the default Dense paths, useful when the model directory
558+ // is provided as the `model-id` rather than the ID from the Hugging Face Hub
559+ // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560+ // read from modules.json, this allows users to explicitly disable dense layers
515561 let mut dense_layers = Vec :: new ( ) ;
516- if let Some ( dense_paths) = & dense_paths {
517- if !dense_paths. is_empty ( ) {
518- tracing:: info!( "Loading Dense module/s from path/s: {dense_paths:?}" ) ;
519-
520- for dense_path in dense_paths. iter ( ) {
521- let dense_safetensors =
522- model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
523- let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
524-
525- if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
526- let dense_config_path =
527- model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
528-
529- let dense_config_str = std:: fs:: read_to_string ( & dense_config_path)
530- . map_err ( |err| {
531- BackendError :: Start ( format ! (
532- "Unable to read `{dense_path}/config.json` file: {err:?}" ,
533- ) )
534- } ) ?;
535- let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
536- . map_err ( |err| {
537- BackendError :: Start ( format ! (
538- "Unable to parse `{dense_path}/config.json`: {err:?}" ,
539- ) )
540- } ) ?;
541-
542- let dense_vb = if dense_safetensors. exists ( ) {
543- unsafe {
544- VarBuilder :: from_mmaped_safetensors (
545- & [ dense_safetensors] ,
546- dtype,
547- & device,
548- )
549- }
550- . s ( ) ?
551- } else {
552- VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
553- } ;
554-
555- let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
556- as Box < dyn DenseLayer + Send > ;
557- dense_layers. push ( dense_layer) ;
558-
559- tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
562+
563+ let paths_to_load = if let Some ( dense_paths) = & dense_paths {
564+ // If dense_paths is explicitly provided (even if empty), respect that choice
565+ dense_paths. clone ( )
566+ } else {
567+ // Try to parse modules.json only if dense_paths is None
568+ parse_dense_paths_from_modules ( model_path) . unwrap_or_default ( )
569+ } ;
570+
571+ if !paths_to_load. is_empty ( ) {
572+ tracing:: info!( "Loading Dense module/s from path/s: {paths_to_load:?}" ) ;
573+
574+ for dense_path in paths_to_load. iter ( ) {
575+ let dense_safetensors = model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
576+ let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
577+
578+ if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
579+ let dense_config_path = model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
580+
581+ let dense_config_str =
582+ std:: fs:: read_to_string ( & dense_config_path) . map_err ( |err| {
583+ BackendError :: Start ( format ! (
584+ "Unable to read `{dense_path}/config.json` file: {err:?}" ,
585+ ) )
586+ } ) ?;
587+ let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
588+ . map_err ( |err| {
589+ BackendError :: Start ( format ! (
590+ "Unable to parse `{dense_path}/config.json`: {err:?}" ,
591+ ) )
592+ } ) ?;
593+
594+ let dense_vb = if dense_safetensors. exists ( ) {
595+ unsafe {
596+ VarBuilder :: from_mmaped_safetensors (
597+ & [ dense_safetensors] ,
598+ dtype,
599+ & device,
600+ )
601+ }
602+ . s ( ) ?
560603 } else {
561- tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
562- }
604+ VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
605+ } ;
606+
607+ let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
608+ as Box < dyn DenseLayer + Send > ;
609+ dense_layers. push ( dense_layer) ;
610+
611+ tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
612+ } else {
613+ tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
563614 }
564615 }
565616 }
0 commit comments