diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e6c5da..0d068b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.2.0] - 26-03-2026 + +### Added + +- **Enriched CSV output for post hoc analysis** (`designer/multiplexpanel.py`, `designer/primer.py`): `selected_multiplex.csv`, `top_panels.csv`, and `candidate_pairs.csv` now include 12 additional per-primer and per-amplicon columns: `Forward/Reverse_Self_Any_Th`, `Forward/Reverse_Self_End_Th`, `Forward/Reverse_Hairpin_Th`, `Forward/Reverse_End_Stability`, `Forward/Reverse_Penalty`, `Amplicon_GC`, and `Num_Candidate_Pairs`. These expose design-time thermodynamic features needed to correlate primer properties with wet-lab validation coverage. +- **Per-assay cross-dimer contribution** (`designer/multiplexpanel.py`): `selected_multiplex.csv` and `top_panels.csv` now include a `Cross_Dimer_Contribution` column quantifying the sum of cross-dimer interaction scores between each assay's primers and all other primers in the multiplex. Mirrors the scoring logic in `selector/cost.py` with cached pairwise alignments. + +- **Primer-level cross-dimer matrix** (`reporting/qc.py`): `generate_panel_qc()` now computes a full primer-to-primer dimer matrix for the selected multiplex, covering all C(2n, 2) individual primer interactions (including intra-junction pairs). The matrix is returned alongside the existing junction-level summary. Tail sequences (`forward_tail`, `reverse_tail`) are now included in the dimer alignment for realistic scoring. A new `save_primer_dimer_matrix_csv()` function exports the symmetric matrix as a CSV file (`primer_dimer_matrix.csv`), written automatically by the pipeline. +- **Primer-level dimer heatmap in HTML report** (`reporting/templates/panel_report.html.j2`): New interactive Plotly heatmap section showing individual primer-to-primer dimer scores, displayed below the existing junction-level heatmap. + +### Fixed + +- **Cross-reactivity heatmap not rendering** (`reporting/templates/plotly.min.js.gz`): The bundled Plotly.js was the "basic" partial build which only includes scatter, bar, and pie trace types. The heatmap trace type was missing, causing the cross-reactivity heatmap to silently fail. Replaced with the full Plotly.js v2.35.2 bundle. + +- **Target dropout for DFS selector** (`selector/selectors.py`): The DFS selector can now optionally drop targets whose primers cause extreme cross-dimer interactions that "poison" the panel. Enabled via `allow_target_dropping: true` in the multiplex picker config. An adaptive dropout penalty is computed from the greedy seed's marginal cross-dimer costs at a configurable percentile (`dropout_stringency`, default 0.8), so only true outliers are removed. A hard floor prevents excessive dropping: `max(minimum_plexity, ceil(min_target_fraction * n_input))`, clamped to the actual input count. +- **New config fields** (`config.py`, preset JSON files): `allow_target_dropping` (bool, default false), `dropout_stringency` (float 0-1, default 0.8), `min_target_fraction` (float 0-1, default 0.8). +- **Dropped targets in output** (`designer/multiplexpanel.py`): `panel_summary.json` includes a `dropped_targets` list. `top_panels.csv` includes `Num_Dropped` and `Dropped_Targets` columns. Dropped targets are also logged as warnings and surfaced in CLI output. +- **Plexity clamping** (`pipeline.py`): When the number of input targets is less than the configured `maximum_plexity` or `minimum_plexity`, the effective plexity values are clamped to the actual input count, preventing nonsensical constraints. + ## [1.1.0] - 13-03-2026 ### Changed diff --git a/docker/DOCKERFILE b/docker/DOCKERFILE index 9511e99..8f2c684 100644 --- a/docker/DOCKERFILE +++ b/docker/DOCKERFILE @@ -1,6 +1,6 @@ ARG BLAST_VERSION=2.17.0 ARG BCFTOOLS_VERSION=1.23 -ARG PLEXUS_VERSION=1.0.0 +ARG PLEXUS_VERSION=1.2.0 # ── Stage 1: Python venv builder ───────────────────────────────────────────── FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder @@ -43,18 +43,18 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Download and extract exact BLAST+ binaries (NCBI prebuilt) RUN wget -qO /tmp/blast.tar.gz \ - "https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/${BLAST_VERSION}/ncbi-blast-${BLAST_VERSION}+-x64-linux.tar.gz" \ + "https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/${BLAST_VERSION}/ncbi-blast-${BLAST_VERSION}+-x64-linux.tar.gz" \ && tar -xzf /tmp/blast.tar.gz -C /tmp \ && mkdir -p /tools/bin \ && cp /tmp/ncbi-blast-${BLAST_VERSION}+/bin/blastn \ - /tmp/ncbi-blast-${BLAST_VERSION}+/bin/makeblastdb \ - /tmp/ncbi-blast-${BLAST_VERSION}+/bin/blast_formatter \ - /tools/bin/ \ + /tmp/ncbi-blast-${BLAST_VERSION}+/bin/makeblastdb \ + /tmp/ncbi-blast-${BLAST_VERSION}+/bin/blast_formatter \ + /tools/bin/ \ && rm -rf /tmp/blast.tar.gz /tmp/ncbi-blast-${BLAST_VERSION}+ # Compile exact bcftools from source RUN wget -qO /tmp/bcftools.tar.bz2 \ - "https://github.com/samtools/bcftools/releases/download/${BCFTOOLS_VERSION}/bcftools-${BCFTOOLS_VERSION}.tar.bz2" \ + "https://github.com/samtools/bcftools/releases/download/${BCFTOOLS_VERSION}/bcftools-${BCFTOOLS_VERSION}.tar.bz2" \ && tar -xjf /tmp/bcftools.tar.bz2 -C /tmp \ && cd /tmp/bcftools-${BCFTOOLS_VERSION} \ && make -j$(nproc) \ @@ -70,12 +70,12 @@ ARG PLEXUS_VERSION # Bake version and audit metadata into the image LABEL org.opencontainers.image.title="plexus" \ - org.opencontainers.image.version="${PLEXUS_VERSION}" \ - org.opencontainers.image.description="Multiplex PCR primer panel designer" \ - org.opencontainers.image.licenses="GPL-2.0-or-later" \ - dev.plexus.blast_version="${BLAST_VERSION}" \ - dev.plexus.bcftools_version="${BCFTOOLS_VERSION}" \ - dev.plexus.compliance_mode="true" + org.opencontainers.image.version="${PLEXUS_VERSION}" \ + org.opencontainers.image.description="Multiplex PCR primer panel designer" \ + org.opencontainers.image.licenses="GPL-2.0-or-later" \ + dev.plexus.blast_version="${BLAST_VERSION}" \ + dev.plexus.bcftools_version="${BCFTOOLS_VERSION}" \ + dev.plexus.compliance_mode="true" # Non-root runtime user RUN useradd --system --create-home --shell /bin/bash plexus diff --git a/src/plexus/config.py b/src/plexus/config.py index 0c3e735..20f203f 100644 --- a/src/plexus/config.py +++ b/src/plexus/config.py @@ -183,6 +183,31 @@ class MultiplexPickerParameters(BaseModel): plexity_wt_gt: float = Field(default=1.0, ge=0.0) force_plexity: bool = Field(default=False) + + allow_target_dropping: bool = Field( + default=False, + description="Allow the DFS selector to drop targets with extreme cross-dimer toxicity.", + ) + dropout_stringency: float = Field( + default=0.8, + ge=0.0, + le=1.0, + description=( + "Percentile (0-1) of marginal cross-dimer distribution used to set " + "the dropout penalty. Higher values = stricter = fewer drops. " + "0.8 means only targets worse than the 80th percentile are drop candidates." + ), + ) + min_target_fraction: float = Field( + default=0.8, + ge=0.0, + le=1.0, + description=( + "Minimum fraction of input targets to retain. " + "Effective floor = max(minimum_plexity, ceil(min_target_fraction * n_input))." + ), + ) + allow_split_panel: bool = Field(default=False) max_splits: int = Field(default=2, ge=1, le=10) diff --git a/src/plexus/data/designer_default_config.json b/src/plexus/data/designer_default_config.json index f0cd00a..b26dc2f 100644 --- a/src/plexus/data/designer_default_config.json +++ b/src/plexus/data/designer_default_config.json @@ -94,6 +94,9 @@ "plexity_wt_lt": 1.0, "plexity_wt_gt": 1.0, "force_plexity": false, + "allow_target_dropping": false, + "dropout_stringency": 0.8, + "min_target_fraction": 0.8, "allow_split_panel": false, "max_splits": 2, "wt_pair_penalty": 1.0, diff --git a/src/plexus/data/designer_lenient_config.json b/src/plexus/data/designer_lenient_config.json index 037e371..4900e99 100644 --- a/src/plexus/data/designer_lenient_config.json +++ b/src/plexus/data/designer_lenient_config.json @@ -94,6 +94,9 @@ "plexity_wt_lt": 1.0, "plexity_wt_gt": 1.0, "force_plexity": false, + "allow_target_dropping": false, + "dropout_stringency": 0.8, + "min_target_fraction": 0.8, "allow_split_panel": false, "max_splits": 2, "wt_pair_penalty": 1.0, diff --git a/src/plexus/designer/multiplexpanel.py b/src/plexus/designer/multiplexpanel.py index e951b63..7197eaa 100644 --- a/src/plexus/designer/multiplexpanel.py +++ b/src/plexus/designer/multiplexpanel.py @@ -833,6 +833,20 @@ def save_candidate_pairs_to_csv(self, file_path: str): "Off_Target_Count": len(pair.off_target_products), "Specificity_Checked": pair.specificity_checked, "Selected": pair.selected, + "Forward_Self_Any_Th": pair.forward.self_any_th, + "Reverse_Self_Any_Th": pair.reverse.self_any_th, + "Forward_Self_End_Th": pair.forward.self_end_th, + "Reverse_Self_End_Th": pair.reverse.self_end_th, + "Forward_Hairpin_Th": pair.forward.hairpin_th, + "Reverse_Hairpin_Th": pair.reverse.hairpin_th, + "Forward_End_Stability": pair.forward.end_stability, + "Reverse_End_Stability": pair.reverse.end_stability, + "Forward_Penalty": pair.forward.penalty, + "Reverse_Penalty": pair.reverse.penalty, + "Amplicon_GC": pair.amplicon_gc, + "Num_Candidate_Pairs": ( + len(junction.primer_pairs) if junction.primer_pairs else 0 + ), } data.append(row) @@ -901,6 +915,20 @@ def _build_enriched_pair_row(self, junction: Junction, pair: PrimerPair) -> dict "SNP_Penalty": pair.snp_penalty, "Forward_SNP_Count": pair.forward.snp_count, "Reverse_SNP_Count": pair.reverse.snp_count, + "Forward_Self_Any_Th": pair.forward.self_any_th, + "Reverse_Self_Any_Th": pair.reverse.self_any_th, + "Forward_Self_End_Th": pair.forward.self_end_th, + "Reverse_Self_End_Th": pair.reverse.self_end_th, + "Forward_Hairpin_Th": pair.forward.hairpin_th, + "Reverse_Hairpin_Th": pair.reverse.hairpin_th, + "Forward_End_Stability": pair.forward.end_stability, + "Reverse_End_Stability": pair.reverse.end_stability, + "Forward_Penalty": pair.forward.penalty, + "Reverse_Penalty": pair.reverse.penalty, + "Amplicon_GC": pair.amplicon_gc, + "Num_Candidate_Pairs": ( + len(junction.primer_pairs) if junction.primer_pairs else 0 + ), } def _junction_for_pair(self, pair: PrimerPair) -> Junction | None: @@ -924,13 +952,41 @@ def save_selected_multiplex_csv(self, file_path: str, selected_pairs: list) -> N logger.warning("No selected pairs to save.") return + # Pre-compute cross-dimer scores for per-pair attribution + dimer_predictor = PrimerDimerPredictor() + dimer_cache: dict[tuple[str, str], float] = {} + all_primers: dict[str, list[tuple[str, str]]] = {} + for pair in selected_pairs: + all_primers[pair.pair_id] = [ + (pair.forward.seq, pair.forward.name), + (pair.reverse.seq, pair.reverse.name), + ] + data = [] for pair in selected_pairs: junction = self._junction_for_pair(pair) if junction is None: logger.warning(f"Could not find junction for pair {pair.pair_id}") continue - data.append(self._build_enriched_pair_row(junction, pair)) + row = self._build_enriched_pair_row(junction, pair) + + contribution = 0.0 + for other in selected_pairs: + if other.pair_id == pair.pair_id: + continue + for seq_a, name_a in all_primers[pair.pair_id]: + for seq_b, name_b in all_primers[other.pair_id]: + key = tuple(sorted((seq_a, seq_b))) + if key not in dimer_cache: + dimer_predictor.set_primers(seq_a, seq_b, name_a, name_b) + dimer_predictor.align() + dimer_cache[key] = dimer_predictor.score or 0.0 + score = dimer_cache[key] + if score < 0: + contribution += abs(score) + row["Cross_Dimer_Contribution"] = round(contribution, 4) + + data.append(row) df = pd.DataFrame(data) df.to_csv(file_path, index=False) @@ -948,19 +1004,56 @@ def save_top_panels_csv(self, file_path: str, solutions: list) -> None: return pair_lookup = self.build_pair_lookup() + dimer_predictor = PrimerDimerPredictor() + dimer_cache: dict[tuple[str, str], float] = {} data = [] for rank, solution in enumerate(solutions, start=1): + # Collect primers for this solution's pairs + sol_primers: dict[str, list[tuple[str, str]]] = {} + sol_pairs: list[PrimerPair] = [] for pair_id in solution.primer_pairs: pair = pair_lookup.get(pair_id) if pair is None: continue + sol_pairs.append(pair) + sol_primers[pair_id] = [ + (pair.forward.seq, pair.forward.name), + (pair.reverse.seq, pair.reverse.name), + ] + + for pair in sol_pairs: junction = self._junction_for_pair(pair) if junction is None: continue row = self._build_enriched_pair_row(junction, pair) + + contribution = 0.0 + for other in sol_pairs: + if other.pair_id == pair.pair_id: + continue + for seq_a, name_a in sol_primers[pair.pair_id]: + for seq_b, name_b in sol_primers[other.pair_id]: + key = tuple(sorted((seq_a, seq_b))) + if key not in dimer_cache: + dimer_predictor.set_primers( + seq_a, seq_b, name_a, name_b + ) + dimer_predictor.align() + dimer_cache[key] = dimer_predictor.score or 0.0 + score = dimer_cache[key] + if score < 0: + contribution += abs(score) + row["Cross_Dimer_Contribution"] = round(contribution, 4) + row["Solution_Rank"] = rank row["Solution_Cost"] = round(solution.cost, 4) + row["Num_Dropped"] = len(solution.dropped_targets) + row["Dropped_Targets"] = ( + "; ".join(solution.dropped_targets) + if solution.dropped_targets + else "" + ) data.append(row) df = pd.DataFrame(data) @@ -1054,6 +1147,9 @@ def save_panel_summary_json( "num_junctions": len(self.junctions), "num_candidate_pairs": total_pairs, "num_selected_pairs": len(pipeline_result.selected_pairs), + "dropped_targets": pipeline_result.multiplex_solutions[0].dropped_targets + if pipeline_result.multiplex_solutions + else [], "best_multiplex_cost": round(pipeline_result.multiplex_solutions[0].cost, 4) if pipeline_result.multiplex_solutions else None, diff --git a/src/plexus/designer/primer.py b/src/plexus/designer/primer.py index b195b5e..9b686ab 100644 --- a/src/plexus/designer/primer.py +++ b/src/plexus/designer/primer.py @@ -71,6 +71,15 @@ class PrimerPair: snp_count: int = 0 snp_penalty: float = 0.0 + @property + def amplicon_gc(self) -> float | None: + """GC content of the amplicon sequence as a percentage.""" + seq = self.amplicon_sequence + if not seq: + return None + gc_count = seq.upper().count("G") + seq.upper().count("C") + return round(100.0 * gc_count / len(seq), 2) + @staticmethod def calculate_primer_pair_penalty_th( primer_left_penalty, diff --git a/src/plexus/pipeline.py b/src/plexus/pipeline.py index 6b3717d..c4ba6c1 100644 --- a/src/plexus/pipeline.py +++ b/src/plexus/pipeline.py @@ -757,7 +757,28 @@ def advance_step(label=None): seed=config.multiplex_picker_parameters.selector_seed, ) logger.info(f"Using '{selector}' selector algorithm.") - if selector in ("Greedy", "Random"): + if selector == "DFS": + import math as _math + + mpick = config.multiplex_picker_parameters + n_input_targets = selector_df["target_id"].nunique() + + # Clamp plexity to actual input count + effective_min = min(mpick.minimum_plexity, n_input_targets) + min_target_floor = min( + max( + effective_min, + _math.ceil(mpick.min_target_fraction * n_input_targets), + ), + n_input_targets, + ) + + solutions = selector_obj.run( + allow_target_dropping=mpick.allow_target_dropping, + min_target_floor=min_target_floor, + stringency=mpick.dropout_stringency, + ) + elif selector in ("Greedy", "Random"): solutions = selector_obj.run( N=config.multiplex_picker_parameters.initial_solutions ) @@ -783,6 +804,18 @@ def advance_step(label=None): f"Selected {len(selected)} primer pairs (best cost: {best.cost:.2f})" ) + if best.dropped_targets: + logger.warning( + f"Dropped {len(best.dropped_targets)} target(s) " + f"due to high interactivity: " + f"{', '.join(best.dropped_targets)}" + ) + result.warnings.append( + f"Target dropout: {len(best.dropped_targets)} " + f"target(s) dropped: " + f"{', '.join(best.dropped_targets)}" + ) + result.steps_completed.append("multiplex_optimized") except Exception as e: logger.error(f"Multiplex optimization failed: {e}") @@ -819,13 +852,37 @@ def advance_step(label=None): # Panel QC report (REPT-01) if result.selected_pairs: try: - from plexus.reporting.qc import generate_panel_qc + from plexus.reporting.qc import ( + generate_panel_qc, + save_primer_dimer_matrix_csv, + ) - qc_data = generate_panel_qc(panel.junctions) + fwd_tail = ( + panel.config.singleplex_design_parameters.forward_tail + if panel.config + else "" + ) + rev_tail = ( + panel.config.singleplex_design_parameters.reverse_tail + if panel.config + else "" + ) + qc_data = generate_panel_qc( + panel.junctions, + forward_tail=fwd_tail, + reverse_tail=rev_tail, + ) qc_path = output_dir / "panel_qc.json" with qc_path.open("w") as f: _json.dump(qc_data, f, indent=2) logger.info(f"Wrote panel QC report to {qc_path.name}") + + if qc_data.get("primer_dimer_matrix"): + csv_path = output_dir / "primer_dimer_matrix.csv" + save_primer_dimer_matrix_csv( + qc_data["primer_dimer_matrix"], str(csv_path) + ) + logger.info(f"Wrote primer dimer matrix to {csv_path.name}") except Exception as e: logger.warning(f"Could not write panel QC report: {e}") result.errors.append(f"Panel QC report failed: {e}") diff --git a/src/plexus/reporting/qc.py b/src/plexus/reporting/qc.py index c194158..8e73abd 100644 --- a/src/plexus/reporting/qc.py +++ b/src/plexus/reporting/qc.py @@ -2,6 +2,7 @@ from __future__ import annotations +import csv import re import statistics from itertools import combinations @@ -13,6 +14,11 @@ from plexus.designer.multiplexpanel import Junction +def _tailed(tail: str, seq: str) -> str: + """Prepend tail to sequence, replacing N with A for NN-model compatibility.""" + return tail.replace("N", "A") + seq + + def generate_panel_qc( junctions: list[Junction], *, @@ -20,6 +26,8 @@ def generate_panel_qc( gc_low_threshold: float = 30.0, homopolymer_min_run: int = 4, dimer_threshold: float = 0.0, + forward_tail: str = "", + reverse_tail: str = "", ) -> dict: """Generate panel QC metrics for the selected primer pairs.""" # Build working lists @@ -86,55 +94,79 @@ def generate_panel_qc( "flagged_primers": flagged_primers, } - # Cross-reactivity matrix + # --- Primer-level dimer matrix --- + # Build list of (label, tailed_sequence) for every primer in the panel. + primer_entries = [] # [(label, tailed_seq), ...] + primer_labels = [] + for jname, pair in selected: + fwd_label = f"{jname}_forward" + rev_label = f"{jname}_reverse" + primer_labels.append(fwd_label) + primer_labels.append(rev_label) + primer_entries.append((fwd_label, _tailed(forward_tail, pair.forward.seq))) + primer_entries.append((rev_label, _tailed(reverse_tail, pair.reverse.seq))) + predictor = PrimerDimerPredictor() - matrix: dict[str, dict] = {} - for (jname_a, pair_a), (jname_b, pair_b) in combinations(selected, 2): - scores = [] - for seq_a, name_a, seq_b, name_b in [ - ( - pair_a.forward.seq, - pair_a.forward.name, - pair_b.forward.seq, - pair_b.forward.name, - ), - ( - pair_a.forward.seq, - pair_a.forward.name, - pair_b.reverse.seq, - pair_b.reverse.name, - ), - ( - pair_a.reverse.seq, - pair_a.reverse.name, - pair_b.forward.seq, - pair_b.forward.name, - ), - ( - pair_a.reverse.seq, - pair_a.reverse.name, - pair_b.reverse.seq, - pair_b.reverse.name, - ), - ]: - predictor.set_primers(seq_a, seq_b, name_a, name_b) - predictor.align() - scores.append(predictor.score or 0.0) + primer_matrix: dict[str, dict[str, float]] = {} + + for (label_a, seq_a), (label_b, seq_b) in combinations(primer_entries, 2): + predictor.set_primers(seq_a, seq_b, label_a, label_b) + predictor.align() + score = round(predictor.score or 0.0, 4) + primer_matrix.setdefault(label_a, {})[label_b] = score + primer_matrix.setdefault(label_b, {})[label_a] = score + primer_dimer_matrix = { + "primer_labels": primer_labels, + "dimer_threshold": dimer_threshold, + "matrix": primer_matrix, + } + + # --- Junction-level cross-reactivity matrix (derived) --- + junction_matrix: dict[str, dict] = {} + for (jname_a, _pair_a), (jname_b, _pair_b) in combinations(selected, 2): + scores = [] + for d_a in ("forward", "reverse"): + for d_b in ("forward", "reverse"): + la = f"{jname_a}_{d_a}" + lb = f"{jname_b}_{d_b}" + scores.append(primer_matrix.get(la, {}).get(lb, 0.0)) cell = { "min_dimer_score": round(min(scores), 4), "interaction_count": sum(1 for s in scores if s < dimer_threshold), } - matrix.setdefault(jname_a, {})[jname_b] = cell - matrix.setdefault(jname_b, {})[jname_a] = cell + junction_matrix.setdefault(jname_a, {})[jname_b] = cell + junction_matrix.setdefault(jname_b, {})[jname_a] = cell cross_reactivity_matrix = { "dimer_threshold": dimer_threshold, - "matrix": matrix, + "matrix": junction_matrix, } return { "tm_distribution": tm_distribution, "sequence_flags": sequence_flags, "cross_reactivity_matrix": cross_reactivity_matrix, + "primer_dimer_matrix": primer_dimer_matrix, } + + +def save_primer_dimer_matrix_csv( + primer_dimer_data: dict, + file_path: str, +) -> None: + """Write the primer-level dimer matrix as a symmetric CSV.""" + labels = primer_dimer_data["primer_labels"] + matrix = primer_dimer_data["matrix"] + + with open(file_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow([""] + labels) + for label_a in labels: + row = [label_a] + for label_b in labels: + if label_a == label_b: + row.append("") + else: + row.append(matrix.get(label_a, {}).get(label_b, "")) + writer.writerow(row) diff --git a/src/plexus/reporting/templates/panel_report.html.j2 b/src/plexus/reporting/templates/panel_report.html.j2 index 17775e2..de09fe8 100644 --- a/src/plexus/reporting/templates/panel_report.html.j2 +++ b/src/plexus/reporting/templates/panel_report.html.j2 @@ -178,6 +178,15 @@
{% endif %} + +{% if qc.primer_dimer_matrix and qc.primer_dimer_matrix.matrix %} +

Primer-Level Dimer Heatmap

+

+ Individual primer-to-primer dimer scores. Lower (more negative) = stronger dimer formation. +

+
+{% endif %} + {% if qc.sequence_flags.flagged_primers %}

Sequence Flags

@@ -276,8 +285,8 @@
Run:
{{ provenance.get("run_timestamp", "—") }}
{% if provenance.get("completed_at") %}
Completed:
{{ provenance.completed_at }}
{% endif %} -
FASTA:
{{ provenance.get("fasta_sha256", "—")[:16] }}…
- {% if provenance.get("snp_vcf_sha256") %}
VCF:
{{ provenance.snp_vcf_sha256[:16] }}…
{% endif %} +
FASTA:
{{ (provenance.get("fasta_sha256") or "—")[:16] }}…
+ {% if provenance.get("snp_vcf_sha256") %}
VCF:
{{ (provenance.snp_vcf_sha256 or "—")[:16] }}…
{% endif %} {% if provenance.get("tool_versions") %} {% for tool, ver in provenance.tool_versions.items() %}
{{ tool }}:
{{ ver }}
@@ -399,7 +408,7 @@ } {% endif %} - // --- Cross-Reactivity Heatmap --- + // --- Cross-Reactivity Heatmap (Junction-Level) --- {% if qc.cross_reactivity_matrix and qc.cross_reactivity_matrix.matrix %} var matrix = qc.cross_reactivity_matrix.matrix; var labels = Object.keys(matrix); @@ -451,6 +460,61 @@ }, config); } {% endif %} + + // --- Primer-Level Dimer Heatmap --- + {% if qc.primer_dimer_matrix and qc.primer_dimer_matrix.matrix %} + var pMatrix = qc.primer_dimer_matrix.matrix; + var pLabels = qc.primer_dimer_matrix.primer_labels; + if (pLabels && pLabels.length > 1) { + var pz = []; + var pHover = []; + for (var i = 0; i < pLabels.length; i++) { + var pRow = []; + var phRow = []; + for (var j = 0; j < pLabels.length; j++) { + if (i === j) { + pRow.push(null); + phRow.push(pLabels[i] + " (self)"); + } else { + var score = (pMatrix[pLabels[i]] || {})[pLabels[j]]; + if (score !== undefined && score !== null) { + pRow.push(score); + phRow.push(pLabels[i] + " vs " + pLabels[j] + "
Score: " + score); + } else { + pRow.push(null); + phRow.push("N/A"); + } + } + } + pz.push(pRow); + pHover.push(phRow); + } + + // Short labels: gene_direction → gene (F/R) + var pShort = pLabels.map(function(l) { + var last = l.lastIndexOf("_"); + var gene = l.substring(0, last); + var dir = l.substring(last + 1); + return gene + " (" + dir.charAt(0).toUpperCase() + ")"; + }); + + Plotly.newPlot("primer-heatmap-chart", [{ + z: pz, x: pShort, y: pShort, + type: "heatmap", + colorscale: [[0, "#dc2626"], [0.5, "#fbbf24"], [1, "#ffffff"]], + zmin: -12, zmax: 0, + hoverongaps: false, + text: pHover, + hoverinfo: "text", + colorbar: {title: "Dimer Score", titleside: "right"} + }], { + xaxis: {tickangle: -45, side: "bottom"}, + yaxis: {autorange: "reversed"}, + margin: {b: 140, t: 30, l: 140, r: 30}, + height: Math.max(500, pLabels.length * 25 + 200) + }, config); + } + {% endif %} })(); diff --git a/src/plexus/reporting/templates/plotly.min.js.gz b/src/plexus/reporting/templates/plotly.min.js.gz index af85ffa..141a583 100644 Binary files a/src/plexus/reporting/templates/plotly.min.js.gz and b/src/plexus/reporting/templates/plotly.min.js.gz differ diff --git a/src/plexus/selector/multiplex.py b/src/plexus/selector/multiplex.py index f22d75a..69df4b8 100644 --- a/src/plexus/selector/multiplex.py +++ b/src/plexus/selector/multiplex.py @@ -7,3 +7,4 @@ class Multiplex: primer_pairs: list = field(default_factory=list) # list of pair_id strings cost: float = 0.0 + dropped_targets: list = field(default_factory=list) # target_ids skipped by dropout diff --git a/src/plexus/selector/selectors.py b/src/plexus/selector/selectors.py index 89ef647..2b8f720 100644 --- a/src/plexus/selector/selectors.py +++ b/src/plexus/selector/selectors.py @@ -308,6 +308,11 @@ class DepthFirstSearch(MultiplexSelector): whose partial cost plus a lower-bound on remaining cost exceeds the best known solution. + Supports optional **target dropout**: when ``allow_target_dropping=True``, + each target may be skipped at the cost of a per-drop penalty. Only + targets whose cross-dimer contribution exceeds the adaptive penalty + threshold will be dropped. A hard floor (``min_target_floor``) prevents + excessive dropping. """ def run( @@ -317,12 +322,23 @@ def run( greedy_seed_iterations=100, max_nodes=10_000_000, target_ordering="ascending_candidates", + # Target dropout parameters + allow_target_dropping=False, + dropout_penalty=None, + min_target_floor=None, + stringency=0.8, ): target_pairs = { target_id: list(set(target_df["pair_name"])) for target_id, target_df in self.primer_df.groupby("target_id") } + # Reverse lookup: pair_id -> target_id + pair_to_target = {} + for target_id, target_df in self.primer_df.groupby("target_id"): + for pair_name in set(target_df["pair_name"]): + pair_to_target[pair_name] = target_id + # Order targets if target_ordering == "ascending_candidates": ordered_targets = sorted(target_pairs, key=lambda t: len(target_pairs[t])) @@ -341,21 +357,61 @@ def run( sorted_candidates[tid] = [p for p, _ in pair_costs] min_individual_cost[tid] = pair_costs[0][1] - # Precompute suffix sums of minimum individual costs for lower bounds - suffix_min = [0.0] * (n_targets + 1) - for i in range(n_targets - 1, -1, -1): - suffix_min[i] = suffix_min[i + 1] + min_individual_cost[ordered_targets[i]] + # Dropout budget + if allow_target_dropping: + if min_target_floor is None: + min_target_floor = n_targets + max_drops = max(0, n_targets - min_target_floor) + else: + max_drops = 0 + min_target_floor = n_targets + + # Precompute suffix lower-bound data structures. + # suffix_prefix_sums[d] = prefix sums of sorted min_individual_costs + # for targets at depths d..n_targets-1. Used to compute the tightest + # lower bound given a remaining drop budget. + suffix_prefix_sums: list[list[float]] = [] + for d in range(n_targets + 1): + costs = sorted( + min_individual_cost[ordered_targets[i]] for i in range(d, n_targets) + ) + ps = [0.0] + for c in costs: + ps.append(ps[-1] + c) + suffix_prefix_sums.append(ps) # Optionally seed best_known_cost from greedy best_known_cost = float("inf") + greedy_results = [] if seed_with_greedy: greedy = GreedySearch(self.primer_df, self.cost_function) greedy_results = greedy.run(N=greedy_seed_iterations) if greedy_results: best_known_cost = min(m.cost for m in greedy_results) + # Compute adaptive dropout penalty from greedy seed + if allow_target_dropping and max_drops > 0: + if dropout_penalty is None: + if greedy_results: + best_greedy = min(greedy_results, key=lambda m: m.cost) + dropout_penalty = self._compute_dropout_penalty( + pair_to_target, ordered_targets, best_greedy, stringency + ) + logger.info( + f"Adaptive dropout penalty: {dropout_penalty:.4f} " + f"(stringency={stringency}, max_drops={max_drops})" + ) + else: + dropout_penalty = 1.0 + logger.warning( + "No greedy seed for adaptive penalty; using fallback=1.0" + ) + else: + if dropout_penalty is None: + dropout_penalty = 0.0 + # Iterative DFS with explicit stack - # Stack entries: (depth, assignment_list, partial_cost) + # Stack entries: (depth, included_pairs, dropped_indices) stored_multiplexes = [] stored_costs = [] nodes_visited = 0 @@ -365,14 +421,16 @@ def run( ) logger.info( f"Running DFS over {n_targets} targets ({total_combos} total combinations), " - f"max_nodes={max_nodes}..." + f"max_nodes={max_nodes}, max_drops={max_drops}..." ) # Push initial candidates for depth 0 in reverse order (cheapest first via LIFO) - stack = [] + stack: list[tuple[int, list[str], frozenset[int]]] = [] tid0 = ordered_targets[0] for candidate in reversed(sorted_candidates[tid0]): - stack.append((0, [candidate])) + stack.append((0, [candidate], frozenset())) + if max_drops > 0: + stack.append((0, [], frozenset({0}))) while stack: if nodes_visited >= max_nodes: @@ -381,39 +439,70 @@ def run( ) break - depth, assignment = stack.pop() + depth, included_pairs, dropped_indices = stack.pop() nodes_visited += 1 + n_dropped = len(dropped_indices) - partial_cost = self.cost_function.calc_cost(assignment) + # Cost = base cost of included pairs + penalty per drop + partial_cost = ( + self.cost_function.calc_cost(included_pairs) + + n_dropped * dropout_penalty + ) - # Prune: if buffer is full, check lower bound + n_remaining = n_targets - (depth + 1) + n_included = len(included_pairs) + + # Feasibility prune: can we still reach min_target_floor? + if n_included + n_remaining < min_target_floor: + continue + + # Cost prune if len(stored_multiplexes) >= store_maximum: - lower_bound = partial_cost + suffix_min[depth + 1] - if lower_bound >= best_known_cost: + remaining_budget = min(max_drops - n_dropped, n_remaining) + n_must_keep = n_remaining - remaining_budget + lb_include = suffix_prefix_sums[depth + 1][n_must_keep] + lb_drop = remaining_budget * dropout_penalty + if partial_cost + lb_include + lb_drop >= best_known_cost: continue # Complete solution if depth + 1 == n_targets: + if n_included < min_target_floor: + continue + + dropped_target_ids = [ + ordered_targets[i] for i in sorted(dropped_indices) + ] + mx = Multiplex( + cost=partial_cost, + primer_pairs=list(included_pairs), + dropped_targets=dropped_target_ids, + ) + if len(stored_multiplexes) < store_maximum: - stored_multiplexes.append( - Multiplex(cost=partial_cost, primer_pairs=list(assignment)) - ) + stored_multiplexes.append(mx) stored_costs.append(partial_cost) if partial_cost < best_known_cost: best_known_cost = partial_cost elif partial_cost < max(stored_costs): worst_idx = stored_costs.index(max(stored_costs)) - stored_multiplexes[worst_idx] = Multiplex( - cost=partial_cost, primer_pairs=list(assignment) - ) + stored_multiplexes[worst_idx] = mx stored_costs[worst_idx] = partial_cost best_known_cost = min(best_known_cost, partial_cost) continue # Expand next level — push in reverse so cheapest is popped first - next_tid = ordered_targets[depth + 1] + next_depth = depth + 1 + next_tid = ordered_targets[next_depth] for candidate in reversed(sorted_candidates[next_tid]): - stack.append((depth + 1, assignment + [candidate])) + stack.append( + (next_depth, included_pairs + [candidate], dropped_indices) + ) + # SKIP branch: drop this target if budget allows + if max_drops > 0 and n_dropped < max_drops: + stack.append( + (next_depth, list(included_pairs), dropped_indices | {next_depth}) + ) logger.info( f"DFS complete. Visited {nodes_visited} nodes, " @@ -421,6 +510,49 @@ def run( ) return stored_multiplexes + def _compute_dropout_penalty( + self, + pair_to_target: dict[str, str], + ordered_targets: list[str], + greedy_solution: Multiplex, + stringency: float, + ) -> float: + """Compute adaptive dropout penalty from greedy seed's marginal cross-dimer costs. + + For each target, the marginal cost is the difference in total panel cost + when that target is included vs. excluded. The penalty is set to the + ``stringency`` percentile of these marginal costs, so only targets with + above-threshold interactivity are candidates for removal. + """ + greedy_by_target = { + pair_to_target[pid]: pid for pid in greedy_solution.primer_pairs + } + all_pairs = [ + greedy_by_target[tid] for tid in ordered_targets if tid in greedy_by_target + ] + + if len(all_pairs) < 2: + return 1.0 + + full_cost = self.cost_function.calc_cost(all_pairs) + + marginal_costs = [] + for tid in ordered_targets: + pid = greedy_by_target.get(tid) + if pid is None: + continue + without = [p for p in all_pairs if p != pid] + cost_without = self.cost_function.calc_cost(without) if without else 0.0 + marginal_costs.append(full_cost - cost_without) + + if not marginal_costs: + return 1.0 + + marginal_costs.sort() + idx = min(int(stringency * len(marginal_costs)), len(marginal_costs) - 1) + penalty = marginal_costs[idx] + return max(penalty, 0.01) + # ================================================================================ # Collection of selection algorithms diff --git a/src/plexus/version.py b/src/plexus/version.py index 0b2f79d..c68196d 100644 --- a/src/plexus/version.py +++ b/src/plexus/version.py @@ -1 +1 @@ -__version__ = "1.1.3" +__version__ = "1.2.0" diff --git a/tests/test_reporting_html.py b/tests/test_reporting_html.py index e63bc5f..5389823 100644 --- a/tests/test_reporting_html.py +++ b/tests/test_reporting_html.py @@ -77,6 +77,37 @@ }, }, }, + "primer_dimer_matrix": { + "primer_labels": [ + "GENE_A_forward", + "GENE_A_reverse", + "GENE_B_forward", + "GENE_B_reverse", + ], + "dimer_threshold": 0.0, + "matrix": { + "GENE_A_forward": { + "GENE_A_reverse": -3.1, + "GENE_B_forward": -4.5, + "GENE_B_reverse": -5.2, + }, + "GENE_A_reverse": { + "GENE_A_forward": -3.1, + "GENE_B_forward": -2.8, + "GENE_B_reverse": -4.0, + }, + "GENE_B_forward": { + "GENE_A_forward": -4.5, + "GENE_A_reverse": -2.8, + "GENE_B_reverse": -1.5, + }, + "GENE_B_reverse": { + "GENE_A_forward": -5.2, + "GENE_A_reverse": -4.0, + "GENE_B_forward": -1.5, + }, + }, + }, } MINIMAL_SUMMARY = { @@ -106,13 +137,21 @@ def full_output_dir(qc_output_dir): "Forward_Genomic_Start,Forward_Genomic_End,Reverse_Genomic_Start,Reverse_Genomic_End," "Amplicon_Length,Insert_Size,Pair_Penalty,Dimer_Score,Off_Target_Count," "Specificity_Checked,On_Target_Detected,SNP_Count,SNP_Penalty," - "Forward_SNP_Count,Reverse_SNP_Count\n" + "Forward_SNP_Count,Reverse_SNP_Count," + "Forward_Self_Any_Th,Reverse_Self_Any_Th," + "Forward_Self_End_Th,Reverse_Self_End_Th," + "Forward_Hairpin_Th,Reverse_Hairpin_Th," + "Forward_End_Stability,Reverse_End_Stability," + "Forward_Penalty,Reverse_Penalty," + "Amplicon_GC,Num_Candidate_Pairs,Cross_Dimer_Contribution\n" "GENE_A,chr1,100,100,GENE_A_fwd_rev,ATCG,GCTA,ATCG,GCTA," "59.5,60.5,1.0,80,70,50.0,55.0,20,20,80,99,101,120,80,40," - "100.0,-1.5,0,True,True,0,0.0,0,0\n" + "100.0,-1.5,0,True,True,0,0.0,0,0," + "30.0,25.0,20.0,18.0,15.0,12.0,3.5,3.2,0.5,0.6,50.0,5,1.5\n" "GENE_B,chr2,200,200,GENE_B_fwd_rev,TTTT,CCCC,TTTT,CCCC," "58.0,62.0,4.0,75,65,35.0,65.0,22,22,180,201,201,222,90,46," - "120.0,-2.0,1,True,True,0,0.0,0,0\n" + "120.0,-2.0,1,True,True,0,0.0,0,0," + "32.0,28.0,22.0,20.0,16.0,14.0,3.8,3.4,0.7,0.8,55.0,3,2.0\n" ) (qc_output_dir / "selected_multiplex.csv").write_text(csv) @@ -165,6 +204,12 @@ def test_html_contains_heatmap_data(self, qc_output_dir): assert "-5.2" in html assert "cross_reactivity_matrix" in html or "heatmap" in html.lower() + def test_html_contains_primer_heatmap(self, qc_output_dir): + path = generate_html_report(qc_output_dir) + html = path.read_text() + assert "primer-heatmap-chart" in html + assert "Primer-Level Dimer Heatmap" in html + def test_html_handles_missing_optional_files(self, tmp_path): (tmp_path / "panel_qc.json").write_text(json.dumps(MINIMAL_QC)) path = generate_html_report(tmp_path, panel_name="Bare") diff --git a/tests/test_reporting_qc.py b/tests/test_reporting_qc.py index 0edb76d..9c3b6c3 100644 --- a/tests/test_reporting_qc.py +++ b/tests/test_reporting_qc.py @@ -221,3 +221,125 @@ def test_only_selected_pairs_used(): assert "R_sel" in names assert "F_no" not in names assert "R_no" not in names + + +# --------------------------------------------------------------------------- +# Primer-level dimer matrix +# --------------------------------------------------------------------------- + + +def test_primer_dimer_matrix_basic(): + pair_a = _make_pair("FA", "ACGTACGT", 60.0, 50.0, "RA", "TGCATGCA", 60.0, 50.0) + pair_b = _make_pair("FB", "CCCCGGGG", 60.0, 75.0, "RB", "GGGGCCCC", 60.0, 75.0) + junc_a = _make_junction("JA", [pair_a]) + junc_b = _make_junction("JB", [pair_b]) + result = generate_panel_qc([junc_a, junc_b]) + pdm = result["primer_dimer_matrix"] + assert len(pdm["primer_labels"]) == 4 + matrix = pdm["matrix"] + # All 6 unique pairs should have scores (C(4,2) = 6) + all_scores = [] + for la in pdm["primer_labels"]: + for lb in pdm["primer_labels"]: + if la != lb: + assert lb in matrix.get(la, {}), f"Missing {la} vs {lb}" + all_scores.append(matrix[la][lb]) + assert len(all_scores) == 12 # 4*3 entries (symmetric) + + +def test_primer_dimer_matrix_symmetric(): + pair_a = _make_pair("FA", "ACGTACGT", 60.0, 50.0, "RA", "TGCATGCA", 60.0, 50.0) + pair_b = _make_pair("FB", "CCCCGGGG", 60.0, 75.0, "RB", "GGGGCCCC", 60.0, 75.0) + junc_a = _make_junction("JA", [pair_a]) + junc_b = _make_junction("JB", [pair_b]) + result = generate_panel_qc([junc_a, junc_b]) + matrix = result["primer_dimer_matrix"]["matrix"] + for la in matrix: + for lb in matrix[la]: + assert matrix[la][lb] == matrix[lb][la], f"Asymmetric: {la} vs {lb}" + + +def test_primer_dimer_matrix_labels_format(): + pair = _make_pair("F1", "ACGTACGT", 60.0, 50.0, "R1", "TGCATGCA", 60.0, 50.0) + junc = _make_junction("KRAS", [pair]) + result = generate_panel_qc([junc]) + labels = result["primer_dimer_matrix"]["primer_labels"] + assert labels == ["KRAS_forward", "KRAS_reverse"] + + +def test_primer_dimer_matrix_single_junction(): + """Single junction should still have F vs R score.""" + pair = _make_pair("F1", "ACGTACGT", 60.0, 50.0, "R1", "TGCATGCA", 60.0, 50.0) + junc = _make_junction("J1", [pair]) + result = generate_panel_qc([junc]) + pdm = result["primer_dimer_matrix"] + assert len(pdm["primer_labels"]) == 2 + matrix = pdm["matrix"] + assert "J1_reverse" in matrix.get("J1_forward", {}) + assert "J1_forward" in matrix.get("J1_reverse", {}) + + +def test_primer_dimer_matrix_empty(): + result = generate_panel_qc([]) + pdm = result["primer_dimer_matrix"] + assert pdm["primer_labels"] == [] + assert pdm["matrix"] == {} + + +def test_primer_dimer_matrix_with_tails(): + """Scores should differ when tail sequences are provided.""" + pair_a = _make_pair( + "FA", "ACGTACGTACGT", 60.0, 50.0, "RA", "TGCATGCATGCA", 60.0, 50.0 + ) + pair_b = _make_pair( + "FB", "CCCCGGGGCCCC", 60.0, 75.0, "RB", "GGGGCCCCGGGG", 60.0, 75.0 + ) + junc_a = _make_junction("JA", [pair_a]) + junc_b = _make_junction("JB", [pair_b]) + + result_bare = generate_panel_qc([junc_a, junc_b]) + result_tailed = generate_panel_qc( + [junc_a, junc_b], + forward_tail="GGACACTCTTTCCCTACACGAC", + reverse_tail="GTGACTGGAGTTCAGACGTGT", + ) + + bare_matrix = result_bare["primer_dimer_matrix"]["matrix"] + tailed_matrix = result_tailed["primer_dimer_matrix"]["matrix"] + # At least one score should differ + diffs = 0 + for la in bare_matrix: + for lb in bare_matrix[la]: + if bare_matrix[la][lb] != tailed_matrix[la][lb]: + diffs += 1 + assert diffs > 0 + + +def test_save_primer_dimer_matrix_csv(tmp_path): + from plexus.reporting.qc import save_primer_dimer_matrix_csv + + pair_a = _make_pair("FA", "ACGTACGT", 60.0, 50.0, "RA", "TGCATGCA", 60.0, 50.0) + pair_b = _make_pair("FB", "CCCCGGGG", 60.0, 75.0, "RB", "GGGGCCCC", 60.0, 75.0) + junc_a = _make_junction("JA", [pair_a]) + junc_b = _make_junction("JB", [pair_b]) + result = generate_panel_qc([junc_a, junc_b]) + + csv_path = str(tmp_path / "dimer_matrix.csv") + save_primer_dimer_matrix_csv(result["primer_dimer_matrix"], csv_path) + + import csv + + with open(csv_path, newline="") as f: + reader = csv.reader(f) + rows = list(reader) + + # Header row + 4 data rows + assert len(rows) == 5 + # Header: empty + 4 labels + assert len(rows[0]) == 5 + assert rows[0][0] == "" + # Diagonal should be empty + for i in range(1, 5): + assert rows[i][i] == "" + # Off-diagonal should be numeric + assert float(rows[1][2]) != 0 or rows[1][2] != "" diff --git a/tests/test_selector.py b/tests/test_selector.py index d6a856e..65d87fa 100644 --- a/tests/test_selector.py +++ b/tests/test_selector.py @@ -329,3 +329,148 @@ def test_pair_dimer_weight_zero_disables_contribution(self): cost_fn = MultiplexCostFunction(pair_lookup, config) assert cost_fn.calc_cost(["weighted"]) == pytest.approx(0.0) + + +# ================================================================================ +# Tests for DFS target dropout +# ================================================================================ + + +class TestDFSTargetDropout: + def test_dropout_disabled_unchanged(self, selector_inputs): + """allow_target_dropping=False produces identical results to default.""" + df, cost_fn = selector_inputs + results_default = DepthFirstSearch(df, cost_fn).run(seed_with_greedy=False) + results_nodrop = DepthFirstSearch(df, cost_fn).run( + seed_with_greedy=False, allow_target_dropping=False + ) + assert len(results_default) == len(results_nodrop) + for m_def, m_nod in zip( + sorted(results_default, key=lambda m: str(m.primer_pairs)), + sorted(results_nodrop, key=lambda m: str(m.primer_pairs)), + strict=False, + ): + assert m_def.primer_pairs == m_nod.primer_pairs + assert m_def.dropped_targets == [] + assert m_nod.dropped_targets == [] + + def test_drops_toxic_target(self): + """DFS drops a target whose inclusion dramatically increases cost.""" + df = pd.DataFrame( + { + "target_id": ["T1", "T1", "T2", "T2", "T3", "T3"], + "pair_name": ["P1a", "P1b", "P2a", "P2b", "P3a", "P3b"], + } + ) + + def toxic_cost(pairs): + base = len(pairs) * 0.1 + if any(p.startswith("P3") for p in pairs): + base += 100.0 + return base + + cost_fn = MagicMock() + cost_fn.calc_cost = toxic_cost + + results = DepthFirstSearch(df, cost_fn).run( + seed_with_greedy=False, + allow_target_dropping=True, + dropout_penalty=5.0, + min_target_floor=2, + ) + best = min(results, key=lambda m: m.cost) + # Best solution should drop T3 (cost 100) for penalty of 5 + assert len(best.primer_pairs) == 2 + assert "T3" in best.dropped_targets + + def test_min_target_floor_enforced(self): + """Cannot drop below min_target_floor.""" + df = pd.DataFrame( + { + "target_id": ["T1", "T2", "T3"], + "pair_name": ["P1", "P2", "P3"], + } + ) + cost_fn = MagicMock() + cost_fn.calc_cost = MagicMock(return_value=1.0) + + results = DepthFirstSearch(df, cost_fn).run( + seed_with_greedy=False, + allow_target_dropping=True, + dropout_penalty=0.01, + min_target_floor=3, + ) + for m in results: + assert len(m.dropped_targets) == 0 + assert len(m.primer_pairs) == 3 + + def test_multiplex_dropped_targets_default(self): + """Multiplex backward compat: dropped_targets defaults to [].""" + m = Multiplex() + assert m.dropped_targets == [] + m2 = Multiplex(cost=1.0, primer_pairs=["P1"]) + assert m2.dropped_targets == [] + + def test_solutions_include_both_dropped_and_full(self): + """Solutions buffer contains both full and reduced panels.""" + df = pd.DataFrame( + { + "target_id": ["T1", "T2", "T3"], + "pair_name": ["P1", "P2", "P3"], + } + ) + + def mild_cost(pairs): + base = len(pairs) * 1.0 + if any(p == "P3" for p in pairs): + base += 2.0 + return base + + cost_fn = MagicMock() + cost_fn.calc_cost = mild_cost + + results = DepthFirstSearch(df, cost_fn).run( + seed_with_greedy=False, + allow_target_dropping=True, + dropout_penalty=1.5, + min_target_floor=2, + store_maximum=200, + ) + # Should have solutions with 3 targets AND solutions with 2 targets + full_solutions = [m for m in results if len(m.dropped_targets) == 0] + dropped_solutions = [m for m in results if len(m.dropped_targets) > 0] + assert len(full_solutions) > 0 + assert len(dropped_solutions) > 0 + + def test_adaptive_penalty_computation(self): + """Verify _compute_dropout_penalty returns reasonable value.""" + df = pd.DataFrame( + { + "target_id": ["T1", "T2", "T3"], + "pair_name": ["P1", "P2", "P3"], + } + ) + + # Cost function where T3 contributes most of the cost + def cost_with_t3_toxic(pairs): + base = len(pairs) * 0.5 + if "P3" in pairs: + base += 10.0 + return base + + cost_fn = MagicMock() + cost_fn.calc_cost = cost_with_t3_toxic + + dfs = DepthFirstSearch(df, cost_fn) + pair_to_target = {"P1": "T1", "P2": "T2", "P3": "T3"} + ordered_targets = ["T1", "T2", "T3"] + greedy_sol = Multiplex(primer_pairs=["P1", "P2", "P3"], cost=11.5) + + penalty = dfs._compute_dropout_penalty( + pair_to_target, ordered_targets, greedy_sol, stringency=0.5 + ) + # T1 and T2 marginal costs are ~0.5 each (just their individual contribution) + # T3 marginal cost is ~10.5 (the toxic contribution) + # Sorted: [~0.5, ~0.5, ~10.5]. At stringency=0.5, idx=1 -> ~0.5 + assert penalty > 0 + assert penalty < 10.0 # should not be at the toxic level with stringency=0.5