Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ cdef class APIConnection:
cdef bint _send_pending_ping
cdef public bint is_connected
cdef bint _handshake_complete
cdef bint _initial_time_sent
cdef bint _debug_enabled
cdef public str received_name
cdef public str connected_address
Expand Down
24 changes: 21 additions & 3 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def _make_hello_request(client_info: str) -> HelloRequest:
)


def _make_time_response() -> GetTimeResponse:
"""Create a GetTimeResponse."""
resp = GetTimeResponse()
resp.epoch_seconds = int(time.time())
return resp


_cached_make_hello_request = lru_cache(maxsize=16)(_make_hello_request)
make_hello_request = _cached_make_hello_request

Expand Down Expand Up @@ -200,6 +207,7 @@ class APIConnection:
"_finish_connect_future",
"_frame_helper",
"_handshake_complete",
"_initial_time_sent",
"_keep_alive_interval",
"_keep_alive_timeout",
"_log_errors",
Expand Down Expand Up @@ -261,6 +269,7 @@ def __init__(
self._loop = asyncio.get_running_loop()
self.is_connected = False
self._handshake_complete = False
self._initial_time_sent = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.connected_address: str | None = None
Expand Down Expand Up @@ -476,6 +485,13 @@ async def _connect_hello_login(self, login: bool) -> None:
# the device has a password but we don't expect it
msg_types.append(ConnectResponse)

# Send a GetTimeResponse proactively to reduce latency during reconnect.
# This avoids an additional round-trip for the GetTimeRequest.
# If the device doesn't have Home Assistant time enabled, it will
# simply ignore this response, but since it's included in the same
# packet, it's nearly free to send and reduces pressure during reconnect.
messages.append(_make_time_response())
self._initial_time_sent = True
responses = await self.send_messages_await_response_complex(
tuple(messages),
None,
Expand Down Expand Up @@ -1061,9 +1077,11 @@ def _handle_get_time_request_internal( # pylint: disable=unused-argument
self, _msg: GetTimeRequest
) -> None:
"""Handle a GetTimeRequest."""
resp = GetTimeResponse()
resp.epoch_seconds = int(time.time())
self.send_messages((resp,))
if self._initial_time_sent:
# Ignore the first time request since we already sent it proactively
self._initial_time_sent = False
return
self.send_messages((_make_time_response(),))

async def disconnect(self) -> None:
"""Disconnect from the API."""
Expand Down
6 changes: 6 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from aioesphomeapi._frame_helper.plain_text import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import (
ConnectResponse,
GetTimeRequest,
HelloResponse,
PingRequest,
PingResponse,
Expand Down Expand Up @@ -192,6 +193,11 @@ def send_ping_request(protocol: APIPlaintextFrameHelper) -> None:
protocol.data_received(generate_plaintext_packet(ping_request))


def send_time_request(protocol: APIPlaintextFrameHelper) -> None:
time_request: message.Message = GetTimeRequest()
protocol.data_received(generate_plaintext_packet(time_request))


def get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
Expand Down
64 changes: 64 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
send_ping_response,
send_plaintext_connect_response,
send_plaintext_hello,
send_time_request,
utcnow,
)

Expand Down Expand Up @@ -1291,3 +1292,66 @@ async def test_report_fatal_error_with_log_errors_false(

# Verify the error is still stored internally
assert conn._fatal_exception is regular_error


async def test_time_request_response(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
"""Test that GetTimeResponse is sent proactively and first request is ignored."""
conn, transport, protocol, connect_task = plaintext_connect_task_with_login

# Verify that GetTimeResponse is sent proactively during initial handshake
# This happens before we even receive HelloResponse/ConnectResponse
initial_calls = transport.writelines.call_args_list
# Find the initial handshake packet that should contain Hello, Connect, and GetTimeResponse
handshake_found = False
for call_args in initial_calls:
full_data = b"".join(call_args[0][0])
# Check if this packet contains GetTimeResponse (message type 0x25)
if b"\x25" in full_data:
handshake_found = True
break
assert handshake_found, "GetTimeResponse was not sent proactively during handshake"

send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)

await connect_task
assert conn.is_connected

# Reset transport mock to check what gets sent after connection
transport.reset_mock()

# Send first GetTimeRequest - this should be ignored since we sent time proactively
send_time_request(protocol)
await asyncio.sleep(0)

# Verify no response was sent for the first request
assert transport.writelines.call_count == 0

# Send second GetTimeRequest - this should be answered
send_time_request(protocol)
await asyncio.sleep(0)

# Verify GetTimeResponse was sent
assert transport.writelines.call_count == 1
# GetTimeResponse message type is 37 (0x25)
# writelines is called with a list of bytes, check that we have the right message type
call_args = transport.writelines.call_args_list[0][0][0]
# Join all the bytes together to check
full_data = b"".join(call_args)
# Message type 37 is 0x25
assert b"\x25" in full_data

# Send third GetTimeRequest - this should also be answered
transport.reset_mock()
send_time_request(protocol)
await asyncio.sleep(0)

# Verify another GetTimeResponse was sent
assert transport.writelines.call_count == 1
call_args = transport.writelines.call_args_list[0][0][0]
full_data = b"".join(call_args)
assert b"\x25" in full_data
Loading