Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions dotpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ def __init__(
device: Optional[str] = None,
):
# --- Gene alignment ---
common_genes = np.intersect1d(spatial['genes'], ref['genes'])
spatial_genes = np.asarray(spatial['genes'])
ref_genes = np.asarray(ref['genes'])
common_genes = np.intersect1d(spatial_genes, ref_genes)
if len(common_genes) == 0:
raise ValueError("No common genes found between spatial and reference data")

sp_idx = np.where(np.isin(spatial['genes'], common_genes))[0]
rf_idx = np.where(np.isin(ref['genes'], common_genes))[0]
# Index both matrices in the same explicit order. Filtering each array with
# np.isin() preserves its original order, which may differ after DE ranking.
sp_lookup = {gene: i for i, gene in enumerate(spatial_genes)}
rf_lookup = {gene: i for i, gene in enumerate(ref_genes)}
sp_idx = np.fromiter((sp_lookup[g] for g in common_genes), dtype=np.int64)
rf_idx = np.fromiter((rf_lookup[g] for g in common_genes), dtype=np.int64)

X_sp = spatial['X_sparse'][:, sp_idx] if issparse(spatial['X_sparse']) \
else spatial['X_sparse'][:, sp_idx]
Expand All @@ -36,7 +42,7 @@ def __init__(
self.spatial = {
'X_sparse': X_sp,
'coords': spatial['coords'],
'genes': spatial['genes'][sp_idx],
'genes': common_genes,
'device': device or 'cpu',
}
if 'pairs' in spatial:
Expand All @@ -46,7 +52,7 @@ def __init__(
'X_sparse': X_rf,
'clusters': ref['clusters'],
'ratios': ref['ratios'],
'genes': ref['genes'][rf_idx],
'genes': common_genes,
'device': device or 'cpu',
}

Expand Down Expand Up @@ -224,6 +230,9 @@ def _run_optimisation(
X_ref = torch.from_numpy(X_ref_np).to(device) # C × G
X_sp = torch.from_numpy(X_sp_np).to(device) # S × G
X_ref_norm = F.normalize(X_ref, p=2, dim=1) # C × G (L2-normed rows)
# R DOT computes ST_Xn <- normalize(ST_X) once and uses it in both
# spot-wise cosine terms.
X_sp_row_norm = F.normalize(X_sp, p=2, dim=1) # S × G

c2m = torch.from_numpy(cluster_to_major).to(device)
r_sc_t = torch.from_numpy(r_sc).to(device)
Expand Down Expand Up @@ -286,13 +295,12 @@ def _run_optimisation(
mix_weight = 0.1
Yt = Yt * mix_weight

X_sp_n = F.normalize(X_sp, p=2, dim=1)
linear_dcosine = 1 - X_ref_norm @ X_sp_n.T
linear_dcosine = 1 - X_ref_norm @ X_sp_row_norm.T
c_min = linear_dcosine.argmin(dim=0)
s_idx = torch.arange(S, device=device)
Yt[c_min, s_idx] += (1 - mix_weight) * r_st_t

del X_sp_n, linear_dcosine
del linear_dcosine

# ============================================================
# 4. Optimisation loop
Expand Down Expand Up @@ -349,6 +357,7 @@ def _run_optimisation(
s1 = min(s0 + batch, S)
Yt_b = Yt[:, s0:s1] # C × b
Xsp_b = X_sp[s0:s1] # b × G
Xsp_b_n = X_sp_row_norm[s0:s1] # b × G

# predicted expression b × G
if compute_dtype == torch.float16:
Expand All @@ -362,14 +371,14 @@ def _run_optimisation(
norms = st_xt.norm(dim=1, keepdim=True).clamp(min=1e-10)
st_xt_n = st_xt / norms

csi = (Xsp_b * st_xt_n).sum(dim=1)
csi = (Xsp_b_n * st_xt_n).sum(dim=1)
di = (1 - csi).clamp(min=0)
d_i_grad = _sqrt_env_grad(di)
di_sqrt = _sqrt_env(di)
dcosine_st += di_sqrt.sum().item()

coef = l_i * (1 - sparsity_coef)
st_de = coef * (Xsp_b - st_xt_n * csi.unsqueeze(1)) \
st_de = coef * (Xsp_b_n - st_xt_n * csi.unsqueeze(1)) \
* d_i_grad.unsqueeze(1) / norms

if compute_dtype == torch.float16:
Expand All @@ -379,7 +388,6 @@ def _run_optimisation(

# -- linear sparsity --
if l_sp > 0:
Xsp_b_n = F.normalize(Xsp_b, p=2, dim=1)
if compute_dtype == torch.float16:
lin_d = (1 - X_ref_norm.half() @ Xsp_b_n.half().T).float()
else:
Expand Down