diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index 6578bff5f..92ff504ab 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -4,6 +4,7 @@ import json import os import shutil +import threading from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -1099,8 +1100,21 @@ def save_basic_memory_config(file_path: Path, config: BasicMemoryConfig) -> None _secure_config_dir(file_path.parent) # Use model_dump with mode='json' to serialize datetime objects properly config_dict = config.model_dump(mode="json") - file_path.write_text(json.dumps(config_dict, indent=2)) - _secure_config_file(file_path) + # Trigger: long-lived readers (MCP stdio server config reload, background + # auto-update threads) re-read config.json whenever its mtime changes, + # concurrently with CLI commands saving it. + # Why: writing the destination in place truncates it first, so a concurrent + # reader can observe empty/partial JSON and load_config() exits the process. + # Outcome: write a sibling temp file (unique per process/thread so parallel + # savers cannot interleave) and publish atomically via os.replace — readers + # always see either the old or the new complete document. (#940) + tmp_path = file_path.parent / f"{file_path.name}.{os.getpid()}.{threading.get_ident()}.tmp" + try: + tmp_path.write_text(json.dumps(config_dict, indent=2)) + _secure_config_file(tmp_path) + os.replace(tmp_path, file_path) + finally: + tmp_path.unlink(missing_ok=True) except Exception as e: # pragma: no cover logger.error(f"Failed to save config: {e}") diff --git a/src/basic_memory/db.py b/src/basic_memory/db.py index d9f56a05b..815005113 100644 --- a/src/basic_memory/db.py +++ b/src/basic_memory/db.py @@ -19,7 +19,7 @@ AsyncEngine, async_scoped_session, ) -from sqlalchemy.pool import NullPool +from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool from basic_memory.repository.postgres_search_repository import PostgresSearchRepository from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository @@ -216,19 +216,31 @@ def _create_sqlite_engine(db_url: str, db_type: DatabaseType) -> AsyncEngine: "isolation_level": None, # Use autocommit mode } ) + + if db_type == DatabaseType.MEMORY: + # Trigger: an in-memory SQLite URL would default to StaticPool, which hands the + # same DBAPI connection to every concurrently checked-out session. + # Why: concurrent asyncio tasks then share one transaction scope — a rollback + # issued by one session (scoped_session exception handling or the pool's + # reset-on-return) silently destroys another session's uncommitted writes (#940). + # Outcome: a single-connection blocking queue pool keeps the in-memory database + # alive for the engine's lifetime while serializing sessions at transaction + # granularity, restoring the isolation the repositories assume. + engine = create_async_engine( + db_url, + connect_args=connect_args, + poolclass=AsyncAdaptedQueuePool, + pool_size=1, + max_overflow=0, + ) + elif os.name == "nt": # Use NullPool for Windows filesystem databases to avoid connection pooling issues - # Important: Do NOT use NullPool for in-memory databases as it will destroy the database - # between connections - if db_type == DatabaseType.FILESYSTEM: - engine = create_async_engine( - db_url, - connect_args=connect_args, - poolclass=NullPool, # Disable connection pooling on Windows - echo=False, - ) - else: - # In-memory databases need connection pooling to maintain state - engine = create_async_engine(db_url, connect_args=connect_args) + engine = create_async_engine( + db_url, + connect_args=connect_args, + poolclass=NullPool, # Disable connection pooling on Windows + echo=False, + ) else: engine = create_async_engine(db_url, connect_args=connect_args) diff --git a/test-int/cli/test_routing_integration.py b/test-int/cli/test_routing_integration.py index c204dec8a..e63ceeecf 100644 --- a/test-int/cli/test_routing_integration.py +++ b/test-int/cli/test_routing_integration.py @@ -89,6 +89,27 @@ def test_tool_edit_note_both_flags_error(self): assert "Cannot specify both --local and --cloud" in result.output +def _stub_auto_update(monkeypatch, mcp_mod) -> None: + """Keep `bm mcp` stdio tests from running the real background auto-update. + + The command starts a daemon thread before mcp_server.run; unstubbed it hits + PyPI and rewrites config.json from a background thread, leaking into later + tests in the same process (#940's KeyError flake on test_mcp_sse_forces_local). + """ + from basic_memory.cli.auto_update import AutoUpdateResult, AutoUpdateStatus, InstallSource + + def skipped_auto_update(**kwargs) -> AutoUpdateResult: + return AutoUpdateResult( + status=AutoUpdateStatus.SKIPPED, + source=InstallSource.UNKNOWN, + checked=False, + update_available=False, + updated=False, + ) + + monkeypatch.setattr(mcp_mod, "run_auto_update", skipped_auto_update) + + class TestMcpCommandRouting: """Tests that MCP routing varies by transport.""" @@ -109,8 +130,10 @@ def mock_run(*args, **kwargs): monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run) monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None) + _stub_auto_update(monkeypatch, mcp_mod) - runner.invoke(cli_app, ["mcp"]) # default transport is stdio + result = runner.invoke(cli_app, ["mcp"]) # default transport is stdio + assert result.exit_code == 0, result.output # Command should not have set these vars assert env_at_run["FORCE_LOCAL"] is None @@ -132,8 +155,10 @@ def mock_run(*args, **kwargs): monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run) monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None) + _stub_auto_update(monkeypatch, mcp_mod) - runner.invoke(cli_app, ["mcp"]) + result = runner.invoke(cli_app, ["mcp"]) + assert result.exit_code == 0, result.output # Externally-set vars should be preserved assert env_at_run["FORCE_CLOUD"] == "true" @@ -153,7 +178,8 @@ def mock_run(*args, **kwargs): monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run) monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None) - runner.invoke(cli_app, ["mcp", "--transport", "streamable-http"]) + result = runner.invoke(cli_app, ["mcp", "--transport", "streamable-http"]) + assert result.exit_code == 0, result.output assert env_at_run["FORCE_LOCAL"] == "true" assert env_at_run["EXPLICIT"] == "true" @@ -172,7 +198,8 @@ def mock_run(*args, **kwargs): monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run) monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None) - runner.invoke(cli_app, ["mcp", "--transport", "sse"]) + result = runner.invoke(cli_app, ["mcp", "--transport", "sse"]) + assert result.exit_code == 0, result.output assert env_at_run["FORCE_LOCAL"] == "true" assert env_at_run["EXPLICIT"] == "true" diff --git a/tests/db/test_memory_db_session_isolation.py b/tests/db/test_memory_db_session_isolation.py new file mode 100644 index 000000000..1e25b4d93 --- /dev/null +++ b/tests/db/test_memory_db_session_isolation.py @@ -0,0 +1,93 @@ +"""Regression test for issue #940: lost writes on the in-memory SQLite engine. + +The in-memory SQLite URL (``sqlite+aiosqlite://``) used to fall back to +SQLAlchemy's StaticPool, which hands the same DBAPI connection to every +concurrently checked-out session. Concurrent asyncio tasks then share one +transaction scope: a rollback issued by one session — scoped_session's +exception handling, or the pool's reset-on-return at checkin — silently rolls +back another session's uncommitted writes. During sync this manifested as a +freshly inserted relation row vanishing without any error, which is how +``test_sync_entity_circular_relations`` failed on CI with +``len(entity_b.outgoing_relations) == 0``. + +This test pins the isolation invariant directly: a session that rolls back in +one task must never destroy an uncommitted write of a session in another task. +""" + +import asyncio +from pathlib import Path + +import pytest +from sqlalchemy import text + +from basic_memory import db +from basic_memory.models import Base + + +class _SimulatedIndexingFailure(Exception): + """Stand-in for any per-file error that _run_bounded swallows during sync.""" + + +@pytest.mark.asyncio +async def test_concurrent_session_rollback_does_not_destroy_uncommitted_writes(): + """A rolled-back session in one task must not erase another task's writes.""" + async with db.engine_session_factory( + db_path=Path("unused.db"), db_type=db.DatabaseType.MEMORY + ) as (engine, session_maker): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Seed a project and entity so a relation row satisfies its FK constraints. + async with db.scoped_session(session_maker) as session: + await session.execute( + text( + "INSERT INTO project (id, external_id, name, description, path, is_active," + " is_default, created_at, updated_at, permalink) " + "VALUES (1, 'px', 'p', '', '/tmp', 1, 1, '2024-01-01', '2024-01-01', 'p')" + ) + ) + await session.execute( + text( + "INSERT INTO entity (id, external_id, title, note_type, content_type," + " project_id, file_path, created_at, updated_at) " + "VALUES (1, 'ex', 'E', 'note', 'text/markdown', 1, 'e.md'," + " '2024-01-01', '2024-01-01')" + ) + ) + + write_in_flight = asyncio.Event() + + async def writer() -> None: + # Mirrors RelationRepository.add_all_ignore_duplicates: INSERT executed, + # commit only happens at scoped_session exit several awaits later. + async with db.scoped_session(session_maker) as session: + await session.execute( + text( + "INSERT INTO relation (project_id, from_id, to_id, to_name," + " relation_type) VALUES (1, 1, NULL, 'target', 'depends_on')" + ) + ) + write_in_flight.set() + # Real DB roundtrips (not sleeps) keep this transaction open across + # await points, exactly like the multi-statement sessions in sync. + for _ in range(10): + await session.execute(text("SELECT 1")) + + async def failing_reader() -> None: + # Mirrors any per-file failure during batch indexing: scoped_session + # rolls back on the exception path while sibling tasks are mid-write. + await write_in_flight.wait() + with pytest.raises(_SimulatedIndexingFailure): + async with db.scoped_session(session_maker) as session: + await session.execute(text("SELECT 1")) + raise _SimulatedIndexingFailure() + + await asyncio.gather(writer(), failing_reader()) + + async with db.scoped_session(session_maker) as session: + count = (await session.execute(text("SELECT count(*) FROM relation"))).scalar() + + assert count == 1, ( + "writer's committed INSERT was rolled back by a concurrent session — " + "the in-memory engine is sharing one transaction scope across sessions" + ) diff --git a/tests/repository/test_entity_repository.py b/tests/repository/test_entity_repository.py index 194722af5..8464b42c9 100644 --- a/tests/repository/test_entity_repository.py +++ b/tests/repository/test_entity_repository.py @@ -291,9 +291,11 @@ async def test_delete_entity_with_relations(entity_repository: EntityRepository, remaining_relations = result.scalars().all() assert len(remaining_relations) == 0 - # Verify target entity still exists - target_exists = await entity_repository.find_by_id(target.id) - assert target_exists is not None + # Verify target entity still exists. Runs outside the session block above: + # find_by_id opens its own scoped session, and the serialized in-memory pool + # (one connection, see #940) deadlocks on nested session checkouts. + target_exists = await entity_repository.find_by_id(target.id) + assert target_exists is not None @pytest.mark.asyncio diff --git a/tests/test_config.py b/tests/test_config.py index 4f3bab0b3..f88ad8399 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1600,3 +1600,62 @@ def test_auto_update_round_trip_persistence(self): assert loaded.auto_update is False assert loaded.update_check_interval == 7200 assert loaded.auto_update_last_checked_at == checked_at + + +class TestAtomicConfigSave: + """Regression tests for #940: saving config must never tear the published file. + + Long-lived readers (the MCP stdio server's mtime-based config reload, the CLI + background auto-update thread) re-read config.json while other code saves it. + An in-place write truncates the file first, so a concurrent reader can observe + empty/partial JSON — and load_config() raises SystemExit on invalid JSON. + """ + + def test_interrupted_save_preserves_published_config(self, config_home, monkeypatch): + """A write that dies mid-stream must leave the existing config untouched.""" + import json + + from basic_memory.config import save_basic_memory_config + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + config_file = temp_path / "config.json" + config = BasicMemoryConfig( + projects={"main": {"path": str(temp_path / "main")}}, + default_project="main", + ) + save_basic_memory_config(config_file, config) + published = config_file.read_text(encoding="utf-8") + json.loads(published) # sanity: complete, valid document + + def torn_write_text(self, content, *args, **kwargs): + # Fault injection: the write dies halfway through. For an in-place + # write this is exactly the truncated state a concurrent reader + # observes mid-save; an atomic save must confine it to a temp file. + with open(self, "w", encoding="utf-8") as fh: + fh.write(content[: len(content) // 2]) + raise OSError("simulated interrupted write") + + monkeypatch.setattr(Path, "write_text", torn_write_text) + # save_basic_memory_config logs write failures instead of raising + save_basic_memory_config(config_file, config) + monkeypatch.undo() + + assert config_file.read_text(encoding="utf-8") == published + + def test_save_leaves_no_temp_files(self, config_home): + """The atomic-write temp file must not survive a successful save.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + config_file = temp_path / "config.json" + config = BasicMemoryConfig( + projects={"main": {"path": str(temp_path / "main")}}, + default_project="main", + ) + + from basic_memory.config import save_basic_memory_config + + save_basic_memory_config(config_file, config) + + assert config_file.exists() + assert not list(temp_path.glob("*.tmp"))