Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class Settings(BaseSettings):

# FASTEMBED CACHE CONFIGURATION
FASTEMBED_CACHE_DIR: str = str(_default_data_dir / "models" / "fastembed")
FASTEMBED_LOCAL_FILES_ONLY: bool = False # When True, FastEmbed only loads models from local cache

"""Pydantic Configuration"""

Expand Down
15 changes: 13 additions & 2 deletions app/repositories/embeddings/embedding_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import time
from typing import Protocol

from fastembed import TextEmbedding
from google import genai
from google.genai import types
from openai import AzureOpenAI, OpenAI

from app.config.settings import settings
from app.repositories.embeddings.fastembed_offline import load_fastembed_model

logger = logging.getLogger(__name__)

Expand All @@ -28,9 +28,11 @@ def __init__(self):
})

start_time = time.time()
self.model = TextEmbedding(
self.model = load_fastembed_model(
model_role="embedding",
model_name=settings.EMBEDDING_MODEL,
cache_dir=settings.FASTEMBED_CACHE_DIR,
factory=lambda fastembed_kwargs: self._create_text_embedding(fastembed_kwargs),
)
elapsed = time.time() - start_time

Expand All @@ -39,6 +41,15 @@ def __init__(self):
"cache_dir": settings.FASTEMBED_CACHE_DIR,
})

def _create_text_embedding(self, fastembed_kwargs: dict[str, bool]):
from fastembed import TextEmbedding

return TextEmbedding(
model_name=settings.EMBEDDING_MODEL,
cache_dir=settings.FASTEMBED_CACHE_DIR,
**fastembed_kwargs,
)

async def generate_embedding(self, text: str) -> list[float]:
try:
embeddings = list(self.model.embed(text))
Expand Down
40 changes: 40 additions & 0 deletions app/repositories/embeddings/fastembed_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Helpers for FastEmbed offline/cache-only startup."""
import os
from collections.abc import Callable

from app.config.settings import settings


def get_fastembed_kwargs() -> dict[str, bool]:
"""Return kwargs that constrain FastEmbed to local files when configured."""
if not settings.FASTEMBED_LOCAL_FILES_ONLY:
return {}

# Hugging Face Hub reads this env var during model resolution. Set it before
# importing/constructing FastEmbed models so startup never falls back to network.
os.environ.setdefault("HF_HUB_OFFLINE", "1")
return {"local_files_only": True}


def load_fastembed_model[T](
*,
model_role: str,
model_name: str,
cache_dir: str,
factory: Callable[[dict[str, bool]], T],
) -> T:
"""Load a FastEmbed model and add actionable offline guidance on failure."""
kwargs = get_fastembed_kwargs()
try:
return factory(kwargs)
except Exception as exc:
if not settings.FASTEMBED_LOCAL_FILES_ONLY:
raise

raise RuntimeError(
f"Could not load FastEmbed {model_role} model from local cache. "
f"FASTEMBED_LOCAL_FILES_ONLY=true prevents downloading missing models. "
f"model={model_name!r}, cache_dir={cache_dir!r}. "
"Pre-populate FASTEMBED_CACHE_DIR with the required model files, "
"disable FASTEMBED_LOCAL_FILES_ONLY, or choose a remote embedding/reranking provider.",
) from exc
30 changes: 27 additions & 3 deletions app/repositories/embeddings/reranker_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Protocol

import httpx
from fastembed.rerank.cross_encoder import TextCrossEncoder

from app.config.settings import settings
from app.repositories.embeddings.fastembed_offline import load_fastembed_model


class RerankAdapter(Protocol):
Expand All @@ -30,12 +30,36 @@ def __init__(
self.threads = threads
self.cache_dir = cache_dir

self._model = TextCrossEncoder(
effective_cache_dir = cache_dir or settings.FASTEMBED_CACHE_DIR
self._model = load_fastembed_model(
model_role="reranking",
model_name=model,
cache_dir=effective_cache_dir,
factory=lambda fastembed_kwargs: self._create_text_cross_encoder(
model=model,
threads=threads,
cache_dir=cache_dir,
fastembed_kwargs=fastembed_kwargs,
),
)
self._executor = ThreadPoolExecutor(max_workers=1)

def _create_text_cross_encoder(
self,
*,
model: str,
threads: int,
cache_dir: str | None,
fastembed_kwargs: dict[str, bool],
):
from fastembed.rerank.cross_encoder import TextCrossEncoder

return TextCrossEncoder(
model_name=model,
threads=threads,
cache_dir=cache_dir,
**fastembed_kwargs,
)
self._executor = ThreadPoolExecutor(max_workers=1)

async def rerank(
self,
Expand Down
4 changes: 4 additions & 0 deletions docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ MEMORY_NUM_AUTO_LINK=3 # number of memories to automatically link (set
EMBEDDING_PROVIDER=FastEmbed
EMBEDDING_MODEL=BAAI/bge-small-en-v1.5
EMBEDDING_DIMENSIONS=384
# Set to true to require pre-cached FastEmbed models and fail fast instead of
# attempting HuggingFace/GCS downloads (applies to both the embedding model
# and the FastEmbed reranker). See docs/OFFLINE_SETUP.md.
FASTEMBED_LOCAL_FILES_ONLY=false

# Azure Embedding Provider Configuration - Only required if EMBEDDING_PROVIDER == Azure
# EMBEDDING_PROVIDER=Azure
Expand Down
13 changes: 13 additions & 0 deletions docs/OFFLINE_SETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ Then install and run forgetful normally:
uvx forgetful-ai
```

To fail fast instead of attempting network downloads when a model is missing,
enable local-files-only mode:

```bash
FASTEMBED_LOCAL_FILES_ONLY=true uvx forgetful-ai
```

In this mode, both the embedding model and FastEmbed reranker must already be
present under `FASTEMBED_CACHE_DIR`. If not, Forgetful exits with an error that
points to the cache directory and configured model names.

### Option 3: Transfer from Another Machine

If you have the models cached on another machine, you can transfer them:
Expand Down Expand Up @@ -114,9 +125,11 @@ uvx forgetful-ai
```

Or create a `.env` file:

```bash
# .env
FASTEMBED_CACHE_DIR=/opt/models/fastembed
FASTEMBED_LOCAL_FILES_ONLY=true
```

## Verification
Expand Down
8 changes: 8 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,14 @@ ACTIVITY_TRACK_READS=false
- **Note**: Can be customized for offline deployments (see [Offline Setup Guide](OFFLINE_SETUP.md))
- **Example**: `FASTEMBED_CACHE_DIR=/app/data/models/fastembed`

#### `FASTEMBED_LOCAL_FILES_ONLY`

- **Default**: `false`
- **Description**: Restrict FastEmbed embedding and reranking models to files already present in `FASTEMBED_CACHE_DIR`
- **Purpose**: Prevents HuggingFace/GCS download attempts in locked-down or offline environments
- **Note**: When enabled, pre-populate the cache first or startup fails with a local-cache error
- **Example**: `FASTEMBED_LOCAL_FILES_ONLY=true`

---

## Configuration Hierarchy
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def _check_first_run_models():
"""Log message on first run when models need to be downloaded."""
cache_dir = Path(settings.FASTEMBED_CACHE_DIR)
if not cache_dir.exists() or not any(cache_dir.iterdir()):
if settings.FASTEMBED_LOCAL_FILES_ONLY:
logger.info(
"FastEmbed local-files-only mode enabled - models must already exist in cache",
extra={"cache_dir": settings.FASTEMBED_CACHE_DIR},
)
return
logger.info("First run detected - downloading embedding models. This may take a minute...")


Expand Down
164 changes: 164 additions & 0 deletions tests/integration/test_fastembed_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Network-free tests for FastEmbed offline/cache-only startup."""
import os
from unittest.mock import MagicMock, patch

import pytest

from app.config.settings import settings
from app.repositories.embeddings.fastembed_offline import (
get_fastembed_kwargs,
load_fastembed_model,
)


@pytest.fixture(autouse=True)
def clean_hf_hub_offline(monkeypatch):
"""Keep HF_HUB_OFFLINE changes isolated to each test."""
monkeypatch.delenv("HF_HUB_OFFLINE", raising=False)


@pytest.fixture
def fastembed_cache_dir(tmp_path):
return str(tmp_path / "fastembed")


def test_get_fastembed_kwargs_default_does_not_force_offline(monkeypatch):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", False)

assert get_fastembed_kwargs() == {}
assert "HF_HUB_OFFLINE" not in os.environ


def test_get_fastembed_kwargs_local_files_only_sets_hf_offline(monkeypatch):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)

assert get_fastembed_kwargs() == {"local_files_only": True}
assert os.environ["HF_HUB_OFFLINE"] == "1"


def test_load_fastembed_model_wraps_errors_in_local_files_only_mode(
monkeypatch,
fastembed_cache_dir,
):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)

def fail(_kwargs):
raise ValueError("raw fastembed failure")

with pytest.raises(RuntimeError, match="FASTEMBED_LOCAL_FILES_ONLY=true") as exc:
load_fastembed_model(
model_role="embedding",
model_name="BAAI/bge-small-en-v1.5",
cache_dir=fastembed_cache_dir,
factory=fail,
)

assert isinstance(exc.value.__cause__, ValueError)


def test_load_fastembed_model_preserves_default_errors(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", False)

def fail(_kwargs):
raise ValueError("raw fastembed failure")

with pytest.raises(ValueError, match="raw fastembed failure"):
load_fastembed_model(
model_role="embedding",
model_name="BAAI/bge-small-en-v1.5",
cache_dir=fastembed_cache_dir,
factory=fail,
)


def test_fast_embedding_adapter_passes_local_files_only(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)
monkeypatch.setattr(settings, "EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5")
monkeypatch.setattr(settings, "FASTEMBED_CACHE_DIR", fastembed_cache_dir)

with patch("fastembed.TextEmbedding") as mock_text_embedding:
from app.repositories.embeddings.embedding_adapter import FastEmbeddingAdapter

adapter = FastEmbeddingAdapter()

mock_text_embedding.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
cache_dir=fastembed_cache_dir,
local_files_only=True,
)
assert adapter.model is mock_text_embedding.return_value


def test_fast_embedding_adapter_keeps_default_constructor(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", False)
monkeypatch.setattr(settings, "EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5")
monkeypatch.setattr(settings, "FASTEMBED_CACHE_DIR", fastembed_cache_dir)

with patch("fastembed.TextEmbedding") as mock_text_embedding:
from app.repositories.embeddings.embedding_adapter import FastEmbeddingAdapter

FastEmbeddingAdapter()

mock_text_embedding.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5",
cache_dir=fastembed_cache_dir,
)


def test_fast_embedding_adapter_wraps_offline_load_error(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)
monkeypatch.setattr(settings, "EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5")
monkeypatch.setattr(settings, "FASTEMBED_CACHE_DIR", fastembed_cache_dir)

with patch("fastembed.TextEmbedding", side_effect=ValueError("missing model")):
from app.repositories.embeddings.embedding_adapter import FastEmbeddingAdapter

with pytest.raises(RuntimeError, match="Could not load FastEmbed embedding model"):
FastEmbeddingAdapter()


def test_fastembed_reranker_passes_local_files_only(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)
monkeypatch.setattr(settings, "FASTEMBED_CACHE_DIR", fastembed_cache_dir)
mock_encoder = MagicMock()

with patch(
"fastembed.rerank.cross_encoder.TextCrossEncoder",
return_value=mock_encoder,
) as mock_cls:
from app.repositories.embeddings.reranker_adapter import (
FastEmbedCrossEncoderAdapter,
)

adapter = FastEmbedCrossEncoderAdapter(
model="Xenova/ms-marco-MiniLM-L-12-v2",
threads=2,
cache_dir=fastembed_cache_dir,
)

mock_cls.assert_called_once_with(
model_name="Xenova/ms-marco-MiniLM-L-12-v2",
threads=2,
cache_dir=fastembed_cache_dir,
local_files_only=True,
)
assert adapter._model is mock_encoder


def test_fastembed_reranker_wraps_offline_load_error(monkeypatch, fastembed_cache_dir):
monkeypatch.setattr(settings, "FASTEMBED_LOCAL_FILES_ONLY", True)
monkeypatch.setattr(settings, "FASTEMBED_CACHE_DIR", fastembed_cache_dir)

with patch(
"fastembed.rerank.cross_encoder.TextCrossEncoder",
side_effect=ValueError("missing model"),
):
from app.repositories.embeddings.reranker_adapter import (
FastEmbedCrossEncoderAdapter,
)

with pytest.raises(RuntimeError, match="Could not load FastEmbed reranking model"):
FastEmbedCrossEncoderAdapter(
model="Xenova/ms-marco-MiniLM-L-12-v2",
cache_dir=fastembed_cache_dir,
)
Loading