Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
"""Pytest configuration and shared fixtures for TruShell tests."""

import sqlite3

import pytest

from trushell.core import database
from trushell.core.plugin_manager import PluginManager


@pytest.fixture
def in_memory_database(monkeypatch):
"""Use one in-memory SQLite connection for database CRUD tests."""
connection = sqlite3.connect(":memory:", check_same_thread=False)
connection.execute(
"""CREATE TABLE todos (
task TEXT,
category TEXT,
date_added TEXT,
date_completed TEXT,
status INTEGER,
position INTEGER
)"""
)
monkeypatch.setattr(database, "get_db_connection", lambda: connection)

yield connection

connection.close()


@pytest.fixture(autouse=True)
def reset_plugin_manager():
"""Reset PluginManager singleton after each test to ensure isolation."""
Expand Down
8 changes: 0 additions & 8 deletions tests/test_cli_argv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@


def test_app_with_lower_does_not_mutate_original_argv(monkeypatch):

# NOTE: `trushell.cli.app_with_lower()` references a module-level
# name `argv` that is not defined in the module. Tests exercising the
# early-return paths must inject this name into the module namespace
# (this mirrors how callers may set it in other environments).
original = ["trushell", "HeLp"]
monkeypatch.setattr(cli.sys, "argv", original)
# Inject a module-level `argv` name for the duration of the test so
# `app_with_lower()` can compare against the real `sys.argv`.
monkeypatch.setattr(cli, "argv", cli.sys.argv, raising=False)

calls: list[str] = []

Expand Down
55 changes: 55 additions & 0 deletions tests/test_cli_os_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

import os
import subprocess

from trushell import cli


def test_os_fallback_runs_argument_vector_without_shell(monkeypatch):
calls = {}

def fake_run(command, shell, check, cwd):
calls.update(
command=command,
shell=shell,
check=check,
cwd=cwd,
)
return subprocess.CompletedProcess(args=command, returncode=0)

monkeypatch.setattr(cli, "_run_external_command", fake_run)

assert cli._handle_os_fallback('echo "hello world"') is True
assert calls == {
"command": ["echo", "hello world"],
"shell": False,
"check": False,
"cwd": os.getcwd(),
}


def test_os_fallback_does_not_interpret_shell_metacharacters(monkeypatch):
calls = {}

def fake_run(command, shell, check, cwd):
calls.update(command=command, shell=shell)
return subprocess.CompletedProcess(args=command, returncode=0)

monkeypatch.setattr(cli, "_run_external_command", fake_run)

assert cli._handle_os_fallback("echo safe; touch injected") is True
assert calls == {
"command": ["echo", "safe;", "touch", "injected"],
"shell": False,
}


def test_os_fallback_handles_malformed_quoting(monkeypatch, capsys):
def fail_if_called(*args, **kwargs):
raise AssertionError("external command should not run")

monkeypatch.setattr(cli, "_run_external_command", fail_if_called)

assert cli._handle_os_fallback('echo "unterminated') is True
assert "OS fallback error" in capsys.readouterr().out
28 changes: 13 additions & 15 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from trushell.core.database import _ensure_initialized, get_all_todos, get_db_connection, insert_todo
from trushell.core.database import (
_ensure_initialized,
get_all_todos,
get_db_connection,
insert_todo,
)
from trushell.core.models import Todo


Expand All @@ -20,10 +25,7 @@ def test_get_db_connection_returns_fresh_connection(monkeypatch, tmp_path) -> No
conn_two.close()


def test_insert_todo_assigns_sequential_positions(monkeypatch, tmp_path) -> None:
_use_temp_database(monkeypatch, tmp_path)

_ensure_initialized()
def test_insert_todo_assigns_sequential_positions(in_memory_database) -> None:
insert_todo(Todo(task="first", category="work"))
insert_todo(Todo(task="second", category="work"))

Expand All @@ -33,20 +35,14 @@ def test_insert_todo_assigns_sequential_positions(monkeypatch, tmp_path) -> None
assert [task.position for task in tasks] == [0, 1]


def test_get_all_todos_works_with_local_connections(monkeypatch, tmp_path) -> None:
_use_temp_database(monkeypatch, tmp_path)

_ensure_initialized()
def test_get_all_todos_works_with_local_connections(in_memory_database) -> None:
insert_todo(Todo(task="alpha", category="study"))

assert len(get_all_todos()) == 1


def test_get_all_todos_returns_rows_ordered_by_position(monkeypatch, tmp_path) -> None:
_use_temp_database(monkeypatch, tmp_path)

_ensure_initialized()
with get_db_connection() as conn:
def test_get_all_todos_returns_rows_ordered_by_position(in_memory_database) -> None:
with in_memory_database as conn:
conn.execute(
"INSERT INTO todos VALUES (?, ?, ?, ?, ?, ?)",
("second", "work", "", None, 0, 1),
Expand All @@ -62,7 +58,9 @@ def test_get_all_todos_returns_rows_ordered_by_position(monkeypatch, tmp_path) -
assert [task.position for task in tasks] == [0, 1]


def test_ensure_initialized_skips_lock_when_already_initialized(monkeypatch, tmp_path) -> None:
def test_ensure_initialized_skips_lock_when_already_initialized(
monkeypatch, tmp_path
) -> None:
_use_temp_database(monkeypatch, tmp_path)
_ensure_initialized()

Expand Down
16 changes: 9 additions & 7 deletions trushell/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@


def app_with_lower() -> None:
"""Entry point that normalizes the first argument to lowercase for case-insensitive invocation."""
"""Normalize command names without changing the caller's argument list."""
if len(sys.argv) > 1:
# Create a local copy to avoid mutating the global sys.argv
argv_copy = sys.argv.copy()
if argv_copy[1].lower() not in {"--help", "-h", "version"}:
# Normalize the command name to lowercase for case-insensitive
Expand All @@ -39,8 +38,6 @@ def app_with_lower() -> None:
get_kernel().execute_command(raw)
return

if argv != sys.argv:
sys.argv = argv
app()


Expand Down Expand Up @@ -119,7 +116,7 @@ def action_quit_app(self) -> None:


def _run_external_command(
command: str,
command: str | list[str],
shell: bool = True,
check: bool = False,
cwd: str | None = None,
Expand Down Expand Up @@ -307,8 +304,13 @@ def _handle_os_fallback(raw_command: str) -> bool:
return False

try:
completed = _run_external_command(command, shell=True, check=False, cwd=os.getcwd())
except (OSError, subprocess.SubprocessError) as error:
completed = _run_external_command(
shlex.split(command),
shell=False,
check=False,
cwd=os.getcwd(),
)
except (OSError, subprocess.SubprocessError, ValueError) as error:
typer.secho("❓ Command not recognized by TruShell or your host OS.", fg=typer.colors.YELLOW)
typer.secho(f"OS fallback error: {error}", fg=typer.colors.RED)
return True
Expand Down
Loading