From ee5a0e89c686476a10c3f9c7458e6e520a644dd5 Mon Sep 17 00:00:00 2001 From: Kydoimos97 Date: Tue, 9 Jun 2026 15:01:52 -0600 Subject: [PATCH] feat: auto-ARN secret resolution and job lifecycle notify functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - resolve_secret now auto-detects {env_var}_ARN as a fallback when the env var is empty and no explicit ARN is passed (e.g. WRENCH_SERVICE_SECRET_ARN). This fixes silent notification failures across all services that use the ARN-only secret pattern. - New _notify.py module adds job_register, job_update, job_close — three fire-and-forget wrappers over the AiAxis job lifecycle endpoints with the same never-raise contract as slack_post. - All three functions exported from WrenchCL.Wrench for use in FeatureForge, elt-api-requester, model-input-builder and other Wrench services. - Comprehensive test coverage: 8 tests for auto-ARN fallback behaviour, 39 tests for the three notify functions, all passing. --- WrenchCL/Wrench/__init__.py | 3 +- WrenchCL/Wrench/_notify.py | 257 +++++++++++++++++++ WrenchCL/Wrench/_secret.py | 19 +- tests/test_wrench_notify.py | 483 ++++++++++++++++++++++++++++++++++++ tests/test_wrench_secret.py | 96 +++++++ 5 files changed, 856 insertions(+), 2 deletions(-) create mode 100644 WrenchCL/Wrench/_notify.py create mode 100644 tests/test_wrench_notify.py diff --git a/WrenchCL/Wrench/__init__.py b/WrenchCL/Wrench/__init__.py index 41b91d1..e16c492 100644 --- a/WrenchCL/Wrench/__init__.py +++ b/WrenchCL/Wrench/__init__.py @@ -4,7 +4,8 @@ """Wrench internal API helpers for service-to-service communication.""" +from ._notify import job_close, job_register, job_update from ._secret import resolve_secret from ._slack import slack_post -__all__ = ["resolve_secret", "slack_post"] +__all__ = ["resolve_secret", "slack_post", "job_register", "job_update", "job_close"] diff --git a/WrenchCL/Wrench/_notify.py b/WrenchCL/Wrench/_notify.py new file mode 100644 index 0000000..c753a15 --- /dev/null +++ b/WrenchCL/Wrench/_notify.py @@ -0,0 +1,257 @@ +# Copyright (c) 2024-2025. +# Author: Willem van der Schans. +# Licensed under the MIT License (https://opensource.org/license/mit). + +""" +Wrench job lifecycle notification helpers. + +Three functions cover the full job lifecycle: + job_register — called when a job starts, returns a job_id + job_update — called during progress (optional, rate-limited to 1/s per job) + job_close — called when a job finishes (success or failure) + +All functions are fire-and-forget on failure: they never raise, log warnings +on any error, and return None / False on failure. job_register returns the +job_id string on success or None on failure; callers should skip update/close +if job_id is None. + +Secret resolution uses the same priority as slack_post: + 1. explicit service_secret param + 2. WRENCH_SERVICE_SECRET env var + 3. WRENCH_SERVICE_SECRET_ARN env var (auto-derived convention) + 4. explicit secret_arn param +""" + +import os +from typing import Optional + +import requests + +from .. import logger +from ._secret import resolve_secret + + +def job_register( + workspace_id: str, + message: str, + source: str, + *, + service_name: Optional[str] = None, + processor_name: Optional[str] = None, + reference: Optional[str] = None, + base_url: Optional[str] = None, + service_secret: Optional[str] = None, + secret_env_var: str = "WRENCH_SERVICE_SECRET", + secret_arn: Optional[str] = None, + timeout: int = 10, +) -> Optional[str]: + """ + Register a job with the Wrench notification system. + + Call once when the job begins. Store the returned job_id and pass it to + job_update and job_close. Returns None on any failure — callers should + skip update/close if job_id is None. + + Parameters + ---------- + workspace_id : str + The workspace (client) UUID the job is running for. + message : str + Human-readable job description shown to the user. + source : str + Service identifier, e.g. "elt", "featureforge", "lead_score". + service_name : str, optional + Internal service name. Defaults to source. + processor_name : str, optional + Internal processor name. Defaults to source. + reference : str, optional + External reference key. Defaults to workspace_id. + + Returns + ------- + str or None + The job_id UUID string on success, None on failure. + """ + resolved_secret = resolve_secret(value=service_secret, env_var=secret_env_var, arn=secret_arn) + if not resolved_secret: + logger.warning("job_register: no service secret available — job not registered") + return None + + if base_url is None: + base_url = os.environ.get("WRENCH_API_BASE_URL", "https://api.v2.wrench.ai") + + payload = { + "workspace_id": str(workspace_id), + "message": message, + "source": source, + } + if service_name is not None: + payload["service_name"] = service_name + if processor_name is not None: + payload["processor_name"] = processor_name + if reference is not None: + payload["reference"] = reference + + try: + response = requests.post( + f"{base_url}/events/jobs/internal/register", + json=payload, + headers={"x-api-secret": resolved_secret, "Content-Type": "application/json"}, + timeout=timeout, + ) + if response.ok: + data = response.json() + job_id = (data.get("data") or data).get("job_id") + if job_id: + return str(job_id) + logger.warning("job_register: response OK but no job_id in body") + return None + logger.warning(f"job_register: API returned {response.status_code}: {response.text[:200]}") + return None + except requests.RequestException as exc: + logger.warning(f"job_register: request failed: {exc}") + return None + except Exception as exc: + logger.warning(f"job_register: unexpected error: {exc}") + return None + + +def job_update( + job_id: str, + progress: int, + description: Optional[str] = None, + *, + workspace_id: Optional[str] = None, + base_url: Optional[str] = None, + service_secret: Optional[str] = None, + secret_env_var: str = "WRENCH_SERVICE_SECRET", + secret_arn: Optional[str] = None, + timeout: int = 10, +) -> bool: + """ + Update job progress. Rate-limited to 1 call/second per job_id server-side. + + Parameters + ---------- + job_id : str + The job_id returned by job_register. + progress : int + Progress percentage 0–100. + description : str, optional + Human-readable progress label shown alongside the percentage. + + Returns + ------- + bool + True on success, False on any failure. + """ + if not job_id: + return False + + resolved_secret = resolve_secret(value=service_secret, env_var=secret_env_var, arn=secret_arn) + if not resolved_secret: + logger.warning("job_update: no service secret available — progress not sent") + return False + + if base_url is None: + base_url = os.environ.get("WRENCH_API_BASE_URL", "https://api.v2.wrench.ai") + + payload: dict = {"job_id": str(job_id), "progress": max(0, min(100, progress))} + if description is not None: + payload["progress_description"] = description + if workspace_id is not None: + payload["workspace_id"] = str(workspace_id) + + try: + response = requests.patch( + f"{base_url}/events/jobs/internal/update", + json=payload, + headers={"x-api-secret": resolved_secret, "Content-Type": "application/json"}, + timeout=timeout, + ) + if response.ok or response.status_code == 429: # 429 = rate limited, not an error + return True + logger.warning(f"job_update: API returned {response.status_code}: {response.text[:200]}") + return False + except requests.RequestException as exc: + logger.warning(f"job_update: request failed: {exc}") + return False + except Exception as exc: + logger.warning(f"job_update: unexpected error: {exc}") + return False + + +def job_close( + job_id: str, + workspace_id: str, + status_code: int, + message: str, + source: str, + *, + notify: bool = True, + base_url: Optional[str] = None, + service_secret: Optional[str] = None, + secret_env_var: str = "WRENCH_SERVICE_SECRET", + secret_arn: Optional[str] = None, + timeout: int = 10, +) -> bool: + """ + Close a job. Call once when the job finishes — success, error, or cancelled. + + Parameters + ---------- + job_id : str + The job_id returned by job_register. + workspace_id : str + The workspace (client) UUID. + status_code : int + HTTP-style status: 200 (success), 500 (error), 499 (cancelled). + message : str + Completion message shown to the user. + source : str + Same source string used in job_register. + notify : bool + Whether to create a user-visible notification. Default True. + + Returns + ------- + bool + True on success, False on any failure. + """ + if not job_id: + return False + + resolved_secret = resolve_secret(value=service_secret, env_var=secret_env_var, arn=secret_arn) + if not resolved_secret: + logger.warning("job_close: no service secret available — job not closed") + return False + + if base_url is None: + base_url = os.environ.get("WRENCH_API_BASE_URL", "https://api.v2.wrench.ai") + + payload = { + "job_id": str(job_id), + "workspace_id": str(workspace_id), + "status_code": status_code, + "message": message, + "source": source, + "notify": notify, + } + + try: + response = requests.post( + f"{base_url}/events/jobs/internal/close", + json=payload, + headers={"x-api-secret": resolved_secret, "Content-Type": "application/json"}, + timeout=timeout, + ) + if response.ok: + return True + logger.warning(f"job_close: API returned {response.status_code}: {response.text[:200]}") + return False + except requests.RequestException as exc: + logger.warning(f"job_close: request failed: {exc}") + return False + except Exception as exc: + logger.warning(f"job_close: unexpected error: {exc}") + return False diff --git a/WrenchCL/Wrench/_secret.py b/WrenchCL/Wrench/_secret.py index 35e1c03..cf8d8a0 100644 --- a/WrenchCL/Wrench/_secret.py +++ b/WrenchCL/Wrench/_secret.py @@ -29,7 +29,12 @@ def resolve_secret( 2. Environment variable: If `value` is None or empty, the environment variable specified by `env_var` is checked. If it exists and is non-empty, its value is returned. - 3. AWS Secrets Manager ARN: If both `value` and the environment variable fail to + 3. Auto-derived ARN environment variable: If `env_var` is provided and empty, + the function checks for `{env_var}_ARN` (e.g., if env_var="WRENCH_SERVICE_SECRET", + it checks "WRENCH_SERVICE_SECRET_ARN"). If this ARN env var is set, it is used + for AWS Secrets Manager fetch. This enables services using ARN-only patterns + without requiring callers to pass explicit ARN parameters. + 4. AWS Secrets Manager ARN: If both `value` and the environment variables fail to resolve, the function attempts to fetch the secret from AWS Secrets Manager using the provided ARN and region. @@ -79,6 +84,12 @@ def resolve_secret( >>> print(secret) "env-secret" + Auto-ARN convention (ARN injected via env var): + + >>> os.environ["WRENCH_SERVICE_SECRET_ARN"] = "arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret" + >>> secret = resolve_secret(env_var="WRENCH_SERVICE_SECRET") + >>> # Fetches from SM because WRENCH_SERVICE_SECRET is empty but _ARN is set + ARN resolution (requires boto3): >>> secret = resolve_secret(arn="arn:aws:secretsmanager:us-east-1:123456789:secret:my-secret") @@ -115,6 +126,12 @@ def resolve_secret( if env_value is not None and env_value: return env_value + # Auto-derive ARN from env var name convention when no explicit ARN was passed + if arn is None and env_var is not None: + auto_arn = os.environ.get(f"{env_var}_ARN") + if auto_arn: + arn = auto_arn + if arn is not None: try: import boto3 diff --git a/tests/test_wrench_notify.py b/tests/test_wrench_notify.py new file mode 100644 index 0000000..5b5768e --- /dev/null +++ b/tests/test_wrench_notify.py @@ -0,0 +1,483 @@ +import os +from unittest.mock import MagicMock, patch + +import pytest + +from WrenchCL.Wrench._notify import job_close, job_register, job_update + +pytestmark = pytest.mark.skipif(False, reason="datadog_itr_unskippable") + + +class TestJobRegister: + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_returns_job_id_on_success(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {"job_id": "uuid-123"}} + mock_post.return_value = mock_response + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + assert result == "uuid-123" + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_handles_nested_response(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"job_id": "uuid-456"} + mock_post.return_value = mock_response + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + assert result == "uuid-456" + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_returns_none_on_missing_job_id(self, mock_post, mock_logger): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {}} + mock_post.return_value = mock_response + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + assert result is None + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_returns_none_on_api_error(self, mock_post, mock_logger): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 500 + mock_response.text = "Internal server error" + mock_post.return_value = mock_response + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + assert result is None + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args[0][0] + assert "500" in call_args + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_returns_none_on_request_exception(self, mock_post, mock_logger): + import requests + mock_post.side_effect = requests.RequestException("Connection failed") + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + assert result is None + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + @patch.dict("os.environ", {}, clear=False) + def test_job_register_returns_none_when_secret_missing(self, mock_logger): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt" + ) + + assert result is None + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_sends_correct_payload(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {"job_id": "uuid-123"}} + mock_post.return_value = mock_response + + job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_name="elt_processor", + processor_name="custom_processor", + reference="ref-456", + service_secret="test-secret" + ) + + call_kwargs = mock_post.call_args[1] + payload = call_kwargs["json"] + assert payload["workspace_id"] == "ws-123" + assert payload["message"] == "Processing data" + assert payload["source"] == "elt" + assert payload["service_name"] == "elt_processor" + assert payload["processor_name"] == "custom_processor" + assert payload["reference"] == "ref-456" + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_uses_default_base_url(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {"job_id": "uuid-123"}} + mock_post.return_value = mock_response + + job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret" + ) + + call_args = mock_post.call_args[0] + endpoint = call_args[0] + assert "api.v2.wrench.ai" in endpoint + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_uses_custom_base_url(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {"job_id": "uuid-123"}} + mock_post.return_value = mock_response + + job_register( + workspace_id="ws-123", + message="Processing data", + source="elt", + service_secret="test-secret", + base_url="https://api.qa.wrench.ai" + ) + + call_args = mock_post.call_args[0] + endpoint = call_args[0] + assert "api.qa.wrench.ai" in endpoint + + +class TestJobUpdate: + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_returns_true_on_success(self, mock_patch): + mock_response = MagicMock() + mock_response.ok = True + mock_patch.return_value = mock_response + + result = job_update( + job_id="uuid-123", + progress=50, + service_secret="test-secret" + ) + + assert result is True + + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_returns_true_on_rate_limit(self, mock_patch): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 429 + mock_patch.return_value = mock_response + + result = job_update( + job_id="uuid-123", + progress=50, + service_secret="test-secret" + ) + + assert result is True # Rate limit is not an error + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_returns_false_on_api_error(self, mock_patch, mock_logger): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 500 + mock_response.text = "Internal server error" + mock_patch.return_value = mock_response + + result = job_update( + job_id="uuid-123", + progress=50, + service_secret="test-secret" + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_returns_false_on_request_exception(self, mock_patch, mock_logger): + import requests + mock_patch.side_effect = requests.RequestException("Connection failed") + + result = job_update( + job_id="uuid-123", + progress=50, + service_secret="test-secret" + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + def test_job_update_returns_false_when_job_id_empty(self, mock_logger): + result = job_update( + job_id="", + progress=50, + service_secret="test-secret" + ) + + assert result is False + + @patch("WrenchCL.Wrench._notify.logger") + @patch.dict("os.environ", {}, clear=False) + def test_job_update_returns_false_when_secret_missing(self, mock_logger): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + + result = job_update( + job_id="uuid-123", + progress=50 + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_clamps_progress(self, mock_patch): + mock_response = MagicMock() + mock_response.ok = True + mock_patch.return_value = mock_response + + job_update( + job_id="uuid-123", + progress=150, + service_secret="test-secret" + ) + + call_kwargs = mock_patch.call_args[1] + payload = call_kwargs["json"] + assert payload["progress"] == 100 + + mock_patch.reset_mock() + + job_update( + job_id="uuid-123", + progress=-10, + service_secret="test-secret" + ) + + call_kwargs = mock_patch.call_args[1] + payload = call_kwargs["json"] + assert payload["progress"] == 0 + + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_includes_description_when_provided(self, mock_patch): + mock_response = MagicMock() + mock_response.ok = True + mock_patch.return_value = mock_response + + job_update( + job_id="uuid-123", + progress=50, + description="Processing step 2", + service_secret="test-secret" + ) + + call_kwargs = mock_patch.call_args[1] + payload = call_kwargs["json"] + assert payload["progress_description"] == "Processing step 2" + + @patch("WrenchCL.Wrench._notify.requests.patch") + def test_job_update_includes_workspace_id_when_provided(self, mock_patch): + mock_response = MagicMock() + mock_response.ok = True + mock_patch.return_value = mock_response + + job_update( + job_id="uuid-123", + progress=50, + workspace_id="ws-456", + service_secret="test-secret" + ) + + call_kwargs = mock_patch.call_args[1] + payload = call_kwargs["json"] + assert payload["workspace_id"] == "ws-456" + + +class TestJobClose: + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_close_returns_true_on_success(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_post.return_value = mock_response + + result = job_close( + job_id="uuid-123", + workspace_id="ws-123", + status_code=200, + message="Job completed", + source="elt", + service_secret="test-secret" + ) + + assert result is True + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_close_returns_false_on_api_error(self, mock_post, mock_logger): + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 500 + mock_response.text = "Internal server error" + mock_post.return_value = mock_response + + result = job_close( + job_id="uuid-123", + workspace_id="ws-123", + status_code=200, + message="Job completed", + source="elt", + service_secret="test-secret" + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_close_returns_false_on_request_exception(self, mock_post, mock_logger): + import requests + mock_post.side_effect = requests.RequestException("Connection failed") + + result = job_close( + job_id="uuid-123", + workspace_id="ws-123", + status_code=200, + message="Job completed", + source="elt", + service_secret="test-secret" + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.logger") + def test_job_close_returns_false_when_job_id_empty(self, mock_logger): + result = job_close( + job_id="", + workspace_id="ws-123", + status_code=200, + message="Job completed", + source="elt", + service_secret="test-secret" + ) + + assert result is False + + @patch("WrenchCL.Wrench._notify.logger") + @patch.dict("os.environ", {}, clear=False) + def test_job_close_returns_false_when_secret_missing(self, mock_logger): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + + result = job_close( + job_id="uuid-123", + workspace_id="ws-123", + status_code=200, + message="Job completed", + source="elt" + ) + + assert result is False + mock_logger.warning.assert_called() + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_close_sends_correct_payload(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_post.return_value = mock_response + + job_close( + job_id="uuid-123", + workspace_id="ws-456", + status_code=500, + message="Job failed with error", + source="featureforge", + notify=False, + service_secret="test-secret" + ) + + call_kwargs = mock_post.call_args[1] + payload = call_kwargs["json"] + assert payload["job_id"] == "uuid-123" + assert payload["workspace_id"] == "ws-456" + assert payload["status_code"] == 500 + assert payload["message"] == "Job failed with error" + assert payload["source"] == "featureforge" + assert payload["notify"] is False + + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_close_default_notify_is_true(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_post.return_value = mock_response + + job_close( + job_id="uuid-123", + workspace_id="ws-456", + status_code=200, + message="Job completed", + source="elt", + service_secret="test-secret" + ) + + call_kwargs = mock_post.call_args[1] + payload = call_kwargs["json"] + assert payload["notify"] is True + + +class TestAutoArnIntegration: + @patch("WrenchCL.Wrench._notify.requests.post") + def test_job_register_uses_auto_arn_env_var(self, mock_post): + mock_response = MagicMock() + mock_response.ok = True + mock_response.json.return_value = {"data": {"job_id": "uuid-123"}} + mock_post.return_value = mock_response + + with patch.dict("os.environ", { + "WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:123456789:secret:test" + }, clear=False): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + + with patch("WrenchCL.Wrench._notify.resolve_secret") as mock_resolve: + mock_resolve.return_value = "auto-arn-resolved-secret" + result = job_register( + workspace_id="ws-123", + message="Processing data", + source="elt" + ) + + assert result == "uuid-123" + mock_resolve.assert_called_once() diff --git a/tests/test_wrench_secret.py b/tests/test_wrench_secret.py index 504afb8..8fc2aea 100644 --- a/tests/test_wrench_secret.py +++ b/tests/test_wrench_secret.py @@ -221,6 +221,102 @@ def test_logs_warning_with_custom_env_var_name(self, mock_logger): assert "no secret could be resolved" in call_args +class TestAutoArnFallback: + def test_auto_arn_fallback_when_env_var_empty(self): + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.get_secret_value.return_value = {"SecretString": "auto-arn-secret"} + + with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:123456789:secret:test"}, clear=False): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var="WRENCH_SERVICE_SECRET") + + assert result == "auto-arn-secret" + mock_boto3.client.assert_called_once_with("secretsmanager", region_name="us-east-1") + + def test_auto_arn_fallback_with_custom_region(self): + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.get_secret_value.return_value = {"SecretString": "auto-arn-secret"} + + with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-west-2:123456789:secret:test"}, clear=False): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var="WRENCH_SERVICE_SECRET", region="us-west-2") + + assert result == "auto-arn-secret" + mock_boto3.client.assert_called_once_with("secretsmanager", region_name="us-west-2") + + def test_auto_arn_does_not_override_explicit_arn(self): + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.get_secret_value.return_value = {"SecretString": "explicit-arn-secret"} + + with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:111111111:secret:auto"}, clear=False): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret( + env_var="WRENCH_SERVICE_SECRET", + arn="arn:aws:secretsmanager:us-east-1:123456789:secret:explicit" + ) + + assert result == "explicit-arn-secret" + mock_client.get_secret_value.assert_called_once_with(SecretId="arn:aws:secretsmanager:us-east-1:123456789:secret:explicit") + + def test_auto_arn_does_not_apply_when_env_var_has_value(self): + mock_boto3 = MagicMock() + with patch.dict("os.environ", { + "WRENCH_SERVICE_SECRET": "env-secret", + "WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:123456789:secret:test" + }, clear=False): + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var="WRENCH_SERVICE_SECRET") + + assert result == "env-secret" + mock_boto3.client.assert_not_called() + + def test_auto_arn_with_custom_env_var_name(self): + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.get_secret_value.return_value = {"SecretString": "custom-auto-arn-secret"} + + with patch.dict("os.environ", {"CUSTOM_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:123456789:secret:custom"}, clear=False): + if "CUSTOM_SECRET" in os.environ: + del os.environ["CUSTOM_SECRET"] + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var="CUSTOM_SECRET") + + assert result == "custom-auto-arn-secret" + + def test_auto_arn_skipped_when_env_var_is_none(self): + mock_boto3 = MagicMock() + with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET_ARN": "arn:aws:secretsmanager:us-east-1:123456789:secret:test"}, clear=False): + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var=None) + + assert result is None + mock_boto3.client.assert_not_called() + + def test_auto_arn_returns_none_when_arn_env_var_empty(self): + mock_boto3 = MagicMock() + with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET_ARN": ""}, clear=False): + if "WRENCH_SERVICE_SECRET" in os.environ: + del os.environ["WRENCH_SERVICE_SECRET"] + with patch.dict(sys.modules, {"boto3": mock_boto3}): + result = resolve_secret(env_var="WRENCH_SERVICE_SECRET") + + assert result is None + mock_boto3.client.assert_not_called() + + class TestEdgeCases: def test_none_env_var_parameter_skips_env_var_check(self): with patch.dict("os.environ", {"WRENCH_SERVICE_SECRET": "env-secret"}, clear=False):