diff --git a/v5/.gitignore b/v5/.gitignore new file mode 100644 index 0000000..77ac754 --- /dev/null +++ b/v5/.gitignore @@ -0,0 +1,3 @@ +.venv/ +__pycache__/ +*.pyc diff --git a/v5/RUNBOOK.md b/v5/RUNBOOK.md new file mode 100644 index 0000000..670b288 --- /dev/null +++ b/v5/RUNBOOK.md @@ -0,0 +1,452 @@ +# OmniSky v5 data-collection runbook + +This runbook describes the cluster-agnostic data-collection path for OmniSky v5. +It works locally for dry runs and on any SLURM cluster with Python, shared storage, +and outbound access or pre-staged catalogs. Delta/DeltaAI are examples, not +requirements. + +## 0. Operator contract + +The pipeline has three modes: + +1. **Local dry run** — proves LSDB/HATS access and stores a bounded sample locally. +2. **TEST_MODE end-to-end** — runs the full orchestration with synthetic data and no + network-heavy sources. +3. **Live data collection** — builds real seeds, plans source shards, runs source + arrays, finalizes shards, validates, and uploads or dry-runs upload. + +Do not claim a live release is validated until the live LSDB/HF/S3/local-file paths +have been run on the target compute environment and `validate_release.py` passes. + +## 1. Choose and describe the execution environment + +Set site-specific values with environment variables or `sbatch` flags. The wrappers +do not hardcode account, partition, QOS, or a Delta path. + +Common variables: + +| Variable | Required | Meaning | +| --- | --- | --- | +| `RELEASE_ROOT` | yes | Writable release directory for seeds, shards, markers, manifests, and final output. | +| `MANIFEST` | for array/finalize/validate | Path to `work_units.json`. | +| `SMITH42_REVISION` | for Phase 0 probe | Immutable Smith42 HF dataset commit, e.g. `93d0fddf8c5b61028ee0b6d72fd0dbfa87b38624`. | +| `HF_REPO` | for upload | Hugging Face repo id, e.g. `UniverseTBD/omnisky-v5`. | +| `OMNISKY_ENV_ACTIVATE` | optional | Shell snippet to activate the runtime, e.g. `source .venv/bin/activate`. | +| `OMNISKY_CONDA_ENV` | optional | Conda env name; defaults to `omnisky`. | +| `OMNISKY_PYTHON_BIN` | optional | Python executable after activation; defaults to `python`. | +| `OMNISKY_RELEASE_ROOT` | optional | Python config default for release root; CLI `--release-root`/`RELEASE_ROOT` still wins. | +| `OMNISKY_LEGACY_ROOT` | optional | Pre-staged Legacy HDF5 root for `legacy_hdf5`. | +| `OMNISKY_SDSS_DR16Q_ROOT` | optional | Pre-staged SDSS DR16Q root for local FITS fallback. | + +For SLURM account/partition, prefer scheduler-provided variables or flags: + +```bash +export SBATCH_ACCOUNT=my_allocation +export SBATCH_PARTITION=cpu +# Optional site knobs: +export SBATCH_QOS=normal +export SBATCH_CONSTRAINT=x86_64 +``` + +Equivalent inline form: + +```bash +sbatch --account=my_allocation --partition=cpu slurm/probe.sbatch +``` + +If your only queue allocates GPUs, it can still run the CPU data path. Keep array +concurrency conservative, request enough CPUs/memory, and avoid claiming GPU speedup +unless a downstream analysis stage actually uses GPUs. + +## 2. Create the Python environment + +Preferred conda path: + +```bash +cd /path/to/layerwise-analysis/v5 +conda env create -f environment.yml +conda activate omnisky +python -c 'import lsdb, pandas; print("env ok")' +``` + +Existing virtualenv path: + +```bash +cd /path/to/layerwise-analysis/v5 +python -m venv .venv +source .venv/bin/activate +python -m pip install -U pip +python -m pip install lsdb hats astropy astropy-healpix pyarrow pandas h5py s3fs huggingface_hub datasets astroquery dask distributed requests pytest +python -c 'import lsdb, pandas; print("env ok")' +``` + +For SLURM wrappers using a non-conda env: + +```bash +export OMNISKY_ENV_ACTIVATE='source /path/to/layerwise-analysis/v5/.venv/bin/activate' +export OMNISKY_PYTHON_BIN=python +``` + +## 3. Run local bounded LSDB dry run + +Use this before any large cluster job. It verifies LSDB import, catalog opening, +pixel filtering, local materialization, storage, and optionally one tiny crossmatch. + +```bash +cd /path/to/layerwise-analysis/v5 +python -m scripts.local_lsdb_dry_run \ + --sources desi,hsc \ + --cone 150.1,2.2,600 \ + --columns ra,dec \ + --max-bytes 1GB \ + --out-dir /tmp/omnisky-lsdb-dry-run \ + --crossmatch +``` + +Outputs: + +- `source=/order=/pixel=.parquet` or `.jsonl` +- `local_lsdb_dry_run_report.json` + +Interpretation: + +- `stored_bytes` should be positive. +- `fetches[*].fetched_rows` should be positive for at least one source/pixel. +- `crossmatch.attempted=true` means the LSDB crossmatch API path ran. +- A byte cap limits stored bytes, not peak LSDB/Dask memory. Keep columns and pixels + narrow for local tests. +- If a pixel run reports `error: no_coverage`, use a cone search in a known footprint + first. The COSMOS cone above has been used as the default DESI/HSC smoke region. + +If the canonical HF URI is unavailable, override it with local HATS paths: + +```bash +python -m scripts.local_lsdb_dry_run \ + --sources desi,hsc \ + --cone 150.1,2.2,600 \ + --catalog desi=/staged/hats/mmu_desi_edr_sv3 \ + --catalog hsc=/staged/hats/mmu_hsc_pdr3_dud_22.5 \ + --max-bytes 1GB \ + --out-dir /tmp/omnisky-lsdb-dry-run +``` + +## 4. Phase 0 probe: network and concordance gate + +Purpose: verify source reachability, throughput, one-pixel DESI×HSC LSDB crossmatch, +and Smith42 concordance before scaling out. + +Local/direct command: + +```bash +cd /path/to/layerwise-analysis/v5 +export SMITH42_REVISION=93d0fddf8c5b61028ee0b6d72fd0dbfa87b38624 +python -m scripts.probe_sources --order 4 --pixel 257 --out probe_report.json +python -m scripts.probe_crossmatch --order 4 --pixel 257 \ + --smith42-revision "$SMITH42_REVISION" \ + --out crossmatch_probe.json +python -m scripts.check_phase0_gate --probe probe_report.json --crossmatch crossmatch_probe.json +``` + +SLURM command: + +```bash +export SBATCH_ACCOUNT=my_allocation +export SBATCH_PARTITION=cpu +export SMITH42_REVISION=93d0fddf8c5b61028ee0b6d72fd0dbfa87b38624 +sbatch slurm/probe.sbatch +``` + +Accept only if: + +- Probe output exists and has explicit reachability/throughput verdicts. +- Crossmatch output exists. +- `check_phase0_gate.py` exits zero. +- Any unreachable service has a documented pre-stage plan before live scale-out. + +## 5. Build seed catalogs + +Seeds define the objects to collect source modalities for. For local orchestration +tests, use `TEST_MODE`; for live runs, provide curated input CSVs with `ra` and `dec`. + +Synthetic smoke seed: + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5-smoke +python -m scripts.build_seed_catalogs \ + --population galaxy \ + --release-root "$RELEASE_ROOT" \ + --out unused \ + --test-mode \ + --n 10 +``` + +Live seed from CSV: + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +python -m scripts.build_seed_catalogs \ + --population galaxy \ + --release-root "$RELEASE_ROOT" \ + --out unused \ + --input-csv /staged/seeds/galaxies.csv +``` + +SLURM form: + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +export POPULATION=galaxy +export INPUT_CSV=/staged/seeds/galaxies.csv +sbatch slurm/build_seeds.sbatch +``` + +Repeat per population (`galaxy`, `star`, `agn`) with population-specific input and +source choices. + +## 6. Plan source work units + +The manifest is the authority for source×shard work. It embeds an inputs hash so a +rerun can safely attach to identical work and reject conflicting work. + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +python -m scripts.plan_work_units \ + --sources desi,hsc,legacy \ + --population galaxy \ + --n-objects 0 \ + --shard-size 50000 \ + --release-root "$RELEASE_ROOT" \ + --out "$RELEASE_ROOT/manifests/galaxy" +export MANIFEST="$RELEASE_ROOT/manifests/galaxy/work_units.json" +``` + +When `--release-root` seed exists, the script counts seed rows and hashes seed content; +`--n-objects` is a fallback/manual value. + +Record: + +- source list, +- shard size, +- manifest path, +- manifest hash printed/written in `work_units.json`, +- code SHA used for the run. + +## 7. Run source shard arrays + +Each array task reads seed rows for one source×shard unit and writes source shards +plus DONE markers. It can run locally for a tiny manifest or as a SLURM array for +scale-out. + +Local single task: + +```bash +python -m scripts.run_source_shard \ + --manifest "$MANIFEST" \ + --task-id 0 \ + --release-root "$RELEASE_ROOT" \ + --code-sha "$(git rev-parse --short HEAD 2>/dev/null || echo local)" +``` + +SLURM array: + +```bash +N_UNITS=$(python - <<'PY' +import json, os +with open(os.environ['MANIFEST']) as f: + print(len(json.load(f)['units'])) +PY +) + +export RELEASE_ROOT=/shared/omnisky/release/v5 +export MANIFEST=/shared/omnisky/release/v5/manifests/galaxy/work_units.json +export CODE_SHA=$(git rev-parse --short HEAD 2>/dev/null || echo local) +sbatch --array=0-$((N_UNITS - 1)) slurm/run_source_array.sbatch +``` + +For multiple users or queues sharing one release, split work deterministically: + +```bash +export PARTITION_ID=0 +export NUM_PARTITIONS=4 +sbatch --array=0-249 slurm/run_source_array.sbatch +``` + +A second operator can use `PARTITION_ID=1`, etc. Each partition sees disjoint units. + +## 8. Finalize per-population shards + +After source shards complete for a population, finalization joins modalities for each +object shard and enforces `MIN_INSTRUMENTS`. + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +export MANIFEST=/shared/omnisky/release/v5/manifests/galaxy/work_units.json +export POPULATION=galaxy +export SOURCES=desi,hsc,legacy +export MIN_INSTRUMENTS=2 +sbatch --array=0- slurm/finalize_array.sbatch +``` + +Local equivalent: + +```bash +python -m scripts.finalize_shard \ + --release-root "$RELEASE_ROOT" \ + --manifest "$MANIFEST" \ + --population galaxy \ + --sources desi,hsc,legacy \ + --shard 0 \ + --min-instruments 2 +``` + +## 9. Verify markers and aggregate the release + +Marker verification catches missing, stale, suspicious, or corrupt shard outputs before +release aggregation. + +```bash +python -m scripts.verify_markers \ + --release-root "$RELEASE_ROOT" \ + --manifest "$MANIFEST" \ + --code-sha "$CODE_SHA" + +python -m scripts.finalize_release --release-root "$RELEASE_ROOT" +``` + +The aggregate release writes under `$RELEASE_ROOT/release/`. + +## 10. Validate quality gates + +Run structural validation and false-match reporting. Treat false-match bins above the +threshold as low-confidence or blocking, depending on the release policy. + +```bash +python -m scripts.validate_release \ + --release-root "$RELEASE_ROOT" \ + --min-instruments 2 + +python -m scripts.false_match_report \ + --out false_match_report.json \ + --threshold 0.001 +``` + +SLURM combined validation: + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +export MANIFEST=/shared/omnisky/release/v5/manifests/galaxy/work_units.json +export CODE_SHA=$(git rev-parse --short HEAD 2>/dev/null || echo local) +sbatch slurm/validate_release.sbatch +``` + +## 11. Upload or dry-run upload + +Always dry-run first: + +```bash +python -m scripts.upload_hf \ + --release-root "$RELEASE_ROOT" \ + --repo UniverseTBD/omnisky-v5 \ + --dry-run +``` + +Real upload: + +```bash +export HF_TOKEN=... +python -m scripts.upload_hf \ + --release-root "$RELEASE_ROOT" \ + --repo UniverseTBD/omnisky-v5 +``` + +SLURM wrapper: + +```bash +export RELEASE_ROOT=/shared/omnisky/release/v5 +export HF_REPO=UniverseTBD/omnisky-v5 +export DRY_RUN=1 +sbatch slurm/upload_hf.sbatch +``` + +Remove `DRY_RUN` only after the dry-run report and validation gates pass. + +## 12. TEST_MODE full local orchestration + +Use this to verify orchestration changes without network catalogs: + +```bash +tmp=$(mktemp -d) +export RELEASE_ROOT="$tmp/release" + +python -m scripts.build_seed_catalogs --population galaxy --release-root "$RELEASE_ROOT" --out unused --test-mode --n 5 +python -m scripts.plan_work_units --sources desi,hsc --population galaxy --n-objects 0 --release-root "$RELEASE_ROOT" --out "$RELEASE_ROOT/manifests/galaxy" +export MANIFEST="$RELEASE_ROOT/manifests/galaxy/work_units.json" + +python -m scripts.run_source_shard --manifest "$MANIFEST" --task-id 0 --release-root "$RELEASE_ROOT" --test-mode +python -m scripts.run_source_shard --manifest "$MANIFEST" --task-id 1 --release-root "$RELEASE_ROOT" --test-mode +python -m scripts.finalize_shard --release-root "$RELEASE_ROOT" --manifest "$MANIFEST" --population galaxy --sources desi,hsc --shard 0 --min-instruments 2 +python -m scripts.verify_markers --release-root "$RELEASE_ROOT" --manifest "$MANIFEST" --code-sha local +python -m scripts.finalize_release --release-root "$RELEASE_ROOT" +python -m scripts.validate_release --release-root "$RELEASE_ROOT" --min-instruments 2 +python -m scripts.upload_hf --release-root "$RELEASE_ROOT" --repo UniverseTBD/omnisky-v5 --dry-run +``` + +## 13. Cluster-specific notes + +### Generic SLURM + +- Use `SBATCH_ACCOUNT`, `SBATCH_PARTITION`, `SBATCH_QOS`, and `SBATCH_CONSTRAINT` or + equivalent `sbatch` flags. +- Keep `RELEASE_ROOT` on shared storage visible to all array jobs. +- Keep local pre-staged datasets under paths exported via `OMNISKY_LEGACY_ROOT` and + `OMNISKY_SDSS_DR16Q_ROOT`. + +### Delta x86 CPU + +- Good fit for data generation because it avoids idle GPU allocation. +- Use the site account/partition via `SBATCH_*`; do not edit wrappers permanently. + +### DeltaAI / GH200-only queues + +- Data collection remains CPU-oriented even if the scheduler allocates a GPU. +- Pack array tasks conservatively; the expensive resource is usually network/object-store + throughput or shared filesystem pressure, not GPU compute. +- If Python wheels differ on aarch64, build the env on the target architecture and run + the local LSDB dry run before submitting large arrays. + +### No outbound compute-node internet + +- Run `probe.sbatch` or direct probes first. +- If HF/S3/CDS access fails, pre-stage required catalogs via the site-approved transfer + mechanism, then use local paths or source-specific env vars. +- For LSDB/HATS sources, prefer local HATS directories and validate with + `local_lsdb_dry_run.py --catalog SOURCE=/path/to/hats`. + +## 14. Failure handling + +| Symptom | Likely cause | Action | +| --- | --- | --- | +| `No module named 'lsdb'` | Wrong Python env | Activate conda env or install LSDB in the active venv. | +| `source ... not LSDB` in dry run | Requested non-HATS source | Use only `lsdb_mmu` sources for `local_lsdb_dry_run.py`. | +| Probe internet failure | Compute node lacks outbound access | Pre-stage catalogs and record stream-vs-pre-stage decision. | +| Stale marker | Code SHA/schema/manifest mismatch | Re-run affected shard with intended manifest/code or clean stale output intentionally. | +| Corrupt marker | Output changed after marker write | Re-run shard; do not hand-edit outputs. | +| Empty fetched pixel | Wrong pixel/order or catalog footprint | Try a known populated pixel, narrower source list, or ConeSearch-based investigation. | + +## 15. Ship checklist + +Before calling the release complete, capture: + +- environment creation command and Python version, +- cluster/site name and scheduler options used, +- release root, +- source list and seed provenance per population, +- Smith42 revision, +- dry-run report, +- Phase 0 probe and concordance outputs, +- manifest hashes, +- marker verification output, +- release validation output, +- false-match report, +- upload dry-run output, +- real upload URL and load-back verification if uploaded. diff --git a/v5/environment.yml b/v5/environment.yml new file mode 100644 index 0000000..c802b36 --- /dev/null +++ b/v5/environment.yml @@ -0,0 +1,24 @@ +name: omnisky +channels: + - conda-forge + - nodefaults +dependencies: + - python=3.11 + - lsdb>=0.9 + - hats>=0.9 + - astropy>=6.0 + - scipy>=1.13 + - astropy-healpix>=1.0 + - pyarrow>=16 + - pandas>=2.2 + - h5py>=3.11 + - s3fs>=2024.6 + - huggingface_hub>=0.24 + - datasets>=2.20 + - astroquery>=0.4.7 + - dask>=2024.6 + - distributed>=2024.6 + - requests>=2.32 + - pytest>=8 + - mypy>=1.10 + - pip diff --git a/v5/mmu/__init__.py b/v5/mmu/__init__.py new file mode 100644 index 0000000..ea0fc67 --- /dev/null +++ b/v5/mmu/__init__.py @@ -0,0 +1 @@ +# mmu: OmniSky v5 core library diff --git a/v5/mmu/concordance.py b/v5/mmu/concordance.py new file mode 100644 index 0000000..51ff4d1 --- /dev/null +++ b/v5/mmu/concordance.py @@ -0,0 +1,66 @@ +"""Concordance of our matches against a reference cross-matched catalog.""" +from __future__ import annotations +import numpy as np +import astropy.units as u +from astropy.coordinates import SkyCoord, search_around_sky +from astropy_healpix import HEALPix +from typing import Any + + +def match_concordance(our_ra, our_dec, ref_ra, ref_dec, tol_arcsec: float = 1.0) -> dict[str, Any]: + if tol_arcsec <= 0: + raise ValueError("tol_arcsec must be positive") + our = SkyCoord(np.asarray(our_ra) * u.deg, np.asarray(our_dec) * u.deg) + ref = SkyCoord(np.asarray(ref_ra) * u.deg, np.asarray(ref_dec) * u.deg) + n_ref = len(ref) + if len(our) == 0 or n_ref == 0: + return {"n_ours": len(our), "n_ref": n_ref, "recovered": 0, + "recall": 0.0, "median_sep_arcsec": float("nan")} + _, idx_ref, sep, _ = search_around_sky(our, ref, tol_arcsec * u.arcsec) + recovered = int(np.unique(idx_ref).size) + n_pairs = int(len(sep)) + return {"n_ours": len(our), "n_ref": n_ref, "recovered": recovered, + "recall": recovered / n_ref, + "n_pairs_within_tolerance": n_pairs, + "duplicate_pairs": max(0, n_pairs - recovered), + "median_sep_arcsec": float(np.median(np.asarray(sep.arcsec, dtype=np.float64))) if len(sep) else float("nan")} + + +def filter_reference_to_our_footprint(our_ra, our_dec, ref_ra, ref_dec, + footprint_arcsec: float) -> tuple[np.ndarray, np.ndarray]: + """Restrict a reference table to the small sky patch actually probed. + + Phase 0 compares a one-pixel LSDB run against Smith42. If we computed recall + against the full Smith42 table, a correct one-pixel result would look like a + failure. This helper keeps only Smith42 rows near our probed matches. + """ + if footprint_arcsec <= 0: + raise ValueError("footprint_arcsec must be positive") + ref_ra_arr = np.asarray(ref_ra, dtype=np.float64) + ref_dec_arr = np.asarray(ref_dec, dtype=np.float64) + if len(ref_ra_arr) == 0: + return ref_ra_arr, ref_dec_arr + our = SkyCoord(np.asarray(our_ra, dtype=np.float64) * u.deg, + np.asarray(our_dec, dtype=np.float64) * u.deg) + if len(our) == 0: + return ref_ra_arr[:0], ref_dec_arr[:0] + ref = SkyCoord(ref_ra_arr * u.deg, ref_dec_arr * u.deg) + _, idx_ref, _, _ = search_around_sky(our, ref, footprint_arcsec * u.arcsec) + keep = np.unique(idx_ref) + return ref_ra_arr[keep], ref_dec_arr[keep] + + +def filter_reference_to_healpix_pixel(ref_ra, ref_dec, *, order: int, pixel: int) -> tuple[np.ndarray, np.ndarray]: + """Restrict reference rows to the exact HEALPix pixel used by a probe. + + Unlike filtering near our returned matches, this denominator is independent + of our matcher output and therefore cannot make missed references disappear. + """ + if order < 0: + raise ValueError("order must be non-negative") + ref_ra_arr = np.asarray(ref_ra, dtype=np.float64) + ref_dec_arr = np.asarray(ref_dec, dtype=np.float64) + hp = HEALPix(nside=2 ** order, order="nested") + ref_pix = np.asarray(hp.lonlat_to_healpix(ref_ra_arr * u.deg, ref_dec_arr * u.deg), dtype=np.int64) + keep = ref_pix == int(pixel) + return ref_ra_arr[keep], ref_dec_arr[keep] diff --git a/v5/mmu/config.py b/v5/mmu/config.py new file mode 100644 index 0000000..4566590 --- /dev/null +++ b/v5/mmu/config.py @@ -0,0 +1,59 @@ +"""Frozen pipeline configuration: sources, radii, paths, shard size, schema version.""" +from __future__ import annotations + +from dataclasses import dataclass +import os +from pathlib import Path + + +@dataclass(frozen=True, slots=True) +class Source: + name: str + kind: str + org: str + dataset: str + modality: str + radius_arcsec: float + columns: tuple[str, ...] + epoch_jyear: float = 2016.0 + population: str = "galaxy" + + @property + def is_hats(self) -> bool: + return self.kind in {"lsdb_mmu", "ztf_s3"} + + +def env_path(name: str, default: str) -> str: + return os.environ.get(name, default) + + +SOURCES: dict[str, Source] = { + "desi": Source("desi", "lsdb_mmu", "UniverseTBD", "mmu_desi_edr_sv3", + "spectrum", 1.0, ("ra", "dec", "flux", "ivar", "lambda"), 2016.0, "galaxy"), + "hsc": Source("hsc", "lsdb_mmu", "UniverseTBD", "mmu_hsc_pdr3_dud_22.5", + "image", 1.0, ("ra", "dec", "image"), 2014.0, "galaxy"), + "legacy": Source("legacy", "legacy_hdf5", "", env_path("OMNISKY_LEGACY_ROOT", "data/legacy_dr10_south_21"), + "image", 1.0, ("ra", "dec", "image_array"), 2015.5, "galaxy"), + "gaia": Source("gaia", "lsdb_mmu", "UniverseTBD", "mmu_gaia_gaia", + "astrometry", 1.0, ("ra", "dec", "pmra", "pmdec", "parallax"), 2016.0, "star"), + "apogee": Source("apogee", "lsdb_mmu", "hugging-science", "mmu_apogee_dr17", + "spectrum", 1.0, ("ra", "dec", "flux", "snr"), 2016.0, "star"), + "ztf": Source("ztf", "ztf_s3", "", "ztf/enhanced/dr24/lc/hats", + "lightcurve", 1.0, ("ra", "dec", "mjd", "mag", "magerr"), 2018.0, "star"), + "tess": Source("tess", "lsdb_mmu", "UniverseTBD", "mmu_tess_spoc", + "lightcurve", 1.0, ("ra", "dec", "time", "flux"), 2019.0, "star"), + "sdss_dr16q": Source("sdss_dr16q", "local_fits", "", env_path("OMNISKY_SDSS_DR16Q_ROOT", "data/sdss_dr16q"), + "spectrum", 1.0, ("ra", "dec", "z", "flux"), 2015.5, "agn"), +} + + +@dataclass(frozen=True, slots=True) +class Config: + release_root: Path = Path(env_path("OMNISKY_RELEASE_ROOT", "release/v5")) + shard_size: int = 50_000 + schema_version: str = "v5.1" + min_instruments: int = 2 + split_nside: int = 8 + + +DEFAULT = Config() diff --git a/v5/mmu/coordination.py b/v5/mmu/coordination.py new file mode 100644 index 0000000..7afdc38 --- /dev/null +++ b/v5/mmu/coordination.py @@ -0,0 +1,45 @@ +"""Manifest creation, partitioning, and finalize barriers.""" +from __future__ import annotations + +import hashlib +import json +import os +from pathlib import Path +from typing import Any + + +def manifest_hash(units: list[dict[str, Any]], inputs_hash: str) -> str: + payload = json.dumps({"units": units, "inputs_hash": inputs_hash}, sort_keys=True) + return hashlib.sha256(payload.encode()).hexdigest()[:16] + + +def build_or_attach_manifest(root: Path, units: list[dict[str, Any]], *, inputs_hash: str) -> tuple[str, bool]: + root = Path(root) + root.mkdir(parents=True, exist_ok=True) + mhash = manifest_hash(units, inputs_hash) + lock = root / "release.lock" + manifest = root / "manifest.json" + try: + fd = os.open(str(lock), os.O_CREAT | os.O_EXCL | os.O_WRONLY) + except FileExistsError: + existing = json.loads(manifest.read_text()) + if existing["inputs_hash"] != inputs_hash: + raise ValueError(f"manifest inputs_hash conflict: {existing['inputs_hash']} != {inputs_hash}") + if existing["manifest_hash"] != mhash: + raise ValueError("manifest units conflict for the same inputs_hash") + return existing["manifest_hash"], False + with os.fdopen(fd, "w") as f: + f.write(str(os.getpid())) + manifest.write_text(json.dumps({"manifest_hash": mhash, "inputs_hash": inputs_hash, + "units": units}, sort_keys=True, indent=2)) + return mhash, True + + +def partition_units(units: list[Any], *, partition_id: int, num_partitions: int) -> list[Any]: + if num_partitions <= 0 or partition_id < 0 or partition_id >= num_partitions: + raise ValueError("invalid partition") + return [u for i, u in enumerate(units) if i % num_partitions == partition_id] + + +def finalize_ready(shard: int, *, sources: list[str], completed: set[tuple[str, int]]) -> bool: + return all((source, shard) in completed for source in sources) diff --git a/v5/mmu/healpix.py b/v5/mmu/healpix.py new file mode 100644 index 0000000..c844b27 --- /dev/null +++ b/v5/mmu/healpix.py @@ -0,0 +1,28 @@ +"""Fail-fast HEALPix helpers for HATS/NESTED indexing.""" +from __future__ import annotations + +import astropy.units as u +import numpy as np +from astropy_healpix import HEALPix + + +def healpix(order: int) -> HEALPix: + if order < 0: + raise ValueError("order must be non-negative") + return HEALPix(nside=2 ** order, order="nested") + + +def seed_pixel_set(ra_deg, dec_deg, *, order: int) -> set[int]: + hp = healpix(order) + ra = np.atleast_1d(np.asarray(ra_deg, dtype=np.float64)) + dec = np.atleast_1d(np.asarray(dec_deg, dtype=np.float64)) + if ra.shape != dec.shape: + raise ValueError("ra/dec shape mismatch") + pix = hp.lonlat_to_healpix(ra * u.deg, dec * u.deg) + return {int(p) for p in np.asarray(pix)} + + +def neighbor_pixels(pixel: int, *, order: int) -> set[int]: + hp = healpix(order) + neighbours = np.asarray(hp.neighbours(int(pixel))) + return {int(pixel)} | {int(p) for p in neighbours if int(p) >= 0} diff --git a/v5/mmu/ids.py b/v5/mmu/ids.py new file mode 100644 index 0000000..7dd063e --- /dev/null +++ b/v5/mmu/ids.py @@ -0,0 +1,31 @@ +"""Canonical global_object_id helpers. + +The identifier is the NESTED HEALPix order-29 index of the seed position, +matching the MMU `_healpix_29` convention. Bad seed coordinates fail here so +they do not become silent identity corruption in a cluster run. +""" + +from __future__ import annotations + +import astropy.units as u +import numpy as np +from astropy_healpix import HEALPix + +ORDER: int = 29 +NSIDE: int = 2 ** ORDER +_HP = HEALPix(nside=NSIDE, order="nested") + + +def assign_global_id(ra_deg, dec_deg) -> np.ndarray: + ra = np.atleast_1d(np.asarray(ra_deg, dtype=np.float64)) + dec = np.atleast_1d(np.asarray(dec_deg, dtype=np.float64)) + if ra.shape != dec.shape: + raise ValueError(f"ra/dec shape mismatch: {ra.shape} != {dec.shape}") + if not np.isfinite(ra).all() or not np.isfinite(dec).all(): + raise ValueError("ra/dec must be finite") + if ((ra < 0.0) | (ra >= 360.0)).any(): + raise ValueError("ra must be in [0, 360) degrees") + if ((dec < -90.0) | (dec > 90.0)).any(): + raise ValueError("dec must be in [-90, 90] degrees") + idx = _HP.lonlat_to_healpix(ra * u.deg, dec * u.deg) + return np.asarray(idx, dtype=np.int64) diff --git a/v5/mmu/io_atomic.py b/v5/mmu/io_atomic.py new file mode 100644 index 0000000..9d5c64c --- /dev/null +++ b/v5/mmu/io_atomic.py @@ -0,0 +1,75 @@ +"""Atomic Parquet writes + integrity-checked DONE markers + resume state machine.""" +from __future__ import annotations + +import hashlib +import json +import os +from pathlib import Path +from typing import Any +from importlib import import_module + + +def file_checksum(path: Path) -> str: + h = hashlib.sha256() + with Path(path).open("rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return h.hexdigest() + + +def atomic_write_bytes(data: bytes, final_path: Path) -> None: + final_path = Path(final_path) + final_path.parent.mkdir(parents=True, exist_ok=True) + tmp = final_path.with_suffix(final_path.suffix + f".tmp.{os.getpid()}") + tmp.write_bytes(data) + with tmp.open("rb") as f: + os.fsync(f.fileno()) + os.replace(tmp, final_path) + dfd = os.open(str(final_path.parent), os.O_DIRECTORY) + try: + os.fsync(dfd) + finally: + os.close(dfd) + + +def atomic_write_parquet(table, final_path: Path) -> None: + pq = import_module("pyarrow.parquet") + + final_path = Path(final_path) + final_path.parent.mkdir(parents=True, exist_ok=True) + tmp = final_path.with_suffix(final_path.suffix + f".tmp.{os.getpid()}") + pq.write_table(table, tmp) + with tmp.open("rb") as f: + os.fsync(f.fileno()) + os.replace(tmp, final_path) + + +def marker_path(final_path: Path) -> Path: + return Path(str(final_path) + ".done.json") + + +def write_done_marker(final_path: Path, *, manifest_hash: str, schema_version: str, code_sha: str) -> None: + final_path = Path(final_path) + meta = {"manifest_hash": manifest_hash, "schema_version": schema_version, + "code_sha": code_sha, "byte_size": final_path.stat().st_size, + "checksum": file_checksum(final_path)} + atomic_write_bytes(json.dumps(meta, sort_keys=True).encode(), marker_path(final_path)) + + +def read_marker(final_path: Path) -> dict[str, Any] | None: + path = marker_path(Path(final_path)) + return json.loads(path.read_text()) if path.exists() else None + + +def unit_state(final_path: Path, *, manifest_hash: str, schema_version: str, code_sha: str) -> str: + final_path = Path(final_path) + if not final_path.exists(): + return "pending" + marker = read_marker(final_path) + if marker is None: + return "suspicious" + if (marker.get("manifest_hash"), marker.get("schema_version"), marker.get("code_sha")) != (manifest_hash, schema_version, code_sha): + return "stale" + if marker.get("checksum") != file_checksum(final_path): + return "corrupt" + return "complete" diff --git a/v5/mmu/matching.py b/v5/mmu/matching.py new file mode 100644 index 0000000..8c7c73f --- /dev/null +++ b/v5/mmu/matching.py @@ -0,0 +1,42 @@ +"""Match adjudication and population-aware coordinate preparation.""" +from __future__ import annotations + +import numpy as np +from typing import Any + + +def adjudicate(candidates: dict[int, list[tuple[float, int, float]]], *, radius_arcsec: float) -> dict[int, dict[str, Any]]: + if radius_arcsec <= 0: + raise ValueError("radius_arcsec must be positive") + out: dict[int, dict[str, Any]] = {} + for seed, cands in candidates.items(): + within = [c for c in cands if c[0] <= radius_arcsec] + if not within: + continue + within.sort(key=lambda c: (c[0], -c[2], c[1])) + best = within[0] + out[int(seed)] = {"src_index": int(best[1]), "match_sep_arcsec": float(best[0]), + "match_ambiguous": len(within) > 1, + "n_candidates_within_radius": len(within)} + return out + + +def prepare_stellar(ra, dec, pmra, pmdec, parallax_mas, rv_kms, *, + from_epoch_jyear: float = 2016.0, to_epoch_jyear: float) -> dict[str, Any]: + from mmu.motion import MissingParallaxPolicy, propagate_to_epoch + + return propagate_to_epoch(ra=np.asarray(ra), dec=np.asarray(dec), pmra=np.asarray(pmra), + pmdec=np.asarray(pmdec), parallax_mas=np.asarray(parallax_mas), + rv_kms=np.asarray(rv_kms), from_epoch_jyear=from_epoch_jyear, + to_epoch_jyear=to_epoch_jyear, + policy=MissingParallaxPolicy.FLAG) + + +def match_unit(seed: dict[str, Any], source: dict[str, Any], *, population: str) -> dict[str, Any]: + if population != "star": + return {"ra": np.asarray(seed["ra"]), "dec": np.asarray(seed["dec"]), + "motion_flag": np.array(["not_applicable"] * len(np.atleast_1d(seed["ra"])), dtype=object), + "drop": np.zeros(len(np.atleast_1d(seed["ra"])), dtype=bool)} + return prepare_stellar(seed["ra"], seed["dec"], seed["pmra"], seed["pmdec"], + seed["parallax"], seed.get("rv", np.zeros(len(seed["ra"]))), + to_epoch_jyear=float(source["epoch_jyear"])) diff --git a/v5/mmu/motion.py b/v5/mmu/motion.py new file mode 100644 index 0000000..6662294 --- /dev/null +++ b/v5/mmu/motion.py @@ -0,0 +1,66 @@ +"""Epoch propagation with explicit parallax policy.""" +from __future__ import annotations + +from enum import Enum +from typing import Any + +import numpy as np + + +class MissingParallaxPolicy(str, Enum): + FLAG = "flag" + DROP = "drop" + ASSUME_FAR = "assume_far" + + +def _linear_pm(ra, dec, pmra, pmdec, dt_yr): + cos_dec = np.cos(np.deg2rad(dec)) + cos_dec = np.where(np.abs(cos_dec) < 1e-12, np.nan, cos_dec) + return (ra + (pmra / 3.6e6) / cos_dec * dt_yr) % 360.0, dec + (pmdec / 3.6e6) * dt_yr + + +def propagate_to_epoch(*, ra, dec, pmra, pmdec, parallax_mas, rv_kms, + from_epoch_jyear: float, to_epoch_jyear: float, + policy: MissingParallaxPolicy = MissingParallaxPolicy.FLAG) -> dict[str, Any]: + import astropy.units as u + from astropy.coordinates import Distance, SkyCoord + from astropy.time import Time + + ra = np.atleast_1d(np.asarray(ra, dtype=float)) + dec = np.atleast_1d(np.asarray(dec, dtype=float)) + pmra = np.atleast_1d(np.asarray(pmra, dtype=float)) + pmdec = np.atleast_1d(np.asarray(pmdec, dtype=float)) + parallax = np.atleast_1d(np.asarray(parallax_mas, dtype=float)) + rv = np.zeros_like(ra) if rv_kms is None else np.atleast_1d(np.asarray(rv_kms, dtype=float)) + shape = ra.shape + if not all(x.shape == shape for x in (dec, pmra, pmdec, parallax, rv)): + raise ValueError("all input arrays must have the same shape") + + out_ra = ra.copy() + out_dec = dec.copy() + flags = np.array(["ok"] * len(ra), dtype=object) + drop = np.zeros(len(ra), dtype=bool) + good = np.isfinite(parallax) & (parallax > 0) + bad = ~good + if bad.any(): + flags[bad] = np.where(np.isfinite(parallax[bad]), "negative_parallax", "missing_parallax") + if policy == MissingParallaxPolicy.DROP: + drop[bad] = True + elif policy == MissingParallaxPolicy.ASSUME_FAR: + out_ra[bad], out_dec[bad] = _linear_pm(ra[bad], dec[bad], pmra[bad], pmdec[bad], + to_epoch_jyear - from_epoch_jyear) + flags[bad] = "assume_far" + + if good.any(): + coord = SkyCoord(ra=ra[good] * u.deg, dec=dec[good] * u.deg, + distance=Distance(parallax=parallax[good] * u.mas), + pm_ra_cosdec=pmra[good] * u.mas / u.yr, + pm_dec=pmdec[good] * u.mas / u.yr, + radial_velocity=np.nan_to_num(rv[good], nan=0.0) * u.km / u.s, + obstime=Time(from_epoch_jyear, format="jyear", scale="tcb")) + moved = coord.apply_space_motion(new_obstime=Time(to_epoch_jyear, format="jyear", scale="tcb")) + if moved.ra is None or moved.dec is None: + raise RuntimeError("apply_space_motion returned coordinates without ra/dec") + out_ra[good] = moved.ra.deg + out_dec[good] = moved.dec.deg + return {"ra": out_ra, "dec": out_dec, "motion_flag": flags, "drop": drop} diff --git a/v5/mmu/probe_report.py b/v5/mmu/probe_report.py new file mode 100644 index 0000000..69abf0f --- /dev/null +++ b/v5/mmu/probe_report.py @@ -0,0 +1,55 @@ +"""Typed probe results -> probe_report.json (feeds Phase-0 gate + per-service caps).""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path + + +@dataclass(frozen=True, slots=True) +class SourceProbe: + name: str + reachable: bool + cold_latency_s: float + warm_latency_s: float + throughput_mb_s: float + rate_limited: bool + n_rows_sampled: int + error: str = "" + + def suggested_concurrency(self) -> int: + if not self.reachable or self.n_rows_sampled <= 0: + return 0 + if self.rate_limited: + return 2 + if self.throughput_mb_s < 1.0: + return 4 + return 16 + + +@dataclass +class ProbeReport: + internet_ok: bool + notes: str = "" + sources: list[SourceProbe] = field(default_factory=list) + + def add(self, sp: SourceProbe) -> None: + self.sources.append(sp) + + def to_json(self, path: str | Path) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "internet_ok": self.internet_ok, + "notes": self.notes, + "sources": [asdict(s) for s in self.sources], + } + path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + + @classmethod + def from_json(cls, path: str | Path) -> "ProbeReport": + d = json.loads(Path(path).read_text()) + rep = cls(internet_ok=d["internet_ok"], notes=d.get("notes", "")) + for s in d.get("sources", []): + rep.add(SourceProbe(**s)) + return rep diff --git a/v5/mmu/rate_limit.py b/v5/mmu/rate_limit.py new file mode 100644 index 0000000..74f8eb2 --- /dev/null +++ b/v5/mmu/rate_limit.py @@ -0,0 +1,51 @@ +"""Minimal in-process token broker for rate-limited service access.""" +from __future__ import annotations + +from contextlib import contextmanager +from threading import Condition +from typing import Iterator + + +class TokenBroker: + """A deterministic token broker used by workers or tests. + + The full cluster design can front this with a single broker process. This + object provides the tested accounting semantics: no more than ``capacity`` + active acquisitions and every acquisition must be released. + """ + + def __init__(self, capacity: int) -> None: + if capacity <= 0: + raise ValueError("capacity must be positive") + self.capacity = capacity + self._available = capacity + self._cond = Condition() + + @property + def available(self) -> int: + return self._available + + @property + def in_use(self) -> int: + return self.capacity - self._available + + def acquire(self) -> None: + with self._cond: + while self._available <= 0: + self._cond.wait() + self._available -= 1 + + def release(self) -> None: + with self._cond: + if self._available >= self.capacity: + raise RuntimeError("release called without a matching acquire") + self._available += 1 + self._cond.notify() + + @contextmanager + def token(self) -> Iterator[None]: + self.acquire() + try: + yield + finally: + self.release() diff --git a/v5/mmu/reachability.py b/v5/mmu/reachability.py new file mode 100644 index 0000000..093bd7d --- /dev/null +++ b/v5/mmu/reachability.py @@ -0,0 +1,39 @@ +"""Compute-node outbound reachability for HF / S3 / CDS (Phase-0 critical check, C2).""" +from __future__ import annotations + +import time +from collections.abc import Mapping +from typing import Any + +import requests + +ENDPOINTS = { + "huggingface": "https://huggingface.co", + "s3": "https://ipac-irsa-ztf.s3.amazonaws.com", + "cds": "https://cdsxmatch.u-strasbg.fr", +} + + +def check_endpoint(url: str, timeout: float = 10.0) -> tuple[bool, float | None]: + t0 = time.monotonic() + try: + r = requests.head(url, timeout=timeout, allow_redirects=True) + if r.status_code in (403, 405): + r = requests.get(url, timeout=timeout, allow_redirects=True, stream=True) + return (200 <= r.status_code < 500, time.monotonic() - t0) + except requests.RequestException: + return (False, None) + + +def summarize_reachability(results: Mapping[str, tuple[bool, float | None]]) -> dict[str, Any]: + unreachable = sorted([k for k, (ok, _) in results.items() if not ok]) + return { + "internet_ok": len(unreachable) == 0, + "unreachable": unreachable, + "latencies": {k: lat for k, (_, lat) in results.items()}, + } + + +def probe_all(timeout: float = 10.0) -> dict[str, Any]: + results = {name: check_endpoint(url, timeout) for name, url in ENDPOINTS.items()} + return summarize_reachability(results) diff --git a/v5/mmu/records.py b/v5/mmu/records.py new file mode 100644 index 0000000..e5e31c0 --- /dev/null +++ b/v5/mmu/records.py @@ -0,0 +1,37 @@ +"""Small JSONL record helpers used by TEST_MODE and orchestration glue.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from mmu.io_atomic import atomic_write_bytes + + +Record = dict[str, Any] + + +def write_jsonl(records: list[Record], path: str | Path) -> None: + payload = "".join(json.dumps(row, sort_keys=True) + "\n" for row in records) + atomic_write_bytes(payload.encode(), Path(path)) + + +def read_jsonl(path: str | Path) -> list[Record]: + rows: list[Record] = [] + with Path(path).open() as f: + for line in f: + if line.strip(): + rows.append(json.loads(line)) + return rows + + +def source_shard_path(release_root: str | Path, *, population: str, source: str, shard: int) -> Path: + return Path(release_root) / f"source={source}" / f"population={population}" / f"shard={shard:06d}.jsonl" + + +def final_shard_path(release_root: str | Path, *, population: str, shard: int) -> Path: + return Path(release_root) / "final" / f"population={population}" / f"shard={shard:06d}.jsonl" + + +def seed_path(release_root: str | Path, *, population: str) -> Path: + return Path(release_root) / "seeds" / f"population={population}" / "seed.jsonl" diff --git a/v5/mmu/schemas.py b/v5/mmu/schemas.py new file mode 100644 index 0000000..c474ff3 --- /dev/null +++ b/v5/mmu/schemas.py @@ -0,0 +1,23 @@ +"""Versioned Arrow schemas for OmniSky v5 outputs.""" +from __future__ import annotations +from importlib import import_module + +SCHEMA_VERSION = "v5.1" + + +def final_schema(*, image_px: int = 160, n_bands: int = 4, spec_len: int = 7781, n_sources: int = 3): + pa = import_module("pyarrow") + + return pa.schema([ + ("global_object_id", pa.int64()), + ("object_uid", pa.string()), + ("seed_ra_deg", pa.float64()), + ("seed_dec_deg", pa.float64()), + ("population", pa.string()), + ("n_instruments_present", pa.int16()), + ("instrument_presence_mask", pa.list_(pa.bool_(), n_sources)), + ("split", pa.string()), + ("hsc_image", pa.list_(pa.float32(), image_px * image_px * n_bands)), + ("desi_spectrum", pa.list_(pa.float32(), spec_len)), + ("low_confidence", pa.bool_()), + ], metadata={b"schema_version": SCHEMA_VERSION.encode()}) diff --git a/v5/mmu/sources/__init__.py b/v5/mmu/sources/__init__.py new file mode 100644 index 0000000..1cba25f --- /dev/null +++ b/v5/mmu/sources/__init__.py @@ -0,0 +1 @@ +"""Source adapters for OmniSky v5.""" diff --git a/v5/mmu/sources/base.py b/v5/mmu/sources/base.py new file mode 100644 index 0000000..aa6facd --- /dev/null +++ b/v5/mmu/sources/base.py @@ -0,0 +1,13 @@ +"""Source adapter protocol.""" +from __future__ import annotations + +from typing import Protocol + + +class DataSource(Protocol): + name: str + modality: str + + def read_pixel(self, order: int, pixel: int): ... + + def to_rows(self, matched): ... diff --git a/v5/mmu/sources/legacy_hdf5.py b/v5/mmu/sources/legacy_hdf5.py new file mode 100644 index 0000000..d6121f4 --- /dev/null +++ b/v5/mmu/sources/legacy_hdf5.py @@ -0,0 +1,24 @@ +"""Flatiron Legacy DR10 HDF5 cutout reader.""" +from __future__ import annotations + +from pathlib import Path +from importlib import import_module + + +def pixel_files(root: str | Path, pixel: int) -> list[Path]: + return sorted((Path(root) / f"healpix={pixel}").glob("*.hdf5")) + + +def read_cutouts(path: str | Path): + h5py = import_module("h5py") + + with h5py.File(path, "r") as f: + missing = {"ra", "dec", "image_array"} - set(f.keys()) + if missing: + raise ValueError(f"missing datasets: {sorted(missing)}") + ra = f["ra"][:] + dec = f["dec"][:] + image = f["image_array"][:] + if len(ra) != len(dec) or len(ra) != len(image): + raise ValueError("ra/dec/image_array lengths differ") + return ra, dec, image diff --git a/v5/mmu/sources/local_fits.py b/v5/mmu/sources/local_fits.py new file mode 100644 index 0000000..78c2893 --- /dev/null +++ b/v5/mmu/sources/local_fits.py @@ -0,0 +1,35 @@ +"""Local FITS fallback readers for APOGEE/SDSS-style sources.""" +from __future__ import annotations + +import numpy as np +from typing import Any, cast + + +def dedup_highest_snr(ids, snr) -> np.ndarray: + ids_arr = np.asarray(ids) + snr_arr = np.asarray(snr, dtype=float) + keep: dict[object, int] = {} + for i, ident in enumerate(ids_arr): + if ident not in keep or snr_arr[i] > snr_arr[keep[ident]]: + keep[ident] = i + return np.array(sorted(keep.values()), dtype=int) + + +def crop_or_pad_flux(flux, length: int) -> np.ndarray: + arr = np.asarray(flux, dtype=np.float32).reshape(-1) + out = np.full(length, np.nan, dtype=np.float32) + n = min(length, len(arr)) + out[:n] = arr[:n] + return out + + +def read_apogee(path, *, snr_min: float = 50.0, flux_len: int = 7514): + from astropy.io import fits + + data = cast(Any, fits.getdata(path)) + mask = np.asarray(data["snr"], dtype=float) >= snr_min + data = data[mask] + keep = dedup_highest_snr(data["apogee_id"], data["snr"]) + data = data[keep] + return {"ra": np.asarray(data["ra"], dtype=float), "dec": np.asarray(data["dec"], dtype=float), + "flux": np.stack([crop_or_pad_flux(f, flux_len) for f in data["flux"]])} diff --git a/v5/mmu/sources/lsdb_mmu.py b/v5/mmu/sources/lsdb_mmu.py new file mode 100644 index 0000000..4d5ff91 --- /dev/null +++ b/v5/mmu/sources/lsdb_mmu.py @@ -0,0 +1,25 @@ +"""Generic MMU/HATS reader helpers over lsdb.""" +from __future__ import annotations +from importlib import import_module + + +def collection_uri(org: str, dataset: str) -> str: + if not org or not dataset: + raise ValueError("org and dataset are required") + return f"hf://datasets/{org}/{dataset}" + + +def candidate_dict_from_crossmatch(seed_idx, src_idx, sep_arcsec, quality): + out: dict[int, list[tuple[float, int, float]]] = {} + for s, c, d, q in zip(seed_idx, src_idx, sep_arcsec, quality): + out.setdefault(int(s), []).append((float(d), int(c), float(q))) + return out + + +def read_pixel_crossmatch(seed_cat, org: str, dataset: str, columns, order: int, pixel: int, radius_arcsec: float): + lsdb = import_module("lsdb") + + src = lsdb.open_catalog(collection_uri(org, dataset), + search_filter=lsdb.PixelSearch([(order, pixel)]), + columns=list(columns)) + return seed_cat.crossmatch(src, radius_arcsec=radius_arcsec, n_neighbors=5).compute() diff --git a/v5/mmu/sources/ztf_s3.py b/v5/mmu/sources/ztf_s3.py new file mode 100644 index 0000000..99a600e --- /dev/null +++ b/v5/mmu/sources/ztf_s3.py @@ -0,0 +1,34 @@ +"""ZTF public S3/HATS helpers.""" +from __future__ import annotations + +import numpy as np +from importlib import import_module + + +def s3_hats_uri(dr: str = "dr24", *, kind: str = "lc") -> str: + if kind not in {"lc", "objects"}: + raise ValueError("kind must be lc or objects") + return f"s3://ipac-irsa-ztf/ztf/enhanced/{dr}/{kind}/hats" + + +def lightcurve_to_fixed(times, mags, errs, *, max_len: int) -> dict[str, np.ndarray]: + if max_len <= 0: + raise ValueError("max_len must be positive") + arrays = [np.asarray(x, dtype=np.float32) for x in (times, mags, errs)] + n = min(max_len, *(len(a) for a in arrays)) + out = {} + for name, arr in zip(("time", "mag", "err"), arrays): + fixed = np.full(max_len, np.nan, dtype=np.float32) + fixed[:n] = arr[:n] + out[name] = fixed + mask = np.zeros(max_len, dtype=bool) + mask[:n] = True + out["valid"] = mask + return out + + +def read_pixel(order: int, pixel: int, *, dr: str = "dr24"): + lsdb = import_module("lsdb") + + return lsdb.open_catalog(s3_hats_uri(dr), storage_options={"anon": True}, + search_filter=lsdb.PixelSearch([(order, pixel)])).compute() diff --git a/v5/pyproject.toml b/v5/pyproject.toml new file mode 100644 index 0000000..1395386 --- /dev/null +++ b/v5/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[project] +name = "omnisky" +version = "0.0.0" +requires-python = ">=3.11" + +[tool.setuptools.packages.find] +include = ["mmu*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-q" +markers = ["network: requires outbound internet (skipped by default)"] + +[tool.mypy] +files = ["mmu"] +strict = true diff --git a/v5/scripts/__init__.py b/v5/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/v5/scripts/build_seed_catalogs.py b/v5/scripts/build_seed_catalogs.py new file mode 100644 index 0000000..0cedb95 --- /dev/null +++ b/v5/scripts/build_seed_catalogs.py @@ -0,0 +1,107 @@ +"""Build population seed catalogs.""" +from __future__ import annotations + +import argparse +import csv +from pathlib import Path +from typing import Any + +import numpy as np + +from mmu.ids import assign_global_id +from mmu.records import Record, seed_path, write_jsonl +from mmu.sources.local_fits import dedup_highest_snr + + +POPULATION_NAMESPACE = {"galaxy": 0, "star": 1, "agn": 2} + + +def namespace_ids(ids, population: str) -> np.ndarray: + """Return a collision-free population-scoped object UID. + + Order-29 HEALPix IDs use roughly 62 bits. Three populations cannot be + packed injectively into a signed int64 without losing HEALPix fidelity, so + the namespaced release key is a string while ``global_object_id`` remains + the raw MMU-compatible int64 HEALPix index. + """ + if population not in POPULATION_NAMESPACE: + raise ValueError(f"unknown population: {population}") + base = np.asarray(ids, dtype=np.int64) + return np.asarray([f"{population}:{int(value)}" for value in base], dtype=object) + + +def assign_ids_and_assert_unique(ra, dec, *, population: str = "galaxy") -> np.ndarray: + ids = namespace_ids(assign_global_id(ra, dec), population) + if len(np.unique(ids)) != len(ids): + raise ValueError("global_object_id collision in seed") + return ids + + +def build_synthetic_seed(*, population: str, n: int) -> list[Record]: + if n <= 0: + raise ValueError("n must be positive") + offsets = {"galaxy": 10.0, "star": 110.0, "agn": 210.0} + if population not in offsets: + raise ValueError(f"unknown population: {population}") + ra = offsets[population] + np.arange(n, dtype=float) * 0.01 + dec = np.full(n, 2.0 if population != "agn" else -2.0, dtype=float) + global_ids = assign_global_id(ra, dec) + object_uids = namespace_ids(global_ids, population) + rows: list[Record] = [] + for i in range(n): + row: Record = { + "object_uid": str(object_uids[i]), + "global_object_id": int(global_ids[i]), + "seed_ra_deg": float(ra[i]), + "seed_dec_deg": float(dec[i]), + "population": population, + "native_id": f"{population}-{i:04d}", + } + if population == "star": + row.update({"pmra": 1000.0 if i == 0 else 5.0, "pmdec": 0.0, + "parallax": 50.0 if i == 0 else 1.0, "rv": 0.0}) + rows.append(row) + return rows + + +def read_seed_csv(path: str | Path, *, population: str) -> list[Record]: + with Path(path).open(newline="") as f: + raw = list(csv.DictReader(f)) + ra = np.asarray([float(row["ra"]) for row in raw], dtype=float) + dec = np.asarray([float(row["dec"]) for row in raw], dtype=float) + global_ids = assign_global_id(ra, dec) + object_uids = namespace_ids(global_ids, population) + rows: list[Record] = [] + for i, row in enumerate(raw): + out: Record = dict(row) + out.update({"object_uid": str(object_uids[i]), "global_object_id": int(global_ids[i]), + "seed_ra_deg": float(ra[i]), "seed_dec_deg": float(dec[i]), + "population": population}) + rows.append(out) + if len({row["object_uid"] for row in rows}) != len(rows): + raise ValueError("object_uid collision in seed") + return rows + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--population", choices=["galaxy", "star", "agn"], required=True) + ap.add_argument("--out", required=True) + ap.add_argument("--release-root", default=None) + ap.add_argument("--test-mode", action="store_true") + ap.add_argument("--n", type=int, default=5) + ap.add_argument("--input-csv", default=None) + args = ap.parse_args() + if args.test_mode: + records = build_synthetic_seed(population=args.population, n=args.n) + elif args.input_csv: + records = read_seed_csv(args.input_csv, population=args.population) + else: + raise SystemExit("provide --test-mode or --input-csv for seed construction") + out = seed_path(args.release_root, population=args.population) if args.release_root else Path(args.out) + write_jsonl(records, out) + print(f"wrote {len(records)} {args.population} seeds to {out}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/check_phase0_gate.py b/v5/scripts/check_phase0_gate.py new file mode 100644 index 0000000..47bd237 --- /dev/null +++ b/v5/scripts/check_phase0_gate.py @@ -0,0 +1,44 @@ +"""Phase-0 acceptance gate: internet reachable + sources reachable + crossmatch reproduces +Smith42. Exit non-zero if the gate fails (so the sbatch job surfaces failure).""" +from __future__ import annotations +import argparse +import json +import sys +from typing import Any + +def evaluate_gate(probe: dict[str, Any], crossmatch: dict[str, Any], min_recall: float = 0.8) -> dict[str, Any]: + reasons: list[str] = [] + if not probe.get("internet_ok"): + reasons.append("compute-node internet NOT reachable (switch to Globus pre-stage)") + if not probe.get("sources"): + reasons.append("no sources probed") + if any(not s["reachable"] for s in probe.get("sources", [])): + bad = [s["name"] for s in probe["sources"] if not s["reachable"]] + reasons.append(f"unreachable sources: {bad}") + if any(s.get("n_rows_sampled", 0) <= 0 for s in probe.get("sources", [])): + empty = [s["name"] for s in probe["sources"] if s.get("n_rows_sampled", 0) <= 0] + reasons.append(f"sources sampled zero rows: {empty}") + if crossmatch.get("n_ref_footprint", crossmatch.get("n_ref", 0)) <= 0: + reasons.append("zero Smith42 reference rows in probed footprint") + if crossmatch.get("recall", 0.0) < min_recall: + reasons.append(f"crossmatch recall {crossmatch.get('recall')} < {min_recall}") + if crossmatch.get("n_matched_pixel", 0) <= 0: + reasons.append("zero matches in probe pixel") + return {"passed": len(reasons) == 0, "reasons": reasons} + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--probe", required=True) + ap.add_argument("--crossmatch", required=True) + ap.add_argument("--min-recall", type=float, default=0.8) + args = ap.parse_args() + with open(args.probe) as f: + probe = json.load(f) + with open(args.crossmatch) as f: + crossmatch = json.load(f) + res = evaluate_gate(probe, crossmatch, args.min_recall) + print(json.dumps(res, indent=2)) + sys.exit(0 if res["passed"] else 1) + +if __name__ == "__main__": + main() diff --git a/v5/scripts/convert_to_hats.py b/v5/scripts/convert_to_hats.py new file mode 100644 index 0000000..1c7f13f --- /dev/null +++ b/v5/scripts/convert_to_hats.py @@ -0,0 +1,22 @@ +"""Convert non-HATS sources to HATS or explicitly flag them.""" +from __future__ import annotations + +from dataclasses import asdict +from typing import Any + +from mmu.config import Source + + +def needs_conversion(source: Source) -> bool: + return source.kind not in {"lsdb_mmu", "ztf_s3"} + + +def conversion_report(results: list[dict[str, Any]]) -> dict[str, Any]: + converted = sorted([r["name"] for r in results if r.get("converted")]) + flagged = sorted([{ "name": r["name"], "reason": r.get("reason", "unknown")} + for r in results if not r.get("converted")], key=lambda r: r["name"]) + return {"converted": converted, "flagged_unconvertible": flagged} + + +def source_to_dict(source: Source) -> dict[str, Any]: + return asdict(source) diff --git a/v5/scripts/false_match_report.py b/v5/scripts/false_match_report.py new file mode 100644 index 0000000..bc6d8cc --- /dev/null +++ b/v5/scripts/false_match_report.py @@ -0,0 +1,70 @@ +"""False-match helpers for galaxy and stellar null tests.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +import numpy as np + + +def random_offsets(*, n: int, r_min_arcsec: float, r_max_arcsec: float, seed: int) -> tuple[np.ndarray, np.ndarray]: + if n < 0 or r_min_arcsec < 0 or r_max_arcsec < r_min_arcsec: + raise ValueError("invalid annulus") + rng = np.random.default_rng(seed) + theta = rng.uniform(0, 2 * np.pi, n) + r = np.sqrt(rng.uniform(r_min_arcsec ** 2, r_max_arcsec ** 2, n)) + return r * np.cos(theta), r * np.sin(theta) + + +def gate_false_match(fmr_by_bin: dict[str, float], *, threshold: float = 0.001) -> dict[str, Any]: + passed = sorted([k for k, v in fmr_by_bin.items() if v <= threshold]) + low = sorted([k for k, v in fmr_by_bin.items() if v > threshold]) + return {"passed_bins": passed, "low_confidence_bins": low} + + +def pm_scramble(pmra, pmdec, *, seed: int) -> tuple[np.ndarray, np.ndarray]: + pmra = np.asarray(pmra, dtype=float) + pmdec = np.asarray(pmdec, dtype=float) + mag = np.hypot(pmra, pmdec) + rng = np.random.default_rng(seed) + theta = rng.uniform(0, 2 * np.pi, len(mag)) + return mag * np.cos(theta), mag * np.sin(theta) + + +def confusion_radius(match_radius, pm_mas_yr, dt_yr, parallax_mas, sigma_astro) -> np.ndarray: + return np.sqrt(np.asarray(match_radius) ** 2 + (np.asarray(pm_mas_yr) * dt_yr / 1000.0) ** 2 + + (np.asarray(parallax_mas) / 1000.0) ** 2 + np.asarray(sigma_astro) ** 2) + + +def parse_bin_values(values: list[str]) -> dict[str, float]: + out: dict[str, float] = {} + for value in values: + if "=" not in value: + raise ValueError(f"bin values must be name=value, got {value!r}") + name, raw = value.split("=", 1) + out[name] = float(raw) + return out + + +def build_report(fmr_by_bin: dict[str, float], *, threshold: float) -> dict[str, Any]: + gate = gate_false_match(fmr_by_bin, threshold=threshold) + return {"threshold": threshold, "fmr_by_bin": fmr_by_bin, **gate, + "passed": len(gate["low_confidence_bins"]) == 0} + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--out", required=True) + ap.add_argument("--threshold", type=float, default=0.001) + ap.add_argument("--bin", action="append", default=[], help="false-match bin as name=value") + args = ap.parse_args() + fmr_by_bin = parse_bin_values(args.bin) if args.bin else {"all": 0.0} + report = build_report(fmr_by_bin, threshold=args.threshold) + Path(args.out).write_text(json.dumps(report, indent=2, sort_keys=True)) + print(json.dumps(report, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/finalize_release.py b/v5/scripts/finalize_release.py new file mode 100644 index 0000000..e1aa7f8 --- /dev/null +++ b/v5/scripts/finalize_release.py @@ -0,0 +1,46 @@ +"""Aggregate finalized shards into a release manifest.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +from mmu.records import read_jsonl, write_jsonl + + +def summarize_final_shards(final_paths: list[Path]) -> dict[str, Any]: + total = 0 + populations: dict[str, int] = {} + splits: dict[str, int] = {} + for path in final_paths: + rows = read_jsonl(path) + total += len(rows) + for row in rows: + populations[row["population"]] = populations.get(row["population"], 0) + 1 + splits[row.get("split", "unknown")] = splits.get(row.get("split", "unknown"), 0) + 1 + return {"n_rows": total, "populations": populations, "splits": splits, + "shards": [str(p) for p in final_paths]} + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--release-root", required=True) + args = ap.parse_args() + root = Path(args.release_root) + final_paths = sorted((root / "final").glob("population=*/shard=*.jsonl")) + if not final_paths: + raise SystemExit(f"no final shards under {root / 'final'}") + manifest = summarize_final_shards(final_paths) + release_dir = root / "release" + release_dir.mkdir(parents=True, exist_ok=True) + (release_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True)) + combined: list[dict[str, Any]] = [] + for path in final_paths: + combined.extend(read_jsonl(path)) + write_jsonl(combined, release_dir / "data.jsonl") + print(f"wrote release manifest with {manifest['n_rows']} rows to {release_dir}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/finalize_shard.py b/v5/scripts/finalize_shard.py new file mode 100644 index 0000000..85fd56d --- /dev/null +++ b/v5/scripts/finalize_shard.py @@ -0,0 +1,99 @@ +"""Finalize one seed shard by joining source shards and enforcing release criteria.""" +from __future__ import annotations + +import argparse +import hashlib +from pathlib import Path +from typing import Any + +import numpy as np + +from mmu.io_atomic import unit_state, write_done_marker +from mmu.records import final_shard_path, read_jsonl, source_shard_path, write_jsonl + + +def enforce_min_instruments(mask, *, min_instruments: int) -> np.ndarray: + return np.asarray(mask, dtype=bool).sum(axis=1) >= min_instruments + + +def assign_split(ra, dec, *, nside: int = 8, seed: int = 42) -> np.ndarray: + from astropy_healpix import HEALPix + import astropy.units as u + + hp = HEALPix(nside=nside, order="nested") + pix = np.asarray(hp.lonlat_to_healpix(np.asarray(ra) * u.deg, np.asarray(dec) * u.deg)) + labels = [] + for p in pix: + digest = hashlib.sha256(f"{seed}:{int(p)}".encode()).digest()[0] / 255.0 + labels.append("train" if digest < 0.8 else "val" if digest < 0.9 else "test") + return np.asarray(labels, dtype=object) + + +def finalize_records(source_rows_by_source: dict[str, list[dict[str, Any]]], *, min_instruments: int = 2) -> list[dict[str, Any]]: + by_uid: dict[str, dict[str, Any]] = {} + for source, rows in source_rows_by_source.items(): + for row in rows: + uid = str(row["object_uid"]) + current = by_uid.setdefault(uid, { + "object_uid": uid, + "global_object_id": int(row["global_object_id"]), + "population": row["population"], + "seed_ra_deg": float(row["match_ra_deg"]), + "seed_dec_deg": float(row["match_dec_deg"]), + "sources": [], + "instrument_presence_mask": [], + "match_sep_arcsec": {}, + "match_ambiguous": {}, + "n_candidates_within_radius": {}, + }) + current["sources"].append(source) + current["instrument_presence_mask"].append(bool(row.get("instrument_present", True))) + current["match_sep_arcsec"][source] = float(row.get("match_sep_arcsec", 0.0)) + current["match_ambiguous"][source] = bool(row.get("match_ambiguous", False)) + current["n_candidates_within_radius"][source] = int(row.get("n_candidates_within_radius", 1)) + out: list[dict[str, Any]] = [] + for row in by_uid.values(): + row["n_instruments_present"] = int(sum(row["instrument_presence_mask"])) + if row["n_instruments_present"] >= min_instruments: + out.append(row) + if out: + split = assign_split(np.asarray([r["seed_ra_deg"] for r in out]), + np.asarray([r["seed_dec_deg"] for r in out])) + for row, label in zip(out, split): + row["split"] = str(label) + row["low_confidence"] = False + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--release-root", required=True) + ap.add_argument("--shard", type=int, required=True) + ap.add_argument("--population", default="galaxy") + ap.add_argument("--sources", required=True) + ap.add_argument("--manifest", required=True) + ap.add_argument("--schema-version", default="v5.1") + ap.add_argument("--code-sha", default="local") + ap.add_argument("--min-instruments", type=int, default=2) + args = ap.parse_args() + with open(args.manifest) as f: + manifest = __import__("json").load(f) + sources = [s for s in args.sources.split(",") if s] + source_rows: dict[str, list[dict[str, Any]]] = {} + completed: set[str] = set() + for source in sources: + path = source_shard_path(args.release_root, population=args.population, source=source, shard=args.shard) + state = unit_state(path, manifest_hash=manifest["manifest_hash"], schema_version=args.schema_version, code_sha=args.code_sha) + if state != "complete": + raise SystemExit(f"source shard not complete: {path} state={state}") + source_rows[source] = read_jsonl(path) + completed.add(source) + rows = finalize_records(source_rows, min_instruments=args.min_instruments) + out = final_shard_path(args.release_root, population=args.population, shard=args.shard) + write_jsonl(rows, out) + write_done_marker(out, manifest_hash=manifest["manifest_hash"], schema_version=args.schema_version, code_sha=args.code_sha) + print(f"wrote {len(rows)} final rows to {out}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/local_lsdb_dry_run.py b/v5/scripts/local_lsdb_dry_run.py new file mode 100644 index 0000000..85e18de --- /dev/null +++ b/v5/scripts/local_lsdb_dry_run.py @@ -0,0 +1,403 @@ +"""Bounded local LSDB smoke test for HATS catalog access. + +This script is intentionally a dry run: it proves that local/HF LSDB catalogs can +be opened, spatially filtered, materialized, and stored without requiring Delta. +It caps bytes written to disk; callers should still keep pixels and columns small +because LSDB may transiently materialize more than the final stored slice. +""" +from __future__ import annotations + +import argparse +import json +from dataclasses import dataclass +from importlib import import_module +from pathlib import Path +from typing import Any, Iterable + +from mmu.config import SOURCES, Source +from mmu.io_atomic import atomic_write_bytes +from mmu.sources.lsdb_mmu import collection_uri + + +DEFAULT_SMITH42_REVISION = "93d0fddf8c5b61028ee0b6d72fd0dbfa87b38624" +DEFAULT_DRY_RUN_COLUMNS = ("ra", "dec") + + +@dataclass(frozen=True, slots=True) +class PixelSpec: + order: int + pixel: int + + +@dataclass(frozen=True, slots=True) +class SearchSpec: + kind: str + label: str + order: int | None = None + pixel: int | None = None + ra: float | None = None + dec: float | None = None + radius_arcsec: float | None = None + + +def parse_byte_budget(value: str) -> int: + """Parse byte budgets like ``1000000000``, ``1GB``, or ``512MiB``.""" + raw = value.strip().lower().replace("_", "") + if not raw: + raise ValueError("byte budget cannot be empty") + suffixes = { + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "b": 1, + } + for suffix, multiplier in suffixes.items(): + if raw.endswith(suffix): + number = raw[: -len(suffix)] + break + else: + number = raw + multiplier = 1 + try: + budget = int(float(number) * multiplier) + except ValueError as exc: + raise ValueError(f"invalid byte budget: {value!r}") from exc + if budget <= 0: + raise ValueError("byte budget must be positive") + return budget + + +def parse_pixels(values: Iterable[str], *, default_order: int | None = None) -> list[PixelSpec]: + """Parse pixel specs from ``ORDER:PIXEL`` strings or bare pixels.""" + pixels: list[PixelSpec] = [] + for value in values: + raw = value.strip() + if not raw: + continue + if ":" in raw: + order_s, pixel_s = raw.split(":", 1) + order = int(order_s) + pixel = int(pixel_s) + else: + if default_order is None: + raise ValueError("bare pixels require --order") + order = default_order + pixel = int(raw) + if order < 0 or pixel < 0: + raise ValueError("order and pixel must be non-negative") + pixels.append(PixelSpec(order=order, pixel=pixel)) + if not pixels: + raise ValueError("at least one pixel is required") + return pixels + + +def parse_cones(values: Iterable[str]) -> list[SearchSpec]: + """Parse cone specs from ``RA,DEC,RADIUS_ARCSEC`` strings.""" + cones: list[SearchSpec] = [] + for value in values: + parts = [part.strip() for part in value.split(",")] + if len(parts) != 3 or any(not part for part in parts): + raise ValueError("cone specs must use RA,DEC,RADIUS_ARCSEC") + ra, dec, radius_arcsec = (float(part) for part in parts) + if not (0.0 <= ra < 360.0): + raise ValueError("cone RA must be in [0, 360)") + if not (-90.0 <= dec <= 90.0): + raise ValueError("cone Dec must be in [-90, 90]") + if radius_arcsec <= 0: + raise ValueError("cone radius must be positive") + label = f"cone_ra={ra:.6f}_dec={dec:.6f}_radius_arcsec={radius_arcsec:g}" + cones.append(SearchSpec(kind="cone", label=label, ra=ra, dec=dec, radius_arcsec=radius_arcsec)) + if not cones: + raise ValueError("at least one cone is required") + return cones + + +def pixel_searches(pixels: Iterable[PixelSpec]) -> list[SearchSpec]: + return [SearchSpec(kind="pixel", label=f"order={pixel.order}/pixel={pixel.pixel:06d}", + order=pixel.order, pixel=pixel.pixel) + for pixel in pixels] + + +def parse_columns(value: str | None) -> list[str]: + if value is None or not value.strip(): + return list(DEFAULT_DRY_RUN_COLUMNS) + columns = [part.strip() for part in value.split(",") if part.strip()] + if not columns: + raise ValueError("at least one column is required") + if "ra" not in columns or "dec" not in columns: + columns = ["ra", "dec", *[col for col in columns if col not in {"ra", "dec"}]] + return columns + + +def select_sources(names: str) -> list[Source]: + selected: list[Source] = [] + for name in [part.strip() for part in names.split(",") if part.strip()]: + if name not in SOURCES: + raise ValueError(f"unknown source: {name}") + source = SOURCES[name] + if source.kind != "lsdb_mmu": + raise ValueError(f"source {name!r} is {source.kind!r}, not an LSDB MMU HATS source") + selected.append(source) + if not selected: + raise ValueError("at least one source is required") + return selected + + +def source_catalog_uri(source: Source, overrides: dict[str, str]) -> str: + return overrides.get(source.name, collection_uri(source.org, source.dataset)) + + +def import_lsdb_or_exit() -> Any: + try: + return import_module("lsdb") + except ModuleNotFoundError as exc: + if exc.name != "lsdb": + raise + raise SystemExit( + "LSDB is not installed in this Python environment.\n" + "Create the project conda env with `conda env create -f environment.yml` " + "or install it into the active env with `python -m pip install lsdb`.\n" + "Then rerun with that env's Python, for example: " + "`python -m scripts.local_lsdb_dry_run --out-dir /tmp/omnisky-lsdb-dry-run`." + ) from exc + + +def parse_catalog_overrides(values: Iterable[str]) -> dict[str, str]: + overrides: dict[str, str] = {} + for value in values: + if "=" not in value: + raise ValueError("catalog overrides must use SOURCE=URI") + name, uri = value.split("=", 1) + name = name.strip() + uri = uri.strip() + if not name or not uri: + raise ValueError("catalog override source and URI must be non-empty") + overrides[name] = uri + return overrides + + +def output_path(out_dir: str | Path, *, source: str, pixel: PixelSpec, suffix: str) -> Path: + return Path(out_dir) / f"source={source}" / f"order={pixel.order}" / f"pixel={pixel.pixel:06d}.{suffix}" + + +def output_search_path(out_dir: str | Path, *, source: str, search: SearchSpec, suffix: str) -> Path: + if search.kind == "pixel": + if search.order is None or search.pixel is None: + raise ValueError("pixel search missing order/pixel") + return output_path(out_dir, source=source, pixel=PixelSpec(search.order, search.pixel), suffix=suffix) + safe_label = search.label.replace("/", "_").replace("=", "-").replace(",", "_") + return Path(out_dir) / f"source={source}" / f"{safe_label}.{suffix}" + + +def dataframe_bytes(df: Any) -> int: + usage = df.memory_usage(index=True, deep=True) + return int(usage.sum() if hasattr(usage, "sum") else usage) + + +def rows_that_fit(*, n_rows: int, frame_bytes: int, remaining_bytes: int) -> int: + if n_rows <= 0 or frame_bytes <= 0 or remaining_bytes <= 0: + return 0 + if frame_bytes <= remaining_bytes: + return n_rows + approx_bytes_per_row = max(frame_bytes / n_rows, 1.0) + return max(0, min(n_rows, int(remaining_bytes / approx_bytes_per_row))) + + +def dataframe_records(df: Any) -> list[dict[str, Any]]: + records = df.to_dict(orient="records") + return [dict(row) for row in records] + + +def write_dataframe(df: Any, path: Path, *, prefer_parquet: bool) -> tuple[str, int]: + path.parent.mkdir(parents=True, exist_ok=True) + if prefer_parquet: + try: + import_module("pyarrow") + except ModuleNotFoundError: + prefer_parquet = False + if prefer_parquet: + parquet_path = path.with_suffix(".parquet") + tmp = parquet_path.with_suffix(parquet_path.suffix + ".tmp") + df.to_parquet(tmp) + tmp.replace(parquet_path) + return str(parquet_path), parquet_path.stat().st_size + jsonl_path = path.with_suffix(".jsonl") + payload = "".join(json.dumps(row, sort_keys=True, default=str) + "\n" for row in dataframe_records(df)) + atomic_write_bytes(payload.encode(), jsonl_path) + return str(jsonl_path), jsonl_path.stat().st_size + + +def is_no_coverage_error(exc: Exception) -> bool: + return isinstance(exc, ValueError) and "no coverage" in str(exc).lower() + + +def make_search_filter(*, lsdb: Any, search: SearchSpec) -> Any: + if search.kind == "pixel": + if search.order is None or search.pixel is None: + raise ValueError("pixel search missing order/pixel") + return lsdb.PixelSearch([(search.order, search.pixel)]) + if search.kind == "cone": + if search.ra is None or search.dec is None or search.radius_arcsec is None: + raise ValueError("cone search missing ra/dec/radius") + return lsdb.ConeSearch(ra=search.ra, dec=search.dec, radius_arcsec=search.radius_arcsec) + raise ValueError(f"unknown search kind: {search.kind}") + + +def fetch_search_dataframe(*, lsdb: Any, uri: str, search: SearchSpec, columns: list[str]) -> Any: + search_filter = make_search_filter(lsdb=lsdb, search=search) + catalog = lsdb.open_catalog(uri, search_filter=search_filter, columns=columns) + return catalog.compute() + + +def run_crossmatch_smoke(*, lsdb: Any, seed_df: Any, target_uri: str, columns: list[str], radius_arcsec: float, max_seed_rows: int, search: SearchSpec | None) -> dict[str, Any]: + if len(seed_df) == 0: + return {"attempted": False, "reason": "no seed rows"} + seed = seed_df[["ra", "dec"]].head(max_seed_rows).copy() + seed["dry_seed_id"] = range(len(seed)) + seed_catalog = lsdb.from_dataframe(seed, ra_column="ra", dec_column="dec") + try: + if search is None: + target = lsdb.open_catalog(target_uri, columns=columns) + else: + target = lsdb.open_catalog(target_uri, columns=columns, search_filter=make_search_filter(lsdb=lsdb, search=search)) + matched = seed_catalog.crossmatch(target, radius_arcsec=radius_arcsec).compute() + except ValueError as exc: + if is_no_coverage_error(exc): + return {"attempted": False, "reason": "target search region has no coverage"} + raise + return {"attempted": True, "seed_rows": int(len(seed)), "matched_rows": int(len(matched))} + + +def build_report(*, max_bytes: int, stored_bytes: int, fetches: list[dict[str, Any]], crossmatch: dict[str, Any] | None) -> dict[str, Any]: + return { + "max_bytes": int(max_bytes), + "stored_bytes": int(stored_bytes), + "remaining_bytes": int(max_bytes - stored_bytes), + "cap_reached": stored_bytes >= max_bytes, + "fetches": fetches, + "crossmatch": crossmatch or {"attempted": False}, + } + + +def main() -> None: + ap = argparse.ArgumentParser(description="Bounded local LSDB HATS dry run") + ap.add_argument("--sources", default="desi,hsc", help="comma-separated LSDB source names from mmu.config") + ap.add_argument("--order", type=int, default=4, help="default order for bare --pixels values") + ap.add_argument("--pixels", nargs="+", default=["257"], help="pixel specs as PIXEL or ORDER:PIXEL") + ap.add_argument("--cone", action="append", default=[], help="use cone search RA,DEC,RADIUS_ARCSEC; can be repeated and overrides --pixels") + ap.add_argument("--columns", default=None, help="comma-separated columns; ra,dec are added if omitted") + ap.add_argument("--max-bytes", default="1GB", help="maximum bytes to store locally, e.g. 1GB or 512MiB") + ap.add_argument("--out-dir", required=True) + ap.add_argument("--catalog", action="append", default=[], help="override catalog URI as SOURCE=URI") + ap.add_argument("--jsonl", action="store_true", help="write JSONL instead of Parquet") + ap.add_argument("--crossmatch", action="store_true", help="run a tiny first-source to second-source LSDB crossmatch smoke") + ap.add_argument("--crossmatch-radius-arcsec", type=float, default=1.0) + ap.add_argument("--max-crossmatch-seed-rows", type=int, default=100) + ap.add_argument("--smith42-revision", default=DEFAULT_SMITH42_REVISION, + help="document the Smith42 revision paired with this dry run") + args = ap.parse_args() + + max_bytes = parse_byte_budget(args.max_bytes) + searches = parse_cones(args.cone) if args.cone else pixel_searches(parse_pixels(args.pixels, default_order=args.order)) + columns = parse_columns(args.columns) + sources = select_sources(args.sources) + overrides = parse_catalog_overrides(args.catalog) + lsdb = import_lsdb_or_exit() + + stored_bytes = 0 + fetches: list[dict[str, Any]] = [] + first_frame: Any | None = None + first_search: SearchSpec | None = None + first_uri: str | None = None + second_uri: str | None = None + stop = False + for source in sources: + uri = source_catalog_uri(source, overrides) + if first_uri is None: + first_uri = uri + elif second_uri is None: + second_uri = uri + for search in searches: + if stop: + break + try: + df = fetch_search_dataframe(lsdb=lsdb, uri=uri, search=search, columns=columns) + except ValueError as exc: + if not is_no_coverage_error(exc): + raise + fetches.append({ + "source": source.name, + "uri": uri, + "search": search.label, + "search_kind": search.kind, + "order": search.order, + "pixel": search.pixel, + "columns": columns, + "fetched_rows": 0, + "fetched_memory_bytes": 0, + "stored_rows": 0, + "stored_bytes": 0, + "output": None, + "truncated_to_fit_budget": False, + "error": "no_coverage", + }) + continue + frame_bytes = dataframe_bytes(df) + fit_rows = rows_that_fit(n_rows=len(df), frame_bytes=frame_bytes, remaining_bytes=max_bytes - stored_bytes) + truncated = fit_rows < len(df) + out_file = None + written_bytes = 0 + if fit_rows > 0: + to_write = df.head(fit_rows).copy() + if first_frame is None: + first_frame = to_write + first_search = search + base = output_search_path(args.out_dir, source=source.name, search=search, suffix="parquet") + out_file, written_bytes = write_dataframe(to_write, base, prefer_parquet=not args.jsonl) + stored_bytes += written_bytes + fetches.append({ + "source": source.name, + "uri": uri, + "search": search.label, + "search_kind": search.kind, + "order": search.order, + "pixel": search.pixel, + "columns": columns, + "fetched_rows": int(len(df)), + "fetched_memory_bytes": int(frame_bytes), + "stored_rows": int(fit_rows), + "stored_bytes": int(written_bytes), + "output": out_file, + "truncated_to_fit_budget": truncated, + }) + stop = stored_bytes >= max_bytes or (truncated and fit_rows == 0) + if stop: + break + + crossmatch = {"attempted": False} + if args.crossmatch: + if first_frame is None or second_uri is None: + crossmatch = {"attempted": False, "reason": "need stored rows and at least two sources"} + else: + crossmatch = run_crossmatch_smoke( + lsdb=lsdb, + seed_df=first_frame, + target_uri=second_uri, + columns=columns, + radius_arcsec=args.crossmatch_radius_arcsec, + max_seed_rows=args.max_crossmatch_seed_rows, + search=first_search, + ) + + report = build_report(max_bytes=max_bytes, stored_bytes=stored_bytes, fetches=fetches, crossmatch=crossmatch) + report["smith42_revision"] = args.smith42_revision + report_path = Path(args.out_dir) / "local_lsdb_dry_run_report.json" + atomic_write_bytes(json.dumps(report, indent=2, sort_keys=True).encode(), report_path) + print(f"stored={stored_bytes} max={max_bytes} report={report_path}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/plan_work_units.py b/v5/scripts/plan_work_units.py new file mode 100644 index 0000000..4b7b0e8 --- /dev/null +++ b/v5/scripts/plan_work_units.py @@ -0,0 +1,60 @@ +"""Create the authoritative source × shard manifest.""" +from __future__ import annotations + +import argparse +import hashlib +import json +from pathlib import Path +from typing import Any + +from mmu.coordination import build_or_attach_manifest +from mmu.records import read_jsonl, seed_path + + +def enumerate_units(*, sources: list[str], n_objects: int, shard_size: int, population: str = "galaxy") -> list[dict[str, Any]]: + if n_objects < 0 or shard_size <= 0: + raise ValueError("n_objects must be >=0 and shard_size >0") + n_shards = (n_objects + shard_size - 1) // shard_size + return [{"population": population, "source": source, "shard": shard, + "row_start": shard * shard_size, "row_end": min((shard + 1) * shard_size, n_objects)} + for shard in range(n_shards) for source in sources] + + +def inputs_hash_for_seed(seed_file: Path, *, sources: list[str], shard_size: int) -> str: + h = hashlib.sha256() + h.update(seed_file.read_bytes()) + h.update(json.dumps({"sources": sources, "shard_size": shard_size}, sort_keys=True).encode()) + return h.hexdigest()[:16] + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--sources", required=True) + ap.add_argument("--n-objects", type=int, required=True) + ap.add_argument("--shard-size", type=int, default=50_000) + ap.add_argument("--population", default="galaxy") + ap.add_argument("--out", required=True) + ap.add_argument("--release-root", default=None) + ap.add_argument("--seed", default=None) + ap.add_argument("--inputs-hash", default="manual") + args = ap.parse_args() + seed_file = Path(args.seed) if args.seed else (seed_path(args.release_root, population=args.population) if args.release_root else None) + n_objects = args.n_objects + inputs_hash = args.inputs_hash + sources = [s for s in args.sources.split(",") if s] + if seed_file is not None and seed_file.exists(): + n_objects = len(read_jsonl(seed_file)) + if args.inputs_hash == "manual": + inputs_hash = inputs_hash_for_seed(seed_file, sources=sources, shard_size=args.shard_size) + units = enumerate_units(sources=args.sources.split(","), n_objects=args.n_objects, + shard_size=args.shard_size, population=args.population) + units = enumerate_units(sources=sources, n_objects=n_objects, + shard_size=args.shard_size, population=args.population) + out_dir = Path(args.out) + mhash, _ = build_or_attach_manifest(out_dir, units, inputs_hash=inputs_hash) + (out_dir / "work_units.json").write_text(json.dumps({"manifest_hash": mhash, "units": units}, indent=2, sort_keys=True)) + print(f"wrote {len(units)} units to {out_dir / 'work_units.json'}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/probe_crossmatch.py b/v5/scripts/probe_crossmatch.py new file mode 100644 index 0000000..5a51e09 --- /dev/null +++ b/v5/scripts/probe_crossmatch.py @@ -0,0 +1,51 @@ +"""Phase-0 proof: one-pixel DESI x HSC LSDB crossmatch + Smith42 concordance.""" +from __future__ import annotations +import argparse +import json +from importlib import import_module + +import numpy as np +from mmu.concordance import filter_reference_to_healpix_pixel, match_concordance + +# NOTE: column suffixes (*_desi/*_hsc) and Smith42 column names must be verified on first cluster run. + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--order", type=int, default=4) + ap.add_argument("--pixel", type=int, default=257) + ap.add_argument("--radius-arcsec", type=float, default=1.0) + ap.add_argument("--tol-arcsec", type=float, default=1.0) + ap.add_argument("--smith42-revision", required=True, + help="immutable Hugging Face dataset revision/commit for Smith42 validation") + ap.add_argument("--out", default="crossmatch_probe.json") + args = ap.parse_args() + lsdb = import_module("lsdb") + load_dataset = import_module("datasets").load_dataset + px = lsdb.PixelSearch([(args.order, args.pixel)]) + + desi = lsdb.open_catalog("hf://datasets/UniverseTBD/mmu_desi_edr_sv3", + search_filter=px, columns=["ra", "dec"]) + hsc = lsdb.open_catalog("hf://datasets/UniverseTBD/mmu_hsc_pdr3_dud_22.5", + search_filter=px, columns=["ra", "dec"]) + matched = desi.crossmatch(hsc, radius_arcsec=args.radius_arcsec, n_neighbors=5).compute() + our_ra = np.asarray(matched["ra_desi"]); our_dec = np.asarray(matched["dec_desi"]) + + ref = load_dataset("Smith42/desi_hsc_crossmatched", split="train", revision=args.smith42_revision) + ra_col = "ra" if "ra" in ref.column_names else "desi_ra" + dec_col = "dec" if "dec" in ref.column_names else "desi_dec" + ref_ra = np.asarray(ref[ra_col]); ref_dec = np.asarray(ref[dec_col]) + + ref_ra_px, ref_dec_px = filter_reference_to_healpix_pixel( + ref_ra, ref_dec, order=args.order, pixel=args.pixel + ) + conc = match_concordance(our_ra, our_dec, ref_ra_px, ref_dec_px, tol_arcsec=args.tol_arcsec) + conc["n_matched_pixel"] = int(len(matched)) + conc["n_ref_full"] = int(len(ref_ra)) + conc["n_ref_footprint"] = int(len(ref_ra_px)) + with open(args.out, "w") as f: + json.dump(conc, f, indent=2, sort_keys=True) + print(f"matched={len(matched)} recall={conc['recall']:.3f} " + f"median_sep={conc['median_sep_arcsec']:.3f}\" -> {args.out}") + +if __name__ == "__main__": + main() diff --git a/v5/scripts/probe_sources.py b/v5/scripts/probe_sources.py new file mode 100644 index 0000000..a1b1bae --- /dev/null +++ b/v5/scripts/probe_sources.py @@ -0,0 +1,58 @@ +"""Phase-0 probe: reachability + cold/warm latency + throughput for MMU HATS. + +Run on a compute node, not a login node. The output feeds the Phase-0 gate and +later per-service concurrency defaults. +""" +from __future__ import annotations +import argparse +import time +from importlib import import_module + +from mmu.probe_report import ProbeReport, SourceProbe +from mmu.reachability import probe_all + +# (catalog, org) — org namespaces differ; see spec sec 3.2 +SOURCES = { + "mmu_desi_edr_sv3": "UniverseTBD", + "mmu_hsc_pdr3_dud_22.5": "UniverseTBD", +} + +def time_open(uri: str, pixel: tuple[int, int], columns: list[str]): + lsdb = import_module("lsdb") + t0 = time.monotonic() + cat = lsdb.open_catalog(uri, search_filter=lsdb.PixelSearch([pixel]), columns=columns) + df = cat.compute() + return df, time.monotonic() - t0 + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--order", type=int, default=4) + ap.add_argument("--pixel", type=int, default=257) + ap.add_argument("--out", default="probe_report.json") + ap.add_argument("--sources", default="all", help="comma-separated source keys, or all") + args = ap.parse_args() + + reach = probe_all() + rep = ProbeReport(internet_ok=reach["internet_ok"], notes=f"reachability={reach}") + px = (args.order, args.pixel) + selected = SOURCES if args.sources == "all" else {s: SOURCES[s] for s in args.sources.split(",")} + for name, org in selected.items(): + uri = f"hf://datasets/{org}/{name}" + try: + df_cold, cold = time_open(uri, px, ["ra", "dec"]) + _, warm = time_open(uri, px, ["ra", "dec"]) + nbytes = int(df_cold.memory_usage(deep=True).sum()) + tput = (nbytes / 1e6) / cold if cold > 0 else 0.0 + rep.add(SourceProbe(name=name, reachable=True, cold_latency_s=round(cold, 2), + warm_latency_s=round(warm, 2), throughput_mb_s=round(tput, 2), + rate_limited=False, n_rows_sampled=len(df_cold))) + except Exception as e: + rep.add(SourceProbe(name=name, reachable=False, cold_latency_s=-1.0, + warm_latency_s=-1.0, throughput_mb_s=0.0, + rate_limited=False, n_rows_sampled=0, error=repr(e))) + print(f"PROBE FAIL {name}: {e}") + rep.to_json(args.out) + print(f"wrote {args.out}; internet_ok={rep.internet_ok}; sources={len(rep.sources)}") + +if __name__ == "__main__": + main() diff --git a/v5/scripts/run_source_shard.py b/v5/scripts/run_source_shard.py new file mode 100644 index 0000000..1fa7a27 --- /dev/null +++ b/v5/scripts/run_source_shard.py @@ -0,0 +1,179 @@ +"""SLURM array workhorse helpers.""" +from __future__ import annotations + +import argparse +import json +from importlib import import_module +from pathlib import Path +from typing import Any + +import numpy as np + +from mmu.config import SOURCES, Source +from mmu.coordination import partition_units +from mmu.io_atomic import unit_state, write_done_marker +from mmu.records import read_jsonl, seed_path, source_shard_path, write_jsonl + + +def select_unit(units: list[dict[str, Any]], *, task_id: int, partition_id: int, num_partitions: int) -> dict[str, Any]: + selected = partition_units(units, partition_id=partition_id, num_partitions=num_partitions) + if task_id < 0 or task_id >= len(selected): + raise IndexError("task_id outside selected partition") + return selected[task_id] + + +def synthetic_source_rows(seed_rows: list[dict[str, Any]], *, source: str, source_epoch: float = 2016.0) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for row in seed_rows: + match_ra = float(row["seed_ra_deg"]) + match_dec = float(row["seed_dec_deg"]) + if row.get("population") == "star" and "pmra" in row: + from mmu.matching import match_unit + prepared = match_unit({"ra": [match_ra], "dec": [match_dec], "pmra": [float(row.get("pmra", 0.0))], + "pmdec": [float(row.get("pmdec", 0.0))], + "parallax": [float(row.get("parallax", 1.0))], + "rv": [float(row.get("rv", 0.0))]}, + {"epoch_jyear": source_epoch}, population="star") + match_ra = float(prepared["ra"][0]) + match_dec = float(prepared["dec"][0]) + rows.append({ + "object_uid": row["object_uid"], + "global_object_id": int(row["global_object_id"]), + "population": row["population"], + "source": source, + "match_ra_deg": match_ra, + "match_dec_deg": match_dec, + "match_sep_arcsec": 0.1, + "match_ambiguous": False, + "n_candidates_within_radius": 1, + "instrument_present": True, + "payload_ref": f"test://{source}/{row['object_uid']}", + }) + return rows + + +def _row_value(row: Any, names: list[str], default: Any = None) -> Any: + for name in names: + if name in row: + return row[name] + return default + + +def lsdb_source_rows(seed_rows: list[dict[str, Any]], source: Source) -> list[dict[str, Any]]: + pandas = import_module("pandas") + lsdb = import_module("lsdb") + from mmu.sources.lsdb_mmu import collection_uri + + seed_df = pandas.DataFrame([ + {"object_uid": row["object_uid"], "global_object_id": row["global_object_id"], + "population": row["population"], "ra": row["seed_ra_deg"], "dec": row["seed_dec_deg"]} + for row in seed_rows + ]) + seed_cat = lsdb.from_dataframe(seed_df, ra_column="ra", dec_column="dec") + src = lsdb.open_catalog(collection_uri(source.org, source.dataset), columns=list(source.columns)) + matched = seed_cat.crossmatch(src, radius_arcsec=source.radius_arcsec, suffixes=("_seed", f"_{source.name}")).compute() + rows: list[dict[str, Any]] = [] + for _, row in matched.iterrows(): + uid = _row_value(row, ["object_uid_seed", "object_uid"]) + if uid is None: + continue + rows.append({ + "object_uid": str(uid), + "global_object_id": int(_row_value(row, ["global_object_id_seed", "global_object_id"])), + "population": str(_row_value(row, ["population_seed", "population"])), + "source": source.name, + "match_ra_deg": float(_row_value(row, [f"ra_{source.name}", "ra"], np.nan)), + "match_dec_deg": float(_row_value(row, [f"dec_{source.name}", "dec"], np.nan)), + "match_sep_arcsec": float(_row_value(row, ["_dist_arcsec"], 0.0)), + "match_ambiguous": False, + "n_candidates_within_radius": 1, + "instrument_present": True, + "payload_ref": f"{source.name}:{uid}", + }) + return rows + + +def legacy_hdf5_source_rows(seed_rows: list[dict[str, Any]], source: Source) -> list[dict[str, Any]]: + import astropy.units as u + from astropy.coordinates import SkyCoord, search_around_sky + from mmu.healpix import seed_pixel_set + from mmu.sources.legacy_hdf5 import pixel_files, read_cutouts + + candidates: list[tuple[float, float, str]] = [] + pixels = seed_pixel_set([row["seed_ra_deg"] for row in seed_rows], + [row["seed_dec_deg"] for row in seed_rows], order=4) + for pixel in pixels: + for path in pixel_files(source.dataset, pixel): + ra, dec, _image = read_cutouts(path) + candidates.extend((float(r), float(d), str(path)) for r, d in zip(ra, dec)) + if not candidates: + raise RuntimeError(f"no Legacy HDF5 candidates found under {source.dataset}") + seed_coord = SkyCoord(np.asarray([r["seed_ra_deg"] for r in seed_rows]) * u.deg, + np.asarray([r["seed_dec_deg"] for r in seed_rows]) * u.deg) + cand_coord = SkyCoord(np.asarray([c[0] for c in candidates]) * u.deg, + np.asarray([c[1] for c in candidates]) * u.deg) + idx_seed, idx_cand, sep, _ = search_around_sky(seed_coord, cand_coord, source.radius_arcsec * u.arcsec) + best: dict[int, tuple[int, float]] = {} + sep_values = np.asarray(sep.arcsec, dtype=np.float64) + for s, c, dist in zip(idx_seed, idx_cand, sep_values): + si = int(s); ci = int(c); dd = float(dist) + if si not in best or dd < best[si][1]: + best[si] = (ci, dd) + rows: list[dict[str, Any]] = [] + for seed_index, (cand_index, sep_arcsec) in best.items(): + seed = seed_rows[seed_index] + cand = candidates[cand_index] + rows.append({"object_uid": seed["object_uid"], "global_object_id": int(seed["global_object_id"]), + "population": seed["population"], "source": source.name, + "match_ra_deg": cand[0], "match_dec_deg": cand[1], + "match_sep_arcsec": sep_arcsec, "match_ambiguous": False, + "n_candidates_within_radius": 1, "instrument_present": True, + "payload_ref": cand[2]}) + return rows + + +def live_source_rows(seed_rows: list[dict[str, Any]], *, source_name: str) -> list[dict[str, Any]]: + if source_name not in SOURCES: + raise ValueError(f"unknown source: {source_name}") + source = SOURCES[source_name] + if source.kind == "lsdb_mmu": + return lsdb_source_rows(seed_rows, source) + if source.kind == "legacy_hdf5": + return legacy_hdf5_source_rows(seed_rows, source) + raise ValueError(f"source kind {source.kind!r} is not supported by run_source_shard yet") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--manifest", required=True) + ap.add_argument("--task-id", type=int, required=True) + ap.add_argument("--partition-id", type=int, default=0) + ap.add_argument("--num-partitions", type=int, default=1) + ap.add_argument("--release-root", required=True) + ap.add_argument("--schema-version", default="v5.1") + ap.add_argument("--code-sha", default="local") + ap.add_argument("--test-mode", action="store_true") + args = ap.parse_args() + with open(args.manifest) as f: + payload = json.load(f) + unit = select_unit(payload["units"], task_id=args.task_id, + partition_id=args.partition_id, num_partitions=args.num_partitions) + out = source_shard_path(args.release_root, population=unit["population"], + source=unit["source"], shard=int(unit["shard"])) + state = unit_state(out, manifest_hash=payload["manifest_hash"], + schema_version=args.schema_version, code_sha=args.code_sha) + if state == "complete": + print(f"skip complete {out}") + return + seed_rows = read_jsonl(seed_path(args.release_root, population=unit["population"])) + shard_rows = seed_rows[int(unit["row_start"]):int(unit["row_end"])] + rows = (synthetic_source_rows(shard_rows, source=unit["source"]) + if args.test_mode else live_source_rows(shard_rows, source_name=unit["source"])) + write_jsonl(rows, out) + write_done_marker(out, manifest_hash=payload["manifest_hash"], + schema_version=args.schema_version, code_sha=args.code_sha) + print(f"wrote {len(rows)} source rows to {out}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/upload_hf.py b/v5/scripts/upload_hf.py new file mode 100644 index 0000000..cd70392 --- /dev/null +++ b/v5/scripts/upload_hf.py @@ -0,0 +1,56 @@ +"""Hugging Face upload helpers.""" +from __future__ import annotations + +import argparse +from importlib import import_module +from pathlib import Path +from typing import Any + + +def build_repo_id(name: str, *, org: str = "UniverseTBD") -> str: + return f"{org}/{name}" + + +def verify_load_back(repo: str) -> dict[str, Any]: + """Verify both metadata and at least one JSON row can be read back from HF.""" + hf = import_module("huggingface_hub") + datasets = import_module("datasets") + hf.hf_hub_download(repo_id=repo, repo_type="dataset", filename="manifest.json") + dataset = datasets.load_dataset( + "json", + data_files=f"hf://datasets/{repo}/data.jsonl", + split="train", + streaming=True, + ) + first = next(iter(dataset), None) + if first is None: + raise RuntimeError("load-back verification found zero rows") + return {"ok": True, "first_keys": sorted(first.keys())} + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--release-root", required=True) + ap.add_argument("--repo", required=True) + ap.add_argument("--dry-run", action="store_true") + ap.add_argument("--skip-load-back", action="store_true") + args = ap.parse_args() + release_dir = Path(args.release_root) / "release" + data = release_dir / "data.jsonl" + manifest = release_dir / "manifest.json" + if not data.exists() or not manifest.exists(): + raise SystemExit(f"release is incomplete under {release_dir}") + if args.dry_run: + print(f"dry-run upload {release_dir} -> {args.repo}") + return + api = import_module("huggingface_hub").HfApi() + api.create_repo(repo_id=args.repo, repo_type="dataset", exist_ok=True) + api.upload_folder(repo_id=args.repo, repo_type="dataset", folder_path=str(release_dir)) + if not args.skip_load_back: + result = verify_load_back(args.repo) + print(f"load-back verified: {result}") + print(f"uploaded {release_dir} -> {args.repo}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/validate_release.py b/v5/scripts/validate_release.py new file mode 100644 index 0000000..ac7d251 --- /dev/null +++ b/v5/scripts/validate_release.py @@ -0,0 +1,50 @@ +"""Streaming release validators.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +import numpy as np + +from mmu.records import read_jsonl + + +def check_uniqueness(ids) -> dict[str, Any]: + arr = np.asarray(ids) + return {"ok": len(np.unique(arr)) == len(arr), "n": int(len(arr))} + + +def check_min_modalities(counts, k: int) -> dict[str, Any]: + arr = np.asarray(counts) + return {"ok": bool((arr >= k).all()), "min": int(arr.min()) if len(arr) else 0} + + +def validate_rows(rows: list[dict[str, Any]], *, min_instruments: int = 2) -> dict[str, Any]: + ids = [row["object_uid"] for row in rows] + counts = [int(row.get("n_instruments_present", 0)) for row in rows] + uniqueness = check_uniqueness(np.asarray(ids, dtype=object)) + modalities = check_min_modalities(np.asarray(counts, dtype=int), min_instruments) + required = {"object_uid", "global_object_id", "population", "n_instruments_present", "split"} + missing_rows = [i for i, row in enumerate(rows) if not required <= set(row)] + ok = uniqueness["ok"] and modalities["ok"] and not missing_rows + return {"ok": ok, "n_rows": len(rows), "uniqueness": uniqueness, + "modalities": modalities, "rows_missing_required": missing_rows[:20]} + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--release-root", required=True) + ap.add_argument("--min-instruments", type=int, default=2) + args = ap.parse_args() + data_path = Path(args.release_root) / "release" / "data.jsonl" + if not data_path.exists(): + raise SystemExit(f"missing release data: {data_path}") + result = validate_rows(read_jsonl(data_path), min_instruments=args.min_instruments) + print(json.dumps(result, indent=2, sort_keys=True)) + raise SystemExit(0 if result["ok"] else 1) + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/verify_gaia_xmatch.py b/v5/scripts/verify_gaia_xmatch.py new file mode 100644 index 0000000..302d9f7 --- /dev/null +++ b/v5/scripts/verify_gaia_xmatch.py @@ -0,0 +1,22 @@ +"""Single-job CDS XMatch helpers for Gaia verification.""" +from __future__ import annotations + +import argparse + + +def chunk_for_xmatch(n_rows: int, *, max_rows: int = 2_000_000) -> list[tuple[int, int]]: + if n_rows < 0 or max_rows <= 0: + raise ValueError("invalid row counts") + return [(start, min(start + max_rows, n_rows)) for start in range(0, n_rows, max_rows)] + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--input", required=True) + ap.add_argument("--out", required=True) + args = ap.parse_args() + raise SystemExit(f"CDS XMatch is intentionally single-job/cluster-only; input={args.input}, out={args.out}") + + +if __name__ == "__main__": + main() diff --git a/v5/scripts/verify_markers.py b/v5/scripts/verify_markers.py new file mode 100644 index 0000000..6cc4197 --- /dev/null +++ b/v5/scripts/verify_markers.py @@ -0,0 +1,53 @@ +"""Release marker audit.""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +from mmu.io_atomic import unit_state +from mmu.records import final_shard_path, source_shard_path + + +def audit_states(states: list[str]) -> dict[str, Any]: + bad = sorted({state for state in states if state != "complete"}) + return {"passed": not bad, "bad_states": bad} + + +def collect_states(*, release_root: str | Path, manifest_path: str | Path, + schema_version: str, code_sha: str, include_final: bool = True) -> list[str]: + with Path(manifest_path).open() as f: + manifest = json.load(f) + states: list[str] = [] + shards_by_pop: set[tuple[str, int]] = set() + for unit in manifest["units"]: + states.append(unit_state(source_shard_path(release_root, population=unit["population"], + source=unit["source"], shard=int(unit["shard"])), + manifest_hash=manifest["manifest_hash"], + schema_version=schema_version, code_sha=code_sha)) + shards_by_pop.add((unit["population"], int(unit["shard"]))) + if include_final: + for population, shard in sorted(shards_by_pop): + states.append(unit_state(final_shard_path(release_root, population=population, shard=shard), + manifest_hash=manifest["manifest_hash"], + schema_version=schema_version, code_sha=code_sha)) + return states + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--release-root", required=True) + ap.add_argument("--manifest", required=True) + ap.add_argument("--schema-version", default="v5.1") + ap.add_argument("--code-sha", default="local") + args = ap.parse_args() + states = collect_states(release_root=args.release_root, manifest_path=args.manifest, + schema_version=args.schema_version, code_sha=args.code_sha) + result = audit_states(states) + print(json.dumps(result, indent=2, sort_keys=True)) + raise SystemExit(0 if result["passed"] else 1) + + +if __name__ == "__main__": + main() diff --git a/v5/slurm/build_seeds.sbatch b/v5/slurm/build_seeds.sbatch new file mode 100644 index 0000000..745d84c --- /dev/null +++ b/v5/slurm/build_seeds.sbatch @@ -0,0 +1,17 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_build_seeds +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=32G +#SBATCH --time=02:00:00 +#SBATCH --output=build_seeds_%j.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +args=(--population "${POPULATION:?galaxy|star|agn}" --release-root "${RELEASE_ROOT:?}" --out "${SEED_OUT:-unused}") +if [[ -n "${TEST_MODE:-}" ]]; then args+=(--test-mode); fi +if [[ -n "${N_OBJECTS:-}" ]]; then args+=(--n "${N_OBJECTS}"); fi +if [[ -n "${INPUT_CSV:-}" ]]; then args+=(--input-csv "${INPUT_CSV}"); fi +"${OMNISKY_PYTHON_BIN}" -m scripts.build_seed_catalogs "${args[@]}" diff --git a/v5/slurm/common.sh b/v5/slurm/common.sh new file mode 100644 index 0000000..44030ad --- /dev/null +++ b/v5/slurm/common.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Shared cluster-agnostic setup for OmniSky v5 SLURM wrappers. +# +# Scheduler options are intentionally not hardcoded here. Pass site-specific +# account/partition/QOS values with SBATCH_* variables or sbatch flags, e.g.: +# SBATCH_ACCOUNT=my_alloc SBATCH_PARTITION=cpu sbatch slurm/probe.sbatch +# sbatch --account=my_alloc --partition=cpu slurm/probe.sbatch + +omnisky_activate_env() { + if [[ -n "${OMNISKY_ENV_ACTIVATE:-}" ]]; then + eval "${OMNISKY_ENV_ACTIVATE}" + elif command -v conda >/dev/null 2>&1; then + source "$(conda info --base)/etc/profile.d/conda.sh" + conda activate "${OMNISKY_CONDA_ENV:-omnisky}" + fi + + export OMNISKY_PYTHON_BIN="${OMNISKY_PYTHON_BIN:-python}" + "${OMNISKY_PYTHON_BIN}" -c 'import sys; print(f"OmniSky Python: {sys.executable}")' +} + +omnisky_cd_repo_root() { + local script_dir + script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + cd "${script_dir}/.." +} diff --git a/v5/slurm/finalize_array.sbatch b/v5/slurm/finalize_array.sbatch new file mode 100644 index 0000000..4d6e2f7 --- /dev/null +++ b/v5/slurm/finalize_array.sbatch @@ -0,0 +1,13 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_finalize +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --time=04:00:00 +#SBATCH --output=finalize_%A_%a.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +"${OMNISKY_PYTHON_BIN}" -m scripts.finalize_shard --release-root "${RELEASE_ROOT:?}" --manifest "${MANIFEST:?}" --population "${POPULATION:?}" --sources "${SOURCES:?comma-separated}" --shard "${SLURM_ARRAY_TASK_ID:?}" --code-sha "${CODE_SHA:-local}" --min-instruments "${MIN_INSTRUMENTS:-2}" diff --git a/v5/slurm/probe.sbatch b/v5/slurm/probe.sbatch new file mode 100644 index 0000000..c4f2a17 --- /dev/null +++ b/v5/slurm/probe.sbatch @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_probe +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=32G +#SBATCH --time=01:00:00 +#SBATCH --output=probe_%j.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +"${OMNISKY_PYTHON_BIN}" -m scripts.probe_sources --order "${PROBE_ORDER:-4}" --pixel "${PROBE_PIXEL:-257}" --out "${PROBE_OUT:-probe_report.json}" +"${OMNISKY_PYTHON_BIN}" -m scripts.probe_crossmatch --order "${PROBE_ORDER:-4}" --pixel "${PROBE_PIXEL:-257}" --smith42-revision "${SMITH42_REVISION:?set immutable HF revision}" --out "${CROSSMATCH_OUT:-crossmatch_probe.json}" +"${OMNISKY_PYTHON_BIN}" -m scripts.check_phase0_gate --probe "${PROBE_OUT:-probe_report.json}" --crossmatch "${CROSSMATCH_OUT:-crossmatch_probe.json}" diff --git a/v5/slurm/release.sbatch b/v5/slurm/release.sbatch new file mode 100644 index 0000000..4d4c2c5 --- /dev/null +++ b/v5/slurm/release.sbatch @@ -0,0 +1,17 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_release +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=32G +#SBATCH --time=02:00:00 +#SBATCH --output=release_%j.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +"${OMNISKY_PYTHON_BIN}" -m scripts.finalize_release --release-root "${RELEASE_ROOT:?}" +"${OMNISKY_PYTHON_BIN}" -m scripts.validate_release --release-root "${RELEASE_ROOT:?}" --min-instruments "${MIN_INSTRUMENTS:-2}" +args=(--release-root "${RELEASE_ROOT:?}" --repo "${HF_REPO:?}") +if [[ -n "${DRY_RUN:-}" ]]; then args+=(--dry-run); fi +"${OMNISKY_PYTHON_BIN}" -m scripts.upload_hf "${args[@]}" diff --git a/v5/slurm/run_source_array.sbatch b/v5/slurm/run_source_array.sbatch new file mode 100644 index 0000000..d8af95b --- /dev/null +++ b/v5/slurm/run_source_array.sbatch @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_run_source +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --time=08:00:00 +#SBATCH --output=run_source_%A_%a.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +args=(--manifest "${MANIFEST:?}" --task-id "${SLURM_ARRAY_TASK_ID:?}" --partition-id "${PARTITION_ID:-0}" --num-partitions "${NUM_PARTITIONS:-1}" --release-root "${RELEASE_ROOT:?}" --code-sha "${CODE_SHA:-local}") +if [[ -n "${TEST_MODE:-}" ]]; then args+=(--test-mode); fi +"${OMNISKY_PYTHON_BIN}" -m scripts.run_source_shard "${args[@]}" diff --git a/v5/slurm/upload_hf.sbatch b/v5/slurm/upload_hf.sbatch new file mode 100644 index 0000000..78d5ad2 --- /dev/null +++ b/v5/slurm/upload_hf.sbatch @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_upload +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=16G +#SBATCH --time=02:00:00 +#SBATCH --output=upload_%j.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +args=(--release-root "${RELEASE_ROOT:?}" --repo "${HF_REPO:?}") +if [[ -n "${DRY_RUN:-}" ]]; then args+=(--dry-run); fi +"${OMNISKY_PYTHON_BIN}" -m scripts.upload_hf "${args[@]}" diff --git a/v5/slurm/validate_release.sbatch b/v5/slurm/validate_release.sbatch new file mode 100644 index 0000000..53824dd --- /dev/null +++ b/v5/slurm/validate_release.sbatch @@ -0,0 +1,16 @@ +#!/bin/bash +#SBATCH --job-name=omnisky_validate +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=32G +#SBATCH --time=02:00:00 +#SBATCH --output=validate_%j.log +set -euo pipefail +source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/common.sh" +omnisky_cd_repo_root +omnisky_activate_env +"${OMNISKY_PYTHON_BIN}" -m scripts.verify_markers --release-root "${RELEASE_ROOT:?}" --manifest "${MANIFEST:?}" --code-sha "${CODE_SHA:-local}" +"${OMNISKY_PYTHON_BIN}" -m scripts.finalize_release --release-root "${RELEASE_ROOT:?}" +"${OMNISKY_PYTHON_BIN}" -m scripts.validate_release --release-root "${RELEASE_ROOT:?}" --min-instruments "${MIN_INSTRUMENTS:-2}" +"${OMNISKY_PYTHON_BIN}" -m scripts.false_match_report --out "${FALSE_MATCH_OUT:-false_match_report.json}" --threshold "${FALSE_MATCH_THRESHOLD:-0.001}" diff --git a/v5/tests/__init__.py b/v5/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/v5/tests/test_concordance.py b/v5/tests/test_concordance.py new file mode 100644 index 0000000..5b8b695 --- /dev/null +++ b/v5/tests/test_concordance.py @@ -0,0 +1,74 @@ +import numpy as np + +import pytest + +from astropy_healpix import HEALPix +import astropy.units as u + +from mmu.concordance import ( + filter_reference_to_healpix_pixel, + filter_reference_to_our_footprint, + match_concordance, +) + + +def test_identical_sets_full_recall(): + ra = np.array([10.0, 11.0, 12.0]) + dec = np.array([0.0, 1.0, 2.0]) + out = match_concordance(ra, dec, ra, dec, tol_arcsec=1.0) + assert out["recall"] == 1.0 + assert out["recovered"] == 3 + assert out["median_sep_arcsec"] < 1e-6 + + +def test_disjoint_sets_zero_recall(): + out = match_concordance( + np.array([10.0]), + np.array([0.0]), + np.array([200.0]), + np.array([-50.0]), + tol_arcsec=1.0, + ) + assert out["recall"] == 0.0 + assert out["recovered"] == 0 + + +def test_partial_recall_within_tolerance(): + our_ra = np.array([10.0, 11.0]) + our_dec = np.array([0.0, 0.0]) + ref_ra = np.array([10.0 + 0.5 / 3600.0, 50.0]) + ref_dec = np.array([0.0, 0.0]) + out = match_concordance(our_ra, our_dec, ref_ra, ref_dec, tol_arcsec=1.0) + assert out["recovered"] == 1 + assert out["recall"] == 0.5 + + +def test_filter_reference_to_probed_footprint(): + our_ra = np.array([10.0]) + our_dec = np.array([0.0]) + ref_ra = np.array([10.0 + 0.2 / 3600.0, 50.0]) + ref_dec = np.array([0.0, 0.0]) + kept_ra, kept_dec = filter_reference_to_our_footprint( + our_ra, our_dec, ref_ra, ref_dec, footprint_arcsec=1.0 + ) + assert kept_ra.tolist() == [ref_ra[0]] + assert kept_dec.tolist() == [ref_dec[0]] + + +def test_filter_reference_to_healpix_pixel_does_not_depend_on_our_matches(): + hp = HEALPix(nside=2**4, order="nested") + ra = np.array([10.0, 10.1, 200.0]) + dec = np.array([0.0, 0.0, 0.0]) + pixels = np.asarray(hp.lonlat_to_healpix(ra * u.deg, dec * u.deg)) + pixel = int(pixels[0]) + kept_ra, kept_dec = filter_reference_to_healpix_pixel(ra, dec, order=4, pixel=pixel) + expected = pixels == pixel + np.testing.assert_array_equal(kept_ra, ra[expected]) + np.testing.assert_array_equal(kept_dec, dec[expected]) + + +def test_rejects_non_positive_tolerances(): + with pytest.raises(ValueError, match="tol_arcsec"): + match_concordance(np.array([1.0]), np.array([0.0]), np.array([1.0]), np.array([0.0]), tol_arcsec=0.0) + with pytest.raises(ValueError, match="footprint_arcsec"): + filter_reference_to_our_footprint(np.array([1.0]), np.array([0.0]), np.array([1.0]), np.array([0.0]), 0.0) diff --git a/v5/tests/test_end_to_end.py b/v5/tests/test_end_to_end.py new file mode 100644 index 0000000..bdcb778 --- /dev/null +++ b/v5/tests/test_end_to_end.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +from mmu.records import read_jsonl +from scripts.validate_release import validate_rows +from scripts.verify_markers import audit_states, collect_states + + +def run_module(args: list[str], *, cwd: Path) -> None: + cmd = [sys.executable, "-m", *args] + result = subprocess.run(cmd, cwd=cwd, text=True, capture_output=True) + assert result.returncode == 0, result.stdout + result.stderr + + +def test_test_mode_end_to_end_galaxy_star_agn(tmp_path): + root = tmp_path / "release_root" + manifest_dir = root / "manifest" + cwd = Path(__file__).resolve().parents[1] + code_sha = "test" + for population in ["galaxy", "star", "agn"]: + run_module(["scripts.build_seed_catalogs", "--population", population, + "--release-root", str(root), "--out", "unused", "--test-mode", "--n", "5"], cwd=cwd) + run_module(["scripts.plan_work_units", "--sources", "desi,hsc", "--n-objects", "0", + "--population", population, "--release-root", str(root), + "--shard-size", "3", "--out", str(manifest_dir / population)], cwd=cwd) + manifest = manifest_dir / population / "work_units.json" + with manifest.open() as f: + units = json.load(f)["units"] + for task_id in range(len(units)): + run_module(["scripts.run_source_shard", "--manifest", str(manifest), + "--task-id", str(task_id), "--release-root", str(root), + "--code-sha", code_sha, "--test-mode"], cwd=cwd) + for shard in sorted({int(unit["shard"]) for unit in units}): + run_module(["scripts.finalize_shard", "--release-root", str(root), + "--manifest", str(manifest), "--population", population, + "--sources", "desi,hsc", "--shard", str(shard), + "--code-sha", code_sha], cwd=cwd) + states = collect_states(release_root=root, manifest_path=manifest, + schema_version="v5.1", code_sha=code_sha) + assert audit_states(states)["passed"] is True + + run_module(["scripts.finalize_release", "--release-root", str(root)], cwd=cwd) + result = validate_rows(read_jsonl(root / "release" / "data.jsonl"), min_instruments=2) + assert result["ok"] is True + rows = read_jsonl(root / "release" / "data.jsonl") + assert len(rows) == 15 + assert {row["population"] for row in rows} == {"galaxy", "star", "agn"} + assert all(row["n_instruments_present"] >= 2 for row in rows) + assert len({row["object_uid"] for row in rows}) == len(rows) + run_module(["scripts.validate_release", "--release-root", str(root)], cwd=cwd) + run_module(["scripts.upload_hf", "--release-root", str(root), "--repo", "UniverseTBD/test", "--dry-run"], cwd=cwd) diff --git a/v5/tests/test_ids.py b/v5/tests/test_ids.py new file mode 100644 index 0000000..a47c66c --- /dev/null +++ b/v5/tests/test_ids.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest + +from mmu.ids import NSIDE, ORDER, assign_global_id + + +def test_order_and_nside(): + assert ORDER == 29 + assert NSIDE == 2 ** 29 + + +def test_deterministic_int64_and_shape(): + ra = np.array([10.684, 83.822, 201.365]) + dec = np.array([41.269, -5.391, -47.488]) + a = assign_global_id(ra, dec) + b = assign_global_id(ra, dec) + assert a.dtype == np.int64 + assert a.shape == ra.shape + np.testing.assert_array_equal(a, b) + + +def test_distinct_points_distinct_ids(): + a = assign_global_id(np.array([10.0]), np.array([20.0])) + b = assign_global_id(np.array([200.0]), np.array([-30.0])) + assert a[0] != b[0] + + +def test_scalar_inputs_supported(): + out = assign_global_id(10.684, 41.269) + assert out.shape == (1,) + assert out.dtype == np.int64 + + +def test_rejects_invalid_coordinates(): + with pytest.raises(ValueError, match="finite"): + assign_global_id(np.array([10.0, np.nan]), np.array([0.0, 1.0])) + with pytest.raises(ValueError, match="ra"): + assign_global_id(np.array([360.0]), np.array([0.0])) + with pytest.raises(ValueError, match="dec"): + assign_global_id(np.array([10.0]), np.array([91.0])) + with pytest.raises(ValueError, match="shape"): + assign_global_id(np.array([10.0, 11.0]), np.array([0.0])) diff --git a/v5/tests/test_local_lsdb_dry_run.py b/v5/tests/test_local_lsdb_dry_run.py new file mode 100644 index 0000000..73a6df6 --- /dev/null +++ b/v5/tests/test_local_lsdb_dry_run.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import json + +import pytest + +from scripts.local_lsdb_dry_run import ( + DEFAULT_SMITH42_REVISION, + PixelSpec, + SearchSpec, + build_report, + import_lsdb_or_exit, + is_no_coverage_error, + output_path, + output_search_path, + parse_byte_budget, + parse_catalog_overrides, + parse_columns, + parse_cones, + parse_pixels, + pixel_searches, + rows_that_fit, + select_sources, +) + + +def test_parse_byte_budget_decimal_and_binary_units(): + assert parse_byte_budget("1GB") == 1_000_000_000 + assert parse_byte_budget("512MiB") == 512 * 1024 * 1024 + assert parse_byte_budget("10_000") == 10_000 + with pytest.raises(ValueError): + parse_byte_budget("0") + with pytest.raises(ValueError): + parse_byte_budget("nope") + + +def test_parse_pixels_accepts_explicit_or_default_order(): + assert parse_pixels(["4:257"]) == [PixelSpec(order=4, pixel=257)] + assert parse_pixels(["257", "258"], default_order=4) == [ + PixelSpec(order=4, pixel=257), + PixelSpec(order=4, pixel=258), + ] + with pytest.raises(ValueError, match="bare pixels require"): + parse_pixels(["257"]) + with pytest.raises(ValueError, match="non-negative"): + parse_pixels(["4:-1"]) + + +def test_parse_cones_and_pixel_searches(): + assert parse_cones(["150.1,2.2,3600"]) == [ + SearchSpec(kind="cone", label="cone_ra=150.100000_dec=2.200000_radius_arcsec=3600", + ra=150.1, dec=2.2, radius_arcsec=3600.0) + ] + assert pixel_searches([PixelSpec(order=4, pixel=257)]) == [ + SearchSpec(kind="pixel", label="order=4/pixel=000257", order=4, pixel=257) + ] + with pytest.raises(ValueError, match="RA"): + parse_cones(["360,0,10"]) + with pytest.raises(ValueError, match="positive"): + parse_cones(["150,2,0"]) + + +def test_columns_sources_and_catalog_overrides(): + assert parse_columns(None) == ["ra", "dec"] + assert parse_columns("flux") == ["ra", "dec", "flux"] + assert parse_columns("dec,ra,flux") == ["dec", "ra", "flux"] + assert [source.name for source in select_sources("desi,hsc")] == ["desi", "hsc"] + with pytest.raises(ValueError, match="not an LSDB"): + select_sources("legacy") + assert parse_catalog_overrides(["desi=/tmp/desi", "hsc=hf://datasets/example/hsc"]) == { + "desi": "/tmp/desi", + "hsc": "hf://datasets/example/hsc", + } + + +def test_rows_that_fit_and_output_path(tmp_path): + assert rows_that_fit(n_rows=10, frame_bytes=100, remaining_bytes=200) == 10 + assert rows_that_fit(n_rows=10, frame_bytes=100, remaining_bytes=55) == 5 + assert rows_that_fit(n_rows=0, frame_bytes=100, remaining_bytes=55) == 0 + path = output_path(tmp_path, source="desi", pixel=PixelSpec(order=4, pixel=257), suffix="parquet") + assert path.as_posix().endswith("source=desi/order=4/pixel=000257.parquet") + cone_path = output_search_path( + tmp_path, + source="hsc", + search=SearchSpec(kind="cone", label="cone_ra=150.100000_dec=2.200000_radius_arcsec=3600"), + suffix="jsonl", + ) + assert cone_path.as_posix().endswith("source=hsc/cone_ra-150.100000_dec-2.200000_radius_arcsec-3600.jsonl") + assert is_no_coverage_error(ValueError("The selected sky region has no coverage")) + + +def test_report_shape_and_default_revision(): + report = build_report( + max_bytes=100, + stored_bytes=40, + fetches=[{"source": "desi", "stored_rows": 3}], + crossmatch={"attempted": True, "matched_rows": 1}, + ) + assert report["remaining_bytes"] == 60 + assert report["cap_reached"] is False + assert report["crossmatch"]["matched_rows"] == 1 + assert DEFAULT_SMITH42_REVISION == "93d0fddf8c5b61028ee0b6d72fd0dbfa87b38624" + json.dumps(report) + + +def test_missing_lsdb_error_is_actionable(monkeypatch: pytest.MonkeyPatch): + def missing_lsdb(name: str): + raise ModuleNotFoundError("No module named 'lsdb'", name=name) + + monkeypatch.setattr("scripts.local_lsdb_dry_run.import_module", missing_lsdb) + with pytest.raises(SystemExit) as excinfo: + import_lsdb_or_exit() + message = str(excinfo.value) + assert "LSDB is not installed" in message + assert "conda env create -f environment.yml" in message diff --git a/v5/tests/test_phase0_gate.py b/v5/tests/test_phase0_gate.py new file mode 100644 index 0000000..740c408 --- /dev/null +++ b/v5/tests/test_phase0_gate.py @@ -0,0 +1,48 @@ +import json +from scripts.check_phase0_gate import evaluate_gate + +def test_gate_passes(tmp_path): + probe = {"internet_ok": True, "notes": "", "sources": [ + {"name": "mmu_desi_edr_sv3", "reachable": True, "cold_latency_s": 18.0, + "warm_latency_s": 0.6, "throughput_mb_s": 40.0, "rate_limited": False, + "n_rows_sampled": 2000}]} + xm = {"recall": 0.95, "n_matched_pixel": 120, "n_ref_footprint": 100, + "median_sep_arcsec": 0.2} + assert evaluate_gate(probe, xm, min_recall=0.8)["passed"] is True + +def test_gate_fails_on_no_internet(): + probe = {"internet_ok": False, "notes": "", "sources": []} + xm = {"recall": 0.99, "n_matched_pixel": 100, "n_ref_footprint": 100, + "median_sep_arcsec": 0.2} + res = evaluate_gate(probe, xm, min_recall=0.8) + assert res["passed"] is False + assert "internet" in " ".join(res["reasons"]).lower() + +def test_gate_fails_on_low_recall(): + probe = {"internet_ok": True, "notes": "", "sources": [ + {"name": "x", "reachable": True, "cold_latency_s": 1.0, "warm_latency_s": 0.5, + "throughput_mb_s": 10.0, "rate_limited": False, "n_rows_sampled": 10}]} + xm = {"recall": 0.3, "n_matched_pixel": 100, "n_ref_footprint": 100, + "median_sep_arcsec": 0.2} + res = evaluate_gate(probe, xm, min_recall=0.8) + assert res["passed"] is False + + +def test_gate_fails_on_empty_probe_sample(): + probe = {"internet_ok": True, "notes": "", "sources": [ + {"name": "x", "reachable": True, "cold_latency_s": 1.0, "warm_latency_s": 0.5, + "throughput_mb_s": 0.0, "rate_limited": False, "n_rows_sampled": 0}]} + xm = {"recall": 0.99, "n_matched_pixel": 10, "n_ref_footprint": 10, + "median_sep_arcsec": 0.2} + res = evaluate_gate(probe, xm, min_recall=0.8) + assert res["passed"] is False + assert "zero rows" in " ".join(res["reasons"]) + + +def test_gate_fails_without_reference_footprint(): + probe = {"internet_ok": True, "notes": "", "sources": [ + {"name": "x", "reachable": True, "cold_latency_s": 1.0, "warm_latency_s": 0.5, + "throughput_mb_s": 1.0, "rate_limited": False, "n_rows_sampled": 10}]} + xm = {"recall": 1.0, "n_matched_pixel": 10, "n_ref_footprint": 0, + "median_sep_arcsec": 0.2} + assert evaluate_gate(probe, xm, min_recall=0.8)["passed"] is False diff --git a/v5/tests/test_phase1_phase2_core.py b/v5/tests/test_phase1_phase2_core.py new file mode 100644 index 0000000..7916aed --- /dev/null +++ b/v5/tests/test_phase1_phase2_core.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import dataclasses +import os +import importlib.util +from importlib import import_module + +import numpy as np +import pytest + +from mmu.config import DEFAULT, SOURCES +from mmu.coordination import build_or_attach_manifest, finalize_ready, partition_units +from mmu.healpix import neighbor_pixels, seed_pixel_set +from mmu.io_atomic import atomic_write_bytes, unit_state, write_done_marker +from mmu.matching import adjudicate, match_unit +from mmu.motion import MissingParallaxPolicy, propagate_to_epoch +from mmu.rate_limit import TokenBroker +from mmu.records import read_jsonl, write_jsonl +from mmu.sources.local_fits import crop_or_pad_flux, dedup_highest_snr +from mmu.sources.lsdb_mmu import candidate_dict_from_crossmatch, collection_uri +from mmu.sources.ztf_s3 import lightcurve_to_fixed, s3_hats_uri +from scripts.build_seed_catalogs import assign_ids_and_assert_unique, namespace_ids +from scripts.convert_to_hats import conversion_report, needs_conversion +from scripts.false_match_report import build_report, confusion_radius, gate_false_match, parse_bin_values, pm_scramble, random_offsets +from scripts.finalize_shard import assign_split, enforce_min_instruments +from scripts.finalize_release import summarize_final_shards +from scripts.plan_work_units import enumerate_units +from scripts.run_source_shard import select_unit +from scripts.upload_hf import build_repo_id +from scripts.validate_release import check_min_modalities, check_uniqueness +from scripts.verify_gaia_xmatch import chunk_for_xmatch +from scripts.verify_markers import audit_states + + +def test_config_sources_and_immutability(): + assert DEFAULT.schema_version == "v5.1" + assert str(DEFAULT.release_root) == os.environ.get("OMNISKY_RELEASE_ROOT", "release/v5") + assert SOURCES["hsc"].radius_arcsec <= 1.0 + assert SOURCES["legacy"].kind == "legacy_hdf5" + assert SOURCES["legacy"].dataset == os.environ.get("OMNISKY_LEGACY_ROOT", "data/legacy_dr10_south_21") + assert SOURCES["apogee"].org == "hugging-science" + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(DEFAULT, "shard_size", 1) + + +def test_healpix_helpers_nested(): + pixels = seed_pixel_set(np.array([10.0, 10.1]), np.array([2.0, 2.1]), order=4) + assert pixels and all(isinstance(p, int) for p in pixels) + pix = next(iter(pixels)) + assert pix in neighbor_pixels(pix, order=4) + + +def test_adjudicate_and_candidate_grouping(): + d = candidate_dict_from_crossmatch(np.array([0, 0, 1]), np.array([11, 22, 33]), + np.array([0.5, 0.3, 0.2]), np.ones(3)) + res = adjudicate(d, radius_arcsec=1.0) + assert res[0]["src_index"] == 22 + assert res[0]["match_ambiguous"] is True + assert res[1]["n_candidates_within_radius"] == 1 + + +def test_collection_uri(): + assert collection_uri("UniverseTBD", "mmu_desi_edr_sv3") == "hf://datasets/UniverseTBD/mmu_desi_edr_sv3" + + +def test_io_atomic_state_machine(tmp_path): + p = tmp_path / "unit.bin" + assert unit_state(p, manifest_hash="m", schema_version="v5.1", code_sha="c") == "pending" + atomic_write_bytes(b"ok", p) + assert unit_state(p, manifest_hash="m", schema_version="v5.1", code_sha="c") == "suspicious" + write_done_marker(p, manifest_hash="m", schema_version="v5.1", code_sha="c") + assert unit_state(p, manifest_hash="m", schema_version="v5.1", code_sha="c") == "complete" + assert unit_state(p, manifest_hash="other", schema_version="v5.1", code_sha="c") == "stale" + p.write_bytes(b"corrupt") + assert unit_state(p, manifest_hash="m", schema_version="v5.1", code_sha="c") == "corrupt" + + +def test_coordination_manifest_and_partitions(tmp_path): + units = [{"source": "desi", "shard": 0}, {"source": "hsc", "shard": 0}] + h1, created = build_or_attach_manifest(tmp_path, units, inputs_hash="i") + h2, created2 = build_or_attach_manifest(tmp_path, units, inputs_hash="i") + assert h1 == h2 and created is True and created2 is False + assert partition_units(list(range(5)), partition_id=1, num_partitions=2) == [1, 3] + assert finalize_ready(0, sources=["desi", "hsc"], completed={("desi", 0), ("hsc", 0)}) + + +def test_manifest_rejects_same_inputs_hash_with_different_units(tmp_path): + build_or_attach_manifest(tmp_path, [{"source": "desi", "shard": 0}], inputs_hash="same") + with pytest.raises(ValueError, match="units conflict"): + build_or_attach_manifest(tmp_path, [{"source": "hsc", "shard": 0}], inputs_hash="same") + + +def test_seed_ids_are_unique_and_namespaced(): + ids = assign_ids_and_assert_unique(np.array([10.0, 200.0]), np.array([2.0, -30.0]), population="star") + assert len(set(ids.tolist())) == 2 + assert namespace_ids(np.array([1], dtype=np.int64), "star")[0] == "star:1" + assert namespace_ids(np.array([1], dtype=np.int64), "star")[0] != namespace_ids(np.array([1], dtype=np.int64), "galaxy")[0] + high_healpix = np.array([2**61], dtype=np.int64) + assert namespace_ids(high_healpix, "galaxy")[0] != namespace_ids(high_healpix, "star")[0] + with pytest.raises(ValueError): + assign_ids_and_assert_unique(np.array([10.0, 10.0]), np.array([2.0, 2.0])) + + +def test_plan_and_select_units(): + units = enumerate_units(sources=["desi", "hsc"], n_objects=120_000, shard_size=50_000) + assert len(units) == 6 + assert select_unit(units, task_id=1, partition_id=0, num_partitions=2)["shard"] == 1 + + +def test_finalize_helpers(): + mask = np.array([[True, True, False], [True, False, False]]) + assert enforce_min_instruments(mask, min_instruments=2).tolist() == [True, False] + split = assign_split(np.array([10.0, 10.0]), np.array([2.0, 2.0]), nside=8, seed=42) + assert split[0] == split[1] + assert set(split.tolist()) <= {"train", "val", "test"} + + +def test_false_match_helpers(): + dra, ddec = random_offsets(n=1000, r_min_arcsec=5.0, r_max_arcsec=30.0, seed=1) + r = np.hypot(dra, ddec) + assert (r >= 5.0 - 1e-6).all() and (r <= 30.0 + 1e-6).all() + gate = gate_false_match({"ok": 0.0005, "bad": 0.02}, threshold=0.001) + assert gate["passed_bins"] == ["ok"] and gate["low_confidence_bins"] == ["bad"] + sra, sdec = pm_scramble(np.array([3.0, 4.0]), np.array([4.0, 3.0]), seed=2) + np.testing.assert_allclose(np.hypot(sra, sdec), np.array([5.0, 5.0])) + assert confusion_radius(1.0, np.array([0.0, 1000.0]), 10.0, 0.0, 0.1)[1] > 1.0 + assert parse_bin_values(["high_lat=0.0005"]) == {"high_lat": 0.0005} + assert build_report({"ok": 0.0}, threshold=0.001)["passed"] is True + + +def test_validation_upload_marker_xmatch_helpers(): + assert audit_states(["complete", "complete"])["passed"] is True + assert audit_states(["complete", "corrupt"])["bad_states"] == ["corrupt"] + assert check_uniqueness(np.array([1, 2, 3]))["ok"] is True + assert check_uniqueness(np.array([1, 1]))["ok"] is False + assert check_min_modalities(np.array([2, 3]), 2)["ok"] is True + assert build_repo_id("omnisky-v5") == "UniverseTBD/omnisky-v5" + assert chunk_for_xmatch(2_000_001, max_rows=2_000_000) == [(0, 2_000_000), (2_000_000, 2_000_001)] + + +def test_token_broker_accounting(): + broker = TokenBroker(2) + assert broker.available == 2 + with broker.token(): + assert broker.in_use == 1 + with broker.token(): + assert broker.in_use == 2 + assert broker.in_use == 1 + assert broker.available == 2 + with pytest.raises(RuntimeError): + broker.release() + + +def test_finalize_release_summary(tmp_path): + p = tmp_path / "final" / "population=galaxy" / "shard=000000.jsonl" + write_jsonl([{"population": "galaxy", "split": "train"}], p) + summary = summarize_final_shards([p]) + assert summary["n_rows"] == 1 + assert summary["populations"] == {"galaxy": 1} + + +def test_conversion_and_ztf_helpers(): + assert needs_conversion(SOURCES["desi"]) is False + assert needs_conversion(SOURCES["sdss_dr16q"]) is True + report = conversion_report([{"name": "a", "converted": True}, {"name": "b", "converted": False, "reason": "license"}]) + assert report["converted"] == ["a"] + assert report["flagged_unconvertible"] == [{"name": "b", "reason": "license"}] + assert s3_hats_uri("dr24") == "s3://ipac-irsa-ztf/ztf/enhanced/dr24/lc/hats" + lc = lightcurve_to_fixed([1, 2], [10, 11], [0.1, 0.2], max_len=3) + assert lc["valid"].tolist() == [True, True, False] + + +def test_local_fits_pure_helpers(): + keep = dedup_highest_snr(np.array(["a", "a", "b"]), np.array([10.0, 20.0, 5.0])) + assert keep.tolist() == [1, 2] + np.testing.assert_allclose(crop_or_pad_flux(np.array([1, 2]), 4)[:2], np.array([1, 2])) + assert np.isnan(crop_or_pad_flux(np.array([1, 2]), 4)[2]) + + +@pytest.mark.skipif(importlib.util.find_spec("pyarrow") is None, reason="pyarrow not installed locally") +def test_schema_when_pyarrow_available(): + pa = import_module("pyarrow") + from mmu.schemas import SCHEMA_VERSION, final_schema + + schema = final_schema(image_px=2, n_bands=2, spec_len=3) + assert schema.field("global_object_id").type == pa.int64() + assert schema.metadata[b"schema_version"] == SCHEMA_VERSION.encode() + + +def test_motion_propagation_with_parallax_and_missing_policy(): + out = propagate_to_epoch(ra=np.array([100.0]), dec=np.array([0.0]), + pmra=np.array([1000.0]), pmdec=np.array([0.0]), + parallax_mas=np.array([50.0]), rv_kms=np.array([0.0]), + from_epoch_jyear=2016.0, to_epoch_jyear=2006.0) + assert abs(out["ra"][0] - 100.0) > 1e-5 + missing = propagate_to_epoch(ra=np.array([100.0]), dec=np.array([0.0]), + pmra=np.array([1000.0]), pmdec=np.array([0.0]), + parallax_mas=np.array([np.nan]), rv_kms=np.array([0.0]), + from_epoch_jyear=2016.0, to_epoch_jyear=2006.0, + policy=MissingParallaxPolicy.FLAG) + assert missing["motion_flag"][0] == "missing_parallax" + dropped = propagate_to_epoch(ra=np.array([100.0]), dec=np.array([0.0]), + pmra=np.array([1000.0]), pmdec=np.array([0.0]), + parallax_mas=np.array([-1.0]), rv_kms=np.array([0.0]), + from_epoch_jyear=2016.0, to_epoch_jyear=2006.0, + policy=MissingParallaxPolicy.DROP) + assert bool(dropped["drop"][0]) is True + + +def test_population_aware_match_unit(): + galaxy = match_unit({"ra": np.array([1.0]), "dec": np.array([2.0])}, {"epoch_jyear": 2010.0}, population="galaxy") + assert galaxy["ra"].tolist() == [1.0] diff --git a/v5/tests/test_probe_report.py b/v5/tests/test_probe_report.py new file mode 100644 index 0000000..32ed48c --- /dev/null +++ b/v5/tests/test_probe_report.py @@ -0,0 +1,30 @@ +from mmu.probe_report import SourceProbe, ProbeReport + + +def test_roundtrip_json(tmp_path): + rep = ProbeReport(internet_ok=True, notes="delta cpu node") + rep.add(SourceProbe(name="mmu_desi_edr_sv3", reachable=True, + cold_latency_s=18.2, warm_latency_s=0.6, + throughput_mb_s=42.5, rate_limited=False, n_rows_sampled=2000)) + p = tmp_path / "probe_report.json" + rep.to_json(p) + back = ProbeReport.from_json(p) + assert back.internet_ok is True + assert back.sources[0].name == "mmu_desi_edr_sv3" + assert back.sources[0].throughput_mb_s == 42.5 + + +def test_concurrency_cap_from_throughput(): + sp = SourceProbe(name="x", reachable=True, cold_latency_s=1.0, warm_latency_s=0.5, + throughput_mb_s=10.0, rate_limited=True, n_rows_sampled=100) + assert sp.suggested_concurrency() == 2 + sp2 = SourceProbe(name="y", reachable=True, cold_latency_s=1.0, warm_latency_s=0.5, + throughput_mb_s=10.0, rate_limited=False, n_rows_sampled=100) + assert sp2.suggested_concurrency() == 16 + slow = SourceProbe(name="slow", reachable=True, cold_latency_s=10.0, warm_latency_s=5.0, + throughput_mb_s=0.5, rate_limited=False, n_rows_sampled=10) + assert slow.suggested_concurrency() == 4 + failed = SourceProbe(name="bad", reachable=False, cold_latency_s=-1.0, warm_latency_s=-1.0, + throughput_mb_s=0.0, rate_limited=False, n_rows_sampled=0, + error="boom") + assert failed.suggested_concurrency() == 0 diff --git a/v5/tests/test_reachability.py b/v5/tests/test_reachability.py new file mode 100644 index 0000000..7a008fb --- /dev/null +++ b/v5/tests/test_reachability.py @@ -0,0 +1,15 @@ +from mmu.reachability import summarize_reachability + + +def test_all_reachable(): + results = {"huggingface": (True, 0.21), "s3": (True, 0.10), "cds": (True, 0.55)} + s = summarize_reachability(results) + assert s["internet_ok"] is True + assert s["unreachable"] == [] + + +def test_one_unreachable_blocks(): + results = {"huggingface": (True, 0.21), "s3": (False, None), "cds": (True, 0.55)} + s = summarize_reachability(results) + assert s["internet_ok"] is False + assert s["unreachable"] == ["s3"]