Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/basic_memory/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import shutil
import threading
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -1099,8 +1100,21 @@ def save_basic_memory_config(file_path: Path, config: BasicMemoryConfig) -> None
_secure_config_dir(file_path.parent)
# Use model_dump with mode='json' to serialize datetime objects properly
config_dict = config.model_dump(mode="json")
file_path.write_text(json.dumps(config_dict, indent=2))
_secure_config_file(file_path)
# Trigger: long-lived readers (MCP stdio server config reload, background
# auto-update threads) re-read config.json whenever its mtime changes,
# concurrently with CLI commands saving it.
# Why: writing the destination in place truncates it first, so a concurrent
# reader can observe empty/partial JSON and load_config() exits the process.
# Outcome: write a sibling temp file (unique per process/thread so parallel
# savers cannot interleave) and publish atomically via os.replace — readers
# always see either the old or the new complete document. (#940)
tmp_path = file_path.parent / f"{file_path.name}.{os.getpid()}.{threading.get_ident()}.tmp"
try:
tmp_path.write_text(json.dumps(config_dict, indent=2))
_secure_config_file(tmp_path)
os.replace(tmp_path, file_path)
finally:
tmp_path.unlink(missing_ok=True)
except Exception as e: # pragma: no cover
logger.error(f"Failed to save config: {e}")

Expand Down
38 changes: 25 additions & 13 deletions src/basic_memory/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
AsyncEngine,
async_scoped_session,
)
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool

from basic_memory.repository.postgres_search_repository import PostgresSearchRepository
from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository
Expand Down Expand Up @@ -216,19 +216,31 @@ def _create_sqlite_engine(db_url: str, db_type: DatabaseType) -> AsyncEngine:
"isolation_level": None, # Use autocommit mode
}
)

if db_type == DatabaseType.MEMORY:
# Trigger: an in-memory SQLite URL would default to StaticPool, which hands the
# same DBAPI connection to every concurrently checked-out session.
# Why: concurrent asyncio tasks then share one transaction scope — a rollback
# issued by one session (scoped_session exception handling or the pool's
# reset-on-return) silently destroys another session's uncommitted writes (#940).
# Outcome: a single-connection blocking queue pool keeps the in-memory database
# alive for the engine's lifetime while serializing sessions at transaction
# granularity, restoring the isolation the repositories assume.
engine = create_async_engine(
db_url,
connect_args=connect_args,
poolclass=AsyncAdaptedQueuePool,
pool_size=1,
max_overflow=0,
)
elif os.name == "nt":
# Use NullPool for Windows filesystem databases to avoid connection pooling issues
# Important: Do NOT use NullPool for in-memory databases as it will destroy the database
# between connections
if db_type == DatabaseType.FILESYSTEM:
engine = create_async_engine(
db_url,
connect_args=connect_args,
poolclass=NullPool, # Disable connection pooling on Windows
echo=False,
)
else:
# In-memory databases need connection pooling to maintain state
engine = create_async_engine(db_url, connect_args=connect_args)
engine = create_async_engine(
db_url,
connect_args=connect_args,
poolclass=NullPool, # Disable connection pooling on Windows
echo=False,
)
else:
engine = create_async_engine(db_url, connect_args=connect_args)

Expand Down
35 changes: 31 additions & 4 deletions test-int/cli/test_routing_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ def test_tool_edit_note_both_flags_error(self):
assert "Cannot specify both --local and --cloud" in result.output


def _stub_auto_update(monkeypatch, mcp_mod) -> None:
"""Keep `bm mcp` stdio tests from running the real background auto-update.

The command starts a daemon thread before mcp_server.run; unstubbed it hits
PyPI and rewrites config.json from a background thread, leaking into later
tests in the same process (#940's KeyError flake on test_mcp_sse_forces_local).
"""
from basic_memory.cli.auto_update import AutoUpdateResult, AutoUpdateStatus, InstallSource

def skipped_auto_update(**kwargs) -> AutoUpdateResult:
return AutoUpdateResult(
status=AutoUpdateStatus.SKIPPED,
source=InstallSource.UNKNOWN,
checked=False,
update_available=False,
updated=False,
)

monkeypatch.setattr(mcp_mod, "run_auto_update", skipped_auto_update)


class TestMcpCommandRouting:
"""Tests that MCP routing varies by transport."""

Expand All @@ -109,8 +130,10 @@ def mock_run(*args, **kwargs):

monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run)
monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None)
_stub_auto_update(monkeypatch, mcp_mod)

runner.invoke(cli_app, ["mcp"]) # default transport is stdio
result = runner.invoke(cli_app, ["mcp"]) # default transport is stdio
assert result.exit_code == 0, result.output

# Command should not have set these vars
assert env_at_run["FORCE_LOCAL"] is None
Expand All @@ -132,8 +155,10 @@ def mock_run(*args, **kwargs):

monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run)
monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None)
_stub_auto_update(monkeypatch, mcp_mod)

runner.invoke(cli_app, ["mcp"])
result = runner.invoke(cli_app, ["mcp"])
assert result.exit_code == 0, result.output

# Externally-set vars should be preserved
assert env_at_run["FORCE_CLOUD"] == "true"
Expand All @@ -153,7 +178,8 @@ def mock_run(*args, **kwargs):
monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run)
monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None)

runner.invoke(cli_app, ["mcp", "--transport", "streamable-http"])
result = runner.invoke(cli_app, ["mcp", "--transport", "streamable-http"])
assert result.exit_code == 0, result.output

assert env_at_run["FORCE_LOCAL"] == "true"
assert env_at_run["EXPLICIT"] == "true"
Expand All @@ -172,7 +198,8 @@ def mock_run(*args, **kwargs):
monkeypatch.setattr(mcp_mod.mcp_server, "run", mock_run)
monkeypatch.setattr(mcp_mod, "init_mcp_logging", lambda: None)

runner.invoke(cli_app, ["mcp", "--transport", "sse"])
result = runner.invoke(cli_app, ["mcp", "--transport", "sse"])
assert result.exit_code == 0, result.output

assert env_at_run["FORCE_LOCAL"] == "true"
assert env_at_run["EXPLICIT"] == "true"
Expand Down
93 changes: 93 additions & 0 deletions tests/db/test_memory_db_session_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Regression test for issue #940: lost writes on the in-memory SQLite engine.

The in-memory SQLite URL (``sqlite+aiosqlite://``) used to fall back to
SQLAlchemy's StaticPool, which hands the same DBAPI connection to every
concurrently checked-out session. Concurrent asyncio tasks then share one
transaction scope: a rollback issued by one session — scoped_session's
exception handling, or the pool's reset-on-return at checkin — silently rolls
back another session's uncommitted writes. During sync this manifested as a
freshly inserted relation row vanishing without any error, which is how
``test_sync_entity_circular_relations`` failed on CI with
``len(entity_b.outgoing_relations) == 0``.

This test pins the isolation invariant directly: a session that rolls back in
one task must never destroy an uncommitted write of a session in another task.
"""

import asyncio
from pathlib import Path

import pytest
from sqlalchemy import text

from basic_memory import db
from basic_memory.models import Base


class _SimulatedIndexingFailure(Exception):
"""Stand-in for any per-file error that _run_bounded swallows during sync."""


@pytest.mark.asyncio
async def test_concurrent_session_rollback_does_not_destroy_uncommitted_writes():
"""A rolled-back session in one task must not erase another task's writes."""
async with db.engine_session_factory(
db_path=Path("unused.db"), db_type=db.DatabaseType.MEMORY
) as (engine, session_maker):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

# Seed a project and entity so a relation row satisfies its FK constraints.
async with db.scoped_session(session_maker) as session:
await session.execute(
text(
"INSERT INTO project (id, external_id, name, description, path, is_active,"
" is_default, created_at, updated_at, permalink) "
"VALUES (1, 'px', 'p', '', '/tmp', 1, 1, '2024-01-01', '2024-01-01', 'p')"
)
)
await session.execute(
text(
"INSERT INTO entity (id, external_id, title, note_type, content_type,"
" project_id, file_path, created_at, updated_at) "
"VALUES (1, 'ex', 'E', 'note', 'text/markdown', 1, 'e.md',"
" '2024-01-01', '2024-01-01')"
)
)

write_in_flight = asyncio.Event()

async def writer() -> None:
# Mirrors RelationRepository.add_all_ignore_duplicates: INSERT executed,
# commit only happens at scoped_session exit several awaits later.
async with db.scoped_session(session_maker) as session:
await session.execute(
text(
"INSERT INTO relation (project_id, from_id, to_id, to_name,"
" relation_type) VALUES (1, 1, NULL, 'target', 'depends_on')"
)
)
write_in_flight.set()
# Real DB roundtrips (not sleeps) keep this transaction open across
# await points, exactly like the multi-statement sessions in sync.
for _ in range(10):
await session.execute(text("SELECT 1"))

async def failing_reader() -> None:
# Mirrors any per-file failure during batch indexing: scoped_session
# rolls back on the exception path while sibling tasks are mid-write.
await write_in_flight.wait()
with pytest.raises(_SimulatedIndexingFailure):
async with db.scoped_session(session_maker) as session:
await session.execute(text("SELECT 1"))
raise _SimulatedIndexingFailure()

await asyncio.gather(writer(), failing_reader())

async with db.scoped_session(session_maker) as session:
count = (await session.execute(text("SELECT count(*) FROM relation"))).scalar()

assert count == 1, (
"writer's committed INSERT was rolled back by a concurrent session — "
"the in-memory engine is sharing one transaction scope across sessions"
)
8 changes: 5 additions & 3 deletions tests/repository/test_entity_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,11 @@ async def test_delete_entity_with_relations(entity_repository: EntityRepository,
remaining_relations = result.scalars().all()
assert len(remaining_relations) == 0

# Verify target entity still exists
target_exists = await entity_repository.find_by_id(target.id)
assert target_exists is not None
# Verify target entity still exists. Runs outside the session block above:
# find_by_id opens its own scoped session, and the serialized in-memory pool
# (one connection, see #940) deadlocks on nested session checkouts.
target_exists = await entity_repository.find_by_id(target.id)
assert target_exists is not None


@pytest.mark.asyncio
Expand Down
59 changes: 59 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,3 +1600,62 @@ def test_auto_update_round_trip_persistence(self):
assert loaded.auto_update is False
assert loaded.update_check_interval == 7200
assert loaded.auto_update_last_checked_at == checked_at


class TestAtomicConfigSave:
"""Regression tests for #940: saving config must never tear the published file.

Long-lived readers (the MCP stdio server's mtime-based config reload, the CLI
background auto-update thread) re-read config.json while other code saves it.
An in-place write truncates the file first, so a concurrent reader can observe
empty/partial JSON — and load_config() raises SystemExit on invalid JSON.
"""

def test_interrupted_save_preserves_published_config(self, config_home, monkeypatch):
"""A write that dies mid-stream must leave the existing config untouched."""
import json

from basic_memory.config import save_basic_memory_config

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
config_file = temp_path / "config.json"
config = BasicMemoryConfig(
projects={"main": {"path": str(temp_path / "main")}},
default_project="main",
)
save_basic_memory_config(config_file, config)
published = config_file.read_text(encoding="utf-8")
json.loads(published) # sanity: complete, valid document

def torn_write_text(self, content, *args, **kwargs):
# Fault injection: the write dies halfway through. For an in-place
# write this is exactly the truncated state a concurrent reader
# observes mid-save; an atomic save must confine it to a temp file.
with open(self, "w", encoding="utf-8") as fh:
fh.write(content[: len(content) // 2])
raise OSError("simulated interrupted write")

monkeypatch.setattr(Path, "write_text", torn_write_text)
# save_basic_memory_config logs write failures instead of raising
save_basic_memory_config(config_file, config)
monkeypatch.undo()

assert config_file.read_text(encoding="utf-8") == published

def test_save_leaves_no_temp_files(self, config_home):
"""The atomic-write temp file must not survive a successful save."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
config_file = temp_path / "config.json"
config = BasicMemoryConfig(
projects={"main": {"path": str(temp_path / "main")}},
default_project="main",
)

from basic_memory.config import save_basic_memory_config

save_basic_memory_config(config_file, config)

assert config_file.exists()
assert not list(temp_path.glob("*.tmp"))
Loading