diff --git a/openavmkit/vertical_equity_study.py b/openavmkit/vertical_equity_study.py index b5e3728..37191f3 100644 --- a/openavmkit/vertical_equity_study.py +++ b/openavmkit/vertical_equity_study.py @@ -107,7 +107,8 @@ def __init__( sales = df_sales[field_sales].to_numpy() results = calc_ratio_stats_bootstrap(predictions, sales, confidence_interval, iterations=iterations, seed=seed) - + self.prd = results["prd"] + prb_point, prb_low, prb_high = calc_prb(predictions, sales, confidence_interval) self.prb = ConfidenceStat(prb_point, confidence_interval, prb_low, prb_high) @@ -232,6 +233,7 @@ def _calc_quantiles(df: pd.DataFrame, field: str): def _calc_grouped_quantiles(df_in: pd.DataFrame, value_field: str, group_field: str): df = df_in.copy() df["quantile"] = _calc_quantiles(df, value_field) + df_group_to_quantile = df.groupby(group_field)["quantile"].agg(lambda x: pd.Series.mode(x)[0]).reset_index() df2 = df_in.merge(df_group_to_quantile, on=group_field, how="left") return df2["quantile"] @@ -269,4 +271,4 @@ def _assemble_quantile_df(df_in: pd.DataFrame, field_sales: str, field_predictio df = pd.DataFrame(data=data) df = df.sort_values(by="quantile", key=lambda col: col.astype(int)) - return df \ No newline at end of file + return df