Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 65 additions & 34 deletions sentry_sdk/_werkzeug.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,58 +41,89 @@


#
# `get_headers` comes from `werkzeug.datastructures.EnvironHeaders`
# https://github.com/pallets/werkzeug/blob/0.14.1/werkzeug/datastructures.py#L1361
# `get_headers` comes from `werkzeug.datastructures.headers.__iter__`
# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/datastructures/headers.py#L644
#
# We need this function because Django does not give us a "pure" http header
# dict. So we might as well use it for all WSGI integrations.
#
def _get_headers(environ):
# type: (Dict[str, str]) -> Iterator[Tuple[str, str]]
"""
Returns only proper HTTP headers.
"""
for key, value in environ.items():
key = str(key)
if key.startswith("HTTP_") and key not in (
if key.startswith("HTTP_") and key not in {
"HTTP_CONTENT_TYPE",
"HTTP_CONTENT_LENGTH",
):
}:
yield key[5:].replace("_", "-").title(), value
elif key in ("CONTENT_TYPE", "CONTENT_LENGTH"):
elif key in {"CONTENT_TYPE", "CONTENT_LENGTH"} and value:
yield key.replace("_", "-").title(), value


#
# `get_host` comes from `werkzeug.wsgi.get_host`
# https://github.com/pallets/werkzeug/blob/1.0.1/src/werkzeug/wsgi.py#L145
# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/wsgi.py#L86
#
def get_host(environ, use_x_forwarded_for=False):
# type: (Dict[str, str], bool) -> str
"""
Return the host for the given WSGI environment.
"""
if use_x_forwarded_for and "HTTP_X_FORWARDED_HOST" in environ:
rv = environ["HTTP_X_FORWARDED_HOST"]
if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"):
rv = rv[:-3]
elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"):
rv = rv[:-4]
elif environ.get("HTTP_HOST"):
rv = environ["HTTP_HOST"]
if environ["wsgi.url_scheme"] == "http" and rv.endswith(":80"):
rv = rv[:-3]
elif environ["wsgi.url_scheme"] == "https" and rv.endswith(":443"):
rv = rv[:-4]
elif environ.get("SERVER_NAME"):
rv = environ["SERVER_NAME"]
if (environ["wsgi.url_scheme"], environ["SERVER_PORT"]) not in (
("https", "443"),
("http", "80"),
):
rv += ":" + environ["SERVER_PORT"]
else:
# In spite of the WSGI spec, SERVER_NAME might not be present.
rv = "unknown"

return rv
return _get_host(
environ["wsgi.url_scheme"],
(
environ["HTTP_X_FORWARDED_HOST"]
if use_x_forwarded_for and environ.get("HTTP_X_FORWARDED_HOST")
else environ.get("HTTP_HOST")
),
_get_server(environ),
)


# `_get_host` comes from `werkzeug.sansio.utils`
# https://github.com/pallets/werkzeug/blob/3.1.3/src/werkzeug/sansio/utils.py#L49
def _get_host(
scheme,
host_header,
server=None,
):
# type: (str, str | None, Tuple[str, int | None] | None) -> str
"""
Return the host for the given parameters.
"""
host = ""

if host_header is not None:
host = host_header
elif server is not None:
host = server[0]

# If SERVER_NAME is IPv6, wrap it in [] to match Host header.
# Check for : because domain or IPv4 can't have that.
if ":" in host and host[0] != "[":
host = f"[{host}]"

if server[1] is not None:
host = f"{host}:{server[1]}" # noqa: E231

if scheme in {"http", "ws"} and host.endswith(":80"):
host = host[:-3]
elif scheme in {"https", "wss"} and host.endswith(":443"):
host = host[:-4]

return host


def _get_server(environ):
# type: (Dict[str, str]) -> Tuple[str, int | None] | None
name = environ.get("SERVER_NAME")

if name is None:
return None

try:
port = int(environ.get("SERVER_PORT", None)) # type: ignore[arg-type]
except (TypeError, ValueError):
# unix socket
port = None

return name, port
67 changes: 67 additions & 0 deletions tests/integrations/wsgi/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sentry_sdk
from sentry_sdk import capture_message
from sentry_sdk.integrations.wsgi import SentryWsgiMiddleware
from sentry_sdk._werkzeug import get_host


@pytest.fixture
Expand Down Expand Up @@ -39,6 +40,50 @@ def next(self):
return type(self).__next__(self)


@pytest.mark.parametrize(
("environ", "expect"),
(
pytest.param({"HTTP_HOST": "spam"}, "spam", id="host"),
pytest.param({"HTTP_HOST": "spam:80"}, "spam", id="host, strip http port"),
pytest.param(
{"wsgi.url_scheme": "https", "HTTP_HOST": "spam:443"},
"spam",
id="host, strip https port",
),
pytest.param({"HTTP_HOST": "spam:8080"}, "spam:8080", id="host, custom port"),
pytest.param(
{"HTTP_HOST": "spam", "SERVER_NAME": "eggs", "SERVER_PORT": "80"},
"spam",
id="prefer host",
),
pytest.param(
{"SERVER_NAME": "eggs", "SERVER_PORT": "80"},
"eggs",
id="name, ignore http port",
),
pytest.param(
{"wsgi.url_scheme": "https", "SERVER_NAME": "eggs", "SERVER_PORT": "443"},
"eggs",
id="name, ignore https port",
),
pytest.param(
{"SERVER_NAME": "eggs", "SERVER_PORT": "8080"},
"eggs:8080",
id="name, custom port",
),
pytest.param(
{"HTTP_HOST": "ham", "HTTP_X_FORWARDED_HOST": "eggs"},
"ham",
id="ignore x-forwarded-host",
),
),
)
# https://github.com/pallets/werkzeug/blob/main/tests/test_wsgi.py#L60
def test_get_host(environ, expect):
environ.setdefault("wsgi.url_scheme", "http")
assert get_host(environ) == expect


def test_basic(sentry_init, crashing_app, capture_events):
sentry_init(send_default_pii=True)
app = SentryWsgiMiddleware(crashing_app)
Expand All @@ -61,6 +106,28 @@ def test_basic(sentry_init, crashing_app, capture_events):
}


def test_basic_forwarded_host(sentry_init, crashing_app, capture_events):
sentry_init(send_default_pii=True)
app = SentryWsgiMiddleware(crashing_app, use_x_forwarded_for=True)
client = Client(app)
events = capture_events()

with pytest.raises(ZeroDivisionError):
client.get("/", environ_overrides={"HTTP_X_FORWARDED_HOST": "foobarbaz:80"})

(event,) = events

assert event["transaction"] == "generic WSGI request"

assert event["request"] == {
"env": {"SERVER_NAME": "localhost", "SERVER_PORT": "80"},
"headers": {"Host": "localhost", "X-Forwarded-Host": "foobarbaz:80"},
"method": "GET",
"query_string": "",
"url": "http://foobarbaz/",
}


@pytest.mark.parametrize("path_info", ("bark/", "/bark/"))
@pytest.mark.parametrize("script_name", ("woof/woof", "woof/woof/"))
def test_script_name_is_respected(
Expand Down