Skip to content

Commit fb267c5

Browse files
committed
refactor(embeddingref): replace STactor(embedding): replace STEmbedding withEmbedding with LocalEmbedding
LocalEmbedding UpdatedUpdated the the Embed EmbeddingEngine to use LocalEmbedding instead of STEmbedding for better consistency and clarity.dingEngine to use LocalEmbedding instead of STEmbedding for better consistency and clarity.
1 parent 769a258 commit fb267c5

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

llm-modules/local-embedding/src/main/kotlin/cc/unitmesh/cf/STEmbedding.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ class STEmbedding(
1010
private val session: OrtSession,
1111
private val env: OrtEnvironment,
1212
) : LocalEmbedding(tokenizer, session, env) {
13+
14+
companion object {
15+
fun create(): LocalEmbedding {
16+
return LocalEmbedding.create()
17+
}
18+
}
1319
}

rag-modules/rag-script/src/main/kotlin/cc/unitmesh/rag/EmbeddingEngine.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package cc.unitmesh.rag
22

3-
import cc.unitmesh.cf.STEmbedding
3+
import cc.unitmesh.cf.LocalEmbedding
44
import cc.unitmesh.nlp.embedding.Embedding
55
import cc.unitmesh.nlp.embedding.EmbeddingProvider
66
import cc.unitmesh.nlp.embedding.text.EnglishTextEmbeddingProvider
@@ -13,14 +13,14 @@ enum class EngineType {
1313

1414
class EmbeddingEngine(private val engine: EngineType = EngineType.SentenceTransformers) {
1515
var provider: EmbeddingProvider = when (engine) {
16-
EngineType.SentenceTransformers -> SentenceTransformersEmbedding()
16+
EngineType.SentenceTransformers -> LocalTransformersEmbedding()
1717
EngineType.EnglishTextEmbedding -> EnglishTextEmbeddingProvider()
1818
EngineType.TextEmbeddingAda -> TODO()
1919
}
2020
}
2121

22-
class SentenceTransformersEmbedding : EmbeddingProvider {
23-
private val semantic = STEmbedding.create()
22+
class LocalTransformersEmbedding : EmbeddingProvider {
23+
private val semantic = LocalEmbedding.create()
2424
override fun embed(texts: List<String>): List<Embedding> {
2525
return texts.map {
2626
semantic.embed(it).toList()

server/src/test/kotlin/RagIntegrationTests.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cc.unitmesh.cf.LocalEmbedding
12
import cc.unitmesh.cf.STEmbedding
23
import cc.unitmesh.cf.infrastructure.llms.embedding.SentenceTransformersEmbedding
34
import cc.unitmesh.nlp.embedding.Embedding
@@ -13,7 +14,7 @@ import io.kotest.matchers.shouldBe
1314
import org.junit.jupiter.api.Test
1415

1516
class RagIntegrationTests {
16-
val semantic = STEmbedding.create()
17+
val semantic = LocalEmbedding.create()
1718

1819
private val embeddingProvider = SentenceTransformersEmbedding()
1920

0 commit comments

Comments
 (0)