Skip to content

Commit 7f9d643

Browse files
committed
Stop multiple authentications on single websocket
1 parent 4eb057e commit 7f9d643

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

oioioi/notifications/server/server.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ def __init__(self, port: int, amqp_url: str, auth_url: str) -> None:
1717
self.logger = logging.getLogger('oioioi')
1818

1919
self.app.on_start(self.on_start)
20-
self.app.ws("/", {
21-
"upgrade": self.on_ws_upgrade,
22-
"message": self.on_ws_message,
23-
"close": self.on_ws_close,
24-
})
20+
self.app.ws(
21+
"/",
22+
{
23+
"upgrade": self.on_ws_upgrade,
24+
"message": self.on_ws_message,
25+
"close": self.on_ws_close,
26+
},
27+
)
2528

2629
def run(self) -> None:
2730
"""Start the notification server."""
@@ -34,9 +37,9 @@ async def on_start(self) -> None:
3437
await self.queue.connect()
3538

3639
def on_ws_upgrade(self, res, req, socket_context):
37-
"""
40+
"""
3841
Taken from socketify's documentation.
39-
This method allows for storing extra data inside the websocket object.
42+
This method allows for storing extra data inside the websocket object.
4043
"""
4144

4245
key = req.get_header("sec-websocket-key")
@@ -47,7 +50,9 @@ def on_ws_upgrade(self, res, req, socket_context):
4750

4851
res.upgrade(key, protocol, extensions, socket_context, user_data)
4952

50-
async def on_ws_message(self, ws: WebSocket, msg: Union[bytes, str], opcode: OpCode) -> None:
53+
async def on_ws_message(
54+
self, ws: WebSocket, msg: Union[bytes, str], opcode: OpCode
55+
) -> None:
5156
"""Handle incoming WebSocket messages."""
5257
try:
5358
data = json.loads(msg)
@@ -61,15 +66,17 @@ async def on_ws_message(self, ws: WebSocket, msg: Union[bytes, str], opcode: OpC
6166
except Exception as e:
6267
self.logger.error(f"Error processing message: {str(e)}")
6368

64-
async def on_ws_close(self, ws: WebSocket, code: int, msg: Union[bytes, str]) -> None:
69+
async def on_ws_close(
70+
self, ws: WebSocket, code: int, msg: Union[bytes, str]
71+
) -> None:
6572
"""Handle WebSocket connection closure."""
6673
try:
6774
user_id = ws.get_user_data()["user_id"]
6875

6976
# If there are no more active connections for this user, unsubscribe from the RabbitMQ queue
7077
if user_id and self.app.num_subscribers(user_id) == 0:
7178
await self.queue.unsubscribe(user_id)
72-
79+
7380
self.logger.debug(f"WebSocket closed for user {user_id}")
7481

7582
except Exception as e:
@@ -85,6 +92,10 @@ def on_rabbit_message(self, user_name: str, msg: str) -> None:
8592

8693
async def on_ws_auth_message(self, ws: WebSocket, session_id: str) -> None:
8794
try:
95+
current_user_id = ws.get_user_data()["user_id"]
96+
if current_user_id:
97+
raise RuntimeError("Socket is already authenticated.")
98+
8899
user_id = await self.auth.authenticate(session_id)
89100

90101
ws.subscribe(user_id)
@@ -96,6 +107,8 @@ async def on_ws_auth_message(self, ws: WebSocket, session_id: str) -> None:
96107

97108
except Exception as e:
98109
self.logger.error(
99-
f"Authentication error for session {session_id}: {str(e)}")
100-
ws.send({"type": "SOCKET_AUTH_RESULT",
101-
"status": "ERR_AUTH_FAILED"}, OpCode.TEXT)
110+
f"Authentication error for session {session_id}: {str(e)}"
111+
)
112+
ws.send(
113+
{"type": "SOCKET_AUTH_RESULT", "status": "ERR_AUTH_FAILED"}, OpCode.TEXT
114+
)

oioioi/notifications/server/tests.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ async def test_on_ws_message_auth_success(self, mock_auth, mock_queue, mock_app)
199199
mock_auth.return_value.authenticate.return_value = "user_id"
200200

201201
mock_ws = MagicMock()
202+
mock_ws.get_user_data.return_value = {"user_id": None}
202203
message = json.dumps({"type": "SOCKET_AUTH", "session_id": "session_id"})
203204
await server.on_ws_message(mock_ws, message, OpCode.TEXT)
204205

@@ -217,6 +218,7 @@ async def test_on_ws_message_auth_failure(self, mock_auth, mock_queue, mock_app)
217218
)
218219

219220
mock_ws = MagicMock()
221+
mock_ws.get_user_data.return_value = {"user_id": None}
220222
message = json.dumps({"type": "SOCKET_AUTH", "session_id": "session_id"})
221223
await server.on_ws_message(mock_ws, message, OpCode.TEXT)
222224

@@ -226,6 +228,24 @@ async def test_on_ws_message_auth_failure(self, mock_auth, mock_queue, mock_app)
226228
mock_ws.send.assert_called_once_with(
227229
{"type": "SOCKET_AUTH_RESULT", "status": "ERR_AUTH_FAILED"}, OpCode.TEXT
228230
)
231+
232+
async def test_on_ws_message_auth_multiple_times(self, mock_auth, mock_queue, mock_app):
233+
server = Server(self.port, self.amqp_url, self.auth_url)
234+
235+
mock_ws = MagicMock()
236+
# Set up the mock websocket to return user data indicating it's already authenticated
237+
mock_ws.get_user_data.return_value = {"user_id": "user_id"}
238+
239+
message = json.dumps({"type": "SOCKET_AUTH", "session_id": "session_id"})
240+
await server.on_ws_message(mock_ws, message, OpCode.TEXT)
241+
242+
# Verify authentication was not attempted again
243+
mock_auth.return_value.authenticate.assert_not_called()
244+
mock_queue.return_value.subscribe.assert_not_called()
245+
mock_ws.subscribe.assert_not_called()
246+
mock_ws.send.assert_called_once_with(
247+
{"type": "SOCKET_AUTH_RESULT", "status": "ERR_AUTH_FAILED"}, OpCode.TEXT
248+
)
229249

230250
async def test_on_ws_close(self, mock_auth, mock_queue, mock_app):
231251
server = Server(self.port, self.amqp_url, self.auth_url)

0 commit comments

Comments
 (0)