diff --git a/.github/workflows/Human.ONT_simulated.TSS.yml b/.github/workflows/Human.ONT_simulated.TSS.yml new file mode 100644 index 00000000..24c4670a --- /dev/null +++ b/.github/workflows/Human.ONT_simulated.TSS.yml @@ -0,0 +1,76 @@ +name: Human ONT R10 simulated TSS (discovery + prediction) + +on: + workflow_dispatch: + schedule: + - cron: '0 5 * * 2' + +env: + RUN_NAME: Human.ONT_simulated.TSS + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant discovery + TSS' + if: always() + shell: bash + env: + STEP_NAME: Human.ONT_simulated.TSS + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/.github/workflows/Human.ONT_simulated.polyA_1.yml b/.github/workflows/Human.ONT_simulated.polyA_1.yml new file mode 100644 index 00000000..e1235e1f --- /dev/null +++ b/.github/workflows/Human.ONT_simulated.polyA_1.yml @@ -0,0 +1,76 @@ +name: Human ONT R10 simulated polyA 1 (discovery + prediction) + +on: + workflow_dispatch: + schedule: + - cron: '0 5 * * 1' + +env: + RUN_NAME: Human.ONT_simulated.polyA_1 + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant discovery + polyA' + if: always() + shell: bash + env: + STEP_NAME: Human.ONT_simulated.polyA_1 + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml b/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml new file mode 100644 index 00000000..78c13528 --- /dev/null +++ b/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml @@ -0,0 +1,76 @@ +name: Human ONT simulated polyA 2 no annotation + +on: + workflow_dispatch: + schedule: + - cron: '0 5 * * 0' + +env: + RUN_NAME: Human.ONT_simulated.polyA_2.no_gtf + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant de-novo discovery (no annotation)' + if: always() + shell: bash + env: + STEP_NAME: Human.ONT_simulated.polyA_2.no_gtf + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/.github/workflows/Human.ONT_simulated.polyA_2.yml b/.github/workflows/Human.ONT_simulated.polyA_2.yml new file mode 100644 index 00000000..a73263b1 --- /dev/null +++ b/.github/workflows/Human.ONT_simulated.polyA_2.yml @@ -0,0 +1,76 @@ +name: Human ONT R10 simulated polyA 2 (discovery + prediction) + +on: + workflow_dispatch: + schedule: + - cron: '0 7 * * 1' + +env: + RUN_NAME: Human.ONT_simulated.polyA_2 + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant discovery + polyA' + if: always() + shell: bash + env: + STEP_NAME: Human.ONT_simulated.polyA_2 + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/.github/workflows/Mouse.ONT_simulated.TSS.yml b/.github/workflows/Mouse.ONT_simulated.TSS.yml new file mode 100644 index 00000000..527f060f --- /dev/null +++ b/.github/workflows/Mouse.ONT_simulated.TSS.yml @@ -0,0 +1,76 @@ +name: Mouse ONT R10 simulated TSS (discovery + prediction) + +on: + workflow_dispatch: + schedule: + - cron: '0 7 * * 2' + +env: + RUN_NAME: Mouse.ONT_simulated.TSS + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant discovery + TSS' + if: always() + shell: bash + env: + STEP_NAME: Mouse.ONT_simulated.TSS + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/.github/workflows/SIRVs.simulated_perfect_polya.yml b/.github/workflows/SIRVs.simulated_perfect_polya.yml new file mode 100644 index 00000000..c32005ae --- /dev/null +++ b/.github/workflows/SIRVs.simulated_perfect_polya.yml @@ -0,0 +1,76 @@ +name: SIRVs simulated perfect polyA (discovery) + +on: + workflow_dispatch: + schedule: + - cron: '0 9 * * 2' + +env: + RUN_NAME: SIRVs.simulated_perfect_polya + LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py + CFG_DIR: /abga/work/andreyp/ci_isoquant/data + BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/ + OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/ + +concurrency: + group: ${{github.workflow}} + cancel-in-progress: false + +jobs: + check-changes: + runs-on: + labels: [isoquant] + name: 'Check for recent changes' + outputs: + has_changes: ${{steps.check.outputs.has_changes}} + steps: + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: 'Check for commits in last 7 days' + id: check + run: | + # Always run on manual trigger + if [ "${{github.event_name}}" = "workflow_dispatch" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + exit 0 + fi + # Check for commits in last 7 days + COMMITS=$(git log --oneline --since="7 days ago" | wc -l) + if [ "$COMMITS" -gt 0 ]; then + echo "Found $COMMITS commits in last 7 days" + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "No commits in last 7 days, skipping" + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + launch-runner: + needs: check-changes + if: needs.check-changes.outputs.has_changes == 'true' + runs-on: + labels: [isoquant] + name: 'Running IsoQuant and QC' + + steps: + - name: 'Cleanup' + run: > + set -e && + shopt -s dotglob && + rm -rf * + + - name: 'Checkout' + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: 'IsoQuant discovery' + if: always() + shell: bash + env: + STEP_NAME: SIRVs.simulated_perfect_polya + run: | + export PATH=${{env.BIN_PATH}}:$PATH + python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}} diff --git a/isoquant.py b/isoquant.py index 30179a57..cf3e3f6b 100755 --- a/isoquant.py +++ b/isoquant.py @@ -360,6 +360,12 @@ def add_hidden_option(*args, **kwargs): # show command only with --full-help add_hidden_option("--collect_tss_training", type=str, default=None, help="Developer: dump per-peak features + true_peak label for TSS training to this CSV path.") + # Alternative-polyA/TSS isoform discovery is default ON when an annotation is + # given (--genedb). --novel_apa (default off) extends it to novel + # (non-reference) transcripts, not only known ones. + add_hidden_option("--novel_apa", action="store_true", default=False, + help="Developer: extend alternative-end isoform creation to novel transcripts too.") + isoquant_version = "3.12.0" try: with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION")) as version_f: diff --git a/isoquant_lib/common.py b/isoquant_lib/common.py index 46373f6f..6f304c75 100644 --- a/isoquant_lib/common.py +++ b/isoquant_lib/common.py @@ -13,6 +13,7 @@ import math from collections import defaultdict from enum import Enum +from typing import List, Tuple logger = logging.getLogger('IsoQuant') @@ -882,7 +883,9 @@ def get_strand(introns, reference_region, ref_region_start=1): # binary search of a coordinate in ordered non-overlapping intervals -def interval_bin_search(ordered_intervals, pos): +def interval_bin_search(ordered_intervals: List[Tuple[int, int]], pos: int) -> int: + if not ordered_intervals: + return -1 if pos > ordered_intervals[-1][1] or pos < ordered_intervals[0][0]: return -1 @@ -901,7 +904,9 @@ def interval_bin_search(ordered_intervals, pos): return ind -def interval_bin_search_rev(ordered_intervals, pos): +def interval_bin_search_rev(ordered_intervals: List[Tuple[int, int]], pos: int) -> int: + if not ordered_intervals: + return -1 if pos > ordered_intervals[-1][1] or pos < ordered_intervals[0][0]: return -1 diff --git a/isoquant_lib/graph_based_model_construction.py b/isoquant_lib/graph_based_model_construction.py index e7030b2b..7a20921e 100644 --- a/isoquant_lib/graph_based_model_construction.py +++ b/isoquant_lib/graph_based_model_construction.py @@ -9,6 +9,7 @@ from collections import defaultdict from functools import cmp_to_key from enum import unique, Enum +from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from .common import ( TranscriptNaming, @@ -33,6 +34,10 @@ from .long_read_assigner import LongReadAssigner from .long_read_profiles import CombinedProfileConstructor from .polya_finder import PolyAInfo +from .terminal_peaks import detect_peaks, get_polya_model, get_tss_model + +if TYPE_CHECKING: + from xgboost import XGBClassifier logger = logging.getLogger('IsoQuant') @@ -53,17 +58,39 @@ class GraphBasedModelConstructor: def __init__(self, gene_info, chr_record, args, transcript_counter, gene_counter, id_distributor, grouping_strategy_names=None, use_technical_replicas=False, - string_pools=None): + string_pools=None, + polya_predictions=None, + tss_predictions=None): self.gene_info = gene_info self.chr_record = chr_record self.args = args self.id_distributor = id_distributor self.string_pools = string_pools + # Predicted polyA / TSS sites for this gene (genomic positions), reused + # from the per-gene terminal-position counters to refine intron-graph + # terminal vertices. + self.polya_predictions = polya_predictions + self.tss_predictions = tss_predictions self.grouping_strategy_names = grouping_strategy_names if grouping_strategy_names else [] self.use_technical_replicas = use_technical_replicas # Find file_name group index for technical replicas check self.file_name_group_idx = self.grouping_strategy_names.index("file_name") if "file_name" in self.grouping_strategy_names else -1 + # Constructed (novel) model ends are always refined to the trained + # polyA/TSS peaks (see correct_novel_transcript_ends). Runs with or + # without an annotation (needs only the shipped detector + the model's + # own reads); never creates or drops models. + # Alternative-end NIC *creation* (graph-level reclassification + + # post-pass known->NIC) requires an annotation to refine against, so it + # is gated on --genedb. Keeps the with-genedb behaviour unchanged. + self.create_nics = bool(getattr(args, 'genedb', None)) + # TSS model is used (for polishing and for NIC creation) only with + # full-length evidence; read starts are unreliable otherwise. + self.use_tss_model = bool(getattr(args, 'fl_data', False)) + # Opt-in (--novel_apa, default off): also create alternative-end isoforms + # for novel (non-reference) transcripts; works with or without --genedb. + self.novel_apa = bool(getattr(args, 'novel_apa', False)) + self.strand_detector = StrandDetector(self.chr_record) self.intron_genes = defaultdict(set) self.set_gene_properties() @@ -125,7 +152,9 @@ def select_reference_gene(self, transcript_introns, transcript_range, transcript return None def process(self, read_assignment_storage): - self.intron_graph = IntronGraph(self.args, self.gene_info, read_assignment_storage) + self.intron_graph = IntronGraph(self.args, self.gene_info, read_assignment_storage, + polya_predictions=self.polya_predictions, + tss_predictions=self.tss_predictions) self.path_processor = IntronPathProcessor(self.args, self.intron_graph) self.path_storage = IntronPathStorage(self.args, self.path_processor) self.path_storage.fill(read_assignment_storage) @@ -366,6 +395,9 @@ def filter_transcripts(self): filtered_storage.append(model) self.transcript_model_storage = filtered_storage + if self.create_nics or self.novel_apa: + self._add_known_alternative_end_models() + self._drop_duplicate_alt_end_models() def mapping_quality(self, transcript_id): mapq = 0 @@ -420,6 +452,50 @@ def save_assigned_read(self, read_assignment, transcript_id): self.internal_counter[transcript_id] += 1 self.read_assignment_counts[read_id] += 1 + def _reference_isoform_for_path(self, assignment: ReadAssignment, path: tuple, intron_path: tuple, + transcript_range: Tuple[int, int]) -> Optional[str]: + # Reference isoform whose intron chain matches this path. It is reported + # as the annotated known UNLESS a *detected polyA* terminal vertex + # (VERTEX_polya / VERTEX_polyt) disagrees with the annotated end by more + # than apa_delta -> then the caller emits a novel-in-catalog isoform with + # the refined ends. A bare read_end / read_start (no polyA evidence, e.g. + # a degraded ONT terminus) never triggers reclassification, so known + # transcripts are not lost to noise. + if is_matching_assignment(assignment): + matched_reference_id = assignment.isoform_matches[0].assigned_transcript + elif intron_path in self.known_isoforms_in_graph: + matched_reference_id = self.known_isoforms_in_graph[intron_path] + else: + return None + + if matched_reference_id is None or \ + matched_reference_id not in self.gene_info.all_isoforms_exons: + return None + + ref_exons = self.gene_info.all_isoforms_exons[matched_reference_id] + left_diff = abs(transcript_range[0] - ref_exons[0][0]) > self.args.apa_delta + right_diff = abs(transcript_range[1] - ref_exons[-1][1]) > self.args.apa_delta + + # PolyA side: a detected polyA vertex (VERTEX_polyt on the genomic-left + # 3' end of a '-' transcript, VERTEX_polya on the genomic-right 3' end of + # a '+' transcript) that disagrees with the annotation -> alternative + # polyA NIC. Bare read termini never trigger here (degraded-end safety). + if path[0][0] == VERTEX_polyt and left_diff: + return None + if path[-1][0] == VERTEX_polya and right_diff: + return None + + # TSS side: only with full-length evidence (--fl_data). The 5' end is the + # genomic-left read_start for '+' and the genomic-right read_end for '-'. + if self.use_tss_model: + strand = self.gene_info.isoform_strands.get(matched_reference_id, '.') + if strand == '+' and path[0][0] == VERTEX_read_start and left_diff: + return None + if strand == '-' and path[-1][0] == VERTEX_read_end and right_diff: + return None + + return matched_reference_id + def construct_fl_isoforms(self): # a minor trick to compare tuples of pairs, whose starting and terminating elements have different type logger.debug("Total FL paths %d" % len(self.path_storage.fl_paths)) @@ -449,7 +525,13 @@ def construct_fl_isoforms(self): # logger.debug("uuu Checking novel transcript %s: %s; assignment type %s" % # (new_transcript_id, str(novel_exons), str(assignment.assignment_type))) - if is_matching_assignment(assignment): + if self.create_nics: + # use the path's goal-1-refined terminal vertices: keep the + # reference only when both ends agree with the annotation within + # apa_delta, otherwise fall through to the novel branch below to + # emit an alternative-end NIC built from the refined exons + reference_isoform = self._reference_isoform_for_path(assignment, path, intron_path, transcript_range) + elif is_matching_assignment(assignment): reference_isoform = assignment.isoform_matches[0].assigned_transcript # logger.debug("uuu Substituting with known isoform %s" % reference_isoform) elif intron_path in self.known_isoforms_in_graph: @@ -796,49 +878,292 @@ def assign_reads_to_models(self, read_assignments): else: self.read_assignment_counts[read_id] = 0 - def correct_novel_transcript_ends(self, transcript_model, assigned_reads): + def correct_novel_transcript_ends(self, transcript_model: TranscriptModel, + assigned_reads: List[ReadAssignment]) -> None: + # Part 1: detector-based per-transcript end refinement. Build read-terminus + # histograms from this model's own reads and snap each unsupported terminal + # exon end to the dominant confident polyA/TSS peak (XGBoost detect_peaks); + # both ends are refined in one pass over the same reads so the chosen + # 5'/3' stay concordant. polyA always; TSS only when use_tss_model + # (--fl_data). No confident peak / no model -> positional fallback, then + # leave as is. Never creates or drops a model. logger.debug("Verifying ends for transcript %s" % transcript_model.transcript_id) - transcript_end = transcript_model.exon_blocks[-1][1] transcript_start = transcript_model.exon_blocks[0][0] - start_supported = False - read_starts = set() - end_supported = False - read_ends = defaultdict(int) + transcript_end = transcript_model.exon_blocks[-1][1] + first_exon_right = transcript_model.exon_blocks[0][1] + last_exon_left = transcript_model.exon_blocks[-1][0] + strand = transcript_model.strand + + start_hist, end_hist = self._terminal_histograms(transcript_model, assigned_reads) + start_supported = any(abs(p - transcript_start) <= self.args.apa_delta for p in start_hist) + end_supported = any(abs(p - transcript_end) <= self.args.apa_delta for p in end_hist) - for assignment in assigned_reads: - read_exons = assignment.corrected_exons - if abs(read_exons[0][0] - transcript_start) <= self.args.apa_delta: - start_supported = True - if not start_supported and read_exons[0][0] < transcript_model.exon_blocks[0][1]: - read_starts.add(read_exons[0][0]) - if abs(read_exons[-1][1] - transcript_end) <= self.args.apa_delta: - end_supported = True - if not end_supported and read_exons[-1][1] > transcript_model.exon_blocks[-1][0]: - read_ends[read_exons[-1][1]] += 1 - - new_transcript_start = None if not start_supported: - read_starts = sorted(read_starts) - for read_start in read_starts: - if read_start > transcript_start: - new_transcript_start = read_start - break - if new_transcript_start and new_transcript_start < transcript_model.exon_blocks[0][1]: - logger.debug("Changed start for transcript %s: from %d to %d" % - (transcript_model.transcript_id, transcript_model.exon_blocks[0][0], new_transcript_start)) - transcript_model.exon_blocks[0] = (new_transcript_start, transcript_model.exon_blocks[0][1]) + model = self._terminal_model(strand, left=True) + if model is not None: + new_start = self._peak_boundary(start_hist, model, lambda pos: pos < first_exon_right) + else: + new_start = self._closest_inward(sorted(start_hist.keys()), transcript_start, greater=True) + if new_start is not None and new_start != transcript_start and new_start < first_exon_right: + logger.debug("Changed start for transcript %s: from %d to %d" % + (transcript_model.transcript_id, transcript_start, new_start)) + transcript_model.exon_blocks[0] = (new_start, first_exon_right) - new_transcript_end = None if not end_supported: - read_ends = sorted(read_ends, reverse=True) - for read_end in read_ends: - if read_end < transcript_end: - new_transcript_end = read_end - break - if new_transcript_end and new_transcript_end > transcript_model.exon_blocks[-1][0]: - logger.debug("Changed end for transcript %s: from %d to %d" % - (transcript_model.transcript_id, transcript_model.exon_blocks[-1][1], new_transcript_end)) - transcript_model.exon_blocks[-1] = (transcript_model.exon_blocks[-1][0], new_transcript_end) + model = self._terminal_model(strand, left=False) + if model is not None: + new_end = self._peak_boundary(end_hist, model, lambda pos: pos > last_exon_left) + else: + new_end = self._closest_inward(sorted(end_hist.keys(), reverse=True), transcript_end, greater=False) + if new_end is not None and new_end != transcript_end and new_end > last_exon_left: + logger.debug("Changed end for transcript %s: from %d to %d" % + (transcript_model.transcript_id, transcript_end, new_end)) + transcript_model.exon_blocks[-1] = (last_exon_left, new_end) + + def _terminal_model(self, strand: str, left: bool) -> Optional["XGBClassifier"]: + # Which trained model applies to a genomic-side boundary. For '+' the + # right end is the polyA site and the left is the TSS; for '-' reversed. + # Returns None (-> positional fallback) for unstranded, or a side whose + # model is disabled: polyA needs the flag, TSS additionally needs fl_data. + if strand == '+': + is_polya_side = not left + elif strand == '-': + is_polya_side = left + else: + return None + if is_polya_side: + # polyA end refinement is always on (polishing needs no annotation). + return get_polya_model() + return get_tss_model() if self.use_tss_model else None + + @staticmethod + def _peak_boundary(histogram: Dict[int, int], model: "XGBClassifier", + valid: Callable[[int], bool]) -> Optional[int]: + # Best-supported detected peak satisfying the terminal-exon clamp. + if not histogram: + return None + peaks = [p for p in detect_peaks(histogram, model) if valid(p.position)] + if not peaks: + return None + return max(peaks, key=lambda p: p.count).position + + @staticmethod + def _closest_inward(sorted_positions: List[int], boundary: int, greater: bool) -> Optional[int]: + # Positional fallback: nearest read terminus on the inward side of the + # current boundary (matches the original end-correction behaviour). + for pos in sorted_positions: + if greater and pos > boundary: + return pos + if not greater and pos < boundary: + return pos + return None + + @staticmethod + def _confirmed_polya_pos(assignment: ReadAssignment, strand: str) -> Optional[int]: + # Detected polyA cleavage site of a read, matching the trained-model + # domain (PolyACounter): only polyA-confirmed reads, at external_polya_pos + # ('+') / external_polyt_pos ('-'). None otherwise. + info = getattr(assignment, 'polya_info', None) + if not getattr(assignment, 'polyA_found', False) or info is None: + return None + if strand == '+' and info.external_polya_pos != -1: + return info.external_polya_pos + if strand == '-' and info.external_polyt_pos != -1: + return info.external_polyt_pos + return None + + def _terminal_histograms(self, source_model: TranscriptModel, + assigned_reads: List[ReadAssignment]) -> Tuple[Dict[int, int], Dict[int, int]]: + # Build genomic-left (start) and genomic-right (end) read-terminus + # histograms matching what each trained model was fit on: + # - 3' polyA side: polyA-CONFIRMED reads at the detected cleavage site + # (so a low overall polyA rate naturally raises the support bar); + # - 5' TSS side: all stranded reads' alignment ends. + # Unstranded -> alignment ends on both sides (no polyA orientation). + strand = source_model.strand + first_exon_right = source_model.exon_blocks[0][1] + last_exon_left = source_model.exon_blocks[-1][0] + start_hist = defaultdict(int) + end_hist = defaultdict(int) + for a in assigned_reads: + ex = a.corrected_exons + if strand == '+': + if ex[0][0] < first_exon_right: # 5' TSS (all reads) + start_hist[ex[0][0]] += 1 + pos = self._confirmed_polya_pos(a, '+') # 3' polyA (confirmed) + if pos is not None and pos > last_exon_left: + end_hist[pos] += 1 + elif strand == '-': + pos = self._confirmed_polya_pos(a, '-') # 3' polyA (confirmed) + if pos is not None and pos < first_exon_right: + start_hist[pos] += 1 + if ex[-1][1] > last_exon_left: # 5' TSS (all reads) + end_hist[ex[-1][1]] += 1 + else: + if ex[0][0] < first_exon_right: + start_hist[ex[0][0]] += 1 + if ex[-1][1] > last_exon_left: + end_hist[ex[-1][1]] += 1 + return start_hist, end_hist + + @staticmethod + def _intron_chain_key(model: TranscriptModel) -> Tuple[Tuple[int, int], ...]: + # Terminal-end-independent identity of a transcript: the tuple of its + # introns derived from exon blocks. Monoexon -> empty tuple. + eb = model.exon_blocks + return tuple((eb[i][1], eb[i + 1][0]) for i in range(len(eb) - 1)) + + def _add_known_alternative_end_models(self) -> None: + # Part 2: for each known (reference) model, peak-call its own assigned + # reads and emit a NIC for every confident alternative polyA/TSS end, + # keeping the known (union). Deduplicated against everything already in + # storage (notably the graph-level NICs) by intron chain + both ends + # within apa_delta, so the same alternative end is never emitted twice. + existing_pairs = defaultdict(list) + # Seed with every reference isoform too: an alternative-end peak that + # lands on another annotated isoform of the same chain is that known, not + # a novel end, so it must be suppressed even if that reference was not + # emitted as a model in this locus. + for ref_exons in self.gene_info.all_isoforms_exons.values(): + ck = tuple((ref_exons[i][1], ref_exons[i + 1][0]) for i in range(len(ref_exons) - 1)) + existing_pairs[ck].append((ref_exons[0][0], ref_exons[-1][1])) + for m in self.transcript_model_storage: + existing_pairs[self._intron_chain_key(m)].append( + (m.exon_blocks[0][0], m.exon_blocks[-1][1])) + + new_models = [] + for model in self.transcript_model_storage: + if model.transcript_type == TranscriptModelType.known: + pass + elif not self.novel_apa: + # Part 2 default: known transcripts only. Part 3 (--novel_apa) + # also spins off alternative-end siblings for novel chains. + continue + for nic in self.derive_alternative_end_models( + model, self.transcript_read_ids[model.transcript_id]): + ck = self._intron_chain_key(nic) + ns, ne = nic.exon_blocks[0][0], nic.exon_blocks[-1][1] + if any(abs(ns - s) <= self.args.apa_delta and abs(ne - e) <= self.args.apa_delta + for (s, e) in existing_pairs[ck]): + continue + existing_pairs[ck].append((ns, ne)) + new_models.append(nic) + if new_models: + logger.debug("Added %d known alternative-end NICs" % len(new_models)) + self.transcript_model_storage.extend(new_models) + + def _drop_duplicate_alt_end_models(self) -> None: + # Final dedup over the whole storage (catches graph-level NICs too): drop + # any non-known model that is structurally identical (intron chain + both + # ends within apa_delta) to a reference isoform -> it is that annotated + # transcript, not a novel end; and collapse non-known models that + # duplicate an already-kept one. + ref_pairs = defaultdict(list) + for ref_exons in self.gene_info.all_isoforms_exons.values(): + ck = tuple((ref_exons[i][1], ref_exons[i + 1][0]) for i in range(len(ref_exons) - 1)) + ref_pairs[ck].append((ref_exons[0][0], ref_exons[-1][1])) + + d = self.args.apa_delta + kept = [] + kept_pairs = defaultdict(list) + dropped = 0 + for m in self.transcript_model_storage: + if m.transcript_type == TranscriptModelType.known: + kept.append(m) + continue + ck = self._intron_chain_key(m) + s, e = m.exon_blocks[0][0], m.exon_blocks[-1][1] + if any(abs(s - rs) <= d and abs(e - re) <= d for rs, re in ref_pairs[ck]) or \ + any(abs(s - ks) <= d and abs(e - ke) <= d for ks, ke in kept_pairs[ck]): + dropped += 1 + # free the model's reads for reassignment (graph-level NICs are + # already counted; post-pass NICs are not yet in the counters) + if m.transcript_id in self.internal_counter: + self.delete_from_storage(m.transcript_id) + continue + kept.append(m) + kept_pairs[ck].append((s, e)) + if dropped: + logger.debug("Dropped %d duplicate / reference-matching alt-end models" % dropped) + self.transcript_model_storage = kept + + def derive_alternative_end_models(self, source_model: TranscriptModel, + assigned_reads: List[ReadAssignment]) -> List[TranscriptModel]: + # Confident alternative polyA/TSS ends for a transcript, from its own + # reads (per-transcript, so 5'/3' stay concordant). Returns a list of new + # alternative-end TranscriptModel objects (one per alternative end, each + # changing a single terminal); the source model is left untouched. polyA + # always; TSS only when use_tss_model (--fl_data). A known source yields + # novel-in-catalog siblings; a novel source keeps its own type. + if not assigned_reads: + return [] + exon_blocks = source_model.exon_blocks + strand = source_model.strand + annotated_start = exon_blocks[0][0] + annotated_end = exon_blocks[-1][1] + first_exon_right = exon_blocks[0][1] + last_exon_left = exon_blocks[-1][0] + sibling_type = (TranscriptModelType.novel_in_catalog + if source_model.transcript_type == TranscriptModelType.known + else source_model.transcript_type) + + start_hist, end_hist = self._terminal_histograms(source_model, assigned_reads) + + new_models = [] + start_model = self._terminal_model(strand, left=True) + if start_model is not None: + for pos in self._alternative_end_positions(start_hist, start_model, annotated_start, + lambda p: p < first_exon_right): + new_models.append(self._nic_model_with_boundary(source_model, sibling_type, start=pos)) + end_model = self._terminal_model(strand, left=False) + if end_model is not None: + for pos in self._alternative_end_positions(end_hist, end_model, annotated_end, + lambda p: p > last_exon_left): + new_models.append(self._nic_model_with_boundary(source_model, sibling_type, end=pos)) + return new_models + + def _alternative_end_positions(self, histogram: Dict[int, int], model: "XGBClassifier", + annotated_pos: int, clamp: Callable[[int], bool]) -> List[int]: + # Detected peaks representing a genuine alternative end: inside the + # terminal exon, > apa_delta from the annotated end, and clearing BOTH an + # absolute (min_novel_count) and a RELATIVE support cutoff. The relative + # cutoff (terminal_position_rel x the transcript's dominant terminal peak, + # reusing the intron-graph terminal threshold) rejects minor secondary + # peaks that are not real co-expressed isoforms and self-normalizes per + # transcript, so it does not depend on absolute depth or fit one isoform. + if not histogram: + return [] + peaks = [p for p in detect_peaks(histogram, model) if clamp(p.position)] + if not peaks: + return [] + dominant = max(p.count for p in peaks) + cutoff = max(self.args.min_novel_count, self.args.terminal_position_rel * dominant) + return [p.position for p in peaks + if abs(p.position - annotated_pos) > self.args.apa_delta and p.count >= cutoff] + + def _nic_model_with_boundary(self, source_model: TranscriptModel, transcript_type: TranscriptModelType, + start: Optional[int] = None, end: Optional[int] = None) -> TranscriptModel: + # Build an alternative-end model from a source one by replacing a single + # terminal coordinate; the intron chain (internal exons) is unchanged. + # transcript_type is novel_in_catalog for a known source, or the source's + # own type for a novel source. + exon_blocks = [tuple(e) for e in source_model.exon_blocks] + if start is not None: + exon_blocks[0] = (start, exon_blocks[0][1]) + if end is not None: + exon_blocks[-1] = (exon_blocks[-1][0], end) + suffix = (TranscriptNaming.nic_transcript_suffix + if transcript_type == TranscriptModelType.novel_in_catalog + else TranscriptNaming.nnic_transcript_suffix) + new_transcript_id = TranscriptNaming.transcript_prefix + str(self.get_transcript_id()) + new_model = TranscriptModel( + self.gene_info.chr_id, source_model.strand, + new_transcript_id + ".%s" % self.gene_info.chr_id + suffix, + source_model.gene_id, exon_blocks, transcript_type) + new_model.intron_path = source_model.intron_path + logger.debug("Adding alternative-end model %s from %s (start=%s, end=%s)" % + (new_model.transcript_id, source_model.transcript_id, str(start), str(end))) + return new_model class IntronPathStorage: diff --git a/isoquant_lib/intron_graph.py b/isoquant_lib/intron_graph.py index aea142cd..ccf7d969 100644 --- a/isoquant_lib/intron_graph.py +++ b/isoquant_lib/intron_graph.py @@ -8,6 +8,7 @@ import logging import queue from collections import defaultdict +from typing import Dict, List, Optional from .common import find_closest, overlaps @@ -135,11 +136,21 @@ def simplify_correction_map(self): class IntronGraph: - def __init__(self, params, gene_info, read_assignments): + def __init__(self, params, gene_info, read_assignments, + polya_predictions: Optional[List[int]] = None, + tss_predictions: Optional[List[int]] = None): self.params = params self.gene_info = gene_info self.read_assignments = read_assignments + # Predicted polyA / TSS sites for this gene (genomic positions), reused + # from the per-gene terminal-position counters. Clustering stays the + # terminal-vertex SOURCE; these only refine vertex coordinates (snap to + # the nearest predicted site). None when unavailable (no annotation for + # polyA, no --fl_data for TSS) -> no refinement, pure clustering. + self.polya_predictions = polya_predictions + self.tss_predictions = tss_predictions + self.incoming_edges = defaultdict(set) self.outgoing_edges = defaultdict(set) self.intron_collector = IntronCollector(gene_info, params.delta) @@ -411,63 +422,84 @@ def collapse_vertex_set(self, vertex_set): def attach_terminal_positions(self): # logger.debug("Setting terminal positions paths for %s" % self.gene_info.gene_db_list[0].id) polya_ends, read_ends, polyt_starts, read_starts = self.collect_terminal_positions() - - for intron in sorted(self.intron_collector.clustered_introns): - self.attach_transcpt_ends(intron, polya_ends, read_ends, read_end=True) - self.attach_transcpt_ends(intron, polyt_starts, read_starts, read_end=False) - - def attach_transcpt_ends(self, intron, polya_confirmed_positions, read_terminal_positions, read_end=True): - read_ends_cutoff = self.params.terminal_position_abs - logger.debug(str(intron) + " => " + str(polya_confirmed_positions[intron])) - clustered_polyas = self.cluster_polya_positions(polya_confirmed_positions[intron], intron, read_end) - if clustered_polyas: - read_ends_cutoff = max(read_ends_cutoff, max(clustered_polyas.values()) * self.params.terminal_position_rel) - extra_end_positions = {} - furtherst_confirmed_position = max(clustered_polyas.keys()) if read_end else min(clustered_polyas.keys()) - for position, count in read_terminal_positions[intron].items(): - if read_end and position >= furtherst_confirmed_position + self.params.apa_delta: - extra_end_positions[position] = count - elif not read_end and position <= furtherst_confirmed_position - self.params.apa_delta : - extra_end_positions[position] = count - else: - extra_end_positions = read_terminal_positions[intron] - - if read_end and intron in self.outgoing_edges and len(self.outgoing_edges[intron]) > 0: - # intron has outgoing edges, hard cut off - neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in self.outgoing_edges[intron]) - read_ends_cutoff = max(read_ends_cutoff, neighboring_cov * self.params.terminal_internal_position_rel) - elif not read_end and intron in self.incoming_edges and len(self.incoming_edges[intron]) > 0: - # intron has incoming edges, hard cut off - neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in self.incoming_edges[intron]) - read_ends_cutoff = max(read_ends_cutoff, neighboring_cov * self.params.terminal_internal_position_rel) - - logger.debug(str(intron) + " +> " + str(extra_end_positions)) - - terminal_positions = self.cluster_terminal_positions(extra_end_positions, - read_end=read_end, - cutoff=read_ends_cutoff) - logger.debug("POLYAs clustered:") - logger.debug(clustered_polyas) - logger.debug("Teminal clustered:") - logger.debug(terminal_positions) - if read_end: - # if intron in self.terminal_known_positions: - # logger.debug("Annotated terminal positions: " + str(sorted(self.terminal_known_positions[intron]))) - # logger.debug("PolyA terminal positions: " + str(sorted(clustered_polyas.keys()))) - # logger.debug("Simple terminal positions: " + str(sorted(terminal_positions.keys()))) - for pos in clustered_polyas.keys(): - self.outgoing_edges[intron].add((VERTEX_polya, pos)) - for pos in terminal_positions.keys(): - self.outgoing_edges[intron].add((VERTEX_read_end, pos)) - else: - # if intron in self.starting_known_positions: - # logger.debug("Annotated terminal positions: " + str(sorted(self.starting_known_positions[intron]))) - # logger.debug("PolyA terminal positions: " + str(sorted(clustered_polyas.keys()))) - # logger.debug("Simple terminal positions: " + str(sorted(terminal_positions.keys()))) - for pos in clustered_polyas.keys(): - self.incoming_edges[intron].add((VERTEX_polyt, pos)) - for pos in terminal_positions.keys(): - self.incoming_edges[intron].add((VERTEX_read_start, pos)) + introns = sorted(self.intron_collector.clustered_introns) + self._attach_side(introns, polya_ends, read_ends, read_end=True) + self._attach_side(introns, polyt_starts, read_starts, read_end=False) + + def _attach_side(self, introns: List, polya_confirmed_positions: dict, + read_terminal_positions: dict, read_end: bool) -> None: + # Clustering stays the vertex SOURCE (recall identical to the ad-hoc + # method); the predicted polyA / TSS sites (reused from the per-gene + # terminal-position counters) only REFINE vertex coordinates, so we + # never emit fewer terminal vertices than clustering alone. + + # Step 1: polyA / polyT confirmed terminal positions per intron. + polya_positions = {} + for intron in introns: + clustered = self.cluster_polya_positions(polya_confirmed_positions[intron], intron, read_end) + polya_positions[intron] = self._refine_positions(clustered, self.polya_predictions) + + # Step 2: per-intron read-end cutoff and the extra (non-polyA) positions. + cutoffs = {} + extra_positions = {} + for intron in introns: + clustered_polyas = polya_positions[intron] + cutoff = self.params.terminal_position_abs + if clustered_polyas: + cutoff = max(cutoff, max(clustered_polyas.values()) * self.params.terminal_position_rel) + extra = {} + furthest = max(clustered_polyas.keys()) if read_end else min(clustered_polyas.keys()) + for position, count in read_terminal_positions[intron].items(): + if read_end and position >= furthest + self.params.apa_delta: + extra[position] = count + elif not read_end and position <= furthest - self.params.apa_delta: + extra[position] = count + else: + extra = read_terminal_positions[intron] + + neighbors = self.outgoing_edges if read_end else self.incoming_edges + if intron in neighbors and len(neighbors[intron]) > 0: + # intron has neighboring edges -> hard cutoff relative to their coverage + neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in neighbors[intron]) + cutoff = max(cutoff, neighboring_cov * self.params.terminal_internal_position_rel) + + cutoffs[intron] = cutoff + extra_positions[intron] = extra + + # Step 3: read-end / TSS terminal positions per intron. Refine toward + # the prediction set for THIS genomic side: polyA (3') for read ends + # (read_end=True), TSS (5') for read starts (read_end=False). Using TSS + # for both sides would snap 3' read ends onto 5' sites. + terminal_predictions = self.polya_predictions if read_end else self.tss_predictions + terminal_positions = {} + for intron in introns: + clustered = self.cluster_terminal_positions(extra_positions[intron], + read_end=read_end, cutoff=cutoffs[intron]) + terminal_positions[intron] = self._refine_positions(clustered, terminal_predictions) + + # Step 4: attach terminal vertices. + polya_vertex = VERTEX_polya if read_end else VERTEX_polyt + read_vertex = VERTEX_read_end if read_end else VERTEX_read_start + edges = self.outgoing_edges if read_end else self.incoming_edges + for intron in introns: + for pos in polya_positions[intron].keys(): + edges[intron].add((polya_vertex, pos)) + for pos in terminal_positions[intron].keys(): + edges[intron].add((read_vertex, pos)) + + def _refine_positions(self, clustered: Dict[int, int], + predicted_positions: Optional[List[int]]) -> Dict[int, int]: + # Snap each clustered terminal position to the nearest predicted polyA / + # TSS site within apa_delta. Keeps every clustering vertex (recall-safe), + # only nudging coordinates toward the predicted site. + if not clustered or not predicted_positions: + return clustered + refined = {} + for pos, count in clustered.items(): + nearest, diff = find_closest(pos, predicted_positions) + new_pos = nearest if (nearest is not None and diff <= self.params.apa_delta) else pos + refined[new_pos] = refined.get(new_pos, 0) + count + return refined def collect_terminal_positions(self): polya_ends = defaultdict(lambda: defaultdict(int)) diff --git a/isoquant_lib/long_read_counter.py b/isoquant_lib/long_read_counter.py index 3c61386d..a9b93156 100644 --- a/isoquant_lib/long_read_counter.py +++ b/isoquant_lib/long_read_counter.py @@ -186,6 +186,11 @@ def add_confirmed_features(self, features): def dump(self): raise NotImplementedError() + def flush(self) -> None: + # Per-gene incremental emission hook. No-op for most counters; the + # terminal (polyA/TSS) counters override it to predict per gene. + pass + def finalize(self, args=None): raise NotImplementedError() @@ -225,6 +230,10 @@ def dump(self): for p in self.counters: p.dump() + def flush(self) -> None: + for p in self.counters: + p.flush() + def finalize(self, args=None): for p in self.counters: p.finalize(args) diff --git a/isoquant_lib/parallel_workers.py b/isoquant_lib/parallel_workers.py index b59a9762..057c64dd 100644 --- a/isoquant_lib/parallel_workers.py +++ b/isoquant_lib/parallel_workers.py @@ -275,15 +275,27 @@ def construct_models_in_parallel(sample, chr_id, chr_ids, saves_prefix, args, re aggregator.global_printer.add_read_info(read_assignment) aggregator.global_counter.add_read_info(read_assignment) + # Per-gene flush of terminal (polyA/TSS) predictions: predict on this + # gene's read accumulation, then clear it. Other counters no-op here. + aggregator.global_counter.flush() + if construct_models: strategy_names = aggregator.grouping_strategy_names if hasattr(aggregator, 'grouping_strategy_names') else [] + # Reuse this gene's polyA/TSS predictions (computed just above by + # flush()) to refine the intron-graph terminal vertices. + polya_counter = getattr(aggregator, 'polya_counter', None) + tss_counter = getattr(aggregator, 'tss_counter', None) + gene_polya_predictions = polya_counter.last_gene_predictions if polya_counter else None + gene_tss_predictions = tss_counter.last_gene_predictions if tss_counter else None model_constructor = GraphBasedModelConstructor(gene_info, loader.chr_record, args, aggregator.transcript_model_global_counter, aggregator.gene_model_global_counter, transcript_id_distributor, grouping_strategy_names=strategy_names, use_technical_replicas=sample.use_technical_replicas, - string_pools=string_pools) + string_pools=string_pools, + polya_predictions=gene_polya_predictions, + tss_predictions=gene_tss_predictions) model_constructor.process(assignment_storage) if args.check_canonical: io_support.add_canonical_info(model_constructor.transcript_model_storage, gene_info) diff --git a/isoquant_lib/terminal_counter.py b/isoquant_lib/terminal_counter.py index 0879ed66..efad998e 100644 --- a/isoquant_lib/terminal_counter.py +++ b/isoquant_lib/terminal_counter.py @@ -33,21 +33,18 @@ ) from .long_read_counter import AbstractCounter from .read_groups import AbstractReadGrouper +from .terminal_peaks import ( + ANNOTATION_TOLERANCE, + FEATURE_COLUMNS, + HISTOGRAM_PAD, + PEAK_DISTANCE, + PEAK_REL_HEIGHT, + POLYA_MODEL_PATH, + TSS_MODEL_PATH, +) logger = logging.getLogger('IsoQuant') -_DATA_DIR = Path(__file__).parent / "data" -POLYA_MODEL_PATH = _DATA_DIR / "model_polya.json" -TSS_MODEL_PATH = _DATA_DIR / "model_tss.json" - -# Peak finding + classification parameters. Match values used to train the -# shipped XGBoost models -- changing these without retraining will break the -# feature alignment between fit time and inference time. -PEAK_DISTANCE = 10 -PEAK_REL_HEIGHT = 0.98 -HISTOGRAM_PAD = 10 -ANNOTATION_TOLERANCE = 10 - _ACCEPTED_ASSIGNMENT_TYPES = frozenset({ ReadAssignmentType.unique, ReadAssignmentType.unique_minor_difference, @@ -55,10 +52,6 @@ ReadAssignmentType.inconsistent_non_intronic, }) -# XGBoost feature columns (must match training) -FEATURE_COLUMNS = ['var', 'skew', 'peak_count', 'peak_width', 'entropy', - 'mean_height', 'peak_heights', 'relative_height'] - EMPTY_COLUMNS = ['chromosome', 'transcript_id', 'gene_id', 'prediction', 'counts', 'flag'] EMPTY_COLUMNS_GROUPED = EMPTY_COLUMNS + ['counts_byGroup', 'group_id'] @@ -123,6 +116,16 @@ def __init__(self, args, output_prefix: str, model_path: Path, # transcript_id -> {'chr', 'gene_id', 'data', 'annotated', # int_group_id -> list[int]} self.transcripts: dict = {} + # Chromosome-wide accumulator for the ungrouped counter: flush() merges + # each gene's per-gene buffer (self.transcripts) into this, and dump() + # predicts once over the full per-transcript histogram. Keeps the + # emitted output independent of loader granularity (a transcript whose + # reads span several gene blocks must not be split into partial rows). + self._all_transcripts: dict = {} + # Genomic positions predicted for the most recently flushed gene, so + # transcript discovery can reuse them to refine intron-graph terminal + # vertices (set by flush()). + self.last_gene_predictions: list = [] @property def model(self) -> XGBClassifier: @@ -212,25 +215,71 @@ def _read_group_id(self, read_assignment: ReadAssignment) -> int: # -- emission ------------------------------------------------------------- + def flush(self) -> None: + """Predict the current gene's transcripts for reuse, then fold them into + the chromosome-wide buffer. + + Called once per gene from the worker loop. The per-gene prediction is + exposed as :attr:`last_gene_predictions` so transcript discovery can + refine intron-graph terminal vertices with this gene's polyA/TSS sites; + the accumulated reads are then merged into :attr:`_all_transcripts` so + the emitted output is still computed once over each transcript's full + histogram in :meth:`dump`. Grouped and training counters keep + accumulating across the chromosome and emit only in :meth:`dump`.""" + if not self.ignore_read_groups or self._collecting_training: + return + rows = self._predict_rows() + self.last_gene_predictions = rows['prediction'].tolist() if rows is not None else [] + self._merge_into_all() + self.transcripts = {} + + def _merge_into_all(self) -> None: + """Fold the current per-gene buffer into the chromosome-wide one, + concatenating the read-position lists of transcripts seen before.""" + for transcript_id, entry in self.transcripts.items(): + existing = self._all_transcripts.get(transcript_id) + if existing is None: + self._all_transcripts[transcript_id] = entry + else: + existing['data'].extend(entry['data']) + def dump(self) -> None: - if not self.transcripts: - self._write_empty() + if self._collecting_training: + self._dump_training() return - df = self._build_peak_dataframe() - zero_peaks, peaks = self._split_zero_and_real_peaks(df) - if not peaks.empty: - peaks = self._rank_and_explode(peaks) + if not self.ignore_read_groups: + # Grouped output needs the full per-group positions, so it is + # computed once here with self.transcripts kept intact. + result = self._predict_rows() + if result is None: + self._write_empty() + else: + self._write_grouped(result) + return - # Developer training-data collection: dump features + true_peak label - # for every detected peak, skip the XGBoost filter, emit an empty - # prediction TSV so the rest of the pipeline still sees the file. - if self._collecting_training: - self._dump_training_features(zero_peaks, peaks) + # Ungrouped: flush the last gene into the chromosome-wide buffer, then + # predict once over each transcript's full histogram (matching the + # grouped path and master; no per-gene splitting). + self.flush() + self.transcripts = self._all_transcripts + result = self._predict_rows() + if result is None: self._write_empty() - return + else: + self._write_ungrouped(result) + def _predict_rows(self) -> Optional[pd.DataFrame]: + """Peak detection + XGBoost filter over the currently accumulated + transcripts. Returns the per-peak prediction rows (genomic position, + counts, Known/Novel flag) as a DataFrame, or None if there are none.""" + if not self.transcripts: + return None + + df = self._build_peak_dataframe() + zero_peaks, peaks = self._split_zero_and_real_peaks(df) if not peaks.empty: + peaks = self._rank_and_explode(peaks) peaks[FEATURE_COLUMNS] = peaks[FEATURE_COLUMNS].astype(float, errors='ignore') predicted = self.model.predict(peaks[FEATURE_COLUMNS].astype(float)) peaks = peaks[predicted == 1].copy() @@ -238,8 +287,7 @@ def dump(self) -> None: frames = [f for f in (zero_peaks, peaks) if not f.empty] if not frames: - self._write_empty() - return + return None result = (pd.concat(frames, axis=0, ignore_index=True, sort=False) if len(frames) > 1 else frames[0].reset_index(drop=True)) @@ -250,11 +298,20 @@ def dump(self) -> None: (result['prediction'] - result['annotated']).abs() .gt(ANNOTATION_TOLERANCE).map({True: 'Novel', False: 'Known'})) result['counts'] = result.apply(self._counts_for_peak, axis=1) + return result - if self.ignore_read_groups: - self._write_ungrouped(result) - else: - self._write_grouped(result) + def _dump_training(self) -> None: + """Training-data collection: emit per-peak features + true_peak label + (whole chromosome) and a header-only prediction TSV.""" + if not self.transcripts: + self._write_empty() + return + df = self._build_peak_dataframe() + zero_peaks, peaks = self._split_zero_and_real_peaks(df) + if not peaks.empty: + peaks = self._rank_and_explode(peaks) + self._dump_training_features(zero_peaks, peaks) + self._write_empty() def _build_peak_dataframe(self) -> pd.DataFrame: """One row per transcript with histogram, summary stats, and raw peaks.""" diff --git a/isoquant_lib/terminal_peaks.py b/isoquant_lib/terminal_peaks.py new file mode 100644 index 00000000..4125b523 --- /dev/null +++ b/isoquant_lib/terminal_peaks.py @@ -0,0 +1,221 @@ +############################################################################ +# Copyright (c) 2022-2026 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +"""Shared polyA / TSS peak detection. + +Single source of truth for the peak-detection tunables and a reusable +detector that turns a per-feature histogram of terminal read positions into a +set of accepted peaks. Both the side-output counters (`terminal_counter.py`) +and transcript discovery (`intron_graph.py`, +`graph_based_model_construction.py`) consume this module so the feature +computation cannot drift between training-time and the two inference sites. + +The detection mirrors `TerminalCounter`'s pipeline for a single histogram: +build a zero-padded base-1 histogram, compute summary features, run +`scipy.signal.find_peaks`, fall back to the distribution mode when no peak is +found, rank peaks by height, and keep the ones the trained XGBoost classifier +accepts. Coordinates returned are genomic. +""" + +import logging +from pathlib import Path +from typing import Dict, List, NamedTuple, Optional + +import numpy as np +import pandas as pd +from scipy import stats +from scipy.signal import find_peaks, peak_widths +from xgboost import XGBClassifier + +logger = logging.getLogger('IsoQuant') + +_DATA_DIR = Path(__file__).parent / "data" +POLYA_MODEL_PATH = _DATA_DIR / "model_polya.json" +TSS_MODEL_PATH = _DATA_DIR / "model_tss.json" + +# Peak finding + classification parameters. Match values used to train the +# shipped XGBoost models -- changing these without retraining will break the +# feature alignment between fit time and inference time. +PEAK_DISTANCE = 10 +PEAK_REL_HEIGHT = 0.98 +HISTOGRAM_PAD = 10 +ANNOTATION_TOLERANCE = 10 + +# XGBoost feature columns (must match training). +FEATURE_COLUMNS = ['var', 'skew', 'peak_count', 'peak_width', 'entropy', + 'mean_height', 'peak_heights', 'relative_height'] + +# Process-level model cache. The models are loaded lazily on first use so the +# XGBoost / OpenMP runtime is only initialized inside the fork() workers that +# actually call detect_peaks -- never in the parent, where the inherited +# OpenMP state would deadlock the pool. One instance per worker process. +_MODEL_CACHE: Dict[str, XGBClassifier] = {} + + +def _get_model(model_path: Path) -> XGBClassifier: + key = str(model_path) + model = _MODEL_CACHE.get(key) + if model is None: + model = XGBClassifier() + model.load_model(key) + _MODEL_CACHE[key] = model + return model + + +def get_polya_model() -> XGBClassifier: + return _get_model(POLYA_MODEL_PATH) + + +def get_tss_model() -> XGBClassifier: + return _get_model(TSS_MODEL_PATH) + + +class Peak(NamedTuple): + """An accepted terminal-position peak in genomic coordinates.""" + position: int # predicted site + count: int # reads supporting the peak window [left, right] + left: int # genomic left bound of the peak window + right: int # genomic right bound of the peak window + + +def detect_peaks(position_counts: Dict[int, int], + model: XGBClassifier) -> List[Peak]: + """Detect accepted peaks in a single ``{genomic_position: count}`` histogram.""" + return detect_peaks_batch([position_counts], model)[0] + + +def detect_peaks_batch(histograms: List[Optional[Dict[int, int]]], + model: XGBClassifier) -> List[List[Peak]]: + """Detect peaks for several histograms with a single batched ``predict``. + + Returns a list parallel to ``histograms``; each element is the list of + accepted :class:`Peak` for that histogram (empty when the histogram is + empty or every peak is rejected). + """ + results: List[List[Peak]] = [[] for _ in histograms] + records: Dict[int, dict] = {} + feature_rows: List[list] = [] + row_owner: List[tuple] = [] # (histogram_index, peak_dict) + + for i, counts in enumerate(histograms): + if not counts: + continue + record = _histogram_record(counts) + records[i] = record + + if record['peak_count'] == 0: + # No detected peak: fall back to the distribution mode (always + # accepted, never scored by the classifier). + mode = record['mode'] + position = record['start'] + mode + count = _window_count(record['histogram'], mode, mode) + results[i].append(Peak(position, count, position, position)) + continue + + for peak in _rank_peaks(record): + feature_rows.append([record['var'], record['skew'], + record['peak_count'], peak['peak_width'], + record['entropy'], record['mean_height'], + peak['peak_heights'], peak['relative_height']]) + row_owner.append((i, peak)) + + if feature_rows: + features = pd.DataFrame(feature_rows, columns=FEATURE_COLUMNS).astype(float) + predicted = model.predict(features) + for (i, peak), keep in zip(row_owner, predicted): + if int(keep) != 1: + continue + record = records[i] + start = record['start'] + position = start + int(peak['peak_location']) + count = _window_count(record['histogram'], + peak['peak_left'], peak['peak_right']) + results[i].append(Peak(position, count, + start + int(peak['peak_left']), + start + int(peak['peak_right']))) + return results + + +def _histogram_record(position_counts: Dict[int, int]) -> dict: + """Build the per-histogram features and raw peaks (mirrors + ``TerminalCounter._build_peak_dataframe`` for one transcript).""" + data: List[int] = [] + for position, count in position_counts.items(): + if count > 0: + data.extend([int(position)] * int(count)) + + data_min = min(data) + data_max = max(data) + hist_counts = np.histogram( + data, bins=(data_max - data_min) + 1, range=(data_min, data_max))[0] + padded = [0] * HISTOGRAM_PAD + list(hist_counts) + [0] * HISTOGRAM_PAD + + var = float(np.var(data)) + record = { + 'start': data_min, + 'mode': int(stats.mode(data, keepdims=False).mode) - data_min, + 'var': var, + 'skew': (float(stats.skew(data)) + if len(data) > 3 and var >= 1e-12 else float('nan')), + 'entropy': float(stats.entropy(padded)), + 'mean_height': float(np.mean(padded)), + 'histogram': padded, + } + + # find_peaks returns padded-array indices; subtract HISTOGRAM_PAD to get + # start-relative coordinates. + peaks_idx = find_peaks(padded, distance=PEAK_DISTANCE)[0] + record['peak_count'] = len(peaks_idx) + if record['peak_count'] == 0: + return record + + widths = peak_widths(padded, peaks_idx, rel_height=PEAK_REL_HEIGHT) + peak_location = [int(j - HISTOGRAM_PAD) for j in peaks_idx] + record['raw_peaks'] = { + 'peak_location': peak_location, + 'peak_width': list(widths[0]), + 'peak_left': [int(j - HISTOGRAM_PAD) for j in widths[2]], + 'peak_right': [int(j - HISTOGRAM_PAD) for j in widths[3]], + 'peak_heights': [padded[int(p) + HISTOGRAM_PAD] for p in peak_location], + } + return record + + +def _rank_peaks(record: dict) -> List[dict]: + """Order a transcript's peaks by descending height and attach + ``relative_height`` (mirrors ``TerminalCounter._rank_peaks``).""" + raw = record['raw_peaks'] + heights = np.asarray(raw['peak_heights'], dtype=float) + if heights.size > 1: + order = np.argsort(-heights) + top = float(heights[order][0]) + else: + order = [0] + top = None + + ranked = [] + for idx in order: + height = raw['peak_heights'][idx] + ranked.append({ + 'peak_location': raw['peak_location'][idx], + 'peak_width': raw['peak_width'][idx], + 'peak_left': raw['peak_left'][idx], + 'peak_right': raw['peak_right'][idx], + 'peak_heights': height, + 'relative_height': (1.0 if top is None + else (height / top if top else 0.0)), + }) + return ranked + + +def _window_count(histogram: List[int], left: int, right: int) -> int: + """Reads inside the padded-histogram window ``[left, right]`` (mirrors + ``TerminalCounter._counts_for_peak``); left/right are start-relative.""" + low = max(0, int(left) + HISTOGRAM_PAD) + high = min(len(histogram), int(right) + HISTOGRAM_PAD + 1) + if high <= low: + return 0 + return int(np.asarray(histogram[low:high]).sum()) diff --git a/isoquant_tests/github/run_pipeline.py b/isoquant_tests/github/run_pipeline.py index 7f02c3af..86cfd76d 100755 --- a/isoquant_tests/github/run_pipeline.py +++ b/isoquant_tests/github/run_pipeline.py @@ -429,33 +429,43 @@ def run_transcript_quality(args, config_dict, baselines=None): log.error("Transcript evaluation exited with non-zero status: %d" % result.returncode) return -21 - # Get etalon values from YAML baselines or external file - if baselines and "transcripts" in baselines: - etalon_qaulity_dict = {k: str(v) for k, v in baselines["transcripts"].items()} - elif "etalon" in config_dict: - etalon_qaulity_dict = load_tsv_config(fix_path(config_file, config_dict["etalon"])) - else: + # Transcript-level accuracy is scored at three terminal-end tolerances: + # the default end-agnostic match plus --terminal-delta 50/10 (end-sensitive, + # produced by reduced_db_gffcompare.py via the gffcompare fork). Each maps + # to its own baseline block. See .claude/GFFCOMPARE.md. A variant is checked + # only if its baseline block exists and its stats file was produced (so a + # stock gffcompare without --terminal-delta degrades to the default only). + delta_variants = [("", "transcripts"), + (".td50", "transcripts_td50"), + (".td10", "transcripts_td10")] + if not (baselines and any(k in baselines for _, k in delta_variants)) \ + and "etalon" not in config_dict: return 0 log.info('== Checking quality metrics ==') exit_code = 0 new_etalon_outf = open(os.path.join(quality_output, "new_gtf_etalon.tsv"), "w") - for gtf_type in ['full', 'known', 'novel']: - recall, precision = parse_gffcomapre(os.path.join(quality_output, "isoquant." + gtf_type + ".stats")) - metric_name = gtf_type + "_recall" - if metric_name in etalon_qaulity_dict: - new_etalon_outf.write("%s\t%.2f\n" % (metric_name, recall)) - etalon_recall = float(etalon_qaulity_dict[metric_name]) - err_code = check_value(etalon_recall, recall , metric_name) - if err_code != 0: - exit_code = err_code - metric_name = gtf_type + "_precision" - if metric_name in etalon_qaulity_dict: - new_etalon_outf.write("%s\t%.2f\n" % (metric_name, precision)) - etalon_precision = float(etalon_qaulity_dict[metric_name]) - err_code = check_value(etalon_precision, precision, metric_name) - if err_code != 0: - exit_code = err_code + for stats_suffix, baseline_key in delta_variants: + if baselines and baseline_key in baselines: + etalon_qaulity_dict = {k: str(v) for k, v in baselines[baseline_key].items()} + elif baseline_key == "transcripts" and "etalon" in config_dict: + etalon_qaulity_dict = load_tsv_config(fix_path(config_file, config_dict["etalon"])) + else: + continue + for gtf_type in ['full', 'known', 'novel']: + stats_path = os.path.join(quality_output, "isoquant.%s%s.stats" % (gtf_type, stats_suffix)) + if not os.path.exists(stats_path): + continue + recall, precision = parse_gffcomapre(stats_path) + for metric_value, suffix in ((recall, "_recall"), (precision, "_precision")): + metric_name = gtf_type + suffix + if metric_name not in etalon_qaulity_dict: + continue + new_etalon_outf.write("%s.%s\t%.2f\n" % (baseline_key, metric_name, metric_value)) + err_code = check_value(float(etalon_qaulity_dict[metric_name]), metric_value, + baseline_key + "." + metric_name) + if err_code != 0: + exit_code = err_code new_etalon_outf.close() return exit_code diff --git a/isoquant_tests/test_common_utilities.py b/isoquant_tests/test_common_utilities.py index 34318f68..3657844a 100644 --- a/isoquant_tests/test_common_utilities.py +++ b/isoquant_tests/test_common_utilities.py @@ -31,3 +31,13 @@ def test_get_path_to_program(self): # Test existing program (assuming 'python' exists) path = get_path_to_program("python") self.assertTrue(os.path.exists(path)) + + def test_interval_bin_search(self): + intervals = [(10, 20), (30, 40), (50, 60)] + self.assertEqual(interval_bin_search(intervals, 35), 1) + self.assertEqual(interval_bin_search(intervals, 5), -1) + self.assertEqual(interval_bin_search(intervals, 100), -1) + # Empty intervals must not raise (regression: novel transcript in a + # locus with no annotated exons + a real polyA position). + self.assertEqual(interval_bin_search([], 100), -1) + self.assertEqual(interval_bin_search_rev([], 100), -1) diff --git a/isoquant_tests/test_graph_alt_ends.py b/isoquant_tests/test_graph_alt_ends.py new file mode 100644 index 00000000..5afa8f62 --- /dev/null +++ b/isoquant_tests/test_graph_alt_ends.py @@ -0,0 +1,219 @@ +############################################################################ +# Copyright (c) 2022-2026 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +"""Unit tests for alternative-end NIC discovery in +``graph_based_model_construction`` (branch ``transcript_model_ends``). + +The constructor is heavy to build (it wires up an assigner + profile +constructor), so the helpers under test are exercised on a bare instance +created with ``__new__`` and only the attributes each method touches. +``detect_peaks`` / ``get_polya_model`` / ``get_tss_model`` are stubbed so the +tests pin the gating/wiring logic, not scipy's peak finder or the shipped +XGBoost model. +""" + +import types +from argparse import Namespace + +import pytest + +from isoquant_lib import graph_based_model_construction as gbmc +from isoquant_lib.gene_info import TranscriptModel, TranscriptModelType +from isoquant_lib.terminal_peaks import Peak + +# Reuse the polyA-counter harness (stub model + read-assignment builder). +from isoquant_tests.test_polya_prediction import _StubModel, _make_read_assignment + + +class _IdDist: + def __init__(self): + self.n = 0 + + def increment(self): + self.n += 1 + return self.n + + +def _make_constructor(use_tss_model=True, apa_delta=10, min_novel_count=2, + terminal_position_rel=0.1, chr_id="chr1", all_isoforms_exons=None): + c = gbmc.GraphBasedModelConstructor.__new__(gbmc.GraphBasedModelConstructor) + c.args = Namespace(apa_delta=apa_delta, min_novel_count=min_novel_count, + terminal_position_rel=terminal_position_rel) + c.use_tss_model = use_tss_model + c.create_nics = True + c.novel_apa = False + c.id_distributor = _IdDist() + c.gene_info = types.SimpleNamespace(chr_id=chr_id, + all_isoforms_exons=all_isoforms_exons or {}) + c.internal_counter = {} + return c + + +def _model(exon_blocks, strand="+", ttype=TranscriptModelType.known, tid="T1", + gene_id="G1", chr_id="chr1"): + return TranscriptModel(chr_id, strand, tid, gene_id, + [tuple(e) for e in exon_blocks], ttype) + + +@pytest.fixture +def stub_models(monkeypatch): + model = _StubModel(accept=True) + monkeypatch.setattr(gbmc, "get_polya_model", lambda: model) + monkeypatch.setattr(gbmc, "get_tss_model", lambda: model) + return model + + +# -- static helpers ----------------------------------------------------------- + +def test_intron_chain_key_multi_exon(): + m = _model([(100, 200), (300, 400), (500, 600)]) + assert gbmc.GraphBasedModelConstructor._intron_chain_key(m) == ((200, 300), (400, 500)) + + +def test_intron_chain_key_monoexon_is_empty(): + assert gbmc.GraphBasedModelConstructor._intron_chain_key(_model([(100, 600)])) == () + + +def test_closest_inward_greater_and_less(): + f = gbmc.GraphBasedModelConstructor._closest_inward + assert f([10, 20, 30], 15, True) == 20 # first strictly greater (ascending) + assert f([10, 20, 30], 30, True) is None + assert f([30, 20, 10], 25, False) == 20 # first strictly less (descending) + assert f([30, 20, 10], 5, False) is None + + +def test_confirmed_polya_pos_plus_and_minus(): + f = gbmc.GraphBasedModelConstructor._confirmed_polya_pos + assert f(_make_read_assignment(strand="+", polya_pos=500), "+") == 500 + assert f(_make_read_assignment(strand="-", polya_pos=120), "-") == 120 + + +def test_confirmed_polya_pos_requires_polya_found(): + a = _make_read_assignment(strand="+", polya_pos=500, polyA_found=False) + assert gbmc.GraphBasedModelConstructor._confirmed_polya_pos(a, "+") is None + + +def test_confirmed_polya_pos_missing_strand_field_returns_none(): + # '+' read carries external_polya_pos but external_polyt_pos == -1 + a = _make_read_assignment(strand="+", polya_pos=500) + assert gbmc.GraphBasedModelConstructor._confirmed_polya_pos(a, "-") is None + + +# -- _terminal_model ---------------------------------------------------------- + +def test_terminal_model_polya_side_always_available(stub_models): + c = _make_constructor(use_tss_model=True) + assert c._terminal_model("+", left=False) is stub_models # 3' polyA of '+' + assert c._terminal_model("-", left=True) is stub_models # 3' polyA of '-' + + +def test_terminal_model_tss_side_needs_fl_data(stub_models): + assert _make_constructor(use_tss_model=True)._terminal_model("+", left=True) is stub_models + assert _make_constructor(use_tss_model=False)._terminal_model("+", left=True) is None + + +def test_terminal_model_dot_strand_is_none(stub_models): + c = _make_constructor() + assert c._terminal_model(".", left=True) is None + assert c._terminal_model(".", left=False) is None + + +# -- _alternative_end_positions (relative-support + apa_delta gate) ------------ + +def test_alternative_end_positions_gate(monkeypatch): + c = _make_constructor(apa_delta=10, min_novel_count=2, terminal_position_rel=0.1) + peaks = [ + Peak(position=400, count=50, left=395, right=405), # dominant, == annotated -> drop + Peak(position=500, count=8, left=495, right=505), # strong alt -> keep (>= cutoff 5) + Peak(position=600, count=3, left=595, right=605), # minor -> below cutoff -> drop + Peak(position=405, count=40, left=400, right=410), # within apa_delta of annotated -> drop + ] + monkeypatch.setattr(gbmc, "detect_peaks", lambda hist, model: peaks) + out = c._alternative_end_positions({1: 1}, _StubModel(), annotated_pos=400, clamp=lambda p: True) + assert out == [500] + + +def test_alternative_end_positions_respects_clamp(monkeypatch): + c = _make_constructor(apa_delta=10, min_novel_count=2, terminal_position_rel=0.1) + peaks = [Peak(400, 50, 395, 405), Peak(500, 8, 495, 505)] + monkeypatch.setattr(gbmc, "detect_peaks", lambda hist, model: peaks) + # clamp rejects the 500 peak -> nothing left + out = c._alternative_end_positions({1: 1}, _StubModel(), annotated_pos=400, clamp=lambda p: p < 450) + assert out == [] + + +# -- derive_alternative_end_models + _nic_model_with_boundary ------------------ + +def _fake_detect(hist, model): + # end histogram carries the alternative polyA at 500; start histogram carries + # only the annotated start at 100 (filtered out by the apa_delta gate). + if 500 in hist: + return [Peak(500, 30, 495, 505)] + if 100 in hist: + return [Peak(100, 30, 95, 105)] + return [] + + +def test_derive_alternative_end_models_known_source(monkeypatch, stub_models): + c = _make_constructor(use_tss_model=True, apa_delta=10, min_novel_count=2, + terminal_position_rel=0.1) + monkeypatch.setattr(gbmc, "detect_peaks", _fake_detect) + source = _model([(100, 200), (300, 400)], strand="+", + ttype=TranscriptModelType.known, tid="ENST1") + reads = [_make_read_assignment(strand="+", polya_pos=500, exons=[(100, 200), (300, 500)]) + for _ in range(20)] + + out = c.derive_alternative_end_models(source, reads) + + assert len(out) == 1 + nic = out[0] + assert nic.exon_blocks[0] == (100, 200) # start unchanged + assert nic.exon_blocks[-1] == (300, 500) # 3' end moved to alt polyA + assert nic.transcript_type == TranscriptModelType.novel_in_catalog + assert nic.transcript_id.endswith(".nic") + assert nic.gene_id == source.gene_id and nic.strand == "+" + # source model is left untouched + assert source.exon_blocks[-1] == (300, 400) + + +def test_derive_alternative_end_models_empty_reads(): + assert _make_constructor().derive_alternative_end_models( + _model([(100, 200), (300, 400)]), []) == [] + + +def test_nic_model_with_boundary_replaces_single_end(): + c = _make_constructor() + source = _model([(100, 200), (300, 400), (500, 600)], strand="-", tid="ENST2", gene_id="G2") + nic = c._nic_model_with_boundary(source, TranscriptModelType.novel_in_catalog, end=650) + assert nic.exon_blocks == [(100, 200), (300, 400), (500, 650)] + assert nic.transcript_id.endswith(".nic") and nic.strand == "-" and nic.gene_id == "G2" + assert source.exon_blocks[-1] == (500, 600) # source untouched + + +def test_nic_model_with_boundary_nnic_suffix_for_novel_source(): + c = _make_constructor() + source = _model([(100, 200), (300, 400)], ttype=TranscriptModelType.novel_not_in_catalog) + m = c._nic_model_with_boundary(source, TranscriptModelType.novel_not_in_catalog, start=50) + assert m.exon_blocks[0] == (50, 200) + assert m.transcript_id.endswith(".nnic") + + +# -- _drop_duplicate_alt_end_models ------------------------------------------- + +def test_drop_duplicate_alt_end_models(): + ref_exons = {"REF1": [(100, 200), (300, 400)]} + c = _make_constructor(apa_delta=10, all_isoforms_exons=ref_exons) + known = _model([(100, 200), (300, 400)], ttype=TranscriptModelType.known, tid="REF1") + # same intron chain + both ends within apa_delta of REF1 -> it IS the known -> drop + dup = _model([(105, 200), (300, 398)], ttype=TranscriptModelType.novel_in_catalog, tid="DUP") + # genuinely different 3' end -> a real alt-end NIC -> keep + keep = _model([(100, 200), (300, 800)], ttype=TranscriptModelType.novel_in_catalog, tid="KEEP") + c.transcript_model_storage = [known, dup, keep] + + c._drop_duplicate_alt_end_models() + + ids = {m.transcript_id for m in c.transcript_model_storage} + assert ids == {"REF1", "KEEP"} diff --git a/isoquant_tests/test_intron_graph_refine.py b/isoquant_tests/test_intron_graph_refine.py new file mode 100644 index 00000000..a3e42b38 --- /dev/null +++ b/isoquant_tests/test_intron_graph_refine.py @@ -0,0 +1,80 @@ +############################################################################ +# Copyright (c) 2022-2026 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +"""Unit tests for intron-graph terminal-vertex refinement +(``IntronGraph._refine_positions`` / ``_attach_side``) on branch +``transcript_model_ends``. + +The graph is heavy to build, so the methods under test run on a bare instance +created with ``__new__`` plus the few attributes/stubs they touch. +""" + +import types +from collections import defaultdict + +from isoquant_lib.intron_graph import IntronGraph, VERTEX_read_end + + +def _bare_graph(apa_delta=10, polya_predictions=None, tss_predictions=None): + g = IntronGraph.__new__(IntronGraph) + g.params = types.SimpleNamespace( + apa_delta=apa_delta, + terminal_position_abs=1, + terminal_position_rel=0.0, + terminal_internal_position_rel=0.0, + ) + g.polya_predictions = polya_predictions + g.tss_predictions = tss_predictions + g.outgoing_edges = defaultdict(set) + g.incoming_edges = defaultdict(set) + return g + + +# -- _refine_positions -------------------------------------------------------- + +def test_refine_positions_snaps_within_delta(): + g = _bare_graph(apa_delta=10) + assert g._refine_positions({105: 3}, [100, 500]) == {100: 3} + + +def test_refine_positions_leaves_when_far(): + g = _bare_graph(apa_delta=10) + assert g._refine_positions({130: 3}, [100, 500]) == {130: 3} + + +def test_refine_positions_merges_counts_on_collision(): + g = _bare_graph(apa_delta=10) + assert g._refine_positions({98: 2, 103: 5}, [100]) == {100: 7} + + +def test_refine_positions_identity_without_predictions(): + g = _bare_graph(apa_delta=10) + assert g._refine_positions({105: 3}, None) == {105: 3} + assert g._refine_positions({105: 3}, []) == {105: 3} + + +# -- _attach_side side-selection (regression for review fix #1) ---------------- + +def test_attach_side_3prime_readend_refines_with_polya_not_tss(): + # read_end=True is the genomic 3' side: a VERTEX_read_end position must be + # refined toward the polyA predictions, never the (5') TSS predictions. + # Place the read-end cluster at 1005, equidistant from polya (1000) and tss + # (1010); the chosen target tells us which set was used. + intron = (100, 200) + g = _bare_graph(apa_delta=10, polya_predictions=[1000], tss_predictions=[1010]) + g.intron_collector = types.SimpleNamespace(clustered_introns={intron: 10}) + # No polyA-confirmed vertices; one read-end cluster at 1005. + g.cluster_polya_positions = lambda positions, i, read_end: {} + g.cluster_terminal_positions = lambda extra, read_end, cutoff: {1005: 4} + + polya_confirmed = {intron: {}} + read_terminal = {intron: {1005: 4}} + g._attach_side([intron], polya_confirmed, read_terminal, read_end=True) + + vertices = g.outgoing_edges[intron] + assert (VERTEX_read_end, 1000) in vertices # snapped to polyA prediction + assert (VERTEX_read_end, 1010) not in vertices # NOT the TSS prediction + assert (VERTEX_read_end, 1005) not in vertices # and it was refined, not left raw diff --git a/isoquant_tests/test_polya_prediction.py b/isoquant_tests/test_polya_prediction.py index a5c9b04e..bd3b02a5 100644 --- a/isoquant_tests/test_polya_prediction.py +++ b/isoquant_tests/test_polya_prediction.py @@ -162,6 +162,62 @@ def test_dump_ungrouped_emits_one_row_per_peak(stub_model, tmp_path): assert df["flag"].iloc[0] == "Novel" +def test_per_gene_flush_does_not_split_transcript(stub_model, tmp_path): + # A transcript whose reads arrive across several per-gene flush() batches + # (loader yields its reads in more than one gene block) must still be + # emitted as a single row per peak with the full read count -- not one + # partial row per batch. Regression for the per-gene-flush splitting bug. + out = tmp_path / "split.tsv" + counter = tc.PolyACounter(_make_args(), str(out)) + exons = [(100, 200), (300, 400)] + # Batch 1 (10 reads), flush; batch 2 (20 reads), flush -- same peak ~420. + for offset in range(10): + counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons)) + counter.flush() + assert counter.last_gene_predictions # per-gene reuse still produced + for offset in range(20): + counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons)) + counter.flush() + counter.dump() + + df = _read_tsv(out) + assert (df["transcript_id"] == "T1").all() + # No duplicate (transcript_id, prediction) rows, and every read counted once. + assert not df.duplicated(subset=["transcript_id", "prediction"]).any() + assert df["counts"].sum() == 30 + + +def test_flush_sets_predictions_and_clears_buffer(stub_model, tmp_path): + # flush() exposes the per-gene predictions (reused by intron-graph + # refinement), folds the gene into the chromosome-wide buffer, and clears + # the per-gene buffer for the next gene. + counter = tc.PolyACounter(_make_args(), str(tmp_path / "f.tsv")) + exons = [(100, 200), (300, 400)] + for offset in range(30): + counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons)) + assert counter.transcripts # buffered before flush + counter.flush() + assert counter.last_gene_predictions # per-gene predictions exposed for reuse + assert counter.transcripts == {} # per-gene buffer cleared + assert counter._all_transcripts # merged into chromosome-wide buffer + + +def test_flush_is_noop_for_grouped_counter(stub_model, tmp_path): + # A grouped counter accumulates across the whole chromosome (whole-chr + # prediction in dump); flush() must not touch its buffer or predictions. + pool = types.SimpleNamespace(get_str=lambda i: f"g{i}") + string_pools = types.SimpleNamespace(get_read_group_pool=lambda _idx: pool) + counter = tc.PolyACounter(_make_args(), str(tmp_path / "g.tsv"), + string_pools=string_pools, group_index=0) + exons = [(100, 200), (300, 400)] + for offset in range(30): + counter.add_read_info(_make_read_assignment( + polya_pos=420 + offset % 3, exons=exons, read_group_ids=[offset % 2])) + counter.flush() + assert counter.transcripts # NOT cleared (flush no-op for grouped) + assert counter.last_gene_predictions == [] + + def test_dump_flags_peak_as_known_within_tolerance(stub_model, tmp_path): out = tmp_path / "known.tsv" counter = tc.PolyACounter(_make_args(), str(out)) diff --git a/isoquant_tests/test_terminal_peaks.py b/isoquant_tests/test_terminal_peaks.py new file mode 100644 index 00000000..c069bdab --- /dev/null +++ b/isoquant_tests/test_terminal_peaks.py @@ -0,0 +1,91 @@ +############################################################################ +# Copyright (c) 2022-2026 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +"""Unit tests for the shared polyA / TSS peak detector (terminal_peaks). + +The detector is the single source of truth used both by the side-output +prediction counters (terminal_counter) and by transcript discovery +(intron_graph / graph_based_model_construction). These tests pin its output +against terminal_counter's own pandas pipeline so the feature computation can +not silently drift between the two call sites. +""" + +from collections import Counter + +import numpy as np +import pytest + +from isoquant_lib import terminal_counter as tc +from isoquant_lib import terminal_peaks as tp +from isoquant_lib.isoform_assignment import ReadAssignmentType + +# Reuse the counter test harness (stub model + read-assignment builders). +from isoquant_tests.test_polya_prediction import ( + _StubModel, + _make_args, + _make_read_assignment, + _read_tsv, + stub_model, # noqa: F401 (pytest fixture) +) + + +def _counter_peaks(positions, tmp_path, name): + """Run the real PolyACounter pipeline over reads at ``positions`` and + return the set of (prediction, counts) it emits.""" + out = tmp_path / name + counter = tc.PolyACounter(_make_args(), str(out)) + exons = [(100, 200), (300, 400)] + for pos in positions: + counter.add_read_info(_make_read_assignment(polya_pos=pos, exons=exons)) + counter.dump() + df = _read_tsv(out) + if df.empty: + return set() + return set(zip(df["prediction"].astype(int), df["counts"].astype(int))) + + +def _detect_peaks(positions, accept=True): + histogram = dict(Counter(positions)) + peaks = tp.detect_peaks(histogram, _StubModel(accept=accept)) + return {(p.position, p.count) for p in peaks} + + +@pytest.mark.parametrize("positions", [ + [420], # single read + [420 + i % 3 for i in range(30)], # one tight peak + [420 + i % 3 for i in range(30)] + + [500 + i % 3 for i in range(20)], # two peaks (ranking) + [400, 401], # adjacent-bin plateau +]) +def test_detect_peaks_parity_with_counter(stub_model, tmp_path, positions): + counter_peaks = _counter_peaks(positions, tmp_path, "parity.tsv") + assert _detect_peaks(positions) == counter_peaks + # And the union of supporting counts is conserved either way. + assert sum(c for _, c in _detect_peaks(positions)) == \ + sum(c for _, c in counter_peaks) + + +def test_detect_peaks_empty_returns_empty(): + assert tp.detect_peaks({}, _StubModel(accept=True)) == [] + + +def test_detect_peaks_batch_matches_single(): + histograms = [ + {420 + i % 3: 1 for i in range(1)}, + {420 + i % 3: 10 for i in range(3)}, + {}, + ] + model = _StubModel(accept=True) + batched = tp.detect_peaks_batch(histograms, model) + singles = [tp.detect_peaks(h, model) for h in histograms] + assert batched == singles + assert batched[2] == [] + + +def test_detect_peaks_rejected_peak_dropped(): + # A clear peak is dropped when the model rejects it (no zero-peak fallback). + positions = [420 + i % 3 for i in range(30)] + assert _detect_peaks(positions, accept=False) == set() diff --git a/misc/prepare_simulated_reduced_db.py b/misc/prepare_simulated_reduced_db.py new file mode 100644 index 00000000..8f8de360 --- /dev/null +++ b/misc/prepare_simulated_reduced_db.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +############################################################################ +# Copyright (c) 2022-2026 University of Helsinki +# # All Rights Reserved +# See file LICENSE for details. +############################################################################ + +"""Split a simulation GTF into the reduced-db trio used by +``reduced_db_gffcompare.py`` for transcript-discovery assessment of +**alternative TSS/polyA**. + +The simulation GTF is the full reference annotation (plain transcript IDs) +plus alternative-end variants whose ID is ``_`` (same intron +chain, different terminal exon). Mapping alt-end variants onto the +reduced-db "novel" split lets the existing 3-terminal-delta assessment +measure how well alternative ends are recovered: + + .expressed.gtf all expressed transcripts (-> full) + .expressed_kept.gtf expressed plain-ID transcripts (-> known) + .excluded.gtf expressed alt-end variants (-> novel) + +"Expressed" = count > 0 in the simulated counts TSV. Alt-end IDs are matched +by ``_`` at the end (the simulation appends the alternative terminal +coordinate); plain GENCODE IDs (incl. ``_PAR_Y``) do not match. +""" +import argparse +import re +import sys +from traceback import print_exc + +ALT_END = re.compile(r'_\d+$') + + +def parse_args(): + p = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--sim_gtf", required=True, help="simulation GTF (full annotation + alt-end variants)") + p.add_argument("--counts", required=True, help="simulated counts TSV (transcript_idcounts[...])") + p.add_argument("--output_prefix", required=True, help="output prefix for the .expressed/.expressed_kept/.excluded GTFs") + return p.parse_args() + + +def load_expressed(counts_path): + expressed = set() + for line in open(counts_path): + if line.startswith("#") or not line.strip(): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 2: + continue + try: + if float(f[1]) > 0: + expressed.add(f[0]) + except ValueError: + continue + return expressed + + +def get_transcript_id(attrs): + i = attrs.find('transcript_id "') + if i < 0: + return None + return attrs[i + 15:].split('"', 1)[0] + + +def main(): + args = parse_args() + expressed = load_expressed(args.counts) + out_all = open(args.output_prefix + ".expressed.gtf", "w") + out_kept = open(args.output_prefix + ".expressed_kept.gtf", "w") + out_excl = open(args.output_prefix + ".excluded.gtf", "w") + + tx = {"all": set(), "kept": set(), "excl": set()} + for line in open(args.sim_gtf): + if line.startswith("#"): + continue + f = line.rstrip("\n").split("\t") + if len(f) < 9 or f[2] not in ("transcript", "exon"): + continue + # Some upstream synthetic GTFs (the lrgasp human ones) carry a bogus + # extra column 8 (literal "") before the attributes; the attributes + # are always the last field. Read from there and write a normalized + # 9-column line so gffcompare parses it (9-col input is unchanged). + attrs = f[-1] + tid = get_transcript_id(attrs) + if tid is None or tid not in expressed: + continue + out_line = "\t".join(f[:8] + [attrs]) + "\n" + out_all.write(out_line) + is_tx = f[2] == "transcript" + if is_tx: + tx["all"].add(tid) + if ALT_END.search(tid): + out_excl.write(out_line) + if is_tx: + tx["excl"].add(tid) + else: + out_kept.write(out_line) + if is_tx: + tx["kept"].add(tid) + + for fh in (out_all, out_kept, out_excl): + fh.close() + print("Expressed transcripts written: all=%d kept/plain(known)=%d excluded/alt-end(novel)=%d" + % (len(tx["all"]), len(tx["kept"]), len(tx["excl"]))) + + +if __name__ == "__main__": + try: + main() + except SystemExit: + raise + except Exception: + print_exc() + sys.exit(-1) diff --git a/misc/reduced_db_gffcompare.py b/misc/reduced_db_gffcompare.py index 71561835..965e0374 100755 --- a/misc/reduced_db_gffcompare.py +++ b/misc/reduced_db_gffcompare.py @@ -15,6 +15,11 @@ from common import * +# Terminal-end tolerances (bp) for transcript-level matching. None = default +# end-agnostic match; the integers use the gffcompare fork's --terminal-delta. +TERMINAL_DELTAS = [None, 50, 10] + + def parse_args(): parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--output", "-o", type=str, help="output folder", default="gtf_stats") @@ -41,15 +46,26 @@ def main(): print("Seprating known and novel transcripts") separator = SEPARATE_FUNCTORS[args.tool](args.gtf) split_gtf(args.gtf, separator, out_full_path, out_known_path, out_novel_path) - print("Running gffcompare for entire GTF") - expressed_gtf = args.genedb + ".expressed.gtf" - run_gff_compare(expressed_gtf, out_full_path, os.path.join(args.output, args.tool + ".full.stats")) - print("Running gffcompare for known transcripts") - expressed_gtf = args.genedb + ".expressed_kept.gtf" - run_gff_compare(expressed_gtf, out_known_path, os.path.join(args.output, args.tool + ".known.stats")) - print("Running gffcompare for novel transcripts") - expressed_gtf = args.genedb + ".excluded.gtf" - run_gff_compare(expressed_gtf, out_novel_path, os.path.join(args.output, args.tool + ".novel.stats")) + + # (split, reference subset, query split GTF) + splits = [ + ("full", args.genedb + ".expressed.gtf", out_full_path), + ("known", args.genedb + ".expressed_kept.gtf", out_known_path), + ("novel", args.genedb + ".excluded.gtf", out_novel_path), + ] + # Score each split at several terminal-end tolerances. None = default + # end-agnostic transcript match (-> "..stats"); the integer + # deltas use the gffcompare fork's --terminal-delta so the transcript-level + # metric becomes end-sensitive (-> "..td.stats"). + # See .claude/GFFCOMPARE.md. Requires the gffcompare fork for the deltas; + # the default run works with stock gffcompare too. + for split, reference_gtf, compared_gtf in splits: + for delta in TERMINAL_DELTAS: + suffix = "" if delta is None else (".td%d" % delta) + option = "" if delta is None else ("--terminal-delta=%d" % delta) + out_stats = os.path.join(args.output, "%s.%s%s.stats" % (args.tool, split, suffix)) + print("Running gffcompare for %s transcripts (terminal-delta=%s)" % (split, delta)) + run_gff_compare(reference_gtf, compared_gtf, out_stats, additional_option=option) if __name__ == "__main__":