Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

# Import middleware
from app.util.auth_dependencies import Authentication, CurrentMember, sign_redirect_url, verify_redirect_url
from app.util.database import get_session, init_db
from app.util.database import get_session, init_db, engine
from app.util.discord import Discord

# Import error handling
Expand Down Expand Up @@ -71,6 +71,41 @@
logger = logging.getLogger(__name__)


def update_discord_model_for_existing_user(user_id: str, discord_data: dict):
"""
Background task to update Discord model for existing users during login.
Only updates the Discord model, does not affect user creation flow for new users.
"""
try:
# Create a new session for the background task
with Session(engine) as session:
# Get the user with their Discord model
statement = select(UserModel).where(UserModel.id == uuid.UUID(user_id)).options(selectinload(UserModel.discord))
user = session.exec(statement).one_or_none()

if not user or not user.discord:
logger.warning(f"Could not find user or Discord model for user {user_id}")
return

# Update Discord model with fresh data from OAuth
discord_model = user.discord
discord_model.email = discord_data.get("email")
discord_model.mfa = discord_data.get("mfa_enabled")
discord_model.avatar = f"https://cdn.discordapp.com/avatars/{discord_data['id']}/{discord_data['avatar']}.png?size=512" if discord_data.get("avatar") else None
discord_model.banner = f"https://cdn.discordapp.com/banners/{discord_data['id']}/{discord_data['banner']}.png?size=1536" if discord_data.get("banner") else None
discord_model.color = discord_data.get("accent_color")
discord_model.nitro = discord_data.get("premium_type")
discord_model.locale = discord_data.get("locale")
discord_model.username = discord_data.get("username")

session.add(discord_model)
session.commit()
logger.info(f"Updated Discord model for existing user {user_id}")

except Exception as e:
logger.error(f"Failed to update Discord model for user {user_id}: {e}")


# Initiate FastAPI.
app = FastAPI()
templates = Jinja2Templates(directory="app/templates")
Expand Down Expand Up @@ -275,6 +310,7 @@ async def oauth_transformer(redir: str = None):
async def oauth_transformer_new(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
code: str = None,
state: str = None,
redir_endpoint: Optional[str] = Cookie(None),
Expand Down Expand Up @@ -359,6 +395,9 @@ async def oauth_transformer_new(
session.add(user)
session.commit()
session.refresh(user)
else:
# Existing user - update their Discord model in background
background_tasks.add_task(update_discord_model_for_existing_user, str(user.id), discordData)

# Create JWT. This should be the only way to issue JWTs.
bearer = Authentication.create_jwt(user)
Expand Down
89 changes: 89 additions & 0 deletions tests/test_discord_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024 Collegiate Cyber Defense Club
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session

from app.main import update_discord_model_for_existing_user
from app.models.user import DiscordModel, UserModel


def test_update_discord_model_for_existing_user(session: Session, test_user: UserModel):
"""Test that the background task updates Discord model for existing users"""

# Setup test data - simulate Discord API response
original_username = test_user.discord.username
original_avatar = test_user.discord.avatar

discord_data = {
"id": test_user.discord_id,
"email": "[email protected]",
"username": "updated_username",
"avatar": "new_avatar_hash",
"banner": "new_banner_hash",
"mfa_enabled": True,
"accent_color": 123456,
"premium_type": 2,
"locale": "en_GB"
}

# Mock the engine and session creation for the background task
with patch('app.main.engine') as mock_engine:
mock_session = MagicMock()
mock_engine.__enter__ = MagicMock(return_value=mock_session)
mock_engine.__exit__ = MagicMock(return_value=None)

# Mock the session context manager
mock_engine_instance = MagicMock()
mock_engine_instance.__enter__ = MagicMock(return_value=mock_session)
mock_engine_instance.__exit__ = MagicMock(return_value=None)

with patch('app.main.Session') as mock_session_class:
mock_session_class.return_value = mock_engine_instance

# Mock the database query
mock_session.exec.return_value.one_or_none.return_value = test_user

# Call the background task function
update_discord_model_for_existing_user(str(test_user.id), discord_data)

# Verify the session was used correctly
mock_session_class.assert_called_once()
mock_session.exec.assert_called_once()
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()


def test_update_discord_model_user_not_found():
"""Test background task handles missing user gracefully"""

fake_user_id = str(uuid.uuid4())
discord_data = {"id": "123", "username": "test"}

with patch('app.main.Session') as mock_session_class:
mock_session = MagicMock()
mock_engine_instance = MagicMock()
mock_engine_instance.__enter__ = MagicMock(return_value=mock_session)
mock_engine_instance.__exit__ = MagicMock(return_value=None)
mock_session_class.return_value = mock_engine_instance

# Mock user not found
mock_session.exec.return_value.one_or_none.return_value = None

# Should not raise exception, just log warning
update_discord_model_for_existing_user(fake_user_id, discord_data)

# Should not call add or commit
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()


@pytest.mark.skip(reason="Requires full app configuration setup")
def test_oauth_calls_background_task_for_existing_user():
"""Test that OAuth endpoint adds background task for existing users"""
# This test would require full OAuth flow mocking
# Skipping for now as it requires extensive setup
pass
Loading