diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 952c7323..1e39f58d 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -23,6 +23,12 @@ config_features_to_enabled_features, config_preset_to_enabled_features, ) +from nrtk_explorer.app.images.image_ids import ( + GROUND_TRUTH_MODEL, +) +from nrtk_explorer.app.images.stateful_annotations import ( + make_stateful_annotations, +) from nrtk_explorer.app.images.images import Images from nrtk_explorer.app.images.image_server import ImageServer from nrtk_explorer.app.applet import Applet @@ -132,6 +138,8 @@ def __init__(self, server=None, **kwargs): images = Images(server=self.server) self._image_server = ImageServer(server=self.server, images=images) + ground_truth_annotations = make_stateful_annotations(self.server, GROUND_TRUTH_MODEL) + self.state.inference_models_obj = {0: {"name": "groundtruth"}} self._datasets_app = None if self.datasets_enabled: @@ -150,15 +158,23 @@ def __init__(self, server=None, **kwargs): from nrtk_explorer.app.features.transforms import TransformsApp self._transforms_app = TransformsApp( - server=self.server.create_child_server(), images=images, **kwargs + server=self.server.create_child_server(), + images=images, + ground_truth_annotations=ground_truth_annotations, + **kwargs, ) + self.context.tramsforms_app = self._transforms_app + self._images_app = None if self.images_enabled: from nrtk_explorer.app.features.images import ImagesApp self._images_app = ImagesApp( - server=self.server.create_child_server(), images=images, **kwargs + server=self.server.create_child_server(), + images=images, + ground_truth_annotations=ground_truth_annotations, + **kwargs, ) self._inference_app = None @@ -207,6 +223,7 @@ def __init__(self, server=None, **kwargs): self.state.dataset_ids = [] self.state.hovered_id = None self.state.maximised_id = None + self.state.maximised_annotations = [] def clear_hovered(**kwargs): self.state.hovered_id = None diff --git a/src/nrtk_explorer/app/features/images.py b/src/nrtk_explorer/app/features/images.py index 14476e50..183a1aaf 100644 --- a/src/nrtk_explorer/app/features/images.py +++ b/src/nrtk_explorer/app/features/images.py @@ -83,7 +83,7 @@ def __init__( self.images = images or Images(server) - ground_truth_annotations = ground_truth_annotations or make_stateful_annotations( + self.ground_truth_annotations = ground_truth_annotations or make_stateful_annotations( server, GROUND_TRUTH_MODEL ) self.context.ground_truth_annotations = ground_truth_annotations.annotations_factory diff --git a/src/nrtk_explorer/app/features/inference.py b/src/nrtk_explorer/app/features/inference.py index 8ba4c6c3..0bfa1fbd 100644 --- a/src/nrtk_explorer/app/features/inference.py +++ b/src/nrtk_explorer/app/features/inference.py @@ -46,7 +46,7 @@ def __init__( self.state.inference_models = [self.state.inference_models_options[0]] self.state.inference_multi_model = False - inference_models_obj = {0: {"name": "ground-truth"}} + inference_models_obj = {0: {"name": "groundtruth"}} for i, model in enumerate(self.state.inference_models): inference_models_obj[i + 1] = {"name": model} @@ -101,7 +101,7 @@ def update_inference_models(self, models): "transformed_annotations": transformed_annotations.annotations_factory, } - models_obj = {0: {"name": "ground-truth"}} + models_obj = {0: {"name": "groundtruth"}} for i, model in enumerate(models): models_obj[i + 1] = {"name": model} diff --git a/src/nrtk_explorer/app/features/parameters.py b/src/nrtk_explorer/app/features/parameters.py index 3e666b43..01501e3b 100644 --- a/src/nrtk_explorer/app/features/parameters.py +++ b/src/nrtk_explorer/app/features/parameters.py @@ -115,6 +115,15 @@ def transform_apply_ui(self): flat=True, ) + def transform_preview_ui(self): + with html.Div(trame_server=self.server): + quasar.QBtn( + "Preview Transforms", + click=(self.server.controller.preview_transform), + classes="full-width", + flat=True, + ) + def transforms_ui(self): with html.Div(trame_server=self.server): TransformsWidget( diff --git a/src/nrtk_explorer/app/features/preview.py b/src/nrtk_explorer/app/features/preview.py new file mode 100644 index 00000000..ca9f360f --- /dev/null +++ b/src/nrtk_explorer/app/features/preview.py @@ -0,0 +1,328 @@ +from typing import Dict + +from trame.ui.quasar import QLayout +from trame.widgets import quasar +from trame.widgets import html +from trame.app import get_server +from trame_client.widgets.trame import Getter + +from trame_annotations.widgets.annotations import ImageDetection +from nrtk_explorer.app.images.image_server import ImageServer +from nrtk_explorer.widgets.nrtk_explorer import AnnotationAggregator + +from nrtk_explorer.library.app_config import process_config +from nrtk_explorer.library.dataset import ( + get_dataset, + expand_hugging_face_datasets, + dataset_select_options, +) + +import nrtk_explorer.library.transforms as trans +import nrtk_explorer.library.yaml_transforms as nrtk_yaml +import nrtk_explorer.library.serialization_helpers as serialization_helpers + +from nrtk_explorer.app.applet import Applet +from nrtk_explorer.app.features.parameters import ParametersApp +from nrtk_explorer.app.images.images import Images + +from nrtk_explorer.app.images.stateful_annotations import ( + make_stateful_annotations, +) + +from nrtk_explorer.app.images.image_ids import ( + GROUND_TRUTH_MODEL, +) + + +class PreviewApp(Applet): + def __init__( + self, + server, + images=None, + ground_truth_annotations=None, + **kwargs, + ): + super().__init__(server) + + self.images = images or Images(server) + + self._image_server = ImageServer(server=self.server, images=self.images) + + self.ground_truth_annotations = ground_truth_annotations or make_stateful_annotations( + server, GROUND_TRUTH_MODEL + ) + self.context.ground_truth_annotations = self.ground_truth_annotations.annotations_factory + + if self.context.parameters_app is None: + self.context.parameters_app = ParametersApp( + server=server, + ) + + self._parameters_app = self.context.parameters_app + + self.state.dataset_ids = [] + self.state.preview_inference_models_obj = {0: {"name": "groundtruth"}} + self.state.preview_image_id = None + self.state.transformed_preview_image = None + + self._ui = None + + self._transform_classes: Dict[str, type[trans.ImageTransform]] = { + "blur": trans.GaussianBlurTransform, + "invert": trans.InvertTransform, + "downsample": trans.DownSampleTransform, + "identity": trans.IdentityTransform, + } + + # Add transform from YAML definition + self._transform_classes.update(nrtk_yaml.generate_transforms()) + + self._parameters_app._transform_classes = self._transform_classes + + # Initialize the transforms pipeline to the identity + self._parameters_app._default_transform = "blur" + self._parameters_app.on_add_transform() + + self.server.controller.on_server_ready.add(self.on_server_ready) + + def on_server_ready(self, *args, **kwargs): + self.server.controller.preview_transform.add(self.on_preview_transform) + self.state.change("preview_image_id")(self.on_preview_transform) + self.on_preview_transform() + + def on_preview_transform(self, **kwargs): + preview_id = self.state.preview_image_id + if preview_id is None: + return + + transforms = list(map(lambda t: t["instance"], self.context.transforms)) + + for transform in transforms: + params = transform.get_parameters() + for key, value in transform.get_parameters_description().items(): + if "deserialize_func" in value.keys(): + params[key] = getattr(serialization_helpers, value["deserialize_func"])( + params[key] + ) + transform.set_parameters(params) + + chained_transform = trans.ChainedImageTransform(transforms) + + image = self.images.get_image_without_cache_eviction(preview_id) + transformed_image = chained_transform.execute(image) + import base64 + from io import BytesIO + + buffered = BytesIO() + transformed_image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()) + img_base64 = bytes("data:image/jpeg;base64,", encoding="utf-8") + img_str + img_base64_str = img_base64.decode("utf-8") + self.state.transformed_preview_image = img_base64_str + + def settings_widget(self): + with html.Div(classes="col"): + self._parameters_app.transforms_ui() + + def preview_ui(self): + with html.Div(): + self._parameters_app.transform_preview_ui() + + def carousel_ui(self): + with quasar.QCarousel( + v_model=("preview_image_id",), + style="padding-bottom: 6rem;", + classes="fit", + control_color="primary", + animated=True, + infinite=True, + swipeable=True, + navigation=True, + thumbnails=True, + arrows=True, + padding=True, + transition_prev="slide-right", + transition_next="slide-left", + ): + with html.Template( + raw_attrs=['v-slot:navigation-icon="{ index, name, active, btnProps, onClick }"'] + ): + with Getter(name=("`img_${name}`",)): + quasar.QImg( + key=("name",), + classes="rounded-borders q-mr-md", + style=( + "active ? 'width: 6rem; height: 6rem; border-style: solid; border-width: 0.125rem; border-color: red;' : 'width: 6rem; height: 6rem;'", + ), + fit="cover", + src=("value",), + click="onClick", + ) + + with html.Template(v_for=("identifier in dataset_ids",)): + with quasar.QCarouselSlide( + name=("identifier",), + key=("identifier",), + ): + with html.Div( + classes="row fit justify-start items-center q-gutter-xs q-col-gutter no-wrap", + ): + with AnnotationAggregator( + image_id=("identifier",), + transformed=False, + models=("preview_inference_models_obj",), + ): + with Getter(name=("`img_${identifier}`",)): + with html.Div( + classes="rounded-borders col-6 full-height", + ): + ImageDetection( + identifier=("identifier",), + src=("value",), + models=("preview_inference_models_obj",), + annotations=("aggregateAnnotations",), + categories=("annotation_categories",), + container_selector=".row", + score_threshold=("confidence_score_threshold", 0.8), + color_by="model", + ) + + with AnnotationAggregator( + image_id=("identifier",), + transformed=True, + models=("preview_inference_models_obj",), + ): + with html.Div( + classes="rounded-borders col-6 full-height", + ): + ImageDetection( + identifier=("identifier",), + src=("transformed_preview_image",), + models=("preview_inference_models_obj",), + # annotations=("groundtruth_annotations",), + annotations=("aggregateAnnotations",), + categories=("annotation_categories",), + container_selector=".row", + score_threshold=("confidence_score_threshold", 0.8), + color_by="model", + ) + + # This is only used within when this module (file) is executed as an Standalone app. + @property + def ui(self): + if self._ui is None: + with QLayout( + self.server, view="lhh LpR lff", classes="shadow-2 rounded-borders bg-grey-2" + ) as layout: + # # Toolbar + with quasar.QHeader(): + with quasar.QToolbar(classes="shadow-4"): + quasar.QBtn( + flat=True, + click="drawerLeft = !drawerLeft", + round=True, + dense=False, + icon="menu", + ) + quasar.QToolbarTitle("Transforms") + + # # Main content + with quasar.QPageContainer(): + with quasar.QPage(classes="row"): + with html.Div(classes="col-2 q-pa-md"): + self.settings_widget() + self.preview_ui() + + with html.Div(classes="col-10 q-pa-md bg-grey-3 shadow-2 rounded-borders"): + self.carousel_ui() + + self._ui = layout + return self._ui + + +def load_dataset(server, **kwargs): + import os + import nrtk_explorer.test_data + + DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__) + DEFAULT_DATASETS = [ + f"{DIR_NAME}/coco-od-2017/test_val2017.json", + ] + NUM_IMAGES_DEFAULT = 500 + + config_options = { + "dataset": { + "flags": ["--dataset"], + "params": { + "nargs": "+", + "default": DEFAULT_DATASETS, + "help": "Path to the JSON file describing the image dataset", + }, + }, + } + config = process_config(server.cli, config_options, **kwargs) + + server.state.input_datasets = expand_hugging_face_datasets(config["dataset"]) + + server.state.all_datasets = server.state.input_datasets + server.state.all_datasets_options = dataset_select_options(server.state.all_datasets) + + with server.state: + server.state.num_images = NUM_IMAGES_DEFAULT + server.state.num_images_max = 0 + server.state.dataset_ids = [] # sampled images + server.state.user_selected_ids = [] # ensure image update in transforms app via image list + server.context.dataset = get_dataset(server.state.all_datasets[0]) + server.state.num_images_max = len(server.context.dataset.imgs) + server.state.num_images = min(server.state.num_images_max, server.state.num_images) + server.state.dirty("num_images") # Trigger resample_images() + server.state.random_sampling_disabled = False + server.state.num_images_disabled = False + + server.state.annotation_categories = { + category["id"]: category for category in server.context.dataset.cats.values() + } + + +def resample_images(server, images, **kwargs): + import random + + ids = [image["id"] for image in server.context.dataset.imgs.values()] + + selected_images = [] + if server.state.num_images: + if server.state.random_sampling: + selected_images = random.sample(ids, min(len(ids), server.state.num_images)) + else: + selected_images = ids[: server.state.num_images] + else: + selected_images = ids + + with server.state: + server.context.dataset_ids = selected_images + server.state.dataset_ids = [str(id) for id in server.context.dataset_ids] + server.state.user_selected_ids = server.state.dataset_ids + if len(server.state.dataset_ids) > 0: + server.state.preview_image_id = server.state.dataset_ids[0] + else: + server.state.preview_image_id = None + + server.context.ground_truth_annotations.get_annotations(server.state.dataset_ids) + + +def main(server=None, *args, **kwargs): + server = get_server(client_type="vue3") + images = Images(server) + + load_dataset(server) + + transforms_app = PreviewApp(server, images) + transforms_app.ui + + resample_images(server, images) + + server.start(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/src/nrtk_explorer/app/features/transforms.py b/src/nrtk_explorer/app/features/transforms.py index 57ce4722..141164bd 100644 --- a/src/nrtk_explorer/app/features/transforms.py +++ b/src/nrtk_explorer/app/features/transforms.py @@ -12,6 +12,7 @@ from nrtk_explorer.app.applet import Applet from nrtk_explorer.app.trame_utils import ProcessingStep from nrtk_explorer.app.features.parameters import ParametersApp +from nrtk_explorer.app.features.preview import PreviewApp from nrtk_explorer.app.images.images import Images from nrtk_explorer.app.ui.image_list import ( @@ -36,6 +37,10 @@ def __init__( server=server, ) + self.context.parameters_app = self._parameters_app + + self._preview_app = PreviewApp(self.server, self.images, **kwargs) + self._ui = None self._transform_classes: Dict[str, type[trans.ImageTransform]] = { @@ -66,6 +71,7 @@ def __init__( ) self.server.controller.apply_transform.add(self.on_apply_transform) + self.server.controller.preview_transform.add(self.on_preview_transform) self.server.controller.on_server_ready.add(self.on_server_ready) def on_server_ready(self, *args, **kwargs): @@ -73,6 +79,7 @@ def on_server_ready(self, *args, **kwargs): def on_apply_transform(self, **kwargs): # Turn on switch if user clicked lower apply button + self.state.show_preview = False self.state.transform_enabled_switch = True transforms = list(map(lambda t: t["instance"], self.context.transforms)) @@ -90,14 +97,45 @@ def on_apply_transform(self, **kwargs): if self.ctrl.run_transform.exists(): self.ctrl.run_transform() + def on_preview_transform(self, **kwargs): + if self.state.preview_image_id is None and len(self.state.dataset_ids) > 0: + self.state.preview_image_id = self.state.dataset_ids[0] + + self.state.show_preview = True + def settings_widget(self): with html.Div(classes="col"): self._parameters_app.transforms_ui() def apply_ui(self): with html.Div(): + self._parameters_app.transform_preview_ui() self._parameters_app.transform_apply_ui() + with quasar.QDialog( + full_width=True, + full_height=True, + transition_duration=0, + v_model=("show_preview", False), + ): + with html.Div(classes="row bg-grey-3 shadow-2 rounded-borders"): + quasar.QBtn( + icon="close", + flat=True, + round=True, + dense=True, + click="show_preview = false;", + style="position: absolute; margin: 0.5rem;", + ) + with html.Div(classes="col-2 q-pa-md"): + html.H5("Transform Preview") + self._parameters_app.transforms_ui() + self._parameters_app.transform_preview_ui() + self._parameters_app.transform_apply_ui() + + with html.Div(classes="col-10 q-pa-md"): + self._preview_app.carousel_ui() + # This is only used within when this module (file) is executed as an Standalone app. @property def ui(self): diff --git a/src/nrtk_explorer/app/images/image_ids.py b/src/nrtk_explorer/app/images/image_ids.py index 2a1efa4a..a71ec3ac 100644 --- a/src/nrtk_explorer/app/images/image_ids.py +++ b/src/nrtk_explorer/app/images/image_ids.py @@ -1,6 +1,6 @@ from nrtk_explorer.app.images.image_meta import dataset_id_to_meta -GROUND_TRUTH_MODEL = "ground-truth" +GROUND_TRUTH_MODEL = "groundtruth" def image_id_to_dataset_id(image_id: str): @@ -15,6 +15,10 @@ def dataset_id_to_transformed_image_id(dataset_id: str): return f"transformed_img_{dataset_id}" +def dataset_id_to_preview_image_id(dataset_id: str): + return f"preview_img_{dataset_id}" + + def image_id_to_result_id(image_id: str, model_name: str): return f"result_{image_id}_{model_name}" @@ -27,6 +31,10 @@ def is_transformed(image_id: str): return image_id.startswith("transformed_img_") +def is_preview(image_id: str): + return image_id.startswith("preview_img_") + + def get_image_state_keys(dataset_id: str, annotation_models: list[str]): image_id = dataset_id_to_image_id(dataset_id) transformed_image_id = dataset_id_to_transformed_image_id(dataset_id) diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py index 69809440..a49f52a9 100644 --- a/src/nrtk_explorer/app/ui/image_list.py +++ b/src/nrtk_explorer/app/ui/image_list.py @@ -288,7 +288,7 @@ def __init__(self, **kwargs): }, {}); const transformedAnnotations = Object.entries(inference_models_obj).reduce(function(acc, [model_id, model]){ - if (model.name == 'ground-truth') { + if (model.name == 'groundtruth') { acc[model_id] = get(`result_${original_id}_${model.name}`); } else { acc[model_id] = get(`result_${transformed_id}_${model.name}`); @@ -312,7 +312,7 @@ def __init__(self, **kwargs): original_src: get(original_id).value, transformed: transformed_id, transformed_src: get(transformed_id).value, - groundTruthAnnotations: get(`result_${original_id}_ground-truth`) || [], + groundTruthAnnotations: get(`result_${original_id}_groundtruth`) || [], originalAnnotations: originalAnnotations || {}, transformedAnnotations: transformedAnnotations, originalScores, diff --git a/src/nrtk_explorer/library/embeddings_extractor.py b/src/nrtk_explorer/library/embeddings_extractor.py index 7f4f75fb..a7e6dfcd 100644 --- a/src/nrtk_explorer/library/embeddings_extractor.py +++ b/src/nrtk_explorer/library/embeddings_extractor.py @@ -25,7 +25,13 @@ def __getitem__(self, i): class EmbeddingsExtractor: def __init__(self, model_name="resnet50d", force_cpu=False): - self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" + if force_cpu: + self.device = "cpu" + elif torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + self.model = model_name self.reset() diff --git a/src/nrtk_explorer/library/predictor.py b/src/nrtk_explorer/library/predictor.py index 032ca2c7..329eb14e 100644 --- a/src/nrtk_explorer/library/predictor.py +++ b/src/nrtk_explorer/library/predictor.py @@ -27,7 +27,14 @@ def __init__( force_cpu: bool = False, ): self.task = task - self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" + + if force_cpu: + self.device = "cpu" + elif torch.cuda.is_available(): + self.device = "cuda" + else: + self.device = "cpu" + self.pipeline = model_name self.reset() diff --git a/src/nrtk_explorer/widgets/nrtk_explorer.py b/src/nrtk_explorer/widgets/nrtk_explorer.py index ff36edeb..ff0325d1 100644 --- a/src/nrtk_explorer/widgets/nrtk_explorer.py +++ b/src/nrtk_explorer/widgets/nrtk_explorer.py @@ -119,3 +119,22 @@ def __init__(self, **kwargs): "models", ] self._event_names += [] + + +class AnnotationAggregator(HtmlElement): + def __init__(self, **kwargs): + super().__init__( + "annotation-aggregator", + **kwargs, + ) + self._attr_names += [ + ("image_id", "imageId"), + "transformed", + "models", + ] + self._event_names += [] + + slot_props = [ + "aggregateAnnotations", + ] + self._attributes["slot"] = f'v-slot="{{ {", ".join(slot_props)} }}"' diff --git a/vue-components/src/components/AnnotationAggregator.vue b/vue-components/src/components/AnnotationAggregator.vue new file mode 100644 index 00000000..2107f6d0 --- /dev/null +++ b/vue-components/src/components/AnnotationAggregator.vue @@ -0,0 +1,54 @@ + + + diff --git a/vue-components/src/components/index.js b/vue-components/src/components/index.js index a78abfcb..ffb947e4 100644 --- a/vue-components/src/components/index.js +++ b/vue-components/src/components/index.js @@ -5,6 +5,7 @@ import FilterOptionsWidget from './FilterOptionsWidget.vue' import FilterOperatorWidget from './FilterOperatorWidget.vue' import ExportWidget from './ExportWidget.vue' import ScoreTable from './ScoreTable.vue' +import AnnotationAggregator from './AnnotationAggregator.vue' export default { scatterPlot: ScatterPlot, @@ -13,5 +14,6 @@ export default { filterOptionsWidget: FilterOptionsWidget, filterOperatorWidget: FilterOperatorWidget, exportWidget: ExportWidget, - scoreTable: ScoreTable + scoreTable: ScoreTable, + annotationAggregator: AnnotationAggregator }