diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py
index 952c7323..1e39f58d 100644
--- a/src/nrtk_explorer/app/core.py
+++ b/src/nrtk_explorer/app/core.py
@@ -23,6 +23,12 @@
config_features_to_enabled_features,
config_preset_to_enabled_features,
)
+from nrtk_explorer.app.images.image_ids import (
+ GROUND_TRUTH_MODEL,
+)
+from nrtk_explorer.app.images.stateful_annotations import (
+ make_stateful_annotations,
+)
from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.images.image_server import ImageServer
from nrtk_explorer.app.applet import Applet
@@ -132,6 +138,8 @@ def __init__(self, server=None, **kwargs):
images = Images(server=self.server)
self._image_server = ImageServer(server=self.server, images=images)
+ ground_truth_annotations = make_stateful_annotations(self.server, GROUND_TRUTH_MODEL)
+ self.state.inference_models_obj = {0: {"name": "groundtruth"}}
self._datasets_app = None
if self.datasets_enabled:
@@ -150,15 +158,23 @@ def __init__(self, server=None, **kwargs):
from nrtk_explorer.app.features.transforms import TransformsApp
self._transforms_app = TransformsApp(
- server=self.server.create_child_server(), images=images, **kwargs
+ server=self.server.create_child_server(),
+ images=images,
+ ground_truth_annotations=ground_truth_annotations,
+ **kwargs,
)
+ self.context.tramsforms_app = self._transforms_app
+
self._images_app = None
if self.images_enabled:
from nrtk_explorer.app.features.images import ImagesApp
self._images_app = ImagesApp(
- server=self.server.create_child_server(), images=images, **kwargs
+ server=self.server.create_child_server(),
+ images=images,
+ ground_truth_annotations=ground_truth_annotations,
+ **kwargs,
)
self._inference_app = None
@@ -207,6 +223,7 @@ def __init__(self, server=None, **kwargs):
self.state.dataset_ids = []
self.state.hovered_id = None
self.state.maximised_id = None
+ self.state.maximised_annotations = []
def clear_hovered(**kwargs):
self.state.hovered_id = None
diff --git a/src/nrtk_explorer/app/features/images.py b/src/nrtk_explorer/app/features/images.py
index 14476e50..183a1aaf 100644
--- a/src/nrtk_explorer/app/features/images.py
+++ b/src/nrtk_explorer/app/features/images.py
@@ -83,7 +83,7 @@ def __init__(
self.images = images or Images(server)
- ground_truth_annotations = ground_truth_annotations or make_stateful_annotations(
+ self.ground_truth_annotations = ground_truth_annotations or make_stateful_annotations(
server, GROUND_TRUTH_MODEL
)
self.context.ground_truth_annotations = ground_truth_annotations.annotations_factory
diff --git a/src/nrtk_explorer/app/features/inference.py b/src/nrtk_explorer/app/features/inference.py
index 8ba4c6c3..0bfa1fbd 100644
--- a/src/nrtk_explorer/app/features/inference.py
+++ b/src/nrtk_explorer/app/features/inference.py
@@ -46,7 +46,7 @@ def __init__(
self.state.inference_models = [self.state.inference_models_options[0]]
self.state.inference_multi_model = False
- inference_models_obj = {0: {"name": "ground-truth"}}
+ inference_models_obj = {0: {"name": "groundtruth"}}
for i, model in enumerate(self.state.inference_models):
inference_models_obj[i + 1] = {"name": model}
@@ -101,7 +101,7 @@ def update_inference_models(self, models):
"transformed_annotations": transformed_annotations.annotations_factory,
}
- models_obj = {0: {"name": "ground-truth"}}
+ models_obj = {0: {"name": "groundtruth"}}
for i, model in enumerate(models):
models_obj[i + 1] = {"name": model}
diff --git a/src/nrtk_explorer/app/features/parameters.py b/src/nrtk_explorer/app/features/parameters.py
index 3e666b43..01501e3b 100644
--- a/src/nrtk_explorer/app/features/parameters.py
+++ b/src/nrtk_explorer/app/features/parameters.py
@@ -115,6 +115,15 @@ def transform_apply_ui(self):
flat=True,
)
+ def transform_preview_ui(self):
+ with html.Div(trame_server=self.server):
+ quasar.QBtn(
+ "Preview Transforms",
+ click=(self.server.controller.preview_transform),
+ classes="full-width",
+ flat=True,
+ )
+
def transforms_ui(self):
with html.Div(trame_server=self.server):
TransformsWidget(
diff --git a/src/nrtk_explorer/app/features/preview.py b/src/nrtk_explorer/app/features/preview.py
new file mode 100644
index 00000000..ca9f360f
--- /dev/null
+++ b/src/nrtk_explorer/app/features/preview.py
@@ -0,0 +1,328 @@
+from typing import Dict
+
+from trame.ui.quasar import QLayout
+from trame.widgets import quasar
+from trame.widgets import html
+from trame.app import get_server
+from trame_client.widgets.trame import Getter
+
+from trame_annotations.widgets.annotations import ImageDetection
+from nrtk_explorer.app.images.image_server import ImageServer
+from nrtk_explorer.widgets.nrtk_explorer import AnnotationAggregator
+
+from nrtk_explorer.library.app_config import process_config
+from nrtk_explorer.library.dataset import (
+ get_dataset,
+ expand_hugging_face_datasets,
+ dataset_select_options,
+)
+
+import nrtk_explorer.library.transforms as trans
+import nrtk_explorer.library.yaml_transforms as nrtk_yaml
+import nrtk_explorer.library.serialization_helpers as serialization_helpers
+
+from nrtk_explorer.app.applet import Applet
+from nrtk_explorer.app.features.parameters import ParametersApp
+from nrtk_explorer.app.images.images import Images
+
+from nrtk_explorer.app.images.stateful_annotations import (
+ make_stateful_annotations,
+)
+
+from nrtk_explorer.app.images.image_ids import (
+ GROUND_TRUTH_MODEL,
+)
+
+
+class PreviewApp(Applet):
+ def __init__(
+ self,
+ server,
+ images=None,
+ ground_truth_annotations=None,
+ **kwargs,
+ ):
+ super().__init__(server)
+
+ self.images = images or Images(server)
+
+ self._image_server = ImageServer(server=self.server, images=self.images)
+
+ self.ground_truth_annotations = ground_truth_annotations or make_stateful_annotations(
+ server, GROUND_TRUTH_MODEL
+ )
+ self.context.ground_truth_annotations = self.ground_truth_annotations.annotations_factory
+
+ if self.context.parameters_app is None:
+ self.context.parameters_app = ParametersApp(
+ server=server,
+ )
+
+ self._parameters_app = self.context.parameters_app
+
+ self.state.dataset_ids = []
+ self.state.preview_inference_models_obj = {0: {"name": "groundtruth"}}
+ self.state.preview_image_id = None
+ self.state.transformed_preview_image = None
+
+ self._ui = None
+
+ self._transform_classes: Dict[str, type[trans.ImageTransform]] = {
+ "blur": trans.GaussianBlurTransform,
+ "invert": trans.InvertTransform,
+ "downsample": trans.DownSampleTransform,
+ "identity": trans.IdentityTransform,
+ }
+
+ # Add transform from YAML definition
+ self._transform_classes.update(nrtk_yaml.generate_transforms())
+
+ self._parameters_app._transform_classes = self._transform_classes
+
+ # Initialize the transforms pipeline to the identity
+ self._parameters_app._default_transform = "blur"
+ self._parameters_app.on_add_transform()
+
+ self.server.controller.on_server_ready.add(self.on_server_ready)
+
+ def on_server_ready(self, *args, **kwargs):
+ self.server.controller.preview_transform.add(self.on_preview_transform)
+ self.state.change("preview_image_id")(self.on_preview_transform)
+ self.on_preview_transform()
+
+ def on_preview_transform(self, **kwargs):
+ preview_id = self.state.preview_image_id
+ if preview_id is None:
+ return
+
+ transforms = list(map(lambda t: t["instance"], self.context.transforms))
+
+ for transform in transforms:
+ params = transform.get_parameters()
+ for key, value in transform.get_parameters_description().items():
+ if "deserialize_func" in value.keys():
+ params[key] = getattr(serialization_helpers, value["deserialize_func"])(
+ params[key]
+ )
+ transform.set_parameters(params)
+
+ chained_transform = trans.ChainedImageTransform(transforms)
+
+ image = self.images.get_image_without_cache_eviction(preview_id)
+ transformed_image = chained_transform.execute(image)
+ import base64
+ from io import BytesIO
+
+ buffered = BytesIO()
+ transformed_image.save(buffered, format="JPEG")
+ img_str = base64.b64encode(buffered.getvalue())
+ img_base64 = bytes("data:image/jpeg;base64,", encoding="utf-8") + img_str
+ img_base64_str = img_base64.decode("utf-8")
+ self.state.transformed_preview_image = img_base64_str
+
+ def settings_widget(self):
+ with html.Div(classes="col"):
+ self._parameters_app.transforms_ui()
+
+ def preview_ui(self):
+ with html.Div():
+ self._parameters_app.transform_preview_ui()
+
+ def carousel_ui(self):
+ with quasar.QCarousel(
+ v_model=("preview_image_id",),
+ style="padding-bottom: 6rem;",
+ classes="fit",
+ control_color="primary",
+ animated=True,
+ infinite=True,
+ swipeable=True,
+ navigation=True,
+ thumbnails=True,
+ arrows=True,
+ padding=True,
+ transition_prev="slide-right",
+ transition_next="slide-left",
+ ):
+ with html.Template(
+ raw_attrs=['v-slot:navigation-icon="{ index, name, active, btnProps, onClick }"']
+ ):
+ with Getter(name=("`img_${name}`",)):
+ quasar.QImg(
+ key=("name",),
+ classes="rounded-borders q-mr-md",
+ style=(
+ "active ? 'width: 6rem; height: 6rem; border-style: solid; border-width: 0.125rem; border-color: red;' : 'width: 6rem; height: 6rem;'",
+ ),
+ fit="cover",
+ src=("value",),
+ click="onClick",
+ )
+
+ with html.Template(v_for=("identifier in dataset_ids",)):
+ with quasar.QCarouselSlide(
+ name=("identifier",),
+ key=("identifier",),
+ ):
+ with html.Div(
+ classes="row fit justify-start items-center q-gutter-xs q-col-gutter no-wrap",
+ ):
+ with AnnotationAggregator(
+ image_id=("identifier",),
+ transformed=False,
+ models=("preview_inference_models_obj",),
+ ):
+ with Getter(name=("`img_${identifier}`",)):
+ with html.Div(
+ classes="rounded-borders col-6 full-height",
+ ):
+ ImageDetection(
+ identifier=("identifier",),
+ src=("value",),
+ models=("preview_inference_models_obj",),
+ annotations=("aggregateAnnotations",),
+ categories=("annotation_categories",),
+ container_selector=".row",
+ score_threshold=("confidence_score_threshold", 0.8),
+ color_by="model",
+ )
+
+ with AnnotationAggregator(
+ image_id=("identifier",),
+ transformed=True,
+ models=("preview_inference_models_obj",),
+ ):
+ with html.Div(
+ classes="rounded-borders col-6 full-height",
+ ):
+ ImageDetection(
+ identifier=("identifier",),
+ src=("transformed_preview_image",),
+ models=("preview_inference_models_obj",),
+ # annotations=("groundtruth_annotations",),
+ annotations=("aggregateAnnotations",),
+ categories=("annotation_categories",),
+ container_selector=".row",
+ score_threshold=("confidence_score_threshold", 0.8),
+ color_by="model",
+ )
+
+ # This is only used within when this module (file) is executed as an Standalone app.
+ @property
+ def ui(self):
+ if self._ui is None:
+ with QLayout(
+ self.server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2"
+ ) as layout:
+ # # Toolbar
+ with quasar.QHeader():
+ with quasar.QToolbar(classes="shadow-4"):
+ quasar.QBtn(
+ flat=True,
+ click="drawerLeft = !drawerLeft",
+ round=True,
+ dense=False,
+ icon="menu",
+ )
+ quasar.QToolbarTitle("Transforms")
+
+ # # Main content
+ with quasar.QPageContainer():
+ with quasar.QPage(classes="row"):
+ with html.Div(classes="col-2 q-pa-md"):
+ self.settings_widget()
+ self.preview_ui()
+
+ with html.Div(classes="col-10 q-pa-md bg-grey-3 shadow-2 rounded-borders"):
+ self.carousel_ui()
+
+ self._ui = layout
+ return self._ui
+
+
+def load_dataset(server, **kwargs):
+ import os
+ import nrtk_explorer.test_data
+
+ DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__)
+ DEFAULT_DATASETS = [
+ f"{DIR_NAME}/coco-od-2017/test_val2017.json",
+ ]
+ NUM_IMAGES_DEFAULT = 500
+
+ config_options = {
+ "dataset": {
+ "flags": ["--dataset"],
+ "params": {
+ "nargs": "+",
+ "default": DEFAULT_DATASETS,
+ "help": "Path to the JSON file describing the image dataset",
+ },
+ },
+ }
+ config = process_config(server.cli, config_options, **kwargs)
+
+ server.state.input_datasets = expand_hugging_face_datasets(config["dataset"])
+
+ server.state.all_datasets = server.state.input_datasets
+ server.state.all_datasets_options = dataset_select_options(server.state.all_datasets)
+
+ with server.state:
+ server.state.num_images = NUM_IMAGES_DEFAULT
+ server.state.num_images_max = 0
+ server.state.dataset_ids = [] # sampled images
+ server.state.user_selected_ids = [] # ensure image update in transforms app via image list
+ server.context.dataset = get_dataset(server.state.all_datasets[0])
+ server.state.num_images_max = len(server.context.dataset.imgs)
+ server.state.num_images = min(server.state.num_images_max, server.state.num_images)
+ server.state.dirty("num_images") # Trigger resample_images()
+ server.state.random_sampling_disabled = False
+ server.state.num_images_disabled = False
+
+ server.state.annotation_categories = {
+ category["id"]: category for category in server.context.dataset.cats.values()
+ }
+
+
+def resample_images(server, images, **kwargs):
+ import random
+
+ ids = [image["id"] for image in server.context.dataset.imgs.values()]
+
+ selected_images = []
+ if server.state.num_images:
+ if server.state.random_sampling:
+ selected_images = random.sample(ids, min(len(ids), server.state.num_images))
+ else:
+ selected_images = ids[: server.state.num_images]
+ else:
+ selected_images = ids
+
+ with server.state:
+ server.context.dataset_ids = selected_images
+ server.state.dataset_ids = [str(id) for id in server.context.dataset_ids]
+ server.state.user_selected_ids = server.state.dataset_ids
+ if len(server.state.dataset_ids) > 0:
+ server.state.preview_image_id = server.state.dataset_ids[0]
+ else:
+ server.state.preview_image_id = None
+
+ server.context.ground_truth_annotations.get_annotations(server.state.dataset_ids)
+
+
+def main(server=None, *args, **kwargs):
+ server = get_server(client_type="vue3")
+ images = Images(server)
+
+ load_dataset(server)
+
+ transforms_app = PreviewApp(server, images)
+ transforms_app.ui
+
+ resample_images(server, images)
+
+ server.start(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/nrtk_explorer/app/features/transforms.py b/src/nrtk_explorer/app/features/transforms.py
index 57ce4722..141164bd 100644
--- a/src/nrtk_explorer/app/features/transforms.py
+++ b/src/nrtk_explorer/app/features/transforms.py
@@ -12,6 +12,7 @@
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app.trame_utils import ProcessingStep
from nrtk_explorer.app.features.parameters import ParametersApp
+from nrtk_explorer.app.features.preview import PreviewApp
from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.ui.image_list import (
@@ -36,6 +37,10 @@ def __init__(
server=server,
)
+ self.context.parameters_app = self._parameters_app
+
+ self._preview_app = PreviewApp(self.server, self.images, **kwargs)
+
self._ui = None
self._transform_classes: Dict[str, type[trans.ImageTransform]] = {
@@ -66,6 +71,7 @@ def __init__(
)
self.server.controller.apply_transform.add(self.on_apply_transform)
+ self.server.controller.preview_transform.add(self.on_preview_transform)
self.server.controller.on_server_ready.add(self.on_server_ready)
def on_server_ready(self, *args, **kwargs):
@@ -73,6 +79,7 @@ def on_server_ready(self, *args, **kwargs):
def on_apply_transform(self, **kwargs):
# Turn on switch if user clicked lower apply button
+ self.state.show_preview = False
self.state.transform_enabled_switch = True
transforms = list(map(lambda t: t["instance"], self.context.transforms))
@@ -90,14 +97,45 @@ def on_apply_transform(self, **kwargs):
if self.ctrl.run_transform.exists():
self.ctrl.run_transform()
+ def on_preview_transform(self, **kwargs):
+ if self.state.preview_image_id is None and len(self.state.dataset_ids) > 0:
+ self.state.preview_image_id = self.state.dataset_ids[0]
+
+ self.state.show_preview = True
+
def settings_widget(self):
with html.Div(classes="col"):
self._parameters_app.transforms_ui()
def apply_ui(self):
with html.Div():
+ self._parameters_app.transform_preview_ui()
self._parameters_app.transform_apply_ui()
+ with quasar.QDialog(
+ full_width=True,
+ full_height=True,
+ transition_duration=0,
+ v_model=("show_preview", False),
+ ):
+ with html.Div(classes="row bg-grey-3 shadow-2 rounded-borders"):
+ quasar.QBtn(
+ icon="close",
+ flat=True,
+ round=True,
+ dense=True,
+ click="show_preview = false;",
+ style="position: absolute; margin: 0.5rem;",
+ )
+ with html.Div(classes="col-2 q-pa-md"):
+ html.H5("Transform Preview")
+ self._parameters_app.transforms_ui()
+ self._parameters_app.transform_preview_ui()
+ self._parameters_app.transform_apply_ui()
+
+ with html.Div(classes="col-10 q-pa-md"):
+ self._preview_app.carousel_ui()
+
# This is only used within when this module (file) is executed as an Standalone app.
@property
def ui(self):
diff --git a/src/nrtk_explorer/app/images/image_ids.py b/src/nrtk_explorer/app/images/image_ids.py
index 2a1efa4a..a71ec3ac 100644
--- a/src/nrtk_explorer/app/images/image_ids.py
+++ b/src/nrtk_explorer/app/images/image_ids.py
@@ -1,6 +1,6 @@
from nrtk_explorer.app.images.image_meta import dataset_id_to_meta
-GROUND_TRUTH_MODEL = "ground-truth"
+GROUND_TRUTH_MODEL = "groundtruth"
def image_id_to_dataset_id(image_id: str):
@@ -15,6 +15,10 @@ def dataset_id_to_transformed_image_id(dataset_id: str):
return f"transformed_img_{dataset_id}"
+def dataset_id_to_preview_image_id(dataset_id: str):
+ return f"preview_img_{dataset_id}"
+
+
def image_id_to_result_id(image_id: str, model_name: str):
return f"result_{image_id}_{model_name}"
@@ -27,6 +31,10 @@ def is_transformed(image_id: str):
return image_id.startswith("transformed_img_")
+def is_preview(image_id: str):
+ return image_id.startswith("preview_img_")
+
+
def get_image_state_keys(dataset_id: str, annotation_models: list[str]):
image_id = dataset_id_to_image_id(dataset_id)
transformed_image_id = dataset_id_to_transformed_image_id(dataset_id)
diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py
index 69809440..a49f52a9 100644
--- a/src/nrtk_explorer/app/ui/image_list.py
+++ b/src/nrtk_explorer/app/ui/image_list.py
@@ -288,7 +288,7 @@ def __init__(self, **kwargs):
}, {});
const transformedAnnotations = Object.entries(inference_models_obj).reduce(function(acc, [model_id, model]){
- if (model.name == 'ground-truth') {
+ if (model.name == 'groundtruth') {
acc[model_id] = get(`result_${original_id}_${model.name}`);
} else {
acc[model_id] = get(`result_${transformed_id}_${model.name}`);
@@ -312,7 +312,7 @@ def __init__(self, **kwargs):
original_src: get(original_id).value,
transformed: transformed_id,
transformed_src: get(transformed_id).value,
- groundTruthAnnotations: get(`result_${original_id}_ground-truth`) || [],
+ groundTruthAnnotations: get(`result_${original_id}_groundtruth`) || [],
originalAnnotations: originalAnnotations || {},
transformedAnnotations: transformedAnnotations,
originalScores,
diff --git a/src/nrtk_explorer/library/embeddings_extractor.py b/src/nrtk_explorer/library/embeddings_extractor.py
index 7f4f75fb..a7e6dfcd 100644
--- a/src/nrtk_explorer/library/embeddings_extractor.py
+++ b/src/nrtk_explorer/library/embeddings_extractor.py
@@ -25,7 +25,13 @@ def __getitem__(self, i):
class EmbeddingsExtractor:
def __init__(self, model_name="resnet50d", force_cpu=False):
- self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
+ if force_cpu:
+ self.device = "cpu"
+ elif torch.cuda.is_available():
+ self.device = "cuda"
+ else:
+ self.device = "cpu"
+
self.model = model_name
self.reset()
diff --git a/src/nrtk_explorer/library/predictor.py b/src/nrtk_explorer/library/predictor.py
index 032ca2c7..329eb14e 100644
--- a/src/nrtk_explorer/library/predictor.py
+++ b/src/nrtk_explorer/library/predictor.py
@@ -27,7 +27,14 @@ def __init__(
force_cpu: bool = False,
):
self.task = task
- self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
+
+ if force_cpu:
+ self.device = "cpu"
+ elif torch.cuda.is_available():
+ self.device = "cuda"
+ else:
+ self.device = "cpu"
+
self.pipeline = model_name
self.reset()
diff --git a/src/nrtk_explorer/widgets/nrtk_explorer.py b/src/nrtk_explorer/widgets/nrtk_explorer.py
index ff36edeb..ff0325d1 100644
--- a/src/nrtk_explorer/widgets/nrtk_explorer.py
+++ b/src/nrtk_explorer/widgets/nrtk_explorer.py
@@ -119,3 +119,22 @@ def __init__(self, **kwargs):
"models",
]
self._event_names += []
+
+
+class AnnotationAggregator(HtmlElement):
+ def __init__(self, **kwargs):
+ super().__init__(
+ "annotation-aggregator",
+ **kwargs,
+ )
+ self._attr_names += [
+ ("image_id", "imageId"),
+ "transformed",
+ "models",
+ ]
+ self._event_names += []
+
+ slot_props = [
+ "aggregateAnnotations",
+ ]
+ self._attributes["slot"] = f'v-slot="{{ {", ".join(slot_props)} }}"'
diff --git a/vue-components/src/components/AnnotationAggregator.vue b/vue-components/src/components/AnnotationAggregator.vue
new file mode 100644
index 00000000..2107f6d0
--- /dev/null
+++ b/vue-components/src/components/AnnotationAggregator.vue
@@ -0,0 +1,54 @@
+
+
+
+
+
diff --git a/vue-components/src/components/index.js b/vue-components/src/components/index.js
index a78abfcb..ffb947e4 100644
--- a/vue-components/src/components/index.js
+++ b/vue-components/src/components/index.js
@@ -5,6 +5,7 @@ import FilterOptionsWidget from './FilterOptionsWidget.vue'
import FilterOperatorWidget from './FilterOperatorWidget.vue'
import ExportWidget from './ExportWidget.vue'
import ScoreTable from './ScoreTable.vue'
+import AnnotationAggregator from './AnnotationAggregator.vue'
export default {
scatterPlot: ScatterPlot,
@@ -13,5 +14,6 @@ export default {
filterOptionsWidget: FilterOptionsWidget,
filterOperatorWidget: FilterOperatorWidget,
exportWidget: ExportWidget,
- scoreTable: ScoreTable
+ scoreTable: ScoreTable,
+ annotationAggregator: AnnotationAggregator
}