Skip to content

Commit 416de87

Browse files
committed
Change PubSub usage and add make use of global connection object
1 parent fc3ce1c commit 416de87

File tree

4 files changed

+133
-123
lines changed

4 files changed

+133
-123
lines changed

pulpcore/tasking/pubsub.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,35 @@
22
from pulpcore.constants import TASK_PUBSUB
33
import os
44
import logging
5+
from django.db import connection
56
from contextlib import suppress
67

78
logger = logging.getLogger(__name__)
89

910

1011
class BasePubSubBackend:
1112
# Utils
12-
def wakeup_worker(self, reason="unknown"):
13-
self.publish(TASK_PUBSUB.WAKEUP_WORKER, reason)
13+
@classmethod
14+
def wakeup_worker(cls, reason="unknown"):
15+
cls.publish(TASK_PUBSUB.WAKEUP_WORKER, reason)
1416

15-
def cancel_task(self, task_pk):
16-
self.publish(TASK_PUBSUB.CANCEL_TASK, str(task_pk))
17+
@classmethod
18+
def cancel_task(cls, task_pk):
19+
cls.publish(TASK_PUBSUB.CANCEL_TASK, str(task_pk))
1720

18-
def record_worker_metrics(self, now):
19-
self.publish(TASK_PUBSUB.WORKER_METRICS, str(now))
21+
@classmethod
22+
def record_worker_metrics(cls, now):
23+
cls.publish(TASK_PUBSUB.WORKER_METRICS, str(now))
2024

2125
# Interface
22-
def subscribe(self, channel, callback):
26+
def subscribe(self, channel):
2327
raise NotImplementedError()
2428

2529
def unsubscribe(self, channel):
2630
raise NotImplementedError()
2731

28-
def publish(self, channel, message=None):
32+
@staticmethod
33+
def publish(channel, payload=None):
2934
raise NotImplementedError()
3035

3136
def fileno(self):
@@ -53,19 +58,19 @@ def drain_non_blocking_fd(fd):
5358

5459
class PostgresPubSub(BasePubSubBackend):
5560

56-
def __init__(self, connection):
57-
self.cursor = connection.cursor()
58-
self.connection = connection.connection
59-
assert self.cursor.connection is self.connection
61+
def __init__(self):
6062
self.subscriptions = []
6163
self.message_buffer = []
62-
self.connection.add_notify_handler(self._store_messages)
64+
# Ensures a connection is established
65+
if not connection.connection:
66+
with connection.cursor():
67+
pass
68+
connection.connection.add_notify_handler(self._store_messages)
6369
# Handle message readiness
6470
# We can use os.evenfd in python >= 3.10
6571
self.sentinel_r, self.sentinel_w = os.pipe()
6672
os.set_blocking(self.sentinel_r, False)
6773
os.set_blocking(self.sentinel_w, False)
68-
logger.debug(f"Initialized pubsub. Conn={self.connection}")
6974

7075
def _store_messages(self, notification):
7176
self.message_buffer.append(
@@ -74,20 +79,25 @@ def _store_messages(self, notification):
7479

7580
def subscribe(self, channel):
7681
self.subscriptions.append(channel)
77-
self.connection.execute(f"LISTEN {channel}")
82+
with connection.cursor() as cursor:
83+
cursor.execute(f"LISTEN {channel}")
7884

7985
def unsubscribe(self, channel):
8086
self.subscriptions.remove(channel)
8187
for i in range(0, len(self.message_buffer), -1):
8288
if self.message_buffer[i].channel == channel:
8389
self.message_buffer.pop(i)
84-
self.connection.execute(f"UNLISTEN {channel}")
85-
86-
def publish(self, channel, message=None):
87-
if not message:
88-
self.cursor.execute(f"NOTIFY {channel}")
90+
with connection.cursor() as cursor:
91+
cursor.execute(f"UNLISTEN {channel}")
92+
93+
@staticmethod
94+
def publish(channel, payload=None):
95+
if not payload:
96+
with connection.cursor() as cursor:
97+
cursor.execute(f"NOTIFY {channel}")
8998
else:
90-
self.cursor.execute("SELECT pg_notify(%s, %s)", (channel, message))
99+
with connection.cursor() as cursor:
100+
cursor.execute("SELECT pg_notify(%s, %s)", (channel, str(payload)))
91101

92102
def fileno(self) -> int:
93103
if self.message_buffer:
@@ -97,18 +107,25 @@ def fileno(self) -> int:
97107
return self.sentinel_r
98108

99109
def fetch(self) -> list[PubsubMessage]:
100-
self.connection.execute("SELECT 1").fetchone()
110+
with connection.cursor() as cursor:
111+
cursor.execute("SELECT 1").fetchone()
101112
result = self.message_buffer.copy()
102113
self.message_buffer.clear()
103114
return result
104115

105116
def close(self):
106117
os.close(self.sentinel_r)
107118
os.close(self.sentinel_w)
108-
self.cursor.close()
119+
self.message_buffer.clear()
120+
connection.connection.remove_notify_handler(self._store_messages)
121+
for channel in self.subscriptions:
122+
self.unsubscribe(channel)
109123

110124
def __enter__(self):
111125
return self
112126

113127
def __exit__(self, exc_type, exc_value, traceback):
114128
self.close()
129+
130+
131+
backend = PostgresPubSub

pulpcore/tasking/tasks.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,7 @@ def dispatch(
304304
task.set_canceling()
305305
task.set_canceled(TASK_STATES.CANCELED, "Resources temporarily unavailable.")
306306
if notify_workers:
307-
with pubsub.PostgresPubSub(connection) as pubsub_client:
308-
pubsub_client.wakeup_worker(reason=TASK_WAKEUP_UNBLOCK)
307+
pubsub.backend.wakeup_worker(reason=TASK_WAKEUP_UNBLOCK)
309308
return task
310309

311310

@@ -339,9 +338,8 @@ def cancel_task(task_id):
339338
# This is the only valid transition without holding the task lock
340339
task.set_canceling()
341340
# Notify the worker that might be running that task and other workers to clean up
342-
with pubsub.PostgresPubSub(connection) as pubsub_client:
343-
pubsub_client.cancel_task(task_pk=task.pk)
344-
pubsub_client.wakeup_worker()
341+
pubsub.backend.cancel_task(task_pk=task.pk)
342+
pubsub.backend.wakeup_worker()
345343
return task
346344

347345

pulpcore/tasking/worker.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from packaging.version import parse as parse_version
1414

1515
from django.conf import settings
16-
from django.db import connection
1716
from django.db.models import Case, Count, F, Max, Value, When
1817
from django.db.models.functions import Random
1918
from django.utils import timezone
@@ -83,7 +82,7 @@ def __init__(self, auxiliary=False):
8382
int(WORKER_CLEANUP_INTERVAL / 10), WORKER_CLEANUP_INTERVAL
8483
)
8584
# Pubsub handling
86-
self.pubsub_client = pubsub.PostgresPubSub(connection)
85+
self.pubsub_client = pubsub.backend()
8786
self.pubsub_channel_callback = {}
8887

8988
# Add a file descriptor to trigger select on signals
@@ -576,9 +575,7 @@ def metric_callback(message):
576575
self.pubsub_channel_callback[TASK_PUBSUB.WORKER_METRICS] = metric_callback
577576

578577
def pubsub_teardown(self):
579-
self.pubsub_client.unsubscribe(TASK_PUBSUB.WAKEUP_WORKER)
580-
self.pubsub_client.unsubscribe(TASK_PUBSUB.CANCEL_TASK)
581-
self.pubsub_client.unsubscribe(TASK_PUBSUB.WORKER_METRICS)
578+
self.pubsub_client.close()
582579

583580
def run(self, burst=False):
584581
with WorkerDirectory(self.name):
Lines changed: 88 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,116 @@
11
from django.db import connection
22
from pulpcore.tasking import pubsub
33
from types import SimpleNamespace
4+
from datetime import datetime
45
import select
56
import pytest
67

78

89
def test_postgres_pubsub():
10+
"""Testing postgres low-level implementation."""
911
state = SimpleNamespace()
10-
state.got_first_message = False
11-
state.got_second_message = False
12+
state.got_message = False
1213
with connection.cursor() as cursor:
1314
assert connection.connection is cursor.connection
1415
conn = cursor.connection
16+
# Listen and Notify
1517
conn.execute("LISTEN abc")
1618
conn.add_notify_handler(lambda notification: setattr(state, "got_message", True))
1719
cursor.execute("NOTIFY abc, 'foo'")
20+
assert state.got_message is True
1821
conn.execute("SELECT 1")
22+
assert state.got_message is True
23+
24+
# Reset and retry
25+
state.got_message = False
1926
conn.execute("UNLISTEN abc")
20-
assert state.got_message is True
27+
cursor.execute("NOTIFY abc, 'foo'")
28+
assert state.got_message is False
2129

2230

2331
M = pubsub.PubsubMessage
2432

33+
PUBSUB_BACKENDS = [
34+
pytest.param(pubsub.PostgresPubSub, id="and-using-postgres-backend"),
35+
]
36+
37+
38+
@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
39+
class TestPublish:
40+
41+
@pytest.mark.parametrize(
42+
"payload",
43+
(
44+
pytest.param(None, id="none"),
45+
pytest.param("", id="empty-string"),
46+
pytest.param("payload", id="non-empty-string"),
47+
pytest.param(123, id="int"),
48+
pytest.param(datetime.now(), id="datetime"),
49+
pytest.param(True, id="bool"),
50+
),
51+
)
52+
def test_with_payload_as(self, pubsub_backend, payload):
53+
pubsub_backend.publish("channel", payload=payload)
54+
2555

56+
@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
2657
@pytest.mark.parametrize(
2758
"messages",
2859
(
29-
[M("channel_a", "A1")],
30-
[M("channel_a", "A1"), M("channel_a", "A2")],
31-
[M("channel_a", "A1"), M("channel_a", "A2"), M("channel_b", "B1"), M("channel_c", "C1")],
60+
pytest.param([M("a", "A1")], id="single-message"),
61+
pytest.param([M("a", "A1"), M("a", "A2")], id="two-messages-in-same-channel"),
62+
pytest.param(
63+
[M("a", "A1"), M("a", "A2"), M("b", "B1"), M("c", "C1")],
64+
id="tree-msgs-in-different-channels",
65+
),
3266
),
3367
)
34-
@pytest.mark.parametrize("same_client", (True, False), ids=("same-clients", "different-clients"))
35-
class TestPubSub:
36-
37-
def test_subscribe_publish_fetch(self, same_client, messages):
38-
"""
39-
GIVEN a publisher and a subscriber (which may be the same)
40-
AND a queue of messages Q with mixed channels and payloads
41-
WHEN the subscriber subscribes to all the channels in Q
42-
AND the publisher publishes all the messages in Q
43-
THEN the subscriber fetch() call returns a queue equivalent to Q
44-
AND calling fetch() a second time returns an empty queue
45-
"""
46-
# Given
47-
publisher = pubsub.PostgresPubSub(connection)
48-
subscriber = publisher if same_client else pubsub.PostgresPubSub(connection)
49-
50-
# When
51-
for message in messages:
52-
subscriber.subscribe(message.channel)
53-
for message in messages:
54-
publisher.publish(message.channel, message=message.payload)
55-
56-
# Then
57-
assert subscriber.fetch() == messages
58-
assert subscriber.fetch() == []
59-
60-
def test_unsubscribe(self, same_client, messages):
61-
"""
62-
GIVEN a publisher and a subscriber (which may be the same)
63-
AND a queue of messages Q with mixed channels and payloads
64-
WHEN the subscriber subscribes and unsubscribes to all the channels in Q
65-
AND the publisher publishes all the messages in Q
66-
THEN the subscriber fetch() call returns an empty queue
67-
"""
68-
# Given
69-
publisher = pubsub.PostgresPubSub(connection)
70-
subscriber = publisher if same_client else pubsub.PostgresPubSub(connection)
71-
72-
# When
73-
for message in messages:
74-
subscriber.subscribe(message.channel)
75-
for message in messages:
76-
subscriber.unsubscribe(message.channel)
77-
for message in messages:
78-
publisher.publish(message.channel, message=message.payload)
79-
80-
# Then
81-
assert subscriber.fetch() == []
82-
83-
def test_select_loop(self, same_client, messages):
84-
"""
85-
GIVEN a publisher and a subscriber (which may be the same)
86-
AND a queue of messages Q with mixed channels and payloads
87-
AND the subscriber is subscribed to all the channels in Q
88-
WHEN the publisher has NOT published anything yet
89-
THEN the select loop won't detect the subscriber readiness
90-
AND the subscriber fetch() call returns an empty queue
91-
BUT WHEN the publisher does publish all messages in Q
92-
THEN the select loop detects the subscriber readiness
93-
AND the subscriber fetch() call returns a queue equivalent to Q
94-
"""
68+
class TestSubscribeFetch:
69+
def unsubscribe_all(self, channels, subscriber):
70+
for channel in channels:
71+
subscriber.unsubscribe(channel)
72+
73+
def subscribe_all(self, channels, subscriber):
74+
for channel in channels:
75+
subscriber.subscribe(channel)
76+
77+
def publish_all(self, messages, publisher):
78+
for channel, payload in messages:
79+
publisher.publish(channel, payload=payload)
80+
81+
def test_with(
82+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
83+
):
84+
channels = {m.channel for m in messages}
85+
publisher = pubsub_backend
86+
with pubsub_backend() as subscriber:
87+
self.subscribe_all(channels, subscriber)
88+
self.publish_all(messages, publisher)
89+
assert subscriber.fetch() == messages
90+
91+
self.unsubscribe_all(channels, subscriber)
92+
assert subscriber.fetch() == []
93+
94+
def test_select_readiness_with(
95+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
96+
):
9597
TIMEOUT = 0.1
96-
97-
# Given
98-
publisher = pubsub.PostgresPubSub(connection)
99-
subscriber = publisher if same_client else pubsub.PostgresPubSub(connection)
100-
101-
# When
102-
for message in messages:
103-
subscriber.subscribe(message.channel)
104-
r, w, x = select.select([subscriber], [], [], TIMEOUT)
105-
106-
# Then
107-
assert subscriber not in r
108-
assert subscriber.fetch() == []
109-
110-
# But When
111-
for message in messages:
112-
publisher.publish(message.channel, message=message.payload)
113-
r, w, x = select.select([subscriber], [], [], TIMEOUT)
114-
115-
# Then
116-
assert subscriber in r
117-
assert subscriber.fetch() == messages
118-
assert subscriber.fetch() == []
98+
channels = {m.channel for m in messages}
99+
publisher = pubsub_backend
100+
with pubsub_backend() as subscriber:
101+
self.subscribe_all(channels, subscriber)
102+
r, w, x = select.select([subscriber], [], [], TIMEOUT)
103+
assert subscriber not in r
104+
assert subscriber.fetch() == []
105+
106+
self.publish_all(messages, publisher)
107+
r, w, x = select.select([subscriber], [], [], TIMEOUT)
108+
assert subscriber in r
109+
assert subscriber.fetch() == messages
110+
assert subscriber.fetch() == []
111+
112+
self.unsubscribe_all(channels, subscriber)
113+
self.publish_all(messages, publisher)
114+
r, w, x = select.select([subscriber], [], [], TIMEOUT)
115+
assert subscriber not in r
116+
assert subscriber.fetch() == []

0 commit comments

Comments
 (0)