Skip to content

Commit 98c5e17

Browse files
committed
Improved H3 for hypercorn.
1 parent 3fbd5f2 commit 98c5e17

File tree

5 files changed

+161
-31
lines changed

5 files changed

+161
-31
lines changed

src/hypercorn/asyncio/task_group.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Awaitable, Callable, Optional
77

88
from ..config import Config
9-
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
9+
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer
1010

1111
try:
1212
from asyncio import TaskGroup as AsyncioTaskGroup
@@ -33,6 +33,44 @@ async def _handle(
3333
await send(None)
3434

3535

36+
LONG_SLEEP = 86400.0
37+
38+
class AsyncioTimer(Timer):
39+
def __init__(self, action: Callable) -> None:
40+
self._action = action
41+
self._done = False
42+
self._wake_up = asyncio.Condition()
43+
self._when: Optional[float] = None
44+
45+
async def schedule(self, when: Optional[float]) -> None:
46+
self._when = when
47+
async with self._wake_up:
48+
self._wake_up.notify()
49+
50+
async def stop(self) -> None:
51+
self._done = True
52+
async with self._wake_up:
53+
self._wake_up.notify()
54+
55+
async def _wait_for_wake_up(self) -> None:
56+
async with self._wake_up:
57+
await self._wake_up.wait()
58+
59+
async def run(self) -> None:
60+
while not self._done:
61+
if self._when is not None and asyncio.get_event_loop().time() >= self._when:
62+
self._when = None
63+
await self._action()
64+
if self._when is not None:
65+
timeout = max(self._when - asyncio.get_event_loop().time(), 0.0)
66+
else:
67+
timeout = LONG_SLEEP
68+
if not self._done:
69+
try:
70+
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
71+
except TimeoutError:
72+
pass
73+
3674
class TaskGroup:
3775
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
3876
self._loop = loop
@@ -66,6 +104,11 @@ def _call_soon(func: Callable, *args: Any) -> Any:
66104
def spawn(self, func: Callable, *args: Any) -> None:
67105
self._task_group.create_task(func(*args))
68106

107+
def create_timer(self, action: Callable) -> Timer:
108+
timer = AsyncioTimer(action)
109+
self._task_group.create_task(timer.run())
110+
return timer
111+
69112
async def __aenter__(self) -> "TaskGroup":
70113
await self._task_group.__aenter__()
71114
return self

src/hypercorn/protocol/quic.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial
4-
from typing import Awaitable, Callable, Dict, Optional, Tuple
4+
from typing import Awaitable, Callable, Dict, Optional, Set, Tuple
55

66
from aioquic.buffer import Buffer
77
from aioquic.h3.connection import H3_ALPN
@@ -22,7 +22,21 @@
2222
from .h3 import H3Protocol
2323
from ..config import Config
2424
from ..events import Closed, Event, RawData
25-
from ..typing import AppWrapper, TaskGroup, WorkerContext
25+
from ..typing import AppWrapper, TaskGroup, WorkerContext, Timer
26+
27+
28+
class ConnectionState:
29+
def __init__(self, connection: QuicConnection):
30+
self.connection = connection
31+
self.timer: Optional[Timer] = None
32+
self.cids: Set[bytes] = set()
33+
self.h3_protocol: Optional[H3Protocol] = None
34+
35+
def add_cid(self, cid: bytes) -> None:
36+
self.cids.add(cid)
37+
38+
def remove_cid(self, cid: bytes) -> None:
39+
self.cids.remove(cid)
2640

2741

2842
class QuicProtocol:
@@ -38,8 +52,7 @@ def __init__(
3852
self.app = app
3953
self.config = config
4054
self.context = context
41-
self.connections: Dict[bytes, QuicConnection] = {}
42-
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
55+
self.connections: Dict[bytes, ConnectionState] = {}
4356
self.send = send
4457
self.server = server
4558
self.task_group = task_group
@@ -49,7 +62,7 @@ def __init__(
4962

5063
@property
5164
def idle(self) -> bool:
52-
return len(self.connections) == 0 and len(self.http_connections) == 0
65+
return len(self.connections) == 0
5366

5467
async def handle(self, event: Event) -> None:
5568
if isinstance(event, RawData):
@@ -69,9 +82,13 @@ async def handle(self, event: Event) -> None:
6982
await self.send(RawData(data=data, address=event.address))
7083
return
7184

72-
connection = self.connections.get(header.destination_cid)
85+
state = self.connections.get(header.destination_cid)
86+
if state is not None:
87+
connection = state.connection
88+
else:
89+
connection = None
7390
if (
74-
connection is None
91+
state is None
7592
and len(event.data) >= 1200
7693
and header.packet_type == PACKET_TYPE_INITIAL
7794
and not self.context.terminated.is_set()
@@ -80,12 +97,18 @@ async def handle(self, event: Event) -> None:
8097
configuration=self.quic_config,
8198
original_destination_connection_id=header.destination_cid,
8299
)
83-
self.connections[header.destination_cid] = connection
84-
self.connections[connection.host_cid] = connection
100+
# This partial() needs python >= 3.8
101+
state = ConnectionState(connection)
102+
timer = self.task_group.create_timer(partial(self._timeout, state))
103+
state.timer = timer
104+
state.add_cid(header.destination_cid)
105+
self.connections[header.destination_cid] = state
106+
state.add_cid(connection.host_cid)
107+
self.connections[connection.host_cid] = state
85108

86109
if connection is not None:
87110
connection.receive_datagram(event.data, event.address, now=self.context.time())
88-
await self._handle_events(connection, event.address)
111+
await self._wake_up_timer(state)
89112
elif isinstance(event, Closed):
90113
pass
91114

@@ -94,42 +117,50 @@ async def send_all(self, connection: QuicConnection) -> None:
94117
await self.send(RawData(data=data, address=address))
95118

96119
async def _handle_events(
97-
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
120+
self, state: ConnectionState, client: Optional[Tuple[str, int]] = None
98121
) -> None:
122+
connection = state.connection
99123
event = connection.next_event()
100124
while event is not None:
101125
if isinstance(event, ConnectionTerminated):
102-
pass
126+
await state.timer.stop()
127+
for cid in state.cids:
128+
del self.connections[cid]
129+
state.cids = set()
103130
elif isinstance(event, ProtocolNegotiated):
104-
self.http_connections[connection] = H3Protocol(
131+
state.h3_protocol = H3Protocol(
105132
self.app,
106133
self.config,
107134
self.context,
108135
self.task_group,
109136
client,
110137
self.server,
111138
connection,
112-
partial(self.send_all, connection),
139+
partial(self._wake_up_timer, state),
113140
)
114141
elif isinstance(event, ConnectionIdIssued):
115-
self.connections[event.connection_id] = connection
142+
state.add_cid(event.connection_id)
143+
self.connections[event.connection_id] = state
116144
elif isinstance(event, ConnectionIdRetired):
145+
state.remove_cid(event.connection_id)
117146
del self.connections[event.connection_id]
118147

119-
if connection in self.http_connections:
120-
await self.http_connections[connection].handle(event)
148+
elif state.h3_protocol is not None:
149+
await state.h3_protocol.handle(event)
121150

122151
event = connection.next_event()
123152

153+
async def _wake_up_timer(self, state: ConnectionState) -> None:
154+
# When new output is send, or new input is received, we
155+
# fire the timer right away so we update our state.
156+
await state.timer.schedule(0.0)
157+
158+
async def _timeout(self, state: ConnectionState) -> None:
159+
connection = state.connection
160+
now = self.context.time()
161+
when = connection.get_timer()
162+
if when is not None and now > when:
163+
connection.handle_timer(now)
164+
await self._handle_events(state, None)
124165
await self.send_all(connection)
125-
126-
timer = connection.get_timer()
127-
if timer is not None:
128-
self.task_group.spawn(self._handle_timer, timer, connection)
129-
130-
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
131-
wait = max(0, timer - self.context.time())
132-
await self.context.sleep(wait)
133-
if connection._close_at is not None:
134-
connection.handle_timer(now=self.context.time())
135-
await self._handle_events(connection, None)
166+
await state.timer.schedule(connection.get_timer())

src/hypercorn/trio/task_group.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import trio
88

99
from ..config import Config
10-
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
10+
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer
1111

1212
if sys.version_info < (3, 11):
1313
from exceptiongroup import BaseExceptionGroup
@@ -39,6 +39,40 @@ async def _handle(
3939
await send(None)
4040

4141

42+
LONG_SLEEP = 86400.0
43+
44+
class TrioTimer(Timer):
45+
def __init__(self, action: Callable) -> None:
46+
self._action = action
47+
self._done = False
48+
self._wake_up = trio.Condition()
49+
self._when: Optional[float] = None
50+
51+
async def schedule(self, when: Optional[float]) -> None:
52+
self._when = when
53+
async with self._wake_up:
54+
self._wake_up.notify()
55+
56+
async def stop(self) -> None:
57+
self._done = True
58+
async with self._wake_up:
59+
self._wake_up.notify()
60+
61+
async def run(self) -> None:
62+
while not self._done:
63+
if self._when is not None and trio.current_time() >= self._when:
64+
self._when = None
65+
await self._action()
66+
if self._when is not None:
67+
timeout = max(self._when - trio.current_time(), 0.0)
68+
else:
69+
timeout = LONG_SLEEP
70+
if not self._done:
71+
with trio.move_on_after(timeout):
72+
async with self._wake_up:
73+
await self._wake_up.wait()
74+
75+
4276
class TaskGroup:
4377
def __init__(self) -> None:
4478
self._nursery: Optional[trio._core._run.Nursery] = None
@@ -67,6 +101,11 @@ async def spawn_app(
67101
def spawn(self, func: Callable, *args: Any) -> None:
68102
self._nursery.start_soon(func, *args)
69103

104+
def create_timer(self, action: Callable) -> Timer:
105+
timer = TrioTimer(action)
106+
self._nursery.start_soon(timer.run)
107+
return timer
108+
70109
async def __aenter__(self) -> TaskGroup:
71110
self._nursery_manager = trio.open_nursery()
72111
self._nursery = await self._nursery_manager.__aenter__()

src/hypercorn/trio/worker_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Type, Union
3+
from typing import Awaitable, Optional, Type, Union
44

55
import trio
66

src/hypercorn/typing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,20 @@ def is_set(self) -> bool:
288288
...
289289

290290

291+
class Timer:
292+
def __init__(self, action: Callable) -> None:
293+
...
294+
295+
async def schedule(self, when: float) -> None:
296+
...
297+
298+
async def stop(self) -> None:
299+
...
300+
301+
async def run(self) -> None:
302+
...
303+
304+
291305
class WorkerContext(Protocol):
292306
event_class: Type[Event]
293307
terminate: Event
@@ -318,6 +332,9 @@ async def spawn_app(
318332
def spawn(self, func: Callable, *args: Any) -> None:
319333
...
320334

335+
def create_timer(self, action: Callable) -> Timer:
336+
...
337+
321338
async def __aenter__(self) -> TaskGroup:
322339
...
323340

0 commit comments

Comments
 (0)