Skip to content

Commit ef52f92

Browse files
restpre some magic
1 parent bbb29d0 commit ef52f92

File tree

8 files changed

+72
-16
lines changed

8 files changed

+72
-16
lines changed

β€Žlangkit/count_regexes.pyβ€Ž

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
pattern_loader = PatternLoader()
1515
response_pattern_loader = PatternLoader()
1616

17+
_initialized = False
18+
1719

1820
def count_patterns(group, text: str) -> int:
21+
if not _initialized:
22+
init()
1923
count = 0
2024
for expression in group["expressions"]:
2125
if expression.search(text):
@@ -68,6 +72,8 @@ def init(
6872
config: Optional[LangKitConfig] = None,
6973
response_pattern_file_path: Optional[str] = None,
7074
):
75+
global _initialized
76+
_initialized = True
7177
language = language or ""
7278
config = deepcopy(config or lang_config)
7379
if pattern_file_path:

β€Žlangkit/injections.pyβ€Ž

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
_index_embeddings = None
1616
_transformer_model = None
1717

18+
_initialized = False
1819

1920
def injection(prompt: Union[Dict[str, List], pd.DataFrame]) -> Union[List, pd.Series]:
21+
if not _initialized:
22+
init()
2023
global _transformer_model
2124
global _index_embeddings
2225
if _transformer_model is None:
@@ -45,6 +48,8 @@ def init(
4548
version: Optional[str] = None,
4649
config: Optional[LangKitConfig] = None,
4750
):
51+
global _initialized
52+
_initialized = True
4853
global _registered
4954
unregister_udfs(_registered)
5055
config = config or deepcopy(lang_config)

β€Žlangkit/input_output.pyβ€Ž

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
_transformer_model = None
1616

17+
_initialized = False
18+
1719
diagnostic_logger = getLogger(__name__)
1820

1921

2022
def prompt_response_similarity(text):
23+
if not _initialized:
24+
init()
2125
global _transformer_model
2226

2327
if _transformer_model is None:
@@ -49,6 +53,8 @@ def init(
4953
custom_encoder: Optional[Callable] = None,
5054
config: Optional[LangKitConfig] = None,
5155
):
56+
global _initialized
57+
_initialized = True
5258
global _registered
5359
unregister_udfs(_registered)
5460
if transformer_name and custom_encoder:

β€Žlangkit/nlp_scores.pyβ€Ž

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
diagnostic_logger = getLogger(__name__)
1616

17+
_initialized = False
1718

1819
_registered: Set[str] = set()
1920

2021

2122
def _register_score_udfs():
23+
if not _initialized:
24+
init()
2225
global _registered
2326
unregister_udfs(_registered)
2427
if _corpus:
@@ -92,6 +95,8 @@ def init(
9295
rouge_type: str = "",
9396
config: Optional[LangKitConfig] = None,
9497
):
98+
global _initialized
99+
_initialized = True
95100
config = config or deepcopy(lang_config)
96101
global _corpus
97102
global _scores

β€Žlangkit/regexes.pyβ€Ž

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
pattern_loader = PatternLoader()
1515
response_pattern_loader = PatternLoader()
1616

17+
_initialized = False
18+
1719

1820
def has_patterns(text, regex_groups):
21+
if not _initialized:
22+
init()
1923
if regex_groups:
2024
for group in regex_groups:
2125
for expression in group["expressions"]:
@@ -35,6 +39,8 @@ def wrappee(text):
3539

3640

3741
def _register_udfs(config: Optional[LangKitConfig] = None):
42+
global _initialized
43+
_initialized = True
3844
global _registered
3945
unregister_udfs(_registered)
4046
if config is None:
@@ -70,6 +76,8 @@ def init(
7076
config: Optional[LangKitConfig] = None,
7177
response_pattern_file_path: Optional[str] = None,
7278
):
79+
global _initialized
80+
_initialized = True
7381
config = deepcopy(config or lang_config)
7482
if pattern_file_path:
7583
config.pattern_file_path = pattern_file_path

β€Žlangkit/sentiment.pyβ€Ž

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,24 @@
66
from langkit.whylogs.unreg import unregister_udfs
77

88

9-
def sentiment_nltk(text: str, sentiment_analyzer) -> float:
9+
_registered: Set[str] = set()
10+
11+
12+
_nltk_downloaded = None
13+
_response_nltk_downloaded = None
14+
_sentiment_analyzer = None
15+
_response_sentiment_analyzer = None
16+
17+
_pipeline = None
18+
_response_pipeline = None
19+
20+
_initialized = False
21+
22+
23+
def sentiment_nltk(text: str, sentiment_analyzer=None) -> float:
24+
if not _initialized:
25+
init()
26+
sentiment_analyzer = sentiment_analyzer or _sentiment_analyzer
1027
if sentiment_analyzer is None:
1128
raise ValueError(
1229
"sentiment metrics must initialize sentiment analyzer before evaluation!"
@@ -30,7 +47,10 @@ def sentiment_nltk(text: str, sentiment_analyzer) -> float:
3047
}
3148

3249

33-
def sentiment_multilingual(text: str, pipeline) -> float:
50+
def sentiment_multilingual(text: str, pipeline=None) -> float:
51+
if not _initialized:
52+
init()
53+
pipeline = pipeline or _pipeline
3454
if pipeline is None:
3555
raise ValueError("sentiment score must initialize the pipeline first")
3656

@@ -47,15 +67,6 @@ def _wrappee(text):
4767
return _wrappee
4868

4969

50-
_registered: Set[str] = set()
51-
52-
53-
_nltk_downloaded = None
54-
_response_nltk_downloaded = None
55-
_sentiment_analyzer = None
56-
_response_sentiment_analyzer = None
57-
58-
5970
def configure_nltk(config, lexicon, response_lexicon):
6071
import nltk
6172
from nltk.sentiment import SentimentIntensityAnalyzer
@@ -83,10 +94,6 @@ def configure_nltk(config, lexicon, response_lexicon):
8394
_response_sentiment_analyzer = None
8495

8596

86-
_pipeline = None
87-
_response_pipeline = None
88-
89-
9097
def configure_hugging_face(config, sentiment_model_path, response_sentiment_model_path):
9198
from transformers import pipeline
9299

@@ -112,6 +119,9 @@ def init(
112119
sentiment_model_path: Optional[str] = None,
113120
response_sentiment_model_path: Optional[str] = None,
114121
):
122+
global _initialized
123+
_initialized = True
124+
115125
global _registered
116126
unregister_udfs(_registered)
117127

β€Žlangkit/topics.pyβ€Ž

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717
_response_model_path: Optional[str] = None
1818
_response_classifier = None
1919

20+
_initialized = False
2021

21-
def closest_topic(text, classifier, topics):
22+
23+
def closest_topic(text, classifier=None, topics=None):
24+
if not _initialized:
25+
init()
26+
classifier = classifier or _classifier
27+
topics = topics or _topics
2228
if classifier is None:
2329
raise ValueError("Topics - classifier model not initialized")
2430
return classifier(text, topics, multi_label=False)["labels"][0]
@@ -41,6 +47,8 @@ def init(
4147
response_model_path: Optional[str] = None,
4248
response_topic_classifier: Optional[str] = None,
4349
):
50+
global _initialized
51+
_initialized = True
4452
global _registered
4553
unregister_udfs(_registered)
4654
config = config or deepcopy(lang_config)

β€Žlangkit/toxicity.pyβ€Ž

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
_response_toxicity_tokenizer = None
1414
_response_toxicity_pipeline = None
1515

16+
_initialized = False
17+
1618
PROMPT_TRANSLATOR: Optional[Translator] = None
1719
RESPONSE_TRANSLATOR: Optional[Translator] = None
1820
TRANSLATOR: Optional[Translator] = None
1921

2022

2123
def toxicity(text: str, pipeline, tokenizer) -> float:
24+
if not _initialized:
25+
init()
26+
pipeline = pipeline or _toxicity_pipeline
27+
tokenizer = tokenizer or _toxicity_tokenizer
2228
if pipeline is None or tokenizer is None:
2329
raise ValueError("toxicity score must initialize the pipeline first")
2430

@@ -41,6 +47,8 @@ def init(
4147
config: Optional[LangKitConfig] = None,
4248
response_model_path: Optional[str] = None,
4349
):
50+
global _initialized
51+
_initialized = True
4452
global _registered
4553
unregister_udfs(_registered)
4654
from transformers import (

0 commit comments

Comments
Β (0)