Skip to content

Commit cb1adf8

Browse files
authored
server : handle failures to restore host cache (#17078)
* server : handle failures to restore host cache * server : add tests for the prompt cache
1 parent ef1d826 commit cb1adf8

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

tools/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,9 @@ struct server_slot {
16901690
bool res = prompt_cache.load(prompt, tokens, ctx, id);
16911691
if (!res) {
16921692
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
1693+
1694+
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
1695+
prompt.tokens.clear();
16931696
}
16941697
}
16951698

tools/server/tests/unit/test_completion.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22
import requests
33
import time
4+
import random
5+
46
from openai import OpenAI
57
from utils import *
68

@@ -564,3 +566,43 @@ def test_cancel_request():
564566
time.sleep(1) # wait for HTTP_POLLING_SECONDS
565567
res = server.make_request("GET", "/slots")
566568
assert res.body[0]["is_processing"] == False
569+
570+
571+
# this test exercises the host-memory prompt cache
572+
# ref: https://github.com/ggml-org/llama.cpp/pull/16391
573+
# ref: https://github.com/ggml-org/llama.cpp/pull/17078
574+
def test_completion_prompt_cache():
575+
global server
576+
server.n_slots = 2
577+
server.kv_unified = True
578+
server.start()
579+
580+
for _ in range(16):
581+
# generate alternating random prompts with variable lengths in order to get them in and out of the cache
582+
r = random.randint(0, 4)
583+
prompt = (" Hello " + str(r)) * (40 + r)
584+
n_prompt = (40 + r)*5 + 2
585+
n_predict = random.randint(1, 8)
586+
587+
res = server.make_request(
588+
"POST",
589+
"/completion",
590+
data={
591+
"prompt": prompt,
592+
"n_predict": n_predict,
593+
},
594+
)
595+
596+
assert res.status_code == 200
597+
assert "content" in res.body
598+
content = res.body["content"]
599+
assert isinstance(content, str)
600+
assert len(content) > 0
601+
602+
assert type(res.body["has_new_line"]) == bool
603+
assert "timings" in res.body
604+
timings = res.body["timings"]
605+
606+
assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
607+
assert "predicted_n" in timings and timings["predicted_n"] == n_predict
608+
assert "tokens" in res.body and isinstance(res.body["tokens"], list)

0 commit comments

Comments
 (0)