2
2
from logging import getLogger
3
3
from typing import List , Optional
4
4
from whylogs .experimental .core .udf_schema import register_dataset_udf
5
- from langkit import lang_config , prompt_column , response_column
5
+ from langkit import LangKitConfig , lang_config , prompt_column , response_column
6
6
from nltk .tokenize import sent_tokenize
7
7
from langkit .openai .openai import LLMInvocationParams , Conversation , ChatLog
8
8
from langkit .transformer import Encoder
@@ -238,7 +238,7 @@ def consistency_check(
238
238
239
239
def response_hallucination (text ):
240
240
series_result = []
241
- for prompt , response in zip (text [_prompt ], text [_response ]):
241
+ for prompt , response in zip (text [prompt_column ], text [response_column ]):
242
242
result : ConsistencyResult = checker .consistency_check (prompt , response )
243
243
series_result .append (result .final_score )
244
244
return series_result
@@ -250,11 +250,16 @@ def consistency_check(prompt: str, response: Optional[str] = None):
250
250
else :
251
251
raise Exception ("You need to call init() before using this function" )
252
252
253
-
253
+
254
254
checker : Optional [ConsistencyChecker ] = None
255
255
256
256
257
- def init (config : Optional [LangKitConfig ] = None , llm : LLMInvocationParams , num_samples = 1 ):
257
+ def init (
258
+ config : Optional [LangKitConfig ] = None ,
259
+ llm : LLMInvocationParams = LLMInvocationParams (),
260
+ num_samples = 1 ,
261
+ ):
262
+ config = config or lang_config
258
263
global checker , embeddings_encoder
259
264
import nltk
260
265
@@ -264,6 +269,6 @@ def init(config: Optional[LangKitConfig] = None, llm: LLMInvocationParams, num_s
264
269
)
265
270
embeddings_encoder = Encoder (config .response_transformer_name , custom_encoder = None )
266
271
checker = ConsistencyChecker (llm , num_samples , embeddings_encoder )
267
- register_dataset_udf ([ prompt_column , response_column ], f" { response_column } .hallucination" )(
268
- response_hallucination
269
- )
272
+ register_dataset_udf (
273
+ [ prompt_column , response_column ], f" { response_column } .hallucination"
274
+ )( response_hallucination )
0 commit comments