Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ CUSTOM_LLM_CLIENT_SECRET = "your-client-secret"
Your custom service must implement the expected OAuth2 client‑credentials flow and provide JSON endpoints
for listing models, obtaining completions, and fetching tokens as used by `CustomLLMService`.

#### Optional provider threads (conversation IDs)

For deployments using a custom LLM service, you can enable provider‑side threads to cache context between turns. This is optional and disabled by default. When enabled, the LMS/XBlock remains the canonical chat history as that ensures vendor flexibility and continuity; provider threads are treated as a cache.

- Site configuration (under `ai_eval`):
- `PROVIDER_SUPPORTS_THREADS`: boolean, default `false`. When `true`, `CustomLLMService` attempts to reuse a provider conversation ID.
- XBlock user state (managed automatically):
- `thread_map`: a dictionary mapping `tag -> conversation_id`, where `tag = provider:model:prompt_hash`. This allows multiple concurrent provider threads per learner per XBlock, one per distinct prompt/model context.

Reset clears `thread_map`. If a provider ignores threads, behavior remains stateless.

Compatibility and fallback
- Not all vendors/models support `conversation_id`. The default service path (via LiteLLM chat completions) does not use provider threads; calls remain stateless.
- If threads are unsupported or ignored by a provider, the code still works and behaves statelessly.
- With a custom provider that supports threads, the first turn sends full context and later turns send only the latest user input along with the cached `conversation_id`.

### Custom Code Execution Service (advanced)

The Coding XBlock can route code execution to a third‑party service instead of Judge0. The service is expected to be asynchronous, exposing a submit endpoint that returns a submission identifier, and a results endpoint that returns the execution result when available. Configure this via Django settings:
Expand Down
33 changes: 29 additions & 4 deletions ai_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from django.utils.translation import gettext_noop as _
from xblock.core import XBlock
from xblock.fields import String, Scope
from xblock.fields import String, Scope, Dict
from xblock.utils.resources import ResourceLoader
from xblock.utils.studio_editable import StudioEditableXBlockMixin
from xblock.validation import ValidationMessage
Expand Down Expand Up @@ -78,6 +78,11 @@ class AIEvalXBlock(StudioEditableXBlockMixin, XBlock):
default="",
values_provider=_get_model_choices,
)
thread_map = Dict(
help=_("Map of provider thread IDs keyed by tag"),
default={},
scope=Scope.user_state,
)

editable_fields = (
"display_name",
Expand Down Expand Up @@ -206,6 +211,26 @@ def validate_field_data(self, validation, data):
)
)

def get_llm_response(self, messages):
return get_llm_response(self.model, self.get_model_api_key(), messages,
self.get_model_api_url())
def get_llm_response(self, messages, tag: str | None = None):
"""
Call the shared LLM entrypoint and return only the response text.
"""
prior_thread_id = None
if tag:
try:
prior_thread_id = (self.thread_map or {}).get(tag) or None
except Exception: # pylint: disable=broad-exception-caught
prior_thread_id = None

text, new_thread_id = get_llm_response(
self.model,
self.get_model_api_key(),
messages,
self.get_model_api_url(),
thread_id=prior_thread_id,
)
if tag and new_thread_id:
tm = dict(getattr(self, "thread_map", {}) or {})
tm[tag] = new_thread_id
self.thread_map = tm
return text
24 changes: 19 additions & 5 deletions ai_eval/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def get_llm_service():


def get_llm_response(
model: str, api_key: str, messages: list, api_base: str
) -> str:
model: str, api_key: str, messages: list, api_base: str, thread_id: str | None = None
) -> tuple[str, str | None]:
"""
Get LLM response, using either the default or custom service based on site configuration.
Expand Down Expand Up @@ -80,8 +80,22 @@ def get_llm_response(
API request URL. This is required only when using Llama which doesn't have an official provider.
Returns:
str: The response text from the LLM. This is typically the generated output based on the provided
messages.
tuple[str, Optional[str]]: The response text and a new thread id if a provider thread was created/used.
"""
llm_service = get_llm_service()
return llm_service.get_response(model, api_key, messages, api_base)
allow_threads = False
try:
allow_threads = bool(llm_service.supports_threads())
except Exception: # pylint: disable=broad-exception-caught
allow_threads = False

# Continue threaded conversation when a thread_id is provided
if thread_id:
return llm_service.get_response(model, api_key, messages, api_base, thread_id=thread_id)

# Start a new thread only when allowed in config setting
if allow_threads:
return llm_service.start_thread(model, api_key, messages, api_base)

# Stateless call - do not create or persist any conversation id
return llm_service.get_response(model, api_key, messages, api_base, thread_id=None)
122 changes: 116 additions & 6 deletions ai_eval/llm_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from litellm import completion
from .supported_models import SupportedModels
from .compat import get_site_configuration_value

logger = logging.getLogger(__name__)

Expand All @@ -16,30 +17,79 @@ class LLMServiceBase:
"""
Base class for llm service.
"""
def get_response(self, model, api_key, messages, api_base):
# pylint: disable=too-many-positional-arguments
def get_response(self, model, api_key, messages, api_base, thread_id=None):
"""
Get a response from the provider.

Args:
model (str): Model identifier.
api_key (str): API key (for default providers or passthrough).
messages (list[dict]): Chat messages.
api_base (str|None): Optional base URL (e.g., for llama/ollama).
thread_id (str|None): Optional provider-side conversation/thread id.

Returns:
tuple[str, str|None]: (response_text, optional_thread_id)
"""
raise NotImplementedError

# pylint: disable=too-many-positional-arguments
def start_thread(self, model, api_key, messages, api_base):
"""
Start a new provider-side thread and return its first response.

Return the provider-issued conversation/thread id if available.

Returns:
tuple[str, str|None]: (response_text, new_thread_id)
"""
raise NotImplementedError

def get_available_models(self):
raise NotImplementedError

def supports_threads(self) -> bool:
"""
Check if this service supports provider-side threads.

Default is False; custom services can override to be flag-driven.
"""
return False


class DefaultLLMService(LLMServiceBase):
"""
Default llm service.
"""
def get_response(self, model, api_key, messages, api_base):
# pylint: disable=too-many-positional-arguments
def get_response(
self,
model,
api_key,
messages,
api_base,
thread_id=None,
):
kwargs = {}
if api_base:
kwargs["api_base"] = api_base
return (
text = (
completion(model=model, api_key=api_key, messages=messages, **kwargs)
.choices[0]
.message.content
)
return text, None

def start_thread(self, model, api_key, messages, api_base): # pragma: no cover - trivial passthrough
return self.get_response(model, api_key, messages, api_base, thread_id=None)

def get_available_models(self):
return [str(m.value) for m in SupportedModels]

def supports_threads(self) -> bool: # pragma: nocover - default is stateless
return False


class CustomLLMService(LLMServiceBase):
"""
Expand Down Expand Up @@ -81,12 +131,61 @@ def _get_headers(self):
self._ensure_token()
return {'Authorization': f'Bearer {self._access_token}'}

def get_response(self, model, api_key, messages, api_base):
def get_response(
self,
model,
api_key,
messages,
api_base,
thread_id=None,
):
"""
Send completion request to custom LLM endpoint.
If thread_id is provided, include it and send only the latest user input.
If thread_id is None, send full context and return (text, None).
"""
url = self.completions_url
# When reusing an existing thread, only send the latest user input and rely on
# the provider to apply prior context associated with the conversation_id.
if thread_id:
latest_user = None
for msg in reversed(messages):
if (msg.get('role') or '').lower() == 'user':
latest_user = msg.get('content', '').strip()
break
prompt = f"User: {latest_user}" if latest_user is not None else ""
else:
prompt = " ".join(
f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
for msg in messages
)
# Adjust the payload structure based on custom API requirements
payload = {
"model": str(model),
"prompt": prompt,
}

if thread_id:
payload["conversation_id"] = thread_id

response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10)
response.raise_for_status()
data = response.json()
# Adjust this if custom API returns the response differently
text = data.get("response")

if thread_id:
new_thread_id = data.get("conversation_id")
if not new_thread_id and isinstance(data.get("data"), dict):
new_thread_id = data["data"].get("conversation_id")
return text, new_thread_id
return text, None

def start_thread(self, model, api_key, messages, api_base):
"""
Start new thread.
"""
url = self.completions_url
prompt = " ".join(
f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
for msg in messages
Expand All @@ -98,8 +197,11 @@ def get_response(self, model, api_key, messages, api_base):
response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10)
response.raise_for_status()
data = response.json()
# Adjust this if custom API returns the response differently
return data.get("response")
text = data.get("response")
new_thread_id = data.get("conversation_id")
if not new_thread_id and isinstance(data.get("data"), dict):
new_thread_id = data["data"].get("conversation_id")
return text, new_thread_id

def get_available_models(self):
url = self.models_url
Expand Down Expand Up @@ -153,3 +255,11 @@ def get_available_models(self):
exc_info=True,
)
return []

def supports_threads(self) -> bool:
"""Return whether provider threads should be used, from site flag."""
try:
val = get_site_configuration_value("ai_eval", "PROVIDER_SUPPORTS_THREADS")
return bool(val)
except Exception: # pylint: disable=broad-exception-caught
return False
30 changes: 26 additions & 4 deletions ai_eval/shortanswer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Short answers Xblock with AI evaluation."""

import logging
import hashlib
import urllib.parse
import urllib.request
from multiprocessing.dummy import Pool
Expand All @@ -14,6 +15,8 @@
from xblock.validation import ValidationMessage

from .base import AIEvalXBlock
from .llm import get_llm_service
from .llm_services import CustomLLMService


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,15 +181,33 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
user_submission = str(data["user_input"])

attachments = []
attachment_hash_inputs = []
for filename, contents in self._get_attachments():
# Build system prompt attachment section (HTML-like) as before
attachments.append(f"""
<attachment>
<filename>{saxutils.escape(filename)}</filename>
<contents>{saxutils.escape(contents)}</contents>
</attachment>
""")
# For tagging, hash filename + contents
attachment_hash_inputs.append(f"{filename}|{contents}")
attachments = '\n'.join(attachments)

# Compute a tag to identify compatible reuse across provider/model/prompt
# Include evaluation prompt, question, and attachment content hashes
prompt_hasher = hashlib.sha256()
prompt_hasher.update((self.evaluation_prompt or "").strip().encode("utf-8"))
prompt_hasher.update((self.question or "").strip().encode("utf-8"))
for item in attachment_hash_inputs:
prompt_hasher.update(item.encode("utf-8"))
prompt_hash = prompt_hasher.hexdigest()

# Determine provider tag based on service type
llm_service = get_llm_service()
provider_tag = "custom" if isinstance(llm_service, CustomLLMService) else "default"
current_tag = f"{provider_tag}:{self.model}:{prompt_hash}"

system_msg = {
"role": "system",
"content": f"""
Expand All @@ -210,18 +231,18 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
messages.append({"role": "user", "content": user_submission})

try:
response = self.get_llm_response(messages)
text = self.get_llm_response(messages, tag=current_tag)
except Exception as e:
logger.error(
f"Failed while making LLM request using model {self.model}. Error: {e}",
exc_info=True,
)
raise JsonHandlerError(500, "A probem occured. Please retry.") from e

if response:
if text:
self.messages[self.USER_KEY].append(user_submission)
self.messages[self.LLM_KEY].append(response)
return {"response": response}
self.messages[self.LLM_KEY].append(text)
return {"response": text}

raise JsonHandlerError(500, "A probem occured. The LLM sent an empty response.")

Expand All @@ -233,6 +254,7 @@ def reset(self, data, suffix=""):
if not self.allow_reset:
raise JsonHandlerError(403, "Reset is disabled.")
self.messages = {self.USER_KEY: [], self.LLM_KEY: []}
self.thread_map = {}
return {}

@staticmethod
Expand Down
Loading