diff --git a/dotpy/core.py b/dotpy/core.py index 9810a1e..e8732e6 100644 --- a/dotpy/core.py +++ b/dotpy/core.py @@ -21,12 +21,18 @@ def __init__( device: Optional[str] = None, ): # --- Gene alignment --- - common_genes = np.intersect1d(spatial['genes'], ref['genes']) + spatial_genes = np.asarray(spatial['genes']) + ref_genes = np.asarray(ref['genes']) + common_genes = np.intersect1d(spatial_genes, ref_genes) if len(common_genes) == 0: raise ValueError("No common genes found between spatial and reference data") - sp_idx = np.where(np.isin(spatial['genes'], common_genes))[0] - rf_idx = np.where(np.isin(ref['genes'], common_genes))[0] + # Index both matrices in the same explicit order. Filtering each array with + # np.isin() preserves its original order, which may differ after DE ranking. + sp_lookup = {gene: i for i, gene in enumerate(spatial_genes)} + rf_lookup = {gene: i for i, gene in enumerate(ref_genes)} + sp_idx = np.fromiter((sp_lookup[g] for g in common_genes), dtype=np.int64) + rf_idx = np.fromiter((rf_lookup[g] for g in common_genes), dtype=np.int64) X_sp = spatial['X_sparse'][:, sp_idx] if issparse(spatial['X_sparse']) \ else spatial['X_sparse'][:, sp_idx] @@ -36,7 +42,7 @@ def __init__( self.spatial = { 'X_sparse': X_sp, 'coords': spatial['coords'], - 'genes': spatial['genes'][sp_idx], + 'genes': common_genes, 'device': device or 'cpu', } if 'pairs' in spatial: @@ -46,7 +52,7 @@ def __init__( 'X_sparse': X_rf, 'clusters': ref['clusters'], 'ratios': ref['ratios'], - 'genes': ref['genes'][rf_idx], + 'genes': common_genes, 'device': device or 'cpu', } @@ -224,6 +230,9 @@ def _run_optimisation( X_ref = torch.from_numpy(X_ref_np).to(device) # C × G X_sp = torch.from_numpy(X_sp_np).to(device) # S × G X_ref_norm = F.normalize(X_ref, p=2, dim=1) # C × G (L2-normed rows) + # R DOT computes ST_Xn <- normalize(ST_X) once and uses it in both + # spot-wise cosine terms. + X_sp_row_norm = F.normalize(X_sp, p=2, dim=1) # S × G c2m = torch.from_numpy(cluster_to_major).to(device) r_sc_t = torch.from_numpy(r_sc).to(device) @@ -286,13 +295,12 @@ def _run_optimisation( mix_weight = 0.1 Yt = Yt * mix_weight - X_sp_n = F.normalize(X_sp, p=2, dim=1) - linear_dcosine = 1 - X_ref_norm @ X_sp_n.T + linear_dcosine = 1 - X_ref_norm @ X_sp_row_norm.T c_min = linear_dcosine.argmin(dim=0) s_idx = torch.arange(S, device=device) Yt[c_min, s_idx] += (1 - mix_weight) * r_st_t - del X_sp_n, linear_dcosine + del linear_dcosine # ============================================================ # 4. Optimisation loop @@ -349,6 +357,7 @@ def _run_optimisation( s1 = min(s0 + batch, S) Yt_b = Yt[:, s0:s1] # C × b Xsp_b = X_sp[s0:s1] # b × G + Xsp_b_n = X_sp_row_norm[s0:s1] # b × G # predicted expression b × G if compute_dtype == torch.float16: @@ -362,14 +371,14 @@ def _run_optimisation( norms = st_xt.norm(dim=1, keepdim=True).clamp(min=1e-10) st_xt_n = st_xt / norms - csi = (Xsp_b * st_xt_n).sum(dim=1) + csi = (Xsp_b_n * st_xt_n).sum(dim=1) di = (1 - csi).clamp(min=0) d_i_grad = _sqrt_env_grad(di) di_sqrt = _sqrt_env(di) dcosine_st += di_sqrt.sum().item() coef = l_i * (1 - sparsity_coef) - st_de = coef * (Xsp_b - st_xt_n * csi.unsqueeze(1)) \ + st_de = coef * (Xsp_b_n - st_xt_n * csi.unsqueeze(1)) \ * d_i_grad.unsqueeze(1) / norms if compute_dtype == torch.float16: @@ -379,7 +388,6 @@ def _run_optimisation( # -- linear sparsity -- if l_sp > 0: - Xsp_b_n = F.normalize(Xsp_b, p=2, dim=1) if compute_dtype == torch.float16: lin_d = (1 - X_ref_norm.half() @ Xsp_b_n.half().T).float() else: