Skip to content

Commit be4551e

Browse files
committed
[python] Support session based sticky routing in async mode
1 parent ea2502d commit be4551e

File tree

11 files changed

+442
-23
lines changed

11 files changed

+442
-23
lines changed

.github/workflows/integration.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ jobs:
173173
- test: TestCorrectnessTrtLlm
174174
instance: g6
175175
failure-prefix: trtllm
176-
176+
- test: TestStatefulModel
177+
instance: g6
178+
failure-prefix: lmi
177179
outputs:
178180
failure_cpu: ${{ steps.test-failure.outputs.failure_cpu }}
179181
failure_gpu: ${{ steps.test-failure.outputs.failure_gpu }}

engines/python/setup/djl_python/async_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919

2020
def create_non_stream_output(data: Union[str, dict],
21+
properties: Optional[dict] = None,
2122
error: Optional[str] = None,
2223
code: Optional[int] = None) -> Output:
2324
return _create_output(
2425
data,
2526
True,
2627
"application/json",
28+
properties=properties,
2729
error=error,
2830
code=code,
2931
)
@@ -46,6 +48,7 @@ def _create_output(
4648
data: Union[str, dict],
4749
last_chunk: bool,
4850
content_type: str,
51+
properties: Optional[dict] = None,
4952
error: Optional[str] = None,
5053
code: Optional[int] = None,
5154
) -> Output:
@@ -65,6 +68,9 @@ def _create_output(
6568
response_dict["code"] = code
6669
output = Output()
6770
output.add_property("Content-Type", content_type)
71+
if properties:
72+
for k, v in properties.items():
73+
output.add_property(k, v)
6874
output.add(Output.binary_encode(response_dict))
6975
return output
7076

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131
from djl_python.outputs import Output
3232
from djl_python.encode_decode import decode
3333
from djl_python.async_utils import handle_streaming_response, create_non_stream_output
34-
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
35-
36-
from .request_response_utils import (
34+
from djl_python.custom_formatter_handling import CustomFormatterError, CustomFormatterHandler
35+
from djl_python.lmi_vllm.request_response_utils import (
3736
ProcessedRequest,
3837
vllm_stream_output_formatter,
3938
vllm_non_stream_output_formatter,
@@ -43,9 +42,15 @@
4342
lmi_with_details_non_stream_output_formatter,
4443
lmi_non_stream_output_formatter,
4544
)
45+
from djl_python.session_manager import SessionManager
46+
from djl_python.session_utils import (create_session, close_session,
47+
get_session,
48+
session_non_stream_output_formatter)
4649

4750
logger = logging.getLogger(__name__)
4851

52+
SESSION_REQUESTS = {"NEW_SESSION": create_session, "CLOSE": close_session}
53+
4954

5055
class VLLMHandler(CustomFormatterHandler):
5156

@@ -119,12 +124,15 @@ async def initialize(self, properties: dict):
119124
tool_parser=self.vllm_properties.tool_call_parser,
120125
reasoning_parser=self.vllm_properties.reasoning_parser,
121126
)
127+
if properties.get("enable_stateful_sessions", "true") == "true":
128+
self.session_manager: SessionManager = SessionManager(properties)
122129
self.initialized = True
123130

124131
def preprocess_request(self, inputs: Input) -> ProcessedRequest:
125132
batch = inputs.get_batches()
126133
assert len(batch) == 1, "only one request per batch allowed"
127134
raw_request = batch[0]
135+
session = get_session(self.session_manager, raw_request)
128136
content_type = raw_request.get_property("Content-Type")
129137
decoded_payload = decode(raw_request, content_type)
130138

@@ -160,6 +168,20 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
160168
vllm_invoke_function = self.chat_completion_service.create_chat_completion
161169
non_stream_output_formatter = vllm_non_stream_output_formatter
162170
stream_output_formatter = vllm_stream_output_formatter
171+
elif "requestType" in decoded_payload:
172+
request_type = decoded_payload["requestType"]
173+
if request_type not in SESSION_REQUESTS.keys():
174+
raise RuntimeError(
175+
f"invalid payload. request type must be one of {SESSION_REQUESTS.keys()}"
176+
)
177+
if self.session_manager is None:
178+
raise RuntimeError(
179+
f"invalid payload. stateful sessions not enabled, {request_type} not supported"
180+
)
181+
vllm_request = self.session_manager, inputs
182+
vllm_invoke_function = SESSION_REQUESTS[request_type]
183+
non_stream_output_formatter = session_non_stream_output_formatter
184+
stream_output_formatter = vllm_stream_output_formatter
163185
else:
164186
raise RuntimeError(
165187
"invalid payload. must contain prompt, inputs, or messages")

engines/python/setup/djl_python/session_manager.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,11 @@ def get(self, key: str, d=None):
5353
return pickle.load(f)
5454

5555
def remove(self):
56-
if os.path.exists(self.files_path):
57-
logging.info(f"closing session: {self.session_id}")
58-
shutil.rmtree(self.files_path)
59-
return True
60-
else:
61-
logging.warning(f"session not found: {self.session_id}")
62-
return False
56+
if not os.path.exists(self.files_path):
57+
raise ValueError(f"session not found: {self.session_id}")
58+
logging.info(f"closing session: {self.session_id}")
59+
shutil.rmtree(self.files_path)
60+
return True
6361

6462
def _path(self, key: str):
6563
return os.path.join(self.files_path, key.replace("/", "-"))
@@ -92,18 +90,32 @@ def create_session(self) -> Session:
9290
session_id = str(uuid.uuid4())
9391
session = Session(session_id, self.sessions_path)
9492
os.makedirs(session.files_path)
95-
session.put(".creation_time", time.time())
93+
expiration_time = time.time() + self.expiration
94+
session.put(".expiration_time", expiration_time)
9695

9796
self.cloud_watch.post("create_session")
9897
return session
9998

10099
def get_session(self, session_id: str) -> Optional[Session]:
101-
if not session_id or not UUID_PATTERN.match(session_id):
102-
raise ValueError(f"invalid session_id: {session_id}")
100+
if session_id == "NEW_SESSION" or not session_id:
101+
return None
103102

103+
if not UUID_PATTERN.match(session_id):
104+
logging.warning(f"invalid session_id: {session_id}")
105+
return None
106+
107+
# Session expired
104108
session = Session(session_id, self.sessions_path)
109+
if session.get(".expiration_time") is not None \
110+
and time.time() > session.get(".expiration_time"):
111+
return None
112+
113+
# Session not found, try to recover from s3 bucket
105114
if not os.path.exists(session.files_path):
106-
return self._recover_from_s3(session)
115+
session = self._recover_from_s3(session)
116+
if session is None:
117+
raise ValueError(f"session not found: {session_id}")
118+
return session
107119

108120
return session
109121

@@ -119,7 +131,8 @@ def _clean_expired_session(self):
119131
sessions = os.listdir(self.sessions_path)
120132
for session_id in sessions:
121133
session = Session(session_id, self.sessions_path)
122-
if time.time() - session.get(".creation_time") > self.expiration:
134+
if session.get(".expiration_time") is None \
135+
or time.time() > session.get(".expiration_time"):
123136
self.close_session(session_id)
124137

125138
def _recover_from_s3(self, session: Session) -> Optional[Session]:
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
import datetime
14+
import logging
15+
16+
from djl_python.async_utils import create_non_stream_output
17+
from djl_python.outputs import Output
18+
19+
logger = logging.getLogger(__name__)
20+
21+
HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
22+
HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id"
23+
24+
25+
async def create_session(request):
26+
session_manager, inputs = request
27+
try:
28+
session = session_manager.create_session()
29+
expiration_ts = datetime.datetime.fromtimestamp(
30+
session.get(".expiration_time")).strftime("%Y-%m-%dT%H:%M:%SZ")
31+
logger.info(f"Session {session.session_id} created")
32+
return {
33+
"data": {
34+
"result": f"Session {session.session_id} created"
35+
},
36+
"properties": {
37+
HEADER_SAGEMAKER_SESSION_ID:
38+
f"{session.session_id}; Expires={expiration_ts}"
39+
}
40+
}
41+
except Exception as e:
42+
return {"error": f"Failed to create session: {str(e)}", "code": 424}
43+
44+
45+
async def close_session(request):
46+
session_manager, inputs = request
47+
session_id = inputs.get_property(HEADER_SAGEMAKER_SESSION_ID)
48+
try:
49+
session_manager.close_session(session_id)
50+
logger.info(f"Session {session_id} closed")
51+
return {
52+
"data": {
53+
"result": f"Session {session_id} closed"
54+
},
55+
"properties": {
56+
HEADER_SAGEMAKER_CLOSED_SESSION_ID: f"{session_id}"
57+
}
58+
}
59+
except Exception as e:
60+
return {"error": f"Failed to close session: {str(e)}", "code": 424}
61+
62+
63+
def get_session(session_manager, request):
64+
session_id = request.get_property(HEADER_SAGEMAKER_SESSION_ID)
65+
if session_manager is None:
66+
if session_id is not None:
67+
raise RuntimeError(
68+
f"invalid payload. stateful sessions not enabled, {HEADER_SAGEMAKER_SESSION_ID} header not supported"
69+
)
70+
return None
71+
session = session_manager.get_session(session_id)
72+
return session
73+
74+
75+
def session_non_stream_output_formatter(
76+
response: dict,
77+
**_,
78+
) -> Output:
79+
if "error" in response:
80+
return create_non_stream_output("",
81+
error=response["error"],
82+
code=response["code"])
83+
84+
return create_non_stream_output(response["data"],
85+
properties=response.get("properties"))

engines/python/src/test/java/ai/djl/python/engine/PySessionsTests.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,23 @@ public void testLocalLoadSave()
8484

8585
// test session timeout
8686
Thread.sleep(1000);
87+
regular = new Input();
88+
regular.addProperty("Content-Type", "application/json");
89+
regular.addProperty("X-Amzn-SageMaker-Session-Id", sessionId);
90+
regular.add(BytesSupplier.wrapAsJson(Map.of("action", "regular")));
91+
ret = predictor.predict(regular);
92+
Assert.assertEquals(ret.getProperty("Content-Type", null), "application/json");
93+
Assert.assertTrue(ret.getAsString(0).contains("session not found"));
94+
long count;
95+
try (Stream<Path> files = Files.list(path)) {
96+
count = files.count();
97+
}
98+
Assert.assertEquals(count, 1);
99+
100+
// create a new session
87101
ret = predictor.predict(createSession);
88102
sessionId = ret.getProperty("X-Amzn-SageMaker-Session-Id", null);
89103
Assert.assertNotNull(sessionId);
90-
long count;
91104
try (Stream<Path> files = Files.list(path)) {
92105
count = files.count();
93106
}

0 commit comments

Comments
 (0)