From cc0aeb7cbce21f96cc0e5e94d16cafb8b83051a1 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 21:58:00 +0000 Subject: [PATCH 1/6] refactor: replace OCI-specific infrastructure with configurable runner backend - Replace OciProvider with RunnerProvider supporting both OCI and K8s backends - Add K8sConfig for Kubernetes-specific settings (namespace, PVC, etc.) - Add RunnerConfig with backend selection ("oci" or "k8s") - Remove runner field from SourceDefinition (now global config) - Add deposition_srn to HookInputs and convention_srn to SourceInputs - Extract shared runner utilities (memory parsing, progress parsing) - Add K8s Job-based runners with security hardening and orphan handling - Update storage adapter for cross-device compatibility (S3 CSI) - Add kubernetes-asyncio as optional dependency with [k8s] extra - Add proper error classification and health checks for K8s API refactor: extract utility functions from OciHookRunner to shared module Move parse_memory, parse_progress_file, and detect_rejection functions to runner_utils module to improve code reusability and testability. Update tests to use extracted functions directly and add missing deposition_srn parameter to HookInputs test instances. --- server/osa/application/di.py | 4 +- server/osa/config.py | 21 +- server/osa/domain/shared/model/source.py | 1 - .../osa/domain/source/port/source_runner.py | 1 + server/osa/domain/source/service/source.py | 1 + .../osa/domain/validation/port/hook_runner.py | 1 + .../domain/validation/service/validation.py | 1 + .../infrastructure/auth/role_repository.py | 5 +- server/osa/infrastructure/k8s/__init__.py | 6 + server/osa/infrastructure/k8s/di.py | 132 +++ server/osa/infrastructure/k8s/errors.py | 29 + server/osa/infrastructure/k8s/health.py | 63 ++ server/osa/infrastructure/k8s/naming.py | 41 + server/osa/infrastructure/k8s/runner.py | 503 ++++++++++ .../osa/infrastructure/k8s/source_runner.py | 428 ++++++++ server/osa/infrastructure/oci/runner.py | 77 +- .../osa/infrastructure/oci/source_runner.py | 60 +- .../persistence/adapter/spreadsheet.py | 4 +- .../persistence/adapter/storage.py | 28 +- server/osa/infrastructure/persistence/di.py | 4 +- .../persistence/repository/auth.py | 7 +- .../persistence/repository/event.py | 5 +- server/osa/infrastructure/runner_utils.py | 100 ++ server/pyproject.toml | 5 + .../domain/validation/test_hook_runner.py | 17 +- .../validation/test_validation_service.py | 1 + .../tests/unit/infrastructure/k8s/__init__.py | 0 .../k8s/test_classify_api_error.py | 46 + .../unit/infrastructure/k8s/test_health.py | 65 ++ .../k8s/test_k8s_hook_runner.py | 913 ++++++++++++++++++ .../k8s/test_k8s_source_runner.py | 423 ++++++++ .../unit/infrastructure/k8s/test_naming.py | 52 + .../persistence/adapter/test_file_storage.py | 8 +- .../infrastructure/test_file_storage_hooks.py | 22 +- .../infrastructure/test_file_storage_move.py | 139 +++ .../infrastructure/test_oci_hook_runner.py | 107 +- server/uv.lock | 30 +- 37 files changed, 3157 insertions(+), 193 deletions(-) create mode 100644 server/osa/infrastructure/k8s/__init__.py create mode 100644 server/osa/infrastructure/k8s/di.py create mode 100644 server/osa/infrastructure/k8s/errors.py create mode 100644 server/osa/infrastructure/k8s/health.py create mode 100644 server/osa/infrastructure/k8s/naming.py create mode 100644 server/osa/infrastructure/k8s/runner.py create mode 100644 server/osa/infrastructure/k8s/source_runner.py create mode 100644 server/osa/infrastructure/runner_utils.py create mode 100644 server/tests/unit/infrastructure/k8s/__init__.py create mode 100644 server/tests/unit/infrastructure/k8s/test_classify_api_error.py create mode 100644 server/tests/unit/infrastructure/k8s/test_health.py create mode 100644 server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py create mode 100644 server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py create mode 100644 server/tests/unit/infrastructure/k8s/test_naming.py create mode 100644 server/tests/unit/infrastructure/test_file_storage_move.py diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 014740d..1a96763 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -15,7 +15,7 @@ from osa.infrastructure.event.di import EventProvider from osa.infrastructure.http.di import HttpProvider from osa.infrastructure.index.di import IndexProvider -from osa.infrastructure.oci import OciProvider +from osa.infrastructure.k8s.di import RunnerProvider from osa.infrastructure.persistence import PersistenceProvider from osa.infrastructure.source.di import SourceProvider from osa.util.di.scope import Scope @@ -42,7 +42,7 @@ def create_container( return make_async_container( PersistenceProvider(), - OciProvider(), + RunnerProvider(), IndexProvider(), SourceProvider(), EventProvider(extra_handlers=extra_handlers), diff --git a/server/osa/config.py b/server/osa/config.py index c5158f3..00cba5a 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -3,7 +3,7 @@ import re import sys from pathlib import Path -from typing import Any +from typing import Any, Literal import yaml from pydantic import BaseModel, field_validator, model_validator @@ -83,6 +83,24 @@ class WorkerConfig(BaseModel): batch_size: int = 100 # Maximum events to fetch per poll cycle +class K8sConfig(BaseModel): + """Kubernetes-specific runner settings, required when runner.backend == "k8s".""" + + namespace: str = "osa" + service_account: str | None = None + data_pvc_name: str = "" + data_mount_path: str = "/data" + image_pull_secrets: list[str] = [] + job_ttl_seconds: int = 300 + + +class RunnerConfig(BaseModel): + """Runner backend selection and Kubernetes configuration.""" + + backend: Literal["oci", "k8s"] = "oci" + k8s: K8sConfig = K8sConfig() + + # ============================================================================= # Authentication Configuration # ============================================================================= @@ -193,6 +211,7 @@ class Config(BaseSettings): logging: LoggingConfig = LoggingConfig() worker: WorkerConfig = WorkerConfig() # Background worker settings auth: AuthConfig # Required - set via OSA_AUTH__JWT__SECRET env var + runner: RunnerConfig = RunnerConfig() host_data_dir: str | None = None # Host path for OSA_DATA_DIR (sibling container mounts) model_config = { diff --git a/server/osa/domain/shared/model/source.py b/server/osa/domain/shared/model/source.py index cf45743..e2e50fc 100644 --- a/server/osa/domain/shared/model/source.py +++ b/server/osa/domain/shared/model/source.py @@ -33,7 +33,6 @@ class SourceDefinition(ValueObject): image: str digest: str - runner: str = "oci" config: dict[str, Any] | None = None limits: SourceLimits = Field(default_factory=SourceLimits) schedule: SourceScheduleConfig | None = None diff --git a/server/osa/domain/source/port/source_runner.py b/server/osa/domain/source/port/source_runner.py index 8e25ddc..2e97858 100644 --- a/server/osa/domain/source/port/source_runner.py +++ b/server/osa/domain/source/port/source_runner.py @@ -14,6 +14,7 @@ class SourceInputs: """Inputs for a source container run.""" + convention_srn: str config: dict[str, Any] | None = None since: datetime | None = None limit: int | None = None diff --git a/server/osa/domain/source/service/source.py b/server/osa/domain/source/service/source.py index b5410b8..ef73660 100644 --- a/server/osa/domain/source/service/source.py +++ b/server/osa/domain/source/service/source.py @@ -71,6 +71,7 @@ async def run_source( # Build inputs inputs = SourceInputs( + convention_srn=str(convention_srn), config=source.config, since=since, limit=limit, diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index 97429d3..8665cd5 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -15,6 +15,7 @@ class HookInputs: """Inputs to pass to a hook container.""" record_json: dict + deposition_srn: str files_dir: Path | None = None config: dict | None = None diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index 2cac870..c63220d 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -107,6 +107,7 @@ async def validate_deposition( record_json = {"srn": str(deposition_srn), "metadata": metadata} inputs = HookInputs( record_json=record_json, + deposition_srn=str(deposition_srn), files_dir=Path(files_dir) if files_dir else None, ) diff --git a/server/osa/infrastructure/auth/role_repository.py b/server/osa/infrastructure/auth/role_repository.py index 0dad962..2c45c18 100644 --- a/server/osa/infrastructure/auth/role_repository.py +++ b/server/osa/infrastructure/auth/role_repository.py @@ -2,10 +2,11 @@ from uuid import UUID -from sqlalchemy import delete, insert, select +from sqlalchemy import CursorResult, delete, insert, select from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.auth.model.role import Role +from osa.domain.shared.error import InfrastructureError from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId from osa.domain.auth.model.value import UserId from osa.domain.auth.port.role_repository import RoleAssignmentRepository @@ -61,6 +62,8 @@ async def delete(self, user_id: UserId, role: Role) -> bool: ) result = await self.session.execute(stmt) await self.session.flush() + if not isinstance(result, CursorResult): + raise InfrastructureError(f"Expected CursorResult, got {type(result).__name__}") return result.rowcount > 0 async def get(self, user_id: UserId, role: Role) -> RoleAssignment | None: diff --git a/server/osa/infrastructure/k8s/__init__.py b/server/osa/infrastructure/k8s/__init__.py new file mode 100644 index 0000000..2511c39 --- /dev/null +++ b/server/osa/infrastructure/k8s/__init__.py @@ -0,0 +1,6 @@ +"""Kubernetes runner infrastructure. + +kubernetes-asyncio is an optional dependency. Modules that require it +(di.py, runner.py, source_runner.py, health.py) perform lazy imports +and raise ConfigurationError if the package is not installed. +""" diff --git a/server/osa/infrastructure/k8s/di.py b/server/osa/infrastructure/k8s/di.py new file mode 100644 index 0000000..3e3f757 --- /dev/null +++ b/server/osa/infrastructure/k8s/di.py @@ -0,0 +1,132 @@ +"""Dishka DI provider for runner infrastructure (OCI or Kubernetes). + +Uses Dishka's conditional activation (Marker + when=) to register only +the factories needed for the configured backend. When backend is "oci", +only Docker-related factories activate. When "k8s", only K8s factories +activate. No None placeholders, no unused dependencies resolved. +""" + +import logging +from typing import AsyncIterable + +import aiodocker +from dishka import Marker, activate, provide + +from osa.config import Config +from osa.domain.source.port.source_runner import SourceRunner +from osa.domain.validation.port.hook_runner import HookRunner +from osa.infrastructure.oci.runner import OciHookRunner +from osa.infrastructure.oci.source_runner import OciSourceRunner +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + +try: + from kubernetes_asyncio.client import ApiClient +except ImportError: + ApiClient = object # type: ignore[misc,assignment] + +logger = logging.getLogger(__name__) + +K8S = Marker("k8s") + + +class RunnerProvider(Provider): + """Config-driven runner provider. + + Uses Dishka conditional activation: factories decorated with + ``when=K8S`` only activate when the activator returns True + (i.e. ``config.runner.backend == "k8s"``). Undecorated factories + serve as the default OCI path. + """ + + @activate(K8S) + def is_k8s(self, config: Config) -> bool: + return config.runner.backend == "k8s" + + # ------------------------------------------------------------------ + # OCI backend (default — no when= condition) + # ------------------------------------------------------------------ + + @provide(scope=Scope.APP) + async def get_docker(self, config: Config) -> AsyncIterable[aiodocker.Docker]: + docker = aiodocker.Docker() + yield docker + await docker.close() + + @provide(scope=Scope.UOW) + def get_hook_runner_oci( + self, + docker: aiodocker.Docker, + config: Config, + ) -> HookRunner: + return OciHookRunner(docker=docker, host_data_dir=config.host_data_dir) + + @provide(scope=Scope.UOW) + def get_source_runner_oci( + self, + docker: aiodocker.Docker, + config: Config, + ) -> SourceRunner: + return OciSourceRunner(docker=docker, host_data_dir=config.host_data_dir) + + # ------------------------------------------------------------------ + # K8s backend (activated when config.runner.backend == "k8s") + # ------------------------------------------------------------------ + + @provide(when=K8S, scope=Scope.APP) + async def get_k8s_api_client(self, config: Config) -> AsyncIterable[ApiClient]: + from osa.domain.shared.error import ConfigurationError + + try: + import kubernetes_asyncio # noqa: F401 + except ImportError: + raise ConfigurationError( + "kubernetes-asyncio is required for K8s runner. Install with: pip install osa[k8s]" + ) + + from kubernetes_asyncio import client as k8s_client + from kubernetes_asyncio import config as k8s_config + + try: + k8s_config.load_incluster_config() + except k8s_config.ConfigException: + await k8s_config.load_kube_config() + + api_client = k8s_client.ApiClient() + + # Startup health check + from osa.infrastructure.k8s.health import check_k8s_health + + k8s_cfg = config.runner.k8s + batch_api = k8s_client.BatchV1Api(api_client) + core_api = k8s_client.CoreV1Api(api_client) + await check_k8s_health( + batch_api, + core_api, + namespace=k8s_cfg.namespace, + pvc_name=k8s_cfg.data_pvc_name, + ) + + logger.info("K8s API client initialized (namespace=%s)", k8s_cfg.namespace) + yield api_client + await api_client.close() + + @provide(when=K8S, scope=Scope.UOW) + def get_hook_runner_k8s( + self, + k8s_api_client: ApiClient, + config: Config, + ) -> HookRunner: + from osa.infrastructure.k8s.runner import K8sHookRunner + + return K8sHookRunner(api_client=k8s_api_client, config=config.runner.k8s) + + @provide(when=K8S, scope=Scope.UOW) + def get_source_runner_k8s( + self, + k8s_api_client: ApiClient, + config: Config, + ) -> SourceRunner: + from osa.infrastructure.k8s.source_runner import K8sSourceRunner + + return K8sSourceRunner(api_client=k8s_api_client, config=config.runner.k8s) diff --git a/server/osa/infrastructure/k8s/errors.py b/server/osa/infrastructure/k8s/errors.py new file mode 100644 index 0000000..3387478 --- /dev/null +++ b/server/osa/infrastructure/k8s/errors.py @@ -0,0 +1,29 @@ +"""K8s API error classification. + +Maps kubernetes-asyncio ApiException status codes to OSA error types. +""" + +from osa.domain.shared.error import ConfigurationError, InfrastructureError, OSAError + + +def classify_api_error(exc: Exception) -> OSAError: + """Classify a K8s API error by HTTP status code. + + - 403 → ConfigurationError (RBAC misconfiguration, not retried) + - 404 → ConfigurationError (namespace/resource missing, not retried) + - 500, 503 → InfrastructureError (transient, retried by outbox) + - Other → InfrastructureError + """ + status = getattr(exc, "status", 0) + reason = getattr(exc, "reason", str(exc)) + + if status == 403: + return ConfigurationError( + f"K8s RBAC permission denied: {reason}. " + "Check ServiceAccount permissions for the OSA namespace." + ) + if status == 404: + return ConfigurationError( + f"K8s resource not found: {reason}. Check that the namespace and resources exist." + ) + return InfrastructureError(f"K8s API error ({status}): {reason}") diff --git a/server/osa/infrastructure/k8s/health.py b/server/osa/infrastructure/k8s/health.py new file mode 100644 index 0000000..7cdb543 --- /dev/null +++ b/server/osa/infrastructure/k8s/health.py @@ -0,0 +1,63 @@ +"""Startup health check for K8s infrastructure.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from osa.domain.shared.error import ConfigurationError + +if TYPE_CHECKING: + from kubernetes_asyncio.client import BatchV1Api, CoreV1Api + +logger = logging.getLogger(__name__) + + +async def check_k8s_health( + batch_api: BatchV1Api, + core_api: CoreV1Api, + *, + namespace: str, + pvc_name: str, +) -> None: + """Verify K8s infrastructure is ready for running Jobs. + + Checks: + 1. K8s API reachable and namespace exists (list_namespaced_job) + 2. RBAC permissions correct (same call) + 3. Data PVC exists (read_namespaced_persistent_volume_claim) + + Raises ConfigurationError with actionable message on failure. + """ + # Check API reachability, namespace, and RBAC + try: + await batch_api.list_namespaced_job(namespace, limit=1) + except Exception as exc: + status = getattr(exc, "status", None) + if status == 403: + raise ConfigurationError( + f"K8s RBAC permission denied in namespace '{namespace}'. " + "Ensure the ServiceAccount can create/list/delete Jobs." + ) from exc + if status == 404: + raise ConfigurationError( + f"K8s namespace '{namespace}' not found. " + "Create the namespace or update OSA_RUNNER__K8S__NAMESPACE." + ) from exc + raise ConfigurationError( + f"K8s API unreachable: {exc}. Check cluster connectivity and kubeconfig." + ) from exc + + # Check PVC existence + try: + await core_api.read_namespaced_persistent_volume_claim(pvc_name, namespace) + except Exception as exc: + status = getattr(exc, "status", None) + if status == 404: + raise ConfigurationError( + f"PVC '{pvc_name}' not found in namespace '{namespace}'. " + "Create the PVC or update OSA_RUNNER__K8S__DATA_PVC_NAME." + ) from exc + raise ConfigurationError(f"Failed to verify PVC '{pvc_name}': {exc}") from exc + + logger.info("K8s health check passed: namespace=%s, pvc=%s", namespace, pvc_name) diff --git a/server/osa/infrastructure/k8s/naming.py b/server/osa/infrastructure/k8s/naming.py new file mode 100644 index 0000000..df5ad43 --- /dev/null +++ b/server/osa/infrastructure/k8s/naming.py @@ -0,0 +1,41 @@ +"""SRN-to-Job-name sanitization for DNS-1035 compliance.""" + +import re +import secrets + + +def job_name(prefix: str, hook_name: str, deposition_srn: str) -> str: + """Generate a K8s Job name from prefix, hook name, and deposition SRN. + + Output conforms to DNS-1035: lowercase alphanumeric + hyphens, + starts with a letter, max 63 characters. A 4-char random suffix + ensures uniqueness. + + Examples: + job_name("hook", "validate-dna", "urn:osa:localhost:dep:abc123") + → "osa-hook-validate-dna-abc123-x7k2" + """ + suffix = secrets.token_hex(2) # 4 hex chars + + # Extract the ID fragment from the SRN (last component) + srn_parts = deposition_srn.split(":") + dep_fragment = srn_parts[-1] if srn_parts else deposition_srn + + raw = f"osa-{prefix}-{hook_name}-{dep_fragment}-{suffix}" + + # Sanitize: lowercase, replace non-DNS chars with hyphens + sanitized = raw.lower() + sanitized = re.sub(r"[^a-z0-9-]", "-", sanitized) + # Collapse multiple hyphens + sanitized = re.sub(r"-+", "-", sanitized) + # Strip leading/trailing hyphens + sanitized = sanitized.strip("-") + + # Ensure starts with a letter + if sanitized and not sanitized[0].isalpha(): + sanitized = "osa-" + sanitized + + # Truncate to 63 chars, strip trailing hyphen after truncation + sanitized = sanitized[:63].rstrip("-") + + return sanitized diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py new file mode 100644 index 0000000..d9e4e84 --- /dev/null +++ b/server/osa/infrastructure/k8s/runner.py @@ -0,0 +1,503 @@ +"""Kubernetes Job-based hook runner.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from osa.config import K8sConfig +from osa.domain.shared.error import InfrastructureError +from osa.domain.shared.model.hook import HookDefinition +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.k8s.errors import classify_api_error +from osa.infrastructure.k8s.naming import job_name +from osa.infrastructure.runner_utils import detect_rejection, parse_progress_file + +if TYPE_CHECKING: + from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job + +logger = logging.getLogger(__name__) + +SCHEDULING_TIMEOUT = 120 # seconds to wait for pod to leave Pending + + +class K8sHookRunner(HookRunner): + """Executes hooks as Kubernetes Jobs. + + Mirrors OciHookRunner's security posture using K8s-native equivalents: + - Network isolation via dnsPolicy=None + NetworkPolicy labels + - Read-only rootfs, dropped capabilities, non-root user + - Resource limits via K8s resources.limits + - Timeout via activeDeadlineSeconds + """ + + def __init__(self, api_client: ApiClient, config: K8sConfig) -> None: + self._api_client = api_client + self._config = config + + async def run( + self, + hook: HookDefinition, + inputs: HookInputs, + work_dir: Path, + ) -> HookResult: + try: + from kubernetes_asyncio.client import BatchV1Api, CoreV1Api + except ImportError: + from osa.domain.shared.error import ConfigurationError + + raise ConfigurationError( + "kubernetes-asyncio is required for K8s runner. Install with: pip install osa[k8s]" + ) + + batch_api = BatchV1Api(self._api_client) + core_api = CoreV1Api(self._api_client) + + # Write input files + input_dir = work_dir / "input" + input_dir.mkdir(parents=True, exist_ok=True) + output_dir = work_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + (input_dir / "record.json").write_text(json.dumps(inputs.record_json)) + if inputs.config or hook.runtime.config: + config = {**hook.runtime.config, **(inputs.config or {})} + (input_dir / "config.json").write_text(json.dumps(config)) + + return await self._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn=inputs.deposition_srn, + ) + + async def _run_job( + self, + batch_api: BatchV1Api, + core_api: CoreV1Api, + hook: HookDefinition, + inputs: HookInputs, + work_dir: Path, + *, + deposition_srn: str = "", + ) -> HookResult: + """Core Job lifecycle: check orphans → create → schedule → execute → parse → cleanup.""" + namespace = self._config.namespace + start_time = time.monotonic() + + # Check for existing Jobs (orphan handling) + job_name_to_watch = None + + try: + existing = await self._check_existing_job( + batch_api, namespace, hook.name, deposition_srn + ) + + if existing == "succeeded": + # Read output from completed Job + return self._parse_hook_result(hook, work_dir, start_time) + + if existing and existing.startswith("active:"): + # Attach to running Job + job_name_to_watch = existing.split(":", 1)[1] + else: + # Create new Job (no existing or failed) + spec = self._build_job_spec( + hook, + work_dir, + deposition_srn=deposition_srn, + files_dir=inputs.files_dir, + ) + job_name_to_watch = spec.metadata.name + + await batch_api.create_namespaced_job(namespace, spec) + logger.info( + "Created K8s Job", + extra={ + "job_name": job_name_to_watch, + "namespace": namespace, + "image": f"{hook.runtime.image}@{hook.runtime.digest}", + "hook_name": hook.name, + "deposition_srn": deposition_srn, + }, + ) + + # Phase 1: Wait for scheduling + await self._wait_for_scheduling(core_api, job_name_to_watch, namespace) + + # Phase 2: Wait for completion + result = await self._wait_for_completion( + batch_api, + core_api, + job_name_to_watch, + namespace, + timeout_seconds=hook.runtime.limits.timeout_seconds + 30, + ) + + if result == "succeeded": + return self._parse_hook_result(hook, work_dir, start_time) + + # Job failed — determine why + return await self._diagnose_failure( + core_api, job_name_to_watch, namespace, hook, start_time, result + ) + + finally: + if job_name_to_watch: + await self._cleanup_job(batch_api, job_name_to_watch, namespace) + + def _parse_hook_result( + self, hook: HookDefinition, work_dir: Path, start_time: float + ) -> HookResult: + """Parse output from a completed Job.""" + output_dir = work_dir / "output" + progress = parse_progress_file(output_dir) + duration = time.monotonic() - start_time + + rejected, reason = detect_rejection(progress) + if rejected: + return HookResult( + hook_name=hook.name, + status=HookStatus.REJECTED, + rejection_reason=reason, + progress=progress, + duration_seconds=duration, + ) + + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + progress=progress, + duration_seconds=duration, + ) + + async def _check_existing_job( + self, + batch_api: BatchV1Api, + namespace: str, + hook_name: str, + deposition_srn: str, + ) -> str | None: + """Check for existing Jobs with matching labels. + + Returns: + "succeeded" if a completed Job exists + "active:{job_name}" if a running Job exists + None if no Job or only failed Jobs exist + """ + label_selector = f"osa.io/hook={hook_name},osa.io/deposition={deposition_srn}" + try: + job_list = await batch_api.list_namespaced_job(namespace, label_selector=label_selector) + except Exception as exc: + raise classify_api_error(exc) from exc + + for job in job_list.items: + if job.status.succeeded: + return "succeeded" + if job.status.active: + return f"active:{job.metadata.name}" + + return None + + def _build_job_spec( + self, + hook: HookDefinition, + work_dir: Path, + *, + deposition_srn: str = "", + files_dir: Path | None = None, + ) -> V1Job: + """Build a K8s Job manifest for a hook execution.""" + from kubernetes_asyncio.client import ( + V1Capabilities, + V1Container, + V1EmptyDirVolumeSource, + V1EnvVar, + V1Job, + V1JobSpec, + V1LocalObjectReference, + V1ObjectMeta, + V1PersistentVolumeClaimVolumeSource, + V1PodDNSConfig, + V1PodSecurityContext, + V1PodSpec, + V1PodTemplateSpec, + V1ResourceRequirements, + V1SecurityContext, + V1Volume, + V1VolumeMount, + ) + + name = job_name("hook", hook.name, deposition_srn) + relative_work = self._relative_path(work_dir) + input_subpath = f"{relative_work}/input" + output_subpath = f"{relative_work}/output" + + labels = { + "osa.io/role": "hook", + "osa.io/hook": hook.name, + "osa.io/deposition": deposition_srn, + } + + mounts = [ + V1VolumeMount( + name="data", mount_path="/osa/in", sub_path=input_subpath, read_only=True + ), + V1VolumeMount( + name="data", mount_path="/osa/out", sub_path=output_subpath, read_only=False + ), + V1VolumeMount(name="tmp", mount_path="/tmp"), + ] + + if files_dir: + relative_files = self._relative_path(files_dir) + mounts.append( + V1VolumeMount( + name="data", mount_path="/osa/in/files", sub_path=relative_files, read_only=True + ) + ) + + volumes = [ + V1Volume( + name="data", + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + claim_name=self._config.data_pvc_name + ), + ), + V1Volume(name="tmp", empty_dir=V1EmptyDirVolumeSource(size_limit="512Mi")), + ] + + container = V1Container( + name="hook", + image=f"{hook.runtime.image}@{hook.runtime.digest}", + env=[ + V1EnvVar(name="OSA_IN", value="/osa/in"), + V1EnvVar(name="OSA_OUT", value="/osa/out"), + V1EnvVar(name="OSA_HOOK_NAME", value=hook.name), + ], + resources=V1ResourceRequirements( + limits={"memory": hook.runtime.limits.memory, "cpu": hook.runtime.limits.cpu}, + ), + security_context=V1SecurityContext( + read_only_root_filesystem=True, + capabilities=V1Capabilities(drop=["ALL"]), + allow_privilege_escalation=False, + run_as_user=65534, + run_as_group=65534, + ), + volume_mounts=mounts, + ) + + pod_spec = V1PodSpec( + restart_policy="Never", + automount_service_account_token=False, + security_context=V1PodSecurityContext(run_as_non_root=True), + dns_policy="None", + dns_config=V1PodDNSConfig(nameservers=[]), + containers=[container], + volumes=volumes, + image_pull_secrets=[ + V1LocalObjectReference(name=s) for s in self._config.image_pull_secrets + ] + or None, + service_account_name=self._config.service_account, + ) + + return V1Job( + api_version="batch/v1", + kind="Job", + metadata=V1ObjectMeta(name=name, namespace=self._config.namespace, labels=labels), + spec=V1JobSpec( + backoff_limit=0, + active_deadline_seconds=SCHEDULING_TIMEOUT + hook.runtime.limits.timeout_seconds, + ttl_seconds_after_finished=self._config.job_ttl_seconds, + template=V1PodTemplateSpec( + metadata=V1ObjectMeta(labels=labels), + spec=pod_spec, + ), + ), + ) + + def _relative_path(self, path: Path) -> str: + """Strip the data mount prefix to get a PVC-relative subpath.""" + mount = self._config.data_mount_path.rstrip("/") + path_str = str(path) + if not path_str.startswith(mount): + raise ValueError(f"Path {path} is outside the data mount prefix {mount}") + return path_str[len(mount) :].lstrip("/") + + async def _wait_for_scheduling( + self, + core_api: CoreV1Api, + job_name: str, + namespace: str, + *, + timeout_seconds: float = SCHEDULING_TIMEOUT, + poll_interval: float = 2.0, + ) -> None: + """Wait for the Job's pod to leave Pending (Phase 1).""" + deadline = time.monotonic() + timeout_seconds + label_selector = f"job-name={job_name}" + + while time.monotonic() < deadline: + try: + pod_list = await core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) + except Exception as exc: + raise classify_api_error(exc) from exc + + for pod in pod_list.items: + phase = pod.status.phase + + # Check for eviction + if phase == "Failed": + reason = getattr(pod.status, "reason", None) or "Unknown" + raise InfrastructureError(f"Pod evicted or failed during scheduling: {reason}") + + # Check for image pull errors + if phase == "Pending" and pod.status.container_statuses: + for cs in pod.status.container_statuses: + waiting = getattr(cs.state, "waiting", None) + if waiting and waiting.reason in ("ImagePullBackOff", "ErrImagePull"): + message = getattr(waiting, "message", "") + raise InfrastructureError( + f"Image pull failed: {waiting.reason}: {message}" + ) + + if phase in ("Running", "Succeeded", "Failed"): + return # Pod scheduled + + await asyncio.sleep(poll_interval) + + raise InfrastructureError( + f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}" + ) + + async def _wait_for_completion( + self, + batch_api: BatchV1Api, + core_api: CoreV1Api, + job_name: str, + namespace: str, + *, + timeout_seconds: float = 330, + poll_interval: float = 5.0, + ) -> str: + """Wait for Job to complete (Phase 2). Returns 'succeeded' or 'failed'.""" + deadline = time.monotonic() + timeout_seconds + + while time.monotonic() < deadline: + try: + job = await batch_api.read_namespaced_job(job_name, namespace) + except Exception as exc: + raise classify_api_error(exc) from exc + + if job.status.succeeded: + return "succeeded" + + if job.status.conditions: + for condition in job.status.conditions: + if condition.type == "Failed" and condition.status == "True": + return f"failed:{getattr(condition, 'reason', 'Unknown')}" + if condition.type == "Complete" and condition.status == "True": + return "succeeded" + + if job.status.failed: + return "failed:BackoffLimitExceeded" + + await asyncio.sleep(poll_interval) + + # Timed out — poll once more + try: + job = await batch_api.read_namespaced_job(job_name, namespace) + if job.status.succeeded: + return "succeeded" + except Exception: + pass + + return "failed:WatchTimeout" + + async def _diagnose_failure( + self, + core_api: CoreV1Api, + job_name: str, + namespace: str, + hook: HookDefinition, + start_time: float, + failure_info: str, + ) -> HookResult: + """Determine failure reason from pod status.""" + duration = time.monotonic() - start_time + + # Check if DeadlineExceeded + if "DeadlineExceeded" in failure_info: + return HookResult( + hook_name=hook.name, + status=HookStatus.FAILED, + error_message="Hook timed out (deadline exceeded)", + duration_seconds=duration, + ) + + # Check pod for OOM or exit code + try: + label_selector = f"job-name={job_name}" + pod_list = await core_api.list_namespaced_pod(namespace, label_selector=label_selector) + for pod in pod_list.items: + if pod.status.container_statuses: + for cs in pod.status.container_statuses: + terminated = getattr(cs.state, "terminated", None) + if terminated: + if getattr(terminated, "reason", None) == "OOMKilled": + return HookResult( + hook_name=hook.name, + status=HookStatus.FAILED, + error_message="Hook killed by OOM", + duration_seconds=duration, + ) + exit_code = getattr(terminated, "exit_code", -1) + if exit_code != 0: + return HookResult( + hook_name=hook.name, + status=HookStatus.FAILED, + error_message=f"Hook exited with code {exit_code}", + duration_seconds=duration, + ) + except Exception: + pass + + return HookResult( + hook_name=hook.name, + status=HookStatus.FAILED, + error_message=f"Hook failed: {failure_info}", + duration_seconds=duration, + ) + + async def _cleanup_job( + self, + batch_api: BatchV1Api, + job_name: str, + namespace: str, + ) -> None: + """Delete a Job and its pods. Ignores 404 (already cleaned up).""" + try: + await batch_api.delete_namespaced_job( + job_name, + namespace, + propagation_policy="Background", + ) + logger.info("Cleaned up K8s Job", extra={"job_name": job_name}) + except Exception as exc: + if getattr(exc, "status", None) == 404: + return # Already gone + logger.warning( + "Failed to clean up K8s Job", + extra={"job_name": job_name, "error": str(exc)}, + ) diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/source_runner.py new file mode 100644 index 0000000..bd40240 --- /dev/null +++ b/server/osa/infrastructure/k8s/source_runner.py @@ -0,0 +1,428 @@ +"""Kubernetes Job-based source runner.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from osa.config import K8sConfig +from osa.domain.shared.error import ExternalServiceError, InfrastructureError +from osa.domain.shared.model.source import SourceDefinition +from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.infrastructure.k8s.errors import classify_api_error +from osa.infrastructure.k8s.naming import job_name +from osa.infrastructure.runner_utils import parse_records_file, parse_session_file + +if TYPE_CHECKING: + from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job + +logger = logging.getLogger(__name__) + +SCHEDULING_TIMEOUT = 120 + + +class K8sSourceRunner(SourceRunner): + """Executes sources as Kubernetes Jobs. + + Key differences from K8sHookRunner: + - Network enabled (normal DNS, no dnsPolicy override) + - Writable rootfs (no readOnlyRootFilesystem) + - Three volume mounts: input (ro), output (rw), files (rw) + - Higher resource defaults (3600s, 4g) + - Source-specific env vars (OSA_FILES, OSA_SINCE, etc.) + - Errors raise ExternalServiceError (not returned as result values) + """ + + def __init__(self, api_client: ApiClient, config: K8sConfig) -> None: + self._api_client = api_client + self._config = config + + async def run( + self, + source: SourceDefinition, + inputs: SourceInputs, + files_dir: Path, + work_dir: Path, + ) -> SourceOutput: + try: + from kubernetes_asyncio.client import BatchV1Api, CoreV1Api + except ImportError: + from osa.domain.shared.error import ConfigurationError + + raise ConfigurationError( + "kubernetes-asyncio is required for K8s runner. Install with: pip install osa[k8s]" + ) + + batch_api = BatchV1Api(self._api_client) + core_api = CoreV1Api(self._api_client) + + # Write input files + input_dir = work_dir / "input" + input_dir.mkdir(parents=True, exist_ok=True) + output_dir = work_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + files_dir.mkdir(parents=True, exist_ok=True) + + if inputs.config or source.config: + config = {**(source.config or {}), **(inputs.config or {})} + (input_dir / "config.json").write_text(json.dumps(config)) + + if inputs.session: + (input_dir / "session.json").write_text(json.dumps(inputs.session)) + + return await self._run_job( + batch_api, + core_api, + source, + inputs, + work_dir, + files_dir, + convention_srn=inputs.convention_srn, + ) + + async def _run_job( + self, + batch_api: BatchV1Api, + core_api: CoreV1Api, + source: SourceDefinition, + inputs: SourceInputs, + work_dir: Path, + files_dir: Path, + *, + convention_srn: str = "", + ) -> SourceOutput: + """Core Job lifecycle for source execution.""" + namespace = self._config.namespace + job_name_to_watch = None + + try: + # Check for existing Jobs + existing = await self._check_existing_job(batch_api, namespace, convention_srn) + + if existing == "succeeded": + return self._parse_source_output(work_dir, files_dir) + + if existing and existing.startswith("active:"): + job_name_to_watch = existing.split(":", 1)[1] + else: + spec = self._build_job_spec( + source, + work_dir=work_dir, + files_dir=files_dir, + inputs=inputs, + convention_srn=convention_srn, + ) + job_name_to_watch = spec.metadata.name + + await batch_api.create_namespaced_job(namespace, spec) + logger.info( + "Created K8s source Job", + extra={ + "job_name": job_name_to_watch, + "namespace": namespace, + "image": f"{source.image}@{source.digest}", + }, + ) + + # Phase 1: Scheduling + await self._wait_for_scheduling(core_api, job_name_to_watch, namespace) + + # Phase 2: Completion + result = await self._wait_for_completion( + batch_api, + core_api, + job_name_to_watch, + namespace, + timeout_seconds=source.limits.timeout_seconds + 30, + ) + + if result == "succeeded": + output = self._parse_source_output(work_dir, files_dir) + logger.info( + "Source completed", + extra={ + "job_name": job_name_to_watch, + "record_count": len(output.records), + "has_session": output.session is not None, + }, + ) + return output + + # Failed — diagnose and raise + await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, source, result) + # unreachable but satisfies type checker + raise ExternalServiceError("Source failed") + + finally: + if job_name_to_watch: + await self._cleanup_job(batch_api, job_name_to_watch, namespace) + + def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceOutput: + output_dir = work_dir / "output" + records = parse_records_file(output_dir) + session = parse_session_file(output_dir) + return SourceOutput(records=records, session=session, files_dir=files_dir) + + async def _check_existing_job( + self, batch_api: BatchV1Api, namespace: str, convention_srn: str + ) -> str | None: + label_parts = ["osa.io/role=source"] + if convention_srn: + label_parts.append(f"osa.io/convention={convention_srn}") + label_selector = ",".join(label_parts) + + try: + job_list = await batch_api.list_namespaced_job(namespace, label_selector=label_selector) + except Exception as exc: + raise classify_api_error(exc) from exc + + for job in job_list.items: + if job.status.succeeded: + return "succeeded" + if job.status.active: + return f"active:{job.metadata.name}" + return None + + def _build_job_spec( + self, + source: SourceDefinition, + *, + work_dir: Path, + files_dir: Path, + inputs: SourceInputs | None = None, + convention_srn: str = "", + ) -> V1Job: + from kubernetes_asyncio.client import ( + V1Capabilities, + V1Container, + V1EnvVar, + V1Job, + V1JobSpec, + V1LocalObjectReference, + V1ObjectMeta, + V1PersistentVolumeClaimVolumeSource, + V1PodSecurityContext, + V1PodSpec, + V1PodTemplateSpec, + V1ResourceRequirements, + V1SecurityContext, + V1Volume, + V1VolumeMount, + ) + + name = job_name("source", "src", convention_srn or "unknown") + relative_work = self._relative_path(work_dir) + input_subpath = f"{relative_work}/input" + output_subpath = f"{relative_work}/output" + relative_files = self._relative_path(files_dir) + + labels: dict[str, str] = { + "osa.io/role": "source", + } + if convention_srn: + labels["osa.io/convention"] = convention_srn + + env = [ + V1EnvVar(name="OSA_IN", value="/osa/in"), + V1EnvVar(name="OSA_OUT", value="/osa/out"), + V1EnvVar(name="OSA_FILES", value="/osa/files"), + ] + if inputs: + if inputs.since is not None: + env.append(V1EnvVar(name="OSA_SINCE", value=inputs.since.isoformat())) + if inputs.limit is not None: + env.append(V1EnvVar(name="OSA_LIMIT", value=str(inputs.limit))) + if inputs.offset: + env.append(V1EnvVar(name="OSA_OFFSET", value=str(inputs.offset))) + + mounts = [ + V1VolumeMount( + name="data", mount_path="/osa/in", sub_path=input_subpath, read_only=True + ), + V1VolumeMount(name="data", mount_path="/osa/out", sub_path=output_subpath), + V1VolumeMount(name="data", mount_path="/osa/files", sub_path=relative_files), + ] + + volumes = [ + V1Volume( + name="data", + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + claim_name=self._config.data_pvc_name + ), + ), + ] + + container = V1Container( + name="source", + image=f"{source.image}@{source.digest}", + env=env, + resources=V1ResourceRequirements( + limits={"memory": source.limits.memory, "cpu": source.limits.cpu}, + ), + security_context=V1SecurityContext( + capabilities=V1Capabilities(drop=["ALL"]), + allow_privilege_escalation=False, + run_as_user=65534, + run_as_group=65534, + ), + volume_mounts=mounts, + ) + + pod_spec = V1PodSpec( + restart_policy="Never", + automount_service_account_token=False, + security_context=V1PodSecurityContext(run_as_non_root=True), + containers=[container], + volumes=volumes, + image_pull_secrets=[ + V1LocalObjectReference(name=s) for s in self._config.image_pull_secrets + ] + or None, + service_account_name=self._config.service_account, + ) + + return V1Job( + api_version="batch/v1", + kind="Job", + metadata=V1ObjectMeta(name=name, namespace=self._config.namespace, labels=labels), + spec=V1JobSpec( + backoff_limit=0, + active_deadline_seconds=SCHEDULING_TIMEOUT + source.limits.timeout_seconds, + ttl_seconds_after_finished=self._config.job_ttl_seconds, + template=V1PodTemplateSpec( + metadata=V1ObjectMeta(labels=labels), + spec=pod_spec, + ), + ), + ) + + def _relative_path(self, path: Path) -> str: + mount = self._config.data_mount_path.rstrip("/") + path_str = str(path) + if not path_str.startswith(mount): + raise ValueError(f"Path {path} is outside the data mount prefix {mount}") + return path_str[len(mount) :].lstrip("/") + + async def _wait_for_scheduling( + self, + core_api: CoreV1Api, + job_name: str, + namespace: str, + *, + timeout_seconds: float = SCHEDULING_TIMEOUT, + poll_interval: float = 2.0, + ) -> None: + deadline = time.monotonic() + timeout_seconds + label_selector = f"job-name={job_name}" + + while time.monotonic() < deadline: + try: + pod_list = await core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) + except Exception as exc: + raise classify_api_error(exc) from exc + + for pod in pod_list.items: + phase = pod.status.phase + if phase == "Failed": + reason = getattr(pod.status, "reason", None) or "Unknown" + raise InfrastructureError(f"Pod failed during scheduling: {reason}") + + if phase == "Pending" and pod.status.container_statuses: + for cs in pod.status.container_statuses: + waiting = getattr(cs.state, "waiting", None) + if waiting and waiting.reason in ("ImagePullBackOff", "ErrImagePull"): + raise InfrastructureError( + f"Image pull failed: {waiting.reason}: {getattr(waiting, 'message', '')}" + ) + + if phase in ("Running", "Succeeded", "Failed"): + return + + await asyncio.sleep(poll_interval) + + raise InfrastructureError( + f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}" + ) + + async def _wait_for_completion( + self, + batch_api: BatchV1Api, + core_api: CoreV1Api, + job_name: str, + namespace: str, + *, + timeout_seconds: float = 3630, + poll_interval: float = 5.0, + ) -> str: + deadline = time.monotonic() + timeout_seconds + + while time.monotonic() < deadline: + try: + job = await batch_api.read_namespaced_job(job_name, namespace) + except Exception as exc: + raise classify_api_error(exc) from exc + + if job.status.succeeded: + return "succeeded" + if job.status.conditions: + for condition in job.status.conditions: + if condition.type == "Failed" and condition.status == "True": + return f"failed:{getattr(condition, 'reason', 'Unknown')}" + if condition.type == "Complete" and condition.status == "True": + return "succeeded" + if job.status.failed: + return "failed:BackoffLimitExceeded" + + await asyncio.sleep(poll_interval) + + return "failed:WatchTimeout" + + async def _diagnose_and_raise( + self, + core_api: CoreV1Api, + job_name: str, + namespace: str, + source: SourceDefinition, + failure_info: str, + ) -> None: + """Determine failure reason and raise appropriate error.""" + if "DeadlineExceeded" in failure_info: + raise ExternalServiceError(f"Source timed out after {source.limits.timeout_seconds}s") + + try: + label_selector = f"job-name={job_name}" + pod_list = await core_api.list_namespaced_pod(namespace, label_selector=label_selector) + for pod in pod_list.items: + if pod.status.container_statuses: + for cs in pod.status.container_statuses: + terminated = getattr(cs.state, "terminated", None) + if terminated: + if getattr(terminated, "reason", None) == "OOMKilled": + raise ExternalServiceError("Source killed by OOM") + exit_code = getattr(terminated, "exit_code", -1) + if exit_code != 0: + raise ExternalServiceError(f"Source exited with code {exit_code}") + except ExternalServiceError: + raise + except Exception: + pass + + raise ExternalServiceError(f"Source failed: {failure_info}") + + async def _cleanup_job(self, batch_api: BatchV1Api, job_name: str, namespace: str) -> None: + try: + await batch_api.delete_namespaced_job( + job_name, + namespace, + propagation_policy="Background", + ) + except Exception as exc: + if getattr(exc, "status", None) == 404: + return + logger.warning("Failed to clean up K8s source Job", extra={"job_name": job_name}) diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index 78e3c20..d51ac68 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -3,7 +3,6 @@ import asyncio import json import os -import re import stat import time from pathlib import Path @@ -13,8 +12,13 @@ import logfire from osa.domain.shared.model.hook import HookDefinition -from osa.domain.validation.model.hook_result import HookResult, HookStatus, ProgressEntry +from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.runner_utils import ( + detect_rejection, + parse_memory, + parse_progress_file, +) def _force_remove(func, path, exc): @@ -83,7 +87,12 @@ async def _resolve_and_run(): ) except asyncio.TimeoutError: duration = time.monotonic() - start_time - logfire.error("Hook timed out", hook=hook.name, timeout=timeout) + logfire.error( + "Hook timed out", + hook=hook.name, + deposition_srn=inputs.deposition_srn, + timeout=timeout, + ) return HookResult( hook_name=hook.name, status=HookStatus.FAILED, @@ -122,8 +131,8 @@ async def _run_container( "User": "65534:65534", "HostConfig": { "Binds": binds, - "Memory": self._parse_memory(hook.runtime.limits.memory), - "MemorySwap": self._parse_memory(hook.runtime.limits.memory), + "Memory": parse_memory(hook.runtime.limits.memory), + "MemorySwap": parse_memory(hook.runtime.limits.memory), "NanoCpus": int(float(hook.runtime.limits.cpu) * 1e9), "NetworkMode": "none", "ReadonlyRootfs": True, @@ -151,11 +160,11 @@ async def _run_container( } # Parse progress file - progress = self._parse_progress(output_dir) + progress = parse_progress_file(output_dir) # Check for rejection in progress - rejection = self._check_rejection(progress) - if rejection: + rejected, rejection = detect_rejection(progress) + if rejected: return { "status": HookStatus.REJECTED, "rejection_reason": rejection, @@ -223,55 +232,3 @@ async def _resolve_image(self, image: str, digest: str) -> str: logfire.info("Pulling hook image", image=image) await self._docker.images.pull(image) return image - - def _parse_progress(self, osa_out: Path) -> list[ProgressEntry]: - """Parse progress.jsonl from hook output.""" - progress_file = osa_out / "progress.jsonl" - if not progress_file.exists(): - return [] - - entries = [] - for line in progress_file.read_text().strip().split("\n"): - if not line.strip(): - continue - try: - data = json.loads(line) - entries.append( - ProgressEntry( - step=data.get("step"), - status=data.get("status", "unknown"), - message=data.get("message"), - ) - ) - except json.JSONDecodeError: - continue - return entries - - def _check_rejection(self, progress: list[ProgressEntry]) -> str | None: - """Check if any progress entry indicates rejection.""" - for entry in reversed(progress): - if entry.status == "rejected": - return entry.message - return None - - def _parse_memory(self, memory: str) -> int: - """Parse memory string like '2g' or '512m' to bytes.""" - memory = memory.strip().lower() - match = re.match(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$", memory) - if not match: - raise ValueError(f"Invalid memory format: {memory}") - - amount = float(match.group(1)) - unit = match.group(2) - - match unit: - case "g": - return int(amount * 1024 * 1024 * 1024) - case "m": - return int(amount * 1024 * 1024) - case "k": - return int(amount * 1024) - case None: - return int(amount) - case _: - raise ValueError(f"Unknown memory unit: {unit}") diff --git a/server/osa/infrastructure/oci/source_runner.py b/server/osa/infrastructure/oci/source_runner.py index 7a8e040..2f261ce 100644 --- a/server/osa/infrastructure/oci/source_runner.py +++ b/server/osa/infrastructure/oci/source_runner.py @@ -3,7 +3,6 @@ import asyncio import json import os -import re import stat import time from pathlib import Path @@ -14,6 +13,11 @@ from osa.domain.shared.error import ExternalServiceError from osa.domain.shared.model.source import SourceDefinition from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.infrastructure.runner_utils import ( + parse_memory, + parse_records_file, + parse_session_file, +) class OciSourceRunner(SourceRunner): @@ -133,8 +137,8 @@ async def _run_container( "Env": env, "HostConfig": { "Binds": binds, - "Memory": self._parse_memory(source.limits.memory), - "MemorySwap": self._parse_memory(source.limits.memory), + "Memory": parse_memory(source.limits.memory), + "MemorySwap": parse_memory(source.limits.memory), "NanoCpus": int(float(source.limits.cpu) * 1e9), # No NetworkMode: "none" — sources need network access # No ReadonlyRootfs — sources may need writable FS @@ -162,7 +166,9 @@ async def _run_container( logs_str = "".join(logs) if logs else "" raise ExternalServiceError(f"Source exited with code {exit_code}: {logs_str[:500]}") - return self._parse_output(output_dir, files_dir) + records = parse_records_file(output_dir) + session = parse_session_file(output_dir) + return SourceOutput(records=records, session=session, files_dir=files_dir) except aiodocker.DockerError as e: logfire.error("Docker error running source", error=str(e)) @@ -207,49 +213,3 @@ async def _resolve_image(self, image: str, digest: str) -> str: logfire.info("Pulling source image", image=image) await self._docker.images.pull(image) return image - - def _parse_output(self, output_dir: Path, files_dir: Path) -> SourceOutput: - """Parse records.jsonl and session.json from the output directory.""" - records: list[dict] = [] - records_file = output_dir / "records.jsonl" - if records_file.exists(): - for line in records_file.read_text().strip().split("\n"): - if not line.strip(): - continue - try: - records.append(json.loads(line)) - except json.JSONDecodeError: - logfire.warn("Skipping invalid JSON line in records.jsonl") - continue - - session = None - session_file = output_dir / "session.json" - if session_file.exists(): - try: - session = json.loads(session_file.read_text()) - except json.JSONDecodeError: - logfire.warn("Invalid session.json") - - return SourceOutput(records=records, session=session, files_dir=files_dir) - - def _parse_memory(self, memory: str) -> int: - """Parse memory string like '4g' or '512m' to bytes.""" - memory = memory.strip().lower() - match = re.match(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$", memory) - if not match: - raise ValueError(f"Invalid memory format: {memory}") - - amount = float(match.group(1)) - unit = match.group(2) - - match unit: - case "g": - return int(amount * 1024 * 1024 * 1024) - case "m": - return int(amount * 1024 * 1024) - case "k": - return int(amount * 1024) - case None: - return int(amount) - case _: - raise ValueError(f"Unknown memory unit: {unit}") diff --git a/server/osa/infrastructure/persistence/adapter/spreadsheet.py b/server/osa/infrastructure/persistence/adapter/spreadsheet.py index 400c832..6048352 100644 --- a/server/osa/infrastructure/persistence/adapter/spreadsheet.py +++ b/server/osa/infrastructure/persistence/adapter/spreadsheet.py @@ -13,7 +13,7 @@ SpreadsheetPort, ) from osa.domain.semantics.model.schema import Schema -from osa.domain.semantics.model.value import FieldDefinition, FieldType +from osa.domain.semantics.model.value import FieldDefinition, FieldType, TermConstraints # Ontologies with <=20 terms get dropdown validation; others get an instruction note. _MAX_DROPDOWN_TERMS = 20 @@ -44,7 +44,7 @@ def generate_template( desc_cell.font = _DESC_FONT # Add dropdown for term fields with small ontologies - if field.type == FieldType.TERM and field.constraints: + if field.type == FieldType.TERM and isinstance(field.constraints, TermConstraints): onto_srn_str = str(field.constraints.ontology_srn) terms = ontology_terms_by_srn.get(onto_srn_str, []) if terms and len(terms) <= _MAX_DROPDOWN_TERMS: diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index 1b2573b..31a0b6d 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -1,5 +1,6 @@ import hashlib import json +import logging import shutil import tempfile from collections.abc import AsyncIterator @@ -9,10 +10,13 @@ from osa.domain.deposition.model.value import DepositionFile from osa.domain.deposition.port.storage import FileStoragePort +from osa.domain.shared.error import InfrastructureError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +logger = logging.getLogger(__name__) -class LocalFileStorageAdapter(FileStoragePort): + +class FilesystemStorageAdapter(FileStoragePort): """Local filesystem adapter satisfying all domain storage ports. Implements FileStoragePort (deposition files), SourceStoragePort, @@ -80,12 +84,16 @@ async def save_file( files_dir = self._files_dir(deposition_id) target = self._safe_path(files_dir, filename) - # Atomic write: write to temp file then rename + # Atomic write: write to temp file then rename (copy+delete on S3 CSI) fd, tmp_path = tempfile.mkstemp(dir=files_dir) try: with open(fd, "wb") as f: f.write(content) - Path(tmp_path).rename(target) + try: + Path(tmp_path).rename(target) + except OSError: + shutil.copy2(tmp_path, target) + Path(tmp_path).unlink(missing_ok=True) except Exception: Path(tmp_path).unlink(missing_ok=True) raise @@ -158,9 +166,17 @@ def move_source_files_to_deposition( if not source_files_dir.exists(): return files_dir = self._files_dir(deposition_srn) - # Rename entire source_id directory contents into deposition files dir + # Move files into deposition dir (copy+delete fallback for S3 CSI) for f in source_files_dir.iterdir(): target = files_dir / f.name - f.rename(target) + try: + f.rename(target) + except OSError: + try: + shutil.copy2(f, target) + f.unlink() + except OSError as e: + raise InfrastructureError(f"Failed to copy file {f.name}: {e}") from e # Clean up empty source_id directory - source_files_dir.rmdir() + if source_files_dir.exists(): + source_files_dir.rmdir() diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index e897002..12fa207 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -35,7 +35,7 @@ OntologyReaderAdapter, SchemaReaderAdapter, ) -from osa.infrastructure.persistence.adapter.storage import LocalFileStorageAdapter +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter from osa.infrastructure.persistence.database import ( create_db_engine, create_session_factory, @@ -120,7 +120,7 @@ def get_feature_store(self, engine: AsyncEngine, session: AsyncSession) -> Featu # File storage @provide(scope=Scope.APP) def get_file_storage(self, paths: "OSAPaths") -> FileStoragePort: - return LocalFileStorageAdapter(base_path=str(paths.data_dir / "files")) + return FilesystemStorageAdapter(base_path=str(paths.data_dir / "files")) @provide(scope=Scope.APP) def get_source_storage(self, file_storage: FileStoragePort) -> SourceStoragePort: diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py index b6f7ce9..1ed308a 100644 --- a/server/osa/infrastructure/persistence/repository/auth.py +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime from uuid import UUID -from sqlalchemy import delete, insert, select, update +from sqlalchemy import CursorResult, delete, insert, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -22,6 +22,7 @@ UserCode, UserId, ) +from osa.domain.shared.error import InfrastructureError from osa.domain.auth.port.repository import ( DeviceAuthorizationRepository, LinkedAccountRepository, @@ -228,6 +229,8 @@ async def revoke_family(self, family_id: TokenFamilyId) -> int: ) result = await self.session.execute(stmt) await self.session.flush() + if not isinstance(result, CursorResult): + raise InfrastructureError(f"Expected CursorResult, got {type(result).__name__}") return result.rowcount @@ -349,4 +352,6 @@ async def delete_expired_before(self, cutoff: datetime) -> int: ) result = await self.session.execute(stmt) await self.session.flush() + if not isinstance(result, CursorResult): + raise InfrastructureError(f"Expected CursorResult, got {type(result).__name__}") return result.rowcount diff --git a/server/osa/infrastructure/persistence/repository/event.py b/server/osa/infrastructure/persistence/repository/event.py index b4c728d..3e09df5 100644 --- a/server/osa/infrastructure/persistence/repository/event.py +++ b/server/osa/infrastructure/persistence/repository/event.py @@ -5,11 +5,12 @@ from typing import TypeVar from uuid import uuid4 -from sqlalchemy import func, insert, or_, select, update +from sqlalchemy import CursorResult, func, insert, or_, select, update from sqlalchemy.dialects.postgresql import INTERVAL from sqlalchemy.sql import literal from sqlalchemy.ext.asyncio import AsyncSession +from osa.domain.shared.error import InfrastructureError from osa.domain.shared.event import ClaimResult, Delivery, Event, EventId from osa.domain.shared.port.event_repository import EventRepository from osa.infrastructure.persistence.tables import deliveries_table, events_table @@ -278,6 +279,8 @@ async def reset_stale_deliveries(self, timeout_seconds: float) -> int: ) result = await self._session.execute(stmt) + if not isinstance(result, CursorResult): + raise InfrastructureError(f"Expected CursorResult, got {type(result).__name__}") count = result.rowcount if count > 0: logger.info(f"Reset {count} stale deliveries (older than {timeout_seconds}s)") diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py new file mode 100644 index 0000000..4b108b0 --- /dev/null +++ b/server/osa/infrastructure/runner_utils.py @@ -0,0 +1,100 @@ +"""Shared result-parsing utilities for OCI and K8s runners.""" + +import json +import re +from pathlib import Path +from typing import Any + +from osa.domain.validation.model.hook_result import ProgressEntry + + +def parse_progress_file(output_dir: Path) -> list[ProgressEntry]: + """Parse progress.jsonl from hook output directory.""" + progress_file = output_dir / "progress.jsonl" + if not progress_file.exists(): + return [] + + entries = [] + for line in progress_file.read_text().strip().split("\n"): + if not line.strip(): + continue + try: + data = json.loads(line) + entries.append( + ProgressEntry( + step=data.get("step"), + status=data.get("status", "unknown"), + message=data.get("message"), + ) + ) + except json.JSONDecodeError: + continue + return entries + + +def detect_rejection(progress: list[ProgressEntry]) -> tuple[bool, str | None]: + """Check if any progress entry indicates rejection. + + Returns (is_rejected, rejection_reason). + """ + for entry in reversed(progress): + if entry.status == "rejected": + return True, entry.message + return False, None + + +def parse_memory(memory: str) -> int: + """Parse memory string like '2g' or '512m' to bytes.""" + memory = memory.strip().lower() + match = re.match(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$", memory) + if not match: + raise ValueError(f"Invalid memory format: {memory}") + + amount = float(match.group(1)) + unit = match.group(2) + + match unit: + case "g": + return int(amount * 1024 * 1024 * 1024) + case "m": + return int(amount * 1024 * 1024) + case "k": + return int(amount * 1024) + case None: + return int(amount) + case _: + raise ValueError(f"Unknown memory unit: {unit}") + + +def parse_records_file(output_dir: Path) -> list[dict[str, Any]]: + """Parse records.jsonl from source output directory.""" + import logfire + + records: list[dict[str, Any]] = [] + records_file = output_dir / "records.jsonl" + if not records_file.exists(): + return records + + for line in records_file.read_text().strip().split("\n"): + if not line.strip(): + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + logfire.warn("Skipping invalid JSON line in records.jsonl") + continue + return records + + +def parse_session_file(output_dir: Path) -> dict[str, Any] | None: + """Parse session.json from source output directory.""" + import logfire + + session_file = output_dir / "session.json" + if not session_file.exists(): + return None + try: + return json.loads(session_file.read_text()) + except json.JSONDecodeError: + logfire.warn("Invalid session.json") + return None diff --git a/server/pyproject.toml b/server/pyproject.toml index 425f613..47c3a4c 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -37,6 +37,11 @@ geo-entrez = "sources.geo_entrez:GEOEntrezSource" requires = ["hatchling"] build-backend = "hatchling.build" +[project.optional-dependencies] +k8s = [ + "kubernetes-asyncio>=31.0", +] + [dependency-groups] dev = [ "coverage>=7.12.0", diff --git a/server/tests/unit/domain/validation/test_hook_runner.py b/server/tests/unit/domain/validation/test_hook_runner.py index b305b27..bcf5ecb 100644 --- a/server/tests/unit/domain/validation/test_hook_runner.py +++ b/server/tests/unit/domain/validation/test_hook_runner.py @@ -11,19 +11,27 @@ class TestHookInputs: def test_minimal_construction(self): - inputs = HookInputs(record_json={"srn": "urn:osa:localhost:rec:123"}) + inputs = HookInputs( + record_json={"srn": "urn:osa:localhost:rec:123"}, + deposition_srn="urn:osa:localhost:dep:test123", + ) assert inputs.record_json == {"srn": "urn:osa:localhost:rec:123"} assert inputs.files_dir is None assert inputs.config is None def test_with_files_dir(self): files = Path("/tmp/files") - inputs = HookInputs(record_json={"srn": "test"}, files_dir=files) + inputs = HookInputs( + record_json={"srn": "test"}, + deposition_srn="urn:osa:localhost:dep:test123", + files_dir=files, + ) assert inputs.files_dir == files def test_with_config(self): inputs = HookInputs( record_json={"srn": "test"}, + deposition_srn="urn:osa:localhost:dep:test123", config={"r_min": 3.0, "threshold": 0.5}, ) assert inputs.config == {"r_min": 3.0, "threshold": 0.5} @@ -32,6 +40,7 @@ def test_full_construction(self): files = Path("/tmp/data/files") inputs = HookInputs( record_json={"srn": "urn:osa:localhost:rec:456", "name": "test"}, + deposition_srn="urn:osa:localhost:dep:test456", files_dir=files, config={"key": "value"}, ) @@ -40,7 +49,9 @@ def test_full_construction(self): assert inputs.config == {"key": "value"} def test_is_frozen(self): - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) with pytest.raises(AttributeError): inputs.record_json = {} # type: ignore[misc] diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index b020dfd..c04b355 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -62,6 +62,7 @@ def _make_service( def _make_inputs() -> HookInputs: return HookInputs( record_json={"srn": "urn:osa:localhost:dep:test123", "metadata": {"name": "test"}}, + deposition_srn="urn:osa:localhost:dep:test123", ) diff --git a/server/tests/unit/infrastructure/k8s/__init__.py b/server/tests/unit/infrastructure/k8s/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/infrastructure/k8s/test_classify_api_error.py b/server/tests/unit/infrastructure/k8s/test_classify_api_error.py new file mode 100644 index 0000000..c9d3096 --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_classify_api_error.py @@ -0,0 +1,46 @@ +"""Tests for K8s API error classification.""" + +from osa.domain.shared.error import ConfigurationError, InfrastructureError +from osa.infrastructure.k8s.errors import classify_api_error + + +class _FakeApiException(Exception): + """Stand-in for kubernetes_asyncio.client.ApiException.""" + + def __init__(self, status: int, reason: str = ""): + self.status = status + self.reason = reason + super().__init__(f"{status}: {reason}") + + +class TestClassifyApiError: + def test_403_returns_configuration_error(self): + exc = _FakeApiException(403, "Forbidden") + result = classify_api_error(exc) + assert isinstance(result, ConfigurationError) + assert "RBAC" in result.message or "permission" in result.message.lower() + + def test_404_returns_configuration_error(self): + exc = _FakeApiException(404, "Not Found") + result = classify_api_error(exc) + assert isinstance(result, ConfigurationError) + + def test_500_returns_infrastructure_error(self): + exc = _FakeApiException(500, "Internal Server Error") + result = classify_api_error(exc) + assert isinstance(result, InfrastructureError) + + def test_503_returns_infrastructure_error(self): + exc = _FakeApiException(503, "Service Unavailable") + result = classify_api_error(exc) + assert isinstance(result, InfrastructureError) + + def test_409_returns_infrastructure_error(self): + exc = _FakeApiException(409, "Conflict") + result = classify_api_error(exc) + assert isinstance(result, InfrastructureError) + + def test_unknown_status_returns_infrastructure_error(self): + exc = _FakeApiException(429, "Too Many Requests") + result = classify_api_error(exc) + assert isinstance(result, InfrastructureError) diff --git a/server/tests/unit/infrastructure/k8s/test_health.py b/server/tests/unit/infrastructure/k8s/test_health.py new file mode 100644 index 0000000..180842b --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_health.py @@ -0,0 +1,65 @@ +"""Tests for K8s startup health check.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.shared.error import ConfigurationError +from osa.infrastructure.k8s.health import check_k8s_health + + +class _FakeApiException(Exception): + def __init__(self, status: int, reason: str = ""): + self.status = status + self.reason = reason + super().__init__(f"{status}: {reason}") + + +class TestCheckK8sHealth: + @pytest.mark.asyncio + async def test_healthy_cluster(self): + batch_api = AsyncMock() + core_api = AsyncMock() + batch_api.list_namespaced_job.return_value = MagicMock(items=[]) + core_api.read_namespaced_persistent_volume_claim.return_value = MagicMock() + + await check_k8s_health(batch_api, core_api, namespace="osa", pvc_name="data-pvc") + + @pytest.mark.asyncio + async def test_api_unreachable(self): + batch_api = AsyncMock() + core_api = AsyncMock() + batch_api.list_namespaced_job.side_effect = Exception("Connection refused") + + with pytest.raises(ConfigurationError, match="K8s API"): + await check_k8s_health(batch_api, core_api, namespace="osa", pvc_name="data-pvc") + + @pytest.mark.asyncio + async def test_namespace_not_found(self): + batch_api = AsyncMock() + core_api = AsyncMock() + batch_api.list_namespaced_job.side_effect = _FakeApiException(404, "Not Found") + + with pytest.raises(ConfigurationError, match="osa"): + await check_k8s_health(batch_api, core_api, namespace="osa", pvc_name="data-pvc") + + @pytest.mark.asyncio + async def test_rbac_forbidden(self): + batch_api = AsyncMock() + core_api = AsyncMock() + batch_api.list_namespaced_job.side_effect = _FakeApiException(403, "Forbidden") + + with pytest.raises(ConfigurationError, match="permission"): + await check_k8s_health(batch_api, core_api, namespace="osa", pvc_name="data-pvc") + + @pytest.mark.asyncio + async def test_pvc_missing(self): + batch_api = AsyncMock() + core_api = AsyncMock() + batch_api.list_namespaced_job.return_value = MagicMock(items=[]) + core_api.read_namespaced_persistent_volume_claim.side_effect = _FakeApiException( + 404, "Not Found" + ) + + with pytest.raises(ConfigurationError, match="data-pvc"): + await check_k8s_health(batch_api, core_api, namespace="osa", pvc_name="data-pvc") diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py new file mode 100644 index 0000000..8d7fcd3 --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -0,0 +1,913 @@ +"""Unit tests for K8sHookRunner — Job spec, scheduling, execution, orphans, cleanup.""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.config import K8sConfig +from osa.domain.shared.error import InfrastructureError +from osa.domain.shared.model.hook import ( + ColumnDef, + HookDefinition, + OciConfig, + OciLimits, + TableFeatureSpec, +) +from osa.domain.validation.model.hook_result import HookStatus +from osa.domain.validation.port.hook_runner import HookInputs +from osa.infrastructure.k8s.runner import K8sHookRunner + + +def _make_hook( + name: str = "validate_dna", + timeout: int = 300, + memory: str = "2g", + cpu: str = "2.0", + config: dict | None = None, + image: str = "ghcr.io/example/hook:v1", + digest: str = "sha256:abc123", +) -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig( + image=image, + digest=digest, + config=config or {}, + limits=OciLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), + ), + feature=TableFeatureSpec( + cardinality="many", + columns=[ColumnDef(name="score", json_type="number", required=True)], + ), + ) + + +def _make_config(**overrides) -> K8sConfig: + defaults = { + "namespace": "osa", + "data_pvc_name": "osa-data-pvc", + "data_mount_path": "/data", + "job_ttl_seconds": 300, + } + defaults.update(overrides) + return K8sConfig(**defaults) + + +def _make_runner(config: K8sConfig | None = None) -> K8sHookRunner: + api_client = MagicMock() + return K8sHookRunner(api_client=api_client, config=config or _make_config()) + + +# --------------------------------------------------------------------------- +# Job spec generation (T014) +# --------------------------------------------------------------------------- + + +class TestJobSpecGeneration: + def test_correct_image(self): + runner = _make_runner() + hook = _make_hook(image="ghcr.io/org/hook:v2", digest="sha256:def456") + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + container = spec.spec.template.spec.containers[0] + assert container.image == "ghcr.io/org/hook:v2@sha256:def456" + + def test_security_context(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + pod_spec = spec.spec.template.spec + container = pod_spec.containers[0] + sec = container.security_context + + assert sec.read_only_root_filesystem is True + assert sec.capabilities.drop == ["ALL"] + assert sec.allow_privilege_escalation is False + assert sec.run_as_user == 65534 + assert sec.run_as_group == 65534 + + # Pod-level security context + assert pod_spec.security_context.run_as_non_root is True + + def test_resource_limits(self): + runner = _make_runner() + hook = _make_hook(memory="4g", cpu="2.0") + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + resources = spec.spec.template.spec.containers[0].resources + assert resources.limits["memory"] == "4g" + assert resources.limits["cpu"] == "2.0" + + def test_volume_mounts(self): + runner = _make_runner() + hook = _make_hook() + work_dir = Path("/data/depositions/localhost_abc/hooks/validate_dna") + spec = runner._build_job_spec(hook, work_dir) + + volumes = spec.spec.template.spec.volumes + pvc_vol = next(v for v in volumes if v.name == "data") + assert pvc_vol.persistent_volume_claim.claim_name == "osa-data-pvc" + + tmp_vol = next(v for v in volumes if v.name == "tmp") + assert tmp_vol.empty_dir is not None + + mounts = spec.spec.template.spec.containers[0].volume_mounts + mount_paths = {m.mount_path for m in mounts} + assert "/osa/in" in mount_paths + assert "/osa/out" in mount_paths + assert "/tmp" in mount_paths + + def test_env_vars(self): + runner = _make_runner() + hook = _make_hook(name="pocket_detect") + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/pocket_detect") + ) + + env = spec.spec.template.spec.containers[0].env + env_dict = {e.name: e.value for e in env} + assert env_dict["OSA_IN"] == "/osa/in" + assert env_dict["OSA_OUT"] == "/osa/out" + assert env_dict["OSA_HOOK_NAME"] == "pocket_detect" + + def test_backoff_limit_zero(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + assert spec.spec.backoff_limit == 0 + + def test_active_deadline_seconds(self): + runner = _make_runner() + hook = _make_hook(timeout=300) + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + # scheduling_timeout (120) + hook timeout (300) + assert spec.spec.active_deadline_seconds == 420 + + def test_dns_policy_none(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + pod_spec = spec.spec.template.spec + assert pod_spec.dns_policy == "None" + assert pod_spec.dns_config.nameservers == [] + + def test_labels(self): + runner = _make_runner() + hook = _make_hook(name="validate_dna") + spec = runner._build_job_spec( + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + labels = spec.spec.template.metadata.labels + assert labels["osa.io/role"] == "hook" + assert labels["osa.io/hook"] == "validate_dna" + assert labels["osa.io/deposition"] == "urn:osa:localhost:dep:abc123" + + def test_human_readable_job_name(self): + runner = _make_runner() + hook = _make_hook(name="validate_dna") + spec = runner._build_job_spec( + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + name = spec.metadata.name + assert name.startswith("osa-hook-") + assert len(name) <= 63 + + def test_empty_dir_at_tmp(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + volumes = spec.spec.template.spec.volumes + tmp = next(v for v in volumes if v.name == "tmp") + assert tmp.empty_dir.size_limit == "512Mi" + + def test_automount_service_account_false(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + pod_spec = spec.spec.template.spec + assert pod_spec.automount_service_account_token is False + + def test_ttl_seconds_after_finished(self): + runner = _make_runner(config=_make_config(job_ttl_seconds=600)) + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + assert spec.spec.ttl_seconds_after_finished == 600 + + def test_files_mount_when_files_dir_provided(self): + runner = _make_runner() + hook = _make_hook() + spec = runner._build_job_spec( + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + files_dir=Path("/data/depositions/localhost_abc/files"), + ) + + mounts = spec.spec.template.spec.containers[0].volume_mounts + files_mount = next((m for m in mounts if m.mount_path == "/osa/in/files"), None) + assert files_mount is not None + assert files_mount.read_only is True + + def test_image_pull_secrets(self): + runner = _make_runner(config=_make_config(image_pull_secrets=["ghcr-secret"])) + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + secrets = spec.spec.template.spec.image_pull_secrets + assert len(secrets) == 1 + assert secrets[0].name == "ghcr-secret" + + def test_service_account(self): + runner = _make_runner(config=_make_config(service_account="osa-runner")) + hook = _make_hook() + spec = runner._build_job_spec( + hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + ) + + assert spec.spec.template.spec.service_account_name == "osa-runner" + + +# --------------------------------------------------------------------------- +# Path coordination (T015) +# --------------------------------------------------------------------------- + + +class TestPathCoordination: + def test_relative_path_strips_prefix(self): + runner = _make_runner(config=_make_config(data_mount_path="/data")) + result = runner._relative_path(Path("/data/depositions/localhost_abc/hooks/validate")) + assert result == "depositions/localhost_abc/hooks/validate" + + def test_relative_path_raises_outside_prefix(self): + runner = _make_runner(config=_make_config(data_mount_path="/data")) + with pytest.raises(ValueError, match="outside"): + runner._relative_path(Path("/other/path")) + + def test_relative_path_handles_trailing_slash(self): + runner = _make_runner(config=_make_config(data_mount_path="/data/")) + result = runner._relative_path(Path("/data/depositions/test")) + assert result == "depositions/test" + + +# --------------------------------------------------------------------------- +# Scheduling watch (T016) +# --------------------------------------------------------------------------- + + +class TestSchedulingWatch: + @pytest.mark.asyncio + async def test_pod_leaves_pending_quickly(self): + runner = _make_runner() + core_api = AsyncMock() + + # Pod transitions from Pending to Running + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + await runner._wait_for_scheduling(core_api, "test-job", "osa") + + @pytest.mark.asyncio + async def test_pod_stuck_scheduling_timeout(self): + runner = _make_runner() + core_api = AsyncMock() + + # Pod stays in Pending + pod = MagicMock() + pod.status.phase = "Pending" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + with pytest.raises(InfrastructureError, match="scheduling"): + await runner._wait_for_scheduling( + core_api, "test-job", "osa", timeout_seconds=0.1, poll_interval=0.05 + ) + + @pytest.mark.asyncio + async def test_image_pull_backoff_fails_fast(self): + runner = _make_runner() + core_api = AsyncMock() + + pod = MagicMock() + pod.status.phase = "Pending" + container_status = MagicMock() + container_status.state.waiting.reason = "ImagePullBackOff" + container_status.state.waiting.message = "pull access denied" + pod.status.container_statuses = [container_status] + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + with pytest.raises(InfrastructureError, match="[Ii]mage pull"): + await runner._wait_for_scheduling(core_api, "test-job", "osa") + + @pytest.mark.asyncio + async def test_err_image_pull_fails_fast(self): + runner = _make_runner() + core_api = AsyncMock() + + pod = MagicMock() + pod.status.phase = "Pending" + container_status = MagicMock() + container_status.state.waiting.reason = "ErrImagePull" + container_status.state.waiting.message = "not found" + pod.status.container_statuses = [container_status] + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + with pytest.raises(InfrastructureError, match="[Ii]mage pull"): + await runner._wait_for_scheduling(core_api, "test-job", "osa") + + @pytest.mark.asyncio + async def test_pod_evicted(self): + runner = _make_runner() + core_api = AsyncMock() + + pod = MagicMock() + pod.status.phase = "Failed" + pod.status.reason = "Evicted" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + with pytest.raises(InfrastructureError, match="[Ee]vict"): + await runner._wait_for_scheduling(core_api, "test-job", "osa") + + +# --------------------------------------------------------------------------- +# Execution watch + orphan handling + cleanup (T017) +# --------------------------------------------------------------------------- + + +class TestExecutionAndCleanup: + @pytest.mark.asyncio + async def test_successful_run(self, tmp_path: Path): + """Full lifecycle: create → schedule → complete → parse → cleanup.""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + # No existing jobs (orphan check) + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + + # Job creation succeeds + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod leaves Pending + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job completes successfully + completed_job = MagicMock() + condition = MagicMock() + condition.type = "Complete" + condition.status = "True" + completed_job.status.conditions = [condition] + completed_job.status.succeeded = 1 + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + # Create output directory with progress + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + (output_dir / "progress.jsonl").write_text( + '{"step":"Check","status":"completed","message":"OK"}\n' + ) + + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.PASSED + assert len(result.progress) == 1 + # Job should be cleaned up + batch_api.delete_namespaced_job.assert_called_once() + + @pytest.mark.asyncio + async def test_timeout_deadline_exceeded(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod Running + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job failed with DeadlineExceeded + failed_job = MagicMock() + condition = MagicMock() + condition.type = "Failed" + condition.status = "True" + condition.reason = "DeadlineExceeded" + failed_job.status.conditions = [condition] + failed_job.status.succeeded = None + failed_job.status.failed = 1 + batch_api.read_namespaced_job.return_value = failed_job + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + work_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.FAILED + assert ( + "timed out" in result.error_message.lower() + or "deadline" in result.error_message.lower() + ) + batch_api.delete_namespaced_job.assert_called_once() + + @pytest.mark.asyncio + async def test_oom_exit_137(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job failed + failed_job = MagicMock() + condition = MagicMock() + condition.type = "Failed" + condition.status = "True" + condition.reason = "BackoffLimitExceeded" + failed_job.status.conditions = [condition] + failed_job.status.succeeded = None + failed_job.status.failed = 1 + batch_api.read_namespaced_job.return_value = failed_job + + # Pod has OOMKilled container + oom_pod = MagicMock() + oom_pod.status.phase = "Failed" + terminated = MagicMock() + terminated.reason = "OOMKilled" + terminated.exit_code = 137 + container_status = MagicMock() + container_status.state.terminated = terminated + oom_pod.status.container_statuses = [container_status] + oom_pod_list = MagicMock() + oom_pod_list.items = [oom_pod] + # Second call to list_namespaced_pod returns the OOM pod + core_api.list_namespaced_pod.side_effect = [pod_list, oom_pod_list] + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + work_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.FAILED + assert "oom" in result.error_message.lower() + + @pytest.mark.asyncio + async def test_nonzero_exit(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + failed_job = MagicMock() + condition = MagicMock() + condition.type = "Failed" + condition.status = "True" + condition.reason = "BackoffLimitExceeded" + failed_job.status.conditions = [condition] + failed_job.status.succeeded = None + failed_job.status.failed = 1 + batch_api.read_namespaced_job.return_value = failed_job + + # Pod with exit code 1 + exit_pod = MagicMock() + exit_pod.status.phase = "Failed" + terminated = MagicMock() + terminated.reason = None + terminated.exit_code = 1 + container_status = MagicMock() + container_status.state.terminated = terminated + exit_pod.status.container_statuses = [container_status] + exit_pod_list = MagicMock() + exit_pod_list.items = [exit_pod] + core_api.list_namespaced_pod.side_effect = [pod_list, exit_pod_list] + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + work_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.FAILED + assert "exit" in result.error_message.lower() + + @pytest.mark.asyncio + async def test_orphan_running_job_attaches(self, tmp_path: Path): + """Existing running Job → attach and wait for it.""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + # Existing active job + existing_job = MagicMock() + existing_job.metadata.name = "osa-hook-existing" + existing_job.status.succeeded = None + existing_job.status.failed = None + existing_job.status.active = 1 + job_list = MagicMock() + job_list.items = [existing_job] + batch_api.list_namespaced_job.return_value = job_list + + # Pod Running (scheduling check) + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job completes + completed_job = MagicMock() + condition = MagicMock() + condition.type = "Complete" + condition.status = "True" + completed_job.status.conditions = [condition] + completed_job.status.succeeded = 1 + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.PASSED + # Should NOT have created a new job + batch_api.create_namespaced_job.assert_not_called() + + @pytest.mark.asyncio + async def test_orphan_completed_job_reads_output(self, tmp_path: Path): + """Existing completed Job → read its output.""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + existing_job = MagicMock() + existing_job.metadata.name = "osa-hook-existing" + existing_job.status.succeeded = 1 + existing_job.status.failed = None + existing_job.status.active = None + job_list = MagicMock() + job_list.items = [existing_job] + batch_api.list_namespaced_job.return_value = job_list + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.PASSED + batch_api.create_namespaced_job.assert_not_called() + + @pytest.mark.asyncio + async def test_orphan_failed_job_creates_new(self, tmp_path: Path): + """Existing failed Job → create new one.""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + existing_job = MagicMock() + existing_job.metadata.name = "osa-hook-existing" + existing_job.status.succeeded = None + existing_job.status.failed = 1 + existing_job.status.active = None + job_list = MagicMock() + job_list.items = [existing_job] + batch_api.list_namespaced_job.return_value = job_list + + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod Running + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # New job completes + completed_job = MagicMock() + condition = MagicMock() + condition.type = "Complete" + condition.status = "True" + completed_job.status.conditions = [condition] + completed_job.status.succeeded = 1 + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.PASSED + batch_api.create_namespaced_job.assert_called_once() + + @pytest.mark.asyncio + async def test_cleanup_404_ignored(self, tmp_path: Path): + """404 on Job delete is ignored (already cleaned up).""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + + class FakeNotFound(Exception): + status = 404 + reason = "Not Found" + + batch_api.delete_namespaced_job.side_effect = FakeNotFound() + + # Should not raise + await runner._cleanup_job(batch_api, "test-job", "osa") + + @pytest.mark.asyncio + async def test_rejection_via_progress(self, tmp_path: Path): + """Hook with rejected progress entry returns REJECTED.""" + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + completed_job = MagicMock() + condition = MagicMock() + condition.type = "Complete" + condition.status = "True" + completed_job.status.conditions = [condition] + completed_job.status.succeeded = 1 + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + hook = _make_hook() + work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + (output_dir / "progress.jsonl").write_text( + '{"step":"Validate","status":"rejected","message":"Missing atoms"}\n' + ) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" + ) + + result = await runner._run_job( + batch_api, + core_api, + hook, + inputs, + work_dir, + deposition_srn="urn:osa:localhost:dep:abc123", + ) + + assert result.status == HookStatus.REJECTED + assert result.rejection_reason == "Missing atoms" + + +# --------------------------------------------------------------------------- +# Identity threading from HookInputs +# --------------------------------------------------------------------------- + + +class TestDepositionSrnFromInputs: + """Verify run() uses inputs.deposition_srn for Job labels, not path parsing.""" + + @pytest.mark.asyncio + async def test_run_uses_deposition_srn_from_inputs(self, tmp_path: Path): + """The deposition SRN in Job labels comes from inputs, not the work_dir path.""" + from unittest.mock import patch + + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sHookRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + # No existing jobs + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod scheduled + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job completes + completed_job = MagicMock() + completed_job.status.succeeded = 1 + completed_job.status.conditions = [] + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + # Work dir does NOT follow depositions path convention — proves + # we're not parsing the path to extract the SRN + work_dir = tmp_path / "arbitrary" / "path" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + (output_dir / "progress.jsonl").write_text("") + + hook = _make_hook() + inputs = HookInputs( + record_json={"srn": "test"}, + deposition_srn="urn:osa:localhost:dep:my-real-srn", + ) + + with ( + patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), + patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), + ): + await runner.run(hook, inputs, work_dir) + + # Verify the Job was created with the SRN from inputs + call_args = batch_api.create_namespaced_job.call_args + spec = call_args[0][1] # positional arg: (namespace, spec) + labels = spec.metadata.labels + assert labels["osa.io/deposition"] == "urn:osa:localhost:dep:my-real-srn" diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py new file mode 100644 index 0000000..51c7db7 --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py @@ -0,0 +1,423 @@ +"""Unit tests for K8sSourceRunner — Job spec differences, source lifecycle.""" + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.config import K8sConfig +from osa.domain.shared.error import ExternalServiceError +from osa.domain.shared.model.source import SourceDefinition, SourceLimits +from osa.domain.source.port.source_runner import SourceInputs +from osa.infrastructure.k8s.source_runner import K8sSourceRunner + + +def _make_source( + image: str = "ghcr.io/example/source:v1", + digest: str = "sha256:abc123", + timeout: int = 3600, + memory: str = "4g", + cpu: str = "2.0", + config: dict[str, Any] | None = None, +) -> SourceDefinition: + return SourceDefinition( + image=image, + digest=digest, + config=config, + limits=SourceLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), + ) + + +def _make_config(**overrides) -> K8sConfig: + defaults = { + "namespace": "osa", + "data_pvc_name": "osa-data-pvc", + "data_mount_path": "/data", + "job_ttl_seconds": 300, + } + defaults.update(overrides) + return K8sConfig(**defaults) + + +def _make_runner(config: K8sConfig | None = None) -> K8sSourceRunner: + api_client = MagicMock() + return K8sSourceRunner(api_client=api_client, config=config or _make_config()) + + +# --------------------------------------------------------------------------- +# Job spec differences (T021) +# --------------------------------------------------------------------------- + + +class TestSourceJobSpec: + def test_network_enabled(self): + """Source Jobs have normal DNS policy (network access).""" + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + pod_spec = spec.spec.template.spec + assert pod_spec.dns_policy is None or pod_spec.dns_policy != "None" + + def test_writable_rootfs(self): + """Source containers do not have readOnlyRootFilesystem.""" + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + sec = spec.spec.template.spec.containers[0].security_context + assert sec is None or sec.read_only_root_filesystem is not True + + def test_higher_defaults(self): + """Source Jobs use higher defaults (3600s, 4g).""" + runner = _make_runner() + source = _make_source(timeout=3600, memory="4g") + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + resources = spec.spec.template.spec.containers[0].resources + assert resources.limits["memory"] == "4g" + # activeDeadlineSeconds = scheduling_timeout + source timeout + assert spec.spec.active_deadline_seconds == 120 + 3600 + + def test_three_volume_mounts(self): + """Source Jobs have input, output, and files mounts.""" + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + mounts = spec.spec.template.spec.containers[0].volume_mounts + mount_paths = {m.mount_path for m in mounts} + assert "/osa/in" in mount_paths + assert "/osa/out" in mount_paths + assert "/osa/files" in mount_paths + + def test_files_mount_writable(self): + """Source files mount is writable.""" + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + mounts = spec.spec.template.spec.containers[0].volume_mounts + files_mount = next(m for m in mounts if m.mount_path == "/osa/files") + assert files_mount.read_only is not True + + def test_env_vars(self): + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + inputs=SourceInputs(convention_srn="urn:osa:localhost:conv:test", limit=100, offset=50), + ) + env = spec.spec.template.spec.containers[0].env + env_dict = {e.name: e.value for e in env} + assert env_dict["OSA_IN"] == "/osa/in" + assert env_dict["OSA_OUT"] == "/osa/out" + assert env_dict["OSA_FILES"] == "/osa/files" + assert env_dict["OSA_LIMIT"] == "100" + assert env_dict["OSA_OFFSET"] == "50" + + def test_since_env_var(self): + from datetime import datetime, UTC + + runner = _make_runner() + source = _make_source() + since = datetime(2026, 1, 1, tzinfo=UTC) + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + inputs=SourceInputs(convention_srn="urn:osa:localhost:conv:test", since=since), + ) + env = spec.spec.template.spec.containers[0].env + env_dict = {e.name: e.value for e in env} + assert "OSA_SINCE" in env_dict + + def test_source_role_label(self): + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + ) + labels = spec.spec.template.metadata.labels + assert labels["osa.io/role"] == "source" + + def test_human_readable_name(self): + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + convention_srn="urn:osa:localhost:conv:conv1", + ) + name = spec.metadata.name + assert name.startswith("osa-source-") + assert len(name) <= 63 + + def test_convention_srn_in_labels(self): + runner = _make_runner() + source = _make_source() + spec = runner._build_job_spec( + source, + work_dir=Path("/data/sources/localhost_conv1/staging/run1"), + files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), + convention_srn="urn:osa:localhost:conv:conv1", + ) + labels = spec.spec.template.metadata.labels + assert labels["osa.io/convention"] == "urn:osa:localhost:conv:conv1" + + +# --------------------------------------------------------------------------- +# Source lifecycle (T022) +# --------------------------------------------------------------------------- + + +class TestSourceLifecycle: + @pytest.mark.asyncio + async def test_successful_run_with_records(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sSourceRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + # No existing jobs + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod running + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job completes + completed_job = MagicMock() + condition = MagicMock() + condition.type = "Complete" + condition.status = "True" + completed_job.status.conditions = [condition] + completed_job.status.succeeded = 1 + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + source = _make_source() + work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + files_dir = work_dir / "files" + files_dir.mkdir(parents=True) + + # Write records output + (output_dir / "records.jsonl").write_text( + '{"id":"r1","metadata":{"title":"Test"}}\n{"id":"r2","metadata":{"title":"Test2"}}\n' + ) + (output_dir / "session.json").write_text('{"cursor":"abc"}') + + inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + result = await runner._run_job( + batch_api, + core_api, + source, + inputs, + work_dir, + files_dir, + ) + + assert len(result.records) == 2 + assert result.session == {"cursor": "abc"} + assert result.files_dir == files_dir + batch_api.delete_namespaced_job.assert_called_once() + + @pytest.mark.asyncio + async def test_timeout_raises_external_service_error(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sSourceRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job failed with DeadlineExceeded + failed_job = MagicMock() + condition = MagicMock() + condition.type = "Failed" + condition.status = "True" + condition.reason = "DeadlineExceeded" + failed_job.status.conditions = [condition] + failed_job.status.succeeded = None + failed_job.status.failed = 1 + batch_api.read_namespaced_job.return_value = failed_job + + source = _make_source() + work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" + work_dir.mkdir(parents=True) + files_dir = work_dir / "files" + files_dir.mkdir(parents=True) + inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + + with pytest.raises(ExternalServiceError, match="[Tt]imed out|[Dd]eadline"): + await runner._run_job( + batch_api, + core_api, + source, + inputs, + work_dir, + files_dir, + ) + + @pytest.mark.asyncio + async def test_oom_raises_external_service_error(self, tmp_path: Path): + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sSourceRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + + failed_job = MagicMock() + condition = MagicMock() + condition.type = "Failed" + condition.status = "True" + condition.reason = "BackoffLimitExceeded" + failed_job.status.conditions = [condition] + failed_job.status.succeeded = None + failed_job.status.failed = 1 + batch_api.read_namespaced_job.return_value = failed_job + + # OOMKilled pod + oom_pod = MagicMock() + oom_pod.status.phase = "Failed" + terminated = MagicMock() + terminated.reason = "OOMKilled" + terminated.exit_code = 137 + container_status = MagicMock() + container_status.state.terminated = terminated + oom_pod.status.container_statuses = [container_status] + oom_pod_list = MagicMock() + oom_pod_list.items = [oom_pod] + + core_api.list_namespaced_pod.side_effect = [pod_list, oom_pod_list] + + source = _make_source() + work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" + work_dir.mkdir(parents=True) + files_dir = work_dir / "files" + files_dir.mkdir(parents=True) + inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + + with pytest.raises(ExternalServiceError, match="[Oo]OM"): + await runner._run_job( + batch_api, + core_api, + source, + inputs, + work_dir, + files_dir, + ) + + +# --------------------------------------------------------------------------- +# Identity threading from SourceInputs +# --------------------------------------------------------------------------- + + +class TestConventionSrnFromInputs: + """Verify run() threads convention_srn from inputs to Job labels.""" + + @pytest.mark.asyncio + async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): + from unittest.mock import patch + + config = _make_config(data_mount_path=str(tmp_path)) + runner = K8sSourceRunner(api_client=MagicMock(), config=config) + + batch_api = AsyncMock() + core_api = AsyncMock() + + # No existing jobs + job_list = MagicMock() + job_list.items = [] + batch_api.list_namespaced_job.return_value = job_list + batch_api.create_namespaced_job.return_value = MagicMock() + + # Pod scheduled + pod = MagicMock() + pod.status.phase = "Running" + pod.status.container_statuses = None + pod_list = MagicMock() + pod_list.items = [pod] + core_api.list_namespaced_pod.return_value = pod_list + + # Job completes + completed_job = MagicMock() + completed_job.status.succeeded = 1 + completed_job.status.conditions = [] + completed_job.status.failed = None + batch_api.read_namespaced_job.return_value = completed_job + + source = _make_source() + work_dir = tmp_path / "sources" / "run1" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + files_dir = work_dir / "files" + files_dir.mkdir(parents=True) + + inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:my-conv") + + with ( + patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), + patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), + ): + await runner.run(source, inputs, files_dir, work_dir) + + # Verify convention_srn from inputs ends up in the Job labels + call_args = batch_api.create_namespaced_job.call_args + spec = call_args[0][1] + labels = spec.metadata.labels + assert labels["osa.io/convention"] == "urn:osa:localhost:conv:my-conv" diff --git a/server/tests/unit/infrastructure/k8s/test_naming.py b/server/tests/unit/infrastructure/k8s/test_naming.py new file mode 100644 index 0000000..9c5b462 --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_naming.py @@ -0,0 +1,52 @@ +"""Tests for SRN-to-Job-name sanitization.""" + +import re + + +from osa.infrastructure.k8s.naming import job_name + + +class TestJobName: + def test_basic_format(self): + name = job_name("hook", "validate-dna", "urn:osa:localhost:dep:abc123") + assert name.startswith("osa-hook-") + assert "validate-dna" in name + assert len(name) <= 63 + + def test_dns_1035_compliant(self): + """Output matches DNS-1035 label: lowercase alpha, digits, hyphens.""" + name = job_name("hook", "my_hook", "urn:osa:localhost:dep:test") + assert re.match(r"^[a-z][a-z0-9-]*[a-z0-9]$", name), f"Invalid DNS-1035: {name}" + + def test_colons_replaced(self): + name = job_name("hook", "validate", "urn:osa:archive.org:dep:abc123") + assert ":" not in name + + def test_long_names_truncated_to_63(self): + long_hook = "a" * 100 + long_srn = "urn:osa:very-long-domain.example.com:dep:" + "b" * 100 + name = job_name("hook", long_hook, long_srn) + assert len(name) <= 63 + + def test_random_suffix_for_uniqueness(self): + name1 = job_name("hook", "validate", "urn:osa:localhost:dep:abc") + name2 = job_name("hook", "validate", "urn:osa:localhost:dep:abc") + # Names should differ due to random suffix + assert name1 != name2 + + def test_source_prefix(self): + name = job_name("source", "geo-entrez", "urn:osa:localhost:dep:abc123") + assert name.startswith("osa-source-") + + def test_unicode_stripped(self): + name = job_name("hook", "validat\u00e9", "urn:osa:localhost:dep:abc") + assert re.match(r"^[a-z][a-z0-9-]*[a-z0-9]$", name) + + def test_no_trailing_hyphen(self): + name = job_name("hook", "test", "urn:osa:localhost:dep:abc") + assert not name.endswith("-") + + def test_no_leading_digit(self): + """DNS-1035 labels must start with a letter.""" + name = job_name("hook", "123test", "urn:osa:localhost:dep:abc") + assert name[0].isalpha() diff --git a/server/tests/unit/infrastructure/persistence/adapter/test_file_storage.py b/server/tests/unit/infrastructure/persistence/adapter/test_file_storage.py index 1dfd67e..eeed2c2 100644 --- a/server/tests/unit/infrastructure/persistence/adapter/test_file_storage.py +++ b/server/tests/unit/infrastructure/persistence/adapter/test_file_storage.py @@ -1,9 +1,9 @@ -"""Unit tests for LocalFileStorageAdapter — path traversal prevention.""" +"""Unit tests for FilesystemStorageAdapter — path traversal prevention.""" import pytest from osa.domain.shared.model.srn import DepositionSRN -from osa.infrastructure.persistence.adapter.storage import LocalFileStorageAdapter +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter def _make_dep_srn() -> DepositionSRN: @@ -17,7 +17,7 @@ def setup_method(self): import tempfile self._tmpdir = tempfile.mkdtemp() - self.adapter = LocalFileStorageAdapter(base_path=self._tmpdir) + self.adapter = FilesystemStorageAdapter(base_path=self._tmpdir) self.dep_srn = _make_dep_srn() @pytest.mark.asyncio @@ -64,7 +64,7 @@ def setup_method(self): import tempfile self._tmpdir = tempfile.mkdtemp() - self.adapter = LocalFileStorageAdapter(base_path=self._tmpdir) + self.adapter = FilesystemStorageAdapter(base_path=self._tmpdir) self.dep_srn = _make_dep_srn() def test_get_files_dir_returns_files_subdirectory(self): diff --git a/server/tests/unit/infrastructure/test_file_storage_hooks.py b/server/tests/unit/infrastructure/test_file_storage_hooks.py index df41930..62b04e4 100644 --- a/server/tests/unit/infrastructure/test_file_storage_hooks.py +++ b/server/tests/unit/infrastructure/test_file_storage_hooks.py @@ -1,4 +1,4 @@ -"""Unit tests for LocalFileStorageAdapter hook output methods.""" +"""Unit tests for FilesystemStorageAdapter hook output methods.""" import json from pathlib import Path @@ -6,7 +6,7 @@ import pytest from osa.domain.shared.model.srn import DepositionSRN -from osa.infrastructure.persistence.adapter.storage import LocalFileStorageAdapter +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter def _make_dep_srn() -> DepositionSRN: @@ -15,7 +15,7 @@ def _make_dep_srn() -> DepositionSRN: class TestGetHookOutputDir: def test_returns_hooks_subdirectory(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() output_dir = adapter.get_hook_output_dir(dep_srn, "pocket_detect") @@ -25,7 +25,7 @@ def test_returns_hooks_subdirectory(self, tmp_path: Path): assert output_dir.exists() def test_creates_directory(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() output_dir = adapter.get_hook_output_dir(dep_srn, "my_hook") @@ -33,7 +33,7 @@ def test_creates_directory(self, tmp_path: Path): assert output_dir.is_dir() def test_idempotent(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() dir1 = adapter.get_hook_output_dir(dep_srn, "hook_a") @@ -45,7 +45,7 @@ def test_idempotent(self, tmp_path: Path): class TestReadHookFeatures: @pytest.mark.asyncio async def test_reads_features_list(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() # Write features.json in the output/ subdirectory @@ -60,7 +60,7 @@ async def test_reads_features_list(self, tmp_path: Path): @pytest.mark.asyncio async def test_reads_features_dict(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() output_dir = tmp_path / "depositions" / "localhost_test-dep" / "hooks" / "detect" / "output" @@ -74,7 +74,7 @@ async def test_reads_features_dict(self, tmp_path: Path): @pytest.mark.asyncio async def test_returns_empty_when_missing(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() features = await adapter.read_hook_features(dep_srn, "nonexistent") @@ -85,7 +85,7 @@ async def test_returns_empty_when_missing(self, tmp_path: Path): class TestHookFeaturesExist: @pytest.mark.asyncio async def test_true_when_file_exists(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() output_dir = tmp_path / "depositions" / "localhost_test-dep" / "hooks" / "detect" / "output" @@ -96,7 +96,7 @@ async def test_true_when_file_exists(self, tmp_path: Path): @pytest.mark.asyncio async def test_false_when_missing(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() assert await adapter.hook_features_exist(dep_srn, "nonexistent") is False @@ -105,7 +105,7 @@ async def test_false_when_missing(self, tmp_path: Path): class TestDeleteCleansHookOutputs: @pytest.mark.asyncio async def test_rmtree_removes_hooks_dir(self, tmp_path: Path): - adapter = LocalFileStorageAdapter(base_path=str(tmp_path)) + adapter = FilesystemStorageAdapter(base_path=str(tmp_path)) dep_srn = _make_dep_srn() # Create hook output diff --git a/server/tests/unit/infrastructure/test_file_storage_move.py b/server/tests/unit/infrastructure/test_file_storage_move.py new file mode 100644 index 0000000..b9735f1 --- /dev/null +++ b/server/tests/unit/infrastructure/test_file_storage_move.py @@ -0,0 +1,139 @@ +"""Tests for FilesystemStorageAdapter move and save fallback behavior. + +Tests that move_source_files_to_deposition and save_file fall back to +copy+delete when rename() raises OSError (e.g., cross-device or S3 CSI mount). +""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from osa.domain.shared.error import InfrastructureError +from osa.domain.shared.model.srn import DepositionSRN +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter + + +def _make_dep_srn() -> DepositionSRN: + return DepositionSRN.parse("urn:osa:localhost:dep:test123") + + +class TestMoveSourceFilesFallback: + """move_source_files_to_deposition falls back to copy+delete on OSError.""" + + def test_rename_works_on_local_filesystem(self, tmp_path: Path): + """rename() still works on local filesystem (no fallback needed).""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + + staging_dir = tmp_path / "staging" + source_id = "src1" + source_files = staging_dir / source_id + source_files.mkdir(parents=True) + (source_files / "data.csv").write_text("a,b,c") + + adapter.move_source_files_to_deposition(staging_dir, source_id, dep_srn) + + files_dir = adapter.get_files_dir(dep_srn) + assert (files_dir / "data.csv").read_text() == "a,b,c" + assert not source_files.exists() + + def test_fallback_copy_delete_on_oserror(self, tmp_path: Path): + """Falls back to shutil.copy2 + unlink when rename() raises OSError.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + + staging_dir = tmp_path / "staging" + source_id = "src1" + source_files = staging_dir / source_id + source_files.mkdir(parents=True) + (source_files / "data.csv").write_text("a,b,c") + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + with patch.object(Path, "rename", failing_rename): + adapter.move_source_files_to_deposition(staging_dir, source_id, dep_srn) + + files_dir = adapter.get_files_dir(dep_srn) + assert (files_dir / "data.csv").read_text() == "a,b,c" + assert not (source_files / "data.csv").exists() + + def test_fallback_is_idempotent_on_retry(self, tmp_path: Path): + """Retrying copy+delete after a crash works (file already at target).""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + + staging_dir = tmp_path / "staging" + source_id = "src1" + source_files = staging_dir / source_id + source_files.mkdir(parents=True) + (source_files / "data.csv").write_text("a,b,c") + + # First move (simulating crash after copy but before delete) + files_dir = adapter.get_files_dir(dep_srn) + (files_dir / "data.csv").write_text("a,b,c") # Pre-existing copy + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + with patch.object(Path, "rename", failing_rename): + adapter.move_source_files_to_deposition(staging_dir, source_id, dep_srn) + + assert (files_dir / "data.csv").read_text() == "a,b,c" + assert not (source_files / "data.csv").exists() + + def test_copy_failure_raises_infrastructure_error(self, tmp_path: Path): + """Copy failure wraps OSError in InfrastructureError with file context.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + + staging_dir = tmp_path / "staging" + source_id = "src1" + source_files = staging_dir / source_id + source_files.mkdir(parents=True) + (source_files / "data.csv").write_text("a,b,c") + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + with ( + patch.object(Path, "rename", failing_rename), + patch("shutil.copy2", side_effect=OSError("No space left on device")), + pytest.raises(InfrastructureError, match="data.csv"), + ): + adapter.move_source_files_to_deposition(staging_dir, source_id, dep_srn) + + +class TestSaveFileFallback: + """save_file atomic write falls back to copy+delete on OSError.""" + + @pytest.mark.asyncio + async def test_save_file_rename_works(self, tmp_path: Path): + """save_file uses rename for atomic write on local filesystem.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + content = b"hello world" + + result = await adapter.save_file(dep_srn, "test.txt", content, len(content)) + + files_dir = adapter.get_files_dir(dep_srn) + assert (files_dir / "test.txt").read_bytes() == content + assert result.name == "test.txt" + + @pytest.mark.asyncio + async def test_save_file_fallback_on_oserror(self, tmp_path: Path): + """save_file falls back to copy+delete when rename() raises OSError.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + content = b"hello world" + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + with patch.object(Path, "rename", failing_rename): + result = await adapter.save_file(dep_srn, "test.txt", content, len(content)) + + files_dir = adapter.get_files_dir(dep_srn) + assert (files_dir / "test.txt").read_bytes() == content + assert result.name == "test.txt" diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index 568fe79..e901c9b 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -15,6 +15,11 @@ from osa.domain.validation.model.hook_result import HookStatus, ProgressEntry from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.oci.runner import OciHookRunner +from osa.infrastructure.runner_utils import ( + detect_rejection, + parse_memory, + parse_progress_file, +) def _make_hook( @@ -47,43 +52,34 @@ def _make_runner(docker: AsyncMock | None = None) -> OciHookRunner: class TestParseMemory: def test_gigabytes(self): - runner = _make_runner() - assert runner._parse_memory("2g") == 2 * 1024 * 1024 * 1024 + assert parse_memory("2g") == 2 * 1024 * 1024 * 1024 def test_megabytes(self): - runner = _make_runner() - assert runner._parse_memory("512m") == 512 * 1024 * 1024 + assert parse_memory("512m") == 512 * 1024 * 1024 def test_kilobytes(self): - runner = _make_runner() - assert runner._parse_memory("1024k") == 1024 * 1024 + assert parse_memory("1024k") == 1024 * 1024 def test_bare_bytes(self): - runner = _make_runner() - assert runner._parse_memory("1048576") == 1048576 + assert parse_memory("1048576") == 1048576 def test_fractional(self): - runner = _make_runner() - assert runner._parse_memory("1.5g") == int(1.5 * 1024 * 1024 * 1024) + assert parse_memory("1.5g") == int(1.5 * 1024 * 1024 * 1024) def test_case_insensitive(self): - runner = _make_runner() - assert runner._parse_memory("2G") == 2 * 1024 * 1024 * 1024 + assert parse_memory("2G") == 2 * 1024 * 1024 * 1024 def test_with_i_suffix(self): - runner = _make_runner() - assert runner._parse_memory("2gi") == 2 * 1024 * 1024 * 1024 + assert parse_memory("2gi") == 2 * 1024 * 1024 * 1024 def test_invalid_format(self): - runner = _make_runner() with pytest.raises(ValueError, match="Invalid memory format"): - runner._parse_memory("abc") + parse_memory("abc") class TestParseProgress: def test_empty_when_no_file(self, tmp_path: Path): - runner = _make_runner() - entries = runner._parse_progress(tmp_path) + entries = parse_progress_file(tmp_path) assert entries == [] def test_parses_valid_jsonl(self, tmp_path: Path): @@ -92,8 +88,7 @@ def test_parses_valid_jsonl(self, tmp_path: Path): '{"step":"Loading","status":"completed","message":"Done"}\n' '{"step":"Analyzing","status":"completed","message":"Finished"}\n' ) - runner = _make_runner() - entries = runner._parse_progress(tmp_path) + entries = parse_progress_file(tmp_path) assert len(entries) == 2 assert entries[0].step == "Loading" assert entries[0].status == "completed" @@ -106,8 +101,7 @@ def test_skips_invalid_json_lines(self, tmp_path: Path): "not valid json\n" '{"step":"AlsoGood","status":"completed"}\n' ) - runner = _make_runner() - entries = runner._parse_progress(tmp_path) + entries = parse_progress_file(tmp_path) assert len(entries) == 2 def test_skips_blank_lines(self, tmp_path: Path): @@ -115,15 +109,13 @@ def test_skips_blank_lines(self, tmp_path: Path): progress_file.write_text( '{"step":"A","status":"completed"}\n\n{"step":"B","status":"completed"}\n' ) - runner = _make_runner() - entries = runner._parse_progress(tmp_path) + entries = parse_progress_file(tmp_path) assert len(entries) == 2 def test_handles_missing_optional_fields(self, tmp_path: Path): progress_file = tmp_path / "progress.jsonl" progress_file.write_text('{"status":"completed"}\n') - runner = _make_runner() - entries = runner._parse_progress(tmp_path) + entries = parse_progress_file(tmp_path) assert len(entries) == 1 assert entries[0].step is None assert entries[0].message is None @@ -131,34 +123,35 @@ def test_handles_missing_optional_fields(self, tmp_path: Path): class TestCheckRejection: def test_no_rejection(self): - runner = _make_runner() entries = [ ProgressEntry(step="Load", status="completed", message="OK"), ProgressEntry(step="Process", status="completed", message="Done"), ] - assert runner._check_rejection(entries) is None + rejected, reason = detect_rejection(entries) + assert not rejected + assert reason is None def test_detects_rejection(self): - runner = _make_runner() entries = [ ProgressEntry(step="Load", status="completed", message="OK"), ProgressEntry(step="Validate", status="rejected", message="Missing atoms"), ] - result = runner._check_rejection(entries) - assert result == "Missing atoms" + rejected, reason = detect_rejection(entries) + assert rejected + assert reason == "Missing atoms" def test_empty_progress(self): - runner = _make_runner() - assert runner._check_rejection([]) is None + rejected, reason = detect_rejection([]) + assert not rejected def test_returns_last_rejection(self): """When multiple rejections exist, returns the most recent.""" - runner = _make_runner() entries = [ ProgressEntry(step="A", status="rejected", message="First rejection"), ProgressEntry(step="B", status="rejected", message="Second rejection"), ] - assert runner._check_rejection(entries) == "Second rejection" + rejected, reason = detect_rejection(entries) + assert reason == "Second rejection" class TestContainerLifecycle: @@ -172,7 +165,9 @@ async def test_successful_hook_returns_passed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -195,7 +190,9 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -215,7 +212,9 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -242,7 +241,9 @@ async def hang(): runner = OciHookRunner(docker=docker) hook = _make_hook(timeout=1) # 1 second timeout - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -262,7 +263,9 @@ async def test_rejection_via_progress(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) work_dir = tmp_path / "hook_work" work_dir.mkdir() @@ -292,7 +295,9 @@ async def test_security_hardening(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook(memory="4g", cpu="4.0") - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -323,7 +328,9 @@ async def test_env_vars_set(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -350,7 +357,11 @@ async def test_nested_bind_mounts(self, tmp_path: Path): hook = _make_hook() files_dir = tmp_path / "files" files_dir.mkdir() - inputs = HookInputs(record_json={"srn": "test"}, files_dir=files_dir) + inputs = HookInputs( + record_json={"srn": "test"}, + deposition_srn="urn:osa:localhost:dep:test123", + files_dir=files_dir, + ) work_dir = tmp_path / "hook_work" work_dir.mkdir() @@ -382,7 +393,11 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}, files_dir=None) + inputs = HookInputs( + record_json={"srn": "test"}, + deposition_srn="urn:osa:localhost:dep:test123", + files_dir=None, + ) output_dir = tmp_path / "output" output_dir.mkdir() @@ -407,7 +422,9 @@ async def test_container_deleted_on_failure(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() - inputs = HookInputs(record_json={"srn": "test"}) + inputs = HookInputs( + record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + ) output_dir = tmp_path / "output" output_dir.mkdir() diff --git a/server/uv.lock b/server/uv.lock index 937793e..906bb47 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -464,11 +464,11 @@ wheels = [ [[package]] name = "dishka" -version = "1.7.2" +version = "1.9.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/40/d7/1be31f5ef32387059190353f9fa493ff4d07a1c75fa856c7566ca45e0800/dishka-1.7.2.tar.gz", hash = "sha256:47d4cb5162b28c61bf5541860e605ed5eaf5c667122299c7ef657c86fc8d5a49", size = 68132, upload-time = "2025-09-24T21:23:05.135Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/97/18d4a9bd44f6baa975cd8d54ed3a1a86b341a43c9c077e647d351c9d4573/dishka-1.9.1.tar.gz", hash = "sha256:973f19dc65160a97370181106764ae076052af4489e94b0cedb3eb4e47fe13bf", size = 274962, upload-time = "2026-03-08T09:43:47.298Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/b9/89381173b4f336e986d72471198614806cd313e0f85c143ccb677c310223/dishka-1.7.2-py3-none-any.whl", hash = "sha256:f6faa6ab321903926b825b3337d77172ee693450279b314434864978d01fbad3", size = 94774, upload-time = "2025-09-24T21:23:03.246Z" }, + { url = "https://files.pythonhosted.org/packages/33/98/c8f80be83fbd92f5f9d4bdb5d619a9c9901fb1523c0b02a448b942e532e6/dishka-1.9.1-py3-none-any.whl", hash = "sha256:5080a46bf40bd403aee396aac81f999f679078655f9a6f2062111d62e94e7b18", size = 114327, upload-time = "2026-03-08T09:43:46.097Z" }, ] [[package]] @@ -926,6 +926,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/70/05b685ea2dffcb2adbf3cdcea5d8865b7bc66f67249084cf845012a0ff13/kubernetes-35.0.0-py2.py3-none-any.whl", hash = "sha256:39e2b33b46e5834ef6c3985ebfe2047ab39135d41de51ce7641a7ca5b372a13d", size = 2017602, upload-time = "2026-01-16T01:05:25.991Z" }, ] +[[package]] +name = "kubernetes-asyncio" +version = "35.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "certifi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "python-dateutil", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyyaml", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "six", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "urllib3", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/b9/f3b9fb2d3ef4550918b83c328dc720a58f65bc66732d9438e06469573ad1/kubernetes_asyncio-35.0.1.tar.gz", hash = "sha256:975870e3097b647c265a59b9175ab0841f0de06cd2162268273ca210b1fa672e", size = 1320250, upload-time = "2026-02-25T20:40:42.87Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/b3/a8917d253763095fb8dcaaefc6a135ed31abbd13f681e78752e226e252fe/kubernetes_asyncio-35.0.1-py3-none-any.whl", hash = "sha256:244ef45943e89c5c5104276a646bfcbf1a9dc3d060876c2094aa601e932f1c03", size = 2868606, upload-time = "2026-02-25T20:40:41.191Z" }, +] + [[package]] name = "logfire" version = "4.21.0" @@ -1583,6 +1600,11 @@ dependencies = [ { name = "uvicorn", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +[package.optional-dependencies] +k8s = [ + { name = "kubernetes-asyncio", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + [package.dev-dependencies] dev = [ { name = "coverage", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1606,6 +1628,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.121.1" }, { name = "greenlet", specifier = ">=3.2.4" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "kubernetes-asyncio", marker = "extra == 'k8s'", specifier = ">=31.0" }, { name = "logfire", extras = ["fastapi", "httpx"], specifier = ">=4.15.1" }, { name = "openpyxl", specifier = ">=3.1.5" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, @@ -1617,6 +1640,7 @@ requires-dist = [ { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "uvicorn", specifier = ">=0.38.0" }, ] +provides-extras = ["k8s"] [package.metadata.requires-dev] dev = [ From 2e05628749e410737fbbcdb6ac12c6bef82c3207 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 22:18:17 +0000 Subject: [PATCH 2/6] ci: add k8s extra dependencies to all workflow jobs Install k8s extra dependencies across all CI jobs to ensure Kubernetes-related functionality is available during testing, type checking, and deployment processes. --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58ca2bf..72f8391 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: run: uv python install ${{ env.PYTHON_VERSION }} - name: Install dependencies - run: uv sync --frozen + run: uv sync --frozen --extra k8s - name: Check formatting run: uv run ruff format --check . @@ -79,7 +79,7 @@ jobs: run: uv python install ${{ env.PYTHON_VERSION }} - name: Install dependencies - run: uv sync --frozen + run: uv sync --frozen --extra k8s - name: Run type checker run: uv run ty check osa @@ -107,7 +107,7 @@ jobs: run: uv python install ${{ env.PYTHON_VERSION }} - name: Install dependencies - run: uv sync --frozen + run: uv sync --frozen --extra k8s - name: Run unit tests with coverage run: uv run pytest tests/unit -v --tb=short --cov=osa --cov-report=xml --cov-report=term-missing @@ -148,7 +148,7 @@ jobs: run: uv python install ${{ env.PYTHON_VERSION }} - name: Install dependencies - run: uv sync --frozen + run: uv sync --frozen --extra k8s - name: Run contract tests run: uv run pytest tests/contract -v --tb=short @@ -187,7 +187,7 @@ jobs: run: uv python install ${{ env.PYTHON_VERSION }} - name: Install dependencies - run: uv sync --frozen + run: uv sync --frozen --extra k8s - name: Run migrations run: uv run alembic upgrade head From b789de6872bd6ed034b0b7ab8c2190a6d9d88b17 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 23:11:33 +0000 Subject: [PATCH 3/6] refactor: replace string SRN parameters with typed SRN objects Replace string-based SRN handling with proper SRN type objects in source and validation runners to improve type safety and enable better K8s label generation. Add K8s label utilities for converting SRNs to DNS-compliant label values within 63-character limit. --- .../osa/domain/source/port/source_runner.py | 3 +- server/osa/domain/source/service/source.py | 2 +- .../osa/domain/validation/port/hook_runner.py | 3 +- .../domain/validation/service/validation.py | 2 +- server/osa/infrastructure/k8s/naming.py | 33 +++++- server/osa/infrastructure/k8s/runner.py | 15 +-- .../osa/infrastructure/k8s/source_runner.py | 38 ++++-- .../domain/validation/test_hook_runner.py | 12 +- .../validation/test_validation_service.py | 2 +- .../k8s/test_k8s_hook_runner.py | 112 ++++++++++-------- .../k8s/test_k8s_source_runner.py | 25 ++-- .../unit/infrastructure/k8s/test_naming.py | 49 +++++++- .../infrastructure/test_oci_hook_runner.py | 29 +++-- 13 files changed, 224 insertions(+), 101 deletions(-) diff --git a/server/osa/domain/source/port/source_runner.py b/server/osa/domain/source/port/source_runner.py index 2e97858..44ad49f 100644 --- a/server/osa/domain/source/port/source_runner.py +++ b/server/osa/domain/source/port/source_runner.py @@ -8,13 +8,14 @@ from typing import Any, Protocol from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.srn import ConventionSRN @dataclass(frozen=True) class SourceInputs: """Inputs for a source container run.""" - convention_srn: str + convention_srn: ConventionSRN config: dict[str, Any] | None = None since: datetime | None = None limit: int | None = None diff --git a/server/osa/domain/source/service/source.py b/server/osa/domain/source/service/source.py index ef73660..03d2040 100644 --- a/server/osa/domain/source/service/source.py +++ b/server/osa/domain/source/service/source.py @@ -71,7 +71,7 @@ async def run_source( # Build inputs inputs = SourceInputs( - convention_srn=str(convention_srn), + convention_srn=convention_srn, config=source.config, since=since, limit=limit, diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index 8665cd5..23c0583 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -6,6 +6,7 @@ from typing import Protocol, runtime_checkable from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.model.srn import DepositionSRN from osa.domain.shared.port import Port from osa.domain.validation.model.hook_result import HookResult @@ -15,7 +16,7 @@ class HookInputs: """Inputs to pass to a hook container.""" record_json: dict - deposition_srn: str + deposition_srn: DepositionSRN files_dir: Path | None = None config: dict | None = None diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index c63220d..07efb88 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -107,7 +107,7 @@ async def validate_deposition( record_json = {"srn": str(deposition_srn), "metadata": metadata} inputs = HookInputs( record_json=record_json, - deposition_srn=str(deposition_srn), + deposition_srn=deposition_srn, files_dir=Path(files_dir) if files_dir else None, ) diff --git a/server/osa/infrastructure/k8s/naming.py b/server/osa/infrastructure/k8s/naming.py index df5ad43..149da9c 100644 --- a/server/osa/infrastructure/k8s/naming.py +++ b/server/osa/infrastructure/k8s/naming.py @@ -1,8 +1,39 @@ -"""SRN-to-Job-name sanitization for DNS-1035 compliance.""" +"""K8s naming utilities: Job names (DNS-1035) and label values.""" import re import secrets +from osa.domain.shared.model.srn import SRN + + +def sanitize_label(raw: str) -> str: + """Sanitize a raw string for use as a K8s label value. + + K8s label values must match [a-zA-Z0-9._-], max 63 chars. + Replaces invalid characters with dots and collapses runs. + """ + sanitized = re.sub(r"[^a-zA-Z0-9._-]", ".", raw) + sanitized = re.sub(r"[._-]{2,}", ".", sanitized) + return sanitized[:63].strip("-._") + + +def label_value(srn: SRN) -> str: + """Convert an SRN to a K8s-safe label value. + + Strips the constant ``urn:osa:`` prefix to save space within the + 63-char K8s label limit, then sanitizes for label compliance. + + Examples: + label_value(DepositionSRN.parse("urn:osa:localhost:dep:abc123")) + → "localhost.dep.abc123" + """ + # Format: urn:osa:{domain}:{type}:{id}[@version] + # Strip "urn:osa:" prefix — it's constant and wastes label budget + compact = f"{srn.domain.root}.{srn.type.value}.{srn.id.root}" + if srn.version is not None: + compact += f".{srn.version}" + return sanitize_label(compact) + def job_name(prefix: str, hook_name: str, deposition_srn: str) -> str: """Generate a K8s Job name from prefix, hook name, and deposition SRN. diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index d9e4e84..b8d446c 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -12,10 +12,11 @@ from osa.config import K8sConfig from osa.domain.shared.error import InfrastructureError from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.model.srn import DepositionSRN from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.k8s.errors import classify_api_error -from osa.infrastructure.k8s.naming import job_name +from osa.infrastructure.k8s.naming import job_name, label_value from osa.infrastructure.runner_utils import detect_rejection, parse_progress_file if TYPE_CHECKING: @@ -86,7 +87,7 @@ async def _run_job( inputs: HookInputs, work_dir: Path, *, - deposition_srn: str = "", + deposition_srn: DepositionSRN, ) -> HookResult: """Core Job lifecycle: check orphans → create → schedule → execute → parse → cleanup.""" namespace = self._config.namespace @@ -183,7 +184,7 @@ async def _check_existing_job( batch_api: BatchV1Api, namespace: str, hook_name: str, - deposition_srn: str, + deposition_srn: DepositionSRN, ) -> str | None: """Check for existing Jobs with matching labels. @@ -192,7 +193,7 @@ async def _check_existing_job( "active:{job_name}" if a running Job exists None if no Job or only failed Jobs exist """ - label_selector = f"osa.io/hook={hook_name},osa.io/deposition={deposition_srn}" + label_selector = f"osa.io/hook={hook_name},osa.io/deposition={label_value(deposition_srn)}" try: job_list = await batch_api.list_namespaced_job(namespace, label_selector=label_selector) except Exception as exc: @@ -211,7 +212,7 @@ def _build_job_spec( hook: HookDefinition, work_dir: Path, *, - deposition_srn: str = "", + deposition_srn: DepositionSRN, files_dir: Path | None = None, ) -> V1Job: """Build a K8s Job manifest for a hook execution.""" @@ -235,7 +236,7 @@ def _build_job_spec( V1VolumeMount, ) - name = job_name("hook", hook.name, deposition_srn) + name = job_name("hook", hook.name, str(deposition_srn)) relative_work = self._relative_path(work_dir) input_subpath = f"{relative_work}/input" output_subpath = f"{relative_work}/output" @@ -243,7 +244,7 @@ def _build_job_spec( labels = { "osa.io/role": "hook", "osa.io/hook": hook.name, - "osa.io/deposition": deposition_srn, + "osa.io/deposition": label_value(deposition_srn), } mounts = [ diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/source_runner.py index bd40240..19b4aa2 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/source_runner.py @@ -12,9 +12,10 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError, InfrastructureError from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.srn import ConventionSRN from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner from osa.infrastructure.k8s.errors import classify_api_error -from osa.infrastructure.k8s.naming import job_name +from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label from osa.infrastructure.runner_utils import parse_records_file, parse_session_file if TYPE_CHECKING: @@ -93,7 +94,7 @@ async def _run_job( work_dir: Path, files_dir: Path, *, - convention_srn: str = "", + convention_srn: ConventionSRN | None = None, ) -> SourceOutput: """Core Job lifecycle for source execution.""" namespace = self._config.namespace @@ -101,7 +102,9 @@ async def _run_job( try: # Check for existing Jobs - existing = await self._check_existing_job(batch_api, namespace, convention_srn) + existing = await self._check_existing_job( + batch_api, namespace, convention_srn, source.digest + ) if existing == "succeeded": return self._parse_source_output(work_dir, files_dir) @@ -168,11 +171,17 @@ def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceOutput: return SourceOutput(records=records, session=session, files_dir=files_dir) async def _check_existing_job( - self, batch_api: BatchV1Api, namespace: str, convention_srn: str + self, + batch_api: BatchV1Api, + namespace: str, + convention_srn: ConventionSRN | None, + digest: str = "", ) -> str | None: label_parts = ["osa.io/role=source"] - if convention_srn: - label_parts.append(f"osa.io/convention={convention_srn}") + if convention_srn is not None: + label_parts.append(f"osa.io/convention={label_value(convention_srn)}") + if digest: + label_parts.append(f"osa.io/digest={sanitize_label(digest)}") label_selector = ",".join(label_parts) try: @@ -194,7 +203,7 @@ def _build_job_spec( work_dir: Path, files_dir: Path, inputs: SourceInputs | None = None, - convention_srn: str = "", + convention_srn: ConventionSRN | None = None, ) -> V1Job: from kubernetes_asyncio.client import ( V1Capabilities, @@ -214,7 +223,7 @@ def _build_job_spec( V1VolumeMount, ) - name = job_name("source", "src", convention_srn or "unknown") + name = job_name("source", "src", str(convention_srn) if convention_srn else "unknown") relative_work = self._relative_path(work_dir) input_subpath = f"{relative_work}/input" output_subpath = f"{relative_work}/output" @@ -222,9 +231,10 @@ def _build_job_spec( labels: dict[str, str] = { "osa.io/role": "source", + "osa.io/digest": sanitize_label(source.digest), } - if convention_srn: - labels["osa.io/convention"] = convention_srn + if convention_srn is not None: + labels["osa.io/convention"] = label_value(convention_srn) env = [ V1EnvVar(name="OSA_IN", value="/osa/in"), @@ -381,6 +391,14 @@ async def _wait_for_completion( await asyncio.sleep(poll_interval) + # Timed out — poll once more to catch last-millisecond completions + try: + job = await batch_api.read_namespaced_job(job_name, namespace) + if job.status.succeeded: + return "succeeded" + except Exception: + pass + return "failed:WatchTimeout" async def _diagnose_and_raise( diff --git a/server/tests/unit/domain/validation/test_hook_runner.py b/server/tests/unit/domain/validation/test_hook_runner.py index bcf5ecb..16d5f26 100644 --- a/server/tests/unit/domain/validation/test_hook_runner.py +++ b/server/tests/unit/domain/validation/test_hook_runner.py @@ -5,6 +5,7 @@ import pytest from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.model.srn import DepositionSRN from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner @@ -13,7 +14,7 @@ class TestHookInputs: def test_minimal_construction(self): inputs = HookInputs( record_json={"srn": "urn:osa:localhost:rec:123"}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) assert inputs.record_json == {"srn": "urn:osa:localhost:rec:123"} assert inputs.files_dir is None @@ -23,7 +24,7 @@ def test_with_files_dir(self): files = Path("/tmp/files") inputs = HookInputs( record_json={"srn": "test"}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), files_dir=files, ) assert inputs.files_dir == files @@ -31,7 +32,7 @@ def test_with_files_dir(self): def test_with_config(self): inputs = HookInputs( record_json={"srn": "test"}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), config={"r_min": 3.0, "threshold": 0.5}, ) assert inputs.config == {"r_min": 3.0, "threshold": 0.5} @@ -40,7 +41,7 @@ def test_full_construction(self): files = Path("/tmp/data/files") inputs = HookInputs( record_json={"srn": "urn:osa:localhost:rec:456", "name": "test"}, - deposition_srn="urn:osa:localhost:dep:test456", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test456"), files_dir=files, config={"key": "value"}, ) @@ -50,7 +51,8 @@ def test_full_construction(self): def test_is_frozen(self): inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) with pytest.raises(AttributeError): inputs.record_json = {} # type: ignore[misc] diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index c04b355..d8b598e 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -62,7 +62,7 @@ def _make_service( def _make_inputs() -> HookInputs: return HookInputs( record_json={"srn": "urn:osa:localhost:dep:test123", "metadata": {"name": "test"}}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 8d7fcd3..1bb5fbb 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -14,10 +14,13 @@ OciLimits, TableFeatureSpec, ) +from osa.domain.shared.model.srn import DepositionSRN from osa.domain.validation.model.hook_result import HookStatus from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.k8s.runner import K8sHookRunner +_DEP_SRN = DepositionSRN.parse("urn:osa:localhost:dep:abc123") + def _make_hook( name: str = "validate_dna", @@ -69,7 +72,9 @@ def test_correct_image(self): runner = _make_runner() hook = _make_hook(image="ghcr.io/org/hook:v2", digest="sha256:def456") spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) container = spec.spec.template.spec.containers[0] @@ -79,7 +84,9 @@ def test_security_context(self): runner = _make_runner() hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) pod_spec = spec.spec.template.spec @@ -99,7 +106,9 @@ def test_resource_limits(self): runner = _make_runner() hook = _make_hook(memory="4g", cpu="2.0") spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) resources = spec.spec.template.spec.containers[0].resources @@ -110,7 +119,7 @@ def test_volume_mounts(self): runner = _make_runner() hook = _make_hook() work_dir = Path("/data/depositions/localhost_abc/hooks/validate_dna") - spec = runner._build_job_spec(hook, work_dir) + spec = runner._build_job_spec(hook, work_dir, deposition_srn=_DEP_SRN) volumes = spec.spec.template.spec.volumes pvc_vol = next(v for v in volumes if v.name == "data") @@ -129,7 +138,9 @@ def test_env_vars(self): runner = _make_runner() hook = _make_hook(name="pocket_detect") spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/pocket_detect") + hook, + Path("/data/depositions/localhost_abc/hooks/pocket_detect"), + deposition_srn=_DEP_SRN, ) env = spec.spec.template.spec.containers[0].env @@ -142,7 +153,9 @@ def test_backoff_limit_zero(self): runner = _make_runner() hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) assert spec.spec.backoff_limit == 0 @@ -151,7 +164,9 @@ def test_active_deadline_seconds(self): runner = _make_runner() hook = _make_hook(timeout=300) spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) # scheduling_timeout (120) + hook timeout (300) @@ -161,7 +176,9 @@ def test_dns_policy_none(self): runner = _make_runner() hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) pod_spec = spec.spec.template.spec @@ -174,13 +191,13 @@ def test_labels(self): spec = runner._build_job_spec( hook, Path("/data/depositions/localhost_abc/hooks/validate_dna"), - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) labels = spec.spec.template.metadata.labels assert labels["osa.io/role"] == "hook" assert labels["osa.io/hook"] == "validate_dna" - assert labels["osa.io/deposition"] == "urn:osa:localhost:dep:abc123" + assert labels["osa.io/deposition"] == "localhost.dep.abc123" def test_human_readable_job_name(self): runner = _make_runner() @@ -188,7 +205,7 @@ def test_human_readable_job_name(self): spec = runner._build_job_spec( hook, Path("/data/depositions/localhost_abc/hooks/validate_dna"), - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) name = spec.metadata.name @@ -199,7 +216,9 @@ def test_empty_dir_at_tmp(self): runner = _make_runner() hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) volumes = spec.spec.template.spec.volumes @@ -210,7 +229,9 @@ def test_automount_service_account_false(self): runner = _make_runner() hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) pod_spec = spec.spec.template.spec @@ -220,7 +241,9 @@ def test_ttl_seconds_after_finished(self): runner = _make_runner(config=_make_config(job_ttl_seconds=600)) hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) assert spec.spec.ttl_seconds_after_finished == 600 @@ -231,6 +254,7 @@ def test_files_mount_when_files_dir_provided(self): spec = runner._build_job_spec( hook, Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, files_dir=Path("/data/depositions/localhost_abc/files"), ) @@ -243,7 +267,9 @@ def test_image_pull_secrets(self): runner = _make_runner(config=_make_config(image_pull_secrets=["ghcr-secret"])) hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) secrets = spec.spec.template.spec.image_pull_secrets @@ -254,7 +280,9 @@ def test_service_account(self): runner = _make_runner(config=_make_config(service_account="osa-runner")) hook = _make_hook() spec = runner._build_job_spec( - hook, Path("/data/depositions/localhost_abc/hooks/validate_dna") + hook, + Path("/data/depositions/localhost_abc/hooks/validate_dna"), + deposition_srn=_DEP_SRN, ) assert spec.spec.template.spec.service_account_name == "osa-runner" @@ -424,16 +452,14 @@ async def test_successful_run(self, tmp_path: Path): '{"step":"Check","status":"completed","message":"OK"}\n' ) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, core_api, hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.PASSED @@ -476,9 +502,7 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -486,7 +510,7 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.FAILED @@ -544,9 +568,7 @@ async def test_oom_exit_137(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -554,7 +576,7 @@ async def test_oom_exit_137(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.FAILED @@ -606,9 +628,7 @@ async def test_nonzero_exit(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -616,7 +636,7 @@ async def test_nonzero_exit(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.FAILED @@ -663,9 +683,7 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -673,7 +691,7 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.PASSED @@ -702,9 +720,7 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -712,7 +728,7 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.PASSED @@ -760,9 +776,7 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -770,7 +784,7 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.PASSED @@ -830,9 +844,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): (output_dir / "progress.jsonl").write_text( '{"step":"Validate","status":"rejected","message":"Missing atoms"}\n' ) - inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:abc123" - ) + inputs = HookInputs(record_json={"srn": "test"}, deposition_srn=_DEP_SRN) result = await runner._run_job( batch_api, @@ -840,7 +852,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): hook, inputs, work_dir, - deposition_srn="urn:osa:localhost:dep:abc123", + deposition_srn=_DEP_SRN, ) assert result.status == HookStatus.REJECTED @@ -897,7 +909,7 @@ async def test_run_uses_deposition_srn_from_inputs(self, tmp_path: Path): hook = _make_hook() inputs = HookInputs( record_json={"srn": "test"}, - deposition_srn="urn:osa:localhost:dep:my-real-srn", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:my-real-srn"), ) with ( @@ -910,4 +922,4 @@ async def test_run_uses_deposition_srn_from_inputs(self, tmp_path: Path): call_args = batch_api.create_namespaced_job.call_args spec = call_args[0][1] # positional arg: (namespace, spec) labels = spec.metadata.labels - assert labels["osa.io/deposition"] == "urn:osa:localhost:dep:my-real-srn" + assert labels["osa.io/deposition"] == "localhost.dep.my-real-srn" diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py index 51c7db7..3935945 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py @@ -9,9 +9,12 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError from osa.domain.shared.model.source import SourceDefinition, SourceLimits +from osa.domain.shared.model.srn import ConventionSRN from osa.domain.source.port.source_runner import SourceInputs from osa.infrastructure.k8s.source_runner import K8sSourceRunner +_CONV_SRN = ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") + def _make_source( image: str = "ghcr.io/example/source:v1", @@ -124,7 +127,7 @@ def test_env_vars(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn="urn:osa:localhost:conv:test", limit=100, offset=50), + inputs=SourceInputs(convention_srn=_CONV_SRN, limit=100, offset=50), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} @@ -144,7 +147,7 @@ def test_since_env_var(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn="urn:osa:localhost:conv:test", since=since), + inputs=SourceInputs(convention_srn=_CONV_SRN, since=since), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} @@ -168,7 +171,7 @@ def test_human_readable_name(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - convention_srn="urn:osa:localhost:conv:conv1", + convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), ) name = spec.metadata.name assert name.startswith("osa-source-") @@ -181,10 +184,10 @@ def test_convention_srn_in_labels(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - convention_srn="urn:osa:localhost:conv:conv1", + convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), ) labels = spec.spec.template.metadata.labels - assert labels["osa.io/convention"] == "urn:osa:localhost:conv:conv1" + assert labels["osa.io/convention"] == "localhost.conv.conv1.1.0.0" # --------------------------------------------------------------------------- @@ -238,7 +241,7 @@ async def test_successful_run_with_records(self, tmp_path: Path): ) (output_dir / "session.json").write_text('{"cursor":"abc"}') - inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + inputs = SourceInputs(convention_srn=_CONV_SRN) result = await runner._run_job( batch_api, core_api, @@ -289,7 +292,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + inputs = SourceInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Tt]imed out|[Dd]eadline"): await runner._run_job( @@ -349,7 +352,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:test") + inputs = SourceInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Oo]OM"): await runner._run_job( @@ -408,7 +411,9 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn="urn:osa:localhost:conv:my-conv") + inputs = SourceInputs( + convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:my-conv@1.0.0") + ) with ( patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), @@ -420,4 +425,4 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): call_args = batch_api.create_namespaced_job.call_args spec = call_args[0][1] labels = spec.metadata.labels - assert labels["osa.io/convention"] == "urn:osa:localhost:conv:my-conv" + assert labels["osa.io/convention"] == "localhost.conv.my-conv.1.0.0" diff --git a/server/tests/unit/infrastructure/k8s/test_naming.py b/server/tests/unit/infrastructure/k8s/test_naming.py index 9c5b462..cf5d6d5 100644 --- a/server/tests/unit/infrastructure/k8s/test_naming.py +++ b/server/tests/unit/infrastructure/k8s/test_naming.py @@ -1,9 +1,9 @@ -"""Tests for SRN-to-Job-name sanitization.""" +"""Tests for K8s naming utilities: Job names and label values.""" import re - -from osa.infrastructure.k8s.naming import job_name +from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label class TestJobName: @@ -50,3 +50,46 @@ def test_no_leading_digit(self): """DNS-1035 labels must start with a letter.""" name = job_name("hook", "123test", "urn:osa:localhost:dep:abc") assert name[0].isalpha() + + +class TestSanitizeLabel: + def test_replaces_colons(self): + assert ":" not in sanitize_label("sha256:abc123def") + + def test_preserves_valid_chars(self): + assert sanitize_label("hello-world_1.0") == "hello-world_1.0" + + def test_truncates_to_63(self): + assert len(sanitize_label("a" * 100)) <= 63 + + def test_strips_edge_chars(self): + result = sanitize_label(".leading-and-trailing.") + assert not result.startswith(".") + assert not result.endswith(".") + + def test_collapses_runs(self): + assert ".." not in sanitize_label("a::b") + + +class TestLabelValue: + def test_deposition_srn(self): + srn = DepositionSRN.parse("urn:osa:localhost:dep:abc123") + result = label_value(srn) + assert result == "localhost.dep.abc123" + assert ":" not in result + + def test_convention_srn_with_version(self): + srn = ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") + result = label_value(srn) + assert result == "localhost.conv.test.1.0.0" + + def test_no_colons_in_output(self): + srn = DepositionSRN.parse("urn:osa:archive.university.edu:dep:xyz789") + result = label_value(srn) + assert ":" not in result + assert re.match(r"^[a-zA-Z0-9._-]+$", result) + + def test_max_63_chars(self): + long_id = "a" * 60 # LocalId max is 64; with "localhost.dep." prefix this exceeds 63 + srn = DepositionSRN.parse(f"urn:osa:localhost:dep:{long_id}") + assert len(label_value(srn)) <= 63 diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index e901c9b..27c145b 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -5,6 +5,7 @@ import pytest +from osa.domain.shared.model.srn import DepositionSRN from osa.domain.shared.model.hook import ( ColumnDef, HookDefinition, @@ -166,7 +167,8 @@ async def test_successful_hook_returns_passed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -191,7 +193,8 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -213,7 +216,8 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -242,7 +246,8 @@ async def hang(): runner = OciHookRunner(docker=docker) hook = _make_hook(timeout=1) # 1 second timeout inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -264,7 +269,8 @@ async def test_rejection_via_progress(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) work_dir = tmp_path / "hook_work" @@ -296,7 +302,8 @@ async def test_security_hardening(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook(memory="4g", cpu="4.0") inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -329,7 +336,8 @@ async def test_env_vars_set(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" @@ -359,7 +367,7 @@ async def test_nested_bind_mounts(self, tmp_path: Path): files_dir.mkdir() inputs = HookInputs( record_json={"srn": "test"}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), files_dir=files_dir, ) @@ -395,7 +403,7 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): hook = _make_hook() inputs = HookInputs( record_json={"srn": "test"}, - deposition_srn="urn:osa:localhost:dep:test123", + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), files_dir=None, ) @@ -423,7 +431,8 @@ async def test_container_deleted_on_failure(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, deposition_srn="urn:osa:localhost:dep:test123" + record_json={"srn": "test"}, + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:test123"), ) output_dir = tmp_path / "output" From be379ae1c3b5211b0fd28ef07bc032364c54da7a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 23:36:38 +0000 Subject: [PATCH 4/6] refactor: remove unused core_api parameter from _wait_for_completion method Remove unused CoreV1Api parameter from _wait_for_completion method in both K8sHookRunner and K8sSourceRunner classes to clean up the method signature and improve code maintainability. --- server/osa/infrastructure/k8s/runner.py | 2 -- server/osa/infrastructure/k8s/source_runner.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index b8d446c..113b03e 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -136,7 +136,6 @@ async def _run_job( # Phase 2: Wait for completion result = await self._wait_for_completion( batch_api, - core_api, job_name_to_watch, namespace, timeout_seconds=hook.runtime.limits.timeout_seconds + 30, @@ -385,7 +384,6 @@ async def _wait_for_scheduling( async def _wait_for_completion( self, batch_api: BatchV1Api, - core_api: CoreV1Api, job_name: str, namespace: str, *, diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/source_runner.py index 19b4aa2..778ef74 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/source_runner.py @@ -137,7 +137,6 @@ async def _run_job( # Phase 2: Completion result = await self._wait_for_completion( batch_api, - core_api, job_name_to_watch, namespace, timeout_seconds=source.limits.timeout_seconds + 30, @@ -363,7 +362,6 @@ async def _wait_for_scheduling( async def _wait_for_completion( self, batch_api: BatchV1Api, - core_api: CoreV1Api, job_name: str, namespace: str, *, From 17edba4405911233bb6e2fd5794beeb434b0693c Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sun, 22 Mar 2026 00:35:51 +0000 Subject: [PATCH 5/6] feat: add K8s config validation and improve error handling - Add model validator to ensure data_pvc_name is set when using K8s backend - Extract relative_path utility function to reduce code duplication - Improve file storage error handling with proper exception wrapping - Add comprehensive tests for K8s config validation and file operations --- server/osa/config.py | 10 +++++++ server/osa/infrastructure/k8s/runner.py | 8 ++--- .../osa/infrastructure/k8s/source_runner.py | 8 ++--- .../persistence/adapter/storage.py | 7 +++-- server/osa/infrastructure/runner_utils.py | 12 ++++++++ .../unit/infrastructure/k8s/test_config.py | 30 +++++++++++++++++++ .../infrastructure/test_file_storage_move.py | 20 +++++++++++++ 7 files changed, 81 insertions(+), 14 deletions(-) create mode 100644 server/tests/unit/infrastructure/k8s/test_config.py diff --git a/server/osa/config.py b/server/osa/config.py index 00cba5a..4c10d17 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -100,6 +100,16 @@ class RunnerConfig(BaseModel): backend: Literal["oci", "k8s"] = "oci" k8s: K8sConfig = K8sConfig() + @model_validator(mode="after") + def validate_k8s_required_fields(self) -> Self: + """Validate that required K8s fields are set when backend is 'k8s'.""" + if self.backend == "k8s" and not self.k8s.data_pvc_name: + raise ValueError( + "runner.k8s.data_pvc_name is required when runner.backend == 'k8s'. " + "Set OSA_RUNNER__K8S__DATA_PVC_NAME." + ) + return self + # ============================================================================= # Authentication Configuration diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 113b03e..0f6659d 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -17,7 +17,7 @@ from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value -from osa.infrastructure.runner_utils import detect_rejection, parse_progress_file +from osa.infrastructure.runner_utils import detect_rejection, parse_progress_file, relative_path if TYPE_CHECKING: from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job @@ -327,11 +327,7 @@ def _build_job_spec( def _relative_path(self, path: Path) -> str: """Strip the data mount prefix to get a PVC-relative subpath.""" - mount = self._config.data_mount_path.rstrip("/") - path_str = str(path) - if not path_str.startswith(mount): - raise ValueError(f"Path {path} is outside the data mount prefix {mount}") - return path_str[len(mount) :].lstrip("/") + return relative_path(path, self._config.data_mount_path) async def _wait_for_scheduling( self, diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/source_runner.py index 778ef74..192631e 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/source_runner.py @@ -16,7 +16,7 @@ from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label -from osa.infrastructure.runner_utils import parse_records_file, parse_session_file +from osa.infrastructure.runner_utils import parse_records_file, parse_session_file, relative_path if TYPE_CHECKING: from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job @@ -310,11 +310,7 @@ def _build_job_spec( ) def _relative_path(self, path: Path) -> str: - mount = self._config.data_mount_path.rstrip("/") - path_str = str(path) - if not path_str.startswith(mount): - raise ValueError(f"Path {path} is outside the data mount prefix {mount}") - return path_str[len(mount) :].lstrip("/") + return relative_path(path, self._config.data_mount_path) async def _wait_for_scheduling( self, diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index 31a0b6d..22890ca 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -92,8 +92,11 @@ async def save_file( try: Path(tmp_path).rename(target) except OSError: - shutil.copy2(tmp_path, target) - Path(tmp_path).unlink(missing_ok=True) + try: + shutil.copy2(tmp_path, target) + Path(tmp_path).unlink(missing_ok=True) + except OSError as e: + raise InfrastructureError(f"Failed to write file {filename}: {e}") from e except Exception: Path(tmp_path).unlink(missing_ok=True) raise diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py index 4b108b0..ef6aa69 100644 --- a/server/osa/infrastructure/runner_utils.py +++ b/server/osa/infrastructure/runner_utils.py @@ -66,6 +66,18 @@ def parse_memory(memory: str) -> int: raise ValueError(f"Unknown memory unit: {unit}") +def relative_path(path: Path, data_mount_path: str) -> str: + """Strip the data mount prefix to get a PVC-relative subpath. + + Used by K8s runners to convert absolute paths into PVC sub_path values. + """ + mount = data_mount_path.rstrip("/") + path_str = str(path) + if not path_str.startswith(mount): + raise ValueError(f"Path {path} is outside the data mount prefix {mount}") + return path_str[len(mount) :].lstrip("/") + + def parse_records_file(output_dir: Path) -> list[dict[str, Any]]: """Parse records.jsonl from source output directory.""" import logfire diff --git a/server/tests/unit/infrastructure/k8s/test_config.py b/server/tests/unit/infrastructure/k8s/test_config.py new file mode 100644 index 0000000..136fe03 --- /dev/null +++ b/server/tests/unit/infrastructure/k8s/test_config.py @@ -0,0 +1,30 @@ +"""Tests for RunnerConfig cross-field validation.""" + +import pytest +from pydantic import ValidationError + +from osa.config import K8sConfig, RunnerConfig + + +class TestRunnerConfigValidation: + """RunnerConfig validates required K8s fields when backend == 'k8s'.""" + + def test_oci_backend_allows_empty_pvc(self): + """OCI backend does not require K8s fields.""" + config = RunnerConfig(backend="oci") + assert config.backend == "oci" + + def test_k8s_backend_requires_data_pvc_name(self): + """K8s backend rejects empty data_pvc_name at config parse time.""" + with pytest.raises(ValidationError, match="data_pvc_name"): + RunnerConfig(backend="k8s", k8s=K8sConfig(data_pvc_name="")) + + def test_k8s_backend_accepts_valid_pvc(self): + """K8s backend accepts a non-empty data_pvc_name.""" + config = RunnerConfig(backend="k8s", k8s=K8sConfig(data_pvc_name="osa-data")) + assert config.k8s.data_pvc_name == "osa-data" + + def test_k8s_backend_default_pvc_rejected(self): + """K8s backend with default K8sConfig (empty pvc) is rejected.""" + with pytest.raises(ValidationError, match="data_pvc_name"): + RunnerConfig(backend="k8s") diff --git a/server/tests/unit/infrastructure/test_file_storage_move.py b/server/tests/unit/infrastructure/test_file_storage_move.py index b9735f1..5d67f1f 100644 --- a/server/tests/unit/infrastructure/test_file_storage_move.py +++ b/server/tests/unit/infrastructure/test_file_storage_move.py @@ -137,3 +137,23 @@ def failing_rename(self_path, target): files_dir = adapter.get_files_dir(dep_srn) assert (files_dir / "test.txt").read_bytes() == content assert result.name == "test.txt" + + @pytest.mark.asyncio + async def test_save_file_copy_failure_raises_infrastructure_error(self, tmp_path: Path): + """Copy failure wraps OSError in InfrastructureError with filename context.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + content = b"hello world" + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + with ( + patch.object(Path, "rename", failing_rename), + patch( + "osa.infrastructure.persistence.adapter.storage.shutil.copy2", + side_effect=OSError("No space left on device"), + ), + pytest.raises(InfrastructureError, match="test.txt"), + ): + await adapter.save_file(dep_srn, "test.txt", content, len(content)) From 4ba4474342a8d2a7e5e2a78c3919083ade4f0397 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sun, 22 Mar 2026 10:47:39 +0000 Subject: [PATCH 6/6] feat: add k8s memory quantity conversion for proper resource limits Convert Docker-style memory strings to K8s resource quantities to ensure proper memory limit specification in Kubernetes jobs fix: improve temp file cleanup in storage adapter Move temp file cleanup outside exception handler and add warning logging when cleanup fails to prevent masking original errors --- server/osa/infrastructure/k8s/runner.py | 12 +++++- .../osa/infrastructure/k8s/source_runner.py | 12 +++++- .../persistence/adapter/storage.py | 5 ++- server/osa/infrastructure/runner_utils.py | 39 +++++++++++++++++++ .../k8s/test_k8s_hook_runner.py | 2 +- .../k8s/test_k8s_source_runner.py | 2 +- .../infrastructure/test_file_storage_move.py | 28 +++++++++++++ .../infrastructure/test_oci_hook_runner.py | 33 ++++++++++++++++ 8 files changed, 126 insertions(+), 7 deletions(-) diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 0f6659d..93b1967 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -17,7 +17,12 @@ from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value -from osa.infrastructure.runner_utils import detect_rejection, parse_progress_file, relative_path +from osa.infrastructure.runner_utils import ( + detect_rejection, + parse_progress_file, + relative_path, + to_k8s_quantity, +) if TYPE_CHECKING: from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job @@ -283,7 +288,10 @@ def _build_job_spec( V1EnvVar(name="OSA_HOOK_NAME", value=hook.name), ], resources=V1ResourceRequirements( - limits={"memory": hook.runtime.limits.memory, "cpu": hook.runtime.limits.cpu}, + limits={ + "memory": to_k8s_quantity(hook.runtime.limits.memory), + "cpu": hook.runtime.limits.cpu, + }, ), security_context=V1SecurityContext( read_only_root_filesystem=True, diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/source_runner.py index 192631e..6cafecd 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/source_runner.py @@ -16,7 +16,12 @@ from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label -from osa.infrastructure.runner_utils import parse_records_file, parse_session_file, relative_path +from osa.infrastructure.runner_utils import ( + parse_records_file, + parse_session_file, + relative_path, + to_k8s_quantity, +) if TYPE_CHECKING: from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job @@ -270,7 +275,10 @@ def _build_job_spec( image=f"{source.image}@{source.digest}", env=env, resources=V1ResourceRequirements( - limits={"memory": source.limits.memory, "cpu": source.limits.cpu}, + limits={ + "memory": to_k8s_quantity(source.limits.memory), + "cpu": source.limits.cpu, + }, ), security_context=V1SecurityContext( capabilities=V1Capabilities(drop=["ALL"]), diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index 22890ca..ed7772d 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -94,9 +94,12 @@ async def save_file( except OSError: try: shutil.copy2(tmp_path, target) - Path(tmp_path).unlink(missing_ok=True) except OSError as e: raise InfrastructureError(f"Failed to write file {filename}: {e}") from e + try: + Path(tmp_path).unlink() + except OSError: + logger.warning("Failed to clean up temp file: %s", tmp_path) except Exception: Path(tmp_path).unlink(missing_ok=True) raise diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py index ef6aa69..dae921d 100644 --- a/server/osa/infrastructure/runner_utils.py +++ b/server/osa/infrastructure/runner_utils.py @@ -66,6 +66,45 @@ def parse_memory(memory: str) -> int: raise ValueError(f"Unknown memory unit: {unit}") +_MEMORY_RE = re.compile(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$") + + +def to_k8s_quantity(memory: str) -> str: + """Convert a Docker-style memory string to a K8s resource quantity. + + Docker uses lowercase units where 'm' means megabytes. + K8s uses IEC binary units where 'Mi' means mebibytes and lowercase 'm' + means *milli* (10⁻³). This function bridges the two conventions. + + Fractional values are converted down one unit to produce an integer + quantity (e.g. "1.5g" → "1536Mi") since K8s quantities must be integers + when using binary suffixes. + """ + raw = memory.strip().lower() + match = _MEMORY_RE.match(raw) + if not match: + raise ValueError(f"Invalid memory format: {memory}") + + amount = float(match.group(1)) + unit = match.group(2) + + match unit: + case "g": + if amount == int(amount): + return f"{int(amount)}Gi" + return f"{int(amount * 1024)}Mi" + case "m": + if amount == int(amount): + return f"{int(amount)}Mi" + return f"{int(amount * 1024)}Ki" + case "k": + return f"{int(amount)}Ki" + case None: + return str(int(amount)) + case _: + raise ValueError(f"Unknown memory unit: {unit}") + + def relative_path(path: Path, data_mount_path: str) -> str: """Strip the data mount prefix to get a PVC-relative subpath. diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 1bb5fbb..003a8e7 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -112,7 +112,7 @@ def test_resource_limits(self): ) resources = spec.spec.template.spec.containers[0].resources - assert resources.limits["memory"] == "4g" + assert resources.limits["memory"] == "4Gi" assert resources.limits["cpu"] == "2.0" def test_volume_mounts(self): diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py index 3935945..6072c4a 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py @@ -88,7 +88,7 @@ def test_higher_defaults(self): files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) resources = spec.spec.template.spec.containers[0].resources - assert resources.limits["memory"] == "4g" + assert resources.limits["memory"] == "4Gi" # activeDeadlineSeconds = scheduling_timeout + source timeout assert spec.spec.active_deadline_seconds == 120 + 3600 diff --git a/server/tests/unit/infrastructure/test_file_storage_move.py b/server/tests/unit/infrastructure/test_file_storage_move.py index 5d67f1f..08eb201 100644 --- a/server/tests/unit/infrastructure/test_file_storage_move.py +++ b/server/tests/unit/infrastructure/test_file_storage_move.py @@ -157,3 +157,31 @@ def failing_rename(self_path, target): pytest.raises(InfrastructureError, match="test.txt"), ): await adapter.save_file(dep_srn, "test.txt", content, len(content)) + + @pytest.mark.asyncio + async def test_save_file_unlink_failure_after_copy_succeeds(self, tmp_path: Path): + """If copy2 succeeds but temp unlink fails, the write still succeeds.""" + adapter = FilesystemStorageAdapter(str(tmp_path)) + dep_srn = _make_dep_srn() + content = b"hello world" + + def failing_rename(self_path, target): + raise OSError("Cross-device link") + + original_unlink = Path.unlink + + def selective_unlink(self_path, *, missing_ok=False): + # Only fail for temp files (inside the fallback path) + if "tmp" in str(self_path) or str(self_path).startswith("/tmp"): + raise OSError("Permission denied") + original_unlink(self_path, missing_ok=missing_ok) + + with ( + patch.object(Path, "rename", failing_rename), + patch.object(Path, "unlink", selective_unlink), + ): + result = await adapter.save_file(dep_srn, "test.txt", content, len(content)) + + files_dir = adapter.get_files_dir(dep_srn) + assert (files_dir / "test.txt").read_bytes() == content + assert result.name == "test.txt" diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index 27c145b..268b3a2 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -20,6 +20,7 @@ detect_rejection, parse_memory, parse_progress_file, + to_k8s_quantity, ) @@ -78,6 +79,38 @@ def test_invalid_format(self): parse_memory("abc") +class TestToK8sQuantity: + """Convert Docker-style memory strings to K8s resource quantities.""" + + def test_gigabytes(self): + assert to_k8s_quantity("2g") == "2Gi" + + def test_megabytes(self): + assert to_k8s_quantity("512m") == "512Mi" + + def test_kilobytes(self): + assert to_k8s_quantity("1024k") == "1024Ki" + + def test_bare_bytes(self): + assert to_k8s_quantity("1048576") == "1048576" + + def test_fractional_gigabytes(self): + assert to_k8s_quantity("1.5g") == "1536Mi" + + def test_fractional_megabytes(self): + assert to_k8s_quantity("1.5m") == "1536Ki" + + def test_case_insensitive(self): + assert to_k8s_quantity("2G") == "2Gi" + + def test_with_i_suffix(self): + assert to_k8s_quantity("2gi") == "2Gi" + + def test_invalid_format(self): + with pytest.raises(ValueError, match="Invalid memory format"): + to_k8s_quantity("abc") + + class TestParseProgress: def test_empty_when_no_file(self, tmp_path: Path): entries = parse_progress_file(tmp_path)