diff --git a/Dockerfile b/Dockerfile index fd9522e..f6c122d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-buster +FROM python:3.11-bullseye LABEL maintainer="Penn Labs" diff --git a/src/auth.py b/src/auth.py index 5a7d9cf..7d43ce1 100644 --- a/src/auth.py +++ b/src/auth.py @@ -5,23 +5,24 @@ from src.config import settings -# The URL to the JWKS endpoint -JWKS_URL = settings.JWKS_URL - - -def get_jwk(): +def get_jwks(): if settings.JWKS_CACHE: - key = settings.JWKS_CACHE - return key - + # Check to make sure we have a cached JWK for each key, otherwise refetch all. + missing = False + for key in settings.JWKS_URL.keys(): + if key not in settings.JWKS_CACHE: + missing = True + if not missing: + return settings.JWKS_CACHE # Make a request to get the JWKS - try: - response = requests.get(JWKS_URL) - jwks = jwk.JWKSet.from_json(response.text) - settings.JWKS_CACHE = jwks - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - + for key in settings.JWKS_URL: + try: + response = requests.get(settings.JWKS_URL[key]) + jwks = jwk.JWKSet.from_json(response.text) + settings.JWKS_CACHE[key] = jwks + except Exception as e: + print(str(e)) + raise HTTPException(status_code=500, detail=str(e)) return settings.JWKS_CACHE @@ -41,11 +42,11 @@ def get_token_from_header(request: Request): def verify_jwt(token: str = Depends(get_token_from_header)): - try: - # Load the public key - public_key = get_jwk() - # Decode and verify the JWT - decoded_token = jwt.JWT(key=public_key, jwt=token) - return decoded_token.claims - except Exception as e: - raise HTTPException(status_code=401, detail=str(e)) + public_keys = get_jwks() + for key in public_keys: + try: + decoded_token = jwt.JWT(key=public_keys[key], jwt=token) + return decoded_token.claims + except Exception: + pass + raise HTTPException(status_code=401, detail="Failed to verify JWT token") diff --git a/src/config.py b/src/config.py index f20c759..253ba21 100644 --- a/src/config.py +++ b/src/config.py @@ -11,9 +11,11 @@ class Config(BaseSettings): DATABASE_URL: PostgresDsn REDIS_URL: RedisDsn - JWKS_CACHE: JWKSet | None = None - JWKS_URL: str = "https://platform.pennlabs.org/identity/jwks/" - + JWKS_CACHE: dict[str, JWKSet] = {} + JWKS_URL: dict[str, str] = { + "b2b": "https://platform.pennlabs.org/identity/jwks/", + "user": "https://platform.pennlabs.org/accounts/.well-known/jwks.json", + } SITE_DOMAIN: str = "analytics.pennlabs.org" ENVIRONMENT: Environment = Environment.PRODUCTION diff --git a/tests/test_load.py b/tests/test_load.py index 13eb980..11fac0a 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -5,7 +5,8 @@ from datetime import datetime import requests -from test_token import get_tokens + +from tests.test_token import get_tokens # Runtime should be less that 3 seconds for most laptops @@ -16,10 +17,8 @@ THREADS = 16 -def make_request(): - access_token, _ = get_tokens() - - url = "http://localhost:8000/analytics" +def make_request(token: str): + url = "http://localhost:8000/analytics/" payload = json.dumps( { "product": random.randint(1, 10), @@ -41,7 +40,7 @@ def make_request(): ) headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", + "Authorization": f"Bearer {token}", } try: @@ -55,8 +54,10 @@ def make_request(): def run_threads(): with ThreadPoolExecutor(max_workers=THREADS) as executor: + # Fetch token once to improve performance + access_token, _ = get_tokens() for _ in range(NUMBER_OF_REQUESTS): - executor.submit(make_request) + executor.submit(make_request, access_token) def test_load():