Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.11-buster
FROM python:3.11-bullseye

LABEL maintainer="Penn Labs"

Expand Down
47 changes: 24 additions & 23 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
from src.config import settings


# The URL to the JWKS endpoint
JWKS_URL = settings.JWKS_URL


def get_jwk():
def get_jwks():
if settings.JWKS_CACHE:
key = settings.JWKS_CACHE
return key

# Check to make sure we have a cached JWK for each key, otherwise refetch all.
missing = False
for key in settings.JWKS_URL.keys():
if key not in settings.JWKS_CACHE:
missing = True
if not missing:
return settings.JWKS_CACHE
# Make a request to get the JWKS
try:
response = requests.get(JWKS_URL)
jwks = jwk.JWKSet.from_json(response.text)
settings.JWKS_CACHE = jwks
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

for key in settings.JWKS_URL:
try:
response = requests.get(settings.JWKS_URL[key])
jwks = jwk.JWKSet.from_json(response.text)
settings.JWKS_CACHE[key] = jwks
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))
return settings.JWKS_CACHE


Expand All @@ -41,11 +42,11 @@ def get_token_from_header(request: Request):


def verify_jwt(token: str = Depends(get_token_from_header)):
try:
# Load the public key
public_key = get_jwk()
# Decode and verify the JWT
decoded_token = jwt.JWT(key=public_key, jwt=token)
return decoded_token.claims
except Exception as e:
raise HTTPException(status_code=401, detail=str(e))
public_keys = get_jwks()
for key in public_keys:
try:
decoded_token = jwt.JWT(key=public_keys[key], jwt=token)
return decoded_token.claims
except Exception:
pass
raise HTTPException(status_code=401, detail="Failed to verify JWT token")
8 changes: 5 additions & 3 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ class Config(BaseSettings):
DATABASE_URL: PostgresDsn
REDIS_URL: RedisDsn

JWKS_CACHE: JWKSet | None = None
JWKS_URL: str = "https://platform.pennlabs.org/identity/jwks/"

JWKS_CACHE: dict[str, JWKSet] = {}
JWKS_URL: dict[str, str] = {
"b2b": "https://platform.pennlabs.org/identity/jwks/",
"user": "https://platform.pennlabs.org/accounts/.well-known/jwks.json",
}
SITE_DOMAIN: str = "analytics.pennlabs.org"

ENVIRONMENT: Environment = Environment.PRODUCTION
Expand Down
15 changes: 8 additions & 7 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import datetime

import requests
from test_token import get_tokens

from tests.test_token import get_tokens


# Runtime should be less that 3 seconds for most laptops
Expand All @@ -16,10 +17,8 @@
THREADS = 16


def make_request():
access_token, _ = get_tokens()

url = "http://localhost:8000/analytics"
def make_request(token: str):
url = "http://localhost:8000/analytics/"
payload = json.dumps(
{
"product": random.randint(1, 10),
Expand All @@ -41,7 +40,7 @@ def make_request():
)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
"Authorization": f"Bearer {token}",
}

try:
Expand All @@ -55,8 +54,10 @@ def make_request():

def run_threads():
with ThreadPoolExecutor(max_workers=THREADS) as executor:
# Fetch token once to improve performance
access_token, _ = get_tokens()
for _ in range(NUMBER_OF_REQUESTS):
executor.submit(make_request)
executor.submit(make_request, access_token)


def test_load():
Expand Down
Loading