diff --git a/deepprofiler/__main__.py b/deepprofiler/__main__.py index 8d94086..552f743 100644 --- a/deepprofiler/__main__.py +++ b/deepprofiler/__main__.py @@ -1,116 +1,85 @@ -"""Command-line interface for DeepProfiler. - -Four subcommands are available, intended to be run in order: - -1. ``setup`` — create the project directory structure under ``--root``. -2. ``prepare`` — compute per-plate illumination statistics and compress images - to 8-bit PNG (optional but recommended for large datasets). -3. ``profile`` — extract per-cell deep learning features using the Cell - Painting CNN v1 checkpoint and write ``.npz`` files. -4. ``split`` — split the metadata index into N parts for parallel profiling - across multiple machines or jobs. - -Typical usage:: - - deepprofiler --root=/data/project --config=config.json --exp=run1 profile - -See README.md and the DeepProfiler Handbook for full configuration details. -""" - -import copy import json import os import click import deepprofiler.dataset.compression -import deepprofiler.dataset.illumination_statistics -import deepprofiler.dataset.image_dataset import deepprofiler.dataset.indexing +import deepprofiler.dataset.illumination_statistics import deepprofiler.dataset.metadata import deepprofiler.dataset.utils -import deepprofiler.profiling +import deepprofiler.dataset.image_dataset +import deepprofiler.learning.training +import deepprofiler.learning.profiling +import deepprofiler.learning.optimization +import deepprofiler.download.normalize_bbbc021_metadata # Main interaction point @click.group() @click.option("--root", prompt="Root directory for DeepProfiler experiment", help="Root directory for DeepProfiler experiment", - type=click.Path(exists=True)) + type=click.Path("r")) @click.option("--config", default=None, - help="Path to existing config file (filename in project_root/inputs/config/)", - type=click.STRING) + help="Path to existing config file", + type=click.Path("r")) @click.option("--cores", default=0, - help="Number of CPU cores for parallel processing (all=0) for prepare command", + help="Number of CPU cores for parallel processing (all=0)", type=click.INT) -@click.option("--gpu", default="0", - help="GPU device id (the id can be checked with nvidia-smi)", - type=click.STRING) -@click.option("--exp", default="results", - help="Name of experiment, this folder will be created in project_root/outputs/", - type=click.STRING) -@click.option("--metadata", default='index.csv', - help="Metadata index filename in project_root/inputs/metadata/", - type=click.STRING) @click.pass_context -def cli(context, root, config, exp, cores, gpu, metadata): - """Configure paths and load the experiment config, then dispatch to a subcommand.""" +def cli(context, root, config, cores): dirs = { "root": root, - "locations": root + "/inputs/locations/", # TODO: use os.path.join() - "config": root + "/inputs/config/", - "images": root + "/inputs/images/", - "metadata": root + "/inputs/metadata/", - "intensities": root + "/outputs/intensities/", - "compressed_images": root + "/outputs/compressed/images/", - "results": root + "/outputs/" + exp + "/", - "checkpoints": root + "/outputs/" + exp + "/checkpoint/", - "logs": root + "/outputs/" + exp + "/logs/", - "summaries": root + "/outputs/" + exp + "/summaries/", - "features": root + "/outputs/" + exp + "/features/" + "locations": os.path.join(root, "inputs", "locations"), + "config": os.path.join(root, "inputs", "config"), + "images": os.path.join(root, "inputs", "images"), + "metadata": os.path.join(root, "inputs", "metadata"), + "preprocessed": os.path.join(root, "inputs", "preprocessed"), + "pretrained": os.path.join(root, "inputs", "pretrained"), + "intensities": os.path.join(root, "outputs", "intensities"), + "compressed_images": os.path.join(root, "outputs", "compressed", "images"), + "compressed_metadata": os.path.join(root, "outputs", "compressed", "metadata"), + "training": os.path.join(root, "outputs", "training"), + "checkpoints": os.path.join(root, "outputs", "training", "checkpoint"), + "logs": os.path.join(root, "outputs", "training", "logs"), + "summaries": os.path.join(root, "outputs", "training", "summaries"), + "features": os.path.join(root, "outputs", "features") } - if context.invoked_subcommand == 'setup': - context.obj["dirs"] = dirs - return + if config is not None: + + context.obj["config"] = {} + context.obj["config"]["paths"] = {} + context.obj["config"]["paths"]["config"] = config + dirs["config"] = os.path.dirname(os.path.abspath(config)) + else: + config = os.path.join(dirs["config"], "config.json") - config = dirs["config"] + "/" + config context.obj["cores"] = cores - context.obj["gpu"] = gpu - os.environ["CUDA_VISIBLE_DEVICES"] = gpu - # Load configuration file - if config is not None and os.path.isfile(config): + + if os.path.isfile(config): with open(config, "r") as f: params = json.load(f) - - # Override paths defined by user if "paths" in params.keys(): for key, value in dirs.items(): if key not in params["paths"].keys(): - params["paths"][key] = dirs[key] + params["paths"][key] = os.path.join(root, dirs[key]) else: - dirs[key] = params["paths"][key] - else: - params["paths"] = copy.deepcopy(dirs) + dirs[key] = os.path.join(root, params["paths"][key]) - if os.path.isdir(dirs["root"]): - for k in ["results", "checkpoints", "logs", "summaries", "features"]: - os.makedirs(dirs[k], exist_ok=True) + else: + params["paths"] = dirs - # Update references - params["experiment_name"] = exp - params["paths"]["index"] = params["paths"]["metadata"] + metadata + params["paths"]["index"] = os.path.join(root, params["paths"]["metadata"], "index.csv") context.obj["config"] = params - else: - raise Exception("Config does not exists; make sure that the file exists in /inputs/config/") - + process = deepprofiler.dataset.utils.Parallel(context.obj["config"], numProcs=context.obj["cores"]) + context.obj["process"] = process context.obj["dirs"] = dirs # Optional tool: Create the support file and folder structure in a root directory -@cli.command(help='initialize folder structure of DeepProfiler project') +@cli.command() @click.pass_context def setup(context): - """Create the project directory tree under the configured root.""" for path in context.obj["dirs"].values(): if not os.path.isdir(path): print("Creating directory: ", path) @@ -121,49 +90,85 @@ def setup(context): context.obj["config"]["paths"] = context.obj["dirs"] +# Optional tool: Download and prepare the BBBC021 dataset +@cli.command() +@click.pass_context +def download_bbbc021(context): + context.invoke(setup) + deepprofiler.download.normalize_bbbc021_metadata.normalize_bbbc021_metadata(context) + print("BBBC021 download and preparation complete!") + + # First tool: Compute illumination statistics and compress images -@cli.command(help='Run illumination correction and compression') +@cli.command() @click.pass_context def prepare(context): - """Compute per-plate illumination statistics and compress images to 8-bit PNG.""" metadata = deepprofiler.dataset.metadata.read_plates(context.obj["config"]["paths"]["index"]) - process = deepprofiler.dataset.utils.Parallel(context.obj["config"], numProcs=context.obj["cores"]) + process = context.obj["process"] process.compute(deepprofiler.dataset.illumination_statistics.calculate_statistics, metadata) print("Illumination complete!") - metadata = deepprofiler.dataset.metadata.read_plates( - context.obj["config"]["paths"]["index"]) # reinitialize generator + metadata = deepprofiler.dataset.metadata.read_plates(context.obj["config"]["paths"]["index"]) # reinitialize generator process.compute(deepprofiler.dataset.compression.compress_plate, metadata) + deepprofiler.dataset.indexing.write_compression_index(context.obj["config"]) + context.parent.obj["config"]["paths"]["index"] = os.path.join(context.obj["config"]["paths"]["compressed_metadata"], "compressed.csv") print("Compression complete!") -# Second tool: Profile cells and extract features -@cli.command(help='run feature extraction') +# Optional learning tool: Optimize the hyperparameters of a model +@cli.command() +@click.option("--epoch", default=1) +@click.option("--seed", default=None) +@click.pass_context +def optimize(context, epoch, seed): + if context.parent.obj["config"]["prepare"]["compression"]["implement"]: + context.parent.obj["config"]["paths"]["index"] = os.path.join(context.obj["config"]["paths"]["compressed_metadata"], "compressed.csv") + context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] + metadata = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) + optim = deepprofiler.learning.optimization.Optimize(context.obj["config"], metadata, epoch, seed) + optim.optimize() + + +# Second tool: Train a network +@cli.command() +@click.option("--epoch", default=1) +@click.option("--seed", default=None) +@click.pass_context +def train(context, epoch, seed): + if context.parent.obj["config"]["prepare"]["compression"]["implement"]: + context.parent.obj["config"]["paths"]["index"] = os.path.join(context.obj["config"]["paths"]["compressed_metadata"], "compressed.csv") + context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] + metadata = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) + deepprofiler.learning.training.learn_model(context.obj["config"], metadata, epoch, seed) + + +# Third tool: Profile cells and extract features +@cli.command() @click.pass_context @click.option("--part", help="Part of index to process", default=-1, type=click.INT) def profile(context, part): - """Extract per-cell deep learning features and write .npz files.""" if context.parent.obj["config"]["prepare"]["compression"]["implement"]: + context.parent.obj["config"]["paths"]["index"] = os.path.join(context.obj["config"]["paths"]["compressed_metadata"], "compressed.csv") context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] config = context.obj["config"] if part >= 0: partfile = "index-{0:03d}.csv".format(part) config["paths"]["index"] = context.obj["config"]["paths"]["index"].replace("index.csv", partfile) - dset = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"], mode='profile') - deepprofiler.profiling.profile(context.obj["config"], dset) + metadata = deepprofiler.dataset.image_dataset.read_dataset(context.obj["config"]) + deepprofiler.learning.profiling.profile(context.obj["config"], metadata) # Auxiliary tool: Split index in multiple parts -@cli.command(help='split metadata into multiple parts') +@cli.command() @click.pass_context @click.option("--parts", help="Number of parts to split the index", type=click.INT) def split(context, parts): - """Split the metadata index into N parts for parallel profiling jobs.""" if context.parent.obj["config"]["prepare"]["compression"]["implement"]: + context.parent.obj["config"]["paths"]["index"] = os.path.join(context.obj["config"]["paths"]["compressed_metadata"], "compressed.csv") context.parent.obj["config"]["paths"]["images"] = context.obj["config"]["paths"]["compressed_images"] deepprofiler.dataset.indexing.split_index(context.obj["config"], parts) diff --git a/deepprofiler/dataset/image_dataset.py b/deepprofiler/dataset/image_dataset.py index 22a9647..a3da902 100644 --- a/deepprofiler/dataset/image_dataset.py +++ b/deepprofiler/dataset/image_dataset.py @@ -1,272 +1,84 @@ -"""Dataset abstraction over a metadata index and per-channel image files. - -:class:`ImageDataset` is the central data container used throughout -DeepProfiler. It wraps a :class:`~deepprofiler.dataset.metadata.Metadata` -object and provides two key interfaces: - -- :meth:`ImageDataset.get_image_paths` — resolve per-channel file paths for - one metadata record. -- :meth:`ImageDataset.scan` — iterate over images and call a processing - function (used by :func:`~deepprofiler.profiling.profile` to drive feature - extraction). - -:func:`read_dataset` is the standard factory that reads a config dict and -returns a fully initialised :class:`ImageDataset`. - -:class:`ImageLocations` is a helper used internally to load all cell centroid -CSVs in parallel before training — it is not used during profiling. -""" - import os - import numpy as np import pandas as pd -import deepprofiler.dataset.metadata import deepprofiler.dataset.pixels -import deepprofiler.dataset.target import deepprofiler.dataset.utils -import deepprofiler.imaging.boxes - - -class ImageLocations(object): - """Pre-load cell locations for a set of images in parallel. - - Collects image keys, paths, and target labels from a metadata DataFrame, - then uses :class:`~deepprofiler.dataset.utils.Parallel` to read all - location CSVs concurrently. Used by - :meth:`ImageDataset.prepare_training_locations`. - - Args: - metadata_training: DataFrame slice (e.g. ``Metadata.train``) to index. - getImagePaths: Callable ``(row) -> (key, paths, outlines)``. - targets: List of - :class:`~deepprofiler.dataset.target.MetadataColumnTarget` objects. - """ - - def __init__(self, metadata_training, getImagePaths, targets): - self.keys = [] - self.images = [] - self.targets = [] - self.outlines = [] - for i, r in metadata_training.iterrows(): - key, image, outl = getImagePaths(r) - self.keys.append(key) - self.images.append(image) - self.targets.append([t.get_values(r) for t in targets]) - self.outlines.append(outl) - print("Reading single-cell locations") - - def load_loc(self, params): - """Load the locations CSV for one image (worker function for Parallel). - - Args: - params: ``[index, config]`` as passed by - :class:`~deepprofiler.dataset.utils.Parallel`. - - Returns: - DataFrame with centroid coordinates plus ``ID``, ``ImageKey``, - ``ImagePaths``, ``Target``, and ``Outlines`` columns appended. - """ - i, config = params - loc = deepprofiler.imaging.boxes.get_locations(self.keys[i], config) - loc["ID"] = loc.index - loc["ImageKey"] = self.keys[i] - loc["ImagePaths"] = "#".join(self.images[i]) - loc["Target"] = self.targets[i][0] - loc["Outlines"] = self.outlines[i] - print("Image", i, ":", len(loc), "cells", end="\r") - return loc - - def load_locations(self, config): - """Load all location CSVs in parallel and return a list of DataFrames. - - Args: - config: Experiment configuration dict (used for worker count and - passed through to :func:`~deepprofiler.imaging.boxes.get_locations`). - - Returns: - List of DataFrames, one per image. - """ - process = deepprofiler.dataset.utils.Parallel(config, numProcs=config["train"]["sampling"]["workers"]) - data = process.compute(self.load_loc, [x for x in range(len(self.keys))]) - process.close() - return data +import deepprofiler.dataset.metadata +import deepprofiler.dataset.target class ImageDataset(): - """Container for a metadata index and its associated image files. - - Provides path resolution, per-image scanning, and (for the training path) - location pre-loading and batch sampling. During profiling only - :meth:`get_image_paths`, :meth:`scan`, :meth:`add_target`, and - :meth:`number_of_records` are used. - Args: - metadata: :class:`~deepprofiler.dataset.metadata.Metadata` object. - sampling_field: Metadata column used as the classification label - (e.g. ``"Class"``). - channels: List of metadata column names, one per imaging channel, - whose values are filenames relative to ``dataRoot``. - dataRoot: Root directory containing image files. - keyGen: Callable ``(row) -> str`` that produces the image key used to - look up location CSVs (typically - ``"{Metadata_Plate}/{Metadata_Well}-{Metadata_Site}"``). - config: Full experiment configuration dict. - """ - - def __init__(self, metadata, sampling_field, channels, dataRoot, keyGen, config): - self.meta = metadata - self.channels = channels - self.root = dataRoot - self.keyGen = keyGen - self.sampling_field = sampling_field + def __init__(self, metadata, sampling_field, channels, dataRoot, keyGen): + self.meta = metadata # Metadata object with a valid dataframe + self.channels = channels # List of column names corresponding to each channel file + self.root = dataRoot # Path to the directory of images + self.keyGen = keyGen # Function that returns the image key given its record in the metadata + self.sampling_field = sampling_field # Field in the metadata used to sample images evenly self.sampling_values = metadata.data[sampling_field].unique() self.targets = [] self.outlines = None - self.config = config - - def get_image_paths(self, r): - """Resolve per-channel file paths and the image key for one metadata row. - - If a channel value is already an absolute directory path it is used - as-is; otherwise the filename is joined to ``self.root``. - Args: - r: A row from ``self.meta.data`` (Pandas Series or dict-like). - - Returns: - Tuple ``(key, image_paths, outlines)`` where ``key`` is the string - identifier (e.g. ``"Plate1/A01-1"``), ``image_paths`` is a list - of resolved file paths (one per channel), and ``outlines`` is - either ``None`` or the path to the outline image for this site. - """ + def getImagePaths(self, r): key = self.keyGen(r) - list_images = [r[ch] for ch in self.channels] - paths = [(os.path.split(r[ch]))[0] for ch in self.channels] - image = [list_images[ch] if os.path.isdir(paths[ch]) else self.root + "/" + list_images[ch] for ch in range(len(paths))] + image = [os.path.join(self.root, r[ch]) for ch in self.channels] outlines = self.outlines if outlines is not None: outlines = self.outlines + r["Outlines"] return (key, image, outlines) - def prepare_training_locations(self): - # Load single cell locations in one data frame - image_loc = ImageLocations(self.meta.train, self.get_image_paths, self.targets) - locations = image_loc.load_locations(self.config) - locations = pd.concat(locations) - - # Group by image and count the number of single cells per image in the column ID - self.training_images = locations.groupby(["ImageKey", "Target"])["ID"].count().reset_index() - - workers = self.config["train"]["sampling"]["workers"] - batch_size = self.config["train"]["model"]["params"]["batch_size"] - cache_size = self.config["train"]["sampling"]["cache_size"] - self.sampling_factor = self.config["train"]["sampling"]["factor"] - - # Count the total number of single cells - self.total_single_cells = len(locations) - # Median number of images per class - self.sample_images = int(np.median(self.training_images.groupby("Target").count()["ID"])) - # Number of classes - targets = len(self.training_images["Target"].unique()) - self.config["num_classes"] = targets - # Median number of single cells per image (column ID has counts as a result of groupby above) - self.sample_locations = int(np.median(self.training_images["ID"])) - # Set the target of single cells per epoch asuming a balanced set - self.cells_per_epoch = int(targets * self.sample_images * self.sample_locations * self.sampling_factor) - # Number of images that each worker should load at a time - self.images_per_worker = int(batch_size / workers) - # Percent of all cells that will be loaded in memory at a given moment in the queue - self.cache_coverage = 100*(cache_size / self.cells_per_epoch) - # Number of gradient updates required to approximately use all cells in an epoch - self.steps_per_epoch = int(self.cells_per_epoch / batch_size) - - self.data_rotation = 0 - self.cache_records = 0 - self.shuffle_training_images() - - - def show_setup(self): - print(" || => Total single cells:", self.total_single_cells) - print(" || => Median # of images per class:", self.sample_images) - print(" || => Number of classes:", len(self.training_images["Target"].unique())) - print(" || => Median # of cells per image:", self.sample_locations) - print(" || => Approx. cells per epoch (with balanced sampling):", self.cells_per_epoch) - print(" || => Images sampled per worker:", self.images_per_worker) - print(" || => Cache data coverage: {}%".format(int(self.cache_coverage))) - print(" || => Steps per epoch:", self.steps_per_epoch) - - - def show_stats(self): ## Deprecated? - # Proportion of images loaded by workers from all images that they should load in one epoch (recall) - worker_efficiency = int(100 * (self.data_rotation / self.training_sample.shape[0])) - # Proportion of single cells placed in the cache from all those that should be used in one epoch - cache_usage = int(100 * self.cache_records / self.cells_per_epoch) - #print("Training set coverage: {}% (worker efficiency). Data rotation: {}% (cache usage).".format( - # worker_efficiency, - # cache_usage) - #) - self.data_rotation = 0 - self.cache_records = 0 - return {'worker_efficiency': worker_efficiency, 'cache_usage': cache_usage} - - def shuffle_training_images(self): - # Images in the original metadata file are resampled at each epoch - sample = [] - for c in self.meta.train[self.sampling_field].unique(): - # Sample the same number of images per class. Oversample if the class has less images than needed + def sampleImages(self, sampling_values, nImgCat): + keys = [] + images = [] + targets = [] + outlines = [] + for c in sampling_values: mask = self.meta.train[self.sampling_field] == c - available = self.meta.train[mask].shape[0] - rec = self.meta.train[mask].sample(n=self.sample_images, replace=available < self.sample_images) - sample.append(rec) - - # Shuffle and restart pointers. Note that training sample has images instead of single cells. - self.training_sample = pd.concat(sample) - self.training_sample = self.training_sample.sample(frac=1.0).reset_index(drop=True) - self.batch_pointer = 0 - - def get_train_batch(self, lock): - # Select the next group of available images for cropping - lock.acquire() - df = self.training_sample[self.batch_pointer:self.batch_pointer + self.images_per_worker].copy() - self.batch_pointer += self.images_per_worker - self.data_rotation += self.images_per_worker - if self.batch_pointer > self.training_sample.shape[0]: - self.shuffle_training_images() - lock.release() - - # Prepare the batch and cropping information for these images - batch = {"keys": [], "images": [], "targets": [], "locations": []} - sample = max(1, int(self.sample_locations * self.sampling_factor)) - for k, r in df.iterrows(): - key, image, outl = self.get_image_paths(r) - batch["keys"].append(key) - batch["targets"].append([t.get_values(r) for t in self.targets]) - batch["images"].append(deepprofiler.dataset.pixels.openImage(image, outl)) - batch["locations"].append(deepprofiler.imaging.boxes.get_locations(key, self.config, random_sample=sample)) + rec = self.meta.train[mask].sample(n=nImgCat, replace=True) + for i, r in rec.iterrows(): + key, image, outl = self.getImagePaths(r) + keys.append(key) + images.append(image) + targets.append([t.get_values(r) for t in self.targets]) + outlines.append(outl) + return keys, images, targets, outlines + + def getTrainBatch(self, N): + #s = deepprofiler.dataset.utils.tic() + # Batch size is N + values = self.sampling_values.copy() + # 1. Sample categories + if len(values) > N: + np.random.shuffle(values) + values = values[0:N] + + # 2. Define images per category + nImgCat = int(N / len(values)) + residual = N % len(values) + + # 3. Select images per category + keys, images, targets, outlines = self.sampleImages(values, nImgCat) + if residual > 0: + np.random.shuffle(values) + rk, ri, rl, ro = self.sampleImages(values[0:residual], 1) + keys += rk + images += ri + targets += rl + outlines += ro + + # 4. Open images + batch = {"keys": keys, "images": [], "targets": targets} + for i in range(len(images)): + image_array = deepprofiler.dataset.pixels.openImage(images[i], outlines[i]) + # TODO: Implement pixel normalization using control statistics + #image_array -= 128.0 + batch["images"].append(image_array) + #dataset.utils.toc("Loading batch", s) return batch def scan(self, f, frame="train", check=lambda k: True): - """Iterate over images and call ``f`` for each one that passes ``check``. - - This is the primary driver for both profiling and compression. Images - are loaded sequentially (not in parallel) using - :func:`~deepprofiler.dataset.pixels.openImage`. - - Args: - f: Callable ``(index, image_array, meta_row)`` invoked for each - image. ``image_array`` is a ``(H, W, C)`` numpy array. - frame: Which subset to iterate: ``"all"`` for the full metadata, - ``"val"`` for the validation split, or ``"train"`` (default) - for the training split. - check: Optional predicate ``(meta_row) -> bool``. Images for - which ``check`` returns ``False`` are skipped. Defaults to - always returning ``True``. Used by - :meth:`~deepprofiler.profiling.Profile.check` to skip - already-profiled images. - """ if frame == "all": frame = self.meta.data.iterrows() elif frame == "val": @@ -274,8 +86,9 @@ def scan(self, f, frame="train", check=lambda k: True): else: frame = self.meta.train.iterrows() - images = [(i, self.get_image_paths(r), r) for i, r in frame] + images = [(i, self.getImagePaths(r), r) for i, r in frame] for img in images: + # img => [0] index key, [1] => [0:key, 1:paths, 2:outlines], [2] => metadata index = img[0] meta = img[2] if check(meta): @@ -284,14 +97,6 @@ def scan(self, f, frame="train", check=lambda k: True): return def number_of_records(self, dataset): - """Return the number of rows in the requested split. - - Args: - dataset: ``"all"``, ``"train"``, or ``"val"``. - - Returns: - Integer row count, or 0 for an unrecognised ``dataset`` value. - """ if dataset == "all": return len(self.meta.data) elif dataset == "val": @@ -302,73 +107,45 @@ def number_of_records(self, dataset): return 0 def add_target(self, new_target): - """Append a :class:`~deepprofiler.dataset.target.MetadataColumnTarget` to ``self.targets``.""" self.targets.append(new_target) - -def read_dataset(config, mode='train'): - """Build an :class:`ImageDataset` from a config dict. - - Reads the metadata index CSV, optionally replaces ``.tif``/``.tiff`` - extensions with ``.png`` if image compression was applied, merges outline - CSVs if specified, adds classification targets, and (for the training path) - pre-loads all cell locations. - - Args: - config: Experiment configuration dict. Must contain at minimum - ``paths.index``, ``dataset``, ``train.partition``, and - ``prepare.compression`` sections. - mode: ``"train"`` to split metadata and pre-load locations, or any - other value (e.g. ``"profile"``) to skip those steps. - - Returns: - Fully initialised :class:`ImageDataset`. - """ +def read_dataset(config): + # Read metadata and split dataset in training and validation metadata = deepprofiler.dataset.metadata.Metadata(config["paths"]["index"], dtype=None) - if config["prepare"]["compression"]["implement"]: - metadata.data.replace({'.tiff': '.png', '.tif': '.png'}, inplace=True, regex=True) # Add outlines if specified outlines = None if "outlines" in config["prepare"].keys() and config["prepare"]["outlines"] != "": - df = pd.read_csv(config["paths"]["metadata"] + "/outlines.csv") + df = pd.read_csv(os.path.join(config["paths"]["metadata"], "outlines.csv")) metadata.mergeOutlines(df) - outlines = config["paths"]["root"] + "inputs/outlines/" + outlines = os.path.join(config["paths"]["root"], "inputs", "outlines") print(metadata.data.info()) # Split training data - if mode == 'train' and config["train"]["model"]["crop_generator"] == 'crop_generator': - split_field = config["train"]["partition"]["split_field"] - trainingFilter = lambda df: df[split_field].isin(config["train"]["partition"]["training"]) - validationFilter = lambda df: df[split_field].isin(config["train"]["partition"]["validation"]) - metadata.splitMetadata(trainingFilter, validationFilter) - + split_field = config["train"]["dset"]["split_field"] + trainingFilter = lambda df: df[split_field].isin(config["train"]["dset"]["training_values"]) + validationFilter = lambda df: df[split_field].isin(config["train"]["dset"]["validation_values"]) + metadata.splitMetadata(trainingFilter, validationFilter) # Create a dataset - keyGen = lambda r: "{}/{}-{}".format(r["Metadata_Plate"], r["Metadata_Well"], r["Metadata_Site"]) + keyGen = lambda r: os.path.join(r["Metadata_Plate"], "{}-{}".format(r["Metadata_Well"], r["Metadata_Site"])) + dset = ImageDataset( - metadata, - config["dataset"]["metadata"]["label_field"], - config["dataset"]["images"]["channels"], - config["paths"]["images"], - keyGen, - config + metadata=metadata, + sampling_field=config["train"]["sampling"]["field"], + channels=config["dataset"]["images"]["channels"], + dataRoot=config["paths"]["images"], + keyGen=keyGen ) # Add training targets - for t in config["train"]["partition"]["targets"]: + for t in config["train"]["dset"]["targets"]: new_target = deepprofiler.dataset.target.MetadataColumnTarget(t, metadata.data[t].unique()) dset.add_target(new_target) # Activate outlines for masking if needed - if config["dataset"]["locations"]["mask_objects"]: + if config["train"]["dset"]["mask_objects"]: dset.outlines = outlines - # For training with sampled_crop_generator, no need to read locations again. - if mode == 'train' and config["train"]["model"]["crop_generator"] == 'crop_generator': - dset.prepare_training_locations() - return dset - - diff --git a/deepprofiler/imaging/boxes.py b/deepprofiler/imaging/boxes.py index d0180cc..9dc6899 100644 --- a/deepprofiler/imaging/boxes.py +++ b/deepprofiler/imaging/boxes.py @@ -1,196 +1,43 @@ -"""Cell bounding box construction from centroid CSV files. - -DeepProfiler does not perform cell segmentation. Cell locations must be -provided externally (e.g. from CellProfiler) as CSV files stored under -``config["paths"]["locations"]`` with the naming convention:: - - {plate}/{well}-{site}-Nuclei.csv - -Each CSV must contain at minimum two columns: - -- ``Nuclei_Location_Center_X`` — centroid X coordinate in pixels -- ``Nuclei_Location_Center_Y`` — centroid Y coordinate in pixels - -:func:`get_locations` reads the appropriate CSV and returns a DataFrame of -centroids. :func:`prepare_boxes` converts those centroids into normalised -``[y1, x1, y2, x2]`` bounding boxes suitable for -``tf.image.crop_and_resize``. -""" - import os import numpy as np import pandas as pd -X_KEY = "Nuclei_Location_Center_X" -Y_KEY = "Nuclei_Location_Center_Y" +################################################# +## BOUNDING BOX HANDLING +################################################# -def get_locations(image_key, config, random_sample=None, seed=None): - """Return cell centroid locations for one image. - - Dispatches to :func:`get_single_cell_locations` (``mode: single_cells``) - or :func:`get_full_image_locations` (``mode: full_image``) based on - ``config["dataset"]["locations"]["mode"]``. - - Args: - image_key: String key of the form ``"{plate}/{well}-{site}"`` that - identifies the image within the dataset. - config: Experiment configuration dict. - random_sample: If not ``None``, randomly sample this many locations. - seed: Random seed for reproducible sampling. - - Returns: - DataFrame with at least ``Nuclei_Location_Center_X`` and - ``Nuclei_Location_Center_Y`` columns, or an empty DataFrame if the - locations file is missing. - """ - if config["dataset"]["locations"]["mode"] == "single_cells": - return get_single_cell_locations(image_key, config, random_sample, seed) - elif config["dataset"]["locations"]["mode"] == "full_image": - return get_full_image_locations(image_key, config, random_sample, seed) - else: - return None - - -def get_single_cell_locations(image_key, config, random_sample=None, seed=None): - """Read per-cell centroids from the locations CSV for one image. - - Constructs the CSV path as:: - - {config["paths"]["locations"]}/{plate}/{well}-{site}-Nuclei.csv - - Returns an empty DataFrame (with the expected columns) if the file does - not exist, so the caller can safely check ``len(locations) == 0``. - - Args: - image_key: String key ``"{plate}/{well}-{site}"``. - config: Experiment configuration dict. - random_sample: If not ``None`` and smaller than the number of cells, - randomly sample this many rows. - seed: Random seed for reproducible sampling. - - Returns: - DataFrame of centroid coordinates. - """ +def get_locations(image_key, config, randomize=True, seed=None): keys = image_key.split("/") - locations_file = "{}/{}-{}.csv".format(keys[0], keys[1], "Nuclei") - locations_path = os.path.join(config["paths"]["locations"], locations_file) + locations_file = os.path.join(keys[0], "{}-{}.csv".format( + keys[1], + config["train"]["sampling"]["locations_field"] + )) + locations_path = os.path.join(config["paths"]["root"], + config["paths"]["locations"], + locations_file) if os.path.exists(locations_path): locations = pd.read_csv(locations_path) - if random_sample is not None and random_sample < len(locations): + random_sample = config["train"]["sampling"]["locations"] + if randomize and random_sample is not None and random_sample < len(locations): return locations.sample(random_sample, random_state=seed) else: return locations else: - return pd.DataFrame(columns=[X_KEY, Y_KEY]) - - -def get_full_image_locations(image_key, config, random_sample, seed): - """Generate a regular grid (or random sample) of crop centres for one image. + y_key = config["train"]["sampling"]["locations_field"] + "_Location_Center_Y" + x_key = config["train"]["sampling"]["locations_field"] + "_Location_Center_X" + return pd.DataFrame(columns=[x_key, y_key]) - Used when ``config["dataset"]["locations"]["mode"]`` is ``"full_image"``. - If the view covers the whole image a single centre point is returned. - Otherwise a grid of non-overlapping ``view_size × view_size`` tiles is - produced (or a random sample of that many centres when ``random_sample`` - is set). - Args: - image_key: Ignored — grid is derived from config dimensions. - config: Experiment configuration dict. Uses ``dataset.images.width``, - ``dataset.images.height``, and ``dataset.locations.view_size``. - random_sample: Number of random centres to generate. Pass ``None`` - for a deterministic grid. - seed: Unused (grid generation is deterministic or uses numpy default - RNG). - - Returns: - DataFrame with ``Nuclei_Location_Center_X`` and - ``Nuclei_Location_Center_Y`` columns. - """ - cols = config["dataset"]["images"]["width"] - rows = config["dataset"]["images"]["height"] - view = config["dataset"]["locations"]["view_size"] - assert (view <= cols) and (view <= rows) - cols_margin = cols - view - rows_margin = rows - view - - data = None - if view == cols: - data = [[cols / 2, rows / 2]] - else: - if random_sample is not None: - cols_pos = np.random.randint(low=-cols_margin / 2, high=cols_margin / 2, size=random_sample) + cols / 2 - rows_pos = np.random.randint(low=-rows_margin / 2, high=rows_margin / 2, size=random_sample) + rows / 2 - data = [[cols_pos[i], rows_pos[i]] for i in range(random_sample)] - elif random_sample is None: - cols_pos = np.linspace(view / 2, cols - view / 2, int(np.ceil(cols / view))) - rows_pos = np.linspace(view / 2, rows - view / 2, int(np.ceil(rows / view))) - grid = np.meshgrid(rows_pos, cols_pos) - rows_pos = grid[0].flatten() - cols_pos = grid[1].flatten() - data = [[rows_pos[i], cols_pos[i]] for i in range(len(cols_pos))] - - return pd.DataFrame(data=data, columns=[X_KEY, Y_KEY]) +def load_batch(dataset, config): + batch = dataset.getTrainBatch(config["train"]["sampling"]["images"]) + batch["locations"] = [ get_locations(x, config) for x in batch["keys"] ] + return batch def prepare_boxes(batch, config): - """Convert centroid locations to normalised bounding boxes for crop_and_resize. - - Dispatches to :func:`get_cropping_regions` with ``box_size`` set to - ``dataset.locations.box_size`` (single-cell mode) or - ``dataset.locations.view_size`` (full-image mode). - - Args: - batch: Dict with keys ``"images"``, ``"locations"``, and ``"targets"``. - ``locations`` is a list of DataFrames (one per image in the batch). - config: Experiment configuration dict. - - Returns: - Tuple ``(boxes, box_ind, targets, masks)`` ready to feed into the TF1 - crop graph placeholders. - """ - if config["dataset"]["locations"]["mode"] == "single_cells": - return get_cropping_regions(batch, config, config["dataset"]["locations"]["box_size"]) - elif config["dataset"]["locations"]["mode"] == "full_image": - view = config["dataset"]["locations"]["view_size"] - return get_cropping_regions(batch, config, view) - else: - return None - - -def get_cropping_regions(batch, config, box_size): - """Build normalised bounding boxes from centroid coordinates. - - For each centroid ``(x, y)``, computes a square bounding box:: - - [y - box_size/2, x - box_size/2, y + box_size/2, x + box_size/2] - - Coordinates are normalised to ``[0, 1]`` by dividing Y by image height - and X by image width, as required by ``tf.image.crop_and_resize``. Crops - that extend beyond the image boundary are automatically zero-padded by TF. - - Also reads the object mask label (the pixel value of the last channel at - the cell centroid) when ``mask_objects`` is enabled in config. - - Args: - batch: Dict with ``"images"``, ``"locations"``, and ``"targets"``. - config: Experiment configuration dict. - box_size: Side length of the bounding box in pixels. - - Returns: - Tuple of four arrays: - - - ``boxes``: float32 array of shape ``(total_cells, 4)`` with - normalised ``[y1, x1, y2, x2]`` coordinates. - - ``box_ind``: int32 array of shape ``(total_cells,)`` mapping each - box to its image index in the batch. - - ``targets``: list of int32 arrays, one per target, each of shape - ``(total_cells,)``. - - ``masks``: int32 array of shape ``(total_cells,)`` with the object - mask label for each cell (0 when masking is disabled). - """ - locations_batch = batch["locations"] + locationsBatch = batch["locations"] image_targets = batch["targets"] images = batch["images"] all_boxes = [] @@ -198,41 +45,41 @@ def get_cropping_regions(batch, config, box_size): all_targets = [[] for i in range(len(image_targets[0]))] all_masks = [] index = 0 - - for locations in locations_batch: + y_key = config["train"]["sampling"]["locations_field"] + "_Location_Center_Y" + x_key = config["train"]["sampling"]["locations_field"] + "_Location_Center_X" + for locations in locationsBatch: + # Collect and normalize boxes between 0 and 1 boxes = np.zeros((len(locations), 4), np.float32) - boxes[:, 0] = locations[Y_KEY] - box_size / 2 - boxes[:, 1] = locations[X_KEY] - box_size / 2 - boxes[:, 2] = locations[Y_KEY] + box_size / 2 - boxes[:, 3] = locations[X_KEY] + box_size / 2 - boxes[:, [0, 2]] /= config["dataset"]["images"]["height"] - boxes[:, [1, 3]] /= config["dataset"]["images"]["width"] - + boxes[:,0] = locations[y_key] - config["train"]["sampling"]["box_size"]/2 + boxes[:,1] = locations[x_key] - config["train"]["sampling"]["box_size"]/2 + boxes[:,2] = locations[y_key] + config["train"]["sampling"]["box_size"]/2 + boxes[:,3] = locations[x_key] + config["train"]["sampling"]["box_size"]/2 + boxes[:,[0,2]] /= config["train"]["dset"]["height"] + boxes[:,[1,3]] /= config["train"]["dset"]["width"] + # Create indicators for this set of boxes, belonging to the same image box_ind = index * np.ones((len(locations)), np.int32) - + # Propage the same labels to all crops for i in range(len(image_targets[index])): all_targets[i].append(image_targets[index][i] * np.ones((len(locations)), np.int32)) - + # Identify object mask for each crop masks = np.zeros(len(locations), np.int32) - if config["dataset"]["locations"]["mask_objects"]: + if config["train"]["dset"]["mask_objects"]: i = 0 for lkey in locations.index: - y = int(locations.loc[lkey, Y_KEY]) - x = int(locations.loc[lkey, X_KEY]) - patch = images[index][max(y - 5, 0):y + 5, max(x - 5, 0):x + 5, -1] + y = int(locations.loc[lkey, y_key]) + x = int(locations.loc[lkey, x_key]) + patch = images[index][max(y-5,0):y+5, max(x-5,0):x+5, -1] if np.size(patch) > 0: masks[i] = int(np.median(patch)) i += 1 - + # Pile up the resulting variables all_boxes.append(boxes) all_indices.append(box_ind) all_masks.append(masks) index += 1 - result = ( - np.concatenate(all_boxes), - np.concatenate(all_indices), - [np.concatenate(t) for t in all_targets], - np.concatenate(all_masks), - ) + result = (np.concatenate(all_boxes), + np.concatenate(all_indices), + [np.concatenate(t) for t in all_targets], + np.concatenate(all_masks)) return result