diff --git a/detect/__main__.py b/detect/__main__.py deleted file mode 100644 index 1e3fdbb..0000000 --- a/detect/__main__.py +++ /dev/null @@ -1,203 +0,0 @@ -import collections -import math -import operator -import os -import sys -import time - -from http import HTTPStatus - -import requests - -from .detector import Detector - - -QUEUE_NAME = os.environ.get("QUEUE_NAME", "detection") -WAIT_TIMEOUT = os.environ.get("WAIT_TIMEOUT", 30) -VIZAR_SERVER = os.environ.get("VIZAR_SERVER", "localhost:5000") -MIN_RETRY_INTERVAL = 5 - -MODEL_REPO = "yolov8" -MODEL_NAME = "yolov8m-seg-nms" - -MARK_ALL_OBJECTS = True -MARK_LABELS = set(["door", "dining table", "desk", "table"]) - -# Rename some of the labels from the detector before marking them as map features. -LABELS_TO_FEATURE_NAMES = { -# "dining table": "table", -# "desk": "table" -} - - -def try_create_features(location_id, item, info): - # Get current list of features for the location - features_url = "http://{}/locations/{}/features".format(VIZAR_SERVER, location_id) - response = requests.get(features_url) - if not response.ok: - return - - features = response.json() - features_by_label = collections.defaultdict(list) - - # Organize the existing features by name. - # We are only interested in objects of the configured types. - for feature in features: - if feature.get("type") == "object" and feature.get("name") in MARK_LABELS: - features_by_label[feature['name']].append(feature) - - for obj in info.get("annotations", []): - label = obj['label'] - if not MARK_ALL_OBJECTS and label not in MARK_LABELS: - continue - if label in LABELS_TO_FEATURE_NAMES: - label = LABELS_TO_FEATURE_NAMES[label] - - pos = obj.get("position") - if pos is None: - continue - - pos_error = obj.get("position_error", 10.0) - - # Check if there is any already existing feature within a certain - # radius. We use the combined position_error values for the two - # objects, which give a rough estimate of how wide they are. If they - # are too close, avoid creating another feature. - duplicate = False - for other in features_by_label[label]: - sq_dist = sum( (pos[d] - other['position'][d])**2 for d in ["x", "y", "z"] ) - dist = math.sqrt(sq_dist) - - # Other point's radius may not be set, which means we do not - # know the position_error value for that other feature. - # Just use the new object's position_error twice, then. - other_radius = other.get("radius") - if other_radius is None: - other_radius = pos_error - - threshold = pos_error + other_radius - - if dist < threshold: - duplicate = True - break - - if duplicate: - continue - - # Create a new feature on the map. We are abusing the radius field - # here to store the position error / spread. The radius attribute was - # meant to control when the feature should be displayed in AR, only - # when the user is within a certain radius of the feature position. - # However, it is not really used. - new_feature = { - "name": label, - "position": pos, - "style": { - "placement": "point", - "radius": pos_error - }, - "type": "object" - } - response = requests.post(features_url, json=new_feature) - if response.ok: - new_feature = response.json() - features_by_label[label].append(new_feature) - - -def get_queue_names(): - url = "http://{}/photos/queues".format(VIZAR_SERVER) - response = requests.get(url) - if response.ok and response.status_code == HTTPStatus.OK: - items = response.json() - return set(x['name'] for x in items) - else: - return set([QUEUE_NAME, "done"]) - - -def get_next_queue(detection_result, supported_queue_names): - annotations = detection_result.info.get('annotations', []) - has_person = any(x['label'] == "person" for x in annotations) - - if has_person and "identification" in supported_queue_names: - return "identification" - elif len(annotations) > 0 and "detection-3d" in supported_queue_names: - return "detection-3d" - else: - return "done" - - -def main(): - detector = Detector(MODEL_REPO, MODEL_NAME) - detector.initialize_model() - - while True: - sys.stdout.flush() - - # Set of photo queues supported by the server - supported_queue_names = get_queue_names() - - query_url = "http://{}/photos?queue_name={}&wait={}".format(VIZAR_SERVER, QUEUE_NAME, WAIT_TIMEOUT) - start_time = time.time() - - items = [] - - try: - response = requests.get(query_url) - if response.ok and response.status_code == HTTPStatus.OK: - items = response.json() - except requests.exceptions.RequestException as error: - # Most common case is if the API server is restarting, - # then we see a connection error temporarily. - print(error) - - # Check if the empty/error response from the server was sooner than - # expected. If so, add an extra delay to avoid spamming the server. - # We need this in case long-polling is not working as expected. - if len(items) == 0: - elapsed = time.time() - start_time - if elapsed < MIN_RETRY_INTERVAL: - time.sleep(MIN_RETRY_INTERVAL - elapsed) - continue - - for item in items: - # Sort by priority level (descending), then creation time (ascending) - item['priority_tuple'] = (-1 * item.get("priority", 0), item.get("created")) - - items.sort(key=operator.itemgetter("priority_tuple")) - for item in items: - try: - result = detector.run(item) - except Exception as error: - print(error) - result = None - - url = "http://{}/photos/{}".format(VIZAR_SERVER, item['id']) - if result is not None: - # Determine the next queue for this photo, then send update - result.info['status'] = get_next_queue(result, supported_queue_names) - requests.patch(url, json=result.info) - - annotated_png, mask_png = result.apply_masks() - headers = { - "Content-Type": "image/png" - } - annotated_url = "{}/annotated.png".format(url) - req = requests.put(annotated_url, data=annotated_png, headers=headers) - mask_url = "{}/mask.png".format(url) - req = requests.put(mask_url, data=mask_png, headers=headers) - - geom_url = "{}/geometry.png".format(url) - if result.try_localize_objects(geom_url): - requests.patch(url, json=result.info) - - camera_location_id = item.get("camera_location_id") - if camera_location_id is not None: - try_create_features(camera_location_id, item, result.info) - - else: - info = {"status": "error"} - requests.patch(url, json=info) - - -if __name__ == "__main__": - main() diff --git a/detect/__init__.py b/ocr/__init__.py similarity index 100% rename from detect/__init__.py rename to ocr/__init__.py diff --git a/ocr/__main__.py b/ocr/__main__.py new file mode 100644 index 0000000..1d9d722 --- /dev/null +++ b/ocr/__main__.py @@ -0,0 +1,166 @@ +import operator +import os +import sys +import time +from http import HTTPStatus + +import requests + +from .ocr_engine import OCREngine + + +QUEUE_NAME = os.environ.get("QUEUE_NAME", "ocr") +WAIT_TIMEOUT = int(os.environ.get("WAIT_TIMEOUT", 30)) +VIZAR_SERVER = os.environ.get("VIZAR_SERVER", "easyvizar.wings.cs.wisc.edu:5001") +MIN_RETRY_INTERVAL = 5 + +API_TOKEN = os.environ.get("VIZAR_API_TOKEN", "") +API_KEY = os.environ.get("VIZAR_API_KEY", "") + + +def build_headers(extra=None): + headers = {} + if API_TOKEN: + headers["Authorization"] = f"Bearer {API_TOKEN}" + if API_KEY: + headers["X-API-Key"] = API_KEY + if extra: + headers.update(extra) + + return headers + + +def get_queue_names(): + url = f"http://{VIZAR_SERVER}/photos/queues" + try: + response = requests.get(url, headers=build_headers()) + if response.ok and response.status_code == HTTPStatus.OK: + items = response.json() + queues = set(x["name"] for x in items) + print("Available queues:", queues) + return queues + else: + print("Queue request failed with status:", response.status_code) + print("Queue response body:", response.text[:300]) + except requests.exceptions.RequestException as error: + print("Queue fetch error:", error) + + return {QUEUE_NAME, "done"} + + +def get_next_queue(result, supported_queue_names): + return "done" + + +def main(): + engine = OCREngine() + engine.initialize_model() + print("OCR worker started. Waiting for queue items...") + print("Using server:", VIZAR_SERVER) + print("Listening on queue:", QUEUE_NAME) + + while True: + sys.stdout.flush() + + supported_queue_names = get_queue_names() + + query_url = f"http://{VIZAR_SERVER}/photos?queue_name={QUEUE_NAME}&wait={WAIT_TIMEOUT}" + start_time = time.time() + items = [] + + try: + response = requests.get(query_url, headers=build_headers()) + + if response.status_code == HTTPStatus.NO_CONTENT: + print("No items currently in queue", QUEUE_NAME) + items = [] + + elif response.ok and response.status_code == HTTPStatus.OK: + items = response.json() + print("Polled", len(items), "items from queue", QUEUE_NAME) + + else: + print("Poll request failed with status:", response.status_code) + print("Poll response body:", response.text[:300]) + + except requests.exceptions.RequestException as error: + print("Polling error:", error) + + if len(items) == 0: + elapsed = time.time() - start_time + if elapsed < MIN_RETRY_INTERVAL: + time.sleep(MIN_RETRY_INTERVAL - elapsed) + continue + + for item in items: + item["priority_tuple"] = (-1 * item.get("priority", 0), item.get("created")) + + items.sort(key=operator.itemgetter("priority_tuple")) + + for item in items: + print("Got item:", item.get("id")) + + try: + result = engine.run(item) + except Exception as error: + print("Run error:", error) + result = None + + url = f"http://{VIZAR_SERVER}/photos/{item['id']}" + + if result is not None: + result.info["status"] = get_next_queue(result, supported_queue_names) + + try: + print("PATCH payload:", result.info) + + patch_response = requests.patch(url, json=result.info, headers=build_headers()) + print("Patched photo info with status:", patch_response.status_code) + + if patch_response.status_code >= 400: + print("PATCH response body:", patch_response.text) + print("PATCH failed, stopping worker so item does not loop forever.") + return + + except requests.exceptions.RequestException as error: + print("Patch error:", error) + return + + try: + annotated_png, mask_png = result.apply_masks() + content_headers = build_headers({"Content-Type": "image/png"}) + + annotated_url = f"{url}/annotated.png" + annotated_response = requests.put( + annotated_url, + data=annotated_png, + headers=content_headers, + ) + print("Uploaded annotated image with status:", annotated_response.status_code) + + mask_url = f"{url}/mask.png" + mask_response = requests.put( + mask_url, + data=mask_png, + headers=content_headers, + ) + print("Uploaded mask image with status:", mask_response.status_code) + + except requests.exceptions.RequestException as error: + print("Upload error:", error) + except Exception as error: + print("Mask/apply error:", error) + else: + try: + error_response = requests.patch( + url, + json={"status": "error"}, + headers=build_headers(), + ) + print("Marked item as error with status:", error_response.status_code) + except requests.exceptions.RequestException as error: + print("Error status patch failed:", error) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ocr/__pycache__/__init__.cpython-310.pyc b/ocr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..f33522a Binary files /dev/null and b/ocr/__pycache__/__init__.cpython-310.pyc differ diff --git a/ocr/__pycache__/__init__.cpython-312.pyc b/ocr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..65c53ff Binary files /dev/null and b/ocr/__pycache__/__init__.cpython-312.pyc differ diff --git a/ocr/__pycache__/__init__.cpython-313.pyc b/ocr/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..a8c3fb5 Binary files /dev/null and b/ocr/__pycache__/__init__.cpython-313.pyc differ diff --git a/ocr/__pycache__/__main__.cpython-310.pyc b/ocr/__pycache__/__main__.cpython-310.pyc new file mode 100644 index 0000000..6cd4a1b Binary files /dev/null and b/ocr/__pycache__/__main__.cpython-310.pyc differ diff --git a/ocr/__pycache__/__main__.cpython-312.pyc b/ocr/__pycache__/__main__.cpython-312.pyc new file mode 100644 index 0000000..bee3d6c Binary files /dev/null and b/ocr/__pycache__/__main__.cpython-312.pyc differ diff --git a/ocr/__pycache__/__main__.cpython-313.pyc b/ocr/__pycache__/__main__.cpython-313.pyc new file mode 100644 index 0000000..54b652b Binary files /dev/null and b/ocr/__pycache__/__main__.cpython-313.pyc differ diff --git a/ocr/__pycache__/ocr_engine.cpython-310.pyc b/ocr/__pycache__/ocr_engine.cpython-310.pyc new file mode 100644 index 0000000..8ce36b3 Binary files /dev/null and b/ocr/__pycache__/ocr_engine.cpython-310.pyc differ diff --git a/ocr/__pycache__/ocr_engine.cpython-313.pyc b/ocr/__pycache__/ocr_engine.cpython-313.pyc new file mode 100644 index 0000000..7d3ee35 Binary files /dev/null and b/ocr/__pycache__/ocr_engine.cpython-313.pyc differ diff --git a/detect/detector.py b/ocr/detector.py similarity index 100% rename from detect/detector.py rename to ocr/detector.py diff --git a/ocr/ocr_engine.py b/ocr/ocr_engine.py new file mode 100644 index 0000000..80fe798 --- /dev/null +++ b/ocr/ocr_engine.py @@ -0,0 +1,261 @@ +import io +import os +import re +import time + +import easyocr +import imageio.v3 as iio +import numpy as np +from PIL import Image, ImageDraw, ImageOps + + +DATA_PATH = os.environ.get("DATA_PATH", "./") +VIZAR_SERVER = os.environ.get("VIZAR_SERVER", "easyvizar.wings.cs.wisc.edu:5001") +OCR_LANGS = [x.strip() for x in os.environ.get("OCR_LANGS", "en").split(",") if x.strip()] +OCR_GPU = os.environ.get("OCR_GPU", "false").lower() == "true" +MIN_CONFIDENCE = float(os.environ.get("MIN_CONFIDENCE", "0.20")) + + +def encode_png(data): + buffer = io.BytesIO() + iio.imwrite(buffer, data, extension=".png") + return buffer.getvalue() + + +def normalize_room_text(text): + return re.sub(r"[^A-Za-z0-9]", "", text.upper()) + + +def matches_room_number(text): + patterns = [ + r"^[A-Z]?\d{2,4}[A-Z]?$", # 221, B145, 314A + r"^\d[A-Z]\d{2,3}$", # 2W14 + r"^[A-Z]{1,4}\d{2,4}[A-Z]?$", # RM221, CS314A + r"^\d{2,4}$", # 221, 901 + ] + return any(re.match(pattern, text) for pattern in patterns) + + +def looks_number_like(text): + digit_count = sum(ch.isdigit() for ch in text) + return digit_count >= 1 and len(text) <= 12 + + +class OCRResult: + def __init__(self, info, image, display_annotations=None): + self.info = info + self.image = image + self.display_annotations = display_annotations or [] + + def apply_masks(self): + if self.image.ndim == 2: + rgb = np.stack([self.image] * 3, axis=-1) + else: + rgb = self.image[:, :, :3] + + pil = Image.fromarray(rgb).convert("RGB") + draw = ImageDraw.Draw(pil) + + h, w = rgb.shape[:2] + + for ann in self.display_annotations: + boundary = ann["boundary"] + + left = int(boundary["left"] * w) + top = int(boundary["top"] * h) + right = int((boundary["left"] + boundary["width"]) * w) + bottom = int((boundary["top"] + boundary["height"]) * h) + + label = ann.get("label", "text") + sublabel = ann.get("sublabel", "") + + if label == "room-number": + color = "red" + text_y = max(0, top - 15) + else: + color = "yellow" + text_y = min(h - 15, bottom + 2) + + display_text = sublabel if sublabel else label + + draw.rectangle([left, top, right, bottom], outline=color, width=3) + draw.text((left, text_y), display_text, fill=color) + + annotated = np.array(pil) + annotated_png = encode_png(annotated) + + mask = np.zeros((h, w, 4), dtype=np.uint8) + mask_png = encode_png(mask) + + return annotated_png, mask_png + + +class OCREngine: + def __init__(self): + self.reader = None + + def initialize_model(self): + if self.reader is None: + self.reader = easyocr.Reader(OCR_LANGS, gpu=OCR_GPU) + + def choose_source(self, item): + path = item.get("imagePath") + url = item.get("imageUrl") + + if path not in [None, ""]: + full_path = os.path.join(DATA_PATH, path) + if os.path.isfile(full_path): + return full_path + + if isinstance(url, str) and url.startswith("http"): + return url + + if isinstance(url, str) and url.startswith("/"): + return "http://" + VIZAR_SERVER + url + + raise Exception(f"Cannot load image path ({path}) or URL ({url})") + + def preprocess(self, image): + if image.ndim == 2: + return image + + rgb = image[:, :, :3] + pil = Image.fromarray(rgb).convert("L") + pil = ImageOps.autocontrast(pil) + return np.array(pil) + + def _bbox_to_boundary(self, bbox, w, h): + xs = [point[0] for point in bbox] + ys = [point[1] for point in bbox] + + min_x = max(0.0, min(xs)) + max_x = min(float(w), max(xs)) + min_y = max(0.0, min(ys)) + max_y = min(float(h), max(ys)) + + if max_x <= min_x or max_y <= min_y: + return None + + return { + "left": float(min_x / w), + "top": float(min_y / h), + "width": float((max_x - min_x) / w), + "height": float((max_y - min_y) / h), + } + + def run(self, item): + self.initialize_model() + + source = self.choose_source(item) + print(f"Processing image from {source}...") + + image = iio.imread(source) + h, w = image.shape[:2] + + preprocess_start = time.time() + processed = self.preprocess(image) + + inference_start = time.time() + raw_results = self.reader.readtext(processed) + postprocess_start = time.time() + + print("Raw OCR results:") + for entry in raw_results: + print(entry) + + confirmed_annotations = [] + fallback_annotations = [] + + for entry in raw_results: + bbox, text, confidence = entry + cleaned_text = normalize_room_text(text) + + print("OCR text:", text, "-> cleaned:", cleaned_text, "confidence:", confidence) + + boundary = self._bbox_to_boundary(bbox, w, h) + if boundary is None: + print("Rejected because bounding box is invalid") + continue + + if confidence < MIN_CONFIDENCE: + print("Rejected because confidence too low") + continue + + if not cleaned_text: + print("Rejected because cleaned text is empty") + continue + + if matches_room_number(cleaned_text): + annotation = { + "boundary": boundary, + "confidence": float(confidence), + "label": "room-number", + "sublabel": cleaned_text, + } + print("Accepted room-number annotation:", annotation) + confirmed_annotations.append(annotation) + continue + + if looks_number_like(cleaned_text): + fallback = { + "boundary": boundary, + "confidence": float(confidence), + "label": "possible-room-text", + "sublabel": cleaned_text, + } + print("Stored fallback annotation:", fallback) + fallback_annotations.append(fallback) + else: + print("Rejected because text does not match room-number pattern") + + postprocess_end = time.time() + + if len(raw_results) == 0: + ocr_status = "no-text-detected" + ocr_message = "Could not extract any text from image." + elif len(confirmed_annotations) > 0: + ocr_status = "room-number-detected" + ocr_message = "Room number text extracted successfully." + elif len(fallback_annotations) > 0: + ocr_status = "possible-number-detected-no-room-match" + ocr_message = "Detected number-like text regions, but could not confidently identify a room number." + else: + ocr_status = "text-detected-no-room-match" + ocr_message = "Text was detected, but no room number could be confidently identified." + + print("OCR found", len(confirmed_annotations), "room-number annotations") + print("OCR summary:", ocr_status, "-", ocr_message) + + if len(confirmed_annotations) > 0: + annotations_for_server = confirmed_annotations + elif len(fallback_annotations) > 0: + annotations_for_server = fallback_annotations + else: + annotations_for_server = [] + + info = { + "status": "done", + "annotations": annotations_for_server, + "detector": { + "model_repo": "ocr", + "model_name": "easyocr", + "engine_name": "easyocr", + "preprocess_duration": inference_start - preprocess_start, + "inference_duration": postprocess_start - inference_start, + "postprocess_duration": postprocess_end - postprocess_start, + }, + "ocr_summary": { + "status": ocr_status, + "message": ocr_message, + "raw_text_count": len(raw_results), + "matched_room_count": len(confirmed_annotations), + "fallback_region_count": len(fallback_annotations), + "raw_text": [entry[1] for entry in raw_results], + }, + } + + return OCRResult( + info=info, + image=image, + display_annotations=annotations_for_server, + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index abe4d6b..cf41bea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,8 @@ -certifi==2023.5.7 -charset-normalizer -coloredlogs==15.0.1 -flatbuffers==23.5.26 -humanfriendly==10.0 -idna==3.4 +easyocr imageio -mpmath==1.3.0 -networkx>=2.5.1 -numpy>=1.19.5 -packaging>=21.3 -Pillow>=8.4.0 -protobuf>=4.21.0 -PyWavelets>=1.1.1 -requests>=2.27.1 -scikit-image>=0.17.2 -scipy>=1.5.4 -sympy>=1.9 -tifffile>=2020.9.3 -urllib3>=1.26.19 +numpy +Pillow +requests +opencv-python-headless +torch +torchvision \ No newline at end of file diff --git a/tests/test_detector.py b/tests/test_detector.py index 27f79cc..8d54fdf 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from detect.detector import Detector +from ocr.detector import Detector def test_detector_choose_source():