Skip to content

Commit 30e6c2b

Browse files
committed
[BB-9988] Add option to use conversation_id for custom llm service (#16)
* feat: Add option to use conversation_id for custom service
1 parent 98ea5b1 commit 30e6c2b

File tree

6 files changed

+193
-25
lines changed

6 files changed

+193
-25
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ CUSTOM_LLM_CLIENT_SECRET = "your-client-secret"
9393
Your custom service must implement the expected OAuth2 client‑credentials flow and provide JSON endpoints
9494
for listing models, obtaining completions, and fetching tokens as used by `CustomLLMService`.
9595

96+
#### Optional provider threads (conversation IDs)
97+
98+
For deployments using a custom LLM service, you can enable provider‑side threads to cache context between turns. This is optional and disabled by default. When enabled, the LMS/XBlock remains the canonical chat history as that ensures vendor flexibility and continuity; provider threads are treated as a cache.
99+
100+
- Site configuration (under `ai_eval`):
101+
- `PROVIDER_SUPPORTS_THREADS`: boolean, default `false`. When `true`, `CustomLLMService` attempts to reuse a provider conversation ID.
102+
- XBlock user state (managed automatically):
103+
- `thread_map`: a dictionary mapping `tag -> conversation_id`, where `tag = provider:model:prompt_hash`. This allows multiple concurrent provider threads per learner per XBlock, one per distinct prompt/model context.
104+
105+
Reset clears `thread_map`. If a provider ignores threads, behavior remains stateless.
106+
107+
Compatibility and fallback
108+
- Not all vendors/models support `conversation_id`. The default service path (via LiteLLM chat completions) does not use provider threads; calls remain stateless.
109+
- If threads are unsupported or ignored by a provider, the code still works and behaves statelessly.
110+
- With a custom provider that supports threads, the first turn sends full context and later turns send only the latest user input along with the cached `conversation_id`.
111+
96112
### Custom Code Execution Service (advanced)
97113

98114
The Coding XBlock can route code execution to a third‑party service instead of Judge0. The service is expected to be asynchronous, exposing a submit endpoint that returns a submission identifier, and a results endpoint that returns the execution result when available. Configure this via Django settings:

ai_eval/base.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from django.utils.translation import gettext_noop as _
99
from xblock.core import XBlock
10-
from xblock.fields import String, Scope
10+
from xblock.fields import String, Scope, Dict
1111
from xblock.utils.resources import ResourceLoader
1212
from xblock.utils.studio_editable import StudioEditableXBlockMixin
1313
from xblock.validation import ValidationMessage
@@ -97,6 +97,11 @@ class AIEvalXBlock(StudioEditableXBlockMixin, XBlock):
9797
default="",
9898
values_provider=_get_model_choices,
9999
)
100+
thread_map = Dict(
101+
help=_("Map of provider thread IDs keyed by tag"),
102+
default={},
103+
scope=Scope.user_state,
104+
)
100105

101106
editable_fields = (
102107
"display_name",
@@ -234,6 +239,26 @@ def validate_field_data(self, validation, data):
234239
)
235240
)
236241

237-
def get_llm_response(self, messages):
238-
return get_llm_response(self.model, self.get_model_api_key(), messages,
239-
self.get_model_api_url())
242+
def get_llm_response(self, messages, tag: str | None = None):
243+
"""
244+
Call the shared LLM entrypoint and return only the response text.
245+
"""
246+
prior_thread_id = None
247+
if tag:
248+
try:
249+
prior_thread_id = (self.thread_map or {}).get(tag) or None
250+
except Exception: # pylint: disable=broad-exception-caught
251+
prior_thread_id = None
252+
253+
text, new_thread_id = get_llm_response(
254+
self.model,
255+
self.get_model_api_key(),
256+
messages,
257+
self.get_model_api_url(),
258+
thread_id=prior_thread_id,
259+
)
260+
if tag and new_thread_id:
261+
tm = dict(getattr(self, "thread_map", {}) or {})
262+
tm[tag] = new_thread_id
263+
self.thread_map = tm
264+
return text

ai_eval/llm.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def get_llm_service():
5050

5151

5252
def get_llm_response(
53-
model: str, api_key: str, messages: list, api_base: str
54-
) -> str:
53+
model: str, api_key: str, messages: list, api_base: str, thread_id: str | None = None
54+
) -> tuple[str, str | None]:
5555
"""
5656
Get LLM response, using either the default or custom service based on site configuration.
5757
@@ -80,8 +80,22 @@ def get_llm_response(
8080
API request URL. This is required only when using Llama which doesn't have an official provider.
8181
8282
Returns:
83-
str: The response text from the LLM. This is typically the generated output based on the provided
84-
messages.
83+
tuple[str, Optional[str]]: The response text and a new thread id if a provider thread was created/used.
8584
"""
8685
llm_service = get_llm_service()
87-
return llm_service.get_response(model, api_key, messages, api_base)
86+
allow_threads = False
87+
try:
88+
allow_threads = bool(llm_service.supports_threads())
89+
except Exception: # pylint: disable=broad-exception-caught
90+
allow_threads = False
91+
92+
# Continue threaded conversation when a thread_id is provided
93+
if thread_id:
94+
return llm_service.get_response(model, api_key, messages, api_base, thread_id=thread_id)
95+
96+
# Start a new thread only when allowed in config setting
97+
if allow_threads:
98+
return llm_service.start_thread(model, api_key, messages, api_base)
99+
100+
# Stateless call - do not create or persist any conversation id
101+
return llm_service.get_response(model, api_key, messages, api_base, thread_id=None)

ai_eval/llm_services.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from litellm import completion
88
from .supported_models import SupportedModels
9+
from .compat import get_site_configuration_value
910

1011
logger = logging.getLogger(__name__)
1112

@@ -18,27 +19,70 @@ class LLMServiceBase:
1819
"""
1920
Base class for llm service.
2021
"""
21-
def get_response(self, model, api_key, messages, api_base):
22+
# pylint: disable=too-many-positional-arguments
23+
def get_response(self, model, api_key, messages, api_base, thread_id=None):
24+
"""
25+
Get a response from the provider.
26+
27+
Args:
28+
model (str): Model identifier.
29+
api_key (str): API key (for default providers or passthrough).
30+
messages (list[dict]): Chat messages.
31+
api_base (str|None): Optional base URL (e.g., for llama/ollama).
32+
thread_id (str|None): Optional provider-side conversation/thread id.
33+
34+
Returns:
35+
tuple[str, str|None]: (response_text, optional_thread_id)
36+
"""
37+
raise NotImplementedError
38+
39+
# pylint: disable=too-many-positional-arguments
40+
def start_thread(self, model, api_key, messages, api_base):
41+
"""
42+
Start a new provider-side thread and return its first response.
43+
44+
Return the provider-issued conversation/thread id if available.
45+
46+
Returns:
47+
tuple[str, str|None]: (response_text, new_thread_id)
48+
"""
2249
raise NotImplementedError
2350

2451
def get_available_models(self):
2552
raise NotImplementedError
2653

54+
def supports_threads(self) -> bool:
55+
"""
56+
Check if this service supports provider-side threads.
57+
58+
Default is False; custom services can override to be flag-driven.
59+
"""
60+
return False
61+
2762

2863
class DefaultLLMService(LLMServiceBase):
2964
"""
3065
Default llm service.
3166
"""
32-
def get_response(self, model, api_key, messages, api_base):
67+
# pylint: disable=too-many-positional-arguments
68+
def get_response(
69+
self,
70+
model,
71+
api_key,
72+
messages,
73+
api_base,
74+
thread_id=None,
75+
):
3376
kwargs = {}
3477
if api_base:
3578
kwargs["api_base"] = api_base
3679
try:
37-
return (
38-
completion(model=model, api_key=api_key, messages=messages, timeout=30, **kwargs)
80+
text = (
81+
completion(model=model, api_key=api_key, messages=messages, **kwargs)
3982
.choices[0]
4083
.message.content
4184
)
85+
return text, None
4286
except Exception as e:
4387
if "timeout" in str(e).lower():
4488
raise Exception(TIMEOUT_ERROR_MESSAGE) from e
@@ -47,6 +91,9 @@ def get_response(self, model, api_key, messages, api_base):
4791
def get_available_models(self):
4892
return [str(m.value) for m in SupportedModels]
4993

94+
def supports_threads(self) -> bool: # pragma: nocover - default is stateless
95+
return False
96+
5097

5198
class CustomLLMService(LLMServiceBase):
5299
"""
@@ -88,11 +135,34 @@ def _get_headers(self):
88135
self._ensure_token()
89136
return {'Authorization': f'Bearer {self._access_token}'}
90137

91-
def get_response(self, model, api_key, messages, api_base):
138+
def get_response(
139+
self,
140+
model,
141+
api_key,
142+
messages,
143+
api_base,
144+
thread_id=None,
145+
):
92146
"""
93147
Send completion request to custom LLM endpoint.
148+
If thread_id is provided, include it and send only the latest user input.
149+
If thread_id is None, send full context and return (text, None).
94150
"""
95151
url = self.completions_url
152+
# When reusing an existing thread, only send the latest user input and rely on
153+
# the provider to apply prior context associated with the conversation_id.
154+
if thread_id:
155+
latest_user = None
156+
for msg in reversed(messages):
157+
if (msg.get('role') or '').lower() == 'user':
158+
latest_user = msg.get('content', '').strip()
159+
break
160+
prompt = f"User: {latest_user}" if latest_user is not None else ""
161+
else:
162+
prompt = " ".join(
163+
f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
164+
for msg in messages
165+
)
96166
# Adjust the payload structure based on custom API requirements
97167
prompt = " ".join(
98168
f"{msg.get('role', '').capitalize()}: {msg.get('content', '').strip()}"
@@ -103,11 +173,14 @@ def get_response(self, model, api_key, messages, api_base):
103173
"prompt": prompt,
104174
}
105175
try:
106-
response = requests.post(url, json=payload, headers=self._get_headers(), timeout=30)
176+
response = requests.post(url, json=payload, headers=self._get_headers(), timeout=10)
107177
response.raise_for_status()
108178
data = response.json()
109-
# Adjust this if custom API returns the response differently
110-
return data.get("response")
179+
text = data.get("response")
180+
new_thread_id = data.get("conversation_id")
181+
if not new_thread_id and isinstance(data.get("data"), dict):
182+
new_thread_id = data["data"].get("conversation_id")
183+
return text, new_thread_id
111184
except requests.exceptions.Timeout:
112185
raise Exception(TIMEOUT_ERROR_MESSAGE) # pylint: disable=raise-missing-from
113186

@@ -163,3 +236,11 @@ def get_available_models(self):
163236
exc_info=True,
164237
)
165238
return []
239+
240+
def supports_threads(self) -> bool:
241+
"""Return whether provider threads should be used, from site flag."""
242+
try:
243+
val = get_site_configuration_value("ai_eval", "PROVIDER_SUPPORTS_THREADS")
244+
return bool(val)
245+
except Exception: # pylint: disable=broad-exception-caught
246+
return False

ai_eval/shortanswer.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Short answers Xblock with AI evaluation."""
22

33
import logging
4+
import hashlib
45
import urllib.parse
56
import urllib.request
67
from multiprocessing.dummy import Pool
@@ -14,7 +15,8 @@
1415
from xblock.validation import ValidationMessage
1516

1617
from .base import AIEvalXBlock
17-
from .llm_services import TIMEOUT_ERROR_MESSAGE
18+
from .llm import get_llm_service
19+
from .llm_services import CustomLLMService, TIMEOUT_ERROR_MESSAGE
1820

1921

2022
logger = logging.getLogger(__name__)
@@ -179,15 +181,33 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
179181
user_submission = str(data["user_input"])
180182

181183
attachments = []
184+
attachment_hash_inputs = []
182185
for filename, contents in self._get_attachments():
186+
# Build system prompt attachment section (HTML-like) as before
183187
attachments.append(f"""
184188
<attachment>
185189
<filename>{saxutils.escape(filename)}</filename>
186190
<contents>{saxutils.escape(contents)}</contents>
187191
</attachment>
188192
""")
193+
# For tagging, hash filename + contents
194+
attachment_hash_inputs.append(f"{filename}|{contents}")
189195
attachments = '\n'.join(attachments)
190196

197+
# Compute a tag to identify compatible reuse across provider/model/prompt
198+
# Include evaluation prompt, question, and attachment content hashes
199+
prompt_hasher = hashlib.sha256()
200+
prompt_hasher.update((self.evaluation_prompt or "").strip().encode("utf-8"))
201+
prompt_hasher.update((self.question or "").strip().encode("utf-8"))
202+
for item in attachment_hash_inputs:
203+
prompt_hasher.update(item.encode("utf-8"))
204+
prompt_hash = prompt_hasher.hexdigest()
205+
206+
# Determine provider tag based on service type
207+
llm_service = get_llm_service()
208+
provider_tag = "custom" if isinstance(llm_service, CustomLLMService) else "default"
209+
current_tag = f"{provider_tag}:{self.model}:{prompt_hash}"
210+
191211
system_msg = {
192212
"role": "system",
193213
"content": f"""
@@ -211,7 +231,7 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
211231
messages.append({"role": "user", "content": user_submission})
212232

213233
try:
214-
response = self.get_llm_response(messages)
234+
text = self.get_llm_response(messages, tag=current_tag)
215235
except Exception as e:
216236
logger.error(
217237
f"Failed while making LLM request using model {self.model}. Error: {e}",
@@ -221,10 +241,10 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
221241
raise JsonHandlerError(500, str(e)) from e
222242
raise JsonHandlerError(500, "A probem occurred. Please retry.") from e
223243

224-
if response:
244+
if text:
225245
self.messages[self.USER_KEY].append(user_submission)
226-
self.messages[self.LLM_KEY].append(response)
227-
return {"response": response}
246+
self.messages[self.LLM_KEY].append(text)
247+
return {"response": text}
228248

229249
raise JsonHandlerError(500, "A probem occurred. The LLM sent an empty response.")
230250

@@ -236,6 +256,7 @@ def reset(self, data, suffix=""):
236256
if not self.allow_reset:
237257
raise JsonHandlerError(403, "Reset is disabled.")
238258
self.messages = {self.USER_KEY: [], self.LLM_KEY: []}
259+
self.thread_map = {}
239260
return {}
240261

241262
@staticmethod

ai_eval/tests/test_ai_eval.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,11 @@ def test_shortanswer_reset_allowed(shortanswer_block_data):
102102
"messages": {"USER": ["Hello"], "LLM": ["Hello"]},
103103
}
104104
block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None)
105+
# Pre-populate thread map to verify reset clears it
106+
block.thread_map = {"provider:model:tag": "abc123"}
105107
block.reset.__wrapped__(block, data={})
106108
assert block.messages == {"USER": [], "LLM": []}
109+
assert not block.thread_map
107110

108111

109112
def test_shortanswer_reset_forbidden(shortanswer_block_data):
@@ -140,9 +143,17 @@ def test_shortanswer_attachments(shortanswer_block_data):
140143
}
141144
block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None)
142145
block._download_attachment = Mock(return_value="file contents <&>")
143-
block.get_llm_response = Mock(return_value=".")
144-
block.get_response.__wrapped__(block, data={"user_input": "."})
145-
messages = block.get_llm_response.call_args.args[0]
146+
with patch('ai_eval.shortanswer.get_llm_service') as mock_service, \
147+
patch('ai_eval.llm.get_llm_service') as mock_llm_service, \
148+
patch('ai_eval.base.get_site_configuration_value', return_value=None), \
149+
patch('ai_eval.base.get_llm_response') as mocked:
150+
mock_service.return_value = Mock()
151+
mock_service.return_value.supports_threads.return_value = False
152+
mock_llm_service.return_value = mock_service.return_value
153+
mocked.return_value = (".", None)
154+
block.get_response.__wrapped__(block, data={"user_input": "."})
155+
# Extract the messages argument passed into get_llm_response
156+
messages = mocked.call_args.kwargs.get('messages') or mocked.call_args.args[2]
146157
prompt = messages[0]["content"]
147158
assert "<filename>1.txt</filename>" in prompt
148159
assert "<contents>file contents &lt;&amp;&gt;</contents>" in prompt

0 commit comments

Comments
 (0)