Skip to content

Commit 3fbd5f2

Browse files
apollo13pgjones
authored andcommitted
Properly set host header to ascii string in ProxyFixMiddleware.
1 parent bc39603 commit 3fbd5f2

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

src/hypercorn/middleware/proxy_fix.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ def __init__(
1818
self.trusted_hops = trusted_hops
1919

2020
async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
21-
if scope["type"] in {"http", "websocket"}:
21+
# Keep the `or` instead of `in {'http' …}` to allow type narrowing
22+
if scope["type"] == "http" or scope["type"] == "websocket":
2223
scope = deepcopy(scope)
23-
headers = scope["headers"] # type: ignore
24+
headers = scope["headers"]
2425
client: Optional[str] = None
2526
scheme: Optional[str] = None
2627
host: Optional[str] = None
@@ -44,19 +45,19 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non
4445
host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops)
4546

4647
if client is not None:
47-
scope["client"] = (client, 0) # type: ignore
48+
scope["client"] = (client, 0)
4849

4950
if scheme is not None:
50-
scope["scheme"] = scheme # type: ignore
51+
scope["scheme"] = scheme
5152

5253
if host is not None:
5354
headers = [
5455
(name, header_value)
5556
for name, header_value in headers
5657
if name.lower() != b"host"
5758
]
58-
headers.append((b"host", host))
59-
scope["headers"] = headers # type: ignore
59+
headers.append((b"host", host.encode()))
60+
scope["headers"] = headers
6061

6162
await self.app(scope, receive, send)
6263

tests/middleware/test_proxy_fix.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@ async def test_proxy_fix_legacy() -> None:
2626
(b"x-forwarded-for", b"127.0.0.1"),
2727
(b"x-forwarded-for", b"127.0.0.2"),
2828
(b"x-forwarded-proto", b"http,https"),
29+
(b"x-forwarded-host", b"example.com"),
2930
],
3031
"client": ("127.0.0.3", 80),
3132
"server": None,
3233
"extensions": {},
3334
}
3435
await app(scope, None, None)
3536
mock.assert_called()
36-
assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0)
37-
assert mock.call_args[0][0]["scheme"] == "https"
37+
scope = mock.call_args[0][0]
38+
assert scope["client"] == ("127.0.0.2", 0)
39+
assert scope["scheme"] == "https"
40+
host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"]
41+
assert host_headers == [(b"host", b"example.com")]
3842

3943

4044
@pytest.mark.asyncio
@@ -52,13 +56,16 @@ async def test_proxy_fix_modern() -> None:
5256
"query_string": b"",
5357
"root_path": "",
5458
"headers": [
55-
(b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https"),
59+
(b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https;host=example.com"),
5660
],
5761
"client": ("127.0.0.3", 80),
5862
"server": None,
5963
"extensions": {},
6064
}
6165
await app(scope, None, None)
6266
mock.assert_called()
63-
assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0)
64-
assert mock.call_args[0][0]["scheme"] == "https"
67+
scope = mock.call_args[0][0]
68+
assert scope["client"] == ("127.0.0.2", 0)
69+
assert scope["scheme"] == "https"
70+
host_headers = [h for h in scope["headers"] if h[0].lower() == b"host"]
71+
assert host_headers == [(b"host", b"example.com")]

0 commit comments

Comments
 (0)