diff --git a/tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md b/tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md new file mode 100644 index 0000000..3428808 --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md @@ -0,0 +1,80 @@ +# Cellot Model Training and Evaluation: OOD Workflow + +This document provides a step-by-step guide to preparing data, training the Cellot model in Out-of-Distribution (OOD) mode, and evaluating model predictions. This process involves custom modifications to the Cellot codebase to address specific requirements and improve model functionality (loss ouputs and anndata with ctrl/stim/pred). + +--- + +## 1. Data Preparation + +The Cellot model requires an AnnData object that contains information for two conditions: + - **Control condition** (e.g., `ctrl`) + - **Perturbed condition** (e.g., `stim`) + +### Data Requirements +- **Data format**: The data should be in an AnnData structure compatible with single-cell analysis tools like Scanpy. Each observation in AnnData should include metadata in the `.obs` attribute, including cell type and condition. +- **Issue with Pertpy anndata objects**: You need to save the metadata and matrix counts of pertpy datasets (in the pertpy environment, anndata version 0.10.8) and then rebuild the anndata objects in the cellOT environment with anndata version 0.7.6, which does not read anndata objects saved in 0.10.8. + +- **Data Normalization**: + - Normalization is essential for consistent model performance. Using the `normalize_total` and `log1p` function in Scanpy. + - **Scaling**: its specific impact on Cellot is still being evaluated, though standardizing features across cells may be beneficial for model performance..don't know yet + +After preparing and normalizing data, save the AnnData object for input into the Cellot model training and evaluation scripts. + + +--- + +## 2. Training the Model with `cellot_train_v3_ood.py` + +### Environment Setup +Follow the environment setup instructions from the Cellot GitHub repository to ensure dependencies are properly installed. Specifically, a Conda environment is recommended for managing dependencies. + +### Custom Code Modifications +To handle specific issues encountered during model training, some modifications were made to the Cellot source code. These adjustments enhance compatibility with my OOD training and include: + - **Files modified**: + - `cellot.data.cell` + - `cellot.models.cellot` + - `cellot.networks.icnns` + - `cellot.train.train` + +### Training Configuration +The Cellot OOD training script (`cellot_train_v3_ood.py`) includes a loop to automatically train individual models for each cell type. The key training parameters include: +- **Condition column** (`condition`): Defines the grouping of data into control and perturbed conditions. +- **Source and target conditions**: These specify the training setup. For example, `source='ctrl'` and `target='stim'`. +- **Epochs and batch size**: Standard parameters for deep learning models. +- **Holdout cell type** (`datasplit_holdout`): Specifies the cell type excluded from training for each OOD model. + +### Running the Training Script +Run the `cellot_train_v3_ood.py` script after verifying all dependencies and data requirements. This script will: +1. Train the model for each specified cell type in OOD mode, using other cell types as training data. +2. Save models and training outputs, including loss tracking. + +**Note**: Loss curves for transport functions are recorded and can be plotted at the end of each training session, though further integration into the loop is in progress. + +--- + +## 3. Model Evaluation with `cellot_eval_v3_ood.py` + +Once models are trained, the `cellot_eval_v3_ood.py` script enables evaluation of each cell-type-specific model. Evaluation includes visualizing predictions and computing performance metrics. + +### Evaluation Outputs +1. **Dimensionality Reduction (PCA and UMAP)**: + - The script generates PCA and UMAP visualizations, specifically for the holdout cell type (excluded during training) for each trained model. + - These plots allow direct visual inspection of predicted cell distributions compared to actual data, providing insights into the model's performance in the OOD setting. + +2. **Performance Metrics**: + - **R² Score**: Calculating R² for the predicted versus actual values helps quantify the model’s prediction accuracy for each cell type. + - **Transport Distance**: The distances (euclidian, edistance and mmd), transport function metrics, assesses how accurately the model translates control cells into their perturbed states. + +The evaluation process allows detailed analysis of each model's performance per cell type, facilitating further adjustments and optimization of model parameters. + +--- + +### Summary of Parameters and Model Configurations + +- **`condition`**: Defines the control and target conditions in the dataset. +- **`datasplit_mode`**: When set to `ood`, this parameter splits the data to ensure the holdout cell type is excluded from the training data. +- **`datasplit_groupby`**: Controls grouping for the data split. For example, setting `['celltype','condition']` splits based on both cell type and condition. + +--- + +This document provides the foundational steps for leveraging Cellot’s OOD training mode. Adapting the model to specific datasets and further optimizing parameters will enhance performance, especially with custom data configurations. It's still ugly with redundancy. diff --git a/tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py b/tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py new file mode 100644 index 0000000..0fbcf53 --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py @@ -0,0 +1,715 @@ +import torch +from pathlib import Path +import pandas as pd +import scanpy as sc +from cellot.models.cellot import load_cellot_model, compute_loss_g, compute_loss_f, compute_w2_distance +import anndata +import numpy as np +import os +import matplotlib.pyplot as plt +from types import SimpleNamespace +from sklearn.metrics import r2_score +print(os.getcwd()) + + + +class ConfigNamespace(SimpleNamespace): + def get(self, key, default=None): + return getattr(self, key, default) + + def to_dict(self): + """ + Recursively converts the ConfigNamespace object into a dictionary. + """ + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, ConfigNamespace): + result[key] = value.to_dict() + else: + result[key] = value + return result + + def as_dict(self): + """Returns the instance as a dictionary for secure use in code""" + return self.to_dict() + + def __contains__(self, key): + return key in self.__dict__ + +# transform a dictionnary in ConfigNamespace +def dict_to_namespace(config_dict): + return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) + +# Function for converting ConfigNamespace objects into a dictionary before use +def convert_to_dict_if_namespace(obj): + """converting ConfigNamespace objects into a dictionary""" + if isinstance(obj, ConfigNamespace): + return obj.to_dict() + elif isinstance(obj, dict): + return {k: convert_to_dict_if_namespace(v) for k, v in obj.items()} + else: + return obj + + + +def load_test_data(test_data_path, config): + test_data = sc.read(test_data_path) + source_data = test_data[test_data.obs[config.data.condition] == config.data.source] + target_data = test_data[test_data.obs[config.data.condition] == config.data.target] + source_tensor = torch.tensor(source_data.X.toarray(), dtype=torch.float32, requires_grad=True) + target_tensor = torch.tensor(target_data.X.toarray(), dtype=torch.float32) + return list(zip(source_tensor, target_tensor)) + + +def create_anndata_with_predictions(config, model_path, original_data): + # Load the model + checkpoint = torch.load(model_path) + (f, g), _ = load_cellot_model(config) + f.load_state_dict(checkpoint['f_state']) + g.load_state_dict(checkpoint['g_state']) + + # Set the model to evaluation mode + f.eval() + g.eval() + + # Filter for source cells (ctrl condition) + source_data = original_data[original_data.obs[config.data.condition] == config.data.source] + + # Convert source data to tensor and set requires_grad for all tensors + source_tensor = torch.tensor( + source_data.X.toarray() if hasattr(source_data.X, "toarray") else source_data.X, + dtype=torch.float32, + requires_grad=True + ) + + # Step 1: Verify for NaNs in the source_tensor + print(f"Step 1: Nombre de NaN dans source_tensor : {torch.isnan(source_tensor).sum().item()}") + + # Store predicted cells + predicted_cells = [] + + # Step 2: Obtain predictions for each source cell and check for NaNs in the prediction + for i, source in enumerate(source_tensor): + with torch.set_grad_enabled(True): # Ensure grad tracking is enabled + source = source.unsqueeze(0) # Ensure correct shape + predicted = g.transport(source) # Transport function + + # Check if prediction contains NaNs + if torch.isnan(predicted).any(): + print(f"Step 2: NaNs detected in prediction for cell {i}") + else: + print(f"Step 2: Prediction successful for cell {i}") + + predicted_cells.append(predicted.detach().numpy()) # Detach after prediction + + # Stack predictions into an array + predicted_data_matrix = np.vstack(predicted_cells) + # Normalize predicted data to match original data's scale ? + # This assumes that the original data has been normalized with scanpy's `normalize_total` + predicted_adata = anndata.AnnData(X=predicted_data_matrix) + + + + # Step 3: Verify if predicted_data_matrix contains NaNs after prediction loop + print(f"Step 3: Nombre de NaN dans predicted_data_matrix : {np.isnan(predicted_data_matrix).sum()}") + + + # Combine predicted data matrix with the existing data matrix, ensuring no duplication with 'ctrl' + original_data_matrix = ( + original_data.X.toarray() if hasattr(original_data.X, "toarray") else original_data.X + ) + + # Step 4: Check for NaNs in the original data matrix + print(f"Step 4: Nomber of NaN in the original_data_matrix : {np.isnan(original_data_matrix).sum()}") + + # Step 5: Combine matrices and check for NaNs in the combined data + combined_data = np.vstack([original_data_matrix, predicted_data_matrix]) + print(f"Step 5: Nomber of NaN in combined_data : {np.isnan(combined_data).sum()}") + + # Copy original metadata and create labels for predictions + combined_obs = original_data.obs.copy() + + # Generate a new observation dataframe for predicted cells based on source cells but labeled as 'predicted' + predicted_obs = source_data.obs.copy() + predicted_obs[config.data.condition] = 'predicted' # Set new condition + predicted_obs.index = [f"pred_{i}" for i in range(len(predicted_cells))] # Unique indices for predictions + + # Concatenate the original observations with the newly created predicted observations + combined_obs = pd.concat([combined_obs, predicted_obs]) + + # Final AnnData object with original and predicted cells + anndata_with_predictions = anndata.AnnData( + X=combined_data, + obs=combined_obs, + var=original_data.var + ) + + + # Ensure observation names are unique + anndata_with_predictions.obs_names_make_unique() + + # Step 6: Check if the final AnnData object contains NaNs in X + print(f"Step 6: Nomber of NaN in anndata_with_predictions.X : {np.isnan(anndata_with_predictions.X).sum()}") + + # Optional: If desired, set raw attribute for the AnnData object + anndata_with_predictions.raw = anndata_with_predictions.copy() + + return anndata_with_predictions + + +# Load the dataset +dataset_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" +adata = sc.read_h5ad(dataset_path) +cell_types = adata.obs['cell_type'].unique() +output_dir = ".\\output_ood_models_1" + +for cell_type in cell_types: + model_dir = os.path.join(output_dir, f"{cell_type}_ood") + model_path = Path(model_dir) / "cache" / "model.pt" + holdout = cell_type + + # Define task configuration for evaluation + task_config = { + 'dataset': dataset_path, + 'condition': 'condition', + 'source': 'ctrl', + 'target': 'stim', + 'type': 'cell', + 'batch_size': 128, + 'shuffle': True, + 'datasplit_groupby': ['cell_type','condition'], + 'datasplit_name': 'toggle_ood', + 'key' : 'cell_type', + 'datasplit_mode': 'ood', # Set mode to 'ood' + 'datasplit_holdout': holdout, # Specify holdout cell type + 'datasplit_test_size': 0.3, + 'datasplit_random_state': 0 + } + + model_config = { + 'input_dim': 1000, + 'name': 'cellot', + 'hidden_units': [64, 64, 64, 64], + 'latent_dim': 100, + 'softplus_W_kernels': False, + 'g': { + 'fnorm_penalty': 1 + }, + 'kernel_init_fxn': { + 'b': 0.1, + 'name': 'uniform' + }, + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.0001, + 'beta1': 0.5, + 'beta2': 0.9, + 'weight_decay': 0 + }, + 'training': { + 'n_iters': 100000, + 'n_inner_iters': 1, + 'cache_freq': 50, + 'eval_freq': 20, + 'logs_freq': 10 + } + } + + config = { + 'training': model_config['training'], + 'data': task_config, + 'model': model_config, + 'datasplit': { + 'groupby': task_config['datasplit_groupby'], + 'name': task_config['datasplit_name'], + 'test_size': task_config['datasplit_test_size'], + 'random_state': task_config['datasplit_random_state'], + 'holdout': task_config.get('datasplit_holdout', None), + 'key': task_config.get('key', None), + 'mode': task_config.get('datasplit_mode', 'iid'), + 'subset': None + }, + 'dataloader': { + 'batch_size': task_config['batch_size'], + 'shuffle': task_config['shuffle'] + } + } + + config_ns = dict_to_namespace(config) + + # Evaluate model and create AnnData with predictions + anndata_with_predictions = create_anndata_with_predictions(config_ns, model_path, adata) + + + # Filter only for the cells of the holdout type and their predictions + holdout_cells = anndata_with_predictions[ + (anndata_with_predictions.obs['cell_type'] == cell_type) + ] + + # Visualization with PCA + print(f"Evaluating PCA and UMAP for holdout cell_type: {cell_type}") + sc.tl.pca(holdout_cells, svd_solver="arpack") + sc.pl.pca(holdout_cells, color="condition", title=f"PCA for {cell_type}") + + # Visualization with UMAP + sc.pp.neighbors(holdout_cells) + sc.tl.umap(holdout_cells) + sc.pl.umap(holdout_cells, color="condition", title=f"UMAP for {cell_type}") + +#---------------------- R² ----------------------- + +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd +from scipy import sparse +import numpy as np + +# R² dictionnary initialization +r2_results = {} + +# Loop on cell types +for cell_type in anndata_with_predictions.obs['cell_type'].unique(): + + cell_data = anndata_with_predictions[anndata_with_predictions.obs['cell_type'] == cell_type] + + # Extract data on each condition 'stim' and 'predicted' for the specific cell type + stim_data = cell_data[cell_data.obs['condition'] == 'stim'].X + predicted_data = cell_data[cell_data.obs['condition'] == 'predicted'].X + + # if necessary convert sparce matrix + if sparse.issparse(stim_data): + stim_data = stim_data.toarray() + if sparse.issparse(predicted_data): + predicted_data = predicted_data.toarray() + + # expression mean on each genes + stim_mean = stim_data.mean(axis=0) + predicted_mean = predicted_data.mean(axis=0) + + # Pearson correlations + r = np.corrcoef(stim_mean, predicted_mean)[0, 1] + r2 = r ** 2 # R² basé sur la corrélation de Pearson + + # Store the result + r2_results[cell_type] = r2 + print(f"R² for {cell_type} between 'stim' and 'predicted' mean genes: {r2:.4f}") + + # Data process to visualization + df_plot = pd.DataFrame({ + 'Stim Mean Expression': stim_mean, + 'Predicted Mean Expression': predicted_mean + }) + + # RegPlot + plt.figure(figsize=(8, 6)) + sns.regplot( + x='Stim Mean Expression', + y='Predicted Mean Expression', + data=df_plot, + scatter_kws={'s': 10}, # Taille des points + line_kws={'color': 'red'} # Couleur de la ligne de régression + ) + plt.title(f'Regression Plot for {cell_type}\nR² = {r2:.4f}') + plt.xlabel('Stim Mean Expression') + plt.ylabel('Predicted Mean Expression') + plt.grid(True) + plt.show() + +# Print R² results +print("\nR² of mean expression genes for each cell type between 'stim' and 'predicted' cells:") +for cell_type, r2 in r2_results.items(): + print(f"{cell_type}: {r2:.4f}") + +#----------- edistance -------------- + +from scipy.spatial.distance import cdist +from scipy.sparse import issparse +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_edistance(set1, set2): + """ + Compute the energy distance between two datasets. + """ + intra_dist1 = np.mean(cdist(set1, set1, metric="euclidean")) + intra_dist2 = np.mean(cdist(set2, set2, metric="euclidean")) + inter_dist = np.mean(cdist(set1, set2, metric="euclidean")) + return 2 * inter_dist - intra_dist1 - intra_dist2 + +def compute_perturbation_score_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted"): + """ + Compute the perturbation score for each cell type. + + Parameters: + anndata: AnnData object containing gene expression data. + n_comps: Number of principal components to use. + condition_col: Column name in `.obs` that specifies the condition. + stim_key: Key for the stimulated condition. + ctrl_key: Key for the control condition. + pred_key: Key for the predicted condition. + + Returns: + A dictionary mapping each cell type to its perturbation score. + """ + perturbation_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract the subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Skip if any subset is empty + if stim_adata.shape[0] == 0 or ctrl_adata.shape[0] == 0 or pred_adata.shape[0] == 0: + print(f"Skipping {cell_type} due to insufficient data.\n") + continue + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Convert sparse matrices to dense + if issparse(stim_pca): stim_pca = stim_pca.toarray() + if issparse(ctrl_pca): ctrl_pca = ctrl_pca.toarray() + if issparse(pred_pca): pred_pca = pred_pca.toarray() + + # Compute energy distances + edistance_stim_pred = compute_edistance(stim_pca, pred_pca) # Perturbed vs Predicted + edistance_ctrl_pred = compute_edistance(ctrl_pca, pred_pca) # Control vs Predicted + + # Avoid division by zero + if edistance_ctrl_pred == 0: + perturbation_score = np.nan + else: + perturbation_score = edistance_stim_pred / edistance_ctrl_pred + + perturbation_scores[cell_type] = perturbation_score + print(f"Perturbation score for {cell_type}: {perturbation_score}\n") + + return perturbation_scores + + +perturbation_scores = compute_perturbation_score_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted" +) + +# Display the results +print("Scaled perturbation scores for all cell types:") +for cell_type, score in perturbation_scores.items(): + print(f"{cell_type}: {score:.4f}") + + +#-------------- mmd ------------------- + +from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_mmd(set1, set2, kernel="linear", **kernel_kwargs): + """ + Compute the Maximum Mean Discrepancy (MMD) between two datasets. + + Parameters: + set1: np.ndarray + First dataset (e.g., real perturbed data). + set2: np.ndarray + Second dataset (e.g., predicted data). + kernel: str + Type of kernel to use. Options are 'linear', 'rbf', and 'poly'. + **kernel_kwargs: + Additional arguments for the kernel function (e.g., gamma for RBF). + + Returns: + float + MMD score. + """ + if kernel == "linear": + XX = np.dot(set1, set1.T) + YY = np.dot(set2, set2.T) + XY = np.dot(set1, set2.T) + elif kernel == "rbf": + XX = rbf_kernel(set1, set1, **kernel_kwargs) + YY = rbf_kernel(set2, set2, **kernel_kwargs) + XY = rbf_kernel(set1, set2, **kernel_kwargs) + elif kernel == "poly": + XX = polynomial_kernel(set1, set1, **kernel_kwargs) + YY = polynomial_kernel(set2, set2, **kernel_kwargs) + XY = polynomial_kernel(set1, set2, **kernel_kwargs) + else: + raise ValueError(f"Unsupported kernel type: {kernel}") + + return XX.mean() + YY.mean() - 2 * XY.mean() + +def compute_mmd_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted", + kernel="linear", + **kernel_kwargs): + """ + Compute the MMD for each cell type. + + Parameters: + anndata: AnnData + Annotated data matrix containing the data. + n_comps: int + Number of PCA components to use. + condition_col: str + Column in `.obs` specifying the condition of the cells. + stim_key: str + Key for the stimulated condition in `.obs`. + ctrl_key: str + Key for the control condition in `.obs`. + pred_key: str + Key for the predicted condition in `.obs`. + kernel: str + Kernel type for MMD. Options: 'linear', 'rbf', 'poly'. + **kernel_kwargs: + Additional parameters for the kernel function. + + Returns: + dict + A dictionary mapping each cell type to its MMD perturbation score. + """ + mmd_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Compute MMD scores + mmd_stim_pred = compute_mmd(stim_pca, pred_pca, kernel=kernel, **kernel_kwargs) # Stimulated vs Predicted + mmd_ctrl_pred = compute_mmd(ctrl_pca, pred_pca, kernel=kernel, **kernel_kwargs) # Control vs Predicted + + # Combine scores into a perturbation score + mmd_score = mmd_stim_pred / mmd_ctrl_pred + mmd_scores[cell_type] = mmd_score + print(f"MMD perturbation score for {cell_type}: {mmd_score}\n") + + return mmd_scores + +# Compute MMD scores for all cell types +mmd_scores = compute_mmd_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted", + kernel="linear"#, # Example: using RBF kernel +# gamma=1.0 # Example parameter for the RBF kernel +) + +# Display the results +print("MMD perturbation scores for all cell types:") +for cell_type, score in mmd_scores.items(): + print(f"{cell_type}: {score:.4f}") + +#------------------ euclidean distances ------------------- + +from sklearn.metrics.pairwise import euclidean_distances +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_mean_euclidean_distance(set1, set2): + """ + Compute the mean Euclidean distance between two datasets. + + Parameters: + set1: np.ndarray + First dataset (e.g., real perturbed data). + set2: np.ndarray + Second dataset (e.g., predicted data). + + Returns: + float + Mean Euclidean distance between set1 and set2. + """ + pairwise_distances = euclidean_distances(set1, set2) + return pairwise_distances.mean() + +def compute_euclidean_distance_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted"): + """ + Compute the mean Euclidean distance for each cell type. + + Parameters: + anndata: AnnData + Annotated data matrix containing the data. + n_comps: int + Number of PCA components to use. + condition_col: str + Column in `.obs` specifying the condition of the cells. + stim_key: str + Key for the stimulated condition in `.obs`. + ctrl_key: str + Key for the control condition in `.obs`. + pred_key: str + Key for the predicted condition in `.obs`. + + Returns: + dict + A dictionary mapping each cell type to its Euclidean perturbation score. + """ + euclidean_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Compute Euclidean distances + euclidean_stim_pred = compute_mean_euclidean_distance(stim_pca, pred_pca) # Stimulated vs Predicted + euclidean_ctrl_pred = compute_mean_euclidean_distance(ctrl_pca, pred_pca) # Control vs Predicted + + # Combine scores into a perturbation score + euclidean_score = euclidean_stim_pred / euclidean_ctrl_pred + euclidean_scores[cell_type] = euclidean_score + print(f"Euclidean perturbation score for {cell_type}: {euclidean_score}\n") + + return euclidean_scores + +# Compute Euclidean distance scores for all cell types +euclidean_scores = compute_euclidean_distance_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted" +) + +# Display the results +print("Euclidean perturbation scores for all cell types:") +for cell_type, score in euclidean_scores.items(): + print(f"{cell_type}: {score:.4f}") + + +#-------------------- Bar plot with each metrics ---------------- + + +import matplotlib.pyplot as plt +import numpy as np + + +cell_types = list(r2_results.keys()) # cell types +r2_scores = [float(val) for val in r2_results.values()] # Convert if necessary +edistances = [float(val) for val in perturbation_scores.values()] # Convert if necessary +mmd_res = [float(val) for val in mmd_scores.values()] # Convert if necessary +euclidean_dist = [float(val) for val in euclidean_scores.values()] # Convert if necessary + + + + +# Config of sub-graphs +fig, axes = plt.subplots(1, 4, figsize=(16, 8), sharey=True) + +# Graph of R² Scores +axes[0].barh(cell_types, r2_scores, color='blue', edgecolor='black') +axes[0].set_title("R² Scores") +axes[0].set_xlabel("Valeur") +axes[0].invert_yaxis() # Aligment of cell types on all graphs + +# Graph of Energy Distances +axes[1].barh(cell_types, edistances, color='green', edgecolor='black') +axes[1].set_title("Energy Distance") +axes[1].set_xlabel("Valeur") + +# Graph of MMD Scores +axes[2].barh(cell_types, mmd_res, color='orange', edgecolor='black') +axes[2].set_title("MMD Scores") +axes[2].set_xlabel("Valeur") + +# Graph of Euclidean Distance +axes[3].barh(cell_types, euclidean_dist, color='red', edgecolor='black') +axes[3].set_title("Euclidean Distance") +axes[3].set_xlabel("Valeur") + +# Spaces between sub-graphs +plt.tight_layout() + +# plot +plt.show() + + + + + + + + + + + diff --git a/tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py b/tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py new file mode 100644 index 0000000..772aefe --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py @@ -0,0 +1,240 @@ +import sys +from pathlib import Path +from cellot.train.train import train_cellot, train_auto_encoder, train_popalign +from types import SimpleNamespace +from cellot.utils.loaders import load +import scanpy as sc + +#from pathlib import Path +#from cellot_train import train_cellot, train_auto_encoder, train_popalign + +#mport csv +#from pathlib import Path + +#import torch +#import numpy as np +#import random +#import pickle +#from absl import logging +#from absl.flags import FLAGS +#from cellot import losses +#from cellot.utils.loaders import load +#from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance +#from cellot.train.summary import Logger +#from cellot.data.utils import cast_loader_to_iterator +#from cellot.models.ae import compute_scgen_shift +#from tqdm import trange + + + + + +class ConfigNamespace(SimpleNamespace): + def get(self, key, default=None): + return getattr(self, key, default) + + def to_dict(self): + """ + Recursively converts the ConfigNamespace object into a dictionary. + """ + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, ConfigNamespace): + result[key] = value.to_dict() + else: + result[key] = value + return result + + def as_dict(self): + """Returns the instance as a dictionary for secure use in code""" + return self.to_dict() + + def __contains__(self, key): + return key in self.__dict__ + +# transform a dictionnary in ConfigNamespace +def dict_to_namespace(config_dict): + return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) + +# Function for converting ConfigNamespace objects into a dictionary before use +def convert_to_dict_if_namespace(obj): + """converting ConfigNamespace objects into a dictionary""" + if isinstance(obj, ConfigNamespace): + return obj.to_dict() + elif isinstance(obj, dict): + return {k: convert_to_dict_if_namespace(v) for k, v in obj.items()} + else: + return obj + + + +def run_cellot_training(task_config, model_config, train_type="cellot", outdir='./output'): + config = { + 'training': { + 'n_iters': task_config.get('epochs', 10), + 'logs_freq': 10, + 'eval_freq': 20, + 'cache_freq': 100, + 'n_inner_iters': 1 + }, + 'data': { + 'condition': task_config.get('condition', 'drug'), + 'source': task_config.get('source', 'control'), + 'target': task_config.get('target', 'stim'), + 'type': task_config.get('type', 'cell'), + 'path': task_config.get('dataset', '') + }, + 'model': model_config, + 'datasplit': { + 'groupby': task_config.get('datasplit_groupby', 'condition'), + 'name': task_config.get('datasplit_name', 'train_test'), + 'test_size': task_config.get('datasplit_test_size', 0.2), + 'random_state': task_config.get('datasplit_random_state', 0), + 'holdout': task_config.get('datasplit_holdout', None), + 'key': task_config.get('key', None), + 'mode': task_config.get('datasplit_mode', 'iid'), + 'subset': None + }, + 'dataloader': { + 'batch_size': task_config.get('batch_size', 64), + 'shuffle': task_config.get('shuffle', True) + } + } + + # Transform in a ConfigNamespace + config_ns = dict_to_namespace(config) + + # Convert outdir to Path object to ensure compatibility + outdir = Path(outdir) + + # Call the model function + if train_type == 'cellot': + train_cellot(outdir, config_ns) + elif train_type == 'auto_encoder': + train_auto_encoder(outdir, config_ns) + elif train_type == 'popalign': + train_popalign(outdir, config_ns) + else: + raise ValueError("Train type not supported: {}".format(train_type)) + + print(f"Training complete for {train_type} model at {outdir}.") + + +import os +import anndata +#from cellot_train_v2 import run_cellot_training # Ensure this is the path to the training function + +# Load the AnnData dataset +adata = anndata.read_h5ad("C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad") # Replace with the actual path + +# Extract unique cell types +cell_types = adata.obs['cell_type'].unique() + +# Directory to save the models +output_dir = '.\\output_ood_models_1' + + +for cell_type in cell_types: + # Set up holdout as the current cell type for OOD + holdout = cell_type + + # Define the output directory for the model + model_dir = os.path.join(output_dir, f"{cell_type}_ood") + os.makedirs(model_dir, exist_ok=True) + + # Define the task configuration for training + task_config = { + 'dataset': "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad", + 'condition': 'condition', # Column defining the conditions + 'source': 'ctrl', # Control condition as source + 'target': 'stim', # Stimulated condition as target + 'type': 'cell', # Specifies type, assuming it's for cells + 'epochs': 100000, # Set the number of epochs + 'batch_size': 128, # Batch size + 'shuffle': True, # Shuffle the data + 'datasplit_groupby': ['cell_type','condition'], # Grouping by 'condition' for OOD + 'datasplit_name': 'toggle_ood', + 'key' : 'cell_type', + 'datasplit_mode': 'ood', # Set mode to 'ood' + 'datasplit_holdout': holdout, # Specify holdout cell type + 'datasplit_test_size': 0.3, # Test split size + 'datasplit_random_state': 0 # Random seed for reproducibility + } + + # Define the model configuration + model_config = { + 'input_dim': 1000, + 'name': 'cellot', + 'hidden_units': [64, 64, 64, 64], + 'latent_dim': 100, + 'softplus_W_kernels': False, + 'g': { + 'fnorm_penalty': 1 + }, + 'kernel_init_fxn': { + 'b': 0.1, + 'name': 'uniform' + }, + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.0001, + 'beta1': 0.5, + 'beta2': 0.9, + 'weight_decay': 0 + }, + 'training': { + 'n_iters': 100000, # Total number of iterations + 'n_inner_iters': 1, # Number of inner iterations + 'cache_freq': 50, # Frequency to cache model state + 'eval_freq': 20, # Frequency for evaluations + 'logs_freq': 10 # Logging frequency + } + } + + print(f"Training model for holdout cell_type: {cell_type}") + + # Run the training with specified configurations and train_type='cellot' + run_cellot_training( + task_config=task_config, + model_config=model_config, + train_type='cellot', + outdir=model_dir # Save output to the model-specific directory + ) + + print(f"Completed training for holdout cell_type: {cell_type}") + + + +#exemple + + +test_data_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" +adata = sc.read_h5ad(test_data_path) +input_dim = adata.shape[1] +print("Number of variables (input_dim) :", input_dim) + + + + + +import pandas as pd +import matplotlib.pyplot as plt + +# load losses data +loss_data = pd.read_csv('.\\output_ood_models\\loss_tracking.csv') + +# Check for column +if all(col in loss_data.columns for col in ['Step', 'Loss_G', 'Loss_F']): + # plot + plt.figure(figsize=(10, 6)) + plt.plot(loss_data['Step'], loss_data['Loss_G'], label='Loss_G') + plt.plot(loss_data['Step'], loss_data['Loss_F'], label='Loss_F') + plt.xlabel('Step') + plt.ylabel('Loss') + plt.yscale('log') + plt.legend() + plt.title('Evolution of Losses (Loss G and Loss F) During Training') + plt.show() +else: + print("No 'Step', 'Loss_G', ou 'Loss_F' in the CSV file.") + diff --git a/tools/Complex_generative/cellOT_v1/source_modif/cell.py b/tools/Complex_generative/cellOT_v1/source_modif/cell.py new file mode 100644 index 0000000..5f8129f --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/source_modif/cell.py @@ -0,0 +1,358 @@ +#!/usr/bin/python3 + +import anndata +import numpy as np +import pandas as pd +from scipy import sparse +from pathlib import Path +import torch +from torch.utils.data import Dataset +from sklearn.model_selection import train_test_split +from cellot.models import load_autoencoder_model +from cellot.utils import load_config +from cellot.data.utils import cast_dataset_to_loader +from cellot.utils.helpers import nest_dict + + +class AnnDataDataset(Dataset): + def __init__( + self, adata, obs=None, categories=None, include_index=False, dim_red=None + ): + self.adata = adata + self.adata.X = self.adata.X.astype(np.float32) + self.obs = obs + self.categories = categories + self.include_index = include_index + + def __len__(self): + return len(self.adata) + + def __getitem__(self, idx): + value = self.adata.X[idx] + + if self.obs is not None: + meta = self.categories.index(self.adata.obs[self.obs].iloc[idx]) + value = value, int(meta) + + if self.include_index: + return self.adata.obs_names[idx], value + + return value + + +def read_list(arg): + + if isinstance(arg, str): + arg = Path(arg) + assert arg.exists() + lst = arg.read_text().split() + else: + lst = arg + + return list(lst) + + +def read_single_anndata(config, path=None): + if path is None: + path = config.data.path + + data = anndata.read(path) + + if hasattr(config.data, "features"): + features = read_list(config.data.features) + data = data[:, features].copy() + + # select subgroup of individuals + if hasattr(config.data, "individuals"): + data = data[ + data.obs[config.data.individuals[0]].isin(config.data.individuals[1]) + ] + + # label conditions as source/target distributions + # config.data.{source,target} can be a list now + transport_mapper = dict() + for value in ["source", "target"]: + key = getattr(config.data, value) + if isinstance(key, list): + for item in key: + transport_mapper[item] = value + else: + transport_mapper[key] = value + + data.obs["transport"] = data.obs[config.data.condition].apply(transport_mapper.get) + + if getattr(config.data, "target") == "all": + data.obs["transport"].fillna("target", inplace=True) + + mask = data.obs["transport"].notna() + assert not hasattr(config.data, "subset") + if config.datasplit.subset is not None: + for key, value in config.datasplit.subset.items(): + + if not isinstance(value, list): + value = [value] + mask = mask & data.obs[key].isin(value) + + # write train/test/valid into split column + data = data[mask].copy() + if hasattr(config, "datasplit"): + data.obs["split"] = split_cell_data(data, **config.datasplit.to_dict() if config.datasplit else {}) + + return data + + +def load_cell_data( + config, + data=None, + split_on=None, + return_as="loader", + include_model_kwargs=False, + pair_batch_on=None, + **kwargs +): + + if isinstance(return_as, str): + return_as = [return_as] + + assert set(return_as).issubset({"anndata", "dataset", "loader"}) + config.data.condition = config.data.get("condition", "drug") + condition = config.data.condition + + if data is None: + if config.data.type == "cell": + data = read_single_anndata(config, **kwargs) + else: + raise ValueError + + if config.data.get("select") is not None: + keep = pd.Series(False, index=data.obs_names) + for key, value in config.data.select.items(): + if not isinstance(value, list): + value = [value] + keep.loc[data.obs[key].isin(value)] = True + assert keep.sum() > 0 + + data = data[keep].copy() + + if "dimension_reduction" in config.data: + genes = data.var_names.to_list() + name = config.data.dimension_reduction.name + if name == "pca": + dims = config.data.dimension_reduction.get( + "dims", data.obsm["X_pca"].shape[1] + ) + + data = anndata.AnnData( + data.obsm["X_pca"][:, :dims], obs=data.obs.copy(), uns=data.uns.copy() + ) + data.uns["genes"] = genes + + if "ae_emb" in config.data: + # load path to autoencoder + assert config.get("model.name", "cellot") == "cellot" + path_ae = Path(config.data.ae_emb.path) + model_kwargs = {"input_dim": data.n_vars} + config_ae = load_config(path_ae / "config.yaml") + ae_model, _ = load_autoencoder_model( + config_ae, restore=path_ae / "cache/model.pt", **model_kwargs + ) + + inputs = torch.Tensor( + data.X if not sparse.issparse(data.X) else data.X.todense() + ) + + genes = data.var_names.to_list() + data = anndata.AnnData( + ae_model.eval().encode(inputs).detach().numpy(), + obs=data.obs.copy(), + uns=data.uns.copy(), + ) + data.uns["genes"] = genes + + # cast to dense and check for nans + if sparse.issparse(data.X): + data.X = data.X.todense() + assert not np.isnan(data.X).any() + + dataset_args = dict() + model_kwargs = {} + + model_kwargs["input_dim"] = data.n_vars + + if config.get("model.name") == "cae": + condition_labels = sorted(data.obs[condition].cat.categories) + model_kwargs["conditions"] = condition_labels + dataset_args["obs"] = condition + dataset_args["categories"] = condition_labels + + if "training" in config: + pair_batch_on = config.training.get("pair_batch_on", pair_batch_on) + + if split_on is None: + if config.model.name == "cellot": + # datasets & dataloaders accessed as loader.train.source + split_on = ["split", "transport"] + if pair_batch_on is not None: + split_on.append(pair_batch_on) + + elif (config.model.name == "scgen" or config.model.name == "cae" + or config.model.name == "popalign"): + split_on = ["split"] + + else: + raise ValueError + + if isinstance(split_on, str): + split_on = [split_on] + + for key in split_on: + assert key in data.obs.columns + + if len(split_on) > 0: + splits = { + (key if isinstance(key, str) else ".".join(key)): data[index] + for key, index in data.obs[split_on].groupby(split_on).groups.items() + } + + dataset = nest_dict( + { + key: AnnDataDataset(val.copy(), **dataset_args) + for key, val in splits.items() + }, + as_dot_dict=True, + ) + + else: + dataset = AnnDataDataset(data.copy(), **dataset_args) + + if "loader" in return_as: + kwargs = config.dataloader.to_dict() if hasattr(config.dataloader, "to_dict") else config.dataloader + kwargs.setdefault("drop_last", True) + loader = cast_dataset_to_loader(dataset, **kwargs) + + returns = list() + for key in return_as: + if key == "anndata": + returns.append(data) + + elif key == "dataset": + returns.append(dataset) + + elif key == "loader": + returns.append(loader) + + if include_model_kwargs: + returns.append(model_kwargs) + + if len(returns) == 1: + return returns[0] + + return tuple(returns) + + +def split_cell_data_train_test( + data, groupby=None, random_state=0, holdout=None, subset=None, **kwargs +): + + kwargs.pop("mode", None) # Delete "mode" if it dosen't exist in kwargs + kwargs.pop("key", None) + split = pd.Series(None, index=data.obs.index, dtype=object) + groups = {None: data.obs.index} + if groupby is not None: + groups = data.obs.groupby(groupby).groups + + for key, index in groups.items(): + trainobs, testobs = train_test_split(index, random_state=random_state, **kwargs) + split.loc[trainobs] = "train" + split.loc[testobs] = "test" + + if holdout is not None: + for key, value in holdout.items(): + if not isinstance(value, list): + value = [value] + split.loc[data.obs[key].isin(value)] = "ood" + + return split + + +def split_cell_data_train_test_eval( + data, + test_size=0.15, + eval_size=0.15, + groupby=None, + random_state=0, + holdout=None, + **kwargs +): + + split = pd.Series(None, index=data.obs.index, dtype=object) + + if holdout is not None: + for key, value in holdout.items(): + if not isinstance(value, list): + value = [value] + split.loc[data.obs[key].isin(value)] = "ood" + + groups = {None: data.obs.loc[split != "ood"].index} + if groupby is not None: + groups = data.obs.loc[split != "ood"].groupby(groupby).groups + + for key, index in groups.items(): + training, evalobs = train_test_split( + index, random_state=random_state, test_size=eval_size + ) + + trainobs, testobs = train_test_split( + training, random_state=random_state, test_size=test_size + ) + + split.loc[trainobs] = "train" + split.loc[testobs] = "test" + split.loc[evalobs] = "eval" + + return split + + +def split_cell_data_toggle_ood(data, holdout, key, mode, random_state=0, **kwargs): + + """Hold out ood sample, coordinated with iid split + + ood sample defined with key, value pair + + for ood mode: hold out all cells from a sample + for iid mode: include half of cells in split + """ + + split = split_cell_data_train_test(data, random_state=random_state, **kwargs) + + if not isinstance(holdout, list): + value = [holdout] + + ood = data.obs_names[data.obs[key].isin(value)] + trainobs, testobs = train_test_split(ood, random_state=random_state, test_size=0.5) + + if mode == "ood": + split.loc[trainobs] = "ignore" + split.loc[testobs] = "ood" + + elif mode == "iid": + split.loc[trainobs] = "train" + split.loc[testobs] = "ood" + + else: + raise ValueError + + return split + + +def split_cell_data(data, name="train_test", **kwargs): + if name == "train_test": + split = split_cell_data_train_test(data, **kwargs) + elif name == "toggle_ood": + split = split_cell_data_toggle_ood(data, **kwargs) + elif name == "train_test_eval": + split = split_cell_data_train_test_eval(data, **kwargs) + else: + raise ValueError + + return split.astype("category") diff --git a/tools/Complex_generative/cellOT_v1/source_modif/cellot.py b/tools/Complex_generative/cellOT_v1/source_modif/cellot.py new file mode 100644 index 0000000..31c370f --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/source_modif/cellot.py @@ -0,0 +1,172 @@ +from pathlib import Path +import torch +from collections import namedtuple +from cellot.networks.icnns import ICNN + +from absl import flags + +FLAGS = flags.FLAGS + +FGPair = namedtuple("FGPair", "f g") + + +def load_networks(config, **kwargs): + def unpack_kernel_init_fxn(name="uniform", **kwargs): + if name == "normal": + def init(*args): + return torch.nn.init.normal_(*args, **kwargs) + elif name == "uniform": + def init(*args): + return torch.nn.init.uniform_(*args, **kwargs) + else: + raise ValueError("Unsupported kernel initialization function.") + return init + + # Exclude parameters not relevant to ICNN + model_params = config.get("model", {}).as_dict() + ignore_keys = ["name", "latent_dim", "optim", "training"] + for key in ignore_keys: + model_params.pop(key, None) + + # Check and define input_dim + input_dim = model_params.get("input_dim") or kwargs.get("input_dim") + if input_dim is None: + raise ValueError("`input_dim` must be specified in the model configuration or kwargs.") + + kwargs.setdefault("hidden_units", [64] * 4) + kwargs.update(model_params) + + # specific parameters for f et g + fupd = kwargs.pop("f", {}) + gupd = kwargs.pop("g", {}) + + # Configure fkwargs et gkwargs for ICNN + fkwargs = kwargs.copy() + fkwargs.update(fupd) + fkwargs["input_dim"] = input_dim # add input_dim for f + if "kernel_init_fxn" in fkwargs: + fkwargs["kernel_init_fxn"] = unpack_kernel_init_fxn(**fkwargs.pop("kernel_init_fxn")) + + gkwargs = kwargs.copy() + gkwargs.update(gupd) + gkwargs["input_dim"] = input_dim # add input_dim for g + if "kernel_init_fxn" in gkwargs: + gkwargs["kernel_init_fxn"] = unpack_kernel_init_fxn(**gkwargs.pop("kernel_init_fxn")) + + # Instantiate ICNN networks for f and g + f = ICNN(**fkwargs) + g = ICNN(**gkwargs) + + if "verbose" in FLAGS and FLAGS.verbose: + print("Network g configuration:", g) + print("Remaining kwargs:", kwargs) + + return f, g + + + +def load_opts(config, f, g): + # Access “optim” as a dictionary without using .as_dict() + optim_config = config.get("optim", {}) + + # optimizers `f` et `g` + fupd = optim_config.get("f", {}) + gupd = optim_config.get("g", {}) + + # Create parameters for optimizers by adjusting the "betas" + fkwargs = optim_config.copy() + fkwargs.update(fupd) + fkwargs["betas"] = (fkwargs.pop("beta1", 0.9), fkwargs.pop("beta2", 0.999)) + + gkwargs = optim_config.copy() + gkwargs.update(gupd) + gkwargs["betas"] = (gkwargs.pop("beta1", 0.9), gkwargs.pop("beta2", 0.999)) + + # Create optimizers for f et g + opts = FGPair( + f=torch.optim.Adam(f.parameters(), **fkwargs), + g=torch.optim.Adam(g.parameters(), **gkwargs), + ) + + return opts + + +def load_cellot_model(config, restore=None, **kwargs): + f, g = load_networks(config, **kwargs) + opts = load_opts(config, f, g) + + if restore is not None and Path(restore).exists(): + ckpt = torch.load(restore) + f.load_state_dict(ckpt["f_state"]) + opts.f.load_state_dict(ckpt["opt_f_state"]) + + g.load_state_dict(ckpt["g_state"]) + opts.g.load_state_dict(ckpt["opt_g_state"]) + + return (f, g), opts + + +def compute_loss_g(f, g, source, transport=None): + if transport is None: + transport = g.transport(source) + + return f(transport) - torch.multiply(source, transport).sum(-1, keepdim=True) + + +def compute_g_constraint(g, form=None, beta=0): + if form is None or form == "None": + return 0 + + if form == "clamp": + g.clamp_w() + return 0 + + elif form == "fnorm": + if beta == 0: + return 0 + + return beta * sum(map(lambda w: w.weight.norm(p="fro"), g.W)) + + raise ValueError + + +def compute_loss_f(f, g, source, target, transport=None): + if transport is None: + transport = g.transport(source) + + return -f(transport) + f(target) + + +def compute_w2_distance(f, g, source, target, transport=None): + if transport is None: + transport = g.transport(source).squeeze() + + with torch.no_grad(): + Cpq = (source * source).sum(1, keepdim=True) + (target * target).sum( + 1, keepdim=True + ) + Cpq = 0.5 * Cpq + + cost = ( + f(transport) + - torch.multiply(source, transport).sum(-1, keepdim=True) + - f(target) + + Cpq + ) + cost = cost.mean() + return cost + + +def numerical_gradient(param, fxn, *args, eps=1e-4): + with torch.no_grad(): + param += eps + plus = float(fxn(*args)) + + with torch.no_grad(): + param -= 2 * eps + minus = float(fxn(*args)) + + with torch.no_grad(): + param += eps + + return (plus - minus) / (2 * eps) diff --git a/tools/Complex_generative/cellOT_v1/source_modif/icnns.py b/tools/Complex_generative/cellOT_v1/source_modif/icnns.py new file mode 100644 index 0000000..9b54e96 --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/source_modif/icnns.py @@ -0,0 +1,135 @@ +import torch +from torch import autograd +import numpy as np +from torch import nn +from numpy.testing import assert_allclose + + +ACTIVATIONS = { + "relu": nn.ReLU, + "leakyrelu": nn.LeakyReLU, +} + + +class NonNegativeLinear(nn.Linear): + def __init__(self, *args, beta=1.0, **kwargs): + super(NonNegativeLinear, self).__init__(*args, **kwargs) + self.beta = beta + return + + def forward(self, x): + return nn.functional.linear(x, self.kernel(), self.bias) + + def kernel(self): + return nn.functional.softplus(self.weight, beta=self.beta) + + +class ICNN(nn.Module): + def __init__( + self, + input_dim, + hidden_units, + activation="LeakyReLU", + softplus_W_kernels=False, + softplus_beta=1, + std=0.1, + fnorm_penalty=0, + kernel_init_fxn=None, + ): + + super(ICNN, self).__init__() + self.fnorm_penalty = fnorm_penalty + self.softplus_W_kernels = softplus_W_kernels + + if isinstance(activation, str): + activation = ACTIVATIONS[activation.lower().replace("_", "")] + self.sigma = activation + + units = hidden_units + [1] + + # z_{l+1} = \sigma_l(W_l*z_l + A_l*x + b_l) + # W_0 = 0 + if self.softplus_W_kernels: + + def WLinear(*args, **kwargs): + return NonNegativeLinear(*args, **kwargs, beta=softplus_beta) + + else: + WLinear = nn.Linear + + self.W = nn.ModuleList( + [ + WLinear(idim, odim, bias=False) + for idim, odim in zip(units[:-1], units[1:]) + ] + ) + + self.A = nn.ModuleList( + [nn.Linear(input_dim, odim, bias=True) for odim in units] + ) + + if kernel_init_fxn is not None: + + for layer in self.A: + kernel_init_fxn(layer.weight) + nn.init.zeros_(layer.bias) + + for layer in self.W: + kernel_init_fxn(layer.weight) + + return + + def forward(self, x): + + z = self.sigma(0.2)(self.A[0](x)) + z = z * z + + for W, A in zip(self.W[:-1], self.A[1:-1]): + z = self.sigma(0.2)(W(z) + A(x)) + + y = self.W[-1](z) + self.A[-1](x) + + return y + + def transport(self, x): + assert x.requires_grad + + (output,) = autograd.grad( + self.forward(x), + x, + create_graph=True, + only_inputs=True, + grad_outputs=torch.ones_like(self.forward(x)), + ) + return output + + def clamp_w(self): + if self.softplus_W_kernels: + return + + for w in self.W: + w.weight.data = w.weight.data.clamp(min=0) + return + + def penalize_w(self): + return self.fnorm_penalty * sum( + map(lambda x: torch.nn.functional.relu(-x.weight).norm(), self.W) + ) + + +def test_icnn_convexity(icnn): + data_dim = icnn.A[0].in_features + + zeros = np.zeros(100) + for _ in range(100): + x = torch.rand((100, data_dim)) + y = torch.rand((100, data_dim)) + + fx = icnn(x) + fy = icnn(y) + + for t in np.linspace(0, 1, 10): + fxy = icnn(t * x + (1 - t) * y) + res = (t * fx + (1 - t) * fy) - fxy + res = res.detach().numpy().squeeze() + assert_allclose(np.minimum(res, 0), zeros, atol=1e-6) diff --git a/tools/Complex_generative/cellOT_v1/source_modif/train.py b/tools/Complex_generative/cellOT_v1/source_modif/train.py new file mode 100644 index 0000000..a65f4e8 --- /dev/null +++ b/tools/Complex_generative/cellOT_v1/source_modif/train.py @@ -0,0 +1,348 @@ +from pathlib import Path +import csv +import torch +import numpy as np +import random +import pickle +from absl import logging +from absl.flags import FLAGS +from cellot import losses +from cellot.utils.loaders import load +from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance +from cellot.train.summary import Logger +from cellot.data.utils import cast_loader_to_iterator +from cellot.models.ae import compute_scgen_shift +from tqdm import trange + + +def load_lr_scheduler(optim, config): + if "scheduler" not in config: + return None + + return torch.optim.lr_scheduler.StepLR(optim, **config.scheduler) + + +def check_loss(*args): + for arg in args: + if torch.isnan(arg): + raise ValueError + + +def load_item_from_save(path, key, default): + path = Path(path) + if not path.exists(): + return default + + ckpt = torch.load(path) + if key not in ckpt: + logging.warn(f"'{key}' not found in ckpt: {str(path)}") + return default + + return ckpt[key] + + +def train_cellot(outdir, config): + def get_state_dict_for_saving(f, g, opts, **kwargs): + if not (hasattr(f, "state_dict") and callable(f.state_dict)): + raise TypeError("`f` n'est pas un modèle PyTorch valide.") + if not (hasattr(g, "state_dict") and callable(g.state_dict)): + raise TypeError("`g` n'est pas un modèle PyTorch valide.") + + state = { + "g_state": g.state_dict(), + "f_state": f.state_dict(), + "opt_g_state": opts.g.state_dict(), + "opt_f_state": opts.f.state_dict(), + } + state.update(kwargs) + return state + + def evaluate(): + target = next(iterator_test_target) + source = next(iterator_test_source) + source.requires_grad_(True) + transport = g.transport(source) + transport = transport.detach() + + with torch.no_grad(): + gl = compute_loss_g(f, g, source, transport).mean() + fl = compute_loss_f(f, g, source, target, transport).mean() + dist = compute_w2_distance(f, g, source, target, transport) + mmd = losses.compute_scalar_mmd( + target.detach().numpy(), transport.detach().numpy() + ) + + loss_tracking_records.append({"step": step, "loss_g": gl.item(), "loss_f": fl.item()}) + + logger.log( + "eval", + gloss=gl.item(), + floss=fl.item(), + jloss=dist.item(), + mmd=mmd, + step=step, + ) + check_loss(gl, gl, dist) + return mmd + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + (f, g), opts, loader = load(config, restore=cachedir / "last.pt") + + # Checks of f and g + #print(f"initial type of f: {type(f)}") + #print(f"initial type of g: {type(g)}") + assert hasattr(f, "state_dict") and callable(f.state_dict), "f n'est pas un modèle PyTorch valide" + assert hasattr(g, "state_dict") and callable(g.state_dict), "g n'est pas un modèle PyTorch valide" + + iterator = cast_loader_to_iterator(loader, cycle_all=True) + n_iters = config.training.n_iters + step = load_item_from_save(cachedir / "last.pt", "step", 0) + minmmd = load_item_from_save(cachedir / "model.pt", "minmmd", np.inf) + mmd = minmmd + loss_tracking_records = [] + + if 'pair_batch_on' in config.training: + keys = list(iterator.train.target.keys()) + test_keys = list(iterator.test.target.keys()) + else: + keys = None + + ticker = trange(step, n_iters, initial=step, total=n_iters) + for step in ticker: + if 'pair_batch_on' in config.training: + assert keys is not None + key = random.choice(keys) + iterator_train_target = iterator.train.target[key] + iterator_train_source = iterator.train.source[key] + try: + iterator_test_target = iterator.test.target[key] + iterator_test_source = iterator.test.source[key] + except KeyError: + test_key = random.choice(test_keys) + iterator_test_target = iterator.test.target[test_key] + iterator_test_source = iterator.test.source[test_key] + else: + iterator_train_target = iterator.train.target + iterator_train_source = iterator.train.source + iterator_test_target = iterator.test.target + iterator_test_source = iterator.test.source + + target = next(iterator_train_target) + for _ in range(config.training.n_inner_iters): + source = next(iterator_train_source).requires_grad_(True) + + opts.g.zero_grad() + gl = compute_loss_g(f, g, source).mean() + if not g.softplus_W_kernels and g.fnorm_penalty > 0: + gl = gl + g.penalize_w() + + gl.backward() + opts.g.step() + + source = next(iterator_train_source).requires_grad_(True) + opts.f.zero_grad() + fl = compute_loss_f(f, g, source, target).mean() + fl.backward() + opts.f.step() + check_loss(gl, fl) + f.clamp_w() + + if step % config.training.logs_freq == 0: + logger.log("train", gloss=gl.item(), floss=fl.item(), step=step) + + if step % config.training.eval_freq == 0: + mmd = evaluate() + if mmd < minmmd: + minmmd = mmd + torch.save( + get_state_dict_for_saving(f, g, opts, step=step, minmmd=minmmd), + cachedir / "model.pt", + ) + + if step % config.training.cache_freq == 0: + # Check before save + #print("content of get_state_dict_for_saving before saving in last.pt :", get_state_dict_for_saving(f, g, opts, step=step)) + torch.save(get_state_dict_for_saving(f, g, opts, step=step), cachedir / "last.pt") + logger.flush() + + with open(outdir / "loss_tracking.csv", mode="w", newline="") as csv_file: + writer = csv.writer(csv_file) + writer.writerow(["Step", "Loss_G", "Loss_F"]) # CSV col.names + for record in loss_tracking_records: + writer.writerow([record["step"], record["loss_g"], record["loss_f"]]) + + torch.save(get_state_dict_for_saving(f, g, opts, step=step), cachedir / "last.pt") + logger.flush() + return + + + +def train_auto_encoder(outdir, config): + def state_dict(model, optim, **kwargs): + state = { + "model_state": model.state_dict(), + "optim_state": optim.state_dict(), + } + + if hasattr(model, "code_means"): + state["code_means"] = model.code_means + + state.update(kwargs) + + return state + + def evaluate(vinputs): + with torch.no_grad(): + loss, comps, _ = model(vinputs) + loss = loss.mean() + comps = {k: v.mean().item() for k, v in comps._asdict().items()} + check_loss(loss) + logger.log("eval", loss=loss.item(), step=step, **comps) + return loss + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + model, optim, loader = load(config, restore=cachedir / "last.pt") + + iterator = cast_loader_to_iterator(loader, cycle_all=True) + scheduler = load_lr_scheduler(optim, config) + + n_iters = config.training.n_iters + step = load_item_from_save(cachedir / "last.pt", "step", 0) + if scheduler is not None and step > 0: + scheduler.last_epoch = step + + best_eval_loss = load_item_from_save( + cachedir / "model.pt", "best_eval_loss", np.inf + ) + + eval_loss = best_eval_loss + + ticker = trange(step, n_iters, initial=step, total=n_iters) + for step in ticker: + + model.train() + inputs = next(iterator.train) + optim.zero_grad() + loss, comps, _ = model(inputs) + loss = loss.mean() + comps = {k: v.mean().item() for k, v in comps._asdict().items()} + loss.backward() + optim.step() + check_loss(loss) + + if step % config.training.logs_freq == 0: + # log to logger object + logger.log("train", loss=loss.item(), step=step, **comps) + + if step % config.training.eval_freq == 0: + model.eval() + eval_loss = evaluate(next(iterator.test)) + if eval_loss < best_eval_loss: + best_eval_loss = eval_loss + sd = state_dict(model, optim, step=(step + 1), eval_loss=eval_loss) + + torch.save(sd, cachedir / "model.pt") + + if step % config.training.cache_freq == 0: + torch.save(state_dict(model, optim, step=(step + 1)), cachedir / "last.pt") + + logger.flush() + + if scheduler is not None: + scheduler.step() + + if config.model.name == "scgen" and config.get("compute_scgen_shift", True): + labels = loader.train.dataset.adata.obs[config.data.condition] + compute_scgen_shift(model, loader.train.dataset, labels=labels) + + torch.save(state_dict(model, optim, step=step), cachedir / "last.pt") + + logger.flush() + + +def train_popalign(outdir, config): + def evaluate(config, data, model): + + # Get control and treated subset of the data and projections. + idx_control_test = np.where(data.obs[ + config.data.condition] == config.data.source)[0] + idx_treated_test = np.where(data.obs[ + config.data.condition] == config.data.target)[0] + + predicted = transport_popalign(model, data[idx_control_test].X) + target = np.array(data[idx_treated_test].X) + + # Compute performance metrics. + mmd = losses.compute_scalar_mmd(target, predicted) + wst = losses.wasserstein_loss(target, predicted) + + # Log to logger object. + logger.log( + "eval", + mmd=mmd, + wst=wst, + step=1 + ) + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + + # Load dataset and previous model parameters. + model, _, dataset = load(config, restore=cachedir / "last.pt", + return_as="dataset") + train_data = dataset["train"].adata + test_data = dataset["test"].adata + + if not all(k in model for k in ("dim_red", "gmm_control", "response")): + + if config.model.embedding == 'onmf': + # Find best low dimensional representation. + q, nfeats, errors = onmf(train_data.X.T) + W, proj = choose_featureset( + train_data.X.T, errors, q, nfeats, alpha=3, multiplier=3) + + else: + W = np.eye(train_data.X.shape[1]) + proj = train_data.X + + # Get control and treated subset of the data and projections. + idx_control_train = np.where(train_data.obs[ + config.data.condition] == config.data.source)[0] + idx_treated_train = np.where(train_data.obs[ + config.data.condition] == config.data.target)[0] + + # Compute probabilistic model for control and treated population. + gmm_control = build_gmm( + train_data.X[idx_control_train, :].T, + proj[idx_control_train], ks=(3), niters=2, + training=.8, criteria='aic') + gmm_treated = build_gmm( + train_data.X[idx_treated_train, :].T, + proj[idx_treated_train], ks=(3), niters=2, + training=.8, criteria='aic') + + # Compute alignment between components of both mixture models. + align, _ = align_components(gmm_control, gmm_treated, method="ref2test") + + # Compute perturbation response for each control component. + res = get_perturbation_response(align, gmm_control, gmm_treated) + + # Save all results to state dict. + model = {"dim_red": W, + "gmm_control": gmm_control, + "gmm_treated": gmm_treated, + "response": res} + state_dict = model + pickle.dump(state_dict, open(cachedir / "last.pt", 'wb')) + pickle.dump(state_dict, open(cachedir / "model.pt", 'wb')) + + else: + W = model["dim_red"] + gmm_control = model["gmm_control"] + gmm_treated = model["gmm_treated"] + res = model["response"] + + # Evaluate performance on test set. + evaluate(config, test_data, model)