diff --git a/README.md b/README.md index c762927..c915d28 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,22 @@ CUSTOM_LLM_CLIENT_SECRET = "your-client-secret" Your custom service must implement the expected OAuth2 client‑credentials flow and provide JSON endpoints for listing models, obtaining completions, and fetching tokens as used by `CustomLLMService`. +#### Optional provider threads (conversation IDs) + +For deployments using a custom LLM service, you can enable provider‑side threads to cache context between turns. This is optional and disabled by default. When enabled, the LMS/XBlock remains the canonical chat history as that ensures vendor flexibility and continuity; provider threads are treated as a cache. + +- Site configuration (under `ai_eval`): + - `PROVIDER_SUPPORTS_THREADS`: boolean, default `false`. When `true`, `CustomLLMService` attempts to reuse a provider conversation ID. +- XBlock user state (managed automatically): + - `thread_map`: a dictionary mapping `tag -> conversation_id`, where `tag = provider:model:prompt_hash`. This allows multiple concurrent provider threads per learner per XBlock, one per distinct prompt/model context. + +Reset clears `thread_map`. If a provider ignores threads, behavior remains stateless. + +Compatibility and fallback +- Not all vendors/models support `conversation_id`. The default service path (via LiteLLM chat completions) does not use provider threads; calls remain stateless. +- If threads are unsupported or ignored by a provider, the code still works and behaves statelessly. +- With a custom provider that supports threads, the first turn sends full context and later turns send only the latest user input along with the cached `conversation_id`. + ### Custom Code Execution Service (advanced) The Coding XBlock can route code execution to a third‑party service instead of Judge0. The service is expected to be asynchronous, exposing a submit endpoint that returns a submission identifier, and a results endpoint that returns the execution result when available. Configure this via Django settings: diff --git a/ai_eval/base.py b/ai_eval/base.py index 43db5ac..fc0f036 100644 --- a/ai_eval/base.py +++ b/ai_eval/base.py @@ -6,7 +6,7 @@ from django.utils.translation import gettext_noop as _ from xblock.core import XBlock -from xblock.fields import String, Scope +from xblock.fields import String, Scope, Dict from xblock.utils.resources import ResourceLoader from xblock.utils.studio_editable import StudioEditableXBlockMixin from xblock.validation import ValidationMessage @@ -78,6 +78,11 @@ class AIEvalXBlock(StudioEditableXBlockMixin, XBlock): default="", values_provider=_get_model_choices, ) + thread_map = Dict( + help=_("Map of provider thread IDs keyed by tag"), + default={}, + scope=Scope.user_state, + ) editable_fields = ( "display_name", @@ -206,6 +211,26 @@ def validate_field_data(self, validation, data): ) ) - def get_llm_response(self, messages): - return get_llm_response(self.model, self.get_model_api_key(), messages, - self.get_model_api_url()) + def get_llm_response(self, messages, tag: str | None = None): + """ + Call the shared LLM entrypoint and return only the response text. + """ + prior_thread_id = None + if tag: + try: + prior_thread_id = (self.thread_map or {}).get(tag) or None + except Exception: # pylint: disable=broad-exception-caught + prior_thread_id = None + + text, new_thread_id = get_llm_response( + self.model, + self.get_model_api_key(), + messages, + self.get_model_api_url(), + thread_id=prior_thread_id, + ) + if tag and new_thread_id: + tm = dict(getattr(self, "thread_map", {}) or {}) + tm[tag] = new_thread_id + self.thread_map = tm + return text diff --git a/ai_eval/llm.py b/ai_eval/llm.py index 98a55e8..307f716 100644 --- a/ai_eval/llm.py +++ b/ai_eval/llm.py @@ -50,8 +50,8 @@ def get_llm_service(): def get_llm_response( - model: str, api_key: str, messages: list, api_base: str -) -> str: + model: str, api_key: str, messages: list, api_base: str, thread_id: str | None = None +) -> tuple[str, str | None]: """ Get LLM response, using either the default or custom service based on site configuration. @@ -80,8 +80,22 @@ def get_llm_response( API request URL. This is required only when using Llama which doesn't have an official provider. Returns: - str: The response text from the LLM. This is typically the generated output based on the provided - messages. + tuple[str, Optional[str]]: The response text and a new thread id if a provider thread was created/used. """ llm_service = get_llm_service() - return llm_service.get_response(model, api_key, messages, api_base) + allow_threads = False + try: + allow_threads = bool(llm_service.supports_threads()) + except Exception: # pylint: disable=broad-exception-caught + allow_threads = False + + # Continue threaded conversation when a thread_id is provided + if thread_id: + return llm_service.get_response(model, api_key, messages, api_base, thread_id=thread_id) + + # Start a new thread only when allowed in config setting + if allow_threads: + return llm_service.start_thread(model, api_key, messages, api_base) + + # Stateless call - do not create or persist any conversation id + return llm_service.get_response(model, api_key, messages, api_base, thread_id=None) diff --git a/ai_eval/llm_services.py b/ai_eval/llm_services.py index c573873..83a2cba 100644 --- a/ai_eval/llm_services.py +++ b/ai_eval/llm_services.py @@ -6,6 +6,7 @@ from litellm import completion from .supported_models import SupportedModels +from .compat import get_site_configuration_value logger = logging.getLogger(__name__) @@ -16,30 +17,79 @@ class LLMServiceBase: """ Base class for llm service. """ - def get_response(self, model, api_key, messages, api_base): + # pylint: disable=too-many-positional-arguments + def get_response(self, model, api_key, messages, api_base, thread_id=None): + """ + Get a response from the provider. + + Args: + model (str): Model identifier. + api_key (str): API key (for default providers or passthrough). + messages (list[dict]): Chat messages. + api_base (str|None): Optional base URL (e.g., for llama/ollama). + thread_id (str|None): Optional provider-side conversation/thread id. + + Returns: + tuple[str, str|None]: (response_text, optional_thread_id) + """ + raise NotImplementedError + + # pylint: disable=too-many-positional-arguments + def start_thread(self, model, api_key, messages, api_base): + """ + Start a new provider-side thread and return its first response. + + Return the provider-issued conversation/thread id if available. + + Returns: + tuple[str, str|None]: (response_text, new_thread_id) + """ raise NotImplementedError def get_available_models(self): raise NotImplementedError + def supports_threads(self) -> bool: + """ + Check if this service supports provider-side threads. + + Default is False; custom services can override to be flag-driven. + """ + return False + class DefaultLLMService(LLMServiceBase): """ Default llm service. """ - def get_response(self, model, api_key, messages, api_base): + # pylint: disable=too-many-positional-arguments + def get_response( + self, + model, + api_key, + messages, + api_base, + thread_id=None, + ): kwargs = {} if api_base: kwargs["api_base"] = api_base - return ( + text = ( completion(model=model, api_key=api_key, messages=messages, **kwargs) .choices[0] .message.content ) + return text, None + + def start_thread(self, model, api_key, messages, api_base): # pragma: no cover - trivial passthrough + return self.get_response(model, api_key, messages, api_base, thread_id=None) def get_available_models(self): return [str(m.value) for m in SupportedModels] + def supports_threads(self) -> bool: # pragma: nocover - default is stateless + return False + class CustomLLMService(LLMServiceBase): """ @@ -81,12 +131,61 @@ def _get_headers(self): self._ensure_token() return {'Authorization': f'Bearer {self._access_token}'} - def get_response(self, model, api_key, messages, api_base): + def get_response( + self, + model, + api_key, + messages, + api_base, + thread_id=None, + ): """ Send completion request to custom LLM endpoint. + If thread_id is provided, include it and send only the latest user input. + If thread_id is None, send full context and return (text, None). """ url = self.completions_url + # When reusing an existing thread, only send the latest user input and rely on + # the provider to apply prior context associated with the conversation_id. + if thread_id: + latest_user = None + for msg in reversed(messages): + if (msg.get('role') or '').lower() == 'user': + latest_user = msg.get('content', '').strip() + break + prompt = f"User: {latest_user}" if latest_user is not None else "" + else: + prompt = " ".join( + f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}" + for msg in messages + ) # Adjust the payload structure based on custom API requirements + payload = { + "model": str(model), + "prompt": prompt, + } + + if thread_id: + payload["conversation_id"] = thread_id + + response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10) + response.raise_for_status() + data = response.json() + # Adjust this if custom API returns the response differently + text = data.get("response") + + if thread_id: + new_thread_id = data.get("conversation_id") + if not new_thread_id and isinstance(data.get("data"), dict): + new_thread_id = data["data"].get("conversation_id") + return text, new_thread_id + return text, None + + def start_thread(self, model, api_key, messages, api_base): + """ + Start new thread. + """ + url = self.completions_url prompt = " ".join( f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}" for msg in messages @@ -98,8 +197,11 @@ def get_response(self, model, api_key, messages, api_base): response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10) response.raise_for_status() data = response.json() - # Adjust this if custom API returns the response differently - return data.get("response") + text = data.get("response") + new_thread_id = data.get("conversation_id") + if not new_thread_id and isinstance(data.get("data"), dict): + new_thread_id = data["data"].get("conversation_id") + return text, new_thread_id def get_available_models(self): url = self.models_url @@ -153,3 +255,11 @@ def get_available_models(self): exc_info=True, ) return [] + + def supports_threads(self) -> bool: + """Return whether provider threads should be used, from site flag.""" + try: + val = get_site_configuration_value("ai_eval", "PROVIDER_SUPPORTS_THREADS") + return bool(val) + except Exception: # pylint: disable=broad-exception-caught + return False diff --git a/ai_eval/shortanswer.py b/ai_eval/shortanswer.py index bbc301b..fe40666 100644 --- a/ai_eval/shortanswer.py +++ b/ai_eval/shortanswer.py @@ -1,6 +1,7 @@ """Short answers Xblock with AI evaluation.""" import logging +import hashlib import urllib.parse import urllib.request from multiprocessing.dummy import Pool @@ -14,6 +15,8 @@ from xblock.validation import ValidationMessage from .base import AIEvalXBlock +from .llm import get_llm_service +from .llm_services import CustomLLMService logger = logging.getLogger(__name__) @@ -178,15 +181,33 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument user_submission = str(data["user_input"]) attachments = [] + attachment_hash_inputs = [] for filename, contents in self._get_attachments(): + # Build system prompt attachment section (HTML-like) as before attachments.append(f""" {saxutils.escape(filename)} {saxutils.escape(contents)} """) + # For tagging, hash filename + contents + attachment_hash_inputs.append(f"{filename}|{contents}") attachments = '\n'.join(attachments) + # Compute a tag to identify compatible reuse across provider/model/prompt + # Include evaluation prompt, question, and attachment content hashes + prompt_hasher = hashlib.sha256() + prompt_hasher.update((self.evaluation_prompt or "").strip().encode("utf-8")) + prompt_hasher.update((self.question or "").strip().encode("utf-8")) + for item in attachment_hash_inputs: + prompt_hasher.update(item.encode("utf-8")) + prompt_hash = prompt_hasher.hexdigest() + + # Determine provider tag based on service type + llm_service = get_llm_service() + provider_tag = "custom" if isinstance(llm_service, CustomLLMService) else "default" + current_tag = f"{provider_tag}:{self.model}:{prompt_hash}" + system_msg = { "role": "system", "content": f""" @@ -210,7 +231,7 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument messages.append({"role": "user", "content": user_submission}) try: - response = self.get_llm_response(messages) + text = self.get_llm_response(messages, tag=current_tag) except Exception as e: logger.error( f"Failed while making LLM request using model {self.model}. Error: {e}", @@ -218,10 +239,10 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument ) raise JsonHandlerError(500, "A probem occured. Please retry.") from e - if response: + if text: self.messages[self.USER_KEY].append(user_submission) - self.messages[self.LLM_KEY].append(response) - return {"response": response} + self.messages[self.LLM_KEY].append(text) + return {"response": text} raise JsonHandlerError(500, "A probem occured. The LLM sent an empty response.") @@ -233,6 +254,7 @@ def reset(self, data, suffix=""): if not self.allow_reset: raise JsonHandlerError(403, "Reset is disabled.") self.messages = {self.USER_KEY: [], self.LLM_KEY: []} + self.thread_map = {} return {} @staticmethod diff --git a/ai_eval/tests/test_ai_eval.py b/ai_eval/tests/test_ai_eval.py index 2243bf9..148736c 100644 --- a/ai_eval/tests/test_ai_eval.py +++ b/ai_eval/tests/test_ai_eval.py @@ -102,8 +102,11 @@ def test_shortanswer_reset_allowed(shortanswer_block_data): "messages": {"USER": ["Hello"], "LLM": ["Hello"]}, } block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None) + # Pre-populate thread map to verify reset clears it + block.thread_map = {"provider:model:tag": "abc123"} block.reset.__wrapped__(block, data={}) assert block.messages == {"USER": [], "LLM": []} + assert not block.thread_map def test_shortanswer_reset_forbidden(shortanswer_block_data): @@ -140,9 +143,17 @@ def test_shortanswer_attachments(shortanswer_block_data): } block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None) block._download_attachment = Mock(return_value="file contents <&>") - block.get_llm_response = Mock(return_value=".") - block.get_response.__wrapped__(block, data={"user_input": "."}) - messages = block.get_llm_response.call_args.args[0] + with patch('ai_eval.shortanswer.get_llm_service') as mock_service, \ + patch('ai_eval.llm.get_llm_service') as mock_llm_service, \ + patch('ai_eval.base.get_site_configuration_value', return_value=None), \ + patch('ai_eval.base.get_llm_response') as mocked: + mock_service.return_value = Mock() + mock_service.return_value.supports_threads.return_value = False + mock_llm_service.return_value = mock_service.return_value + mocked.return_value = (".", None) + block.get_response.__wrapped__(block, data={"user_input": "."}) + # Extract the messages argument passed into get_llm_response + messages = mocked.call_args.kwargs.get('messages') or mocked.call_args.args[2] prompt = messages[0]["content"] assert "1.txt" in prompt assert "file contents <&>" in prompt