Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 92 additions & 32 deletions src/pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Option:
KEYS_UP = (curses.KEY_UP, ord("k"))
KEYS_DOWN = (curses.KEY_DOWN, ord("j"))
KEYS_SELECT = (curses.KEY_RIGHT, ord(" "))
KEYS_BACKSPACE = (curses.KEY_BACKSPACE, 127, 8)

SYMBOL_CIRCLE_FILLED = "(x)"
SYMBOL_CIRCLE_EMPTY = "( )"
Expand All @@ -58,6 +59,8 @@ class Picker(Generic[OPTION_T]):
clear_screen: bool = True
quit_keys: Optional[Union[Container[int], Iterable[int]]] = None
backend: Union[str, Backend] = "curses"
enable_search: bool = False
_search_string: str = field(init=False, default="")

def __post_init__(self) -> None:
if len(self.options) == 0:
Expand All @@ -77,36 +80,71 @@ def __post_init__(self) -> None:
)

self.index = self.default_index
option = self.options[self.index]
self._ensure_valid_selection()

def _ensure_valid_selection(self) -> None:
filtered = self.get_filtered_options()
if not filtered:
return
if self.index >= len(filtered):
self.index = 0
_, option = filtered[self.index]
if isinstance(option, Option) and not option.enabled:
self.move_down()

def get_filtered_options(self) -> List[Tuple[int, OPTION_T]]:
if not self.enable_search or not self._search_string:
return list(enumerate(self.options))

query = self._search_string.lower()
filtered = []
for i, opt in enumerate(self.options):
label = opt.label if isinstance(opt, Option) else str(opt)
if query in label.lower():
filtered.append((i, opt))
return filtered

def move_up(self) -> None:
filtered = self.get_filtered_options()
if not filtered:
self.index = 0
return
while True:
self.index -= 1
if self.index < 0:
self.index = len(self.options) - 1
option = self.options[self.index]
self.index = len(filtered) - 1
_, option = filtered[self.index]
if not isinstance(option, Option) or option.enabled:
break

def move_down(self) -> None:
filtered = self.get_filtered_options()
if not filtered:
self.index = 0
return
while True:
self.index += 1
if self.index >= len(self.options):
if self.index >= len(filtered):
self.index = 0
option = self.options[self.index]
_, option = filtered[self.index]
if not isinstance(option, Option) or option.enabled:
break

def mark_index(self) -> None:
filtered = self.get_filtered_options()
if not filtered:
self.index = 0
return
if self.index >= len(filtered):
self.index = 0
orig_index, _ = filtered[self.index]
if self.multiselect:
if self.index in self.selected_indexes:
self.selected_indexes.remove(self.index)
if orig_index in self.selected_indexes:
self.selected_indexes.remove(orig_index)
else:
self.selected_indexes.append(self.index)
self.selected_indexes.append(orig_index)

def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]:
def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T, None]:
"""return the current selected option as a tuple: (option, index)
or as a list of tuples (in case multiselect==True)
"""
Expand All @@ -116,21 +154,36 @@ def get_selected(self) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]:
return_tuples.append((self.options[selected], selected))
return return_tuples
else:
return self.options[self.index], self.index
filtered = self.get_filtered_options()
if not filtered:
return None
if self.index >= len(filtered):
self.index = 0
orig_index, option = filtered[self.index]
return option, orig_index

def get_title_lines(self, *, max_width: int = 80) -> List[str]:
if not self.title:
return []
lines = []
if self.title:
if "\n" in self.title:
lines.extend(self.title.split("\n"))
else:
lines.extend(textwrap.fill(self.title, max_width - 2, drop_whitespace=False).split("\n"))
lines.append("")

if self.enable_search:
lines.append(f"Search: {self._search_string}")
lines.append("")

if "\n" in self.title:
lines = self.title.split("\n")
else:
lines = textwrap.fill(self.title, max_width - 2, drop_whitespace=False).split("\n")
return lines + [""]
return lines

def get_option_lines(self) -> List[str]:
lines: List[str] = []
for index, option in enumerate(self.options):
filtered = self.get_filtered_options()
if not filtered:
return ["No results"]

for index, (orig_index, option) in enumerate(filtered):
if index == self.index:
prefix = self.indicator
else:
Expand All @@ -139,7 +192,7 @@ def get_option_lines(self) -> List[str]:
if self.multiselect:
symbol = (
SYMBOL_CIRCLE_FILLED
if index in self.selected_indexes
if orig_index in self.selected_indexes
else SYMBOL_CIRCLE_EMPTY
)
prefix = f"{prefix} {symbol}"
Expand Down Expand Up @@ -184,24 +237,26 @@ def draw(self, screen: Backend) -> None:
title_length = len(self.get_title_lines(max_width=max_x))

for i, line in enumerate(lines_to_draw):
if description_present and i > title_length:
if description_present and i >= title_length:
screen.addnstr(y, x, line, max_x // 2 - 2)
else:
screen.addnstr(y, x, line, max_x - 2)
y += 1

option = self.options[self.index]
if isinstance(option, Option) and option.description is not None:
description_lines = textwrap.fill(option.description, max_x // 2 - 2).split('\n')
filtered = self.get_filtered_options()
if filtered and self.index < len(filtered):
_, option = filtered[self.index]
if isinstance(option, Option) and option.description is not None:
description_lines = textwrap.fill(option.description, max_x // 2 - 2).split('\n')

for i, line in enumerate(description_lines):
screen.addnstr(i + title_length, max_x // 2, line, max_x - 2)
for i, line in enumerate(description_lines):
screen.addnstr(i + title_length, max_x // 2, line, max_x - 2)

screen.refresh()

def run_loop(
self, screen: Backend, position: Position
) -> Union[List[PICK_RETURN_T], PICK_RETURN_T]:
) -> Union[List[PICK_RETURN_T], PICK_RETURN_T, None]:
while True:
self.draw(screen)
c = screen.getch()
Expand All @@ -223,6 +278,15 @@ def run_loop(
return self.get_selected()
elif c in KEYS_SELECT and self.multiselect:
self.mark_index()
elif self.enable_search:
if c in KEYS_BACKSPACE:
self._search_string = self._search_string[:-1]
self.index = 0
self._ensure_valid_selection()
elif 0 <= c <= 255 and chr(c).isprintable():
self._search_string += chr(c)
self.index = 0
self._ensure_valid_selection()

def _resolve_backend(self) -> Backend:
if isinstance(self.backend, Backend):
Expand All @@ -238,12 +302,9 @@ def _resolve_backend(self) -> Backend:

def config_curses(self) -> None:
try:
# use the default colors of the terminal
curses.use_default_colors()
# hide the cursor
curses.curs_set(0)
except Exception:
# Curses failed to initialize color support, eg. when TERM=vt100
curses.initscr()

def _start(self, screen: "curses._CursesWindow"):
Expand All @@ -253,21 +314,18 @@ def _start(self, screen: "curses._CursesWindow"):
def start(self):
backend = self._resolve_backend()
if isinstance(backend, CursesBackend) and backend._screen is not None:
# Embedded in an existing curses application (backward-compatible)
last_cur = curses.curs_set(0)
ret = self.run_loop(backend, self.position)
if last_cur:
curses.curs_set(last_cur)
return ret
elif isinstance(backend, CursesBackend):
# Standalone curses mode
def _curses_main(screen: "curses._CursesWindow"):
backend._screen = screen
backend.setup()
return self.run_loop(backend, self.position)
return curses.wrapper(_curses_main)
else:
# Other backends (e.g. blessed)
backend.setup()
try:
return self.run_loop(backend, self.position)
Expand All @@ -287,6 +345,7 @@ def pick(
clear_screen: bool = True,
quit_keys: Optional[Union[Container[int], Iterable[int]]] = None,
backend: Union[str, Backend] = "curses",
enable_search: bool = False,
):
picker: Picker = Picker(
options,
Expand All @@ -300,5 +359,6 @@ def pick(
clear_screen,
quit_keys,
backend,
enable_search,
)
return picker.start()