Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions src/nwp500/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,20 @@ def _on_connection_resumed_internal(
)
)

# Send any queued commands
if self.config.enable_command_queue and self._command_queue:
# When the broker starts a clean session (session_present=False), all
# previous subscriptions have been dropped server-side. We must
# re-establish them before any device data can flow. This covers the
# common case where the AWS IoT SDK auto-reconnects internally before
# the MqttReconnectionHandler fires its own reconnect path — in that
# scenario the reconnect handler sees _connected==True and exits early,
# so resubscribe_all() would never be called without this block.
#
# When session_present=False, we must resubscribe before sending queued
# commands to ensure subscriptions are restored before device responses
# are processed. Use a composite coroutine to enforce ordering.
if not session_present and self._subscription_manager:
self._schedule_coroutine(self._handle_clean_session_resume())
elif self.config.enable_command_queue and self._command_queue:
self._schedule_coroutine(self._send_queued_commands_internal())
Comment on lines +368 to 381
Comment on lines +367 to 381

async def _send_queued_commands_internal(self) -> None:
Expand All @@ -377,6 +389,29 @@ async def _send_queued_commands_internal(self) -> None:
self._connection_manager.publish, lambda: self._connected
)

async def _handle_clean_session_resume(self) -> None:
"""
Handle clean session reconnection with ordered resubscription.

When session_present=False (clean session), the broker has dropped all
subscriptions. This method ensures subscriptions are restored BEFORE
sending any queued commands, preventing commands from being processed
before their subscriptions are re-established.
"""
if not self._subscription_manager or not self._connection_manager:
return

if not self._connection_manager.connection:
return

self._subscription_manager.update_connection(
self._connection_manager.connection
)
await self._subscription_manager.resubscribe_all()

if self.config.enable_command_queue and self._command_queue:
await self._send_queued_commands_internal()

async def _active_reconnect(self) -> None:
"""
Actively trigger a reconnection attempt.
Expand Down
6 changes: 0 additions & 6 deletions src/nwp500/mqtt/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,6 @@ async def periodic_request() -> None:
await self._request_device_info(device)
elif request_type == PeriodicRequestType.DEVICE_STATUS:
await self._request_device_status(device)
else:
_logger.error(
"Unknown periodic request type: %s",
request_type,
)
break

_logger.debug(
"Sent periodic %s request for %s",
Expand Down
190 changes: 190 additions & 0 deletions tests/test_mqtt_clean_session_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Tests for MQTT client clean session reconnection handling."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from nwp500.auth import AuthenticationResponse, AuthTokens, UserInfo
from nwp500.mqtt import NavienMqttClient


@pytest.fixture
def auth_client_with_valid_tokens():
"""Create an auth client with valid tokens."""
from nwp500.auth import NavienAuthClient

auth_client = NavienAuthClient("test@example.com", "password")
valid_tokens = AuthTokens(
id_token="test_id",
access_token="test_access",
refresh_token="test_refresh",
authentication_expires_in=3600,
access_key_id="test_key_id",
secret_key="test_secret_key",
session_token="test_session",
authorization_expires_in=3600,
)
auth_client._auth_response = AuthenticationResponse(
user_info=UserInfo(user_first_name="Test", user_last_name="User"),
tokens=valid_tokens,
)
return auth_client


class TestMqttCleanSessionResume:
"""Tests for clean session (session_present=False) reconnection handling."""

@pytest.mark.asyncio(loop_scope="function")
async def test_on_connection_resumed_with_clean_session_resubscribes(
self, auth_client_with_valid_tokens
):
"""Resubscribe when session_present=False on connection resume."""
client = NavienMqttClient(auth_client_with_valid_tokens)

# Mock the components
mock_subscription_manager = AsyncMock()
mock_subscription_manager.resubscribe_all = AsyncMock()
client._subscription_manager = mock_subscription_manager

mock_connection_manager = MagicMock()
mock_connection = MagicMock()
mock_connection_manager.connection = mock_connection
client._connection_manager = mock_connection_manager

# Mock the event emitter and diagnostics
client.emit = AsyncMock()
client._diagnostics = MagicMock()
client._diagnostics.record_connection_success = AsyncMock()

# Call with session_present=False (clean session)
client._on_connection_resumed_internal(
connection=mock_connection, return_code=0, session_present=False
)

# Give the scheduled coroutine time to run
import asyncio

await asyncio.sleep(0.1)

# Verify resubscribe_all was called
mock_subscription_manager.update_connection.assert_called_once_with(
mock_connection
)
# The resubscribe should be scheduled via _schedule_coroutine
# We need to wait for it or check the internal state

@pytest.mark.asyncio(loop_scope="function")
async def test_resubscribe_before_queued_commands(
self, auth_client_with_valid_tokens
):
"""Resubscribe completes before queued commands are sent."""
client = NavienMqttClient(auth_client_with_valid_tokens)

# Track call order
call_order = []

# Mock the components
mock_subscription_manager = MagicMock()
mock_subscription_manager.resubscribe_all = AsyncMock(
side_effect=lambda: call_order.append("resubscribe")
)
client._subscription_manager = mock_subscription_manager

mock_connection_manager = MagicMock()
mock_connection = MagicMock()
mock_connection_manager.connection = mock_connection
client._connection_manager = mock_connection_manager

# Mock command queue
client._command_queue = AsyncMock()
client.config.enable_command_queue = True

# Mock send_queued_commands to track it's called after resubscribe
original_send = client._send_queued_commands_internal

async def mock_send():
call_order.append("send_queued")
await original_send()

client._send_queued_commands_internal = mock_send

# Call the method
await client._handle_clean_session_resume()

# Verify subscription manager was updated with connection
mock_subscription_manager.update_connection.assert_called_once_with(
mock_connection
)

# Verify resubscribe was called before queued commands
assert call_order == ["resubscribe", "send_queued"]

@pytest.mark.asyncio(loop_scope="function")
async def test_skip_when_no_subscription_manager(
self, auth_client_with_valid_tokens
):
"""Return early if subscription_manager is None."""
client = NavienMqttClient(auth_client_with_valid_tokens)
client._subscription_manager = None

# Should not raise
await client._handle_clean_session_resume()

@pytest.mark.asyncio(loop_scope="function")
async def test_handle_clean_session_resume_skips_when_no_connection(
self, auth_client_with_valid_tokens
):
"""Return early if connection is None."""
client = NavienMqttClient(auth_client_with_valid_tokens)

mock_subscription_manager = MagicMock()
client._subscription_manager = mock_subscription_manager

mock_connection_manager = MagicMock()
mock_connection_manager.connection = None
client._connection_manager = mock_connection_manager

# Should not raise
await client._handle_clean_session_resume()

# Should not try to update connection
mock_subscription_manager.update_connection.assert_not_called()

@pytest.mark.asyncio(loop_scope="function")
async def test_on_connection_resumed_with_session_sends_queued_commands(
self, auth_client_with_valid_tokens
):
"""Send queued commands normally when session_present=True."""
client = NavienMqttClient(auth_client_with_valid_tokens)

# Mock the components
mock_command_queue = AsyncMock()
client._command_queue = mock_command_queue
client.config.enable_command_queue = True

# Mock the event emitter and diagnostics
client.emit = AsyncMock()
client._diagnostics = MagicMock()
client._diagnostics.record_connection_success = AsyncMock()

# Mock connection
mock_connection = MagicMock()

# Patch _send_queued_commands_internal to track if called
with patch.object(
client, "_send_queued_commands_internal", new_callable=AsyncMock
):
# Call with session_present=True (session resumed)
client._on_connection_resumed_internal(
connection=mock_connection, return_code=0, session_present=True
)

# Give the scheduled coroutine time to run
import asyncio

await asyncio.sleep(0.1)

# Verify send_queued_commands_internal was scheduled
# (it will be called through _schedule_coroutine)
Loading