diff --git a/linux_voice_assistant/__main__.py b/linux_voice_assistant/__main__.py index 8ff6b7e..f0a3dc3 100644 --- a/linux_voice_assistant/__main__.py +++ b/linux_voice_assistant/__main__.py @@ -21,6 +21,7 @@ ListEntitiesRequest, MediaPlayerCommandRequest, SubscribeHomeAssistantStatesRequest, + SwitchCommandRequest, VoiceAssistantAnnounceFinished, VoiceAssistantAnnounceRequest, VoiceAssistantAudio, @@ -40,7 +41,7 @@ from google.protobuf import message from .api_server import APIServer -from .entity import ESPHomeEntity, MediaPlayerEntity +from .entity import ESPHomeEntity, MediaPlayerEntity, MuteSwitchEntity from .microwakeword import MicroWakeWord from .mpv_player import MpvMediaPlayer from .util import call_all, get_mac, is_arm @@ -80,6 +81,7 @@ class ServerState: timer_finished_sound: str media_player_entity: Optional[MediaPlayerEntity] = None satellite: "Optional[VoiceSatelliteProtocol]" = None + muted: bool = False # ----------------------------------------------------------------------------- @@ -104,12 +106,44 @@ def __init__(self, state: ServerState) -> None: ) self.state.entities.append(self.state.media_player_entity) + # Add mute switch entity (like ESPHome Voice PE) + mute_switch = MuteSwitchEntity( + server=self, + key=len(state.entities), + name="Mute", + object_id="mute", + get_muted=lambda: self.state.muted, + set_muted=self._set_muted, + ) + self.state.entities.append(mute_switch) + self._is_streaming_audio = False self._tts_url: Optional[str] = None self._tts_played = False self._continue_conversation = False self._timer_finished = False + def _set_muted(self, new_state: bool) -> None: + """Set mute state - behaves like ESPHome Voice PE mute switch. + + When muted (True): Behaves like voice_assistant.stop + When unmuted (False): Behaves like voice_assistant.start_continuous + """ + self.state.muted = bool(new_state) + + if self.state.muted: + # voice_assistant.stop behavior + _LOGGER.debug("Muting voice assistant (voice_assistant.stop)") + self._is_streaming_audio = False + self.state.tts_player.stop() + # Stop any ongoing voice processing + self.state.stop_word.is_active = False + else: + # voice_assistant.start_continuous behavior + _LOGGER.debug("Unmuting voice assistant (voice_assistant.start_continuous)") + # Resume normal operation - wake word detection will be active again + pass + def handle_voice_event( self, event_type: VoiceAssistantEventType, data: Dict[str, str] ) -> None: @@ -203,6 +237,7 @@ def handle_message(self, msg: message.Message) -> Iterable[message.Message]: ListEntitiesRequest, SubscribeHomeAssistantStatesRequest, MediaPlayerCommandRequest, + SwitchCommandRequest, ), ): for entity in self.state.entities: @@ -246,7 +281,7 @@ def handle_message(self, msg: message.Message) -> Iterable[message.Message]: def handle_audio(self, audio_chunk: bytes) -> None: - if not self._is_streaming_audio: + if not self._is_streaming_audio or self.state.muted: return self.send_messages([VoiceAssistantAudio(data=audio_chunk)]) @@ -259,6 +294,10 @@ def wakeup(self) -> None: _LOGGER.debug("Stopping timer finished sound") return + if self.state.muted: + # Don't respond to wake words when muted (voice_assistant.stop behavior) + return + wake_word_phrase = self.state.wake_word.wake_word _LOGGER.debug("Detected wake word: %s", wake_word_phrase) self.send_messages( @@ -341,15 +380,17 @@ def process_audio(state: ServerState): try: state.satellite.handle_audio(audio_chunk) - if state.wake_word.is_active and state.wake_word.process_streaming( - audio_chunk - ): - state.satellite.wakeup() - - if state.stop_word.is_active and state.stop_word.process_streaming( - audio_chunk - ): - state.satellite.stop() + # Skip wake word and stop word processing when muted (voice_assistant.stop behavior) + if not state.muted: + if state.wake_word.is_active and state.wake_word.process_streaming( + audio_chunk + ): + state.satellite.wakeup() + + if state.stop_word.is_active and state.stop_word.process_streaming( + audio_chunk + ): + state.satellite.stop() except Exception: _LOGGER.exception("Unexpected error handling audio") diff --git a/linux_voice_assistant/entity.py b/linux_voice_assistant/entity.py index d2e455c..b2f3360 100644 --- a/linux_voice_assistant/entity.py +++ b/linux_voice_assistant/entity.py @@ -6,8 +6,11 @@ from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined] ListEntitiesMediaPlayerResponse, ListEntitiesRequest, + ListEntitiesSwitchResponse, MediaPlayerCommandRequest, MediaPlayerStateResponse, + SwitchCommandRequest, + SwitchStateResponse, SubscribeHomeAssistantStatesRequest, ) from aioesphomeapi.model import MediaPlayerCommand, MediaPlayerState @@ -131,3 +134,49 @@ def _get_state_message(self) -> MediaPlayerStateResponse: volume=self.volume, muted=self.muted, ) + + +# ----------------------------------------------------------------------------- + + +class MuteSwitchEntity(ESPHomeEntity): + """Mute switch entity that behaves like ESPHome Voice PE mute switch. + + This switch maintains its own state and triggers voice_assistant.stop/start_continuous actions. + """ + + def __init__( + self, + server: APIServer, + key: int, + name: str, + object_id: str, + get_muted: Callable[[], bool], + set_muted: Callable[[bool], None], + ) -> None: + ESPHomeEntity.__init__(self, server) + + self.key = key + self.name = name + self.object_id = object_id + self._get_muted = get_muted + self._set_muted = set_muted + self._switch_state = False # Internal switch state + + def handle_message(self, msg: message.Message) -> Iterable[message.Message]: + if isinstance(msg, SwitchCommandRequest) and (msg.key == self.key): + # User toggled the switch - update our internal state and trigger actions + new_state = bool(msg.state) + self._switch_state = new_state + self._set_muted(new_state) + # Return the new state immediately + yield SwitchStateResponse(key=self.key, state=self._switch_state) + elif isinstance(msg, ListEntitiesRequest): + yield ListEntitiesSwitchResponse( + object_id=self.object_id, + key=self.key, + name=self.name, + ) + elif isinstance(msg, SubscribeHomeAssistantStatesRequest): + # Always return our internal switch state + yield SwitchStateResponse(key=self.key, state=self._switch_state)