Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions src/plexus/blast/specificity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pandas as pd
from loguru import logger

from plexus.blast.annotator import BlastResultsAnnotator
Expand Down Expand Up @@ -109,14 +110,8 @@ def run_specificity_check(
return

# 6. Map results back to PrimerPairs
amplicon_map = {}
for _, row in all_amplicons_df.iterrows():
f_id = row["F_primer"] # SEQ_X
r_id = row["R_primer"] # SEQ_Y

if (f_id, r_id) not in amplicon_map:
amplicon_map[(f_id, r_id)] = []
amplicon_map[(f_id, r_id)].append(row.to_dict())
amplicon_groups = all_amplicons_df.groupby(["F_primer", "R_primer"])
_empty = pd.DataFrame()

# Use the panel's unique map to look up IDs
seq_to_id = panel.unique_primer_map
Expand All @@ -138,29 +133,40 @@ def run_specificity_check(
pair.specificity_checked = True
n_checked += 1

potential_products = amplicon_map.get((f_id, r_id), [])
potential_products += amplicon_map.get((r_id, f_id), [])

off_targets = []
on_targets = []
for prod in potential_products:
if _is_on_target(prod, junction, pair, tolerance=ontarget_tolerance):
on_targets.append(prod)
else:
off_targets.append(prod)

pair.off_target_products = off_targets
pair.on_target_detected = len(on_targets) > 0
fwd = (
amplicon_groups.get_group((f_id, r_id))
if (f_id, r_id) in amplicon_groups.groups
else _empty
)
rev = (
amplicon_groups.get_group((r_id, f_id))
if (r_id, f_id) in amplicon_groups.groups
else _empty
)
potential_products = pd.concat([fwd, rev]) if not rev.empty else fwd

if potential_products.empty:
pair.off_target_products = []
pair.on_target_detected = False
else:
on_mask = _is_on_target_vec(
potential_products, junction, pair, tolerance=ontarget_tolerance
)
pair.on_target_detected = bool(on_mask.any())
off_df = potential_products[~on_mask]
pair.off_target_products = (
off_df.to_dict("records") if not off_df.empty else []
)

if not pair.on_target_detected:
n_missing_on_target += 1
logger.debug(
f"Pair {pair.pair_id}: on-target amplicon not detected by BLAST."
)

if off_targets:
if pair.off_target_products:
logger.debug(
f"Pair {pair.pair_id} has {len(off_targets)} off-target products."
f"Pair {pair.pair_id} has {len(pair.off_target_products)} off-target products."
)

if n_missing_on_target > 0:
Expand Down Expand Up @@ -232,6 +238,18 @@ def filter_offtarget_pairs(panel: MultiplexPanel) -> tuple[int, list[str]]:
return total_removed, fallback_junctions


def _is_on_target_vec(df, junction, pair, tolerance: int = 5):
"""Vectorized on-target classification for a DataFrame of amplicons."""
design_start = getattr(junction, "design_start", None) or 0
expected_fwd = design_start + pair.forward.start
expected_rev = design_start + pair.reverse.start + pair.reverse.length - 1
return (
(df["chrom"] == junction.chrom)
& ((df["F_start"] - expected_fwd).abs() <= tolerance)
& ((df["R_start"] - expected_rev).abs() <= tolerance)
)


def _is_on_target(prod: dict, junction, pair, tolerance: int = 5) -> bool:
"""Check if a BLAST amplicon overlaps the intended target region.

Expand Down
2 changes: 1 addition & 1 deletion src/plexus/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.0"
__version__ = "1.1.1"
Loading