@@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum";
5
5
import {
6
6
embeddingEndpoints ,
7
7
embeddingEndpointSchema ,
8
- type EmbeddingEndpoint ,
9
8
} from "$lib/server/embeddingEndpoints/embeddingEndpoints" ;
10
9
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints" ;
11
10
12
11
import JSON5 from "json5" ;
12
+ import type { EmbeddingModel } from "$lib/types/EmbeddingModel" ;
13
+ import { collections } from "./database" ;
14
+ import { ObjectId } from "mongodb" ;
13
15
14
16
const modelConfig = z . object ( {
15
17
/** Used as an identifier in DB */
@@ -42,67 +44,77 @@ const rawEmbeddingModelJSON =
42
44
43
45
const embeddingModelsRaw = z . array ( modelConfig ) . parse ( JSON5 . parse ( rawEmbeddingModelJSON ) ) ;
44
46
45
- const processEmbeddingModel = async ( m : z . infer < typeof modelConfig > ) => ( {
46
- ...m ,
47
- id : m . id || m . name ,
47
+ const embeddingModels = embeddingModelsRaw . map ( ( rawEmbeddingModel ) => {
48
+ const embeddingModel : EmbeddingModel = {
49
+ name : rawEmbeddingModel . name ,
50
+ description : rawEmbeddingModel . description ,
51
+ websiteUrl : rawEmbeddingModel . websiteUrl ,
52
+ modelUrl : rawEmbeddingModel . modelUrl ,
53
+ chunkCharLength : rawEmbeddingModel . chunkCharLength ,
54
+ maxBatchSize : rawEmbeddingModel . maxBatchSize ,
55
+ preQuery : rawEmbeddingModel . preQuery ,
56
+ prePassage : rawEmbeddingModel . prePassage ,
57
+ _id : new ObjectId ( ) ,
58
+ createdAt : new Date ( ) ,
59
+ updatedAt : new Date ( ) ,
60
+ endpoints : rawEmbeddingModel . endpoints ,
61
+ } ;
62
+
63
+ return embeddingModel ;
48
64
} ) ;
49
65
50
- const addEndpoint = ( m : Awaited < ReturnType < typeof processEmbeddingModel > > ) => ( {
51
- ...m ,
52
- getEndpoint : async ( ) : Promise < EmbeddingEndpoint > => {
53
- if ( ! m . endpoints ) {
54
- return embeddingEndpointTransformersJS ( {
55
- type : "transformersjs" ,
56
- weight : 1 ,
57
- model : m ,
58
- } ) ;
59
- }
66
+ export const getEmbeddingEndpoint = async ( embeddingModel : EmbeddingModel ) => {
67
+ if ( ! embeddingModel . endpoints ) {
68
+ return embeddingEndpointTransformersJS ( {
69
+ type : "transformersjs" ,
70
+ weight : 1 ,
71
+ model : embeddingModel ,
72
+ } ) ;
73
+ }
60
74
61
- const totalWeight = sum ( m . endpoints . map ( ( e ) => e . weight ) ) ;
62
-
63
- let random = Math . random ( ) * totalWeight ;
64
-
65
- for ( const endpoint of m . endpoints ) {
66
- if ( random < endpoint . weight ) {
67
- const args = { ...endpoint , model : m } ;
68
-
69
- switch ( args . type ) {
70
- case "tei" :
71
- return embeddingEndpoints . tei ( args ) ;
72
- case "transformersjs" :
73
- return embeddingEndpoints . transformersjs ( args ) ;
74
- case "openai" :
75
- return embeddingEndpoints . openai ( args ) ;
76
- case "hfapi" :
77
- return embeddingEndpoints . hfapi ( args ) ;
78
- default :
79
- throw new Error ( `Unknown endpoint type: ${ args } ` ) ;
80
- }
75
+ const totalWeight = sum ( embeddingModel . endpoints . map ( ( e ) => e . weight ) ) ;
76
+
77
+ let random = Math . random ( ) * totalWeight ;
78
+
79
+ for ( const endpoint of embeddingModel . endpoints ) {
80
+ if ( random < endpoint . weight ) {
81
+ const args = { ...endpoint , model : embeddingModel } ;
82
+ console . log ( args . type ) ;
83
+
84
+ switch ( args . type ) {
85
+ case " tei" :
86
+ return embeddingEndpoints . tei ( args ) ;
87
+ case " transformersjs" :
88
+ return embeddingEndpoints . transformersjs ( args ) ;
89
+ case " openai" :
90
+ return embeddingEndpoints . openai ( args ) ;
91
+ case " hfapi" :
92
+ return embeddingEndpoints . hfapi ( args ) ;
93
+ default :
94
+ throw new Error ( `Unknown endpoint type: ${ args } ` ) ;
81
95
}
82
-
83
- random -= endpoint . weight ;
84
96
}
85
97
86
- throw new Error ( `Failed to select embedding endpoint` ) ;
87
- } ,
88
- } ) ;
89
-
90
- export const embeddingModels = await Promise . all (
91
- embeddingModelsRaw . map ( ( e ) => processEmbeddingModel ( e ) . then ( addEndpoint ) )
92
- ) ;
93
-
94
- export const defaultEmbeddingModel = embeddingModels [ 0 ] ;
98
+ random -= endpoint . weight ;
99
+ }
95
100
96
- const validateEmbeddingModel = ( _models : EmbeddingBackendModel [ ] , key : "id" | "name" ) => {
97
- return z . enum ( [ _models [ 0 ] [ key ] , ..._models . slice ( 1 ) . map ( ( m ) => m [ key ] ) ] ) ;
101
+ throw new Error ( `Failed to select embedding endpoint` ) ;
98
102
} ;
99
103
100
- export const validateEmbeddingModelById = ( _models : EmbeddingBackendModel [ ] ) => {
101
- return validateEmbeddingModel ( _models , "id" ) ;
102
- } ;
104
+ export const getDefaultEmbeddingModel = async ( ) : Promise < EmbeddingModel > => {
105
+ if ( ! embeddingModels [ 0 ] ) {
106
+ throw new Error ( `Failed to find default embedding endpoint` ) ;
107
+ }
108
+
109
+ const defaultModel = await collections . embeddingModels . findOne ( {
110
+ _id : embeddingModels [ 0 ] . _id ,
111
+ } ) ;
103
112
104
- export const validateEmbeddingModelByName = ( _models : EmbeddingBackendModel [ ] ) => {
105
- return validateEmbeddingModel ( _models , "name" ) ;
113
+ return defaultModel ? defaultModel : embeddingModels [ 0 ] ;
106
114
} ;
107
115
108
- export type EmbeddingBackendModel = typeof defaultEmbeddingModel ;
116
+ // to mimic current behaivor with creating embedding models from scratch during server start
117
+ export async function pupulateEmbeddingModel ( ) {
118
+ await collections . embeddingModels . deleteMany ( { } ) ;
119
+ await collections . embeddingModels . insertMany ( embeddingModels ) ;
120
+ }
0 commit comments