1
1
from __future__ import annotations
2
2
3
3
from functools import partial
4
- from typing import Awaitable , Callable , Dict , Optional , Tuple
4
+ from typing import Awaitable , Callable , Dict , Optional , Set , Tuple
5
5
6
6
from aioquic .buffer import Buffer
7
7
from aioquic .h3 .connection import H3_ALPN
22
22
from .h3 import H3Protocol
23
23
from ..config import Config
24
24
from ..events import Closed , Event , RawData
25
- from ..typing import AppWrapper , TaskGroup , WorkerContext
25
+ from ..typing import AppWrapper , TaskGroup , WorkerContext , Timer
26
+
27
+
28
+ class ConnectionState :
29
+ def __init__ (self , connection : QuicConnection ):
30
+ self .connection = connection
31
+ self .timer : Optional [Timer ] = None
32
+ self .cids : Set [bytes ] = set ()
33
+ self .h3_protocol : Optional [H3Protocol ] = None
34
+
35
+ def add_cid (self , cid : bytes ) -> None :
36
+ self .cids .add (cid )
37
+
38
+ def remove_cid (self , cid : bytes ) -> None :
39
+ self .cids .remove (cid )
26
40
27
41
28
42
class QuicProtocol :
@@ -38,8 +52,7 @@ def __init__(
38
52
self .app = app
39
53
self .config = config
40
54
self .context = context
41
- self .connections : Dict [bytes , QuicConnection ] = {}
42
- self .http_connections : Dict [QuicConnection , H3Protocol ] = {}
55
+ self .connections : Dict [bytes , ConnectionState ] = {}
43
56
self .send = send
44
57
self .server = server
45
58
self .task_group = task_group
@@ -49,7 +62,7 @@ def __init__(
49
62
50
63
@property
51
64
def idle (self ) -> bool :
52
- return len (self .connections ) == 0 and len ( self . http_connections ) == 0
65
+ return len (self .connections ) == 0
53
66
54
67
async def handle (self , event : Event ) -> None :
55
68
if isinstance (event , RawData ):
@@ -69,9 +82,13 @@ async def handle(self, event: Event) -> None:
69
82
await self .send (RawData (data = data , address = event .address ))
70
83
return
71
84
72
- connection = self .connections .get (header .destination_cid )
85
+ state = self .connections .get (header .destination_cid )
86
+ if state is not None :
87
+ connection = state .connection
88
+ else :
89
+ connection = None
73
90
if (
74
- connection is None
91
+ state is None
75
92
and len (event .data ) >= 1200
76
93
and header .packet_type == PACKET_TYPE_INITIAL
77
94
and not self .context .terminated .is_set ()
@@ -80,12 +97,18 @@ async def handle(self, event: Event) -> None:
80
97
configuration = self .quic_config ,
81
98
original_destination_connection_id = header .destination_cid ,
82
99
)
83
- self .connections [header .destination_cid ] = connection
84
- self .connections [connection .host_cid ] = connection
100
+ # This partial() needs python >= 3.8
101
+ state = ConnectionState (connection )
102
+ timer = self .task_group .create_timer (partial (self ._timeout , state ))
103
+ state .timer = timer
104
+ state .add_cid (header .destination_cid )
105
+ self .connections [header .destination_cid ] = state
106
+ state .add_cid (connection .host_cid )
107
+ self .connections [connection .host_cid ] = state
85
108
86
109
if connection is not None :
87
110
connection .receive_datagram (event .data , event .address , now = self .context .time ())
88
- await self ._handle_events ( connection , event . address )
111
+ await self ._wake_up_timer ( state )
89
112
elif isinstance (event , Closed ):
90
113
pass
91
114
@@ -94,42 +117,50 @@ async def send_all(self, connection: QuicConnection) -> None:
94
117
await self .send (RawData (data = data , address = address ))
95
118
96
119
async def _handle_events (
97
- self , connection : QuicConnection , client : Optional [Tuple [str , int ]] = None
120
+ self , state : ConnectionState , client : Optional [Tuple [str , int ]] = None
98
121
) -> None :
122
+ connection = state .connection
99
123
event = connection .next_event ()
100
124
while event is not None :
101
125
if isinstance (event , ConnectionTerminated ):
102
- pass
126
+ await state .timer .stop ()
127
+ for cid in state .cids :
128
+ del self .connections [cid ]
129
+ state .cids = set ()
103
130
elif isinstance (event , ProtocolNegotiated ):
104
- self . http_connections [ connection ] = H3Protocol (
131
+ state . h3_protocol = H3Protocol (
105
132
self .app ,
106
133
self .config ,
107
134
self .context ,
108
135
self .task_group ,
109
136
client ,
110
137
self .server ,
111
138
connection ,
112
- partial (self .send_all , connection ),
139
+ partial (self ._wake_up_timer , state ),
113
140
)
114
141
elif isinstance (event , ConnectionIdIssued ):
115
- self .connections [event .connection_id ] = connection
142
+ state .add_cid (event .connection_id )
143
+ self .connections [event .connection_id ] = state
116
144
elif isinstance (event , ConnectionIdRetired ):
145
+ state .remove_cid (event .connection_id )
117
146
del self .connections [event .connection_id ]
118
147
119
- if connection in self . http_connections :
120
- await self . http_connections [ connection ] .handle (event )
148
+ elif state . h3_protocol is not None :
149
+ await state . h3_protocol .handle (event )
121
150
122
151
event = connection .next_event ()
123
152
153
+ async def _wake_up_timer (self , state : ConnectionState ) -> None :
154
+ # When new output is send, or new input is received, we
155
+ # fire the timer right away so we update our state.
156
+ await state .timer .schedule (0.0 )
157
+
158
+ async def _timeout (self , state : ConnectionState ) -> None :
159
+ connection = state .connection
160
+ now = self .context .time ()
161
+ when = connection .get_timer ()
162
+ if when is not None and now > when :
163
+ connection .handle_timer (now )
164
+ await self ._handle_events (state , None )
124
165
await self .send_all (connection )
125
-
126
- timer = connection .get_timer ()
127
- if timer is not None :
128
- self .task_group .spawn (self ._handle_timer , timer , connection )
129
-
130
- async def _handle_timer (self , timer : float , connection : QuicConnection ) -> None :
131
- wait = max (0 , timer - self .context .time ())
132
- await self .context .sleep (wait )
133
- if connection ._close_at is not None :
134
- connection .handle_timer (now = self .context .time ())
135
- await self ._handle_events (connection , None )
166
+ await state .timer .schedule (connection .get_timer ())
0 commit comments