diff --git a/src/pick/__init__.py b/src/pick/__init__.py index 148b478..1ff8af6 100644 --- a/src/pick/__init__.py +++ b/src/pick/__init__.py @@ -33,6 +33,7 @@ class Option: KEYS_UP = (curses.KEY_UP, ord("k")) KEYS_DOWN = (curses.KEY_DOWN, ord("j")) KEYS_SELECT = (curses.KEY_RIGHT, ord(" ")) +KEYS_BACKSPACE = (curses.KEY_BACKSPACE, 127, 8) SYMBOL_CIRCLE_FILLED = "(x)" SYMBOL_CIRCLE_EMPTY = "( )" @@ -58,6 +59,8 @@ class Picker(Generic[OPTION_T]): clear_screen: bool = True quit_keys: Optional[Union[Container[int], Iterable[int]]] = None backend: Union[str, Backend] = "curses" + enable_search: bool = False + _search_string: str = field(init=False, default="") def __post_init__(self) -> None: if len(self.options) == 0: @@ -77,36 +80,71 @@ def __post_init__(self) -> None: ) self.index = self.default_index - option = self.options[self.index] + self._ensure_valid_selection() + + def _ensure_valid_selection(self) -> None: + filtered = self.get_filtered_options() + if not filtered: + return + if self.index >= len(filtered): + self.index = 0 + _, option = filtered[self.index] if isinstance(option, Option) and not option.enabled: self.move_down() + def get_filtered_options(self) -> List[Tuple[int, OPTION_T]]: + if not self.enable_search or not self._search_string: + return list(enumerate(self.options)) + + query = self._search_string.lower() + filtered = [] + for i, opt in enumerate(self.options): + label = opt.label if isinstance(opt, Option) else str(opt) + if query in label.lower(): + filtered.append((i, opt)) + return filtered + def move_up(self) -> None: + filtered = self.get_filtered_options() + if not filtered: + self.index = 0 + return while True: self.index -= 1 if self.index < 0: - self.index = len(self.options) - 1 - option = self.options[self.index] + self.index = len(filtered) - 1 + _, option = filtered[self.index] if not isinstance(option, Option) or option.enabled: break def move_down(self) -> None: + filtered = self.get_filtered_options() + if not filtered: + self.index = 0 + return while True: self.index += 1 - if self.index >= len(self.options): + if self.index >= len(filtered): self.index = 0 - option = self.options[self.index] + _, option = filtered[self.index] if not isinstance(option, Option) or option.enabled: break def mark_index(self) -> None: + filtered = self.get_filtered_options() + if not filtered: + self.index = 0 + return + if self.index >= len(filtered): + self.index = 0 + orig_index, _ = filtered[self.index] if self.multiselect: - if self.index in self.selected_indexes: - self.selected_indexes.remove(self.index) + if orig_index in self.selected_indexes: + self.selected_indexes.remove(orig_index) else: - self.selected_indexes.append(self.index) + self.selected_indexes.append(orig_index) - def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]: + def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T, None]: """return the current selected option as a tuple: (option, index) or as a list of tuples (in case multiselect==True) """ @@ -116,21 +154,36 @@ def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]: return_tuples.append((self.options[selected], selected)) return return_tuples else: - return self.options[self.index], self.index + filtered = self.get_filtered_options() + if not filtered: + return None + if self.index >= len(filtered): + self.index = 0 + orig_index, option = filtered[self.index] + return option, orig_index def get_title_lines(self, *, max_width: int = 80) -> List[str]: - if not self.title: - return [] + lines = [] + if self.title: + if "\n" in self.title: + lines.extend(self.title.split("\n")) + else: + lines.extend(textwrap.fill(self.title, max_width - 2, drop_whitespace=False).split("\n")) + lines.append("") + + if self.enable_search: + lines.append(f"Search: {self._search_string}") + lines.append("") - if "\n" in self.title: - lines = self.title.split("\n") - else: - lines = textwrap.fill(self.title, max_width - 2, drop_whitespace=False).split("\n") - return lines + [""] + return lines def get_option_lines(self) -> List[str]: lines: List[str] = [] - for index, option in enumerate(self.options): + filtered = self.get_filtered_options() + if not filtered: + return ["No results"] + + for index, (orig_index, option) in enumerate(filtered): if index == self.index: prefix = self.indicator else: @@ -139,7 +192,7 @@ def get_option_lines(self) -> List[str]: if self.multiselect: symbol = ( SYMBOL_CIRCLE_FILLED - if index in self.selected_indexes + if orig_index in self.selected_indexes else SYMBOL_CIRCLE_EMPTY ) prefix = f"{prefix} {symbol}" @@ -184,24 +237,26 @@ def draw(self, screen: Backend) -> None: title_length = len(self.get_title_lines(max_width=max_x)) for i, line in enumerate(lines_to_draw): - if description_present and i > title_length: + if description_present and i >= title_length: screen.addnstr(y, x, line, max_x // 2 - 2) else: screen.addnstr(y, x, line, max_x - 2) y += 1 - option = self.options[self.index] - if isinstance(option, Option) and option.description is not None: - description_lines = textwrap.fill(option.description, max_x // 2 - 2).split('\n') + filtered = self.get_filtered_options() + if filtered and self.index < len(filtered): + _, option = filtered[self.index] + if isinstance(option, Option) and option.description is not None: + description_lines = textwrap.fill(option.description, max_x // 2 - 2).split('\n') - for i, line in enumerate(description_lines): - screen.addnstr(i + title_length, max_x // 2, line, max_x - 2) + for i, line in enumerate(description_lines): + screen.addnstr(i + title_length, max_x // 2, line, max_x - 2) screen.refresh() def run_loop( self, screen: Backend, position: Position - ) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]: + ) -> Union[List[PICK_RETURN_T], PICK_RETURN_T, None]: while True: self.draw(screen) c = screen.getch() @@ -223,6 +278,15 @@ def run_loop( return self.get_selected() elif c in KEYS_SELECT and self.multiselect: self.mark_index() + elif self.enable_search: + if c in KEYS_BACKSPACE: + self._search_string = self._search_string[:-1] + self.index = 0 + self._ensure_valid_selection() + elif 0 <= c <= 255 and chr(c).isprintable(): + self._search_string += chr(c) + self.index = 0 + self._ensure_valid_selection() def _resolve_backend(self) -> Backend: if isinstance(self.backend, Backend): @@ -238,12 +302,9 @@ def _resolve_backend(self) -> Backend: def config_curses(self) -> None: try: - # use the default colors of the terminal curses.use_default_colors() - # hide the cursor curses.curs_set(0) except Exception: - # Curses failed to initialize color support, eg. when TERM=vt100 curses.initscr() def _start(self, screen: "curses._CursesWindow"): @@ -253,21 +314,18 @@ def _start(self, screen: "curses._CursesWindow"): def start(self): backend = self._resolve_backend() if isinstance(backend, CursesBackend) and backend._screen is not None: - # Embedded in an existing curses application (backward-compatible) last_cur = curses.curs_set(0) ret = self.run_loop(backend, self.position) if last_cur: curses.curs_set(last_cur) return ret elif isinstance(backend, CursesBackend): - # Standalone curses mode def _curses_main(screen: "curses._CursesWindow"): backend._screen = screen backend.setup() return self.run_loop(backend, self.position) return curses.wrapper(_curses_main) else: - # Other backends (e.g. blessed) backend.setup() try: return self.run_loop(backend, self.position) @@ -287,6 +345,7 @@ def pick( clear_screen: bool = True, quit_keys: Optional[Union[Container[int], Iterable[int]]] = None, backend: Union[str, Backend] = "curses", + enable_search: bool = False, ): picker: Picker = Picker( options, @@ -300,5 +359,6 @@ def pick( clear_screen, quit_keys, backend, + enable_search, ) return picker.start()