Skip to content

Commit dd6269b

Browse files
pre-commit
1 parent 368468b commit dd6269b

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

langkit/response_hallucination.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from logging import getLogger
33
from typing import List, Optional
44
from whylogs.experimental.core.udf_schema import register_dataset_udf
5-
from langkit import lang_config, prompt_column, response_column
5+
from langkit import LangKitConfig, lang_config, prompt_column, response_column
66
from nltk.tokenize import sent_tokenize
77
from langkit.openai.openai import LLMInvocationParams, Conversation, ChatLog
88
from langkit.transformer import Encoder
@@ -238,7 +238,7 @@ def consistency_check(
238238

239239
def response_hallucination(text):
240240
series_result = []
241-
for prompt, response in zip(text[_prompt], text[_response]):
241+
for prompt, response in zip(text[prompt_column], text[response_column]):
242242
result: ConsistencyResult = checker.consistency_check(prompt, response)
243243
series_result.append(result.final_score)
244244
return series_result
@@ -250,11 +250,16 @@ def consistency_check(prompt: str, response: Optional[str] = None):
250250
else:
251251
raise Exception("You need to call init() before using this function")
252252

253-
253+
254254
checker: Optional[ConsistencyChecker] = None
255255

256256

257-
def init(config: Optional[LangKitConfig] = None, llm: LLMInvocationParams, num_samples=1):
257+
def init(
258+
config: Optional[LangKitConfig] = None,
259+
llm: LLMInvocationParams = LLMInvocationParams(),
260+
num_samples=1,
261+
):
262+
config = config or lang_config
258263
global checker, embeddings_encoder
259264
import nltk
260265

@@ -264,6 +269,6 @@ def init(config: Optional[LangKitConfig] = None, llm: LLMInvocationParams, num_s
264269
)
265270
embeddings_encoder = Encoder(config.response_transformer_name, custom_encoder=None)
266271
checker = ConsistencyChecker(llm, num_samples, embeddings_encoder)
267-
register_dataset_udf([prompt_column, response_column], f"{response_column}.hallucination")(
268-
response_hallucination
269-
)
272+
register_dataset_udf(
273+
[prompt_column, response_column], f"{response_column}.hallucination"
274+
)(response_hallucination)

0 commit comments

Comments
 (0)