Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ repos:
- types-PyYAML
- types-cachetools
- types-requests
- types-python-dateutil
- types-aiobotocore[essential]
- boto3-stubs[essential]
exclude: ^(diracx-client/src/diracx/client/_generated|diracx-[a-z]+/tests/|diracx-testing/|build|extensions/gubbins/gubbins-client/src/gubbins/client/_generated)
Expand Down
2 changes: 2 additions & 0 deletions diracx-db/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ dependencies = [
"pydantic >=2.10",
"sqlalchemy[aiomysql,aiosqlite] >= 2",
"uuid-utils",
"python-dateutil",
]
dynamic = ["version"]

[project.optional-dependencies]
testing = ["diracx-testing", "freezegun"]
types = ["types-python-dateutil"]

[project.entry-points."diracx.dbs.sql"]
AuthDB = "diracx.db.sql:AuthDB"
Expand Down
1 change: 1 addition & 0 deletions diracx-db/src/diracx/db/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def init_sql():
if db._db_url.startswith("sqlite"):
await conn.exec_driver_sql("PRAGMA foreign_keys=ON")
await conn.run_sync(db.metadata.create_all)
await db.post_create(conn)


async def init_os():
Expand Down
71 changes: 69 additions & 2 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import logging
import secrets
from datetime import UTC, datetime
from itertools import pairwise

from sqlalchemy import insert, select, update
from dateutil.rrule import MONTHLY, rrule
from sqlalchemy import insert, select, text, update
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import AsyncConnection
from uuid_utils import UUID, uuid7

from diracx.core.exceptions import (
AuthorizationError,
TokenNotFoundError,
)
from diracx.db.sql.utils import BaseSQLDB, hash, substract_date
from diracx.db.sql.utils import BaseSQLDB, hash, substract_date, uuid7_from_datetime

from .schema import (
AuthorizationFlows,
Expand All @@ -25,10 +30,72 @@
USER_CODE_ALPHABET = "BCDFGHJKLMNPQRSTVWXZ"
MAX_RETRY = 5

logger = logging.getLogger(__name__)


class AuthDB(BaseSQLDB):
metadata = AuthDBBase.metadata

@classmethod
async def post_create(cls, conn: AsyncConnection) -> None:
"""Create partitions if it is a MySQL DB and it does not have
it yet and the table does not have any data yet.
We do this as a post_create step as sqlalchemy does not support
partition so well.
"""
if conn.dialect.name == "mysql":
check_partition_query = text(
"SELECT PARTITION_NAME FROM information_schema.partitions "
"WHERE TABLE_NAME = 'RefreshTokens' AND PARTITION_NAME is not NULL"
)
partition_names = (await conn.execute(check_partition_query)).all()

if not partition_names:
# Create a monthly partition from today until 2 years
# The partition are named p_<year>_<month>
start_date = datetime.now(tz=UTC).replace(
day=1, hour=0, minute=0, second=0, microsecond=0
)
end_date = start_date.replace(year=start_date.year + 2)

dates = [
dt for dt in rrule(MONTHLY, dtstart=start_date, until=end_date)
]

partition_list = []
for name, limit in pairwise(dates):
partition_list.append(
f"PARTITION p_{name.year}_{name.month} "
f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
)
partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")

alter_query = text(
f"ALTER TABLE RefreshTokens PARTITION BY RANGE COLUMNS (JTI) ({','.join(partition_list)})"
)

check_table_empty_query = text("SELECT * FROM RefreshTokens LIMIT 1")
refresh_table_content = (
await conn.execute(check_table_empty_query)
).all()
if refresh_table_content:
logger.warning(
"RefreshTokens table not empty. Run the following query yourself"
)
logger.warning(alter_query)
return

await conn.execute(alter_query)

partition_names = (
await conn.execute(
check_partition_query, {"table_name": "RefreshTokens"}
)
).all()
assert partition_names, (
f"There should be partitions now {partition_names}"
)

async def device_flow_validate_user_code(
self, user_code: str, max_validity: int
) -> str:
Expand Down
5 changes: 5 additions & 0 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def available_urls(cls) -> dict[str, str]:
raise
return db_urls

@classmethod
async def post_create(cls, conn: AsyncConnection) -> None:
"""Execute actions after the schema has been created."""
return

@classmethod
def transaction(cls) -> Self:
raise NotImplementedError("This should never be called")
Expand Down
Loading