Skip to content

Commit 5c0ab7a

Browse files
[replit] The lesser of two evils (#167)
* [replit] The lesser of two evils Currently the replit library has a very gross quirk: it has a global in `replit.database.default_db.db`, and the mere action of importing this library causes side effects to run! (connects to the database, starts a thread to refresh the URL, and prints a warning to stdout, adding insult to injury). So we're trading that very gross quirk with a gross workaround to preserve backwards compatibility: the modules that somehow end up importing that module now have a `__getattr__` that _lazily_ calls the code that used to be invoked as a side-effect of importing the library. Maybe in the future we'll deploy a breaking version of the library where we're not beholden to this backwards-compatibility quirck. * Marking internal properties as private Providing accessors, to hint that we are accessing mutable state * Reintroduce refresh_db noop to avoid errors on upgrade * Reflow LazyDB back down into default_db module An issue with LazyDB is that the refresh_db timer would not get canceled if the user closes the database. Additionally, the db_url refresh logic relies on injection, whereas the Database should ideally be the thing requesting that information from the environment * Removing stale main.sh --------- Co-authored-by: Devon Stewart <[email protected]>
1 parent 2e7528a commit 5c0ab7a

File tree

8 files changed

+200
-51
lines changed

8 files changed

+200
-51
lines changed

main.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/replit/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
"""The Replit Python module."""
44

5-
from . import web
5+
from typing import Any
6+
7+
from . import database, web
68
from .audio import Audio
79
from .database import (
8-
db,
910
Database,
1011
AsyncDatabase,
1112
make_database_proxy_blueprint,
@@ -23,3 +24,13 @@ def clear() -> None:
2324

2425

2526
audio = Audio()
27+
28+
29+
# Previous versions of this library would just have side-effects and always set
30+
# up a database unconditionally. That is very undesirable, so instead of doing
31+
# that, we are using this egregious hack to get the database / database URL
32+
# lazily.
33+
def __getattr__(name: str) -> Any:
34+
if name == "db":
35+
return database.db
36+
raise AttributeError(name)

src/replit/database/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Interface with the Replit Database."""
2+
from typing import Any
3+
4+
from . import default_db
25
from .database import AsyncDatabase, Database, DBJSONEncoder, dumps, to_primitive
3-
from .default_db import db, db_url
46
from .server import make_database_proxy_blueprint, start_database_proxy
57

68
__all__ = [
@@ -14,3 +16,15 @@
1416
"start_database_proxy",
1517
"to_primitive",
1618
]
19+
20+
21+
# Previous versions of this library would just have side-effects and always set
22+
# up a database unconditionally. That is very undesirable, so instead of doing
23+
# that, we are using this egregious hack to get the database / database URL
24+
# lazily.
25+
def __getattr__(name: str) -> Any:
26+
if name == "db":
27+
return default_db.db
28+
if name == "db_url":
29+
return default_db.db_url
30+
raise AttributeError(name)

src/replit/database/database.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
"""Async and dict-like interfaces for interacting with Repl.it Database."""
1+
"""Async and dict-like interfaces for interacting with Replit Database."""
22

33
from collections import abc
44
import json
5+
import threading
56
from typing import (
67
Any,
78
Callable,
@@ -61,24 +62,57 @@ def dumps(val: Any) -> str:
6162

6263

6364
class AsyncDatabase:
64-
"""Async interface for Repl.it Database."""
65+
"""Async interface for Replit Database.
6566
66-
__slots__ = ("db_url", "sess", "client")
67+
:param str db_url: The Database URL to connect to
68+
:param int retry_count: How many retry attempts we should make
69+
:param get_db_url Callable: A callback that returns the current db_url
70+
:param unbind Callable: Permit additional behavior after Database close
71+
"""
72+
73+
__slots__ = ("db_url", "sess", "client", "_get_db_url", "_unbind", "_refresh_timer")
74+
_refresh_timer: Optional[threading.Timer]
6775

68-
def __init__(self, db_url: str, retry_count: int = 5) -> None:
76+
def __init__(
77+
self,
78+
db_url: str,
79+
retry_count: int = 5,
80+
get_db_url: Optional[Callable[[], Optional[str]]] = None,
81+
unbind: Optional[Callable[[], None]] = None,
82+
) -> None:
6983
"""Initialize database. You shouldn't have to do this manually.
7084
7185
Args:
7286
db_url (str): Database url to use.
7387
retry_count (int): How many times to retry connecting
7488
(with exponential backoff)
89+
get_db_url (callable[[], str]): A function that will be called to refresh
90+
the db_url property
91+
unbind (callable[[], None]): A callback to clean up after .close() is called
7592
"""
7693
self.db_url = db_url
7794
self.sess = aiohttp.ClientSession()
95+
self._get_db_url = get_db_url
96+
self._unbind = unbind
7897

7998
retry_options = ExponentialRetry(attempts=retry_count)
8099
self.client = RetryClient(client_session=self.sess, retry_options=retry_options)
81100

101+
if self._get_db_url:
102+
self._refresh_timer = threading.Timer(3600, self._refresh_db)
103+
self._refresh_timer.start()
104+
105+
def _refresh_db(self) -> None:
106+
if self._refresh_timer:
107+
self._refresh_timer.cancel()
108+
self._refresh_timer = None
109+
if self._get_db_url:
110+
db_url = self._get_db_url()
111+
if db_url:
112+
self.update_db_url(db_url)
113+
self._refresh_timer = threading.Timer(3600, self._refresh_db)
114+
self._refresh_timer.start()
115+
82116
def update_db_url(self, db_url: str) -> None:
83117
"""Update the database url.
84118
@@ -239,6 +273,16 @@ async def items(self) -> Tuple[Tuple[str, str], ...]:
239273
"""
240274
return tuple((await self.to_dict()).items())
241275

276+
async def close(self) -> None:
277+
"""Closes the database client connection."""
278+
await self.sess.close()
279+
if self._refresh_timer:
280+
self._refresh_timer.cancel()
281+
self._refresh_timer = None
282+
if self._unbind:
283+
# Permit signaling to surrounding scopes that we have closed
284+
self._unbind()
285+
242286
def __repr__(self) -> str:
243287
"""A representation of the database.
244288
@@ -417,30 +461,62 @@ def item_to_observed(on_mutate: Callable[[Any], None], item: Any) -> Any:
417461

418462

419463
class Database(abc.MutableMapping):
420-
"""Dictionary-like interface for Repl.it Database.
464+
"""Dictionary-like interface for Replit Database.
421465
422466
This interface will coerce all values everything to and from JSON. If you
423467
don't want this, use AsyncDatabase instead.
468+
469+
:param str db_url: The Database URL to connect to
470+
:param int retry_count: How many retry attempts we should make
471+
:param get_db_url Callable: A callback that returns the current db_url
472+
:param unbind Callable: Permit additional behavior after Database close
424473
"""
425474

426-
__slots__ = ("db_url", "sess")
475+
__slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer")
476+
_refresh_timer: Optional[threading.Timer]
427477

428-
def __init__(self, db_url: str, retry_count: int = 5) -> None:
478+
def __init__(
479+
self,
480+
db_url: str,
481+
retry_count: int = 5,
482+
get_db_url: Optional[Callable[[], Optional[str]]] = None,
483+
unbind: Optional[Callable[[], None]] = None,
484+
) -> None:
429485
"""Initialize database. You shouldn't have to do this manually.
430486
431487
Args:
432488
db_url (str): Database url to use.
433489
retry_count (int): How many times to retry connecting
434490
(with exponential backoff)
491+
get_db_url (callable[[], str]): A function that will be called to refresh
492+
the db_url property
493+
unbind (callable[[], None]): A callback to clean up after .close() is called
435494
"""
436495
self.db_url = db_url
437496
self.sess = requests.Session()
497+
self._get_db_url = get_db_url
498+
self._unbind = unbind
438499
retries = Retry(
439500
total=retry_count, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]
440501
)
441502
self.sess.mount("http://", HTTPAdapter(max_retries=retries))
442503
self.sess.mount("https://", HTTPAdapter(max_retries=retries))
443504

505+
if self._get_db_url:
506+
self._refresh_timer = threading.Timer(3600, self._refresh_db)
507+
self._refresh_timer.start()
508+
509+
def _refresh_db(self) -> None:
510+
if self._refresh_timer:
511+
self._refresh_timer.cancel()
512+
self._refresh_timer = None
513+
if self._get_db_url:
514+
db_url = self._get_db_url()
515+
if db_url:
516+
self.update_db_url(db_url)
517+
self._refresh_timer = threading.Timer(3600, self._refresh_db)
518+
self._refresh_timer.start()
519+
444520
def update_db_url(self, db_url: str) -> None:
445521
"""Update the database url.
446522
@@ -627,3 +703,9 @@ def __repr__(self) -> str:
627703
def close(self) -> None:
628704
"""Closes the database client connection."""
629705
self.sess.close()
706+
if self._refresh_timer:
707+
self._refresh_timer.cancel()
708+
self._refresh_timer = None
709+
if self._unbind:
710+
# Permit signaling to surrounding scopes that we have closed
711+
self._unbind()

src/replit/database/default_db.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,58 @@
11
"""A module containing the default database."""
2-
from os import environ, path
3-
import threading
4-
from typing import Optional
5-
2+
import os
3+
import os.path
4+
from typing import Any, Optional
65

76
from .database import Database
87

98

10-
def get_db_url() -> str:
9+
def get_db_url() -> Optional[str]:
1110
"""Fetches the most up-to-date db url from the Repl environment."""
1211
# todo look into the security warning ignored below
1312
tmpdir = "/tmp/replitdb" # noqa: S108
14-
if path.exists(tmpdir):
13+
if os.path.exists(tmpdir):
1514
with open(tmpdir, "r") as file:
16-
db_url = file.read()
17-
else:
18-
db_url = environ.get("REPLIT_DB_URL")
15+
return file.read()
1916

20-
return db_url
17+
return os.environ.get("REPLIT_DB_URL")
2118

2219

2320
def refresh_db() -> None:
24-
"""Refresh the DB URL every hour."""
25-
global db
21+
"""Deprecated: refresh_db is now the responsibility of the Database instance."""
22+
pass
23+
24+
25+
def _unbind() -> None:
26+
global _db
27+
_db = None
28+
29+
30+
def _get_db() -> Optional[Database]:
31+
global _db
32+
if _db is not None:
33+
return _db
34+
2635
db_url = get_db_url()
27-
db.update_db_url(db_url)
28-
threading.Timer(3600, refresh_db).start()
2936

37+
if db_url:
38+
_db = Database(db_url, get_db_url=get_db_url, unbind=_unbind)
39+
else:
40+
# The user will see errors if they try to use the database.
41+
print("Warning: error initializing database. Replit DB is not configured.")
42+
_db = None
43+
return _db
44+
45+
46+
_db: Optional[Database] = None
3047

31-
db: Optional[Database]
32-
db_url = get_db_url()
33-
if db_url:
34-
db = Database(db_url)
35-
else:
36-
# The user will see errors if they try to use the database.
37-
print("Warning: error initializing database. Replit DB is not configured.")
38-
db = None
3948

40-
if db:
41-
refresh_db()
49+
# Previous versions of this library would just have side-effects and always set
50+
# up a database unconditionally. That is very undesirable, so instead of doing
51+
# that, we are using this egregious hack to get the database / database URL
52+
# lazily.
53+
def __getattr__(name: str) -> Any:
54+
if name == "db":
55+
return _get_db()
56+
if name == "db_url":
57+
return get_db_url()
58+
raise AttributeError(name)

src/replit/database/server.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from flask import Blueprint, Flask, request
66

7-
from .default_db import db
7+
from . import default_db
88

99

1010
def make_database_proxy_blueprint(view_only: bool, prefix: str = "") -> Blueprint:
@@ -20,21 +20,25 @@ def make_database_proxy_blueprint(view_only: bool, prefix: str = "") -> Blueprin
2020
app = Blueprint("database_proxy" + ("_view_only" if view_only else ""), __name__)
2121

2222
def list_keys() -> Any:
23-
user_prefix = request.args.get("prefix")
23+
if default_db.db is None:
24+
return "Database is not configured", 500
25+
user_prefix = request.args.get("prefix", "")
2426
encode = "encode" in request.args
25-
keys = db.prefix(prefix=prefix + user_prefix)
26-
keys = [k[len(prefix) :] for k in keys]
27+
raw_keys = default_db.db.prefix(prefix=prefix + user_prefix)
28+
keys = [k[len(prefix) :] for k in raw_keys]
2729

2830
if encode:
2931
return "\n".join(quote(k) for k in keys)
3032
else:
3133
return "\n".join(keys)
3234

3335
def set_key() -> Any:
36+
if default_db.db is None:
37+
return "Database is not configured", 500
3438
if view_only:
3539
return "Database is view only", 401
3640
for k, v in request.form.items():
37-
db[prefix + k] = v
41+
default_db.db[prefix + k] = v
3842
return ""
3943

4044
@app.route("/", methods=["GET", "POST"])
@@ -44,16 +48,20 @@ def index() -> Any:
4448
return set_key()
4549

4650
def get_key(key: str) -> Any:
51+
if default_db.db is None:
52+
return "Database is not configured", 500
4753
try:
48-
return db[prefix + key]
54+
return default_db.db[prefix + key]
4955
except KeyError:
5056
return "", 404
5157

5258
def delete_key(key: str) -> Any:
59+
if default_db.db is None:
60+
return "Database is not configured", 500
5361
if view_only:
5462
return "Database is view only", 401
5563
try:
56-
del db[prefix + key]
64+
del default_db.db[prefix + key]
5765
except KeyError:
5866
return "", 404
5967
return ""

src/replit/web/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
from .app import debug, ReplitAuthContext, run
1010
from .user import User, UserStore
1111
from .utils import *
12-
from ..database import AsyncDatabase, Database, db
12+
from .. import database
13+
from ..database import AsyncDatabase, Database
1314

1415
auth = LocalProxy(lambda: ReplitAuthContext.from_headers(flask.request.headers))
16+
17+
18+
# Previous versions of this library would just have side-effects and always set
19+
# up a database unconditionally. That is very undesirable, so instead of doing
20+
# that, we are using this egregious hack to get the database / database URL
21+
# lazily.
22+
def __getattr__(name: str) -> Any:
23+
if name == "db":
24+
return database.db
25+
raise AttributeError(name)

0 commit comments

Comments
 (0)