Skip to content

Commit e8de761

Browse files
authored
Merge pull request #23 from AnswerDotAI/fix-cache-control
Fix cache control
2 parents bcba403 + 5ce06a0 commit e8de761

File tree

4 files changed

+677
-1384
lines changed

4 files changed

+677
-1384
lines changed

cachy.jsonl

Lines changed: 129 additions & 0 deletions
Large diffs are not rendered by default.

lisette/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'lisette.core._add_cache_control': ('core.html#_add_cache_control', 'lisette/core.py'),
1818
'lisette.core._alite_call_func': ('core.html#_alite_call_func', 'lisette/core.py'),
1919
'lisette.core._clean_str': ('core.html#_clean_str', 'lisette/core.py'),
20+
'lisette.core._has_cache': ('core.html#_has_cache', 'lisette/core.py'),
2021
'lisette.core._has_search': ('core.html#_has_search', 'lisette/core.py'),
2122
'lisette.core._is_img': ('core.html#_is_img', 'lisette/core.py'),
2223
'lisette.core._lite_call_func': ('core.html#_lite_call_func', 'lisette/core.py'),

lisette/core.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,8 @@ def _is_img(data):
8383
return isinstance(data, bytes) and bool(imghdr.what(None, data))
8484

8585
def _add_cache_control(msg, # LiteLLM formatted msg
86-
cache=False, # Enable Anthropic caching
8786
ttl=None): # Cache TTL: '5m' (default) or '1h'
8887
"cache `msg` with default time-to-live (ttl) of 5minutes ('5m'), but can be set to '1h'."
89-
if not cache: return msg
9088
if isinstance(msg["content"], str):
9189
msg["content"] = [{"type": "text", "text": msg["content"]}]
9290
cache_control = {"type": "ephemeral"}
@@ -95,10 +93,12 @@ def _add_cache_control(msg, # LiteLLM formatted msg
9593
msg["content"][-1]["cache_control"] = cache_control
9694
return msg
9795

96+
def _has_cache(msg):
97+
return msg["content"] and isinstance(msg["content"], list) and ('cache_control' in msg["content"][-1])
98+
9899
def _remove_cache_ckpts(msg):
99-
"remove unnecessary cache checkpoints."
100-
if isinstance(msg["content"], list) and msg["content"]:
101-
msg["content"][-1].pop('cache_control', None)
100+
"remove cache checkpoints and return msg."
101+
if _has_cache(msg): msg["content"][-1].pop('cache_control', None)
102102
return msg
103103

104104
def _mk_content(o):
@@ -117,23 +117,24 @@ def mk_msg(content, # Content: str, bytes (image), list of mixed content, o
117117
if isinstance(content, list) and len(content) == 1 and isinstance(content[0], str): c = content[0]
118118
elif isinstance(content, list): c = [_mk_content(o) for o in content]
119119
else: c = content
120-
return _add_cache_control({"role": role, "content": c}, cache=cache, ttl=ttl)
120+
msg = {"role": role, "content": c}
121+
return _add_cache_control(msg, ttl=ttl) if cache else msg
121122

122123
# %% ../nbs/00_core.ipynb
123-
def mk_msgs(msgs, # List of messages (each: str, bytes, list, or dict w 'role' and 'content' fields)
124-
cache=False, # Enable Anthropic caching
125-
ttl=None, # Cache TTL: '5m' (default) or '1h'
126-
cache_last_ckpt_only=True # Only cache the last message
124+
def mk_msgs(msgs, # List of messages (each: str, bytes, list, or dict w 'role' and 'content' fields)
125+
cache=False, # Enable Anthropic caching
126+
ttl=None, # Cache TTL: '5m' (default) or '1h'
127127
):
128128
"Create a list of LiteLLM compatible messages."
129129
if not msgs: return []
130130
if not isinstance(msgs, list): msgs = [msgs]
131131
res,role = [],'user'
132132
for m in msgs:
133-
res.append(msg:=mk_msg(m, role=role,cache=cache))
133+
res.append(msg:=mk_msg(m, role=role))
134134
role = 'assistant' if msg['role'] in ('user','function', 'tool') else 'user'
135-
if cache_last_ckpt_only: res = [_remove_cache_ckpts(m) for m in res]
136-
if res and cache: res[-1] = _add_cache_control(res[-1], cache=cache, ttl=ttl)
135+
if cache:
136+
res[-1] = _add_cache_control(res[-1], ttl)
137+
res[-2] = _add_cache_control(res[-2], ttl)
137138
return res
138139

139140
# %% ../nbs/00_core.ipynb
@@ -194,11 +195,12 @@ def __init__(
194195
tools:list=None, # Add tools
195196
hist:list=None, # Chat history
196197
ns:Optional[dict]=None, # Custom namespace for tool calling
197-
cache=False # Anthropic prompt caching
198+
cache=False, # Anthropic prompt caching
199+
ttl=None, # Anthropic prompt caching ttl
198200
):
199201
"LiteLLM chat client."
200202
self.model = model
201-
hist,tools = mk_msgs(hist),listify(tools)
203+
hist,tools = mk_msgs(hist,cache,ttl),listify(tools)
202204
if ns is None and tools: ns = mk_ns(tools)
203205
elif ns is None: ns = globals()
204206
self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None
@@ -207,7 +209,7 @@ def __init__(
207209
def _prep_msg(self, msg=None, prefill=None):
208210
"Prepare the messages list for the API call"
209211
sp = [{"role": "system", "content": self.sp}] if self.sp else []
210-
if msg: self.hist = mk_msgs(self.hist+[msg], cache=self.cache)
212+
if msg: self.hist = mk_msgs(self.hist+[msg], self.cache, self.ttl)
211213
pf = [{"role":"assistant","content":prefill}] if prefill else []
212214
return sp + self.hist + pf
213215

@@ -369,7 +371,7 @@ def _trunc_str(s, mx=2000, replace="…"):
369371
return s[:mx]+replace if len(s)>mx else s
370372

371373
# %% ../nbs/00_core.ipynb
372-
async def aformat_stream(rs):
374+
async def aformat_stream(rs, include_usage=False):
373375
"Format the response stream for markdown display."
374376
think = False
375377
async for o in rs:
@@ -382,9 +384,11 @@ async def aformat_stream(rs):
382384
think = False
383385
yield '\n\n'
384386
if c := d.content: yield c
385-
elif isinstance(o, ModelResponse) and (c := getattr(o.choices[0].message, 'tool_calls', None)):
386-
fn = first(c).function
387-
yield f"\n<details class='tool-usage-details'>\n\n `{fn.name}({_trunc_str(fn.arguments)})`\n"
387+
elif isinstance(o, ModelResponse):
388+
if include_usage: yield f"\nUsage: {o.usage}"
389+
if (c := getattr(o.choices[0].message, 'tool_calls', None)):
390+
fn = first(c).function
391+
yield f"\n<details class='tool-usage-details'>\n\n `{fn.name}({_trunc_str(fn.arguments)})`\n"
388392
elif isinstance(o, dict) and 'tool_call_id' in o:
389393
yield f" - `{_trunc_str(_clean_str(o.get('content')))}`\n\n</details>\n\n"
390394

0 commit comments

Comments
 (0)