diff --git a/.github/workflows/Human.ONT_simulated.TSS.yml b/.github/workflows/Human.ONT_simulated.TSS.yml
new file mode 100644
index 00000000..24c4670a
--- /dev/null
+++ b/.github/workflows/Human.ONT_simulated.TSS.yml
@@ -0,0 +1,76 @@
+name: Human ONT R10 simulated TSS (discovery + prediction)
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 5 * * 2'
+
+env:
+ RUN_NAME: Human.ONT_simulated.TSS
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant discovery + TSS'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: Human.ONT_simulated.TSS
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/.github/workflows/Human.ONT_simulated.polyA_1.yml b/.github/workflows/Human.ONT_simulated.polyA_1.yml
new file mode 100644
index 00000000..e1235e1f
--- /dev/null
+++ b/.github/workflows/Human.ONT_simulated.polyA_1.yml
@@ -0,0 +1,76 @@
+name: Human ONT R10 simulated polyA 1 (discovery + prediction)
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 5 * * 1'
+
+env:
+ RUN_NAME: Human.ONT_simulated.polyA_1
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant discovery + polyA'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: Human.ONT_simulated.polyA_1
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml b/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml
new file mode 100644
index 00000000..78c13528
--- /dev/null
+++ b/.github/workflows/Human.ONT_simulated.polyA_2.no_gtf.yml
@@ -0,0 +1,76 @@
+name: Human ONT simulated polyA 2 no annotation
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 5 * * 0'
+
+env:
+ RUN_NAME: Human.ONT_simulated.polyA_2.no_gtf
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant de-novo discovery (no annotation)'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: Human.ONT_simulated.polyA_2.no_gtf
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/.github/workflows/Human.ONT_simulated.polyA_2.yml b/.github/workflows/Human.ONT_simulated.polyA_2.yml
new file mode 100644
index 00000000..a73263b1
--- /dev/null
+++ b/.github/workflows/Human.ONT_simulated.polyA_2.yml
@@ -0,0 +1,76 @@
+name: Human ONT R10 simulated polyA 2 (discovery + prediction)
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 7 * * 1'
+
+env:
+ RUN_NAME: Human.ONT_simulated.polyA_2
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant discovery + polyA'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: Human.ONT_simulated.polyA_2
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/.github/workflows/Mouse.ONT_simulated.TSS.yml b/.github/workflows/Mouse.ONT_simulated.TSS.yml
new file mode 100644
index 00000000..527f060f
--- /dev/null
+++ b/.github/workflows/Mouse.ONT_simulated.TSS.yml
@@ -0,0 +1,76 @@
+name: Mouse ONT R10 simulated TSS (discovery + prediction)
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 7 * * 2'
+
+env:
+ RUN_NAME: Mouse.ONT_simulated.TSS
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant discovery + TSS'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: Mouse.ONT_simulated.TSS
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/.github/workflows/SIRVs.simulated_perfect_polya.yml b/.github/workflows/SIRVs.simulated_perfect_polya.yml
new file mode 100644
index 00000000..c32005ae
--- /dev/null
+++ b/.github/workflows/SIRVs.simulated_perfect_polya.yml
@@ -0,0 +1,76 @@
+name: SIRVs simulated perfect polyA (discovery)
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: '0 9 * * 2'
+
+env:
+ RUN_NAME: SIRVs.simulated_perfect_polya
+ LAUNCHER: ${{github.workspace}}/isoquant_tests/github/run_pipeline.py
+ CFG_DIR: /abga/work/andreyp/ci_isoquant/data
+ BIN_PATH: /abga/work/andreyp/ci_isoquant/bin/
+ OUTPUT_BASE: /abga/work/andreyp/ci_isoquant/output/${{github.ref_name}}/
+
+concurrency:
+ group: ${{github.workflow}}
+ cancel-in-progress: false
+
+jobs:
+ check-changes:
+ runs-on:
+ labels: [isoquant]
+ name: 'Check for recent changes'
+ outputs:
+ has_changes: ${{steps.check.outputs.has_changes}}
+ steps:
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: 'Check for commits in last 7 days'
+ id: check
+ run: |
+ # Always run on manual trigger
+ if [ "${{github.event_name}}" = "workflow_dispatch" ]; then
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ exit 0
+ fi
+ # Check for commits in last 7 days
+ COMMITS=$(git log --oneline --since="7 days ago" | wc -l)
+ if [ "$COMMITS" -gt 0 ]; then
+ echo "Found $COMMITS commits in last 7 days"
+ echo "has_changes=true" >> $GITHUB_OUTPUT
+ else
+ echo "No commits in last 7 days, skipping"
+ echo "has_changes=false" >> $GITHUB_OUTPUT
+ fi
+
+ launch-runner:
+ needs: check-changes
+ if: needs.check-changes.outputs.has_changes == 'true'
+ runs-on:
+ labels: [isoquant]
+ name: 'Running IsoQuant and QC'
+
+ steps:
+ - name: 'Cleanup'
+ run: >
+ set -e &&
+ shopt -s dotglob &&
+ rm -rf *
+
+ - name: 'Checkout'
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 'IsoQuant discovery'
+ if: always()
+ shell: bash
+ env:
+ STEP_NAME: SIRVs.simulated_perfect_polya
+ run: |
+ export PATH=${{env.BIN_PATH}}:$PATH
+ python3 ${{env.LAUNCHER}} ${{env.CFG_DIR}}/${{env.STEP_NAME}}.yaml -o ${{env.OUTPUT_BASE}}
diff --git a/isoquant.py b/isoquant.py
index 30179a57..cf3e3f6b 100755
--- a/isoquant.py
+++ b/isoquant.py
@@ -360,6 +360,12 @@ def add_hidden_option(*args, **kwargs): # show command only with --full-help
add_hidden_option("--collect_tss_training", type=str, default=None,
help="Developer: dump per-peak features + true_peak label for TSS training to this CSV path.")
+ # Alternative-polyA/TSS isoform discovery is default ON when an annotation is
+ # given (--genedb). --novel_apa (default off) extends it to novel
+ # (non-reference) transcripts, not only known ones.
+ add_hidden_option("--novel_apa", action="store_true", default=False,
+ help="Developer: extend alternative-end isoform creation to novel transcripts too.")
+
isoquant_version = "3.12.0"
try:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "VERSION")) as version_f:
diff --git a/isoquant_lib/common.py b/isoquant_lib/common.py
index 46373f6f..6f304c75 100644
--- a/isoquant_lib/common.py
+++ b/isoquant_lib/common.py
@@ -13,6 +13,7 @@
import math
from collections import defaultdict
from enum import Enum
+from typing import List, Tuple
logger = logging.getLogger('IsoQuant')
@@ -882,7 +883,9 @@ def get_strand(introns, reference_region, ref_region_start=1):
# binary search of a coordinate in ordered non-overlapping intervals
-def interval_bin_search(ordered_intervals, pos):
+def interval_bin_search(ordered_intervals: List[Tuple[int, int]], pos: int) -> int:
+ if not ordered_intervals:
+ return -1
if pos > ordered_intervals[-1][1] or pos < ordered_intervals[0][0]:
return -1
@@ -901,7 +904,9 @@ def interval_bin_search(ordered_intervals, pos):
return ind
-def interval_bin_search_rev(ordered_intervals, pos):
+def interval_bin_search_rev(ordered_intervals: List[Tuple[int, int]], pos: int) -> int:
+ if not ordered_intervals:
+ return -1
if pos > ordered_intervals[-1][1] or pos < ordered_intervals[0][0]:
return -1
diff --git a/isoquant_lib/graph_based_model_construction.py b/isoquant_lib/graph_based_model_construction.py
index e7030b2b..7a20921e 100644
--- a/isoquant_lib/graph_based_model_construction.py
+++ b/isoquant_lib/graph_based_model_construction.py
@@ -9,6 +9,7 @@
from collections import defaultdict
from functools import cmp_to_key
from enum import unique, Enum
+from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from .common import (
TranscriptNaming,
@@ -33,6 +34,10 @@
from .long_read_assigner import LongReadAssigner
from .long_read_profiles import CombinedProfileConstructor
from .polya_finder import PolyAInfo
+from .terminal_peaks import detect_peaks, get_polya_model, get_tss_model
+
+if TYPE_CHECKING:
+ from xgboost import XGBClassifier
logger = logging.getLogger('IsoQuant')
@@ -53,17 +58,39 @@ class GraphBasedModelConstructor:
def __init__(self, gene_info, chr_record, args, transcript_counter, gene_counter, id_distributor,
grouping_strategy_names=None,
use_technical_replicas=False,
- string_pools=None):
+ string_pools=None,
+ polya_predictions=None,
+ tss_predictions=None):
self.gene_info = gene_info
self.chr_record = chr_record
self.args = args
self.id_distributor = id_distributor
self.string_pools = string_pools
+ # Predicted polyA / TSS sites for this gene (genomic positions), reused
+ # from the per-gene terminal-position counters to refine intron-graph
+ # terminal vertices.
+ self.polya_predictions = polya_predictions
+ self.tss_predictions = tss_predictions
self.grouping_strategy_names = grouping_strategy_names if grouping_strategy_names else []
self.use_technical_replicas = use_technical_replicas
# Find file_name group index for technical replicas check
self.file_name_group_idx = self.grouping_strategy_names.index("file_name") if "file_name" in self.grouping_strategy_names else -1
+ # Constructed (novel) model ends are always refined to the trained
+ # polyA/TSS peaks (see correct_novel_transcript_ends). Runs with or
+ # without an annotation (needs only the shipped detector + the model's
+ # own reads); never creates or drops models.
+ # Alternative-end NIC *creation* (graph-level reclassification +
+ # post-pass known->NIC) requires an annotation to refine against, so it
+ # is gated on --genedb. Keeps the with-genedb behaviour unchanged.
+ self.create_nics = bool(getattr(args, 'genedb', None))
+ # TSS model is used (for polishing and for NIC creation) only with
+ # full-length evidence; read starts are unreliable otherwise.
+ self.use_tss_model = bool(getattr(args, 'fl_data', False))
+ # Opt-in (--novel_apa, default off): also create alternative-end isoforms
+ # for novel (non-reference) transcripts; works with or without --genedb.
+ self.novel_apa = bool(getattr(args, 'novel_apa', False))
+
self.strand_detector = StrandDetector(self.chr_record)
self.intron_genes = defaultdict(set)
self.set_gene_properties()
@@ -125,7 +152,9 @@ def select_reference_gene(self, transcript_introns, transcript_range, transcript
return None
def process(self, read_assignment_storage):
- self.intron_graph = IntronGraph(self.args, self.gene_info, read_assignment_storage)
+ self.intron_graph = IntronGraph(self.args, self.gene_info, read_assignment_storage,
+ polya_predictions=self.polya_predictions,
+ tss_predictions=self.tss_predictions)
self.path_processor = IntronPathProcessor(self.args, self.intron_graph)
self.path_storage = IntronPathStorage(self.args, self.path_processor)
self.path_storage.fill(read_assignment_storage)
@@ -366,6 +395,9 @@ def filter_transcripts(self):
filtered_storage.append(model)
self.transcript_model_storage = filtered_storage
+ if self.create_nics or self.novel_apa:
+ self._add_known_alternative_end_models()
+ self._drop_duplicate_alt_end_models()
def mapping_quality(self, transcript_id):
mapq = 0
@@ -420,6 +452,50 @@ def save_assigned_read(self, read_assignment, transcript_id):
self.internal_counter[transcript_id] += 1
self.read_assignment_counts[read_id] += 1
+ def _reference_isoform_for_path(self, assignment: ReadAssignment, path: tuple, intron_path: tuple,
+ transcript_range: Tuple[int, int]) -> Optional[str]:
+ # Reference isoform whose intron chain matches this path. It is reported
+ # as the annotated known UNLESS a *detected polyA* terminal vertex
+ # (VERTEX_polya / VERTEX_polyt) disagrees with the annotated end by more
+ # than apa_delta -> then the caller emits a novel-in-catalog isoform with
+ # the refined ends. A bare read_end / read_start (no polyA evidence, e.g.
+ # a degraded ONT terminus) never triggers reclassification, so known
+ # transcripts are not lost to noise.
+ if is_matching_assignment(assignment):
+ matched_reference_id = assignment.isoform_matches[0].assigned_transcript
+ elif intron_path in self.known_isoforms_in_graph:
+ matched_reference_id = self.known_isoforms_in_graph[intron_path]
+ else:
+ return None
+
+ if matched_reference_id is None or \
+ matched_reference_id not in self.gene_info.all_isoforms_exons:
+ return None
+
+ ref_exons = self.gene_info.all_isoforms_exons[matched_reference_id]
+ left_diff = abs(transcript_range[0] - ref_exons[0][0]) > self.args.apa_delta
+ right_diff = abs(transcript_range[1] - ref_exons[-1][1]) > self.args.apa_delta
+
+ # PolyA side: a detected polyA vertex (VERTEX_polyt on the genomic-left
+ # 3' end of a '-' transcript, VERTEX_polya on the genomic-right 3' end of
+ # a '+' transcript) that disagrees with the annotation -> alternative
+ # polyA NIC. Bare read termini never trigger here (degraded-end safety).
+ if path[0][0] == VERTEX_polyt and left_diff:
+ return None
+ if path[-1][0] == VERTEX_polya and right_diff:
+ return None
+
+ # TSS side: only with full-length evidence (--fl_data). The 5' end is the
+ # genomic-left read_start for '+' and the genomic-right read_end for '-'.
+ if self.use_tss_model:
+ strand = self.gene_info.isoform_strands.get(matched_reference_id, '.')
+ if strand == '+' and path[0][0] == VERTEX_read_start and left_diff:
+ return None
+ if strand == '-' and path[-1][0] == VERTEX_read_end and right_diff:
+ return None
+
+ return matched_reference_id
+
def construct_fl_isoforms(self):
# a minor trick to compare tuples of pairs, whose starting and terminating elements have different type
logger.debug("Total FL paths %d" % len(self.path_storage.fl_paths))
@@ -449,7 +525,13 @@ def construct_fl_isoforms(self):
# logger.debug("uuu Checking novel transcript %s: %s; assignment type %s" %
# (new_transcript_id, str(novel_exons), str(assignment.assignment_type)))
- if is_matching_assignment(assignment):
+ if self.create_nics:
+ # use the path's goal-1-refined terminal vertices: keep the
+ # reference only when both ends agree with the annotation within
+ # apa_delta, otherwise fall through to the novel branch below to
+ # emit an alternative-end NIC built from the refined exons
+ reference_isoform = self._reference_isoform_for_path(assignment, path, intron_path, transcript_range)
+ elif is_matching_assignment(assignment):
reference_isoform = assignment.isoform_matches[0].assigned_transcript
# logger.debug("uuu Substituting with known isoform %s" % reference_isoform)
elif intron_path in self.known_isoforms_in_graph:
@@ -796,49 +878,292 @@ def assign_reads_to_models(self, read_assignments):
else:
self.read_assignment_counts[read_id] = 0
- def correct_novel_transcript_ends(self, transcript_model, assigned_reads):
+ def correct_novel_transcript_ends(self, transcript_model: TranscriptModel,
+ assigned_reads: List[ReadAssignment]) -> None:
+ # Part 1: detector-based per-transcript end refinement. Build read-terminus
+ # histograms from this model's own reads and snap each unsupported terminal
+ # exon end to the dominant confident polyA/TSS peak (XGBoost detect_peaks);
+ # both ends are refined in one pass over the same reads so the chosen
+ # 5'/3' stay concordant. polyA always; TSS only when use_tss_model
+ # (--fl_data). No confident peak / no model -> positional fallback, then
+ # leave as is. Never creates or drops a model.
logger.debug("Verifying ends for transcript %s" % transcript_model.transcript_id)
- transcript_end = transcript_model.exon_blocks[-1][1]
transcript_start = transcript_model.exon_blocks[0][0]
- start_supported = False
- read_starts = set()
- end_supported = False
- read_ends = defaultdict(int)
+ transcript_end = transcript_model.exon_blocks[-1][1]
+ first_exon_right = transcript_model.exon_blocks[0][1]
+ last_exon_left = transcript_model.exon_blocks[-1][0]
+ strand = transcript_model.strand
+
+ start_hist, end_hist = self._terminal_histograms(transcript_model, assigned_reads)
+ start_supported = any(abs(p - transcript_start) <= self.args.apa_delta for p in start_hist)
+ end_supported = any(abs(p - transcript_end) <= self.args.apa_delta for p in end_hist)
- for assignment in assigned_reads:
- read_exons = assignment.corrected_exons
- if abs(read_exons[0][0] - transcript_start) <= self.args.apa_delta:
- start_supported = True
- if not start_supported and read_exons[0][0] < transcript_model.exon_blocks[0][1]:
- read_starts.add(read_exons[0][0])
- if abs(read_exons[-1][1] - transcript_end) <= self.args.apa_delta:
- end_supported = True
- if not end_supported and read_exons[-1][1] > transcript_model.exon_blocks[-1][0]:
- read_ends[read_exons[-1][1]] += 1
-
- new_transcript_start = None
if not start_supported:
- read_starts = sorted(read_starts)
- for read_start in read_starts:
- if read_start > transcript_start:
- new_transcript_start = read_start
- break
- if new_transcript_start and new_transcript_start < transcript_model.exon_blocks[0][1]:
- logger.debug("Changed start for transcript %s: from %d to %d" %
- (transcript_model.transcript_id, transcript_model.exon_blocks[0][0], new_transcript_start))
- transcript_model.exon_blocks[0] = (new_transcript_start, transcript_model.exon_blocks[0][1])
+ model = self._terminal_model(strand, left=True)
+ if model is not None:
+ new_start = self._peak_boundary(start_hist, model, lambda pos: pos < first_exon_right)
+ else:
+ new_start = self._closest_inward(sorted(start_hist.keys()), transcript_start, greater=True)
+ if new_start is not None and new_start != transcript_start and new_start < first_exon_right:
+ logger.debug("Changed start for transcript %s: from %d to %d" %
+ (transcript_model.transcript_id, transcript_start, new_start))
+ transcript_model.exon_blocks[0] = (new_start, first_exon_right)
- new_transcript_end = None
if not end_supported:
- read_ends = sorted(read_ends, reverse=True)
- for read_end in read_ends:
- if read_end < transcript_end:
- new_transcript_end = read_end
- break
- if new_transcript_end and new_transcript_end > transcript_model.exon_blocks[-1][0]:
- logger.debug("Changed end for transcript %s: from %d to %d" %
- (transcript_model.transcript_id, transcript_model.exon_blocks[-1][1], new_transcript_end))
- transcript_model.exon_blocks[-1] = (transcript_model.exon_blocks[-1][0], new_transcript_end)
+ model = self._terminal_model(strand, left=False)
+ if model is not None:
+ new_end = self._peak_boundary(end_hist, model, lambda pos: pos > last_exon_left)
+ else:
+ new_end = self._closest_inward(sorted(end_hist.keys(), reverse=True), transcript_end, greater=False)
+ if new_end is not None and new_end != transcript_end and new_end > last_exon_left:
+ logger.debug("Changed end for transcript %s: from %d to %d" %
+ (transcript_model.transcript_id, transcript_end, new_end))
+ transcript_model.exon_blocks[-1] = (last_exon_left, new_end)
+
+ def _terminal_model(self, strand: str, left: bool) -> Optional["XGBClassifier"]:
+ # Which trained model applies to a genomic-side boundary. For '+' the
+ # right end is the polyA site and the left is the TSS; for '-' reversed.
+ # Returns None (-> positional fallback) for unstranded, or a side whose
+ # model is disabled: polyA needs the flag, TSS additionally needs fl_data.
+ if strand == '+':
+ is_polya_side = not left
+ elif strand == '-':
+ is_polya_side = left
+ else:
+ return None
+ if is_polya_side:
+ # polyA end refinement is always on (polishing needs no annotation).
+ return get_polya_model()
+ return get_tss_model() if self.use_tss_model else None
+
+ @staticmethod
+ def _peak_boundary(histogram: Dict[int, int], model: "XGBClassifier",
+ valid: Callable[[int], bool]) -> Optional[int]:
+ # Best-supported detected peak satisfying the terminal-exon clamp.
+ if not histogram:
+ return None
+ peaks = [p for p in detect_peaks(histogram, model) if valid(p.position)]
+ if not peaks:
+ return None
+ return max(peaks, key=lambda p: p.count).position
+
+ @staticmethod
+ def _closest_inward(sorted_positions: List[int], boundary: int, greater: bool) -> Optional[int]:
+ # Positional fallback: nearest read terminus on the inward side of the
+ # current boundary (matches the original end-correction behaviour).
+ for pos in sorted_positions:
+ if greater and pos > boundary:
+ return pos
+ if not greater and pos < boundary:
+ return pos
+ return None
+
+ @staticmethod
+ def _confirmed_polya_pos(assignment: ReadAssignment, strand: str) -> Optional[int]:
+ # Detected polyA cleavage site of a read, matching the trained-model
+ # domain (PolyACounter): only polyA-confirmed reads, at external_polya_pos
+ # ('+') / external_polyt_pos ('-'). None otherwise.
+ info = getattr(assignment, 'polya_info', None)
+ if not getattr(assignment, 'polyA_found', False) or info is None:
+ return None
+ if strand == '+' and info.external_polya_pos != -1:
+ return info.external_polya_pos
+ if strand == '-' and info.external_polyt_pos != -1:
+ return info.external_polyt_pos
+ return None
+
+ def _terminal_histograms(self, source_model: TranscriptModel,
+ assigned_reads: List[ReadAssignment]) -> Tuple[Dict[int, int], Dict[int, int]]:
+ # Build genomic-left (start) and genomic-right (end) read-terminus
+ # histograms matching what each trained model was fit on:
+ # - 3' polyA side: polyA-CONFIRMED reads at the detected cleavage site
+ # (so a low overall polyA rate naturally raises the support bar);
+ # - 5' TSS side: all stranded reads' alignment ends.
+ # Unstranded -> alignment ends on both sides (no polyA orientation).
+ strand = source_model.strand
+ first_exon_right = source_model.exon_blocks[0][1]
+ last_exon_left = source_model.exon_blocks[-1][0]
+ start_hist = defaultdict(int)
+ end_hist = defaultdict(int)
+ for a in assigned_reads:
+ ex = a.corrected_exons
+ if strand == '+':
+ if ex[0][0] < first_exon_right: # 5' TSS (all reads)
+ start_hist[ex[0][0]] += 1
+ pos = self._confirmed_polya_pos(a, '+') # 3' polyA (confirmed)
+ if pos is not None and pos > last_exon_left:
+ end_hist[pos] += 1
+ elif strand == '-':
+ pos = self._confirmed_polya_pos(a, '-') # 3' polyA (confirmed)
+ if pos is not None and pos < first_exon_right:
+ start_hist[pos] += 1
+ if ex[-1][1] > last_exon_left: # 5' TSS (all reads)
+ end_hist[ex[-1][1]] += 1
+ else:
+ if ex[0][0] < first_exon_right:
+ start_hist[ex[0][0]] += 1
+ if ex[-1][1] > last_exon_left:
+ end_hist[ex[-1][1]] += 1
+ return start_hist, end_hist
+
+ @staticmethod
+ def _intron_chain_key(model: TranscriptModel) -> Tuple[Tuple[int, int], ...]:
+ # Terminal-end-independent identity of a transcript: the tuple of its
+ # introns derived from exon blocks. Monoexon -> empty tuple.
+ eb = model.exon_blocks
+ return tuple((eb[i][1], eb[i + 1][0]) for i in range(len(eb) - 1))
+
+ def _add_known_alternative_end_models(self) -> None:
+ # Part 2: for each known (reference) model, peak-call its own assigned
+ # reads and emit a NIC for every confident alternative polyA/TSS end,
+ # keeping the known (union). Deduplicated against everything already in
+ # storage (notably the graph-level NICs) by intron chain + both ends
+ # within apa_delta, so the same alternative end is never emitted twice.
+ existing_pairs = defaultdict(list)
+ # Seed with every reference isoform too: an alternative-end peak that
+ # lands on another annotated isoform of the same chain is that known, not
+ # a novel end, so it must be suppressed even if that reference was not
+ # emitted as a model in this locus.
+ for ref_exons in self.gene_info.all_isoforms_exons.values():
+ ck = tuple((ref_exons[i][1], ref_exons[i + 1][0]) for i in range(len(ref_exons) - 1))
+ existing_pairs[ck].append((ref_exons[0][0], ref_exons[-1][1]))
+ for m in self.transcript_model_storage:
+ existing_pairs[self._intron_chain_key(m)].append(
+ (m.exon_blocks[0][0], m.exon_blocks[-1][1]))
+
+ new_models = []
+ for model in self.transcript_model_storage:
+ if model.transcript_type == TranscriptModelType.known:
+ pass
+ elif not self.novel_apa:
+ # Part 2 default: known transcripts only. Part 3 (--novel_apa)
+ # also spins off alternative-end siblings for novel chains.
+ continue
+ for nic in self.derive_alternative_end_models(
+ model, self.transcript_read_ids[model.transcript_id]):
+ ck = self._intron_chain_key(nic)
+ ns, ne = nic.exon_blocks[0][0], nic.exon_blocks[-1][1]
+ if any(abs(ns - s) <= self.args.apa_delta and abs(ne - e) <= self.args.apa_delta
+ for (s, e) in existing_pairs[ck]):
+ continue
+ existing_pairs[ck].append((ns, ne))
+ new_models.append(nic)
+ if new_models:
+ logger.debug("Added %d known alternative-end NICs" % len(new_models))
+ self.transcript_model_storage.extend(new_models)
+
+ def _drop_duplicate_alt_end_models(self) -> None:
+ # Final dedup over the whole storage (catches graph-level NICs too): drop
+ # any non-known model that is structurally identical (intron chain + both
+ # ends within apa_delta) to a reference isoform -> it is that annotated
+ # transcript, not a novel end; and collapse non-known models that
+ # duplicate an already-kept one.
+ ref_pairs = defaultdict(list)
+ for ref_exons in self.gene_info.all_isoforms_exons.values():
+ ck = tuple((ref_exons[i][1], ref_exons[i + 1][0]) for i in range(len(ref_exons) - 1))
+ ref_pairs[ck].append((ref_exons[0][0], ref_exons[-1][1]))
+
+ d = self.args.apa_delta
+ kept = []
+ kept_pairs = defaultdict(list)
+ dropped = 0
+ for m in self.transcript_model_storage:
+ if m.transcript_type == TranscriptModelType.known:
+ kept.append(m)
+ continue
+ ck = self._intron_chain_key(m)
+ s, e = m.exon_blocks[0][0], m.exon_blocks[-1][1]
+ if any(abs(s - rs) <= d and abs(e - re) <= d for rs, re in ref_pairs[ck]) or \
+ any(abs(s - ks) <= d and abs(e - ke) <= d for ks, ke in kept_pairs[ck]):
+ dropped += 1
+ # free the model's reads for reassignment (graph-level NICs are
+ # already counted; post-pass NICs are not yet in the counters)
+ if m.transcript_id in self.internal_counter:
+ self.delete_from_storage(m.transcript_id)
+ continue
+ kept.append(m)
+ kept_pairs[ck].append((s, e))
+ if dropped:
+ logger.debug("Dropped %d duplicate / reference-matching alt-end models" % dropped)
+ self.transcript_model_storage = kept
+
+ def derive_alternative_end_models(self, source_model: TranscriptModel,
+ assigned_reads: List[ReadAssignment]) -> List[TranscriptModel]:
+ # Confident alternative polyA/TSS ends for a transcript, from its own
+ # reads (per-transcript, so 5'/3' stay concordant). Returns a list of new
+ # alternative-end TranscriptModel objects (one per alternative end, each
+ # changing a single terminal); the source model is left untouched. polyA
+ # always; TSS only when use_tss_model (--fl_data). A known source yields
+ # novel-in-catalog siblings; a novel source keeps its own type.
+ if not assigned_reads:
+ return []
+ exon_blocks = source_model.exon_blocks
+ strand = source_model.strand
+ annotated_start = exon_blocks[0][0]
+ annotated_end = exon_blocks[-1][1]
+ first_exon_right = exon_blocks[0][1]
+ last_exon_left = exon_blocks[-1][0]
+ sibling_type = (TranscriptModelType.novel_in_catalog
+ if source_model.transcript_type == TranscriptModelType.known
+ else source_model.transcript_type)
+
+ start_hist, end_hist = self._terminal_histograms(source_model, assigned_reads)
+
+ new_models = []
+ start_model = self._terminal_model(strand, left=True)
+ if start_model is not None:
+ for pos in self._alternative_end_positions(start_hist, start_model, annotated_start,
+ lambda p: p < first_exon_right):
+ new_models.append(self._nic_model_with_boundary(source_model, sibling_type, start=pos))
+ end_model = self._terminal_model(strand, left=False)
+ if end_model is not None:
+ for pos in self._alternative_end_positions(end_hist, end_model, annotated_end,
+ lambda p: p > last_exon_left):
+ new_models.append(self._nic_model_with_boundary(source_model, sibling_type, end=pos))
+ return new_models
+
+ def _alternative_end_positions(self, histogram: Dict[int, int], model: "XGBClassifier",
+ annotated_pos: int, clamp: Callable[[int], bool]) -> List[int]:
+ # Detected peaks representing a genuine alternative end: inside the
+ # terminal exon, > apa_delta from the annotated end, and clearing BOTH an
+ # absolute (min_novel_count) and a RELATIVE support cutoff. The relative
+ # cutoff (terminal_position_rel x the transcript's dominant terminal peak,
+ # reusing the intron-graph terminal threshold) rejects minor secondary
+ # peaks that are not real co-expressed isoforms and self-normalizes per
+ # transcript, so it does not depend on absolute depth or fit one isoform.
+ if not histogram:
+ return []
+ peaks = [p for p in detect_peaks(histogram, model) if clamp(p.position)]
+ if not peaks:
+ return []
+ dominant = max(p.count for p in peaks)
+ cutoff = max(self.args.min_novel_count, self.args.terminal_position_rel * dominant)
+ return [p.position for p in peaks
+ if abs(p.position - annotated_pos) > self.args.apa_delta and p.count >= cutoff]
+
+ def _nic_model_with_boundary(self, source_model: TranscriptModel, transcript_type: TranscriptModelType,
+ start: Optional[int] = None, end: Optional[int] = None) -> TranscriptModel:
+ # Build an alternative-end model from a source one by replacing a single
+ # terminal coordinate; the intron chain (internal exons) is unchanged.
+ # transcript_type is novel_in_catalog for a known source, or the source's
+ # own type for a novel source.
+ exon_blocks = [tuple(e) for e in source_model.exon_blocks]
+ if start is not None:
+ exon_blocks[0] = (start, exon_blocks[0][1])
+ if end is not None:
+ exon_blocks[-1] = (exon_blocks[-1][0], end)
+ suffix = (TranscriptNaming.nic_transcript_suffix
+ if transcript_type == TranscriptModelType.novel_in_catalog
+ else TranscriptNaming.nnic_transcript_suffix)
+ new_transcript_id = TranscriptNaming.transcript_prefix + str(self.get_transcript_id())
+ new_model = TranscriptModel(
+ self.gene_info.chr_id, source_model.strand,
+ new_transcript_id + ".%s" % self.gene_info.chr_id + suffix,
+ source_model.gene_id, exon_blocks, transcript_type)
+ new_model.intron_path = source_model.intron_path
+ logger.debug("Adding alternative-end model %s from %s (start=%s, end=%s)" %
+ (new_model.transcript_id, source_model.transcript_id, str(start), str(end)))
+ return new_model
class IntronPathStorage:
diff --git a/isoquant_lib/intron_graph.py b/isoquant_lib/intron_graph.py
index aea142cd..ccf7d969 100644
--- a/isoquant_lib/intron_graph.py
+++ b/isoquant_lib/intron_graph.py
@@ -8,6 +8,7 @@
import logging
import queue
from collections import defaultdict
+from typing import Dict, List, Optional
from .common import find_closest, overlaps
@@ -135,11 +136,21 @@ def simplify_correction_map(self):
class IntronGraph:
- def __init__(self, params, gene_info, read_assignments):
+ def __init__(self, params, gene_info, read_assignments,
+ polya_predictions: Optional[List[int]] = None,
+ tss_predictions: Optional[List[int]] = None):
self.params = params
self.gene_info = gene_info
self.read_assignments = read_assignments
+ # Predicted polyA / TSS sites for this gene (genomic positions), reused
+ # from the per-gene terminal-position counters. Clustering stays the
+ # terminal-vertex SOURCE; these only refine vertex coordinates (snap to
+ # the nearest predicted site). None when unavailable (no annotation for
+ # polyA, no --fl_data for TSS) -> no refinement, pure clustering.
+ self.polya_predictions = polya_predictions
+ self.tss_predictions = tss_predictions
+
self.incoming_edges = defaultdict(set)
self.outgoing_edges = defaultdict(set)
self.intron_collector = IntronCollector(gene_info, params.delta)
@@ -411,63 +422,84 @@ def collapse_vertex_set(self, vertex_set):
def attach_terminal_positions(self):
# logger.debug("Setting terminal positions paths for %s" % self.gene_info.gene_db_list[0].id)
polya_ends, read_ends, polyt_starts, read_starts = self.collect_terminal_positions()
-
- for intron in sorted(self.intron_collector.clustered_introns):
- self.attach_transcpt_ends(intron, polya_ends, read_ends, read_end=True)
- self.attach_transcpt_ends(intron, polyt_starts, read_starts, read_end=False)
-
- def attach_transcpt_ends(self, intron, polya_confirmed_positions, read_terminal_positions, read_end=True):
- read_ends_cutoff = self.params.terminal_position_abs
- logger.debug(str(intron) + " => " + str(polya_confirmed_positions[intron]))
- clustered_polyas = self.cluster_polya_positions(polya_confirmed_positions[intron], intron, read_end)
- if clustered_polyas:
- read_ends_cutoff = max(read_ends_cutoff, max(clustered_polyas.values()) * self.params.terminal_position_rel)
- extra_end_positions = {}
- furtherst_confirmed_position = max(clustered_polyas.keys()) if read_end else min(clustered_polyas.keys())
- for position, count in read_terminal_positions[intron].items():
- if read_end and position >= furtherst_confirmed_position + self.params.apa_delta:
- extra_end_positions[position] = count
- elif not read_end and position <= furtherst_confirmed_position - self.params.apa_delta :
- extra_end_positions[position] = count
- else:
- extra_end_positions = read_terminal_positions[intron]
-
- if read_end and intron in self.outgoing_edges and len(self.outgoing_edges[intron]) > 0:
- # intron has outgoing edges, hard cut off
- neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in self.outgoing_edges[intron])
- read_ends_cutoff = max(read_ends_cutoff, neighboring_cov * self.params.terminal_internal_position_rel)
- elif not read_end and intron in self.incoming_edges and len(self.incoming_edges[intron]) > 0:
- # intron has incoming edges, hard cut off
- neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in self.incoming_edges[intron])
- read_ends_cutoff = max(read_ends_cutoff, neighboring_cov * self.params.terminal_internal_position_rel)
-
- logger.debug(str(intron) + " +> " + str(extra_end_positions))
-
- terminal_positions = self.cluster_terminal_positions(extra_end_positions,
- read_end=read_end,
- cutoff=read_ends_cutoff)
- logger.debug("POLYAs clustered:")
- logger.debug(clustered_polyas)
- logger.debug("Teminal clustered:")
- logger.debug(terminal_positions)
- if read_end:
- # if intron in self.terminal_known_positions:
- # logger.debug("Annotated terminal positions: " + str(sorted(self.terminal_known_positions[intron])))
- # logger.debug("PolyA terminal positions: " + str(sorted(clustered_polyas.keys())))
- # logger.debug("Simple terminal positions: " + str(sorted(terminal_positions.keys())))
- for pos in clustered_polyas.keys():
- self.outgoing_edges[intron].add((VERTEX_polya, pos))
- for pos in terminal_positions.keys():
- self.outgoing_edges[intron].add((VERTEX_read_end, pos))
- else:
- # if intron in self.starting_known_positions:
- # logger.debug("Annotated terminal positions: " + str(sorted(self.starting_known_positions[intron])))
- # logger.debug("PolyA terminal positions: " + str(sorted(clustered_polyas.keys())))
- # logger.debug("Simple terminal positions: " + str(sorted(terminal_positions.keys())))
- for pos in clustered_polyas.keys():
- self.incoming_edges[intron].add((VERTEX_polyt, pos))
- for pos in terminal_positions.keys():
- self.incoming_edges[intron].add((VERTEX_read_start, pos))
+ introns = sorted(self.intron_collector.clustered_introns)
+ self._attach_side(introns, polya_ends, read_ends, read_end=True)
+ self._attach_side(introns, polyt_starts, read_starts, read_end=False)
+
+ def _attach_side(self, introns: List, polya_confirmed_positions: dict,
+ read_terminal_positions: dict, read_end: bool) -> None:
+ # Clustering stays the vertex SOURCE (recall identical to the ad-hoc
+ # method); the predicted polyA / TSS sites (reused from the per-gene
+ # terminal-position counters) only REFINE vertex coordinates, so we
+ # never emit fewer terminal vertices than clustering alone.
+
+ # Step 1: polyA / polyT confirmed terminal positions per intron.
+ polya_positions = {}
+ for intron in introns:
+ clustered = self.cluster_polya_positions(polya_confirmed_positions[intron], intron, read_end)
+ polya_positions[intron] = self._refine_positions(clustered, self.polya_predictions)
+
+ # Step 2: per-intron read-end cutoff and the extra (non-polyA) positions.
+ cutoffs = {}
+ extra_positions = {}
+ for intron in introns:
+ clustered_polyas = polya_positions[intron]
+ cutoff = self.params.terminal_position_abs
+ if clustered_polyas:
+ cutoff = max(cutoff, max(clustered_polyas.values()) * self.params.terminal_position_rel)
+ extra = {}
+ furthest = max(clustered_polyas.keys()) if read_end else min(clustered_polyas.keys())
+ for position, count in read_terminal_positions[intron].items():
+ if read_end and position >= furthest + self.params.apa_delta:
+ extra[position] = count
+ elif not read_end and position <= furthest - self.params.apa_delta:
+ extra[position] = count
+ else:
+ extra = read_terminal_positions[intron]
+
+ neighbors = self.outgoing_edges if read_end else self.incoming_edges
+ if intron in neighbors and len(neighbors[intron]) > 0:
+ # intron has neighboring edges -> hard cutoff relative to their coverage
+ neighboring_cov = max(self.intron_collector.clustered_introns[i] for i in neighbors[intron])
+ cutoff = max(cutoff, neighboring_cov * self.params.terminal_internal_position_rel)
+
+ cutoffs[intron] = cutoff
+ extra_positions[intron] = extra
+
+ # Step 3: read-end / TSS terminal positions per intron. Refine toward
+ # the prediction set for THIS genomic side: polyA (3') for read ends
+ # (read_end=True), TSS (5') for read starts (read_end=False). Using TSS
+ # for both sides would snap 3' read ends onto 5' sites.
+ terminal_predictions = self.polya_predictions if read_end else self.tss_predictions
+ terminal_positions = {}
+ for intron in introns:
+ clustered = self.cluster_terminal_positions(extra_positions[intron],
+ read_end=read_end, cutoff=cutoffs[intron])
+ terminal_positions[intron] = self._refine_positions(clustered, terminal_predictions)
+
+ # Step 4: attach terminal vertices.
+ polya_vertex = VERTEX_polya if read_end else VERTEX_polyt
+ read_vertex = VERTEX_read_end if read_end else VERTEX_read_start
+ edges = self.outgoing_edges if read_end else self.incoming_edges
+ for intron in introns:
+ for pos in polya_positions[intron].keys():
+ edges[intron].add((polya_vertex, pos))
+ for pos in terminal_positions[intron].keys():
+ edges[intron].add((read_vertex, pos))
+
+ def _refine_positions(self, clustered: Dict[int, int],
+ predicted_positions: Optional[List[int]]) -> Dict[int, int]:
+ # Snap each clustered terminal position to the nearest predicted polyA /
+ # TSS site within apa_delta. Keeps every clustering vertex (recall-safe),
+ # only nudging coordinates toward the predicted site.
+ if not clustered or not predicted_positions:
+ return clustered
+ refined = {}
+ for pos, count in clustered.items():
+ nearest, diff = find_closest(pos, predicted_positions)
+ new_pos = nearest if (nearest is not None and diff <= self.params.apa_delta) else pos
+ refined[new_pos] = refined.get(new_pos, 0) + count
+ return refined
def collect_terminal_positions(self):
polya_ends = defaultdict(lambda: defaultdict(int))
diff --git a/isoquant_lib/long_read_counter.py b/isoquant_lib/long_read_counter.py
index 3c61386d..a9b93156 100644
--- a/isoquant_lib/long_read_counter.py
+++ b/isoquant_lib/long_read_counter.py
@@ -186,6 +186,11 @@ def add_confirmed_features(self, features):
def dump(self):
raise NotImplementedError()
+ def flush(self) -> None:
+ # Per-gene incremental emission hook. No-op for most counters; the
+ # terminal (polyA/TSS) counters override it to predict per gene.
+ pass
+
def finalize(self, args=None):
raise NotImplementedError()
@@ -225,6 +230,10 @@ def dump(self):
for p in self.counters:
p.dump()
+ def flush(self) -> None:
+ for p in self.counters:
+ p.flush()
+
def finalize(self, args=None):
for p in self.counters:
p.finalize(args)
diff --git a/isoquant_lib/parallel_workers.py b/isoquant_lib/parallel_workers.py
index b59a9762..057c64dd 100644
--- a/isoquant_lib/parallel_workers.py
+++ b/isoquant_lib/parallel_workers.py
@@ -275,15 +275,27 @@ def construct_models_in_parallel(sample, chr_id, chr_ids, saves_prefix, args, re
aggregator.global_printer.add_read_info(read_assignment)
aggregator.global_counter.add_read_info(read_assignment)
+ # Per-gene flush of terminal (polyA/TSS) predictions: predict on this
+ # gene's read accumulation, then clear it. Other counters no-op here.
+ aggregator.global_counter.flush()
+
if construct_models:
strategy_names = aggregator.grouping_strategy_names if hasattr(aggregator, 'grouping_strategy_names') else []
+ # Reuse this gene's polyA/TSS predictions (computed just above by
+ # flush()) to refine the intron-graph terminal vertices.
+ polya_counter = getattr(aggregator, 'polya_counter', None)
+ tss_counter = getattr(aggregator, 'tss_counter', None)
+ gene_polya_predictions = polya_counter.last_gene_predictions if polya_counter else None
+ gene_tss_predictions = tss_counter.last_gene_predictions if tss_counter else None
model_constructor = GraphBasedModelConstructor(gene_info, loader.chr_record, args,
aggregator.transcript_model_global_counter,
aggregator.gene_model_global_counter,
transcript_id_distributor,
grouping_strategy_names=strategy_names,
use_technical_replicas=sample.use_technical_replicas,
- string_pools=string_pools)
+ string_pools=string_pools,
+ polya_predictions=gene_polya_predictions,
+ tss_predictions=gene_tss_predictions)
model_constructor.process(assignment_storage)
if args.check_canonical:
io_support.add_canonical_info(model_constructor.transcript_model_storage, gene_info)
diff --git a/isoquant_lib/terminal_counter.py b/isoquant_lib/terminal_counter.py
index 0879ed66..efad998e 100644
--- a/isoquant_lib/terminal_counter.py
+++ b/isoquant_lib/terminal_counter.py
@@ -33,21 +33,18 @@
)
from .long_read_counter import AbstractCounter
from .read_groups import AbstractReadGrouper
+from .terminal_peaks import (
+ ANNOTATION_TOLERANCE,
+ FEATURE_COLUMNS,
+ HISTOGRAM_PAD,
+ PEAK_DISTANCE,
+ PEAK_REL_HEIGHT,
+ POLYA_MODEL_PATH,
+ TSS_MODEL_PATH,
+)
logger = logging.getLogger('IsoQuant')
-_DATA_DIR = Path(__file__).parent / "data"
-POLYA_MODEL_PATH = _DATA_DIR / "model_polya.json"
-TSS_MODEL_PATH = _DATA_DIR / "model_tss.json"
-
-# Peak finding + classification parameters. Match values used to train the
-# shipped XGBoost models -- changing these without retraining will break the
-# feature alignment between fit time and inference time.
-PEAK_DISTANCE = 10
-PEAK_REL_HEIGHT = 0.98
-HISTOGRAM_PAD = 10
-ANNOTATION_TOLERANCE = 10
-
_ACCEPTED_ASSIGNMENT_TYPES = frozenset({
ReadAssignmentType.unique,
ReadAssignmentType.unique_minor_difference,
@@ -55,10 +52,6 @@
ReadAssignmentType.inconsistent_non_intronic,
})
-# XGBoost feature columns (must match training)
-FEATURE_COLUMNS = ['var', 'skew', 'peak_count', 'peak_width', 'entropy',
- 'mean_height', 'peak_heights', 'relative_height']
-
EMPTY_COLUMNS = ['chromosome', 'transcript_id', 'gene_id', 'prediction',
'counts', 'flag']
EMPTY_COLUMNS_GROUPED = EMPTY_COLUMNS + ['counts_byGroup', 'group_id']
@@ -123,6 +116,16 @@ def __init__(self, args, output_prefix: str, model_path: Path,
# transcript_id -> {'chr', 'gene_id', 'data', 'annotated',
# int_group_id -> list[int]}
self.transcripts: dict = {}
+ # Chromosome-wide accumulator for the ungrouped counter: flush() merges
+ # each gene's per-gene buffer (self.transcripts) into this, and dump()
+ # predicts once over the full per-transcript histogram. Keeps the
+ # emitted output independent of loader granularity (a transcript whose
+ # reads span several gene blocks must not be split into partial rows).
+ self._all_transcripts: dict = {}
+ # Genomic positions predicted for the most recently flushed gene, so
+ # transcript discovery can reuse them to refine intron-graph terminal
+ # vertices (set by flush()).
+ self.last_gene_predictions: list = []
@property
def model(self) -> XGBClassifier:
@@ -212,25 +215,71 @@ def _read_group_id(self, read_assignment: ReadAssignment) -> int:
# -- emission -------------------------------------------------------------
+ def flush(self) -> None:
+ """Predict the current gene's transcripts for reuse, then fold them into
+ the chromosome-wide buffer.
+
+ Called once per gene from the worker loop. The per-gene prediction is
+ exposed as :attr:`last_gene_predictions` so transcript discovery can
+ refine intron-graph terminal vertices with this gene's polyA/TSS sites;
+ the accumulated reads are then merged into :attr:`_all_transcripts` so
+ the emitted output is still computed once over each transcript's full
+ histogram in :meth:`dump`. Grouped and training counters keep
+ accumulating across the chromosome and emit only in :meth:`dump`."""
+ if not self.ignore_read_groups or self._collecting_training:
+ return
+ rows = self._predict_rows()
+ self.last_gene_predictions = rows['prediction'].tolist() if rows is not None else []
+ self._merge_into_all()
+ self.transcripts = {}
+
+ def _merge_into_all(self) -> None:
+ """Fold the current per-gene buffer into the chromosome-wide one,
+ concatenating the read-position lists of transcripts seen before."""
+ for transcript_id, entry in self.transcripts.items():
+ existing = self._all_transcripts.get(transcript_id)
+ if existing is None:
+ self._all_transcripts[transcript_id] = entry
+ else:
+ existing['data'].extend(entry['data'])
+
def dump(self) -> None:
- if not self.transcripts:
- self._write_empty()
+ if self._collecting_training:
+ self._dump_training()
return
- df = self._build_peak_dataframe()
- zero_peaks, peaks = self._split_zero_and_real_peaks(df)
- if not peaks.empty:
- peaks = self._rank_and_explode(peaks)
+ if not self.ignore_read_groups:
+ # Grouped output needs the full per-group positions, so it is
+ # computed once here with self.transcripts kept intact.
+ result = self._predict_rows()
+ if result is None:
+ self._write_empty()
+ else:
+ self._write_grouped(result)
+ return
- # Developer training-data collection: dump features + true_peak label
- # for every detected peak, skip the XGBoost filter, emit an empty
- # prediction TSV so the rest of the pipeline still sees the file.
- if self._collecting_training:
- self._dump_training_features(zero_peaks, peaks)
+ # Ungrouped: flush the last gene into the chromosome-wide buffer, then
+ # predict once over each transcript's full histogram (matching the
+ # grouped path and master; no per-gene splitting).
+ self.flush()
+ self.transcripts = self._all_transcripts
+ result = self._predict_rows()
+ if result is None:
self._write_empty()
- return
+ else:
+ self._write_ungrouped(result)
+ def _predict_rows(self) -> Optional[pd.DataFrame]:
+ """Peak detection + XGBoost filter over the currently accumulated
+ transcripts. Returns the per-peak prediction rows (genomic position,
+ counts, Known/Novel flag) as a DataFrame, or None if there are none."""
+ if not self.transcripts:
+ return None
+
+ df = self._build_peak_dataframe()
+ zero_peaks, peaks = self._split_zero_and_real_peaks(df)
if not peaks.empty:
+ peaks = self._rank_and_explode(peaks)
peaks[FEATURE_COLUMNS] = peaks[FEATURE_COLUMNS].astype(float, errors='ignore')
predicted = self.model.predict(peaks[FEATURE_COLUMNS].astype(float))
peaks = peaks[predicted == 1].copy()
@@ -238,8 +287,7 @@ def dump(self) -> None:
frames = [f for f in (zero_peaks, peaks) if not f.empty]
if not frames:
- self._write_empty()
- return
+ return None
result = (pd.concat(frames, axis=0, ignore_index=True, sort=False)
if len(frames) > 1 else frames[0].reset_index(drop=True))
@@ -250,11 +298,20 @@ def dump(self) -> None:
(result['prediction'] - result['annotated']).abs()
.gt(ANNOTATION_TOLERANCE).map({True: 'Novel', False: 'Known'}))
result['counts'] = result.apply(self._counts_for_peak, axis=1)
+ return result
- if self.ignore_read_groups:
- self._write_ungrouped(result)
- else:
- self._write_grouped(result)
+ def _dump_training(self) -> None:
+ """Training-data collection: emit per-peak features + true_peak label
+ (whole chromosome) and a header-only prediction TSV."""
+ if not self.transcripts:
+ self._write_empty()
+ return
+ df = self._build_peak_dataframe()
+ zero_peaks, peaks = self._split_zero_and_real_peaks(df)
+ if not peaks.empty:
+ peaks = self._rank_and_explode(peaks)
+ self._dump_training_features(zero_peaks, peaks)
+ self._write_empty()
def _build_peak_dataframe(self) -> pd.DataFrame:
"""One row per transcript with histogram, summary stats, and raw peaks."""
diff --git a/isoquant_lib/terminal_peaks.py b/isoquant_lib/terminal_peaks.py
new file mode 100644
index 00000000..4125b523
--- /dev/null
+++ b/isoquant_lib/terminal_peaks.py
@@ -0,0 +1,221 @@
+############################################################################
+# Copyright (c) 2022-2026 University of Helsinki
+# # All Rights Reserved
+# See file LICENSE for details.
+############################################################################
+
+"""Shared polyA / TSS peak detection.
+
+Single source of truth for the peak-detection tunables and a reusable
+detector that turns a per-feature histogram of terminal read positions into a
+set of accepted peaks. Both the side-output counters (`terminal_counter.py`)
+and transcript discovery (`intron_graph.py`,
+`graph_based_model_construction.py`) consume this module so the feature
+computation cannot drift between training-time and the two inference sites.
+
+The detection mirrors `TerminalCounter`'s pipeline for a single histogram:
+build a zero-padded base-1 histogram, compute summary features, run
+`scipy.signal.find_peaks`, fall back to the distribution mode when no peak is
+found, rank peaks by height, and keep the ones the trained XGBoost classifier
+accepts. Coordinates returned are genomic.
+"""
+
+import logging
+from pathlib import Path
+from typing import Dict, List, NamedTuple, Optional
+
+import numpy as np
+import pandas as pd
+from scipy import stats
+from scipy.signal import find_peaks, peak_widths
+from xgboost import XGBClassifier
+
+logger = logging.getLogger('IsoQuant')
+
+_DATA_DIR = Path(__file__).parent / "data"
+POLYA_MODEL_PATH = _DATA_DIR / "model_polya.json"
+TSS_MODEL_PATH = _DATA_DIR / "model_tss.json"
+
+# Peak finding + classification parameters. Match values used to train the
+# shipped XGBoost models -- changing these without retraining will break the
+# feature alignment between fit time and inference time.
+PEAK_DISTANCE = 10
+PEAK_REL_HEIGHT = 0.98
+HISTOGRAM_PAD = 10
+ANNOTATION_TOLERANCE = 10
+
+# XGBoost feature columns (must match training).
+FEATURE_COLUMNS = ['var', 'skew', 'peak_count', 'peak_width', 'entropy',
+ 'mean_height', 'peak_heights', 'relative_height']
+
+# Process-level model cache. The models are loaded lazily on first use so the
+# XGBoost / OpenMP runtime is only initialized inside the fork() workers that
+# actually call detect_peaks -- never in the parent, where the inherited
+# OpenMP state would deadlock the pool. One instance per worker process.
+_MODEL_CACHE: Dict[str, XGBClassifier] = {}
+
+
+def _get_model(model_path: Path) -> XGBClassifier:
+ key = str(model_path)
+ model = _MODEL_CACHE.get(key)
+ if model is None:
+ model = XGBClassifier()
+ model.load_model(key)
+ _MODEL_CACHE[key] = model
+ return model
+
+
+def get_polya_model() -> XGBClassifier:
+ return _get_model(POLYA_MODEL_PATH)
+
+
+def get_tss_model() -> XGBClassifier:
+ return _get_model(TSS_MODEL_PATH)
+
+
+class Peak(NamedTuple):
+ """An accepted terminal-position peak in genomic coordinates."""
+ position: int # predicted site
+ count: int # reads supporting the peak window [left, right]
+ left: int # genomic left bound of the peak window
+ right: int # genomic right bound of the peak window
+
+
+def detect_peaks(position_counts: Dict[int, int],
+ model: XGBClassifier) -> List[Peak]:
+ """Detect accepted peaks in a single ``{genomic_position: count}`` histogram."""
+ return detect_peaks_batch([position_counts], model)[0]
+
+
+def detect_peaks_batch(histograms: List[Optional[Dict[int, int]]],
+ model: XGBClassifier) -> List[List[Peak]]:
+ """Detect peaks for several histograms with a single batched ``predict``.
+
+ Returns a list parallel to ``histograms``; each element is the list of
+ accepted :class:`Peak` for that histogram (empty when the histogram is
+ empty or every peak is rejected).
+ """
+ results: List[List[Peak]] = [[] for _ in histograms]
+ records: Dict[int, dict] = {}
+ feature_rows: List[list] = []
+ row_owner: List[tuple] = [] # (histogram_index, peak_dict)
+
+ for i, counts in enumerate(histograms):
+ if not counts:
+ continue
+ record = _histogram_record(counts)
+ records[i] = record
+
+ if record['peak_count'] == 0:
+ # No detected peak: fall back to the distribution mode (always
+ # accepted, never scored by the classifier).
+ mode = record['mode']
+ position = record['start'] + mode
+ count = _window_count(record['histogram'], mode, mode)
+ results[i].append(Peak(position, count, position, position))
+ continue
+
+ for peak in _rank_peaks(record):
+ feature_rows.append([record['var'], record['skew'],
+ record['peak_count'], peak['peak_width'],
+ record['entropy'], record['mean_height'],
+ peak['peak_heights'], peak['relative_height']])
+ row_owner.append((i, peak))
+
+ if feature_rows:
+ features = pd.DataFrame(feature_rows, columns=FEATURE_COLUMNS).astype(float)
+ predicted = model.predict(features)
+ for (i, peak), keep in zip(row_owner, predicted):
+ if int(keep) != 1:
+ continue
+ record = records[i]
+ start = record['start']
+ position = start + int(peak['peak_location'])
+ count = _window_count(record['histogram'],
+ peak['peak_left'], peak['peak_right'])
+ results[i].append(Peak(position, count,
+ start + int(peak['peak_left']),
+ start + int(peak['peak_right'])))
+ return results
+
+
+def _histogram_record(position_counts: Dict[int, int]) -> dict:
+ """Build the per-histogram features and raw peaks (mirrors
+ ``TerminalCounter._build_peak_dataframe`` for one transcript)."""
+ data: List[int] = []
+ for position, count in position_counts.items():
+ if count > 0:
+ data.extend([int(position)] * int(count))
+
+ data_min = min(data)
+ data_max = max(data)
+ hist_counts = np.histogram(
+ data, bins=(data_max - data_min) + 1, range=(data_min, data_max))[0]
+ padded = [0] * HISTOGRAM_PAD + list(hist_counts) + [0] * HISTOGRAM_PAD
+
+ var = float(np.var(data))
+ record = {
+ 'start': data_min,
+ 'mode': int(stats.mode(data, keepdims=False).mode) - data_min,
+ 'var': var,
+ 'skew': (float(stats.skew(data))
+ if len(data) > 3 and var >= 1e-12 else float('nan')),
+ 'entropy': float(stats.entropy(padded)),
+ 'mean_height': float(np.mean(padded)),
+ 'histogram': padded,
+ }
+
+ # find_peaks returns padded-array indices; subtract HISTOGRAM_PAD to get
+ # start-relative coordinates.
+ peaks_idx = find_peaks(padded, distance=PEAK_DISTANCE)[0]
+ record['peak_count'] = len(peaks_idx)
+ if record['peak_count'] == 0:
+ return record
+
+ widths = peak_widths(padded, peaks_idx, rel_height=PEAK_REL_HEIGHT)
+ peak_location = [int(j - HISTOGRAM_PAD) for j in peaks_idx]
+ record['raw_peaks'] = {
+ 'peak_location': peak_location,
+ 'peak_width': list(widths[0]),
+ 'peak_left': [int(j - HISTOGRAM_PAD) for j in widths[2]],
+ 'peak_right': [int(j - HISTOGRAM_PAD) for j in widths[3]],
+ 'peak_heights': [padded[int(p) + HISTOGRAM_PAD] for p in peak_location],
+ }
+ return record
+
+
+def _rank_peaks(record: dict) -> List[dict]:
+ """Order a transcript's peaks by descending height and attach
+ ``relative_height`` (mirrors ``TerminalCounter._rank_peaks``)."""
+ raw = record['raw_peaks']
+ heights = np.asarray(raw['peak_heights'], dtype=float)
+ if heights.size > 1:
+ order = np.argsort(-heights)
+ top = float(heights[order][0])
+ else:
+ order = [0]
+ top = None
+
+ ranked = []
+ for idx in order:
+ height = raw['peak_heights'][idx]
+ ranked.append({
+ 'peak_location': raw['peak_location'][idx],
+ 'peak_width': raw['peak_width'][idx],
+ 'peak_left': raw['peak_left'][idx],
+ 'peak_right': raw['peak_right'][idx],
+ 'peak_heights': height,
+ 'relative_height': (1.0 if top is None
+ else (height / top if top else 0.0)),
+ })
+ return ranked
+
+
+def _window_count(histogram: List[int], left: int, right: int) -> int:
+ """Reads inside the padded-histogram window ``[left, right]`` (mirrors
+ ``TerminalCounter._counts_for_peak``); left/right are start-relative."""
+ low = max(0, int(left) + HISTOGRAM_PAD)
+ high = min(len(histogram), int(right) + HISTOGRAM_PAD + 1)
+ if high <= low:
+ return 0
+ return int(np.asarray(histogram[low:high]).sum())
diff --git a/isoquant_tests/github/run_pipeline.py b/isoquant_tests/github/run_pipeline.py
index 7f02c3af..86cfd76d 100755
--- a/isoquant_tests/github/run_pipeline.py
+++ b/isoquant_tests/github/run_pipeline.py
@@ -429,33 +429,43 @@ def run_transcript_quality(args, config_dict, baselines=None):
log.error("Transcript evaluation exited with non-zero status: %d" % result.returncode)
return -21
- # Get etalon values from YAML baselines or external file
- if baselines and "transcripts" in baselines:
- etalon_qaulity_dict = {k: str(v) for k, v in baselines["transcripts"].items()}
- elif "etalon" in config_dict:
- etalon_qaulity_dict = load_tsv_config(fix_path(config_file, config_dict["etalon"]))
- else:
+ # Transcript-level accuracy is scored at three terminal-end tolerances:
+ # the default end-agnostic match plus --terminal-delta 50/10 (end-sensitive,
+ # produced by reduced_db_gffcompare.py via the gffcompare fork). Each maps
+ # to its own baseline block. See .claude/GFFCOMPARE.md. A variant is checked
+ # only if its baseline block exists and its stats file was produced (so a
+ # stock gffcompare without --terminal-delta degrades to the default only).
+ delta_variants = [("", "transcripts"),
+ (".td50", "transcripts_td50"),
+ (".td10", "transcripts_td10")]
+ if not (baselines and any(k in baselines for _, k in delta_variants)) \
+ and "etalon" not in config_dict:
return 0
log.info('== Checking quality metrics ==')
exit_code = 0
new_etalon_outf = open(os.path.join(quality_output, "new_gtf_etalon.tsv"), "w")
- for gtf_type in ['full', 'known', 'novel']:
- recall, precision = parse_gffcomapre(os.path.join(quality_output, "isoquant." + gtf_type + ".stats"))
- metric_name = gtf_type + "_recall"
- if metric_name in etalon_qaulity_dict:
- new_etalon_outf.write("%s\t%.2f\n" % (metric_name, recall))
- etalon_recall = float(etalon_qaulity_dict[metric_name])
- err_code = check_value(etalon_recall, recall , metric_name)
- if err_code != 0:
- exit_code = err_code
- metric_name = gtf_type + "_precision"
- if metric_name in etalon_qaulity_dict:
- new_etalon_outf.write("%s\t%.2f\n" % (metric_name, precision))
- etalon_precision = float(etalon_qaulity_dict[metric_name])
- err_code = check_value(etalon_precision, precision, metric_name)
- if err_code != 0:
- exit_code = err_code
+ for stats_suffix, baseline_key in delta_variants:
+ if baselines and baseline_key in baselines:
+ etalon_qaulity_dict = {k: str(v) for k, v in baselines[baseline_key].items()}
+ elif baseline_key == "transcripts" and "etalon" in config_dict:
+ etalon_qaulity_dict = load_tsv_config(fix_path(config_file, config_dict["etalon"]))
+ else:
+ continue
+ for gtf_type in ['full', 'known', 'novel']:
+ stats_path = os.path.join(quality_output, "isoquant.%s%s.stats" % (gtf_type, stats_suffix))
+ if not os.path.exists(stats_path):
+ continue
+ recall, precision = parse_gffcomapre(stats_path)
+ for metric_value, suffix in ((recall, "_recall"), (precision, "_precision")):
+ metric_name = gtf_type + suffix
+ if metric_name not in etalon_qaulity_dict:
+ continue
+ new_etalon_outf.write("%s.%s\t%.2f\n" % (baseline_key, metric_name, metric_value))
+ err_code = check_value(float(etalon_qaulity_dict[metric_name]), metric_value,
+ baseline_key + "." + metric_name)
+ if err_code != 0:
+ exit_code = err_code
new_etalon_outf.close()
return exit_code
diff --git a/isoquant_tests/test_common_utilities.py b/isoquant_tests/test_common_utilities.py
index 34318f68..3657844a 100644
--- a/isoquant_tests/test_common_utilities.py
+++ b/isoquant_tests/test_common_utilities.py
@@ -31,3 +31,13 @@ def test_get_path_to_program(self):
# Test existing program (assuming 'python' exists)
path = get_path_to_program("python")
self.assertTrue(os.path.exists(path))
+
+ def test_interval_bin_search(self):
+ intervals = [(10, 20), (30, 40), (50, 60)]
+ self.assertEqual(interval_bin_search(intervals, 35), 1)
+ self.assertEqual(interval_bin_search(intervals, 5), -1)
+ self.assertEqual(interval_bin_search(intervals, 100), -1)
+ # Empty intervals must not raise (regression: novel transcript in a
+ # locus with no annotated exons + a real polyA position).
+ self.assertEqual(interval_bin_search([], 100), -1)
+ self.assertEqual(interval_bin_search_rev([], 100), -1)
diff --git a/isoquant_tests/test_graph_alt_ends.py b/isoquant_tests/test_graph_alt_ends.py
new file mode 100644
index 00000000..5afa8f62
--- /dev/null
+++ b/isoquant_tests/test_graph_alt_ends.py
@@ -0,0 +1,219 @@
+############################################################################
+# Copyright (c) 2022-2026 University of Helsinki
+# # All Rights Reserved
+# See file LICENSE for details.
+############################################################################
+
+"""Unit tests for alternative-end NIC discovery in
+``graph_based_model_construction`` (branch ``transcript_model_ends``).
+
+The constructor is heavy to build (it wires up an assigner + profile
+constructor), so the helpers under test are exercised on a bare instance
+created with ``__new__`` and only the attributes each method touches.
+``detect_peaks`` / ``get_polya_model`` / ``get_tss_model`` are stubbed so the
+tests pin the gating/wiring logic, not scipy's peak finder or the shipped
+XGBoost model.
+"""
+
+import types
+from argparse import Namespace
+
+import pytest
+
+from isoquant_lib import graph_based_model_construction as gbmc
+from isoquant_lib.gene_info import TranscriptModel, TranscriptModelType
+from isoquant_lib.terminal_peaks import Peak
+
+# Reuse the polyA-counter harness (stub model + read-assignment builder).
+from isoquant_tests.test_polya_prediction import _StubModel, _make_read_assignment
+
+
+class _IdDist:
+ def __init__(self):
+ self.n = 0
+
+ def increment(self):
+ self.n += 1
+ return self.n
+
+
+def _make_constructor(use_tss_model=True, apa_delta=10, min_novel_count=2,
+ terminal_position_rel=0.1, chr_id="chr1", all_isoforms_exons=None):
+ c = gbmc.GraphBasedModelConstructor.__new__(gbmc.GraphBasedModelConstructor)
+ c.args = Namespace(apa_delta=apa_delta, min_novel_count=min_novel_count,
+ terminal_position_rel=terminal_position_rel)
+ c.use_tss_model = use_tss_model
+ c.create_nics = True
+ c.novel_apa = False
+ c.id_distributor = _IdDist()
+ c.gene_info = types.SimpleNamespace(chr_id=chr_id,
+ all_isoforms_exons=all_isoforms_exons or {})
+ c.internal_counter = {}
+ return c
+
+
+def _model(exon_blocks, strand="+", ttype=TranscriptModelType.known, tid="T1",
+ gene_id="G1", chr_id="chr1"):
+ return TranscriptModel(chr_id, strand, tid, gene_id,
+ [tuple(e) for e in exon_blocks], ttype)
+
+
+@pytest.fixture
+def stub_models(monkeypatch):
+ model = _StubModel(accept=True)
+ monkeypatch.setattr(gbmc, "get_polya_model", lambda: model)
+ monkeypatch.setattr(gbmc, "get_tss_model", lambda: model)
+ return model
+
+
+# -- static helpers -----------------------------------------------------------
+
+def test_intron_chain_key_multi_exon():
+ m = _model([(100, 200), (300, 400), (500, 600)])
+ assert gbmc.GraphBasedModelConstructor._intron_chain_key(m) == ((200, 300), (400, 500))
+
+
+def test_intron_chain_key_monoexon_is_empty():
+ assert gbmc.GraphBasedModelConstructor._intron_chain_key(_model([(100, 600)])) == ()
+
+
+def test_closest_inward_greater_and_less():
+ f = gbmc.GraphBasedModelConstructor._closest_inward
+ assert f([10, 20, 30], 15, True) == 20 # first strictly greater (ascending)
+ assert f([10, 20, 30], 30, True) is None
+ assert f([30, 20, 10], 25, False) == 20 # first strictly less (descending)
+ assert f([30, 20, 10], 5, False) is None
+
+
+def test_confirmed_polya_pos_plus_and_minus():
+ f = gbmc.GraphBasedModelConstructor._confirmed_polya_pos
+ assert f(_make_read_assignment(strand="+", polya_pos=500), "+") == 500
+ assert f(_make_read_assignment(strand="-", polya_pos=120), "-") == 120
+
+
+def test_confirmed_polya_pos_requires_polya_found():
+ a = _make_read_assignment(strand="+", polya_pos=500, polyA_found=False)
+ assert gbmc.GraphBasedModelConstructor._confirmed_polya_pos(a, "+") is None
+
+
+def test_confirmed_polya_pos_missing_strand_field_returns_none():
+ # '+' read carries external_polya_pos but external_polyt_pos == -1
+ a = _make_read_assignment(strand="+", polya_pos=500)
+ assert gbmc.GraphBasedModelConstructor._confirmed_polya_pos(a, "-") is None
+
+
+# -- _terminal_model ----------------------------------------------------------
+
+def test_terminal_model_polya_side_always_available(stub_models):
+ c = _make_constructor(use_tss_model=True)
+ assert c._terminal_model("+", left=False) is stub_models # 3' polyA of '+'
+ assert c._terminal_model("-", left=True) is stub_models # 3' polyA of '-'
+
+
+def test_terminal_model_tss_side_needs_fl_data(stub_models):
+ assert _make_constructor(use_tss_model=True)._terminal_model("+", left=True) is stub_models
+ assert _make_constructor(use_tss_model=False)._terminal_model("+", left=True) is None
+
+
+def test_terminal_model_dot_strand_is_none(stub_models):
+ c = _make_constructor()
+ assert c._terminal_model(".", left=True) is None
+ assert c._terminal_model(".", left=False) is None
+
+
+# -- _alternative_end_positions (relative-support + apa_delta gate) ------------
+
+def test_alternative_end_positions_gate(monkeypatch):
+ c = _make_constructor(apa_delta=10, min_novel_count=2, terminal_position_rel=0.1)
+ peaks = [
+ Peak(position=400, count=50, left=395, right=405), # dominant, == annotated -> drop
+ Peak(position=500, count=8, left=495, right=505), # strong alt -> keep (>= cutoff 5)
+ Peak(position=600, count=3, left=595, right=605), # minor -> below cutoff -> drop
+ Peak(position=405, count=40, left=400, right=410), # within apa_delta of annotated -> drop
+ ]
+ monkeypatch.setattr(gbmc, "detect_peaks", lambda hist, model: peaks)
+ out = c._alternative_end_positions({1: 1}, _StubModel(), annotated_pos=400, clamp=lambda p: True)
+ assert out == [500]
+
+
+def test_alternative_end_positions_respects_clamp(monkeypatch):
+ c = _make_constructor(apa_delta=10, min_novel_count=2, terminal_position_rel=0.1)
+ peaks = [Peak(400, 50, 395, 405), Peak(500, 8, 495, 505)]
+ monkeypatch.setattr(gbmc, "detect_peaks", lambda hist, model: peaks)
+ # clamp rejects the 500 peak -> nothing left
+ out = c._alternative_end_positions({1: 1}, _StubModel(), annotated_pos=400, clamp=lambda p: p < 450)
+ assert out == []
+
+
+# -- derive_alternative_end_models + _nic_model_with_boundary ------------------
+
+def _fake_detect(hist, model):
+ # end histogram carries the alternative polyA at 500; start histogram carries
+ # only the annotated start at 100 (filtered out by the apa_delta gate).
+ if 500 in hist:
+ return [Peak(500, 30, 495, 505)]
+ if 100 in hist:
+ return [Peak(100, 30, 95, 105)]
+ return []
+
+
+def test_derive_alternative_end_models_known_source(monkeypatch, stub_models):
+ c = _make_constructor(use_tss_model=True, apa_delta=10, min_novel_count=2,
+ terminal_position_rel=0.1)
+ monkeypatch.setattr(gbmc, "detect_peaks", _fake_detect)
+ source = _model([(100, 200), (300, 400)], strand="+",
+ ttype=TranscriptModelType.known, tid="ENST1")
+ reads = [_make_read_assignment(strand="+", polya_pos=500, exons=[(100, 200), (300, 500)])
+ for _ in range(20)]
+
+ out = c.derive_alternative_end_models(source, reads)
+
+ assert len(out) == 1
+ nic = out[0]
+ assert nic.exon_blocks[0] == (100, 200) # start unchanged
+ assert nic.exon_blocks[-1] == (300, 500) # 3' end moved to alt polyA
+ assert nic.transcript_type == TranscriptModelType.novel_in_catalog
+ assert nic.transcript_id.endswith(".nic")
+ assert nic.gene_id == source.gene_id and nic.strand == "+"
+ # source model is left untouched
+ assert source.exon_blocks[-1] == (300, 400)
+
+
+def test_derive_alternative_end_models_empty_reads():
+ assert _make_constructor().derive_alternative_end_models(
+ _model([(100, 200), (300, 400)]), []) == []
+
+
+def test_nic_model_with_boundary_replaces_single_end():
+ c = _make_constructor()
+ source = _model([(100, 200), (300, 400), (500, 600)], strand="-", tid="ENST2", gene_id="G2")
+ nic = c._nic_model_with_boundary(source, TranscriptModelType.novel_in_catalog, end=650)
+ assert nic.exon_blocks == [(100, 200), (300, 400), (500, 650)]
+ assert nic.transcript_id.endswith(".nic") and nic.strand == "-" and nic.gene_id == "G2"
+ assert source.exon_blocks[-1] == (500, 600) # source untouched
+
+
+def test_nic_model_with_boundary_nnic_suffix_for_novel_source():
+ c = _make_constructor()
+ source = _model([(100, 200), (300, 400)], ttype=TranscriptModelType.novel_not_in_catalog)
+ m = c._nic_model_with_boundary(source, TranscriptModelType.novel_not_in_catalog, start=50)
+ assert m.exon_blocks[0] == (50, 200)
+ assert m.transcript_id.endswith(".nnic")
+
+
+# -- _drop_duplicate_alt_end_models -------------------------------------------
+
+def test_drop_duplicate_alt_end_models():
+ ref_exons = {"REF1": [(100, 200), (300, 400)]}
+ c = _make_constructor(apa_delta=10, all_isoforms_exons=ref_exons)
+ known = _model([(100, 200), (300, 400)], ttype=TranscriptModelType.known, tid="REF1")
+ # same intron chain + both ends within apa_delta of REF1 -> it IS the known -> drop
+ dup = _model([(105, 200), (300, 398)], ttype=TranscriptModelType.novel_in_catalog, tid="DUP")
+ # genuinely different 3' end -> a real alt-end NIC -> keep
+ keep = _model([(100, 200), (300, 800)], ttype=TranscriptModelType.novel_in_catalog, tid="KEEP")
+ c.transcript_model_storage = [known, dup, keep]
+
+ c._drop_duplicate_alt_end_models()
+
+ ids = {m.transcript_id for m in c.transcript_model_storage}
+ assert ids == {"REF1", "KEEP"}
diff --git a/isoquant_tests/test_intron_graph_refine.py b/isoquant_tests/test_intron_graph_refine.py
new file mode 100644
index 00000000..a3e42b38
--- /dev/null
+++ b/isoquant_tests/test_intron_graph_refine.py
@@ -0,0 +1,80 @@
+############################################################################
+# Copyright (c) 2022-2026 University of Helsinki
+# # All Rights Reserved
+# See file LICENSE for details.
+############################################################################
+
+"""Unit tests for intron-graph terminal-vertex refinement
+(``IntronGraph._refine_positions`` / ``_attach_side``) on branch
+``transcript_model_ends``.
+
+The graph is heavy to build, so the methods under test run on a bare instance
+created with ``__new__`` plus the few attributes/stubs they touch.
+"""
+
+import types
+from collections import defaultdict
+
+from isoquant_lib.intron_graph import IntronGraph, VERTEX_read_end
+
+
+def _bare_graph(apa_delta=10, polya_predictions=None, tss_predictions=None):
+ g = IntronGraph.__new__(IntronGraph)
+ g.params = types.SimpleNamespace(
+ apa_delta=apa_delta,
+ terminal_position_abs=1,
+ terminal_position_rel=0.0,
+ terminal_internal_position_rel=0.0,
+ )
+ g.polya_predictions = polya_predictions
+ g.tss_predictions = tss_predictions
+ g.outgoing_edges = defaultdict(set)
+ g.incoming_edges = defaultdict(set)
+ return g
+
+
+# -- _refine_positions --------------------------------------------------------
+
+def test_refine_positions_snaps_within_delta():
+ g = _bare_graph(apa_delta=10)
+ assert g._refine_positions({105: 3}, [100, 500]) == {100: 3}
+
+
+def test_refine_positions_leaves_when_far():
+ g = _bare_graph(apa_delta=10)
+ assert g._refine_positions({130: 3}, [100, 500]) == {130: 3}
+
+
+def test_refine_positions_merges_counts_on_collision():
+ g = _bare_graph(apa_delta=10)
+ assert g._refine_positions({98: 2, 103: 5}, [100]) == {100: 7}
+
+
+def test_refine_positions_identity_without_predictions():
+ g = _bare_graph(apa_delta=10)
+ assert g._refine_positions({105: 3}, None) == {105: 3}
+ assert g._refine_positions({105: 3}, []) == {105: 3}
+
+
+# -- _attach_side side-selection (regression for review fix #1) ----------------
+
+def test_attach_side_3prime_readend_refines_with_polya_not_tss():
+ # read_end=True is the genomic 3' side: a VERTEX_read_end position must be
+ # refined toward the polyA predictions, never the (5') TSS predictions.
+ # Place the read-end cluster at 1005, equidistant from polya (1000) and tss
+ # (1010); the chosen target tells us which set was used.
+ intron = (100, 200)
+ g = _bare_graph(apa_delta=10, polya_predictions=[1000], tss_predictions=[1010])
+ g.intron_collector = types.SimpleNamespace(clustered_introns={intron: 10})
+ # No polyA-confirmed vertices; one read-end cluster at 1005.
+ g.cluster_polya_positions = lambda positions, i, read_end: {}
+ g.cluster_terminal_positions = lambda extra, read_end, cutoff: {1005: 4}
+
+ polya_confirmed = {intron: {}}
+ read_terminal = {intron: {1005: 4}}
+ g._attach_side([intron], polya_confirmed, read_terminal, read_end=True)
+
+ vertices = g.outgoing_edges[intron]
+ assert (VERTEX_read_end, 1000) in vertices # snapped to polyA prediction
+ assert (VERTEX_read_end, 1010) not in vertices # NOT the TSS prediction
+ assert (VERTEX_read_end, 1005) not in vertices # and it was refined, not left raw
diff --git a/isoquant_tests/test_polya_prediction.py b/isoquant_tests/test_polya_prediction.py
index a5c9b04e..bd3b02a5 100644
--- a/isoquant_tests/test_polya_prediction.py
+++ b/isoquant_tests/test_polya_prediction.py
@@ -162,6 +162,62 @@ def test_dump_ungrouped_emits_one_row_per_peak(stub_model, tmp_path):
assert df["flag"].iloc[0] == "Novel"
+def test_per_gene_flush_does_not_split_transcript(stub_model, tmp_path):
+ # A transcript whose reads arrive across several per-gene flush() batches
+ # (loader yields its reads in more than one gene block) must still be
+ # emitted as a single row per peak with the full read count -- not one
+ # partial row per batch. Regression for the per-gene-flush splitting bug.
+ out = tmp_path / "split.tsv"
+ counter = tc.PolyACounter(_make_args(), str(out))
+ exons = [(100, 200), (300, 400)]
+ # Batch 1 (10 reads), flush; batch 2 (20 reads), flush -- same peak ~420.
+ for offset in range(10):
+ counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons))
+ counter.flush()
+ assert counter.last_gene_predictions # per-gene reuse still produced
+ for offset in range(20):
+ counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons))
+ counter.flush()
+ counter.dump()
+
+ df = _read_tsv(out)
+ assert (df["transcript_id"] == "T1").all()
+ # No duplicate (transcript_id, prediction) rows, and every read counted once.
+ assert not df.duplicated(subset=["transcript_id", "prediction"]).any()
+ assert df["counts"].sum() == 30
+
+
+def test_flush_sets_predictions_and_clears_buffer(stub_model, tmp_path):
+ # flush() exposes the per-gene predictions (reused by intron-graph
+ # refinement), folds the gene into the chromosome-wide buffer, and clears
+ # the per-gene buffer for the next gene.
+ counter = tc.PolyACounter(_make_args(), str(tmp_path / "f.tsv"))
+ exons = [(100, 200), (300, 400)]
+ for offset in range(30):
+ counter.add_read_info(_make_read_assignment(polya_pos=420 + offset % 3, exons=exons))
+ assert counter.transcripts # buffered before flush
+ counter.flush()
+ assert counter.last_gene_predictions # per-gene predictions exposed for reuse
+ assert counter.transcripts == {} # per-gene buffer cleared
+ assert counter._all_transcripts # merged into chromosome-wide buffer
+
+
+def test_flush_is_noop_for_grouped_counter(stub_model, tmp_path):
+ # A grouped counter accumulates across the whole chromosome (whole-chr
+ # prediction in dump); flush() must not touch its buffer or predictions.
+ pool = types.SimpleNamespace(get_str=lambda i: f"g{i}")
+ string_pools = types.SimpleNamespace(get_read_group_pool=lambda _idx: pool)
+ counter = tc.PolyACounter(_make_args(), str(tmp_path / "g.tsv"),
+ string_pools=string_pools, group_index=0)
+ exons = [(100, 200), (300, 400)]
+ for offset in range(30):
+ counter.add_read_info(_make_read_assignment(
+ polya_pos=420 + offset % 3, exons=exons, read_group_ids=[offset % 2]))
+ counter.flush()
+ assert counter.transcripts # NOT cleared (flush no-op for grouped)
+ assert counter.last_gene_predictions == []
+
+
def test_dump_flags_peak_as_known_within_tolerance(stub_model, tmp_path):
out = tmp_path / "known.tsv"
counter = tc.PolyACounter(_make_args(), str(out))
diff --git a/isoquant_tests/test_terminal_peaks.py b/isoquant_tests/test_terminal_peaks.py
new file mode 100644
index 00000000..c069bdab
--- /dev/null
+++ b/isoquant_tests/test_terminal_peaks.py
@@ -0,0 +1,91 @@
+############################################################################
+# Copyright (c) 2022-2026 University of Helsinki
+# # All Rights Reserved
+# See file LICENSE for details.
+############################################################################
+
+"""Unit tests for the shared polyA / TSS peak detector (terminal_peaks).
+
+The detector is the single source of truth used both by the side-output
+prediction counters (terminal_counter) and by transcript discovery
+(intron_graph / graph_based_model_construction). These tests pin its output
+against terminal_counter's own pandas pipeline so the feature computation can
+not silently drift between the two call sites.
+"""
+
+from collections import Counter
+
+import numpy as np
+import pytest
+
+from isoquant_lib import terminal_counter as tc
+from isoquant_lib import terminal_peaks as tp
+from isoquant_lib.isoform_assignment import ReadAssignmentType
+
+# Reuse the counter test harness (stub model + read-assignment builders).
+from isoquant_tests.test_polya_prediction import (
+ _StubModel,
+ _make_args,
+ _make_read_assignment,
+ _read_tsv,
+ stub_model, # noqa: F401 (pytest fixture)
+)
+
+
+def _counter_peaks(positions, tmp_path, name):
+ """Run the real PolyACounter pipeline over reads at ``positions`` and
+ return the set of (prediction, counts) it emits."""
+ out = tmp_path / name
+ counter = tc.PolyACounter(_make_args(), str(out))
+ exons = [(100, 200), (300, 400)]
+ for pos in positions:
+ counter.add_read_info(_make_read_assignment(polya_pos=pos, exons=exons))
+ counter.dump()
+ df = _read_tsv(out)
+ if df.empty:
+ return set()
+ return set(zip(df["prediction"].astype(int), df["counts"].astype(int)))
+
+
+def _detect_peaks(positions, accept=True):
+ histogram = dict(Counter(positions))
+ peaks = tp.detect_peaks(histogram, _StubModel(accept=accept))
+ return {(p.position, p.count) for p in peaks}
+
+
+@pytest.mark.parametrize("positions", [
+ [420], # single read
+ [420 + i % 3 for i in range(30)], # one tight peak
+ [420 + i % 3 for i in range(30)] +
+ [500 + i % 3 for i in range(20)], # two peaks (ranking)
+ [400, 401], # adjacent-bin plateau
+])
+def test_detect_peaks_parity_with_counter(stub_model, tmp_path, positions):
+ counter_peaks = _counter_peaks(positions, tmp_path, "parity.tsv")
+ assert _detect_peaks(positions) == counter_peaks
+ # And the union of supporting counts is conserved either way.
+ assert sum(c for _, c in _detect_peaks(positions)) == \
+ sum(c for _, c in counter_peaks)
+
+
+def test_detect_peaks_empty_returns_empty():
+ assert tp.detect_peaks({}, _StubModel(accept=True)) == []
+
+
+def test_detect_peaks_batch_matches_single():
+ histograms = [
+ {420 + i % 3: 1 for i in range(1)},
+ {420 + i % 3: 10 for i in range(3)},
+ {},
+ ]
+ model = _StubModel(accept=True)
+ batched = tp.detect_peaks_batch(histograms, model)
+ singles = [tp.detect_peaks(h, model) for h in histograms]
+ assert batched == singles
+ assert batched[2] == []
+
+
+def test_detect_peaks_rejected_peak_dropped():
+ # A clear peak is dropped when the model rejects it (no zero-peak fallback).
+ positions = [420 + i % 3 for i in range(30)]
+ assert _detect_peaks(positions, accept=False) == set()
diff --git a/misc/prepare_simulated_reduced_db.py b/misc/prepare_simulated_reduced_db.py
new file mode 100644
index 00000000..8f8de360
--- /dev/null
+++ b/misc/prepare_simulated_reduced_db.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+############################################################################
+# Copyright (c) 2022-2026 University of Helsinki
+# # All Rights Reserved
+# See file LICENSE for details.
+############################################################################
+
+"""Split a simulation GTF into the reduced-db trio used by
+``reduced_db_gffcompare.py`` for transcript-discovery assessment of
+**alternative TSS/polyA**.
+
+The simulation GTF is the full reference annotation (plain transcript IDs)
+plus alternative-end variants whose ID is ``_`` (same intron
+chain, different terminal exon). Mapping alt-end variants onto the
+reduced-db "novel" split lets the existing 3-terminal-delta assessment
+measure how well alternative ends are recovered:
+
+ .expressed.gtf all expressed transcripts (-> full)
+ .expressed_kept.gtf expressed plain-ID transcripts (-> known)
+ .excluded.gtf expressed alt-end variants (-> novel)
+
+"Expressed" = count > 0 in the simulated counts TSV. Alt-end IDs are matched
+by ``_`` at the end (the simulation appends the alternative terminal
+coordinate); plain GENCODE IDs (incl. ``_PAR_Y``) do not match.
+"""
+import argparse
+import re
+import sys
+from traceback import print_exc
+
+ALT_END = re.compile(r'_\d+$')
+
+
+def parse_args():
+ p = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
+ p.add_argument("--sim_gtf", required=True, help="simulation GTF (full annotation + alt-end variants)")
+ p.add_argument("--counts", required=True, help="simulated counts TSV (transcript_idcounts[...])")
+ p.add_argument("--output_prefix", required=True, help="output prefix for the .expressed/.expressed_kept/.excluded GTFs")
+ return p.parse_args()
+
+
+def load_expressed(counts_path):
+ expressed = set()
+ for line in open(counts_path):
+ if line.startswith("#") or not line.strip():
+ continue
+ f = line.rstrip("\n").split("\t")
+ if len(f) < 2:
+ continue
+ try:
+ if float(f[1]) > 0:
+ expressed.add(f[0])
+ except ValueError:
+ continue
+ return expressed
+
+
+def get_transcript_id(attrs):
+ i = attrs.find('transcript_id "')
+ if i < 0:
+ return None
+ return attrs[i + 15:].split('"', 1)[0]
+
+
+def main():
+ args = parse_args()
+ expressed = load_expressed(args.counts)
+ out_all = open(args.output_prefix + ".expressed.gtf", "w")
+ out_kept = open(args.output_prefix + ".expressed_kept.gtf", "w")
+ out_excl = open(args.output_prefix + ".excluded.gtf", "w")
+
+ tx = {"all": set(), "kept": set(), "excl": set()}
+ for line in open(args.sim_gtf):
+ if line.startswith("#"):
+ continue
+ f = line.rstrip("\n").split("\t")
+ if len(f) < 9 or f[2] not in ("transcript", "exon"):
+ continue
+ # Some upstream synthetic GTFs (the lrgasp human ones) carry a bogus
+ # extra column 8 (literal "") before the attributes; the attributes
+ # are always the last field. Read from there and write a normalized
+ # 9-column line so gffcompare parses it (9-col input is unchanged).
+ attrs = f[-1]
+ tid = get_transcript_id(attrs)
+ if tid is None or tid not in expressed:
+ continue
+ out_line = "\t".join(f[:8] + [attrs]) + "\n"
+ out_all.write(out_line)
+ is_tx = f[2] == "transcript"
+ if is_tx:
+ tx["all"].add(tid)
+ if ALT_END.search(tid):
+ out_excl.write(out_line)
+ if is_tx:
+ tx["excl"].add(tid)
+ else:
+ out_kept.write(out_line)
+ if is_tx:
+ tx["kept"].add(tid)
+
+ for fh in (out_all, out_kept, out_excl):
+ fh.close()
+ print("Expressed transcripts written: all=%d kept/plain(known)=%d excluded/alt-end(novel)=%d"
+ % (len(tx["all"]), len(tx["kept"]), len(tx["excl"])))
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except SystemExit:
+ raise
+ except Exception:
+ print_exc()
+ sys.exit(-1)
diff --git a/misc/reduced_db_gffcompare.py b/misc/reduced_db_gffcompare.py
index 71561835..965e0374 100755
--- a/misc/reduced_db_gffcompare.py
+++ b/misc/reduced_db_gffcompare.py
@@ -15,6 +15,11 @@
from common import *
+# Terminal-end tolerances (bp) for transcript-level matching. None = default
+# end-agnostic match; the integers use the gffcompare fork's --terminal-delta.
+TERMINAL_DELTAS = [None, 50, 10]
+
+
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--output", "-o", type=str, help="output folder", default="gtf_stats")
@@ -41,15 +46,26 @@ def main():
print("Seprating known and novel transcripts")
separator = SEPARATE_FUNCTORS[args.tool](args.gtf)
split_gtf(args.gtf, separator, out_full_path, out_known_path, out_novel_path)
- print("Running gffcompare for entire GTF")
- expressed_gtf = args.genedb + ".expressed.gtf"
- run_gff_compare(expressed_gtf, out_full_path, os.path.join(args.output, args.tool + ".full.stats"))
- print("Running gffcompare for known transcripts")
- expressed_gtf = args.genedb + ".expressed_kept.gtf"
- run_gff_compare(expressed_gtf, out_known_path, os.path.join(args.output, args.tool + ".known.stats"))
- print("Running gffcompare for novel transcripts")
- expressed_gtf = args.genedb + ".excluded.gtf"
- run_gff_compare(expressed_gtf, out_novel_path, os.path.join(args.output, args.tool + ".novel.stats"))
+
+ # (split, reference subset, query split GTF)
+ splits = [
+ ("full", args.genedb + ".expressed.gtf", out_full_path),
+ ("known", args.genedb + ".expressed_kept.gtf", out_known_path),
+ ("novel", args.genedb + ".excluded.gtf", out_novel_path),
+ ]
+ # Score each split at several terminal-end tolerances. None = default
+ # end-agnostic transcript match (-> "..stats"); the integer
+ # deltas use the gffcompare fork's --terminal-delta so the transcript-level
+ # metric becomes end-sensitive (-> "..td.stats").
+ # See .claude/GFFCOMPARE.md. Requires the gffcompare fork for the deltas;
+ # the default run works with stock gffcompare too.
+ for split, reference_gtf, compared_gtf in splits:
+ for delta in TERMINAL_DELTAS:
+ suffix = "" if delta is None else (".td%d" % delta)
+ option = "" if delta is None else ("--terminal-delta=%d" % delta)
+ out_stats = os.path.join(args.output, "%s.%s%s.stats" % (args.tool, split, suffix))
+ print("Running gffcompare for %s transcripts (terminal-delta=%s)" % (split, delta))
+ run_gff_compare(reference_gtf, compared_gtf, out_stats, additional_option=option)
if __name__ == "__main__":