From c48ad245ff343fc1df9154ee37c7f59842373b41 Mon Sep 17 00:00:00 2001 From: Stephen Date: Thu, 17 Jul 2025 14:49:34 -0400 Subject: [PATCH] feat(embeddings): update embeddings to filter images properly --- src/nrtk_explorer/app/features/embeddings.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/nrtk_explorer/app/features/embeddings.py b/src/nrtk_explorer/app/features/embeddings.py index eadb0cee..ae48c469 100644 --- a/src/nrtk_explorer/app/features/embeddings.py +++ b/src/nrtk_explorer/app/features/embeddings.py @@ -51,12 +51,12 @@ def add_images(self, dataset_id_to_image: IdToImage): self.transformed_features.update(id_to_feature) self.emit_update() - @change("dataset_ids") - def on_dataset_ids(self, **kwargs): + @change("user_selected_ids") + def on_user_selected_ids(self, **kwargs): self.transformed_features = { id: features for id, features in self.transformed_features.items() - if image_id_to_dataset_id(id) in self.server.state.dataset_ids + if image_id_to_dataset_id(id) in self.server.state.user_selected_ids } self.emit_update() @@ -87,7 +87,7 @@ def __init__( # Local initialization if standalone self.is_standalone_app = self.server.root_server == self.server if self.is_standalone_app and datasets: - self.state.dataset_ids = [] + self.state.user_selected_ids = [] self.state.current_dataset = datasets[0] self.context.dataset = get_dataset(self.state.current_dataset) @@ -115,7 +115,7 @@ def on_server_ready(self, *args, **kwargs): self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change) self.save_embedding_params() self.update_points() - self.state.change("dataset_ids")(self.update_points) + self.state.change("user_selected_ids")(self.update_points) self.ctrl.apply_transform.add(self.clear_points_transformations) self.ctrl.apply_transform.add(self.transformed_images.clear) self.state.change("transform_enabled_switch")(self.update_points_transformations_state) @@ -127,7 +127,7 @@ def on_feature_extraction_model_change(self, **kwargs): self.transformed_images.set_extractor(self.extractor) def compute_points(self, fit_features, features): - if len(features) == 0: + if len(features) <= 1: # reduce will fail if no features return [] @@ -171,14 +171,14 @@ def update_points_transformations_state(self, **kwargs): def compute_source_points(self): images = ( - self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids + self.images.get_image_without_cache_eviction(id) for id in self.state.user_selected_ids ) self.features = self.extractor.extract(images) points = self.compute_points(self.features, self.features) self.state.points_sources = { - id: point for id, point in zip(self.state.dataset_ids, points) + id: point for id, point in zip(self.state.user_selected_ids, points) } self.state.camera_position = [] @@ -243,15 +243,15 @@ def update_transformed_points(self, id_to_features): self.update_points_transformations_state() def on_scatter_select(self, image_ids): - self.state.user_selected_ids = image_ids or self.state.dataset_ids + self.state.user_selected_ids = image_ids or self.state.user_selected_ids def on_move(self, camera_position): self.state.camera_position = camera_position def get_dataset_id_index(self, point_index): - if point_index < len(self.state.dataset_ids): + if point_index < len(self.state.user_selected_ids): return point_index - return point_index - len(self.state.dataset_ids) + return point_index - len(self.state.user_selected_ids) def on_point_hover(self, event): self.state.highlighted_image = event