diff --git a/requirements.txt b/requirements.txt index 87eac85..5a97c1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,6 @@ pillow opencv-python PyYAML scipy -segment_anything @ git+https://github.com/facebookresearch/segment-anything.git@dca509fe793f601edb92606367a655c15ac00fdf +sam2 @ git+https://github.com/facebookresearch/sam2.git torch torchvision diff --git a/src/multilabeller/SAM/sam.py b/src/multilabeller/SAM/sam.py index f446d9d..d41d214 100644 --- a/src/multilabeller/SAM/sam.py +++ b/src/multilabeller/SAM/sam.py @@ -1,6 +1,7 @@ import cv2 import numpy as np -from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from src.multilabeller.SAM_contour import SAM_Contour @@ -13,12 +14,11 @@ def __init__(self, config): self.contours = [] def initialize(self, config): - model_type = config["model"]["name"] + model_config = config["model"]["config"] sam_checkpoint = config["model"]["file"] device = config["device"] - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) - sam.to(device=device) - self.mask_generator = SamAutomaticMaskGenerator(sam) + sam2 = build_sam2(model_config, sam_checkpoint, device=device) + self.mask_generator = SAM2AutomaticMaskGenerator(sam2) def apply(self, image_input): masks = self.mask_generator.generate(image_input) diff --git a/src/multilabeller/config.yml b/src/multilabeller/config.yml index 4e66d8a..dc9c961 100644 --- a/src/multilabeller/config.yml +++ b/src/multilabeller/config.yml @@ -30,8 +30,8 @@ left_mouse_click: linux: "" SAM: - device: "cpu" # "cpu" or "cuda" - model: "vit_b" # vit_l; vit_b or vit_h + device: "cuda" # "cpu" or "cuda" + model: "sam2.1_hiera_small" # sam2.1_hiera_tiny; sam2.1_hiera_small; sam2.1_hiera_base_plus; sam2.1_hiera_large shortcuts: circle_mode: "b" @@ -48,5 +48,5 @@ shortcuts: #TODO: It should work with any image size, but its not working. We'll check this out in the future #TODO: It works OK with a square dimension. image_viewer: - width: 400 - height: 400 + width: 1000 + height: 1000 diff --git a/src/multilabeller/image_viewer_app/image_viewer_app.py b/src/multilabeller/image_viewer_app/image_viewer_app.py index c08beb3..1a73ada 100644 --- a/src/multilabeller/image_viewer_app/image_viewer_app.py +++ b/src/multilabeller/image_viewer_app/image_viewer_app.py @@ -84,24 +84,30 @@ def initialize_SAM(self): cuda_available ), "PyTorch is having trouble with CUDA. Please change device to 'cpu'" - if self.config["SAM"]["model"] == "vit_b": - SAM_model_filename = "sam_vit_b_01ec64.pth" - elif self.config["SAM"]["model"] == "vit_l": - SAM_model_filename = "sam_vit_l_0b3195.pth" - elif self.config["SAM"]["model"] == "vit_h": - SAM_model_filename = "sam_vit_h_4b8939.pth" + model_name = self.config["SAM"]["model"] + SAM2_MODELS = { + "sam2.1_hiera_tiny": ("sam2.1_hiera_tiny.pt", "configs/sam2.1/sam2.1_hiera_t.yaml"), + "sam2.1_hiera_small": ("sam2.1_hiera_small.pt", "configs/sam2.1/sam2.1_hiera_s.yaml"), + "sam2.1_hiera_base_plus": ("sam2.1_hiera_base_plus.pt", "configs/sam2.1/sam2.1_hiera_b+.yaml"), + "sam2.1_hiera_large": ("sam2.1_hiera_large.pt", "configs/sam2.1/sam2.1_hiera_l.yaml"), + } + assert model_name in SAM2_MODELS, ( + f"Unknown SAM2 model '{model_name}'. " + f"Choose from: {list(SAM2_MODELS.keys())}" + ) + SAM_model_filename, SAM_model_config = SAM2_MODELS[model_name] SAM_model_file = Path("SAM", SAM_model_filename) message = ( - f"{os.path.basename(SAM_model_file)} not found at src/multilabeller/SAM folder." + f"{SAM_model_filename} not found at src/multilabeller/SAM folder." f"\nPlease download the model file at " - f"https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints" + f"https://github.com/facebookresearch/sam2?tab=readme-ov-file#model-checkpoints" ) assert os.path.isfile(SAM_model_file), message SAM_config = { "device": self.config["SAM"]["device"], - "model": {"name": self.config["SAM"]["model"], "file": SAM_model_file}, + "model": {"config": SAM_model_config, "file": SAM_model_file}, } self.SAM = SegmentAnything(SAM_config) diff --git a/src/multilabeller/test/h5read_test.py b/src/multilabeller/test/h5read_test.py index 43aa1c9..01688f9 100644 --- a/src/multilabeller/test/h5read_test.py +++ b/src/multilabeller/test/h5read_test.py @@ -3,7 +3,7 @@ from pathlib import Path # Path to .h5 file - Modify it accordingly -h5_file = Path(r"output/aiaiaia.h5") +h5_file = Path(r"..\test\output\train.h5") # Reading the .h5 file h5_dataset = h5py.File(h5_file, "r")