diff --git a/blebox_uniapi/box.py b/blebox_uniapi/box.py index e57f88c..fb18b85 100644 --- a/blebox_uniapi/box.py +++ b/blebox_uniapi/box.py @@ -14,6 +14,7 @@ from .binary_sensor import BinarySensor from .session import ApiHost from .switch import Switch +from .update import Update from .error import ( UnsupportedBoxResponse, @@ -27,7 +28,6 @@ HttpError, ) - DEFAULT_PORT = 80 @@ -105,6 +105,8 @@ def __init__( info, f"{location} has no hardware version" ) from ex + available_firmware_version = info.get("availableFv") + level = int(info.get("apiLevel", _DEFAULT_API_LEVEL)) self._data_path = config["api_path"] @@ -114,6 +116,7 @@ def __init__( self._address = address self._firmware_version = firmware_version self._hardware_version = hardware_version + self._available_firmware_version = available_firmware_version self._api_version = level self._model = config.get("model", type) self._api = config.get("api", {}) @@ -143,6 +146,7 @@ def create_features( ) except KeyError: raise UnsupportedBoxResponse("Failed to initialize:", info) + features["updates"] = [Update(self, "update", {})] return features @classmethod @@ -153,7 +157,9 @@ async def async_from_host(cls, api_host: ApiHost) -> Box: except HttpError: path = "/info" data = await api_host.async_api_get(path) - info = data.get("device", data) # type: ignore + if data is None: + raise UnsupportedBoxResponse("Device returned non-JSON response") + info = data.get("device", data) extended_state = None config = cls._match_device_config(info) @@ -218,6 +224,10 @@ def firmware_version(self) -> Any: def hardware_version(self) -> Any: return self._hardware_version + @property + def available_firmware_version(self) -> Any: + return self._available_firmware_version + @property def api_version(self) -> int: return self._api_version @@ -366,3 +376,19 @@ async def _async_api( response = await self._session.async_api_post(path, post_data) self._update_last_data(response) self._last_real_update = time.time() + + async def async_ota_check(self) -> None: + await self._session.async_api_get_ota("/api/ota/check") + for _ in range(3): + await asyncio.sleep(1) + response = await self._session.async_api_get("/info") + if response is None: + raise UnsupportedBoxResponse("Device returned non-JSON response") + response = response.get("device", response) + if response.get("availableFv") is not None: + self._firmware_version = response.get("fv", self._firmware_version) + self._available_firmware_version = response["availableFv"] + return + + async def async_ota_update(self) -> None: + await self._session.async_api_get("/api/ota/update") diff --git a/blebox_uniapi/session.py b/blebox_uniapi/session.py index d7b80c2..cd72cc4 100644 --- a/blebox_uniapi/session.py +++ b/blebox_uniapi/session.py @@ -46,7 +46,11 @@ def __init__( self._loop = loop async def async_request( - self, path: str, async_method: Any, data: Union[dict, str, None] = None + self, + path: str, + async_method: Any, + data: Union[dict, str, None] = None, + allow_ota_check_response: bool = False, ) -> Optional[dict]: # TODO: check timeout client_timeout = self._timeout @@ -57,7 +61,8 @@ async def async_request( else: response = await async_method(url, timeout=client_timeout) - if response.status != 200: + accepted_statuses = (200, 202, 204) if allow_ota_check_response else (200,) + if response.status not in accepted_statuses: if response.status == 401: raise error.UnauthorizedRequest( f"Request to {url} failed with HTTP {response.status}, UNAUTHORISED" @@ -66,6 +71,9 @@ async def async_request( f"Request to {url} failed with HTTP {response.status}" ) + if response.content_type != "application/json": + return None + return await response.json() except asyncio.TimeoutError as ex: @@ -81,6 +89,11 @@ async def async_request( except aiohttp.ClientError as ex: raise error.ClientError(f"API request {url} failed: {ex}") from ex + except UnicodeDecodeError as ex: + raise error.ConnectionError( + f"Invalid response encoding from {url}: {ex}" + ) from ex + async def async_api_get(self, path: str) -> Optional[dict]: try: return await self.async_request(path, self._session.get) @@ -88,6 +101,15 @@ async def async_api_get(self, path: str) -> Optional[dict]: logger.error(f"EXCEPTION DURING API CALL: {ex}") raise ex + async def async_api_get_ota(self, path: str) -> Optional[dict]: + try: + return await self.async_request( + path, self._session.get, allow_ota_check_response=True + ) + except Exception as ex: + logger.error(f"EXCEPTION DURING API CALL: {ex}") + raise ex + async def async_api_post( self, path: str, data: Union[dict, str, None] ) -> Optional[dict]: diff --git a/blebox_uniapi/update.py b/blebox_uniapi/update.py new file mode 100644 index 0000000..28dfe47 --- /dev/null +++ b/blebox_uniapi/update.py @@ -0,0 +1,23 @@ +from typing import Optional + +from .feature import Feature + + +class Update(Feature): + @property + def installed_version(self) -> Optional[str]: + return self._product.firmware_version + + @property + def latest_version(self) -> Optional[str]: + return self._product.available_firmware_version + + async def async_update(self) -> None: + # OTA state is stored as Box attributes, not in last_data. async_update_data() won't trigger the OTA check. + await self._product.async_ota_check() + + async def async_install(self) -> None: + await self._product.async_ota_update() + + def after_update(self) -> None: + pass diff --git a/tests/conftest.py b/tests/conftest.py index f1ae061..16a10fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,7 +94,9 @@ def __call__(self, url, **kwargs): data = HTTP_MOCKS[self._key][url] response = _json.dumps(data).encode("utf-8") status = 200 - return AiohttpClientMockResponse("GET", url, status, response) + return AiohttpClientMockResponse( + "GET", url, status, response, headers={"content-type": "application/json"} + ) mock.get = AsyncMock(side_effect=EffectWhenGet(mock)) @@ -123,7 +125,9 @@ def __call__(self, url, **kwargs): data = HTTP_MOCKS[self._key][url][params] response = _json.dumps(data).encode("utf-8") status = 200 - return AiohttpClientMockResponse("POST", url, status, response) + return AiohttpClientMockResponse( + "POST", url, status, response, headers={"content-type": "application/json"} + ) mock.post = AsyncMock(side_effect=EffectWhenPost(mock)) diff --git a/tests/test_box.py b/tests/test_box.py index c132816..69e6f6e 100644 --- a/tests/test_box.py +++ b/tests/test_box.py @@ -1,5 +1,6 @@ import pytest from unittest import mock +from unittest.mock import AsyncMock, patch from blebox_uniapi.box import Box from blebox_uniapi import error from blebox_uniapi.jfollow import follow @@ -167,3 +168,62 @@ async def test_field_validations(mock_session, sample_data, config): error.BadFieldNotRGBW, match=r"foobar.field1 is 123 which is not a rgbw string" ): box.check_rgbw("123", "field1") + + +async def test_available_firmware_version_none(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + assert box.available_firmware_version is None + + +async def test_available_firmware_version_set(mock_session, sample_data, config): + sample_data["availableFv"] = "2.0" + box = Box(mock_session, sample_data, config, None) + assert box.available_firmware_version == "2.0" + + +async def test_async_ota_update_calls_session(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + mock_session.async_api_get = AsyncMock(return_value=None) + await box.async_ota_update() + mock_session.async_api_get.assert_called_once_with("/api/ota/update") + + +async def test_async_ota_check_updates_firmware_versions(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + mock_session.async_api_get_ota = AsyncMock(return_value=None) + mock_session.async_api_get = AsyncMock( + return_value={"availableFv": "2.0", "fv": "1.5"} + ) + with patch("asyncio.sleep", new=AsyncMock()): + await box.async_ota_check() + assert box.available_firmware_version == "2.0" + assert box.firmware_version == "1.5" + + +async def test_async_ota_check_unwraps_device_key(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + mock_session.async_api_get_ota = AsyncMock(return_value=None) + mock_session.async_api_get = AsyncMock( + return_value={"device": {"availableFv": "3.0", "fv": "2.5"}} + ) + with patch("asyncio.sleep", new=AsyncMock()): + await box.async_ota_check() + assert box.available_firmware_version == "3.0" + + +async def test_async_ota_check_raises_on_none_info_response(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + mock_session.async_api_get_ota = AsyncMock(return_value=None) + mock_session.async_api_get = AsyncMock(return_value=None) + with patch("asyncio.sleep", new=AsyncMock()): + with pytest.raises(error.UnsupportedBoxResponse): + await box.async_ota_check() + + +async def test_async_ota_check_returns_silently_when_no_available_fv(mock_session, sample_data, config): + box = Box(mock_session, sample_data, config, None) + mock_session.async_api_get_ota = AsyncMock(return_value=None) + mock_session.async_api_get = AsyncMock(return_value={"fv": "1.0"}) + with patch("asyncio.sleep", new=AsyncMock()): + await box.async_ota_check() + assert box.available_firmware_version is None diff --git a/tests/test_session.py b/tests/test_session.py index fbea369..a21d8ad 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -29,6 +29,7 @@ def client(): def valid_response(): response = Mock(spec_set=aiohttp.ClientResponse) response.status = 200 + response.content_type = "application/json" response.text = AsyncMock(return_value="foobar") response.json = AsyncMock(return_value=123) return response @@ -153,3 +154,76 @@ async def test_session_provides_a_logger(logger, client): api_session = Session("127.0.0.4", "88", 2, client, None, logger) api_session.logger.debug("foobar") logger.debug.assert_called_once_with("foobar") + + +def ota_accepted_response(status=202): + response = Mock(spec_set=aiohttp.ClientResponse) + response.status = status + response.content_type = None + return response + + +def unauthorized_response(): + response = Mock(spec_set=aiohttp.ClientResponse) + response.status = 401 + return response + + +def non_json_response(): + response = Mock(spec_set=aiohttp.ClientResponse) + response.status = 200 + response.content_type = None + return response + + +async def test_session_api_get_ota_accepts_202_without_body(logger, client): + client.get = AsyncMock(return_value=ota_accepted_response(status=202)) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + result = await api_session.async_api_get_ota("/api/ota/check") + assert result is None + + +async def test_session_api_get_ota_accepts_204_no_content(logger, client): + client.get = AsyncMock(return_value=ota_accepted_response(status=204)) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + result = await api_session.async_api_get_ota("/api/ota/check") + assert result is None + + +async def test_session_api_get_ota_rejects_400(logger, client): + client.get = AsyncMock(return_value=bad_http_response()) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + with pytest.raises(error.HttpError): + await api_session.async_api_get_ota("/api/ota/check") + + +async def test_session_api_get_unauthorized(logger, client): + client.get = AsyncMock(return_value=unauthorized_response()) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + with pytest.raises(error.UnauthorizedRequest): + await api_session.async_api_get("/api/device/state") + + +async def test_session_api_get_ota_unauthorized(logger, client): + client.get = AsyncMock(return_value=unauthorized_response()) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + with pytest.raises(error.UnauthorizedRequest): + await api_session.async_api_get_ota("/api/ota/check") + + +async def test_session_api_get_non_json_returns_none(logger, client): + client.get = AsyncMock(return_value=non_json_response()) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + result = await api_session.async_api_get("/api/device/state") + assert result is None + + +async def test_session_api_get_unicode_decode_error_raises_connection_error(logger, client): + response = Mock(spec_set=aiohttp.ClientResponse) + response.status = 200 + response.content_type = "application/json" + response.json = AsyncMock(side_effect=UnicodeDecodeError("utf-8", b"", 0, 1, "invalid")) + client.get = AsyncMock(return_value=response) + api_session = Session("127.0.0.4", "88", 2, client, None, logger) + with pytest.raises(error.ConnectionError, match="Invalid response encoding"): + await api_session.async_api_get("/api/device/state") diff --git a/tests/test_update.py b/tests/test_update.py new file mode 100644 index 0000000..53c7fcf --- /dev/null +++ b/tests/test_update.py @@ -0,0 +1,66 @@ +"""Tests for the Update feature.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from blebox_uniapi.update import Update +from blebox_uniapi import error + + +@pytest.fixture +def mock_product(): + product = MagicMock() + product.firmware_version = "0.176" + product.available_firmware_version = None + return product + + +@pytest.fixture +def update(mock_product): + return Update(mock_product, "update", {}) + + +class TestInstalledVersion: + def test_returns_firmware_version(self, update, mock_product): + assert update.installed_version == "0.176" + + def test_returns_none_when_not_set(self, update, mock_product): + mock_product.firmware_version = None + assert update.installed_version is None + + +class TestLatestVersion: + def test_returns_none_when_not_set(self, update, mock_product): + assert update.latest_version is None + + def test_returns_available_firmware_version(self, update, mock_product): + mock_product.available_firmware_version = "1.2.3" + assert update.latest_version == "1.2.3" + + +class TestAsyncUpdate: + async def test_calls_ota_check_on_product(self, update, mock_product): + mock_product.async_ota_check = AsyncMock() + await update.async_update() + mock_product.async_ota_check.assert_called_once() + + async def test_propagates_connection_error_from_ota_check(self, update, mock_product): + mock_product.async_ota_check = AsyncMock( + side_effect=error.ConnectionError("connection refused") + ) + with pytest.raises(error.ConnectionError): + await update.async_update() + + +class TestAsyncInstall: + async def test_calls_ota_update_on_product(self, update, mock_product): + mock_product.async_ota_update = AsyncMock() + await update.async_install() + mock_product.async_ota_update.assert_called_once() + + async def test_propagates_connection_error_from_ota_update(self, update, mock_product): + mock_product.async_ota_update = AsyncMock( + side_effect=error.ConnectionError("connection refused") + ) + with pytest.raises(error.ConnectionError): + await update.async_install()