From b751fc8c4ff3492d317157e9cc7ff55c51fdc109 Mon Sep 17 00:00:00 2001 From: Sabrina Date: Thu, 14 Nov 2024 18:12:46 +0100 Subject: [PATCH 1/7] Add first files in branche cellOT_v1 --- Cellot_v1_sj.md | 78 +++++++++ cell.py | 358 +++++++++++++++++++++++++++++++++++++++++ cellot.py | 172 ++++++++++++++++++++ cellot_eval_v3_ood.py | 266 ++++++++++++++++++++++++++++++ cellot_train_v3_ood.py | 256 +++++++++++++++++++++++++++++ icnns.py | 135 ++++++++++++++++ train.py | 348 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 1613 insertions(+) create mode 100644 Cellot_v1_sj.md create mode 100644 cell.py create mode 100644 cellot.py create mode 100644 cellot_eval_v3_ood.py create mode 100644 cellot_train_v3_ood.py create mode 100644 icnns.py create mode 100644 train.py diff --git a/Cellot_v1_sj.md b/Cellot_v1_sj.md new file mode 100644 index 0000000..23b7d25 --- /dev/null +++ b/Cellot_v1_sj.md @@ -0,0 +1,78 @@ +# Cellot Model Training and Evaluation: OOD Workflow + +This document provides a step-by-step guide to preparing data, training the Cellot model in Out-of-Distribution (OOD) mode, and evaluating model predictions. This process involves custom modifications to the Cellot codebase to address specific requirements and improve model functionality (loss ouputs and anndata with ctrl/stim/pred). + +--- + +## 1. Data Preparation + +The Cellot model requires an AnnData object that contains information for two conditions: + - **Control condition** (e.g., `ctrl`) + - **Perturbed condition** (e.g., `stim`) + +### Data Requirements +- **Data format**: The data should be in an AnnData structure compatible with single-cell analysis tools like Scanpy. Each observation in AnnData should include metadata in the `.obs` attribute, including cell type and condition. + +- **Data Normalization**: + - Normalization is essential for consistent model performance. Using the `normalize_total` and `log1p` function in Scanpy. + - **Scaling**: its specific impact on Cellot is still being evaluated, though standardizing features across cells may be beneficial for model performance..don't know yet + +After preparing and normalizing data, save the AnnData object for input into the Cellot model training and evaluation scripts. + +--- + +## 2. Training the Model with `cellot_train_v3_ood.py` + +### Environment Setup +Follow the environment setup instructions from the Cellot GitHub repository to ensure dependencies are properly installed. Specifically, a Conda environment is recommended for managing dependencies. + +### Custom Code Modifications +To handle specific issues encountered during model training, some modifications were made to the Cellot source code. These adjustments enhance compatibility with my OOD training and include: + - **Files modified**: + - `cellot.data.cell` + - `cellot.models.cellot` + - `cellot.networks.icnns` + - `cellot.train.train` + +### Training Configuration +The Cellot OOD training script (`cellot_train_v3_ood.py`) includes a loop to automatically train individual models for each cell type. The key training parameters include: +- **Condition column** (`condition`): Defines the grouping of data into control and perturbed conditions. +- **Source and target conditions**: These specify the training setup. For example, `source='ctrl'` and `target='stim'`. +- **Epochs and batch size**: Standard parameters for deep learning models. +- **Holdout cell type** (`datasplit_holdout`): Specifies the cell type excluded from training for each OOD model. + +### Running the Training Script +Run the `cellot_train_v3_ood.py` script after verifying all dependencies and data requirements. This script will: +1. Train the model for each specified cell type in OOD mode, using other cell types as training data. +2. Save models and training outputs, including loss tracking. + +**Note**: Loss curves for transport functions are recorded and can be plotted at the end of each training session, though further integration into the loop is in progress. + +--- + +## 3. Model Evaluation with `cellot_eval_v3_ood.py` + +Once models are trained, the `cellot_eval_v3_ood.py` script enables evaluation of each cell-type-specific model. Evaluation includes visualizing predictions and computing performance metrics. + +### Evaluation Outputs +1. **Dimensionality Reduction (PCA and UMAP)**: + - The script generates PCA and UMAP visualizations, specifically for the holdout cell type (excluded during training) for each trained model. + - These plots allow direct visual inspection of predicted cell distributions compared to actual data, providing insights into the model's performance in the OOD setting. + +2. **Performance Metrics (in progress)**: + - **R² Score**: Calculating R² for the predicted versus actual values helps quantify the model’s prediction accuracy for each cell type. + - **Transport Distance**: In progress !! The distances (euclidian, e and mmd), a transport function metric, assesses how accurately the model translates control cells into their perturbed states. + +The evaluation process allows detailed analysis of each model's performance per cell type, facilitating further adjustments and optimization of model parameters. + +--- + +### Summary of Parameters and Model Configurations + +- **`condition`**: Defines the control and target conditions in the dataset. +- **`datasplit_mode`**: When set to `ood`, this parameter splits the data to ensure the holdout cell type is excluded from the training data. +- **`datasplit_groupby`**: Controls grouping for the data split. For example, setting `['celltype','condition']` splits based on both cell type and condition. + +--- + +This document provides the foundational steps for leveraging Cellot’s OOD training mode. Adapting the model to specific datasets and further optimizing parameters will enhance performance, especially with custom data configurations. It's still ugly with redundancy. diff --git a/cell.py b/cell.py new file mode 100644 index 0000000..5f8129f --- /dev/null +++ b/cell.py @@ -0,0 +1,358 @@ +#!/usr/bin/python3 + +import anndata +import numpy as np +import pandas as pd +from scipy import sparse +from pathlib import Path +import torch +from torch.utils.data import Dataset +from sklearn.model_selection import train_test_split +from cellot.models import load_autoencoder_model +from cellot.utils import load_config +from cellot.data.utils import cast_dataset_to_loader +from cellot.utils.helpers import nest_dict + + +class AnnDataDataset(Dataset): + def __init__( + self, adata, obs=None, categories=None, include_index=False, dim_red=None + ): + self.adata = adata + self.adata.X = self.adata.X.astype(np.float32) + self.obs = obs + self.categories = categories + self.include_index = include_index + + def __len__(self): + return len(self.adata) + + def __getitem__(self, idx): + value = self.adata.X[idx] + + if self.obs is not None: + meta = self.categories.index(self.adata.obs[self.obs].iloc[idx]) + value = value, int(meta) + + if self.include_index: + return self.adata.obs_names[idx], value + + return value + + +def read_list(arg): + + if isinstance(arg, str): + arg = Path(arg) + assert arg.exists() + lst = arg.read_text().split() + else: + lst = arg + + return list(lst) + + +def read_single_anndata(config, path=None): + if path is None: + path = config.data.path + + data = anndata.read(path) + + if hasattr(config.data, "features"): + features = read_list(config.data.features) + data = data[:, features].copy() + + # select subgroup of individuals + if hasattr(config.data, "individuals"): + data = data[ + data.obs[config.data.individuals[0]].isin(config.data.individuals[1]) + ] + + # label conditions as source/target distributions + # config.data.{source,target} can be a list now + transport_mapper = dict() + for value in ["source", "target"]: + key = getattr(config.data, value) + if isinstance(key, list): + for item in key: + transport_mapper[item] = value + else: + transport_mapper[key] = value + + data.obs["transport"] = data.obs[config.data.condition].apply(transport_mapper.get) + + if getattr(config.data, "target") == "all": + data.obs["transport"].fillna("target", inplace=True) + + mask = data.obs["transport"].notna() + assert not hasattr(config.data, "subset") + if config.datasplit.subset is not None: + for key, value in config.datasplit.subset.items(): + + if not isinstance(value, list): + value = [value] + mask = mask & data.obs[key].isin(value) + + # write train/test/valid into split column + data = data[mask].copy() + if hasattr(config, "datasplit"): + data.obs["split"] = split_cell_data(data, **config.datasplit.to_dict() if config.datasplit else {}) + + return data + + +def load_cell_data( + config, + data=None, + split_on=None, + return_as="loader", + include_model_kwargs=False, + pair_batch_on=None, + **kwargs +): + + if isinstance(return_as, str): + return_as = [return_as] + + assert set(return_as).issubset({"anndata", "dataset", "loader"}) + config.data.condition = config.data.get("condition", "drug") + condition = config.data.condition + + if data is None: + if config.data.type == "cell": + data = read_single_anndata(config, **kwargs) + else: + raise ValueError + + if config.data.get("select") is not None: + keep = pd.Series(False, index=data.obs_names) + for key, value in config.data.select.items(): + if not isinstance(value, list): + value = [value] + keep.loc[data.obs[key].isin(value)] = True + assert keep.sum() > 0 + + data = data[keep].copy() + + if "dimension_reduction" in config.data: + genes = data.var_names.to_list() + name = config.data.dimension_reduction.name + if name == "pca": + dims = config.data.dimension_reduction.get( + "dims", data.obsm["X_pca"].shape[1] + ) + + data = anndata.AnnData( + data.obsm["X_pca"][:, :dims], obs=data.obs.copy(), uns=data.uns.copy() + ) + data.uns["genes"] = genes + + if "ae_emb" in config.data: + # load path to autoencoder + assert config.get("model.name", "cellot") == "cellot" + path_ae = Path(config.data.ae_emb.path) + model_kwargs = {"input_dim": data.n_vars} + config_ae = load_config(path_ae / "config.yaml") + ae_model, _ = load_autoencoder_model( + config_ae, restore=path_ae / "cache/model.pt", **model_kwargs + ) + + inputs = torch.Tensor( + data.X if not sparse.issparse(data.X) else data.X.todense() + ) + + genes = data.var_names.to_list() + data = anndata.AnnData( + ae_model.eval().encode(inputs).detach().numpy(), + obs=data.obs.copy(), + uns=data.uns.copy(), + ) + data.uns["genes"] = genes + + # cast to dense and check for nans + if sparse.issparse(data.X): + data.X = data.X.todense() + assert not np.isnan(data.X).any() + + dataset_args = dict() + model_kwargs = {} + + model_kwargs["input_dim"] = data.n_vars + + if config.get("model.name") == "cae": + condition_labels = sorted(data.obs[condition].cat.categories) + model_kwargs["conditions"] = condition_labels + dataset_args["obs"] = condition + dataset_args["categories"] = condition_labels + + if "training" in config: + pair_batch_on = config.training.get("pair_batch_on", pair_batch_on) + + if split_on is None: + if config.model.name == "cellot": + # datasets & dataloaders accessed as loader.train.source + split_on = ["split", "transport"] + if pair_batch_on is not None: + split_on.append(pair_batch_on) + + elif (config.model.name == "scgen" or config.model.name == "cae" + or config.model.name == "popalign"): + split_on = ["split"] + + else: + raise ValueError + + if isinstance(split_on, str): + split_on = [split_on] + + for key in split_on: + assert key in data.obs.columns + + if len(split_on) > 0: + splits = { + (key if isinstance(key, str) else ".".join(key)): data[index] + for key, index in data.obs[split_on].groupby(split_on).groups.items() + } + + dataset = nest_dict( + { + key: AnnDataDataset(val.copy(), **dataset_args) + for key, val in splits.items() + }, + as_dot_dict=True, + ) + + else: + dataset = AnnDataDataset(data.copy(), **dataset_args) + + if "loader" in return_as: + kwargs = config.dataloader.to_dict() if hasattr(config.dataloader, "to_dict") else config.dataloader + kwargs.setdefault("drop_last", True) + loader = cast_dataset_to_loader(dataset, **kwargs) + + returns = list() + for key in return_as: + if key == "anndata": + returns.append(data) + + elif key == "dataset": + returns.append(dataset) + + elif key == "loader": + returns.append(loader) + + if include_model_kwargs: + returns.append(model_kwargs) + + if len(returns) == 1: + return returns[0] + + return tuple(returns) + + +def split_cell_data_train_test( + data, groupby=None, random_state=0, holdout=None, subset=None, **kwargs +): + + kwargs.pop("mode", None) # Delete "mode" if it dosen't exist in kwargs + kwargs.pop("key", None) + split = pd.Series(None, index=data.obs.index, dtype=object) + groups = {None: data.obs.index} + if groupby is not None: + groups = data.obs.groupby(groupby).groups + + for key, index in groups.items(): + trainobs, testobs = train_test_split(index, random_state=random_state, **kwargs) + split.loc[trainobs] = "train" + split.loc[testobs] = "test" + + if holdout is not None: + for key, value in holdout.items(): + if not isinstance(value, list): + value = [value] + split.loc[data.obs[key].isin(value)] = "ood" + + return split + + +def split_cell_data_train_test_eval( + data, + test_size=0.15, + eval_size=0.15, + groupby=None, + random_state=0, + holdout=None, + **kwargs +): + + split = pd.Series(None, index=data.obs.index, dtype=object) + + if holdout is not None: + for key, value in holdout.items(): + if not isinstance(value, list): + value = [value] + split.loc[data.obs[key].isin(value)] = "ood" + + groups = {None: data.obs.loc[split != "ood"].index} + if groupby is not None: + groups = data.obs.loc[split != "ood"].groupby(groupby).groups + + for key, index in groups.items(): + training, evalobs = train_test_split( + index, random_state=random_state, test_size=eval_size + ) + + trainobs, testobs = train_test_split( + training, random_state=random_state, test_size=test_size + ) + + split.loc[trainobs] = "train" + split.loc[testobs] = "test" + split.loc[evalobs] = "eval" + + return split + + +def split_cell_data_toggle_ood(data, holdout, key, mode, random_state=0, **kwargs): + + """Hold out ood sample, coordinated with iid split + + ood sample defined with key, value pair + + for ood mode: hold out all cells from a sample + for iid mode: include half of cells in split + """ + + split = split_cell_data_train_test(data, random_state=random_state, **kwargs) + + if not isinstance(holdout, list): + value = [holdout] + + ood = data.obs_names[data.obs[key].isin(value)] + trainobs, testobs = train_test_split(ood, random_state=random_state, test_size=0.5) + + if mode == "ood": + split.loc[trainobs] = "ignore" + split.loc[testobs] = "ood" + + elif mode == "iid": + split.loc[trainobs] = "train" + split.loc[testobs] = "ood" + + else: + raise ValueError + + return split + + +def split_cell_data(data, name="train_test", **kwargs): + if name == "train_test": + split = split_cell_data_train_test(data, **kwargs) + elif name == "toggle_ood": + split = split_cell_data_toggle_ood(data, **kwargs) + elif name == "train_test_eval": + split = split_cell_data_train_test_eval(data, **kwargs) + else: + raise ValueError + + return split.astype("category") diff --git a/cellot.py b/cellot.py new file mode 100644 index 0000000..31c370f --- /dev/null +++ b/cellot.py @@ -0,0 +1,172 @@ +from pathlib import Path +import torch +from collections import namedtuple +from cellot.networks.icnns import ICNN + +from absl import flags + +FLAGS = flags.FLAGS + +FGPair = namedtuple("FGPair", "f g") + + +def load_networks(config, **kwargs): + def unpack_kernel_init_fxn(name="uniform", **kwargs): + if name == "normal": + def init(*args): + return torch.nn.init.normal_(*args, **kwargs) + elif name == "uniform": + def init(*args): + return torch.nn.init.uniform_(*args, **kwargs) + else: + raise ValueError("Unsupported kernel initialization function.") + return init + + # Exclude parameters not relevant to ICNN + model_params = config.get("model", {}).as_dict() + ignore_keys = ["name", "latent_dim", "optim", "training"] + for key in ignore_keys: + model_params.pop(key, None) + + # Check and define input_dim + input_dim = model_params.get("input_dim") or kwargs.get("input_dim") + if input_dim is None: + raise ValueError("`input_dim` must be specified in the model configuration or kwargs.") + + kwargs.setdefault("hidden_units", [64] * 4) + kwargs.update(model_params) + + # specific parameters for f et g + fupd = kwargs.pop("f", {}) + gupd = kwargs.pop("g", {}) + + # Configure fkwargs et gkwargs for ICNN + fkwargs = kwargs.copy() + fkwargs.update(fupd) + fkwargs["input_dim"] = input_dim # add input_dim for f + if "kernel_init_fxn" in fkwargs: + fkwargs["kernel_init_fxn"] = unpack_kernel_init_fxn(**fkwargs.pop("kernel_init_fxn")) + + gkwargs = kwargs.copy() + gkwargs.update(gupd) + gkwargs["input_dim"] = input_dim # add input_dim for g + if "kernel_init_fxn" in gkwargs: + gkwargs["kernel_init_fxn"] = unpack_kernel_init_fxn(**gkwargs.pop("kernel_init_fxn")) + + # Instantiate ICNN networks for f and g + f = ICNN(**fkwargs) + g = ICNN(**gkwargs) + + if "verbose" in FLAGS and FLAGS.verbose: + print("Network g configuration:", g) + print("Remaining kwargs:", kwargs) + + return f, g + + + +def load_opts(config, f, g): + # Access “optim” as a dictionary without using .as_dict() + optim_config = config.get("optim", {}) + + # optimizers `f` et `g` + fupd = optim_config.get("f", {}) + gupd = optim_config.get("g", {}) + + # Create parameters for optimizers by adjusting the "betas" + fkwargs = optim_config.copy() + fkwargs.update(fupd) + fkwargs["betas"] = (fkwargs.pop("beta1", 0.9), fkwargs.pop("beta2", 0.999)) + + gkwargs = optim_config.copy() + gkwargs.update(gupd) + gkwargs["betas"] = (gkwargs.pop("beta1", 0.9), gkwargs.pop("beta2", 0.999)) + + # Create optimizers for f et g + opts = FGPair( + f=torch.optim.Adam(f.parameters(), **fkwargs), + g=torch.optim.Adam(g.parameters(), **gkwargs), + ) + + return opts + + +def load_cellot_model(config, restore=None, **kwargs): + f, g = load_networks(config, **kwargs) + opts = load_opts(config, f, g) + + if restore is not None and Path(restore).exists(): + ckpt = torch.load(restore) + f.load_state_dict(ckpt["f_state"]) + opts.f.load_state_dict(ckpt["opt_f_state"]) + + g.load_state_dict(ckpt["g_state"]) + opts.g.load_state_dict(ckpt["opt_g_state"]) + + return (f, g), opts + + +def compute_loss_g(f, g, source, transport=None): + if transport is None: + transport = g.transport(source) + + return f(transport) - torch.multiply(source, transport).sum(-1, keepdim=True) + + +def compute_g_constraint(g, form=None, beta=0): + if form is None or form == "None": + return 0 + + if form == "clamp": + g.clamp_w() + return 0 + + elif form == "fnorm": + if beta == 0: + return 0 + + return beta * sum(map(lambda w: w.weight.norm(p="fro"), g.W)) + + raise ValueError + + +def compute_loss_f(f, g, source, target, transport=None): + if transport is None: + transport = g.transport(source) + + return -f(transport) + f(target) + + +def compute_w2_distance(f, g, source, target, transport=None): + if transport is None: + transport = g.transport(source).squeeze() + + with torch.no_grad(): + Cpq = (source * source).sum(1, keepdim=True) + (target * target).sum( + 1, keepdim=True + ) + Cpq = 0.5 * Cpq + + cost = ( + f(transport) + - torch.multiply(source, transport).sum(-1, keepdim=True) + - f(target) + + Cpq + ) + cost = cost.mean() + return cost + + +def numerical_gradient(param, fxn, *args, eps=1e-4): + with torch.no_grad(): + param += eps + plus = float(fxn(*args)) + + with torch.no_grad(): + param -= 2 * eps + minus = float(fxn(*args)) + + with torch.no_grad(): + param += eps + + return (plus - minus) / (2 * eps) diff --git a/cellot_eval_v3_ood.py b/cellot_eval_v3_ood.py new file mode 100644 index 0000000..5fa7d79 --- /dev/null +++ b/cellot_eval_v3_ood.py @@ -0,0 +1,266 @@ +import torch +from pathlib import Path +import pandas as pd +import scanpy as sc +from cellot.models.cellot import load_cellot_model, compute_loss_g, compute_loss_f, compute_w2_distance +import anndata +import numpy as np +import os +import matplotlib.pyplot as plt +from types import SimpleNamespace +from sklearn.metrics import r2_score +print(os.getcwd()) + + + +class ConfigNamespace(SimpleNamespace): + def get(self, key, default=None): + return getattr(self, key, default) + + def to_dict(self): + """ + Convertit récursivement l'objet ConfigNamespace en dictionnaire. + """ + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, ConfigNamespace): + result[key] = value.to_dict() + else: + result[key] = value + return result + + def as_dict(self): + """Renvoie l'instance sous forme de dictionnaire pour une utilisation sécurisée dans le code.""" + return self.to_dict() + + def __contains__(self, key): + return key in self.__dict__ + +# Fonction utilitaire pour transformer un dictionnaire en ConfigNamespace +def dict_to_namespace(config_dict): + return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) + +# Fonction pour convertir les objets ConfigNamespace en dictionnaire avant l'utilisation +def convert_to_dict_if_namespace(obj): + """Convertit un ConfigNamespace en dictionnaire, récursivement si nécessaire.""" + if isinstance(obj, ConfigNamespace): + return obj.to_dict() + elif isinstance(obj, dict): + return {k: convert_to_dict_if_namespace(v) for k, v in obj.items()} + else: + return obj + + + +def load_test_data(test_data_path, config): + test_data = sc.read(test_data_path) + source_data = test_data[test_data.obs[config.data.condition] == config.data.source] + target_data = test_data[test_data.obs[config.data.condition] == config.data.target] + source_tensor = torch.tensor(source_data.X.toarray(), dtype=torch.float32, requires_grad=True) + target_tensor = torch.tensor(target_data.X.toarray(), dtype=torch.float32) + return list(zip(source_tensor, target_tensor)) + + +def create_anndata_with_predictions(config, model_path, original_data): + # Load the model + checkpoint = torch.load(model_path) + (f, g), _ = load_cellot_model(config) + f.load_state_dict(checkpoint['f_state']) + g.load_state_dict(checkpoint['g_state']) + + # Set the model to evaluation mode + f.eval() + g.eval() + + # Filter for source cells (ctrl condition) + source_data = original_data[original_data.obs[config.data.condition] == config.data.source] + + # Convert source data to tensor and set requires_grad for all tensors + source_tensor = torch.tensor( + source_data.X.toarray() if hasattr(source_data.X, "toarray") else source_data.X, + dtype=torch.float32, + requires_grad=True + ) + + # Step 1: Verify for NaNs in the source_tensor + print(f"Step 1: Nombre de NaN dans source_tensor : {torch.isnan(source_tensor).sum().item()}") + + # Store predicted cells + predicted_cells = [] + + # Step 2: Obtain predictions for each source cell and check for NaNs in the prediction + for i, source in enumerate(source_tensor): + with torch.set_grad_enabled(True): # Ensure grad tracking is enabled + source = source.unsqueeze(0) # Ensure correct shape + predicted = g.transport(source) # Transport function + + # Check if prediction contains NaNs + if torch.isnan(predicted).any(): + print(f"Step 2: NaNs detected in prediction for cell {i}") + else: + print(f"Step 2: Prediction successful for cell {i}") + + predicted_cells.append(predicted.detach().numpy()) # Detach after prediction + + # Stack predictions into an array + predicted_data_matrix = np.vstack(predicted_cells) + # Normalize predicted data to match original data's scale + # This assumes that the original data has been normalized with scanpy's `normalize_total` + predicted_adata = anndata.AnnData(X=predicted_data_matrix) + #sc.pp.normalize_total(predicted_adata, target_sum=1e4) + #sc.pp.log1p(predicted_adata) + #predicted_data_matrix = predicted_adata.X + + + # Step 3: Verify if predicted_data_matrix contains NaNs after prediction loop + print(f"Step 3: Nombre de NaN dans predicted_data_matrix : {np.isnan(predicted_data_matrix).sum()}") + + # Optional: Replace NaNs in `predicted_data_matrix` if needed + # predicted_data_matrix = np.nan_to_num(predicted_data_matrix) + # print(f"Step 3b: NaNs replaced. Nombre de NaN dans predicted_data_matrix après remplacement : {np.isnan(predicted_data_matrix).sum()}") + + # Combine predicted data matrix with the existing data matrix, ensuring no duplication with 'ctrl' + original_data_matrix = ( + original_data.X.toarray() if hasattr(original_data.X, "toarray") else original_data.X + ) + + # Step 4: Check for NaNs in the original data matrix + print(f"Step 4: Nombre de NaN dans original_data_matrix : {np.isnan(original_data_matrix).sum()}") + + # Step 5: Combine matrices and check for NaNs in the combined data + combined_data = np.vstack([original_data_matrix, predicted_data_matrix]) + print(f"Step 5: Nombre de NaN dans combined_data après concaténation : {np.isnan(combined_data).sum()}") + + # Copy original metadata and create labels for predictions + combined_obs = original_data.obs.copy() + + # Generate a new observation dataframe for predicted cells based on source cells but labeled as 'predicted' + predicted_obs = source_data.obs.copy() + predicted_obs[config.data.condition] = 'predicted' # Set new condition + predicted_obs.index = [f"pred_{i}" for i in range(len(predicted_cells))] # Unique indices for predictions + + # Concatenate the original observations with the newly created predicted observations + combined_obs = pd.concat([combined_obs, predicted_obs]) + + # Final AnnData object with original and predicted cells + anndata_with_predictions = anndata.AnnData( + X=combined_data, + obs=combined_obs, + var=original_data.var + ) + + + # Ensure observation names are unique + anndata_with_predictions.obs_names_make_unique() + + # Step 6: Check if the final AnnData object contains NaNs in X + print(f"Step 6: Nombre de NaN dans anndata_with_predictions.X : {np.isnan(anndata_with_predictions.X).sum()}") + + # Optional: If desired, set raw attribute for the AnnData object + anndata_with_predictions.raw = anndata_with_predictions.copy() + + return anndata_with_predictions + + +# Load the dataset +dataset_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" +adata = sc.read_h5ad(dataset_path) +cell_types = adata.obs['cell_type'].unique() +output_dir = ".\\output_ood_models" + +for cell_type in cell_types: + model_dir = os.path.join(output_dir, f"{cell_type}_ood") + model_path = Path(model_dir) / "cache" / "model.pt" + holdout = cell_type + + # Define task configuration for evaluation + task_config = { + 'dataset': dataset_path, + 'condition': 'condition', + 'source': 'ctrl', + 'target': 'stim', + 'type': 'cell', + 'batch_size': 128, + 'shuffle': True, + 'datasplit_groupby': ['cell_type','condition'], + 'datasplit_name': 'toggle_ood', + 'key' : 'cell_type',# Name for this data split + 'datasplit_mode': 'ood', # Set mode to 'ood' + 'datasplit_holdout': holdout, # Specify holdout cell type + 'datasplit_test_size': 0.3, + 'datasplit_random_state': 0 + } + + model_config = { + 'input_dim': 1000, + 'name': 'cellot', + 'hidden_units': [64, 64, 64, 64], + 'latent_dim': 100, + 'softplus_W_kernels': False, + 'g': { + 'fnorm_penalty': 1 + }, + 'kernel_init_fxn': { + 'b': 0.1, + 'name': 'uniform' + }, + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.0001, + 'beta1': 0.5, + 'beta2': 0.9, + 'weight_decay': 0 + }, + 'training': { + 'n_iters': 100000, + 'n_inner_iters': 1, + 'cache_freq': 50, + 'eval_freq': 20, + 'logs_freq': 10 + } + } + + config = { + 'training': model_config['training'], + 'data': task_config, + 'model': model_config, + 'datasplit': { + 'groupby': task_config['datasplit_groupby'], + 'name': task_config['datasplit_name'], + 'test_size': task_config['datasplit_test_size'], + 'random_state': task_config['datasplit_random_state'], + 'holdout': task_config.get('datasplit_holdout', None), + 'key': task_config.get('key', None), + 'mode': task_config.get('datasplit_mode', 'iid'), + 'subset': None + }, + 'dataloader': { + 'batch_size': task_config['batch_size'], + 'shuffle': task_config['shuffle'] + } + } + + config_ns = dict_to_namespace(config) + + # Evaluate model and create AnnData with predictions + anndata_with_predictions = create_anndata_with_predictions(config_ns, model_path, adata) + + + # Filter only for the cells of the holdout type and their predictions + holdout_cells = anndata_with_predictions[ + (anndata_with_predictions.obs['cell_type'] == cell_type) + ] + + # Visualization with PCA + print(f"Evaluating PCA and UMAP for holdout cell_type: {cell_type}") + sc.tl.pca(holdout_cells, svd_solver="arpack") + sc.pl.pca(holdout_cells, color="condition", title=f"PCA for {cell_type}") + + # Visualization with UMAP + sc.pp.neighbors(holdout_cells) + sc.tl.umap(holdout_cells) + sc.pl.umap(holdout_cells, color="condition", title=f"UMAP for {cell_type}") + + + + diff --git a/cellot_train_v3_ood.py b/cellot_train_v3_ood.py new file mode 100644 index 0000000..e3e1b33 --- /dev/null +++ b/cellot_train_v3_ood.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Nov 14 14:55:57 2024 + +@author: Shadow +""" + +# -*- coding: utf-8 -*- +""" +Created on Tue Nov 12 11:36:46 2024 + +@author: Shadow +""" + +import sys +from pathlib import Path +from cellot.train.train import train_cellot, train_auto_encoder, train_popalign +from types import SimpleNamespace +from cellot.utils.loaders import load +import scanpy as sc + +#from pathlib import Path +#from cellot_train import train_cellot, train_auto_encoder, train_popalign + +#mport csv +#from pathlib import Path + +#import torch +#import numpy as np +#import random +#import pickle +#from absl import logging +#from absl.flags import FLAGS +#from cellot import losses +#from cellot.utils.loaders import load +#from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance +#from cellot.train.summary import Logger +#from cellot.data.utils import cast_loader_to_iterator +#from cellot.models.ae import compute_scgen_shift +#from tqdm import trange + + + + + +class ConfigNamespace(SimpleNamespace): + def get(self, key, default=None): + return getattr(self, key, default) + + def to_dict(self): + """ + Convertit récursivement l'objet ConfigNamespace en dictionnaire. + """ + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, ConfigNamespace): + result[key] = value.to_dict() + else: + result[key] = value + return result + + def as_dict(self): + """Renvoie l'instance sous forme de dictionnaire pour une utilisation sécurisée dans le code.""" + return self.to_dict() + + def __contains__(self, key): + return key in self.__dict__ + +# Fonction utilitaire pour transformer un dictionnaire en ConfigNamespace +def dict_to_namespace(config_dict): + return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) + +# Fonction pour convertir les objets ConfigNamespace en dictionnaire avant l'utilisation +def convert_to_dict_if_namespace(obj): + """Convertit un ConfigNamespace en dictionnaire, récursivement si nécessaire.""" + if isinstance(obj, ConfigNamespace): + return obj.to_dict() + elif isinstance(obj, dict): + return {k: convert_to_dict_if_namespace(v) for k, v in obj.items()} + else: + return obj + + + + +def run_cellot_training(task_config, model_config, train_type="cellot", outdir='./output'): + # Créer la structure de configuration avec tous les éléments nécessaires + config = { + 'training': { + 'n_iters': task_config.get('epochs', 10), + 'logs_freq': 10, + 'eval_freq': 20, + 'cache_freq': 100, + 'n_inner_iters': 1 + }, + 'data': { + 'condition': task_config.get('condition', 'drug'), + 'source': task_config.get('source', 'control'), + 'target': task_config.get('target', 'stim'), + 'type': task_config.get('type', 'cell'), + 'path': task_config.get('dataset', '') + }, + 'model': model_config, + 'datasplit': { + 'groupby': task_config.get('datasplit_groupby', 'condition'), + 'name': task_config.get('datasplit_name', 'train_test'), + 'test_size': task_config.get('datasplit_test_size', 0.2), + 'random_state': task_config.get('datasplit_random_state', 0), + 'holdout': task_config.get('datasplit_holdout', None), + 'key': task_config.get('key', None), + 'mode': task_config.get('datasplit_mode', 'iid'), + 'subset': None + }, + 'dataloader': { + 'batch_size': task_config.get('batch_size', 64), + 'shuffle': task_config.get('shuffle', True) + } + } + + # Convertir la configuration en un objet ConfigNamespace + config_ns = dict_to_namespace(config) + + # Convert outdir to Path object to ensure compatibility + outdir = Path(outdir) + + # Appeler la fonction d'entraînement en fonction du type de modèle choisi + if train_type == 'cellot': + train_cellot(outdir, config_ns) + elif train_type == 'auto_encoder': + train_auto_encoder(outdir, config_ns) + elif train_type == 'popalign': + train_popalign(outdir, config_ns) + else: + raise ValueError("Type d'entraînement non supporté : {}".format(train_type)) + + print(f"Training complete for {train_type} model at {outdir}.") + + +import os +import anndata +#from cellot_train_v2 import run_cellot_training # Ensure this is the path to the training function + +# Load the AnnData dataset +adata = anndata.read_h5ad("C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad") # Replace with the actual path + +# Extract unique cell types +cell_types = adata.obs['cell_type'].unique() + +# Directory to save the models +output_dir = '.\\output_ood_models' + + +for cell_type in cell_types: + # Set up holdout as the current cell type for OOD + holdout = cell_type + + # Define the output directory for the model + model_dir = os.path.join(output_dir, f"{cell_type}_ood") + os.makedirs(model_dir, exist_ok=True) + + # Define the task configuration for training + task_config = { + 'dataset': "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad", + 'condition': 'condition', # Column defining the conditions + 'source': 'ctrl', # Control condition as source + 'target': 'stim', # Stimulated condition as target + 'type': 'cell', # Specifies type, assuming it's for cells + 'epochs': 100000, # Set the number of epochs + 'batch_size': 128, # Batch size + 'shuffle': True, # Shuffle the data + 'datasplit_groupby': ['cell_type','condition'], # Grouping by 'condition' for OOD + 'datasplit_name': 'toggle_ood', + 'key' : 'cell_type', + 'datasplit_mode': 'ood', # Set mode to 'ood' + 'datasplit_holdout': holdout, # Specify holdout cell type + 'datasplit_test_size': 0.3, # Test split size + 'datasplit_random_state': 0 # Random seed for reproducibility + } + + # Define the model configuration + model_config = { + 'input_dim': 1000, + 'name': 'cellot', + 'hidden_units': [64, 64, 64, 64], + 'latent_dim': 100, + 'softplus_W_kernels': False, + 'g': { + 'fnorm_penalty': 1 + }, + 'kernel_init_fxn': { + 'b': 0.1, + 'name': 'uniform' + }, + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.0001, + 'beta1': 0.5, + 'beta2': 0.9, + 'weight_decay': 0 + }, + 'training': { + 'n_iters': 100000, # Total number of iterations + 'n_inner_iters': 1, # Number of inner iterations + 'cache_freq': 50, # Frequency to cache model state + 'eval_freq': 20, # Frequency for evaluations + 'logs_freq': 10 # Logging frequency + } + } + + print(f"Training model for holdout cell_type: {cell_type}") + + # Run the training with specified configurations and train_type='cellot' + run_cellot_training( + task_config=task_config, + model_config=model_config, + train_type='cellot', + outdir=model_dir # Save output to the model-specific directory + ) + + print(f"Completed training for holdout cell_type: {cell_type}") + + + +#exemple + + +test_data_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" +adata = sc.read_h5ad(test_data_path) +input_dim = adata.shape[1] +print("Nombre de caractéristiques (input_dim) :", input_dim) + + + + + +import pandas as pd +import matplotlib.pyplot as plt + +# Charger les données de perte +loss_data = pd.read_csv('.\\output_ood_models\\loss_tracking.csv') + +# Vérifier que les colonnes attendues sont bien présentes +if all(col in loss_data.columns for col in ['Step', 'Loss_G', 'Loss_F']): + # Tracer les pertes + plt.figure(figsize=(10, 6)) + plt.plot(loss_data['Step'], loss_data['Loss_G'], label='Loss_G') + plt.plot(loss_data['Step'], loss_data['Loss_F'], label='Loss_F') + plt.xlabel('Step') + plt.ylabel('Loss') + plt.yscale('log') + plt.legend() + plt.title('Evolution of Losses (Loss G and Loss F) During Training') + plt.show() +else: + print("Les colonnes attendues 'step', 'loss_g', ou 'loss_f' ne sont pas présentes dans le fichier CSV.") +#AF echelle log !!! diff --git a/icnns.py b/icnns.py new file mode 100644 index 0000000..9b54e96 --- /dev/null +++ b/icnns.py @@ -0,0 +1,135 @@ +import torch +from torch import autograd +import numpy as np +from torch import nn +from numpy.testing import assert_allclose + + +ACTIVATIONS = { + "relu": nn.ReLU, + "leakyrelu": nn.LeakyReLU, +} + + +class NonNegativeLinear(nn.Linear): + def __init__(self, *args, beta=1.0, **kwargs): + super(NonNegativeLinear, self).__init__(*args, **kwargs) + self.beta = beta + return + + def forward(self, x): + return nn.functional.linear(x, self.kernel(), self.bias) + + def kernel(self): + return nn.functional.softplus(self.weight, beta=self.beta) + + +class ICNN(nn.Module): + def __init__( + self, + input_dim, + hidden_units, + activation="LeakyReLU", + softplus_W_kernels=False, + softplus_beta=1, + std=0.1, + fnorm_penalty=0, + kernel_init_fxn=None, + ): + + super(ICNN, self).__init__() + self.fnorm_penalty = fnorm_penalty + self.softplus_W_kernels = softplus_W_kernels + + if isinstance(activation, str): + activation = ACTIVATIONS[activation.lower().replace("_", "")] + self.sigma = activation + + units = hidden_units + [1] + + # z_{l+1} = \sigma_l(W_l*z_l + A_l*x + b_l) + # W_0 = 0 + if self.softplus_W_kernels: + + def WLinear(*args, **kwargs): + return NonNegativeLinear(*args, **kwargs, beta=softplus_beta) + + else: + WLinear = nn.Linear + + self.W = nn.ModuleList( + [ + WLinear(idim, odim, bias=False) + for idim, odim in zip(units[:-1], units[1:]) + ] + ) + + self.A = nn.ModuleList( + [nn.Linear(input_dim, odim, bias=True) for odim in units] + ) + + if kernel_init_fxn is not None: + + for layer in self.A: + kernel_init_fxn(layer.weight) + nn.init.zeros_(layer.bias) + + for layer in self.W: + kernel_init_fxn(layer.weight) + + return + + def forward(self, x): + + z = self.sigma(0.2)(self.A[0](x)) + z = z * z + + for W, A in zip(self.W[:-1], self.A[1:-1]): + z = self.sigma(0.2)(W(z) + A(x)) + + y = self.W[-1](z) + self.A[-1](x) + + return y + + def transport(self, x): + assert x.requires_grad + + (output,) = autograd.grad( + self.forward(x), + x, + create_graph=True, + only_inputs=True, + grad_outputs=torch.ones_like(self.forward(x)), + ) + return output + + def clamp_w(self): + if self.softplus_W_kernels: + return + + for w in self.W: + w.weight.data = w.weight.data.clamp(min=0) + return + + def penalize_w(self): + return self.fnorm_penalty * sum( + map(lambda x: torch.nn.functional.relu(-x.weight).norm(), self.W) + ) + + +def test_icnn_convexity(icnn): + data_dim = icnn.A[0].in_features + + zeros = np.zeros(100) + for _ in range(100): + x = torch.rand((100, data_dim)) + y = torch.rand((100, data_dim)) + + fx = icnn(x) + fy = icnn(y) + + for t in np.linspace(0, 1, 10): + fxy = icnn(t * x + (1 - t) * y) + res = (t * fx + (1 - t) * fy) - fxy + res = res.detach().numpy().squeeze() + assert_allclose(np.minimum(res, 0), zeros, atol=1e-6) diff --git a/train.py b/train.py new file mode 100644 index 0000000..a65f4e8 --- /dev/null +++ b/train.py @@ -0,0 +1,348 @@ +from pathlib import Path +import csv +import torch +import numpy as np +import random +import pickle +from absl import logging +from absl.flags import FLAGS +from cellot import losses +from cellot.utils.loaders import load +from cellot.models.cellot import compute_loss_f, compute_loss_g, compute_w2_distance +from cellot.train.summary import Logger +from cellot.data.utils import cast_loader_to_iterator +from cellot.models.ae import compute_scgen_shift +from tqdm import trange + + +def load_lr_scheduler(optim, config): + if "scheduler" not in config: + return None + + return torch.optim.lr_scheduler.StepLR(optim, **config.scheduler) + + +def check_loss(*args): + for arg in args: + if torch.isnan(arg): + raise ValueError + + +def load_item_from_save(path, key, default): + path = Path(path) + if not path.exists(): + return default + + ckpt = torch.load(path) + if key not in ckpt: + logging.warn(f"'{key}' not found in ckpt: {str(path)}") + return default + + return ckpt[key] + + +def train_cellot(outdir, config): + def get_state_dict_for_saving(f, g, opts, **kwargs): + if not (hasattr(f, "state_dict") and callable(f.state_dict)): + raise TypeError("`f` n'est pas un modèle PyTorch valide.") + if not (hasattr(g, "state_dict") and callable(g.state_dict)): + raise TypeError("`g` n'est pas un modèle PyTorch valide.") + + state = { + "g_state": g.state_dict(), + "f_state": f.state_dict(), + "opt_g_state": opts.g.state_dict(), + "opt_f_state": opts.f.state_dict(), + } + state.update(kwargs) + return state + + def evaluate(): + target = next(iterator_test_target) + source = next(iterator_test_source) + source.requires_grad_(True) + transport = g.transport(source) + transport = transport.detach() + + with torch.no_grad(): + gl = compute_loss_g(f, g, source, transport).mean() + fl = compute_loss_f(f, g, source, target, transport).mean() + dist = compute_w2_distance(f, g, source, target, transport) + mmd = losses.compute_scalar_mmd( + target.detach().numpy(), transport.detach().numpy() + ) + + loss_tracking_records.append({"step": step, "loss_g": gl.item(), "loss_f": fl.item()}) + + logger.log( + "eval", + gloss=gl.item(), + floss=fl.item(), + jloss=dist.item(), + mmd=mmd, + step=step, + ) + check_loss(gl, gl, dist) + return mmd + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + (f, g), opts, loader = load(config, restore=cachedir / "last.pt") + + # Checks of f and g + #print(f"initial type of f: {type(f)}") + #print(f"initial type of g: {type(g)}") + assert hasattr(f, "state_dict") and callable(f.state_dict), "f n'est pas un modèle PyTorch valide" + assert hasattr(g, "state_dict") and callable(g.state_dict), "g n'est pas un modèle PyTorch valide" + + iterator = cast_loader_to_iterator(loader, cycle_all=True) + n_iters = config.training.n_iters + step = load_item_from_save(cachedir / "last.pt", "step", 0) + minmmd = load_item_from_save(cachedir / "model.pt", "minmmd", np.inf) + mmd = minmmd + loss_tracking_records = [] + + if 'pair_batch_on' in config.training: + keys = list(iterator.train.target.keys()) + test_keys = list(iterator.test.target.keys()) + else: + keys = None + + ticker = trange(step, n_iters, initial=step, total=n_iters) + for step in ticker: + if 'pair_batch_on' in config.training: + assert keys is not None + key = random.choice(keys) + iterator_train_target = iterator.train.target[key] + iterator_train_source = iterator.train.source[key] + try: + iterator_test_target = iterator.test.target[key] + iterator_test_source = iterator.test.source[key] + except KeyError: + test_key = random.choice(test_keys) + iterator_test_target = iterator.test.target[test_key] + iterator_test_source = iterator.test.source[test_key] + else: + iterator_train_target = iterator.train.target + iterator_train_source = iterator.train.source + iterator_test_target = iterator.test.target + iterator_test_source = iterator.test.source + + target = next(iterator_train_target) + for _ in range(config.training.n_inner_iters): + source = next(iterator_train_source).requires_grad_(True) + + opts.g.zero_grad() + gl = compute_loss_g(f, g, source).mean() + if not g.softplus_W_kernels and g.fnorm_penalty > 0: + gl = gl + g.penalize_w() + + gl.backward() + opts.g.step() + + source = next(iterator_train_source).requires_grad_(True) + opts.f.zero_grad() + fl = compute_loss_f(f, g, source, target).mean() + fl.backward() + opts.f.step() + check_loss(gl, fl) + f.clamp_w() + + if step % config.training.logs_freq == 0: + logger.log("train", gloss=gl.item(), floss=fl.item(), step=step) + + if step % config.training.eval_freq == 0: + mmd = evaluate() + if mmd < minmmd: + minmmd = mmd + torch.save( + get_state_dict_for_saving(f, g, opts, step=step, minmmd=minmmd), + cachedir / "model.pt", + ) + + if step % config.training.cache_freq == 0: + # Check before save + #print("content of get_state_dict_for_saving before saving in last.pt :", get_state_dict_for_saving(f, g, opts, step=step)) + torch.save(get_state_dict_for_saving(f, g, opts, step=step), cachedir / "last.pt") + logger.flush() + + with open(outdir / "loss_tracking.csv", mode="w", newline="") as csv_file: + writer = csv.writer(csv_file) + writer.writerow(["Step", "Loss_G", "Loss_F"]) # CSV col.names + for record in loss_tracking_records: + writer.writerow([record["step"], record["loss_g"], record["loss_f"]]) + + torch.save(get_state_dict_for_saving(f, g, opts, step=step), cachedir / "last.pt") + logger.flush() + return + + + +def train_auto_encoder(outdir, config): + def state_dict(model, optim, **kwargs): + state = { + "model_state": model.state_dict(), + "optim_state": optim.state_dict(), + } + + if hasattr(model, "code_means"): + state["code_means"] = model.code_means + + state.update(kwargs) + + return state + + def evaluate(vinputs): + with torch.no_grad(): + loss, comps, _ = model(vinputs) + loss = loss.mean() + comps = {k: v.mean().item() for k, v in comps._asdict().items()} + check_loss(loss) + logger.log("eval", loss=loss.item(), step=step, **comps) + return loss + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + model, optim, loader = load(config, restore=cachedir / "last.pt") + + iterator = cast_loader_to_iterator(loader, cycle_all=True) + scheduler = load_lr_scheduler(optim, config) + + n_iters = config.training.n_iters + step = load_item_from_save(cachedir / "last.pt", "step", 0) + if scheduler is not None and step > 0: + scheduler.last_epoch = step + + best_eval_loss = load_item_from_save( + cachedir / "model.pt", "best_eval_loss", np.inf + ) + + eval_loss = best_eval_loss + + ticker = trange(step, n_iters, initial=step, total=n_iters) + for step in ticker: + + model.train() + inputs = next(iterator.train) + optim.zero_grad() + loss, comps, _ = model(inputs) + loss = loss.mean() + comps = {k: v.mean().item() for k, v in comps._asdict().items()} + loss.backward() + optim.step() + check_loss(loss) + + if step % config.training.logs_freq == 0: + # log to logger object + logger.log("train", loss=loss.item(), step=step, **comps) + + if step % config.training.eval_freq == 0: + model.eval() + eval_loss = evaluate(next(iterator.test)) + if eval_loss < best_eval_loss: + best_eval_loss = eval_loss + sd = state_dict(model, optim, step=(step + 1), eval_loss=eval_loss) + + torch.save(sd, cachedir / "model.pt") + + if step % config.training.cache_freq == 0: + torch.save(state_dict(model, optim, step=(step + 1)), cachedir / "last.pt") + + logger.flush() + + if scheduler is not None: + scheduler.step() + + if config.model.name == "scgen" and config.get("compute_scgen_shift", True): + labels = loader.train.dataset.adata.obs[config.data.condition] + compute_scgen_shift(model, loader.train.dataset, labels=labels) + + torch.save(state_dict(model, optim, step=step), cachedir / "last.pt") + + logger.flush() + + +def train_popalign(outdir, config): + def evaluate(config, data, model): + + # Get control and treated subset of the data and projections. + idx_control_test = np.where(data.obs[ + config.data.condition] == config.data.source)[0] + idx_treated_test = np.where(data.obs[ + config.data.condition] == config.data.target)[0] + + predicted = transport_popalign(model, data[idx_control_test].X) + target = np.array(data[idx_treated_test].X) + + # Compute performance metrics. + mmd = losses.compute_scalar_mmd(target, predicted) + wst = losses.wasserstein_loss(target, predicted) + + # Log to logger object. + logger.log( + "eval", + mmd=mmd, + wst=wst, + step=1 + ) + + logger = Logger(outdir / "cache/scalars") + cachedir = outdir / "cache" + + # Load dataset and previous model parameters. + model, _, dataset = load(config, restore=cachedir / "last.pt", + return_as="dataset") + train_data = dataset["train"].adata + test_data = dataset["test"].adata + + if not all(k in model for k in ("dim_red", "gmm_control", "response")): + + if config.model.embedding == 'onmf': + # Find best low dimensional representation. + q, nfeats, errors = onmf(train_data.X.T) + W, proj = choose_featureset( + train_data.X.T, errors, q, nfeats, alpha=3, multiplier=3) + + else: + W = np.eye(train_data.X.shape[1]) + proj = train_data.X + + # Get control and treated subset of the data and projections. + idx_control_train = np.where(train_data.obs[ + config.data.condition] == config.data.source)[0] + idx_treated_train = np.where(train_data.obs[ + config.data.condition] == config.data.target)[0] + + # Compute probabilistic model for control and treated population. + gmm_control = build_gmm( + train_data.X[idx_control_train, :].T, + proj[idx_control_train], ks=(3), niters=2, + training=.8, criteria='aic') + gmm_treated = build_gmm( + train_data.X[idx_treated_train, :].T, + proj[idx_treated_train], ks=(3), niters=2, + training=.8, criteria='aic') + + # Compute alignment between components of both mixture models. + align, _ = align_components(gmm_control, gmm_treated, method="ref2test") + + # Compute perturbation response for each control component. + res = get_perturbation_response(align, gmm_control, gmm_treated) + + # Save all results to state dict. + model = {"dim_red": W, + "gmm_control": gmm_control, + "gmm_treated": gmm_treated, + "response": res} + state_dict = model + pickle.dump(state_dict, open(cachedir / "last.pt", 'wb')) + pickle.dump(state_dict, open(cachedir / "model.pt", 'wb')) + + else: + W = model["dim_red"] + gmm_control = model["gmm_control"] + gmm_treated = model["gmm_treated"] + res = model["response"] + + # Evaluate performance on test set. + evaluate(config, test_data, model) From 837a8d6fac06cb2c9ba1b06112807714951fd474 Mon Sep 17 00:00:00 2001 From: Sabrina Date: Thu, 14 Nov 2024 18:14:08 +0100 Subject: [PATCH 2/7] Add first files in branche cellOT_v1 --- cellot_train_v3_ood.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/cellot_train_v3_ood.py b/cellot_train_v3_ood.py index e3e1b33..7dd78b4 100644 --- a/cellot_train_v3_ood.py +++ b/cellot_train_v3_ood.py @@ -1,17 +1,3 @@ -# -*- coding: utf-8 -*- -""" -Created on Thu Nov 14 14:55:57 2024 - -@author: Shadow -""" - -# -*- coding: utf-8 -*- -""" -Created on Tue Nov 12 11:36:46 2024 - -@author: Shadow -""" - import sys from pathlib import Path from cellot.train.train import train_cellot, train_auto_encoder, train_popalign From 369432ef11aa82c9cb0233252f321a2778b71ee2 Mon Sep 17 00:00:00 2001 From: Sabrina Date: Thu, 14 Nov 2024 18:49:29 +0100 Subject: [PATCH 3/7] add and delete some comments --- cellot_eval_v3_ood.py | 27 +++++++++++---------------- cellot_train_v3_ood.py | 30 ++++++++++++++---------------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/cellot_eval_v3_ood.py b/cellot_eval_v3_ood.py index 5fa7d79..6a90738 100644 --- a/cellot_eval_v3_ood.py +++ b/cellot_eval_v3_ood.py @@ -19,7 +19,7 @@ def get(self, key, default=None): def to_dict(self): """ - Convertit récursivement l'objet ConfigNamespace en dictionnaire. + Recursively converts the ConfigNamespace object into a dictionary. """ result = {} for key, value in self.__dict__.items(): @@ -30,19 +30,19 @@ def to_dict(self): return result def as_dict(self): - """Renvoie l'instance sous forme de dictionnaire pour une utilisation sécurisée dans le code.""" + """Returns the instance as a dictionary for secure use in code""" return self.to_dict() def __contains__(self, key): return key in self.__dict__ -# Fonction utilitaire pour transformer un dictionnaire en ConfigNamespace +# transform a dictionnary in ConfigNamespace def dict_to_namespace(config_dict): return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) -# Fonction pour convertir les objets ConfigNamespace en dictionnaire avant l'utilisation +# Function for converting ConfigNamespace objects into a dictionary before use def convert_to_dict_if_namespace(obj): - """Convertit un ConfigNamespace en dictionnaire, récursivement si nécessaire.""" + """converting ConfigNamespace objects into a dictionary""" if isinstance(obj, ConfigNamespace): return obj.to_dict() elif isinstance(obj, dict): @@ -104,20 +104,15 @@ def create_anndata_with_predictions(config, model_path, original_data): # Stack predictions into an array predicted_data_matrix = np.vstack(predicted_cells) - # Normalize predicted data to match original data's scale + # Normalize predicted data to match original data's scale ? # This assumes that the original data has been normalized with scanpy's `normalize_total` predicted_adata = anndata.AnnData(X=predicted_data_matrix) - #sc.pp.normalize_total(predicted_adata, target_sum=1e4) - #sc.pp.log1p(predicted_adata) - #predicted_data_matrix = predicted_adata.X + # Step 3: Verify if predicted_data_matrix contains NaNs after prediction loop print(f"Step 3: Nombre de NaN dans predicted_data_matrix : {np.isnan(predicted_data_matrix).sum()}") - # Optional: Replace NaNs in `predicted_data_matrix` if needed - # predicted_data_matrix = np.nan_to_num(predicted_data_matrix) - # print(f"Step 3b: NaNs replaced. Nombre de NaN dans predicted_data_matrix après remplacement : {np.isnan(predicted_data_matrix).sum()}") # Combine predicted data matrix with the existing data matrix, ensuring no duplication with 'ctrl' original_data_matrix = ( @@ -125,11 +120,11 @@ def create_anndata_with_predictions(config, model_path, original_data): ) # Step 4: Check for NaNs in the original data matrix - print(f"Step 4: Nombre de NaN dans original_data_matrix : {np.isnan(original_data_matrix).sum()}") + print(f"Step 4: Nomber of NaN in the original_data_matrix : {np.isnan(original_data_matrix).sum()}") # Step 5: Combine matrices and check for NaNs in the combined data combined_data = np.vstack([original_data_matrix, predicted_data_matrix]) - print(f"Step 5: Nombre de NaN dans combined_data après concaténation : {np.isnan(combined_data).sum()}") + print(f"Step 5: Nomber of NaN in combined_data : {np.isnan(combined_data).sum()}") # Copy original metadata and create labels for predictions combined_obs = original_data.obs.copy() @@ -154,7 +149,7 @@ def create_anndata_with_predictions(config, model_path, original_data): anndata_with_predictions.obs_names_make_unique() # Step 6: Check if the final AnnData object contains NaNs in X - print(f"Step 6: Nombre de NaN dans anndata_with_predictions.X : {np.isnan(anndata_with_predictions.X).sum()}") + print(f"Step 6: Nomber of NaN in anndata_with_predictions.X : {np.isnan(anndata_with_predictions.X).sum()}") # Optional: If desired, set raw attribute for the AnnData object anndata_with_predictions.raw = anndata_with_predictions.copy() @@ -184,7 +179,7 @@ def create_anndata_with_predictions(config, model_path, original_data): 'shuffle': True, 'datasplit_groupby': ['cell_type','condition'], 'datasplit_name': 'toggle_ood', - 'key' : 'cell_type',# Name for this data split + 'key' : 'cell_type', 'datasplit_mode': 'ood', # Set mode to 'ood' 'datasplit_holdout': holdout, # Specify holdout cell type 'datasplit_test_size': 0.3, diff --git a/cellot_train_v3_ood.py b/cellot_train_v3_ood.py index 7dd78b4..50eee22 100644 --- a/cellot_train_v3_ood.py +++ b/cellot_train_v3_ood.py @@ -35,7 +35,7 @@ def get(self, key, default=None): def to_dict(self): """ - Convertit récursivement l'objet ConfigNamespace en dictionnaire. + Recursively converts the ConfigNamespace object into a dictionary. """ result = {} for key, value in self.__dict__.items(): @@ -46,19 +46,19 @@ def to_dict(self): return result def as_dict(self): - """Renvoie l'instance sous forme de dictionnaire pour une utilisation sécurisée dans le code.""" + """Returns the instance as a dictionary for secure use in code""" return self.to_dict() def __contains__(self, key): return key in self.__dict__ -# Fonction utilitaire pour transformer un dictionnaire en ConfigNamespace +# transform a dictionnary in ConfigNamespace def dict_to_namespace(config_dict): return ConfigNamespace(**{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in config_dict.items()}) -# Fonction pour convertir les objets ConfigNamespace en dictionnaire avant l'utilisation +# Function for converting ConfigNamespace objects into a dictionary before use def convert_to_dict_if_namespace(obj): - """Convertit un ConfigNamespace en dictionnaire, récursivement si nécessaire.""" + """converting ConfigNamespace objects into a dictionary""" if isinstance(obj, ConfigNamespace): return obj.to_dict() elif isinstance(obj, dict): @@ -68,9 +68,7 @@ def convert_to_dict_if_namespace(obj): - def run_cellot_training(task_config, model_config, train_type="cellot", outdir='./output'): - # Créer la structure de configuration avec tous les éléments nécessaires config = { 'training': { 'n_iters': task_config.get('epochs', 10), @@ -103,13 +101,13 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' } } - # Convertir la configuration en un objet ConfigNamespace + # Transform in a ConfigNamespace config_ns = dict_to_namespace(config) # Convert outdir to Path object to ensure compatibility outdir = Path(outdir) - # Appeler la fonction d'entraînement en fonction du type de modèle choisi + # Call the model function if train_type == 'cellot': train_cellot(outdir, config_ns) elif train_type == 'auto_encoder': @@ -117,7 +115,7 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' elif train_type == 'popalign': train_popalign(outdir, config_ns) else: - raise ValueError("Type d'entraînement non supporté : {}".format(train_type)) + raise ValueError("Train type not supported: {}".format(train_type)) print(f"Training complete for {train_type} model at {outdir}.") @@ -213,7 +211,7 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' test_data_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" adata = sc.read_h5ad(test_data_path) input_dim = adata.shape[1] -print("Nombre de caractéristiques (input_dim) :", input_dim) +print("Number of variables (input_dim) :", input_dim) @@ -222,12 +220,12 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' import pandas as pd import matplotlib.pyplot as plt -# Charger les données de perte +# load losses data loss_data = pd.read_csv('.\\output_ood_models\\loss_tracking.csv') -# Vérifier que les colonnes attendues sont bien présentes +# Check for column if all(col in loss_data.columns for col in ['Step', 'Loss_G', 'Loss_F']): - # Tracer les pertes + # plot plt.figure(figsize=(10, 6)) plt.plot(loss_data['Step'], loss_data['Loss_G'], label='Loss_G') plt.plot(loss_data['Step'], loss_data['Loss_F'], label='Loss_F') @@ -238,5 +236,5 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' plt.title('Evolution of Losses (Loss G and Loss F) During Training') plt.show() else: - print("Les colonnes attendues 'step', 'loss_g', ou 'loss_f' ne sont pas présentes dans le fichier CSV.") -#AF echelle log !!! + print("No 'Step', 'Loss_G', ou 'Loss_F' in the CSV file.") + From 179090e5a01b492ac9c3df97143f832bee215156 Mon Sep 17 00:00:00 2001 From: Sabrina Date: Thu, 12 Dec 2024 18:04:36 +0100 Subject: [PATCH 4/7] =?UTF-8?q?Adding=20r=C2=B2,=20edistance,=20mmd=20and?= =?UTF-8?q?=20euclidean=20distance=20metrics?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cellot_v1_sj.md | 4 +- cellot_eval_v3_ood.py | 457 ++++++++++++++++++++++++++++++++++++++++- cellot_train_v3_ood.py | 2 +- 3 files changed, 459 insertions(+), 4 deletions(-) diff --git a/Cellot_v1_sj.md b/Cellot_v1_sj.md index 23b7d25..08664aa 100644 --- a/Cellot_v1_sj.md +++ b/Cellot_v1_sj.md @@ -59,9 +59,9 @@ Once models are trained, the `cellot_eval_v3_ood.py` script enables evaluation o - The script generates PCA and UMAP visualizations, specifically for the holdout cell type (excluded during training) for each trained model. - These plots allow direct visual inspection of predicted cell distributions compared to actual data, providing insights into the model's performance in the OOD setting. -2. **Performance Metrics (in progress)**: +2. **Performance Metrics**: - **R² Score**: Calculating R² for the predicted versus actual values helps quantify the model’s prediction accuracy for each cell type. - - **Transport Distance**: In progress !! The distances (euclidian, e and mmd), a transport function metric, assesses how accurately the model translates control cells into their perturbed states. + - **Transport Distance**: In progress !! The distances (euclidian, edistance and mmd), transport function metrics, assesses how accurately the model translates control cells into their perturbed states. The evaluation process allows detailed analysis of each model's performance per cell type, facilitating further adjustments and optimization of model parameters. diff --git a/cellot_eval_v3_ood.py b/cellot_eval_v3_ood.py index 6a90738..4a2961c 100644 --- a/cellot_eval_v3_ood.py +++ b/cellot_eval_v3_ood.py @@ -161,7 +161,7 @@ def create_anndata_with_predictions(config, model_path, original_data): dataset_path = "C:\\Users\\Shadow\\Desktop\\BioHack24\\scPRAM\\processed_datasets_all\\datasets\\scrna-lupuspatients\\kang-hvg.h5ad" adata = sc.read_h5ad(dataset_path) cell_types = adata.obs['cell_type'].unique() -output_dir = ".\\output_ood_models" +output_dir = ".\\output_ood_models_1" for cell_type in cell_types: model_dir = os.path.join(output_dir, f"{cell_type}_ood") @@ -257,5 +257,460 @@ def create_anndata_with_predictions(config, model_path, original_data): sc.pl.umap(holdout_cells, color="condition", title=f"UMAP for {cell_type}") +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd +from scipy import sparse +import numpy as np + +# Initialisation du dictionnaire pour stocker les résultats R² +r2_results = {} + +# Boucle sur chaque type cellulaire unique +for cell_type in anndata_with_predictions.obs['cell_type'].unique(): + # Filtrer les données pour ce type cellulaire + cell_data = anndata_with_predictions[anndata_with_predictions.obs['cell_type'] == cell_type] + + # Extraire les données pour les conditions 'stim' et 'predicted' + stim_data = cell_data[cell_data.obs['condition'] == 'stim'].X + predicted_data = cell_data[cell_data.obs['condition'] == 'predicted'].X + + # Convertir les matrices creuses en matrices denses si nécessaire + if sparse.issparse(stim_data): + stim_data = stim_data.toarray() + if sparse.issparse(predicted_data): + predicted_data = predicted_data.toarray() + + # Calculer la moyenne d'expression pour chaque gène + stim_mean = stim_data.mean(axis=0) + predicted_mean = predicted_data.mean(axis=0) + + # Calculer la corrélation de Pearson + r = np.corrcoef(stim_mean, predicted_mean)[0, 1] + r2 = r ** 2 # R² basé sur la corrélation de Pearson + + # Stocker le résultat + r2_results[cell_type] = r2 + print(f"R² pour {cell_type} entre les moyennes des gènes 'stim' et 'predicted' : {r2:.4f}") + + # Préparer les données pour la visualisation + df_plot = pd.DataFrame({ + 'Stim Mean Expression': stim_mean, + 'Predicted Mean Expression': predicted_mean + }) + + # Tracer le nuage de points avec la ligne de régression + plt.figure(figsize=(8, 6)) + sns.regplot( + x='Stim Mean Expression', + y='Predicted Mean Expression', + data=df_plot, + scatter_kws={'s': 10}, # Taille des points + line_kws={'color': 'red'} # Couleur de la ligne de régression + ) + plt.title(f'Regression Plot for {cell_type}\nR² = {r2:.4f}') + plt.xlabel('Stim Mean Expression') + plt.ylabel('Predicted Mean Expression') + plt.grid(True) + plt.show() + +# Afficher tous les résultats R² +print("\nR² entre les moyennes d'expression pour chaque type cellulaire entre 'stim' et 'predicted':") +for cell_type, r2 in r2_results.items(): + print(f"{cell_type}: {r2:.4f}") + + + +from scipy.spatial.distance import cdist +from scipy.sparse import issparse +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_edistance(set1, set2): + """ + Compute the energy distance between two datasets. + """ + intra_dist1 = np.mean(cdist(set1, set1, metric="euclidean")) + intra_dist2 = np.mean(cdist(set2, set2, metric="euclidean")) + inter_dist = np.mean(cdist(set1, set2, metric="euclidean")) + return 2 * inter_dist - intra_dist1 - intra_dist2 + +def compute_perturbation_score_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted"): + """ + Compute the perturbation score for each cell type. + + Parameters: + anndata: AnnData object containing gene expression data. + n_comps: Number of principal components to use. + condition_col: Column name in `.obs` that specifies the condition. + stim_key: Key for the stimulated condition. + ctrl_key: Key for the control condition. + pred_key: Key for the predicted condition. + + Returns: + A dictionary mapping each cell type to its perturbation score. + """ + perturbation_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract the subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Skip if any subset is empty + if stim_adata.shape[0] == 0 or ctrl_adata.shape[0] == 0 or pred_adata.shape[0] == 0: + print(f"Skipping {cell_type} due to insufficient data.\n") + continue + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Convert sparse matrices to dense + if issparse(stim_pca): stim_pca = stim_pca.toarray() + if issparse(ctrl_pca): ctrl_pca = ctrl_pca.toarray() + if issparse(pred_pca): pred_pca = pred_pca.toarray() + + # Compute energy distances + edistance_stim_pred = compute_edistance(stim_pca, pred_pca) # Perturbed vs Predicted + edistance_ctrl_pred = compute_edistance(ctrl_pca, pred_pca) # Control vs Predicted + + # Avoid division by zero + if edistance_ctrl_pred == 0: + perturbation_score = np.nan + else: + perturbation_score = edistance_stim_pred / edistance_ctrl_pred + + perturbation_scores[cell_type] = perturbation_score + print(f"Perturbation score for {cell_type}: {perturbation_score}\n") + + return perturbation_scores + + +perturbation_scores = compute_perturbation_score_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted" +) + +# Display the results +print("Scaled perturbation scores for all cell types:") +for cell_type, score in perturbation_scores.items(): + print(f"{cell_type}: {score:.4f}") + + +#-------------- mmd ------------------- + +from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_mmd(set1, set2, kernel="linear", **kernel_kwargs): + """ + Compute the Maximum Mean Discrepancy (MMD) between two datasets. + + Parameters: + set1: np.ndarray + First dataset (e.g., real perturbed data). + set2: np.ndarray + Second dataset (e.g., predicted data). + kernel: str + Type of kernel to use. Options are 'linear', 'rbf', and 'poly'. + **kernel_kwargs: + Additional arguments for the kernel function (e.g., gamma for RBF). + + Returns: + float + MMD score. + """ + if kernel == "linear": + XX = np.dot(set1, set1.T) + YY = np.dot(set2, set2.T) + XY = np.dot(set1, set2.T) + elif kernel == "rbf": + XX = rbf_kernel(set1, set1, **kernel_kwargs) + YY = rbf_kernel(set2, set2, **kernel_kwargs) + XY = rbf_kernel(set1, set2, **kernel_kwargs) + elif kernel == "poly": + XX = polynomial_kernel(set1, set1, **kernel_kwargs) + YY = polynomial_kernel(set2, set2, **kernel_kwargs) + XY = polynomial_kernel(set1, set2, **kernel_kwargs) + else: + raise ValueError(f"Unsupported kernel type: {kernel}") + + return XX.mean() + YY.mean() - 2 * XY.mean() + +def compute_mmd_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted", + kernel="linear", + **kernel_kwargs): + """ + Compute the MMD for each cell type. + + Parameters: + anndata: AnnData + Annotated data matrix containing the data. + n_comps: int + Number of PCA components to use. + condition_col: str + Column in `.obs` specifying the condition of the cells. + stim_key: str + Key for the stimulated condition in `.obs`. + ctrl_key: str + Key for the control condition in `.obs`. + pred_key: str + Key for the predicted condition in `.obs`. + kernel: str + Kernel type for MMD. Options: 'linear', 'rbf', 'poly'. + **kernel_kwargs: + Additional parameters for the kernel function. + + Returns: + dict + A dictionary mapping each cell type to its MMD perturbation score. + """ + mmd_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Compute MMD scores + mmd_stim_pred = compute_mmd(stim_pca, pred_pca, kernel=kernel, **kernel_kwargs) # Stimulated vs Predicted + mmd_ctrl_pred = compute_mmd(ctrl_pca, pred_pca, kernel=kernel, **kernel_kwargs) # Control vs Predicted + + # Combine scores into a perturbation score + mmd_score = mmd_stim_pred / mmd_ctrl_pred + mmd_scores[cell_type] = mmd_score + print(f"MMD perturbation score for {cell_type}: {mmd_score}\n") + + return mmd_scores + +# Compute MMD scores for all cell types +mmd_scores = compute_mmd_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted", + kernel="linear"#, # Example: using RBF kernel +# gamma=1.0 # Example parameter for the RBF kernel +) + +# Display the results +print("MMD perturbation scores for all cell types:") +for cell_type, score in mmd_scores.items(): + print(f"{cell_type}: {score:.4f}") + +#------------------ euclidean distances ------------------- + +from sklearn.metrics.pairwise import euclidean_distances +import numpy as np +import pandas as pd +import scanpy as sc + +def compute_mean_euclidean_distance(set1, set2): + """ + Compute the mean Euclidean distance between two datasets. + + Parameters: + set1: np.ndarray + First dataset (e.g., real perturbed data). + set2: np.ndarray + Second dataset (e.g., predicted data). + + Returns: + float + Mean Euclidean distance between set1 and set2. + """ + pairwise_distances = euclidean_distances(set1, set2) + return pairwise_distances.mean() + +def compute_euclidean_distance_per_cell_type(anndata, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted"): + """ + Compute the mean Euclidean distance for each cell type. + + Parameters: + anndata: AnnData + Annotated data matrix containing the data. + n_comps: int + Number of PCA components to use. + condition_col: str + Column in `.obs` specifying the condition of the cells. + stim_key: str + Key for the stimulated condition in `.obs`. + ctrl_key: str + Key for the control condition in `.obs`. + pred_key: str + Key for the predicted condition in `.obs`. + + Returns: + dict + A dictionary mapping each cell type to its Euclidean perturbation score. + """ + euclidean_scores = {} + + # Perform PCA on the data + if n_comps > min(anndata.shape): + n_comps = min(anndata.shape) - 1 + + sc.tl.pca(anndata, svd_solver="arpack", n_comps=n_comps) + print(f"PCA with {n_comps} components computed.\n") + + # Iterate over each cell type + for cell_type in anndata.obs['cell_type'].unique(): + print(f"Processing cell type: {cell_type}") + + # Subset the data for the current cell type + cell_data = anndata[anndata.obs['cell_type'] == cell_type] + + # Extract subsets for stimulated, control, and predicted data + stim_adata = cell_data[cell_data.obs[condition_col] == stim_key] + ctrl_adata = cell_data[cell_data.obs[condition_col] == ctrl_key] + pred_adata = cell_data[cell_data.obs[condition_col] == pred_key] + + # Extract PCA embeddings + stim_pca = stim_adata.obsm["X_pca"] + ctrl_pca = ctrl_adata.obsm["X_pca"] + pred_pca = pred_adata.obsm["X_pca"] + + # Compute Euclidean distances + euclidean_stim_pred = compute_mean_euclidean_distance(stim_pca, pred_pca) # Stimulated vs Predicted + euclidean_ctrl_pred = compute_mean_euclidean_distance(ctrl_pca, pred_pca) # Control vs Predicted + + # Combine scores into a perturbation score + euclidean_score = euclidean_stim_pred / euclidean_ctrl_pred + euclidean_scores[cell_type] = euclidean_score + print(f"Euclidean perturbation score for {cell_type}: {euclidean_score}\n") + + return euclidean_scores + +# Compute Euclidean distance scores for all cell types +euclidean_scores = compute_euclidean_distance_per_cell_type( + anndata=anndata_with_predictions, + n_comps=50, + condition_col="condition", + stim_key="stim", + ctrl_key="ctrl", + pred_key="predicted" +) + +# Display the results +print("Euclidean perturbation scores for all cell types:") +for cell_type, score in euclidean_scores.items(): + print(f"{cell_type}: {score:.4f}") + + +# Bar plot with each metrics + + +import matplotlib.pyplot as plt +import numpy as np + +# Assurez-vous que les listes suivantes contiennent les données réelles générées +cell_types = list(r2_results.keys()) # Les types cellulaires + + +r2_scores = [float(val) for val in r2_results.values()] # Conversion si nécessaire +edistances = [float(val) for val in perturbation_scores.values()] # Conversion si nécessaire +mmd_res = [float(val) for val in mmd_scores.values()] # Conversion des ArrayView +euclidean_dist = [float(val) for val in euclidean_scores.values()] # Conversion si nécessaire + + + + +# Configuration des sous-graphiques +fig, axes = plt.subplots(1, 4, figsize=(16, 8), sharey=True) + +# Graphique pour R² Scores +axes[0].barh(cell_types, r2_scores, color='blue', edgecolor='black') +axes[0].set_title("R² Scores") +axes[0].set_xlabel("Valeur") +axes[0].invert_yaxis() # Alignement des types cellulaires sur tous les graphiques + +# Graphique pour Energy Distance +axes[1].barh(cell_types, edistances, color='green', edgecolor='black') +axes[1].set_title("Energy Distance") +axes[1].set_xlabel("Valeur") + +# Graphique pour MMD Scores +axes[2].barh(cell_types, mmd_res, color='orange', edgecolor='black') +axes[2].set_title("MMD Scores") +axes[2].set_xlabel("Valeur") + +# Graphique pour Euclidean Distance +axes[3].barh(cell_types, euclidean_dist, color='red', edgecolor='black') +axes[3].set_title("Euclidean Distance") +axes[3].set_xlabel("Valeur") + +# Ajuster l'espacement entre les sous-graphiques +plt.tight_layout() + +# Afficher le graphique +plt.show() + + + + + + + + + diff --git a/cellot_train_v3_ood.py b/cellot_train_v3_ood.py index 50eee22..772aefe 100644 --- a/cellot_train_v3_ood.py +++ b/cellot_train_v3_ood.py @@ -131,7 +131,7 @@ def run_cellot_training(task_config, model_config, train_type="cellot", outdir=' cell_types = adata.obs['cell_type'].unique() # Directory to save the models -output_dir = '.\\output_ood_models' +output_dir = '.\\output_ood_models_1' for cell_type in cell_types: From 26eab670bd861013751d38b9d32c96d6838cb616 Mon Sep 17 00:00:00 2001 From: Sabrina Date: Fri, 13 Dec 2024 17:29:01 +0100 Subject: [PATCH 5/7] add data preparation issue + comments --- Cellot_v1_sj.md | 6 +++-- cd | 0 cellot_eval_v3_ood.py | 59 +++++++++++++++++++++---------------------- git | 0 4 files changed, 33 insertions(+), 32 deletions(-) create mode 100644 cd create mode 100644 git diff --git a/Cellot_v1_sj.md b/Cellot_v1_sj.md index 08664aa..3428808 100644 --- a/Cellot_v1_sj.md +++ b/Cellot_v1_sj.md @@ -12,13 +12,15 @@ The Cellot model requires an AnnData object that contains information for two co ### Data Requirements - **Data format**: The data should be in an AnnData structure compatible with single-cell analysis tools like Scanpy. Each observation in AnnData should include metadata in the `.obs` attribute, including cell type and condition. - +- **Issue with Pertpy anndata objects**: You need to save the metadata and matrix counts of pertpy datasets (in the pertpy environment, anndata version 0.10.8) and then rebuild the anndata objects in the cellOT environment with anndata version 0.7.6, which does not read anndata objects saved in 0.10.8. + - **Data Normalization**: - Normalization is essential for consistent model performance. Using the `normalize_total` and `log1p` function in Scanpy. - **Scaling**: its specific impact on Cellot is still being evaluated, though standardizing features across cells may be beneficial for model performance..don't know yet After preparing and normalizing data, save the AnnData object for input into the Cellot model training and evaluation scripts. + --- ## 2. Training the Model with `cellot_train_v3_ood.py` @@ -61,7 +63,7 @@ Once models are trained, the `cellot_eval_v3_ood.py` script enables evaluation o 2. **Performance Metrics**: - **R² Score**: Calculating R² for the predicted versus actual values helps quantify the model’s prediction accuracy for each cell type. - - **Transport Distance**: In progress !! The distances (euclidian, edistance and mmd), transport function metrics, assesses how accurately the model translates control cells into their perturbed states. + - **Transport Distance**: The distances (euclidian, edistance and mmd), transport function metrics, assesses how accurately the model translates control cells into their perturbed states. The evaluation process allows detailed analysis of each model's performance per cell type, facilitating further adjustments and optimization of model parameters. diff --git a/cd b/cd new file mode 100644 index 0000000..e69de29 diff --git a/cellot_eval_v3_ood.py b/cellot_eval_v3_ood.py index 4a2961c..0fbcf53 100644 --- a/cellot_eval_v3_ood.py +++ b/cellot_eval_v3_ood.py @@ -256,6 +256,7 @@ def create_anndata_with_predictions(config, model_path, original_data): sc.tl.umap(holdout_cells) sc.pl.umap(holdout_cells, color="condition", title=f"UMAP for {cell_type}") +#---------------------- R² ----------------------- import seaborn as sns import matplotlib.pyplot as plt @@ -263,43 +264,43 @@ def create_anndata_with_predictions(config, model_path, original_data): from scipy import sparse import numpy as np -# Initialisation du dictionnaire pour stocker les résultats R² +# R² dictionnary initialization r2_results = {} -# Boucle sur chaque type cellulaire unique +# Loop on cell types for cell_type in anndata_with_predictions.obs['cell_type'].unique(): - # Filtrer les données pour ce type cellulaire + cell_data = anndata_with_predictions[anndata_with_predictions.obs['cell_type'] == cell_type] - # Extraire les données pour les conditions 'stim' et 'predicted' + # Extract data on each condition 'stim' and 'predicted' for the specific cell type stim_data = cell_data[cell_data.obs['condition'] == 'stim'].X predicted_data = cell_data[cell_data.obs['condition'] == 'predicted'].X - # Convertir les matrices creuses en matrices denses si nécessaire + # if necessary convert sparce matrix if sparse.issparse(stim_data): stim_data = stim_data.toarray() if sparse.issparse(predicted_data): predicted_data = predicted_data.toarray() - # Calculer la moyenne d'expression pour chaque gène + # expression mean on each genes stim_mean = stim_data.mean(axis=0) predicted_mean = predicted_data.mean(axis=0) - # Calculer la corrélation de Pearson + # Pearson correlations r = np.corrcoef(stim_mean, predicted_mean)[0, 1] r2 = r ** 2 # R² basé sur la corrélation de Pearson - # Stocker le résultat + # Store the result r2_results[cell_type] = r2 - print(f"R² pour {cell_type} entre les moyennes des gènes 'stim' et 'predicted' : {r2:.4f}") + print(f"R² for {cell_type} between 'stim' and 'predicted' mean genes: {r2:.4f}") - # Préparer les données pour la visualisation + # Data process to visualization df_plot = pd.DataFrame({ 'Stim Mean Expression': stim_mean, 'Predicted Mean Expression': predicted_mean }) - # Tracer le nuage de points avec la ligne de régression + # RegPlot plt.figure(figsize=(8, 6)) sns.regplot( x='Stim Mean Expression', @@ -314,12 +315,12 @@ def create_anndata_with_predictions(config, model_path, original_data): plt.grid(True) plt.show() -# Afficher tous les résultats R² -print("\nR² entre les moyennes d'expression pour chaque type cellulaire entre 'stim' et 'predicted':") +# Print R² results +print("\nR² of mean expression genes for each cell type between 'stim' and 'predicted' cells:") for cell_type, r2 in r2_results.items(): print(f"{cell_type}: {r2:.4f}") - +#----------- edistance -------------- from scipy.spatial.distance import cdist from scipy.sparse import issparse @@ -656,52 +657,50 @@ def compute_euclidean_distance_per_cell_type(anndata, print(f"{cell_type}: {score:.4f}") -# Bar plot with each metrics +#-------------------- Bar plot with each metrics ---------------- import matplotlib.pyplot as plt import numpy as np -# Assurez-vous que les listes suivantes contiennent les données réelles générées -cell_types = list(r2_results.keys()) # Les types cellulaires - -r2_scores = [float(val) for val in r2_results.values()] # Conversion si nécessaire -edistances = [float(val) for val in perturbation_scores.values()] # Conversion si nécessaire -mmd_res = [float(val) for val in mmd_scores.values()] # Conversion des ArrayView -euclidean_dist = [float(val) for val in euclidean_scores.values()] # Conversion si nécessaire +cell_types = list(r2_results.keys()) # cell types +r2_scores = [float(val) for val in r2_results.values()] # Convert if necessary +edistances = [float(val) for val in perturbation_scores.values()] # Convert if necessary +mmd_res = [float(val) for val in mmd_scores.values()] # Convert if necessary +euclidean_dist = [float(val) for val in euclidean_scores.values()] # Convert if necessary -# Configuration des sous-graphiques +# Config of sub-graphs fig, axes = plt.subplots(1, 4, figsize=(16, 8), sharey=True) -# Graphique pour R² Scores +# Graph of R² Scores axes[0].barh(cell_types, r2_scores, color='blue', edgecolor='black') axes[0].set_title("R² Scores") axes[0].set_xlabel("Valeur") -axes[0].invert_yaxis() # Alignement des types cellulaires sur tous les graphiques +axes[0].invert_yaxis() # Aligment of cell types on all graphs -# Graphique pour Energy Distance +# Graph of Energy Distances axes[1].barh(cell_types, edistances, color='green', edgecolor='black') axes[1].set_title("Energy Distance") axes[1].set_xlabel("Valeur") -# Graphique pour MMD Scores +# Graph of MMD Scores axes[2].barh(cell_types, mmd_res, color='orange', edgecolor='black') axes[2].set_title("MMD Scores") axes[2].set_xlabel("Valeur") -# Graphique pour Euclidean Distance +# Graph of Euclidean Distance axes[3].barh(cell_types, euclidean_dist, color='red', edgecolor='black') axes[3].set_title("Euclidean Distance") axes[3].set_xlabel("Valeur") -# Ajuster l'espacement entre les sous-graphiques +# Spaces between sub-graphs plt.tight_layout() -# Afficher le graphique +# plot plt.show() diff --git a/git b/git new file mode 100644 index 0000000..e69de29 From 9407024116065181bc5095b1938da4762bcba43c Mon Sep 17 00:00:00 2001 From: Sabrina Date: Wed, 7 May 2025 11:27:18 +0200 Subject: [PATCH 6/7] Organisation : cellOT sj branch --- .../Complex_generative/cellOT_v1/Cellot_v1_sj.md | 0 .../Complex_generative/cellOT_v1/cellot_eval_v3_ood.py | 0 .../Complex_generative/cellOT_v1/cellot_train_v3_ood.py | 0 .../Complex_generative/cellOT_v1/source_modif/cell.py | 0 .../Complex_generative/cellOT_v1/source_modif/cellot.py | 0 .../Complex_generative/cellOT_v1/source_modif/icnns.py | 0 .../Complex_generative/cellOT_v1/source_modif/train.py | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename Cellot_v1_sj.md => tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md (100%) rename cellot_eval_v3_ood.py => tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py (100%) rename cellot_train_v3_ood.py => tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py (100%) rename cell.py => tools/Complex_generative/cellOT_v1/source_modif/cell.py (100%) rename cellot.py => tools/Complex_generative/cellOT_v1/source_modif/cellot.py (100%) rename icnns.py => tools/Complex_generative/cellOT_v1/source_modif/icnns.py (100%) rename train.py => tools/Complex_generative/cellOT_v1/source_modif/train.py (100%) diff --git a/Cellot_v1_sj.md b/tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md similarity index 100% rename from Cellot_v1_sj.md rename to tools/Complex_generative/cellOT_v1/Cellot_v1_sj.md diff --git a/cellot_eval_v3_ood.py b/tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py similarity index 100% rename from cellot_eval_v3_ood.py rename to tools/Complex_generative/cellOT_v1/cellot_eval_v3_ood.py diff --git a/cellot_train_v3_ood.py b/tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py similarity index 100% rename from cellot_train_v3_ood.py rename to tools/Complex_generative/cellOT_v1/cellot_train_v3_ood.py diff --git a/cell.py b/tools/Complex_generative/cellOT_v1/source_modif/cell.py similarity index 100% rename from cell.py rename to tools/Complex_generative/cellOT_v1/source_modif/cell.py diff --git a/cellot.py b/tools/Complex_generative/cellOT_v1/source_modif/cellot.py similarity index 100% rename from cellot.py rename to tools/Complex_generative/cellOT_v1/source_modif/cellot.py diff --git a/icnns.py b/tools/Complex_generative/cellOT_v1/source_modif/icnns.py similarity index 100% rename from icnns.py rename to tools/Complex_generative/cellOT_v1/source_modif/icnns.py diff --git a/train.py b/tools/Complex_generative/cellOT_v1/source_modif/train.py similarity index 100% rename from train.py rename to tools/Complex_generative/cellOT_v1/source_modif/train.py From 74ba479d04ae0a1c1ad3d459efc4ac563a006afd Mon Sep 17 00:00:00 2001 From: Sabrina Date: Wed, 7 May 2025 11:32:27 +0200 Subject: [PATCH 7/7] Remove empty file --- cd | 0 git | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 cd delete mode 100644 git diff --git a/cd b/cd deleted file mode 100644 index e69de29..0000000 diff --git a/git b/git deleted file mode 100644 index e69de29..0000000