|
1 | | -"""Async and dict-like interfaces for interacting with Repl.it Database.""" |
| 1 | +"""Async and dict-like interfaces for interacting with Replit Database.""" |
2 | 2 |
|
3 | 3 | from collections import abc |
4 | 4 | import json |
| 5 | +import threading |
5 | 6 | from typing import ( |
6 | 7 | Any, |
7 | 8 | Callable, |
@@ -61,24 +62,57 @@ def dumps(val: Any) -> str: |
61 | 62 |
|
62 | 63 |
|
63 | 64 | class AsyncDatabase: |
64 | | - """Async interface for Repl.it Database.""" |
| 65 | + """Async interface for Replit Database. |
65 | 66 |
|
66 | | - __slots__ = ("db_url", "sess", "client") |
| 67 | + :param str db_url: The Database URL to connect to |
| 68 | + :param int retry_count: How many retry attempts we should make |
| 69 | + :param get_db_url Callable: A callback that returns the current db_url |
| 70 | + :param unbind Callable: Permit additional behavior after Database close |
| 71 | + """ |
| 72 | + |
| 73 | + __slots__ = ("db_url", "sess", "client", "_get_db_url", "_unbind", "_refresh_timer") |
| 74 | + _refresh_timer: Optional[threading.Timer] |
67 | 75 |
|
68 | | - def __init__(self, db_url: str, retry_count: int = 5) -> None: |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + db_url: str, |
| 79 | + retry_count: int = 5, |
| 80 | + get_db_url: Optional[Callable[[], Optional[str]]] = None, |
| 81 | + unbind: Optional[Callable[[], None]] = None, |
| 82 | + ) -> None: |
69 | 83 | """Initialize database. You shouldn't have to do this manually. |
70 | 84 |
|
71 | 85 | Args: |
72 | 86 | db_url (str): Database url to use. |
73 | 87 | retry_count (int): How many times to retry connecting |
74 | 88 | (with exponential backoff) |
| 89 | + get_db_url (callable[[], str]): A function that will be called to refresh |
| 90 | + the db_url property |
| 91 | + unbind (callable[[], None]): A callback to clean up after .close() is called |
75 | 92 | """ |
76 | 93 | self.db_url = db_url |
77 | 94 | self.sess = aiohttp.ClientSession() |
| 95 | + self._get_db_url = get_db_url |
| 96 | + self._unbind = unbind |
78 | 97 |
|
79 | 98 | retry_options = ExponentialRetry(attempts=retry_count) |
80 | 99 | self.client = RetryClient(client_session=self.sess, retry_options=retry_options) |
81 | 100 |
|
| 101 | + if self._get_db_url: |
| 102 | + self._refresh_timer = threading.Timer(3600, self._refresh_db) |
| 103 | + self._refresh_timer.start() |
| 104 | + |
| 105 | + def _refresh_db(self) -> None: |
| 106 | + if self._refresh_timer: |
| 107 | + self._refresh_timer.cancel() |
| 108 | + self._refresh_timer = None |
| 109 | + if self._get_db_url: |
| 110 | + db_url = self._get_db_url() |
| 111 | + if db_url: |
| 112 | + self.update_db_url(db_url) |
| 113 | + self._refresh_timer = threading.Timer(3600, self._refresh_db) |
| 114 | + self._refresh_timer.start() |
| 115 | + |
82 | 116 | def update_db_url(self, db_url: str) -> None: |
83 | 117 | """Update the database url. |
84 | 118 |
|
@@ -239,6 +273,16 @@ async def items(self) -> Tuple[Tuple[str, str], ...]: |
239 | 273 | """ |
240 | 274 | return tuple((await self.to_dict()).items()) |
241 | 275 |
|
| 276 | + async def close(self) -> None: |
| 277 | + """Closes the database client connection.""" |
| 278 | + await self.sess.close() |
| 279 | + if self._refresh_timer: |
| 280 | + self._refresh_timer.cancel() |
| 281 | + self._refresh_timer = None |
| 282 | + if self._unbind: |
| 283 | + # Permit signaling to surrounding scopes that we have closed |
| 284 | + self._unbind() |
| 285 | + |
242 | 286 | def __repr__(self) -> str: |
243 | 287 | """A representation of the database. |
244 | 288 |
|
@@ -417,30 +461,62 @@ def item_to_observed(on_mutate: Callable[[Any], None], item: Any) -> Any: |
417 | 461 |
|
418 | 462 |
|
419 | 463 | class Database(abc.MutableMapping): |
420 | | - """Dictionary-like interface for Repl.it Database. |
| 464 | + """Dictionary-like interface for Replit Database. |
421 | 465 |
|
422 | 466 | This interface will coerce all values everything to and from JSON. If you |
423 | 467 | don't want this, use AsyncDatabase instead. |
| 468 | +
|
| 469 | + :param str db_url: The Database URL to connect to |
| 470 | + :param int retry_count: How many retry attempts we should make |
| 471 | + :param get_db_url Callable: A callback that returns the current db_url |
| 472 | + :param unbind Callable: Permit additional behavior after Database close |
424 | 473 | """ |
425 | 474 |
|
426 | | - __slots__ = ("db_url", "sess") |
| 475 | + __slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer") |
| 476 | + _refresh_timer: Optional[threading.Timer] |
427 | 477 |
|
428 | | - def __init__(self, db_url: str, retry_count: int = 5) -> None: |
| 478 | + def __init__( |
| 479 | + self, |
| 480 | + db_url: str, |
| 481 | + retry_count: int = 5, |
| 482 | + get_db_url: Optional[Callable[[], Optional[str]]] = None, |
| 483 | + unbind: Optional[Callable[[], None]] = None, |
| 484 | + ) -> None: |
429 | 485 | """Initialize database. You shouldn't have to do this manually. |
430 | 486 |
|
431 | 487 | Args: |
432 | 488 | db_url (str): Database url to use. |
433 | 489 | retry_count (int): How many times to retry connecting |
434 | 490 | (with exponential backoff) |
| 491 | + get_db_url (callable[[], str]): A function that will be called to refresh |
| 492 | + the db_url property |
| 493 | + unbind (callable[[], None]): A callback to clean up after .close() is called |
435 | 494 | """ |
436 | 495 | self.db_url = db_url |
437 | 496 | self.sess = requests.Session() |
| 497 | + self._get_db_url = get_db_url |
| 498 | + self._unbind = unbind |
438 | 499 | retries = Retry( |
439 | 500 | total=retry_count, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504] |
440 | 501 | ) |
441 | 502 | self.sess.mount("http://", HTTPAdapter(max_retries=retries)) |
442 | 503 | self.sess.mount("https://", HTTPAdapter(max_retries=retries)) |
443 | 504 |
|
| 505 | + if self._get_db_url: |
| 506 | + self._refresh_timer = threading.Timer(3600, self._refresh_db) |
| 507 | + self._refresh_timer.start() |
| 508 | + |
| 509 | + def _refresh_db(self) -> None: |
| 510 | + if self._refresh_timer: |
| 511 | + self._refresh_timer.cancel() |
| 512 | + self._refresh_timer = None |
| 513 | + if self._get_db_url: |
| 514 | + db_url = self._get_db_url() |
| 515 | + if db_url: |
| 516 | + self.update_db_url(db_url) |
| 517 | + self._refresh_timer = threading.Timer(3600, self._refresh_db) |
| 518 | + self._refresh_timer.start() |
| 519 | + |
444 | 520 | def update_db_url(self, db_url: str) -> None: |
445 | 521 | """Update the database url. |
446 | 522 |
|
@@ -627,3 +703,9 @@ def __repr__(self) -> str: |
627 | 703 | def close(self) -> None: |
628 | 704 | """Closes the database client connection.""" |
629 | 705 | self.sess.close() |
| 706 | + if self._refresh_timer: |
| 707 | + self._refresh_timer.cancel() |
| 708 | + self._refresh_timer = None |
| 709 | + if self._unbind: |
| 710 | + # Permit signaling to surrounding scopes that we have closed |
| 711 | + self._unbind() |
0 commit comments