Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/_static/custom.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/* Reduce the height of the piccolo_theme top navigation bar */
:root {
--navbarHeight: 3.25rem;
}

div#top_nav nav {
padding: 0.7rem 1rem;
}
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@

html_theme = "piccolo_theme"
html_static_path = ["_static"]
html_css_files = ["custom.css"]
13 changes: 13 additions & 0 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class SEQopts:
:type plot_title: str
:param plot_type: Type of plot to show ["risk", "survival" or "incidence" if compevent is specified]
:type plot_type: str
:param risk_times: Followup times at which to report risk difference and risk ratio when ``km_curves = True``.
Each requested time is snapped to the latest available followup at or before it, and the maximum
followup is always included. Defaults to ``None`` (report at the maximum followup only).
:type risk_times: Optional[List[float]] or None
:param seed: RNG seed
:type seed: int
:param selection_first_trial: Boolean to only use first trial for analysis (similar to non-expanded)
Expand Down Expand Up @@ -150,6 +154,7 @@ class SEQopts:
plot_labels: List[str] = field(default_factory=lambda: [])
plot_title: str = None
plot_type: Literal["risk", "survival", "incidence"] = "survival"
risk_times: Optional[List[float]] = None
seed: Optional[int] = None
selection_first_trial: bool = False
selection_sample: float = 0.8
Expand Down Expand Up @@ -210,6 +215,14 @@ def _validate_ranges(self):
raise ValueError(
f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})."
)
if self.risk_times is not None:
times = (
self.risk_times
if isinstance(self.risk_times, (list, tuple))
else [self.risk_times]
)
if any(not isinstance(t, (int, float)) or t < 0 for t in times):
raise ValueError("risk_times values must be non-negative numbers.")

def _validate_choices(self):
if self.plot_type not in ["risk", "survival", "incidence"]:
Expand Down
32 changes: 32 additions & 0 deletions pySEQTarget/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import re

import numpy as np
import pandas as pd
import polars as pl
import statsmodels.api as sm
import statsmodels.formula.api as smf
from pandas.api.types import is_numeric_dtype


def _compute_spline_knots(followup_arr, df=3):
Expand Down Expand Up @@ -40,6 +42,20 @@ def _apply_spline_formula(formula, indicator_squared, spline_knots):
return spline


def _categorical_tv_columns(self, df_pd):
"""
Names of the categorical (non-numeric) time-varying covariate columns
present in ``df_pd``, including their baseline (``indicator_baseline``)
versions used by the outcome model.
"""
cols = []
for col in self.time_varying_cols or []:
for variant in (col, f"{col}{self.indicator_baseline}"):
if variant in df_pd.columns and not is_numeric_dtype(df_pd[variant]):
cols.append(variant)
return cols


def _cast_categories(self, df_pd):
if self.treatment_col in df_pd.columns:
df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category")
Expand All @@ -58,6 +74,22 @@ def _cast_categories(self, df_pd):
if col in df_pd.columns:
df_pd[col] = df_pd[col].astype("category")

# Stable factor encoding for categorical time-varying covariates: fix the
# level set from the full expanded data (captured on the non-bootstrap
# pass) so a bootstrap resample cannot realise a different set of levels —
# otherwise a level absent from the resample would be unknown to that fit
# and crash counterfactual prediction with NaNs.
tv_cat_cols = _categorical_tv_columns(self, df_pd)
if getattr(self, "_current_boot_idx", None) is None:
cats = getattr(self, "_covariate_categories", {})
for col in tv_cat_cols:
cats[col] = sorted(df_pd[col].dropna().unique().tolist())
self._covariate_categories = cats
cats = getattr(self, "_covariate_categories", {})
for col in tv_cat_cols:
if col in cats:
df_pd[col] = pd.Categorical(df_pd[col], categories=cats[col])

return df_pd


Expand Down
253 changes: 151 additions & 102 deletions pySEQTarget/analysis/_risk_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,42 @@ def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
return rd_comp, rr_comp


def _resolve_risk_times(grid, risk_times):
"""
Snap each requested risk time to the latest available followup at or before
it, always including the maximum followup. Returns a sorted list of followup
values that exist in ``grid``.
"""
grid = sorted(set(grid))
final = grid[-1]

if risk_times is None:
return [final]

req = risk_times if isinstance(risk_times, (list, tuple)) else [risk_times]
req = [float(t) for t in req if t is not None]
if not req:
return [final]

above = [t for t in req if t > final]
if above:
raise ValueError(
f"risk_times value(s) exceed the maximum followup ({final}): {above}"
)
below = [t for t in req if t < grid[0]]
if below:
raise ValueError(
f"risk_times value(s) below the minimum followup ({grid[0]}): {below}"
)

snapped = [max(g for g in grid if g <= t) for t in req]
return sorted(set(snapped + [final]))


def _risk_estimates(self):
last_followup = self.km_data["followup"].max()
risk = self.km_data.filter(
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
risk_all = self.km_data.filter(pl.col("estimate") == "risk")
report_times = _resolve_risk_times(
risk_all["followup"].unique().to_list(), self.risk_times
)

group_cols = [self.subgroup_colname] if self.subgroup_colname else []
Expand All @@ -101,115 +133,132 @@ def _risk_estimates(self):
z = None
alpha = None

risk_by_level = {}
for tx in self.treatment_level:
level_data = risk.filter(pl.col(self.treatment_col) == tx)
risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])}
if has_bootstrap and not use_paired:
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])

rd_comparisons = []
rr_comparisons = []

for tx_x in self.treatment_level:
for tx_y in self.treatment_level:
if tx_x == tx_y:
continue

if use_paired:
boot_x = (
self._boot_risks[tx_x]
.filter(pl.col("followup") == last_followup)
.select(["boot_idx", pl.col("risk").alias("risk_x")])
)
boot_y = (
self._boot_risks[tx_y]
.filter(pl.col("followup") == last_followup)
.select(["boot_idx", pl.col("risk").alias("risk_y")])
)
paired = boot_x.join(boot_y, on="boot_idx").with_columns(
(pl.col("risk_x") - pl.col("risk_y")).alias("RD")
)

risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0])
risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0])
rd_point = risk_x_val - risk_y_val
rr_point = risk_x_val / risk_y_val if risk_y_val != 0 else float("inf")

# Filter degenerate RR bootstrap values (risk_y == 0 or negative)
valid_rr = paired.filter(
(pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0)
).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR"))

n_valid_rr = len(valid_rr)

if self.bootstrap_CI_method == "percentile":
rd_lci = float(paired["RD"].quantile(alpha / 2))
rd_uci = float(paired["RD"].quantile(1 - alpha / 2))
if n_valid_rr >= 2:
rr_lci = float(valid_rr["RR"].quantile(alpha / 2))
rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2))
else:
rr_lci = float("nan")
rr_uci = float("nan")
else:
rd_se = float(paired["RD"].std())
rd_lci = rd_point - z * rd_se
rd_uci = rd_point + z * rd_se
if n_valid_rr >= 2 and rr_point > 0:
log_rr_se = float(valid_rr["RR"].log().std())
rr_lci = math.exp(math.log(rr_point) - z * log_rr_se)
rr_uci = math.exp(math.log(rr_point) + z * log_rr_se)
for followup_t in report_times:
risk = risk_all.filter(pl.col("followup") == followup_t)

risk_by_level = {}
for tx in self.treatment_level:
level_data = risk.filter(pl.col(self.treatment_col) == tx)
risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])}
if has_bootstrap and not use_paired:
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])

for tx_x in self.treatment_level:
for tx_y in self.treatment_level:
if tx_x == tx_y:
continue

if use_paired:
boot_x = (
self._boot_risks[tx_x]
.filter(pl.col("followup") == followup_t)
.select(["boot_idx", pl.col("risk").alias("risk_x")])
)
boot_y = (
self._boot_risks[tx_y]
.filter(pl.col("followup") == followup_t)
.select(["boot_idx", pl.col("risk").alias("risk_y")])
)
paired = boot_x.join(boot_y, on="boot_idx").with_columns(
(pl.col("risk_x") - pl.col("risk_y")).alias("RD")
)

risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0])
risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0])
rd_point = risk_x_val - risk_y_val
rr_point = (
risk_x_val / risk_y_val if risk_y_val != 0 else float("inf")
)

# Filter degenerate RR bootstrap values (risk_y == 0 or negative)
valid_rr = paired.filter(
(pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0)
).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR"))

n_valid_rr = len(valid_rr)

if self.bootstrap_CI_method == "percentile":
rd_lci = float(paired["RD"].quantile(alpha / 2))
rd_uci = float(paired["RD"].quantile(1 - alpha / 2))
if n_valid_rr >= 2:
rr_lci = float(valid_rr["RR"].quantile(alpha / 2))
rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2))
else:
rr_lci = float("nan")
rr_uci = float("nan")
else:
rr_lci = float("nan")
rr_uci = float("nan")

rd_comp = pl.DataFrame(
{
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Difference": [rd_point],
"RD 95% LCI": [rd_lci],
"RD 95% UCI": [rd_uci],
}
)
rr_comp = pl.DataFrame(
{
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Ratio": [rr_point],
"RR 95% LCI": [rr_lci],
"RR 95% UCI": [rr_uci],
}
)
else:
# Fall back to independent delta method
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})

if group_cols:
comp = risk_x.join(risk_y, on=group_cols, how="left")
else:
comp = risk_x.join(risk_y, how="cross")
rd_se = float(paired["RD"].std())
rd_lci = rd_point - z * rd_se
rd_uci = rd_point + z * rd_se
if n_valid_rr >= 2 and rr_point > 0:
log_rr_se = float(valid_rr["RR"].log().std())
rr_lci = math.exp(math.log(rr_point) - z * log_rr_se)
rr_uci = math.exp(math.log(rr_point) + z * log_rr_se)
else:
rr_lci = float("nan")
rr_uci = float("nan")

comp = comp.with_columns(
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
)
rd_comp = pl.DataFrame(
{
"Followup": [followup_t],
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Difference": [rd_point],
"RD 95% LCI": [rd_lci],
"RD 95% UCI": [rd_uci],
}
)
rr_comp = pl.DataFrame(
{
"Followup": [followup_t],
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Ratio": [rr_point],
"RR 95% LCI": [rr_lci],
"RR 95% UCI": [rr_uci],
}
)
else:
# Fall back to independent delta method
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})

if has_bootstrap:
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
if group_cols:
comp = comp.join(se_x, on=group_cols, how="left")
comp = comp.join(se_y, on=group_cols, how="left")
comp = risk_x.join(risk_y, on=group_cols, how="left")
else:
comp = comp.join(se_x, how="cross")
comp = comp.join(se_y, how="cross")
comp = risk_x.join(risk_y, how="cross")

comp = comp.with_columns(
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
)

if has_bootstrap:
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
if group_cols:
comp = comp.join(se_x, on=group_cols, how="left")
comp = comp.join(se_y, on=group_cols, how="left")
else:
comp = comp.join(se_x, how="cross")
comp = comp.join(se_y, how="cross")

rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
rd_comp, rr_comp = _compute_rd_rr(
comp, has_bootstrap, z, group_cols
)
rd_cols = rd_comp.columns
rr_cols = rr_comp.columns
rd_comp = rd_comp.with_columns(
pl.lit(followup_t).alias("Followup")
).select(["Followup"] + rd_cols)
rr_comp = rr_comp.with_columns(
pl.lit(followup_t).alias("Followup")
).select(["Followup"] + rr_cols)

rd_comparisons.append(rd_comp)
rr_comparisons.append(rr_comp)
rd_comparisons.append(rd_comp)
rr_comparisons.append(rr_comp)

risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
Expand Down
Loading
Loading