Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion tests/test_cell_annotation_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from __future__ import annotations

import json
import importlib.util
import json
import inspect
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

from ueler.viewer.interfaces import FlowsomParamsProvider, HeatmapStateProvider, SelectionSpec
from ueler.viewer.plugin.cell_annotation.manifest import Manifest
from ueler.viewer.plugin.cell_annotation.plugin import CellAnnotationPlugin, _flag_enabled
from ueler.viewer.plugin.cell_annotation.selection_spec import MaterializedSelectionSpec
Expand Down Expand Up @@ -214,6 +217,31 @@ def test_union_requires_matching_dataset(self):
left.union(right)


class TestCrossPluginInterfaces(unittest.TestCase):
def test_compatibility_interfaces_are_importable(self):
compatibility_module = importlib.import_module("viewer.interfaces")

self.assertIs(compatibility_module.SelectionSpec, SelectionSpec)
self.assertTrue(hasattr(SelectionSpec, "cardinality"))
self.assertTrue(hasattr(HeatmapStateProvider, "export_heatmap_state"))
self.assertTrue(hasattr(FlowsomParamsProvider, "run_flowsom"))

def test_flowsom_protocol_exposes_expected_signature(self):
signature = inspect.signature(FlowsomParamsProvider.run_flowsom)
self.assertEqual(
list(signature.parameters),
[
"self",
"selection",
"params",
"training_markers",
"extra_markers",
"imputation",
"projection",
],
)


class TestFeatureFlagAndPluginLifecycle(unittest.TestCase):
def test_flag_defaults_to_disabled(self):
with patch.dict(os.environ, {}, clear=True):
Expand Down Expand Up @@ -287,8 +315,73 @@ def test_plugin_rebuilds_manifest_when_missing(self):
)
self.assertTrue(plugin.manifest.path.exists())

def test_register_loaded_providers_logs_each_provider_once(self):
plugin = CellAnnotationPlugin(MagicMock())
heatmap = MagicMock(name="heatmap")
flowsom = MagicMock(name="flowsom")
side_plots = SimpleNamespace(
heatmap_output=SimpleNamespace(_register_cell_annotation_provider=lambda: plugin.register_heatmap(heatmap)),
flowsom_output=SimpleNamespace(_register_cell_annotation_provider=lambda: plugin.register_flowsom(flowsom)),
)

with self.assertLogs("ueler.viewer.plugin.cell_annotation.plugin", level="INFO") as logs:
plugin.register_loaded_providers(side_plots)
plugin.register_loaded_providers(side_plots)

self.assertIs(plugin.heatmap_provider, heatmap)
self.assertIs(plugin.flowsom_provider, flowsom)
self.assertEqual(sum("registered Heatmap provider" in message for message in logs.output), 1)
self.assertEqual(sum("registered FlowSOM provider" in message for message in logs.output), 1)

def test_register_loaded_providers_allows_missing_flowsom(self):
plugin = CellAnnotationPlugin(MagicMock())
heatmap = MagicMock(name="heatmap")
side_plots = SimpleNamespace(
heatmap_output=SimpleNamespace(_register_cell_annotation_provider=lambda: plugin.register_heatmap(heatmap)),
other_output=object(),
)

plugin.register_loaded_providers(side_plots)

self.assertIs(plugin.heatmap_provider, heatmap)
self.assertIsNone(plugin.flowsom_provider)


class TestProviderStubMethods(unittest.TestCase):
def test_heatmap_provider_registration_calls_cell_annotation_hook(self):
heatmap = types.SimpleNamespace()
register_heatmap = MagicMock()
heatmap.main_viewer = types.SimpleNamespace(
cell_annotation_plugin=types.SimpleNamespace(register_heatmap=register_heatmap)
)

heatmap_stubs = {
"ipywidgets": _widget_module(),
"pandas": types.ModuleType("pandas"),
"scipy.cluster.hierarchy": types.SimpleNamespace(dendrogram=lambda *_a, **_k: None),
"ueler.viewer.observable": types.SimpleNamespace(Observable=object),
"ueler.viewer.plugin.plugin_base": types.SimpleNamespace(
PluginBase=type("PluginBase", (), {"__init__": lambda self, *_args, **_kwargs: None})
),
"ueler.viewer.plugin.heatmap_adapter": types.SimpleNamespace(
HeatmapModeAdapter=type("HeatmapModeAdapter", (), {"__init__": lambda self, *_args, **_kwargs: None})
),
"ueler.viewer.plugin.heatmap_layers": types.SimpleNamespace(
DataLayer=type("DataLayer", (), {}),
InteractionLayer=type("InteractionLayer", (), {}),
DisplayLayer=type("DisplayLayer", (), {}),
),
}
module = _load_module_from_file(
"test_heatmap_module_registration",
REPO_ROOT / "ueler/viewer/plugin/heatmap.py",
heatmap_stubs,
)

module.HeatmapDisplay._register_cell_annotation_provider(heatmap)

register_heatmap.assert_called_once_with(heatmap)

def test_heatmap_import_stub_records_last_path(self):
heatmap_stubs = {
"ipywidgets": _widget_module(),
Expand Down Expand Up @@ -361,6 +454,53 @@ def test_flowsom_selection_context_stub_is_stored(self):

self.assertIs(flowsom._selection_context, selection)

def test_flowsom_provider_registration_calls_cell_annotation_hook(self):
flowsom = types.SimpleNamespace()
register_flowsom = MagicMock()
flowsom.main_viewer = types.SimpleNamespace(
cell_annotation_plugin=types.SimpleNamespace(register_flowsom=register_flowsom)
)

numpy_stub = types.ModuleType("numpy")
numpy_stub.inf = float("inf")
flowsom_stubs = {
"numpy": numpy_stub,
"pandas": types.ModuleType("pandas"),
"seaborn": types.ModuleType("seaborn"),
"ipywidgets": _widget_module(),
"matplotlib.font_manager": types.ModuleType("matplotlib.font_manager"),
"matplotlib.pyplot": types.ModuleType("matplotlib.pyplot"),
"matplotlib.backend_bases": types.SimpleNamespace(MouseButton=object),
"matplotlib.text": types.SimpleNamespace(Annotation=object),
"IPython.display": types.SimpleNamespace(display=lambda *_a, **_k: None),
"mpl_toolkits.axes_grid1": types.SimpleNamespace(make_axes_locatable=lambda *_a, **_k: None),
"mpl_toolkits.axes_grid1.anchored_artists": types.SimpleNamespace(AnchoredSizeBar=object),
"scipy.cluster.hierarchy": types.SimpleNamespace(
cut_tree=lambda *_a, **_k: None,
dendrogram=lambda *_a, **_k: None,
linkage=lambda *_a, **_k: None,
),
"ueler.image_utils": types.SimpleNamespace(
color_one_image=lambda *_a, **_k: None,
estimate_color_range=lambda *_a, **_k: None,
process_single_crop=lambda *_a, **_k: None,
),
"ueler.viewer.decorators": types.SimpleNamespace(update_status_bar=lambda func: func),
"ueler.viewer.observable": types.SimpleNamespace(Observable=object),
"ueler.viewer.plugin.plugin_base": types.SimpleNamespace(
PluginBase=type("PluginBase", (), {"__init__": lambda self, *_args, **_kwargs: None})
),
}
module = _load_module_from_file(
"test_flowsom_module_registration",
REPO_ROOT / "ueler/viewer/plugin/run_flowsom.py",
flowsom_stubs,
)

module.RunFlowsom._register_cell_annotation_provider(flowsom)

register_flowsom.assert_called_once_with(flowsom)


if __name__ == "__main__": # pragma: no cover
unittest.main()
11 changes: 11 additions & 0 deletions ueler/viewer/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Mapping, Protocol, runtime_checkable


Expand Down Expand Up @@ -51,3 +52,13 @@ def import_flowsom_params(self, params: Mapping[str, Any]) -> None:
def set_selection_context(self, selection: SelectionSpec) -> None:
"""Constrain the FlowSOM plugin to a Cell Annotation selection."""

def run_flowsom(
self,
selection: SelectionSpec | None,
params: Mapping[str, Any] | None,
training_markers: Sequence[str] | None,
extra_markers: Sequence[str] | None,
imputation: Any,
projection: Any,
) -> Mapping[str, Any]:
"""Run FlowSOM for the current selection and return execution metadata."""
1 change: 1 addition & 0 deletions ueler/viewer/main_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ def _register_cell_annotation_plugin(self) -> None:
plugin = CellAnnotationPlugin(self)
setattr(self, CellAnnotationPlugin.REGISTRY_KEY, plugin)
plugin.on_dataset_opened(self.base_folder)
plugin.register_loaded_providers(self.SidePlots)
if self._debug:
print(f"[CellAnnotation] plugin registered: {plugin.store and plugin.store.store_path}")

Expand Down
16 changes: 16 additions & 0 deletions ueler/viewer/plugin/cell_annotation/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,23 @@ def on_dataset_closed(self) -> None:
self._manifest = None

def register_heatmap(self, provider: HeatmapStateProvider) -> None:
if self._heatmap_provider is provider:
return
self._heatmap_provider = provider
logger.info("[CellAnnotation] registered Heatmap provider: %s", type(provider).__name__)

def register_flowsom(self, provider: FlowsomParamsProvider) -> None:
if self._flowsom_provider is provider:
return
self._flowsom_provider = provider
logger.info("[CellAnnotation] registered FlowSOM provider: %s", type(provider).__name__)

def register_loaded_providers(self, side_plots: object) -> None:
"""Replay provider self-registration for already-instantiated side-plot plugins."""
for attr_name in dir(side_plots):
if attr_name.startswith("_"):
continue
provider = getattr(side_plots, attr_name, None)
register = getattr(provider, "_register_cell_annotation_provider", None)
if callable(register):
register()
2 changes: 0 additions & 2 deletions ueler/viewer/plugin/heatmap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# viewer/cell_gallery.py

from ipywidgets import (SelectMultiple, FloatSlider, Dropdown, VBox, Output, TagsInput,
Checkbox, IntText, Text, Button, HBox, Layout, IntSlider, Tab, RadioButtons, HTML)
from scipy.cluster.hierarchy import dendrogram
Expand Down
41 changes: 38 additions & 3 deletions ueler/viewer/plugin/run_flowsom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import pickle
from collections import OrderedDict
from collections.abc import Sequence
from typing import Any, Mapping

import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -90,8 +92,30 @@ def on_cell_table_change(self):
self.ui_component.channel_selector.allowed_tags = self.main_viewer.cell_table.columns.tolist()
self.ui_component.subset_on_dropdown.options = self.main_viewer.cell_table.select_dtypes(include=['int', 'int64', 'object']).columns.tolist()

def _run_flowsom_button_click(self, _button) -> None:
self.run_flowsom(
selection=self._selection_context,
params=None,
training_markers=None,
extra_markers=None,
imputation=None,
projection=None,
)

@update_status_bar
def run_flowsom(self, b):
def run_flowsom(
self,
selection: SelectionSpec | None,
params: Mapping[str, Any] | None,
training_markers: Sequence[str] | None,
extra_markers: Sequence[str] | None,
imputation: Any,
projection: Any,
) -> Mapping[str, Any]:
if selection is not None:
self.set_selection_context(selection)
if params is not None:
self.import_flowsom_params(params)
# First, subset the data based on the selected high-level clusters
subset_on = self.ui_component.subset_on_dropdown.value
subset = list(self.ui_component.subset_selector.value)
Expand Down Expand Up @@ -134,6 +158,17 @@ def run_flowsom(self, b):
self.main_viewer.inform_plugins("on_cell_table_change")

print(f"FlowSOM clustering completed. The labels are saved in the column {column_name_text}")
selected_markers = list(self.ui_component.channel_selector.value or ())
resolved_training_markers = selected_markers if training_markers is None else list(training_markers)
return {
"column_name": column_name_text,
"params": self.export_flowsom_params(),
"selection": self._selection_context,
"training_markers": resolved_training_markers,
"extra_markers": list(extra_markers or ()),
"imputation": imputation,
"projection": projection,
}

def _register_cell_annotation_provider(self) -> None:
plugin = getattr(self.main_viewer, "cell_annotation_plugin", None)
Expand All @@ -159,7 +194,7 @@ def export_flowsom_params(self) -> dict:
"seed": self.ui_component.seed_input.value,
}

def import_flowsom_params(self, params):
def import_flowsom_params(self, params: Mapping[str, Any]) -> None:
mapping = {
"subset_on": self.ui_component.subset_on_dropdown,
"column_name": self.ui_component.column_name_text,
Expand Down Expand Up @@ -312,7 +347,7 @@ def __init__(self, parent):
tooltip='Run FlowSOM clustering',
icon='play'
)
self.run_button.on_click(parent.run_flowsom)
self.run_button.on_click(parent._run_flowsom_button_click)

class Data:
def __init__(self):
Expand Down
Loading