Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

__version__ = metadata.version("scib")

from . import integration, metrics, preprocessing, utils
from . import integration, metrics, plotting, preprocessing, utils
from ._package_tools import rename_func
from .metrics import clustering

Expand Down Expand Up @@ -33,3 +33,4 @@
ig = integration
me = metrics
cl = clustering
pl = plotting
141 changes: 141 additions & 0 deletions scib/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt


def metrics(
metrics_df,
method_column="method",
metric_column="metric",
value_column="value",
batch_metrics=None,
bio_metrics=None,
palette=None,
overall=True,
return_fig=False,
):
"""
:param metrics_df: dataframe with columns for methods, metrics and metric values
:param method_column: column in ``metrics_df`` of methods
:param metric_column: column in ``metrics_df`` of metrics
:param value_column: column in ``metrics_df`` with metric values
:param batch_metrics: list of batch correction metrics in metrics column for annotating metric type
:param bio_metrics: list of biological conservation metrics in the metrics column for annotating metric type
:param palette: color map as input for ``seaborn.scatterplot``
:param overall: whether to include a column for the overall score
:param return_fig: whether to return a fig object
"""
sns.set_context("paper")
sns.set_style("white")

if palette is None:
palette = "viridis_r"
# sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True)

if batch_metrics is None:
batch_metrics = ["ASW_batch", "PCR_batch", "graph_conn", "kBET", "iLISI"]

if bio_metrics is None:
bio_metrics = [
"NMI_cluster",
"ARI_cluster",
"ASW_label",
"cell_cycle_conservation",
"isolated_label_F1",
"isolated_label_silhouette",
"cLISI",
"hvg_overlap",
"trajectory",
]

df = metrics_df.copy()

conditions = [
(df[metric_column].isin(batch_metrics)),
(df[metric_column].isin(bio_metrics)),
]
metric_type = ["Batch Correction", "Biological Conservation"]
df["metric_type"] = np.select(conditions, metric_type)
df[metric_column] = df[metric_column].str.replace("_", " ")

# overall score
df_list = [
df,
df.groupby([method_column, "metric_type"])[value_column]
.mean()
.reset_index()
.assign(metric="Overall"),
df.groupby(method_column)[value_column]
.mean()
.reset_index()
.assign(metric_type="Overall", metric="Overall"),
]
df = pd.concat(df_list)

# rank metrics
df["rank"] = (
df.groupby([metric_column, "metric_type"])[value_column]
.rank(
method="min",
ascending=False,
na_option="bottom",
)
.astype(int)
)
method_rank = df.query('metric_type == "Overall"').sort_values(
"rank", ascending=True
)[method_column]
df[method_column] = pd.Categorical(df[method_column], categories=method_rank)

# get plot dimensions
dims = (
df[["metric_type", metric_column]]
.drop_duplicates()["metric_type"]
.value_counts()
)
n_metric_types = dims.shape[0]
n_metrics = dims.sum()
n_methods = df[method_column].nunique()
metric_len = df[metric_column].str.len().max()
dim_x = np.max([4, (n_metrics + n_metric_types) * 0.4 + (metric_len / 10)])
dim_y = np.max([2.5, n_methods * 0.9])

# Build plot
fig, axs = plt.subplots(
nrows=1,
ncols=n_metric_types,
figsize=(dim_x, dim_y),
sharey=True,
gridspec_kw=dict(width_ratios=list(dims)),
)

for i, metric_type in enumerate(dims.index):
legend = "brief" if i == 0 else None
df_sub = df.query(f'metric_type == "{metric_type}"')
ax = axs if n_metric_types == 1 else axs[i]
sns.scatterplot(
data=df_sub,
x=metric_column,
y=method_column,
hue="rank",
palette=palette,
size=value_column,
sizes=(df_sub["value"].min() * 100, df_sub["value"].max() * 100),
# sizes={x: int(x * 200) for x in df_sub['value'].dropna().unique()},
edgecolor="black",
legend=legend,
ax=ax,
)
ax.set(title=metric_type, xlabel=None, ylabel=None)
ax.tick_params(axis="x", rotation=90)
if legend is not None:
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)
for t in ax.legend_.texts:
t.set_text(t.get_text()[:5])
sns.despine(bottom=True, left=True)

fig.tight_layout()

if return_fig:
return fig
44 changes: 44 additions & 0 deletions tests/plots/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pandas as pd

import scib


def test_plot():
data = {
"ARI_cluster": {0: 0.951112722518898, 1: 0.262192519680191, 2: 0.1, 3: 0.2},
"ASW_batch": {0: 0.9057019050549192, 1: 0.8448200803913499, 2: 0.7, 3: 0.3},
"ASW_label": {0: 0.617242477834225, 1: 0.564448088407517, 2: 0.4, 3: 0.7},
"NMI_cluster": {0: 0.9138665032024672, 1: 0.632615412598558, 2: 0.4, 3: 0.7},
"PCR_batch": {0: 0.855878437307926, 1: 0.7125446098053699, 2: 0.6, 3: 0.5},
"cLISI": {0: 1.0, 1: 0.9993835933509928, 2: 0.8, 3: 0.9},
"cell_cycle_conservation": {
0: 0.470498471863989,
1: 0.741363581608263,
2: 0.6,
3: 0.8,
},
"graph_conn": {0: 0.971955345243732, 1: 0.944989571511962, 2: 0.8, 3: 0.7},
"hvg_overlap": {0: 0.4772209890553079, 1: 0.2025893518406739, 2: 0.1, 3: 0.2},
"iLISI": {0: 0.07924053136125, 1: 0.004064867867098, 2: 0.1, 3: 0.2},
"isolated_label_F1": {
0: 0.107692307692308,
1: 0.106870229007634,
2: 0.1,
3: 0.3,
},
"isolated_label_silhouette": {
0: 0.520902156829834,
1: 0.550404392182827,
2: 0.4,
3: 0.6,
},
"kBET": {0: 0.3197709591957574, 1: 0.2183332674192387, 2: 0.1, 3: 0.2},
"method": {0: "method1", 1: "method2", 2: "method3", 3: "method4"},
"trajectory": {0: np.nan, 1: np.nan, 2: np.nan, 3: np.nan},
}

df = pd.DataFrame(data)
df = df.melt(id_vars=["method"], var_name="metric", value_name="value")
scib.pl.metrics(df)
scib.pl.metrics(df[0:1])
Loading