diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 0000000..23a7d8b --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,8 @@ +/* Reduce the height of the piccolo_theme top navigation bar */ +:root { + --navbarHeight: 3.25rem; +} + +div#top_nav nav { + padding: 0.7rem 1rem; +} diff --git a/docs/conf.py b/docs/conf.py index eaa3ba4..aa8fa75 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,3 +50,4 @@ html_theme = "piccolo_theme" html_static_path = ["_static"] +html_css_files = ["custom.css"] diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 6a6ae88..4392b7e 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -78,6 +78,10 @@ class SEQopts: :type plot_title: str :param plot_type: Type of plot to show ["risk", "survival" or "incidence" if compevent is specified] :type plot_type: str + :param risk_times: Followup times at which to report risk difference and risk ratio when ``km_curves = True``. + Each requested time is snapped to the latest available followup at or before it, and the maximum + followup is always included. Defaults to ``None`` (report at the maximum followup only). + :type risk_times: Optional[List[float]] or None :param seed: RNG seed :type seed: int :param selection_first_trial: Boolean to only use first trial for analysis (similar to non-expanded) @@ -150,6 +154,7 @@ class SEQopts: plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None plot_type: Literal["risk", "survival", "incidence"] = "survival" + risk_times: Optional[List[float]] = None seed: Optional[int] = None selection_first_trial: bool = False selection_sample: float = 0.8 @@ -210,6 +215,14 @@ def _validate_ranges(self): raise ValueError( f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})." ) + if self.risk_times is not None: + times = ( + self.risk_times + if isinstance(self.risk_times, (list, tuple)) + else [self.risk_times] + ) + if any(not isinstance(t, (int, float)) or t < 0 for t in times): + raise ValueError("risk_times values must be non-negative numbers.") def _validate_choices(self): if self.plot_type not in ["risk", "survival", "incidence"]: diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index ed049a4..af91176 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -1,9 +1,11 @@ import re import numpy as np +import pandas as pd import polars as pl import statsmodels.api as sm import statsmodels.formula.api as smf +from pandas.api.types import is_numeric_dtype def _compute_spline_knots(followup_arr, df=3): @@ -40,6 +42,20 @@ def _apply_spline_formula(formula, indicator_squared, spline_knots): return spline +def _categorical_tv_columns(self, df_pd): + """ + Names of the categorical (non-numeric) time-varying covariate columns + present in ``df_pd``, including their baseline (``indicator_baseline``) + versions used by the outcome model. + """ + cols = [] + for col in self.time_varying_cols or []: + for variant in (col, f"{col}{self.indicator_baseline}"): + if variant in df_pd.columns and not is_numeric_dtype(df_pd[variant]): + cols.append(variant) + return cols + + def _cast_categories(self, df_pd): if self.treatment_col in df_pd.columns: df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") @@ -58,6 +74,22 @@ def _cast_categories(self, df_pd): if col in df_pd.columns: df_pd[col] = df_pd[col].astype("category") + # Stable factor encoding for categorical time-varying covariates: fix the + # level set from the full expanded data (captured on the non-bootstrap + # pass) so a bootstrap resample cannot realise a different set of levels — + # otherwise a level absent from the resample would be unknown to that fit + # and crash counterfactual prediction with NaNs. + tv_cat_cols = _categorical_tv_columns(self, df_pd) + if getattr(self, "_current_boot_idx", None) is None: + cats = getattr(self, "_covariate_categories", {}) + for col in tv_cat_cols: + cats[col] = sorted(df_pd[col].dropna().unique().tolist()) + self._covariate_categories = cats + cats = getattr(self, "_covariate_categories", {}) + for col in tv_cat_cols: + if col in cats: + df_pd[col] = pd.Categorical(df_pd[col], categories=cats[col]) + return df_pd diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py index 561179c..d8a6a0a 100644 --- a/pySEQTarget/analysis/_risk_estimates.py +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -73,10 +73,42 @@ def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None): return rd_comp, rr_comp +def _resolve_risk_times(grid, risk_times): + """ + Snap each requested risk time to the latest available followup at or before + it, always including the maximum followup. Returns a sorted list of followup + values that exist in ``grid``. + """ + grid = sorted(set(grid)) + final = grid[-1] + + if risk_times is None: + return [final] + + req = risk_times if isinstance(risk_times, (list, tuple)) else [risk_times] + req = [float(t) for t in req if t is not None] + if not req: + return [final] + + above = [t for t in req if t > final] + if above: + raise ValueError( + f"risk_times value(s) exceed the maximum followup ({final}): {above}" + ) + below = [t for t in req if t < grid[0]] + if below: + raise ValueError( + f"risk_times value(s) below the minimum followup ({grid[0]}): {below}" + ) + + snapped = [max(g for g in grid if g <= t) for t in req] + return sorted(set(snapped + [final])) + + def _risk_estimates(self): - last_followup = self.km_data["followup"].max() - risk = self.km_data.filter( - (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk") + risk_all = self.km_data.filter(pl.col("estimate") == "risk") + report_times = _resolve_risk_times( + risk_all["followup"].unique().to_list(), self.risk_times ) group_cols = [self.subgroup_colname] if self.subgroup_colname else [] @@ -101,115 +133,132 @@ def _risk_estimates(self): z = None alpha = None - risk_by_level = {} - for tx in self.treatment_level: - level_data = risk.filter(pl.col(self.treatment_col) == tx) - risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])} - if has_bootstrap and not use_paired: - risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"]) - rd_comparisons = [] rr_comparisons = [] - for tx_x in self.treatment_level: - for tx_y in self.treatment_level: - if tx_x == tx_y: - continue - - if use_paired: - boot_x = ( - self._boot_risks[tx_x] - .filter(pl.col("followup") == last_followup) - .select(["boot_idx", pl.col("risk").alias("risk_x")]) - ) - boot_y = ( - self._boot_risks[tx_y] - .filter(pl.col("followup") == last_followup) - .select(["boot_idx", pl.col("risk").alias("risk_y")]) - ) - paired = boot_x.join(boot_y, on="boot_idx").with_columns( - (pl.col("risk_x") - pl.col("risk_y")).alias("RD") - ) - - risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0]) - risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0]) - rd_point = risk_x_val - risk_y_val - rr_point = risk_x_val / risk_y_val if risk_y_val != 0 else float("inf") - - # Filter degenerate RR bootstrap values (risk_y == 0 or negative) - valid_rr = paired.filter( - (pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0) - ).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR")) - - n_valid_rr = len(valid_rr) - - if self.bootstrap_CI_method == "percentile": - rd_lci = float(paired["RD"].quantile(alpha / 2)) - rd_uci = float(paired["RD"].quantile(1 - alpha / 2)) - if n_valid_rr >= 2: - rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) - rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) - else: - rr_lci = float("nan") - rr_uci = float("nan") - else: - rd_se = float(paired["RD"].std()) - rd_lci = rd_point - z * rd_se - rd_uci = rd_point + z * rd_se - if n_valid_rr >= 2 and rr_point > 0: - log_rr_se = float(valid_rr["RR"].log().std()) - rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) - rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + for followup_t in report_times: + risk = risk_all.filter(pl.col("followup") == followup_t) + + risk_by_level = {} + for tx in self.treatment_level: + level_data = risk.filter(pl.col(self.treatment_col) == tx) + risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])} + if has_bootstrap and not use_paired: + risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"]) + + for tx_x in self.treatment_level: + for tx_y in self.treatment_level: + if tx_x == tx_y: + continue + + if use_paired: + boot_x = ( + self._boot_risks[tx_x] + .filter(pl.col("followup") == followup_t) + .select(["boot_idx", pl.col("risk").alias("risk_x")]) + ) + boot_y = ( + self._boot_risks[tx_y] + .filter(pl.col("followup") == followup_t) + .select(["boot_idx", pl.col("risk").alias("risk_y")]) + ) + paired = boot_x.join(boot_y, on="boot_idx").with_columns( + (pl.col("risk_x") - pl.col("risk_y")).alias("RD") + ) + + risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0]) + risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0]) + rd_point = risk_x_val - risk_y_val + rr_point = ( + risk_x_val / risk_y_val if risk_y_val != 0 else float("inf") + ) + + # Filter degenerate RR bootstrap values (risk_y == 0 or negative) + valid_rr = paired.filter( + (pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0) + ).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR")) + + n_valid_rr = len(valid_rr) + + if self.bootstrap_CI_method == "percentile": + rd_lci = float(paired["RD"].quantile(alpha / 2)) + rd_uci = float(paired["RD"].quantile(1 - alpha / 2)) + if n_valid_rr >= 2: + rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) + rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) + else: + rr_lci = float("nan") + rr_uci = float("nan") else: - rr_lci = float("nan") - rr_uci = float("nan") - - rd_comp = pl.DataFrame( - { - "A_x": [tx_x], - "A_y": [tx_y], - "Risk Difference": [rd_point], - "RD 95% LCI": [rd_lci], - "RD 95% UCI": [rd_uci], - } - ) - rr_comp = pl.DataFrame( - { - "A_x": [tx_x], - "A_y": [tx_y], - "Risk Ratio": [rr_point], - "RR 95% LCI": [rr_lci], - "RR 95% UCI": [rr_uci], - } - ) - else: - # Fall back to independent delta method - risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"}) - risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"}) - - if group_cols: - comp = risk_x.join(risk_y, on=group_cols, how="left") - else: - comp = risk_x.join(risk_y, how="cross") + rd_se = float(paired["RD"].std()) + rd_lci = rd_point - z * rd_se + rd_uci = rd_point + z * rd_se + if n_valid_rr >= 2 and rr_point > 0: + log_rr_se = float(valid_rr["RR"].log().std()) + rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) + rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + else: + rr_lci = float("nan") + rr_uci = float("nan") - comp = comp.with_columns( - [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] - ) + rd_comp = pl.DataFrame( + { + "Followup": [followup_t], + "A_x": [tx_x], + "A_y": [tx_y], + "Risk Difference": [rd_point], + "RD 95% LCI": [rd_lci], + "RD 95% UCI": [rd_uci], + } + ) + rr_comp = pl.DataFrame( + { + "Followup": [followup_t], + "A_x": [tx_x], + "A_y": [tx_y], + "Risk Ratio": [rr_point], + "RR 95% LCI": [rr_lci], + "RR 95% UCI": [rr_uci], + } + ) + else: + # Fall back to independent delta method + risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"}) + risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"}) - if has_bootstrap: - se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"}) - se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"}) if group_cols: - comp = comp.join(se_x, on=group_cols, how="left") - comp = comp.join(se_y, on=group_cols, how="left") + comp = risk_x.join(risk_y, on=group_cols, how="left") else: - comp = comp.join(se_x, how="cross") - comp = comp.join(se_y, how="cross") + comp = risk_x.join(risk_y, how="cross") + + comp = comp.with_columns( + [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] + ) + + if has_bootstrap: + se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"}) + se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"}) + if group_cols: + comp = comp.join(se_x, on=group_cols, how="left") + comp = comp.join(se_y, on=group_cols, how="left") + else: + comp = comp.join(se_x, how="cross") + comp = comp.join(se_y, how="cross") - rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols) + rd_comp, rr_comp = _compute_rd_rr( + comp, has_bootstrap, z, group_cols + ) + rd_cols = rd_comp.columns + rr_cols = rr_comp.columns + rd_comp = rd_comp.with_columns( + pl.lit(followup_t).alias("Followup") + ).select(["Followup"] + rd_cols) + rr_comp = rr_comp.with_columns( + pl.lit(followup_t).alias("Followup") + ).select(["Followup"] + rr_cols) - rd_comparisons.append(rd_comp) - rr_comparisons.append(rr_comp) + rd_comparisons.append(rd_comp) + rr_comparisons.append(rr_comp) risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index ac9a4ef..c9e17a7 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -132,8 +132,13 @@ def _calculate_risk(self, data, idx=None, val=None): lci = a / 2 uci = 1 - lci - # Pre-compute the followup range once (starts at 1, not 0) - followup_range = list(range(1, self.followup_max + 1)) + # Predict the hazard on the full followup grid starting at 0 — the first + # interval of every trial, where an event can already occur. Curve labels + # are shifted +1 after the cumulative product (below) so that followup=k + # means "survival/risk after k elapsed intervals", giving rows + # 0..followup_max+1. This matches SEQTaRget (R); starting the grid at 1 + # silently dropped the first interval's hazard and ended one step short. + followup_range = list(range(0, self.followup_max + 1)) SDT = ( data.with_columns( @@ -223,6 +228,7 @@ def _calculate_risk(self, data, idx=None, val=None): TxDT.group_by("followup") .agg([pl.col(col).mean() for col in surv_names + inc_names]) .sort("followup") + .with_columns(pl.col("followup") + 1) ) main_col = "surv" boot_cols = [col for col in surv_names if col != "surv"] @@ -242,6 +248,7 @@ def _calculate_risk(self, data, idx=None, val=None): .agg([pl.col(col).mean() for col in outcome_names]) .sort("followup") .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names]) + .with_columns(pl.col("followup") + 1) ) main_col = "pred_outcome" boot_cols = [col for col in outcome_names if col != "pred_outcome"] diff --git a/pyproject.toml b/pyproject.toml index 0eb1343..0f4c31f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.4" +version = "0.13.5" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_categorical_covariates.py b/tests/test_categorical_covariates.py new file mode 100644 index 0000000..5aa4019 --- /dev/null +++ b/tests/test_categorical_covariates.py @@ -0,0 +1,72 @@ +import numpy as np +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _model(data, **opts): + return SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P", "grp"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, **opts), + ) + + +def test_string_time_varying_covariate_bootstrap(): + """A categorical (string) time-varying covariate should run through the full + bootstrap pipeline and produce risk estimates.""" + data = load_data("SEQdata") + rng = np.random.RandomState(1) + data = data.with_columns( + pl.Series("grp", rng.choice(["a", "b", "c"], size=data.height)) + ) + + s = _model(data, bootstrap_nboot=3, seed=42) + s.expand() + s.bootstrap() + s.fit() + s.survival() + + assert "grp_bas" in s.DT.columns + rd = s.risk_estimates["risk_difference"] + assert rd.height > 0 + assert rd.select(["RD 95% LCI", "RD 95% UCI"]).null_count().to_series().sum() == 0 + + +def test_rare_level_not_dropped_by_bootstrap_resample(): + """A rare categorical level absent from some bootstrap resamples must not + crash counterfactual prediction. The full-data level set is fixed at fit + time so every resample shares a stable factor encoding.""" + data = load_data("SEQdata") + ids = data["ID"].unique().to_list() + rng = np.random.RandomState(0) + # Level "c" appears for a single ID only, so aggressive subsampling will + # produce resamples that omit it entirely. + grp = pl.Series( + "grp", + np.where( + np.isin(data["ID"].to_numpy(), [ids[0]]), + "c", + rng.choice(["a", "b"], size=data.height), + ), + ) + data = data.with_columns(grp) + + s = _model(data, bootstrap_nboot=8, bootstrap_sample=0.5, seed=7) + s.expand() + s.bootstrap() + s.fit() + s.survival() # previously raised ValueError on NaN predictions + + rd = s.risk_estimates["risk_difference"] + assert rd.height > 0 + assert not rd["RD 95% LCI"].is_nan().any() + assert not rd["RD 95% UCI"].is_nan().any() diff --git a/tests/test_survival.py b/tests/test_survival.py index d9a577f..183dc32 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,11 +1,125 @@ import os +import polars as pl import pytest from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data +def _final_followup(s): + return s.km_data.filter(pl.col("estimate") == "risk")["followup"].max() + + +def test_risk_times_reports_requested_followups(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + km_curves=True, risk_times=[2, 5], bootstrap_nboot=3, seed=42 + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + + final = _final_followup(s) + rd = s.risk_estimates["risk_difference"] + rr = s.risk_estimates["risk_ratio"] + + assert "Followup" in rd.columns + assert set(rd["Followup"].to_list()) == {2, 5, final} + assert set(rr["Followup"].to_list()) == {2, 5, final} + for col in ["RD 95% LCI", "RD 95% UCI"]: + assert rd[col].null_count() == 0 + + +def test_risk_times_default_reports_only_final(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True), + ) + s.expand() + s.fit() + s.survival() + + assert set(s.risk_estimates["risk_difference"]["Followup"].to_list()) == { + _final_followup(s) + } + + +def test_risk_times_snaps_to_grid(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, risk_times=[2.5]), + ) + s.expand() + s.fit() + s.survival() + + # 2.5 snaps down to 2; final followup is always included + assert set(s.risk_estimates["risk_difference"]["Followup"].to_list()) == { + 2, + _final_followup(s), + } + + +def test_risk_times_exceeding_max_raises(): + data = load_data("SEQdata") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, risk_times=[1e6]), + ) + s.expand() + s.fit() + with pytest.raises(ValueError, match="maximum followup"): + s.survival() + + +def test_risk_times_negative_rejected(): + with pytest.raises(ValueError, match="non-negative"): + SEQopts(km_curves=True, risk_times=[-1]) + + def test_regular_survival(): data = load_data("SEQdata")