diff --git a/README.md b/README.md
index c762927..c915d28 100644
--- a/README.md
+++ b/README.md
@@ -93,6 +93,22 @@ CUSTOM_LLM_CLIENT_SECRET = "your-client-secret"
Your custom service must implement the expected OAuth2 client‑credentials flow and provide JSON endpoints
for listing models, obtaining completions, and fetching tokens as used by `CustomLLMService`.
+#### Optional provider threads (conversation IDs)
+
+For deployments using a custom LLM service, you can enable provider‑side threads to cache context between turns. This is optional and disabled by default. When enabled, the LMS/XBlock remains the canonical chat history as that ensures vendor flexibility and continuity; provider threads are treated as a cache.
+
+- Site configuration (under `ai_eval`):
+ - `PROVIDER_SUPPORTS_THREADS`: boolean, default `false`. When `true`, `CustomLLMService` attempts to reuse a provider conversation ID.
+- XBlock user state (managed automatically):
+ - `thread_map`: a dictionary mapping `tag -> conversation_id`, where `tag = provider:model:prompt_hash`. This allows multiple concurrent provider threads per learner per XBlock, one per distinct prompt/model context.
+
+Reset clears `thread_map`. If a provider ignores threads, behavior remains stateless.
+
+Compatibility and fallback
+- Not all vendors/models support `conversation_id`. The default service path (via LiteLLM chat completions) does not use provider threads; calls remain stateless.
+- If threads are unsupported or ignored by a provider, the code still works and behaves statelessly.
+- With a custom provider that supports threads, the first turn sends full context and later turns send only the latest user input along with the cached `conversation_id`.
+
### Custom Code Execution Service (advanced)
The Coding XBlock can route code execution to a third‑party service instead of Judge0. The service is expected to be asynchronous, exposing a submit endpoint that returns a submission identifier, and a results endpoint that returns the execution result when available. Configure this via Django settings:
diff --git a/ai_eval/base.py b/ai_eval/base.py
index 43db5ac..fc0f036 100644
--- a/ai_eval/base.py
+++ b/ai_eval/base.py
@@ -6,7 +6,7 @@
from django.utils.translation import gettext_noop as _
from xblock.core import XBlock
-from xblock.fields import String, Scope
+from xblock.fields import String, Scope, Dict
from xblock.utils.resources import ResourceLoader
from xblock.utils.studio_editable import StudioEditableXBlockMixin
from xblock.validation import ValidationMessage
@@ -78,6 +78,11 @@ class AIEvalXBlock(StudioEditableXBlockMixin, XBlock):
default="",
values_provider=_get_model_choices,
)
+ thread_map = Dict(
+ help=_("Map of provider thread IDs keyed by tag"),
+ default={},
+ scope=Scope.user_state,
+ )
editable_fields = (
"display_name",
@@ -206,6 +211,26 @@ def validate_field_data(self, validation, data):
)
)
- def get_llm_response(self, messages):
- return get_llm_response(self.model, self.get_model_api_key(), messages,
- self.get_model_api_url())
+ def get_llm_response(self, messages, tag: str | None = None):
+ """
+ Call the shared LLM entrypoint and return only the response text.
+ """
+ prior_thread_id = None
+ if tag:
+ try:
+ prior_thread_id = (self.thread_map or {}).get(tag) or None
+ except Exception: # pylint: disable=broad-exception-caught
+ prior_thread_id = None
+
+ text, new_thread_id = get_llm_response(
+ self.model,
+ self.get_model_api_key(),
+ messages,
+ self.get_model_api_url(),
+ thread_id=prior_thread_id,
+ )
+ if tag and new_thread_id:
+ tm = dict(getattr(self, "thread_map", {}) or {})
+ tm[tag] = new_thread_id
+ self.thread_map = tm
+ return text
diff --git a/ai_eval/llm.py b/ai_eval/llm.py
index 98a55e8..307f716 100644
--- a/ai_eval/llm.py
+++ b/ai_eval/llm.py
@@ -50,8 +50,8 @@ def get_llm_service():
def get_llm_response(
- model: str, api_key: str, messages: list, api_base: str
-) -> str:
+ model: str, api_key: str, messages: list, api_base: str, thread_id: str | None = None
+) -> tuple[str, str | None]:
"""
Get LLM response, using either the default or custom service based on site configuration.
@@ -80,8 +80,22 @@ def get_llm_response(
API request URL. This is required only when using Llama which doesn't have an official provider.
Returns:
- str: The response text from the LLM. This is typically the generated output based on the provided
- messages.
+ tuple[str, Optional[str]]: The response text and a new thread id if a provider thread was created/used.
"""
llm_service = get_llm_service()
- return llm_service.get_response(model, api_key, messages, api_base)
+ allow_threads = False
+ try:
+ allow_threads = bool(llm_service.supports_threads())
+ except Exception: # pylint: disable=broad-exception-caught
+ allow_threads = False
+
+ # Continue threaded conversation when a thread_id is provided
+ if thread_id:
+ return llm_service.get_response(model, api_key, messages, api_base, thread_id=thread_id)
+
+ # Start a new thread only when allowed in config setting
+ if allow_threads:
+ return llm_service.start_thread(model, api_key, messages, api_base)
+
+ # Stateless call - do not create or persist any conversation id
+ return llm_service.get_response(model, api_key, messages, api_base, thread_id=None)
diff --git a/ai_eval/llm_services.py b/ai_eval/llm_services.py
index c573873..83a2cba 100644
--- a/ai_eval/llm_services.py
+++ b/ai_eval/llm_services.py
@@ -6,6 +6,7 @@
from litellm import completion
from .supported_models import SupportedModels
+from .compat import get_site_configuration_value
logger = logging.getLogger(__name__)
@@ -16,30 +17,79 @@ class LLMServiceBase:
"""
Base class for llm service.
"""
- def get_response(self, model, api_key, messages, api_base):
+ # pylint: disable=too-many-positional-arguments
+ def get_response(self, model, api_key, messages, api_base, thread_id=None):
+ """
+ Get a response from the provider.
+
+ Args:
+ model (str): Model identifier.
+ api_key (str): API key (for default providers or passthrough).
+ messages (list[dict]): Chat messages.
+ api_base (str|None): Optional base URL (e.g., for llama/ollama).
+ thread_id (str|None): Optional provider-side conversation/thread id.
+
+ Returns:
+ tuple[str, str|None]: (response_text, optional_thread_id)
+ """
+ raise NotImplementedError
+
+ # pylint: disable=too-many-positional-arguments
+ def start_thread(self, model, api_key, messages, api_base):
+ """
+ Start a new provider-side thread and return its first response.
+
+ Return the provider-issued conversation/thread id if available.
+
+ Returns:
+ tuple[str, str|None]: (response_text, new_thread_id)
+ """
raise NotImplementedError
def get_available_models(self):
raise NotImplementedError
+ def supports_threads(self) -> bool:
+ """
+ Check if this service supports provider-side threads.
+
+ Default is False; custom services can override to be flag-driven.
+ """
+ return False
+
class DefaultLLMService(LLMServiceBase):
"""
Default llm service.
"""
- def get_response(self, model, api_key, messages, api_base):
+ # pylint: disable=too-many-positional-arguments
+ def get_response(
+ self,
+ model,
+ api_key,
+ messages,
+ api_base,
+ thread_id=None,
+ ):
kwargs = {}
if api_base:
kwargs["api_base"] = api_base
- return (
+ text = (
completion(model=model, api_key=api_key, messages=messages, **kwargs)
.choices[0]
.message.content
)
+ return text, None
+
+ def start_thread(self, model, api_key, messages, api_base): # pragma: no cover - trivial passthrough
+ return self.get_response(model, api_key, messages, api_base, thread_id=None)
def get_available_models(self):
return [str(m.value) for m in SupportedModels]
+ def supports_threads(self) -> bool: # pragma: nocover - default is stateless
+ return False
+
class CustomLLMService(LLMServiceBase):
"""
@@ -81,12 +131,61 @@ def _get_headers(self):
self._ensure_token()
return {'Authorization': f'Bearer {self._access_token}'}
- def get_response(self, model, api_key, messages, api_base):
+ def get_response(
+ self,
+ model,
+ api_key,
+ messages,
+ api_base,
+ thread_id=None,
+ ):
"""
Send completion request to custom LLM endpoint.
+ If thread_id is provided, include it and send only the latest user input.
+ If thread_id is None, send full context and return (text, None).
"""
url = self.completions_url
+ # When reusing an existing thread, only send the latest user input and rely on
+ # the provider to apply prior context associated with the conversation_id.
+ if thread_id:
+ latest_user = None
+ for msg in reversed(messages):
+ if (msg.get('role') or '').lower() == 'user':
+ latest_user = msg.get('content', '').strip()
+ break
+ prompt = f"User: {latest_user}" if latest_user is not None else ""
+ else:
+ prompt = " ".join(
+ f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
+ for msg in messages
+ )
# Adjust the payload structure based on custom API requirements
+ payload = {
+ "model": str(model),
+ "prompt": prompt,
+ }
+
+ if thread_id:
+ payload["conversation_id"] = thread_id
+
+ response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10)
+ response.raise_for_status()
+ data = response.json()
+ # Adjust this if custom API returns the response differently
+ text = data.get("response")
+
+ if thread_id:
+ new_thread_id = data.get("conversation_id")
+ if not new_thread_id and isinstance(data.get("data"), dict):
+ new_thread_id = data["data"].get("conversation_id")
+ return text, new_thread_id
+ return text, None
+
+ def start_thread(self, model, api_key, messages, api_base):
+ """
+ Start new thread.
+ """
+ url = self.completions_url
prompt = " ".join(
f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
for msg in messages
@@ -98,8 +197,11 @@ def get_response(self, model, api_key, messages, api_base):
response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10)
response.raise_for_status()
data = response.json()
- # Adjust this if custom API returns the response differently
- return data.get("response")
+ text = data.get("response")
+ new_thread_id = data.get("conversation_id")
+ if not new_thread_id and isinstance(data.get("data"), dict):
+ new_thread_id = data["data"].get("conversation_id")
+ return text, new_thread_id
def get_available_models(self):
url = self.models_url
@@ -153,3 +255,11 @@ def get_available_models(self):
exc_info=True,
)
return []
+
+ def supports_threads(self) -> bool:
+ """Return whether provider threads should be used, from site flag."""
+ try:
+ val = get_site_configuration_value("ai_eval", "PROVIDER_SUPPORTS_THREADS")
+ return bool(val)
+ except Exception: # pylint: disable=broad-exception-caught
+ return False
diff --git a/ai_eval/shortanswer.py b/ai_eval/shortanswer.py
index bbc301b..fe40666 100644
--- a/ai_eval/shortanswer.py
+++ b/ai_eval/shortanswer.py
@@ -1,6 +1,7 @@
"""Short answers Xblock with AI evaluation."""
import logging
+import hashlib
import urllib.parse
import urllib.request
from multiprocessing.dummy import Pool
@@ -14,6 +15,8 @@
from xblock.validation import ValidationMessage
from .base import AIEvalXBlock
+from .llm import get_llm_service
+from .llm_services import CustomLLMService
logger = logging.getLogger(__name__)
@@ -178,15 +181,33 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
user_submission = str(data["user_input"])
attachments = []
+ attachment_hash_inputs = []
for filename, contents in self._get_attachments():
+ # Build system prompt attachment section (HTML-like) as before
attachments.append(f"""
{saxutils.escape(filename)}
{saxutils.escape(contents)}
""")
+ # For tagging, hash filename + contents
+ attachment_hash_inputs.append(f"{filename}|{contents}")
attachments = '\n'.join(attachments)
+ # Compute a tag to identify compatible reuse across provider/model/prompt
+ # Include evaluation prompt, question, and attachment content hashes
+ prompt_hasher = hashlib.sha256()
+ prompt_hasher.update((self.evaluation_prompt or "").strip().encode("utf-8"))
+ prompt_hasher.update((self.question or "").strip().encode("utf-8"))
+ for item in attachment_hash_inputs:
+ prompt_hasher.update(item.encode("utf-8"))
+ prompt_hash = prompt_hasher.hexdigest()
+
+ # Determine provider tag based on service type
+ llm_service = get_llm_service()
+ provider_tag = "custom" if isinstance(llm_service, CustomLLMService) else "default"
+ current_tag = f"{provider_tag}:{self.model}:{prompt_hash}"
+
system_msg = {
"role": "system",
"content": f"""
@@ -210,7 +231,7 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
messages.append({"role": "user", "content": user_submission})
try:
- response = self.get_llm_response(messages)
+ text = self.get_llm_response(messages, tag=current_tag)
except Exception as e:
logger.error(
f"Failed while making LLM request using model {self.model}. Error: {e}",
@@ -218,10 +239,10 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
)
raise JsonHandlerError(500, "A probem occured. Please retry.") from e
- if response:
+ if text:
self.messages[self.USER_KEY].append(user_submission)
- self.messages[self.LLM_KEY].append(response)
- return {"response": response}
+ self.messages[self.LLM_KEY].append(text)
+ return {"response": text}
raise JsonHandlerError(500, "A probem occured. The LLM sent an empty response.")
@@ -233,6 +254,7 @@ def reset(self, data, suffix=""):
if not self.allow_reset:
raise JsonHandlerError(403, "Reset is disabled.")
self.messages = {self.USER_KEY: [], self.LLM_KEY: []}
+ self.thread_map = {}
return {}
@staticmethod
diff --git a/ai_eval/tests/test_ai_eval.py b/ai_eval/tests/test_ai_eval.py
index 2243bf9..148736c 100644
--- a/ai_eval/tests/test_ai_eval.py
+++ b/ai_eval/tests/test_ai_eval.py
@@ -102,8 +102,11 @@ def test_shortanswer_reset_allowed(shortanswer_block_data):
"messages": {"USER": ["Hello"], "LLM": ["Hello"]},
}
block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None)
+ # Pre-populate thread map to verify reset clears it
+ block.thread_map = {"provider:model:tag": "abc123"}
block.reset.__wrapped__(block, data={})
assert block.messages == {"USER": [], "LLM": []}
+ assert not block.thread_map
def test_shortanswer_reset_forbidden(shortanswer_block_data):
@@ -140,9 +143,17 @@ def test_shortanswer_attachments(shortanswer_block_data):
}
block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None)
block._download_attachment = Mock(return_value="file contents <&>")
- block.get_llm_response = Mock(return_value=".")
- block.get_response.__wrapped__(block, data={"user_input": "."})
- messages = block.get_llm_response.call_args.args[0]
+ with patch('ai_eval.shortanswer.get_llm_service') as mock_service, \
+ patch('ai_eval.llm.get_llm_service') as mock_llm_service, \
+ patch('ai_eval.base.get_site_configuration_value', return_value=None), \
+ patch('ai_eval.base.get_llm_response') as mocked:
+ mock_service.return_value = Mock()
+ mock_service.return_value.supports_threads.return_value = False
+ mock_llm_service.return_value = mock_service.return_value
+ mocked.return_value = (".", None)
+ block.get_response.__wrapped__(block, data={"user_input": "."})
+ # Extract the messages argument passed into get_llm_response
+ messages = mocked.call_args.kwargs.get('messages') or mocked.call_args.args[2]
prompt = messages[0]["content"]
assert "1.txt" in prompt
assert "file contents <&>" in prompt