diff --git a/README.md b/README.md index 7aa6d2d..fd5f342 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,39 @@ pip install thepipe-api[all] By default, thepipe uses the [OpenAI API](https://platform.openai.com/docs/overview), so VLM features will work out-of-the-box provided you pass in an OpenAI client. +### MiniMax setup + +[MiniMax](https://www.minimaxi.com/) is supported as a first-class provider. Its OpenAI-compatible API works with all of thepipe's LLM features (PDF VLM scraping, webpage analysis, agentic chunking, and structured extraction). + +```python +from thepipe.provider import create_provider_client +from thepipe.scraper import scrape_file + +client, preset = create_provider_client("minimax") + +chunks = scrape_file( + filepath="paper.pdf", + openai_client=client, + model="MiniMax-M2.7", # or MiniMax-M2.5, MiniMax-M2.5-highspeed +) +``` + +Or via the CLI: + +```bash +export MINIMAX_API_KEY=your-key +thepipe paper.pdf --provider minimax --verbose +``` + +Available MiniMax models: + +| Model | Context | Notes | +|---|---|---| +| `MiniMax-M2.7` | 1M tokens | Latest, recommended | +| `MiniMax-M2.7-highspeed` | 1M tokens | Faster inference | +| `MiniMax-M2.5` | 204K tokens | Previous generation | +| `MiniMax-M2.5-highspeed` | 204K tokens | Fast inference | + ### Custom VLM server setup (OpenRouter, OpenLLM, etc.) If you wish to use a local vision-language model or a different cloud provider, you can provide a custom OpenAI client, for example, by setting the base url to `https://openrouter.ai/api/v1` for [OpenRouter](https://openrouter.ai/), or `http://localhost:3000/v1` for a local server such as [OpenLLM](https://github.com/bentoml/OpenLLM). Note that uou must also pass the api key to your non-OpenAI cloud provider into the OpenAI client. The model name can be changed with the `model` parameter. By default, the model will be `gpt-4o`. @@ -277,6 +310,12 @@ export GITHUB_TOKEN=... `thepipe [options]` +### Provider selection + +`--provider=NAME` LLM provider to use (`openai`, `minimax`). Auto-detected from environment variables if omitted. + +`--api-key=KEY` API key for the selected provider. Falls back to the provider's environment variable. + ### AI scraping options `--openai-api-key=KEY` To enable VLM scraping, pass in your OpenAI API key diff --git a/tests/test_provider.py b/tests/test_provider.py new file mode 100644 index 0000000..55e0f3e --- /dev/null +++ b/tests/test_provider.py @@ -0,0 +1,315 @@ +"""Unit and integration tests for thepipe.provider (MiniMax support).""" + +import json +import os +import unittest +from unittest.mock import MagicMock, patch + +from thepipe.provider import ( + MINIMAX_PRESET, + OPENAI_PRESET, + PROVIDER_PRESETS, + ProviderPreset, + clamp_temperature, + create_provider_client, + detect_provider, + get_provider_preset, + strip_think_tags, +) + + +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- + + +class TestProviderPreset(unittest.TestCase): + """Tests for the ProviderPreset dataclass.""" + + def test_openai_preset_exists(self): + self.assertIn("openai", PROVIDER_PRESETS) + self.assertEqual(OPENAI_PRESET.name, "openai") + self.assertEqual(OPENAI_PRESET.base_url, "https://api.openai.com/v1") + + def test_minimax_preset_exists(self): + self.assertIn("minimax", PROVIDER_PRESETS) + self.assertEqual(MINIMAX_PRESET.name, "minimax") + self.assertEqual(MINIMAX_PRESET.base_url, "https://api.minimax.io/v1") + self.assertEqual(MINIMAX_PRESET.default_model, "MiniMax-M2.7") + self.assertEqual(MINIMAX_PRESET.api_key_env, "MINIMAX_API_KEY") + + def test_minimax_models(self): + models = MINIMAX_PRESET.models + self.assertIn("MiniMax-M2.7", models) + self.assertIn("MiniMax-M2.7-highspeed", models) + self.assertIn("MiniMax-M2.5", models) + self.assertIn("MiniMax-M2.5-highspeed", models) + + def test_minimax_temperature_range(self): + self.assertEqual(MINIMAX_PRESET.temperature_min, 0.0) + self.assertEqual(MINIMAX_PRESET.temperature_max, 1.0) + + +class TestGetProviderPreset(unittest.TestCase): + """Tests for get_provider_preset().""" + + def test_known_provider_openai(self): + preset = get_provider_preset("openai") + self.assertEqual(preset.name, "openai") + + def test_known_provider_minimax(self): + preset = get_provider_preset("minimax") + self.assertEqual(preset.name, "minimax") + + def test_case_insensitive(self): + preset = get_provider_preset("MiniMax") + self.assertEqual(preset.name, "minimax") + + def test_unknown_provider_raises(self): + with self.assertRaises(ValueError) as ctx: + get_provider_preset("nonexistent") + self.assertIn("nonexistent", str(ctx.exception)) + self.assertIn("Available providers", str(ctx.exception)) + + +class TestDetectProvider(unittest.TestCase): + """Tests for detect_provider().""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}, clear=False) + def test_detects_minimax_when_only_minimax_key(self): + with patch.dict(os.environ, {}, clear=False): + # Remove OPENAI_API_KEY if present + env = os.environ.copy() + env.pop("OPENAI_API_KEY", None) + with patch.dict(os.environ, env, clear=True): + os.environ["MINIMAX_API_KEY"] = "test-key" + self.assertEqual(detect_provider(), "minimax") + + @patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False) + def test_defaults_to_openai(self): + self.assertEqual(detect_provider(), "openai") + + @patch.dict(os.environ, {}, clear=True) + def test_defaults_to_openai_when_no_keys(self): + self.assertEqual(detect_provider(), "openai") + + +class TestClampTemperature(unittest.TestCase): + """Tests for clamp_temperature().""" + + def test_none_returns_none(self): + self.assertIsNone(clamp_temperature(None, MINIMAX_PRESET)) + + def test_within_range(self): + self.assertEqual(clamp_temperature(0.5, MINIMAX_PRESET), 0.5) + + def test_above_max_clamped(self): + self.assertEqual(clamp_temperature(1.5, MINIMAX_PRESET), 1.0) + + def test_below_min_clamped(self): + self.assertEqual(clamp_temperature(-0.1, MINIMAX_PRESET), 0.0) + + def test_zero_accepted(self): + self.assertEqual(clamp_temperature(0.0, MINIMAX_PRESET), 0.0) + + def test_openai_wider_range(self): + self.assertEqual(clamp_temperature(1.5, OPENAI_PRESET), 1.5) + self.assertEqual(clamp_temperature(2.5, OPENAI_PRESET), 2.0) + + +class TestStripThinkTags(unittest.TestCase): + """Tests for strip_think_tags().""" + + def test_no_think_tags(self): + text = "Hello world" + self.assertEqual(strip_think_tags(text), "Hello world") + + def test_single_think_tag(self): + text = "reasoning hereThe answer is 42." + self.assertEqual(strip_think_tags(text), "The answer is 42.") + + def test_multiline_think_tag(self): + text = "\nStep 1: analyze\nStep 2: compute\n\nResult: done" + self.assertEqual(strip_think_tags(text), "Result: done") + + def test_multiple_think_tags(self): + text = "aHello bworld" + self.assertEqual(strip_think_tags(text), "Hello world") + + def test_empty_think_tag(self): + text = "Just the output" + self.assertEqual(strip_think_tags(text), "Just the output") + + +class TestCreateProviderClient(unittest.TestCase): + """Tests for create_provider_client().""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-minimax-key"}, clear=False) + def test_creates_minimax_client(self): + client, preset = create_provider_client("minimax") + self.assertEqual(preset.name, "minimax") + self.assertEqual(preset.default_model, "MiniMax-M2.7") + + @patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}, clear=False) + def test_creates_openai_client(self): + client, preset = create_provider_client("openai") + self.assertEqual(preset.name, "openai") + + def test_explicit_api_key(self): + client, preset = create_provider_client("minimax", api_key="explicit-key") + self.assertEqual(preset.name, "minimax") + + @patch.dict(os.environ, {}, clear=True) + def test_no_api_key_raises(self): + with self.assertRaises(ValueError) as ctx: + create_provider_client("minimax") + self.assertIn("MINIMAX_API_KEY", str(ctx.exception)) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}, clear=False) + def test_custom_base_url(self): + client, preset = create_provider_client( + "minimax", base_url="https://custom.api.example.com/v1" + ) + self.assertEqual(preset.name, "minimax") + + +class TestChunkerAgenticFallback(unittest.TestCase): + """Tests for chunk_agentic() MiniMax fallback.""" + + def test_agentic_json_fallback(self): + """Verify chunk_agentic() falls back to json_object mode when + .beta.chat.completions.parse() is unavailable.""" + from thepipe.chunker import chunk_agentic + from thepipe.core import Chunk + + chunks = [Chunk(path="test.txt", text="Line one\nLine two\nLine three")] + + # Mock an OpenAI client whose .beta.chat.completions.parse() fails + mock_client = MagicMock() + mock_client.beta.chat.completions.parse.side_effect = Exception( + "Not supported" + ) + + # Mock fallback .chat.completions.create() to return valid JSON + sections_json = json.dumps( + { + "sections": [ + {"title": "All", "start_line": 1, "end_line": 3}, + ] + } + ) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = sections_json + mock_client.chat.completions.create.return_value = mock_response + + result = chunk_agentic(chunks, openai_client=mock_client, model="MiniMax-M2.7") + + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + self.assertIn("Line one", result[0].text) + + def test_agentic_fallback_with_think_tags(self): + """Verify think tags are stripped in agentic chunking fallback.""" + from thepipe.chunker import chunk_agentic + from thepipe.core import Chunk + + chunks = [Chunk(path="test.txt", text="Hello world")] + + mock_client = MagicMock() + mock_client.beta.chat.completions.parse.side_effect = Exception( + "Not supported" + ) + + sections_json = ( + 'reasoning about sections' + + json.dumps( + {"sections": [{"title": "Intro", "start_line": 1, "end_line": 1}]} + ) + ) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = sections_json + mock_client.chat.completions.create.return_value = mock_response + + result = chunk_agentic(chunks, openai_client=mock_client, model="MiniMax-M2.5") + + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) + + +class TestProviderPresetImmutability(unittest.TestCase): + """Verify presets are frozen dataclasses.""" + + def test_preset_is_frozen(self): + with self.assertRaises(AttributeError): + MINIMAX_PRESET.name = "changed" # type: ignore[misc] + + def test_preset_has_correct_fields(self): + self.assertIsInstance(MINIMAX_PRESET, ProviderPreset) + self.assertIsInstance(MINIMAX_PRESET.models, dict) + + +# --------------------------------------------------------------------------- +# Integration tests (require MINIMAX_API_KEY) +# --------------------------------------------------------------------------- + +HAS_MINIMAX_KEY = bool(os.getenv("MINIMAX_API_KEY")) + + +@unittest.skipUnless(HAS_MINIMAX_KEY, "MINIMAX_API_KEY not set") +class TestMiniMaxIntegration(unittest.TestCase): + """Integration tests that hit the real MiniMax API.""" + + def setUp(self): + self.client, self.preset = create_provider_client("minimax") + self.files_directory = os.path.join(os.path.dirname(__file__), "files") + + def test_minimax_chat_completion(self): + """Verify basic chat completion works with MiniMax.""" + response = self.client.chat.completions.create( + model=self.preset.default_model, + messages=[{"role": "user", "content": "Say hello in one word."}], + ) + content = response.choices[0].message.content + self.assertIsNotNone(content) + self.assertGreater(len(content), 0) + + def test_minimax_json_mode(self): + """Verify json_object response_format works with MiniMax.""" + response = self.client.chat.completions.create( + model=self.preset.default_model, + messages=[ + { + "role": "user", + "content": 'Return a JSON object with key "status" and value "ok".', + } + ], + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content + self.assertIsNotNone(content) + data = json.loads(strip_think_tags(content)) + self.assertIn("status", data) + + def test_minimax_scrape_pdf_text_only(self): + """Verify MiniMax can be used for PDF scraping.""" + from thepipe.scraper import scrape_file + + pdf_path = os.path.join(self.files_directory, "example.pdf") + if not os.path.exists(pdf_path): + self.skipTest("example.pdf not found in test files") + + chunks = scrape_file( + filepath=pdf_path, + openai_client=self.client, + model=self.preset.default_model, + include_input_images=False, + ) + self.assertIsInstance(chunks, list) + self.assertGreater(len(chunks), 0) + self.assertTrue(any(c.text for c in chunks)) + + +if __name__ == "__main__": + unittest.main() diff --git a/thepipe/__init__.py b/thepipe/__init__.py index 93702da..c988ef4 100644 --- a/thepipe/__init__.py +++ b/thepipe/__init__.py @@ -9,6 +9,12 @@ from .scraper import scrape_directory, scrape_file, scrape_url from .core import DEFAULT_AI_MODEL, save_outputs +from .provider import ( + PROVIDER_PRESETS, + create_provider_client, + detect_provider, + get_provider_preset, +) # Argument parsing @@ -23,7 +29,7 @@ def parse_arguments() -> argparse.Namespace: # noqa: D401 – imperative is fin """ parser = argparse.ArgumentParser( prog="thepipe", - description="Universal document/Web scraper with optional OpenAI extraction.", + description="Universal document/Web scraper with optional LLM extraction.", ) # Required source (file, directory, or URL) @@ -54,7 +60,23 @@ def parse_arguments() -> argparse.Namespace: # noqa: D401 – imperative is fin help="Suppress images – output only extracted text.", ) - # OpenAI-related flags + # Provider selection + available_providers = ", ".join(sorted(PROVIDER_PRESETS)) + parser.add_argument( + "--provider", + default=None, + help=f"LLM provider to use ({available_providers}). " + "Auto-detected from environment variables if omitted.", + ) + parser.add_argument( + "--api-key", + dest="api_key", + default=None, + help="API key for the selected provider. " + "Falls back to the provider's environment variable (e.g. OPENAI_API_KEY, MINIMAX_API_KEY).", + ) + + # OpenAI-related flags (kept for backwards compatibility) parser.add_argument( "--openai-api-key", dest="openai_api_key", @@ -114,34 +136,46 @@ def main() -> None: """CLI entry point""" args = parse_arguments() - # Instantiate the OpenAI client if requested - openai_client = create_openai_client( - api_key=args.openai_api_key, - base_url=args.openai_base_url, - enable_vlm=args.ai_extraction, - ) + # Determine model and client based on provider selection + if args.provider or args.api_key: + # New provider-based path + provider_name = args.provider or detect_provider() + preset = get_provider_preset(provider_name) + client, _ = create_provider_client( + provider=provider_name, + api_key=args.api_key, + ) + model = args.openai_model if args.openai_model != DEFAULT_AI_MODEL else preset.default_model + else: + # Legacy OpenAI-only path + client = create_openai_client( + api_key=args.openai_api_key, + base_url=args.openai_base_url, + enable_vlm=args.ai_extraction, + ) + model = args.openai_model # Delegate scraping based on source type if args.source.startswith(("http://", "https://")): chunks = scrape_url( args.source, verbose=args.verbose, - openai_client=openai_client, - model=args.openai_model, + openai_client=client, + model=model, ) elif os.path.isdir(args.source): chunks = scrape_directory( dir_path=args.source, inclusion_pattern=args.inclusion_pattern, verbose=args.verbose, - openai_client=openai_client, + openai_client=client, ) elif os.path.isfile(args.source): chunks = scrape_file( filepath=args.source, verbose=args.verbose, - openai_client=openai_client, - model=args.openai_model, + openai_client=client, + model=model, ) else: raise ValueError(f"Invalid source: {args.source}") diff --git a/thepipe/chunker.py b/thepipe/chunker.py index 3510dbb..733eebf 100644 --- a/thepipe/chunker.py +++ b/thepipe/chunker.py @@ -1,3 +1,4 @@ +import json import re from typing import Dict, List, Optional, Tuple, Union from .core import ( @@ -6,6 +7,7 @@ DEFAULT_AI_MODEL, DEFAULT_EMBEDDING_MODEL, ) +from .provider import strip_think_tags import numpy as np from pydantic import BaseModel from openai import OpenAI @@ -303,19 +305,46 @@ def chunk_agentic( ) user_prompt = numbered - completion = openai_client.beta.chat.completions.parse( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - response_format=SectionList, - ) - - if not completion.choices[0].message.parsed: - raise ValueError( - "LLM did not return a valid response during agentic chunking." + # Try structured output first (.beta.chat.completions.parse); + # fall back to json_object mode for providers that don't support it + # (e.g. MiniMax). + try: + completion = openai_client.beta.chat.completions.parse( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format=SectionList, + ) + if not completion.choices[0].message.parsed: + raise ValueError( + "LLM did not return a valid response during agentic chunking." + ) + except Exception: + # Fallback: use json_object mode and parse manually + fallback = openai_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, ) + raw = fallback.choices[0].message.content or "" + raw = strip_think_tags(raw) + parsed_data = json.loads(raw) + completion = type("_Stub", (), { + "choices": [type("_Choice", (), { + "message": type("_Msg", (), { + "parsed": SectionList(**parsed_data), + })(), + })()], + })() + if not completion.choices[0].message.parsed: + raise ValueError( + "LLM did not return a valid response during agentic chunking." + ) sections: List[Section] = completion.choices[0].message.parsed.sections diff --git a/thepipe/provider.py b/thepipe/provider.py new file mode 100644 index 0000000..7bc18ee --- /dev/null +++ b/thepipe/provider.py @@ -0,0 +1,148 @@ +"""LLM provider presets and client factory for thepipe. + +Provides a clean abstraction over different OpenAI-compatible LLM providers, +allowing users to switch between OpenAI, MiniMax, and others via a single +``--provider`` flag or ``LLM_PROVIDER`` environment variable. +""" + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass, field +from typing import Dict, Optional + +from openai import OpenAI + + +@dataclass(frozen=True) +class ProviderPreset: + """Immutable configuration for an OpenAI-compatible LLM provider.""" + + name: str + base_url: str + default_model: str + api_key_env: str + models: Dict[str, str] = field(default_factory=dict) + temperature_min: float = 0.0 + temperature_max: float = 2.0 + + +# --------------------------------------------------------------------------- +# Built-in provider presets +# --------------------------------------------------------------------------- + +OPENAI_PRESET = ProviderPreset( + name="openai", + base_url="https://api.openai.com/v1", + default_model="gpt-4o", + api_key_env="OPENAI_API_KEY", + models={ + "gpt-4o": "GPT-4o (latest)", + "gpt-4o-mini": "GPT-4o Mini", + "gpt-4-turbo": "GPT-4 Turbo", + }, +) + +MINIMAX_PRESET = ProviderPreset( + name="minimax", + base_url="https://api.minimax.io/v1", + default_model="MiniMax-M2.7", + api_key_env="MINIMAX_API_KEY", + models={ + "MiniMax-M2.7": "MiniMax M2.7 (latest, 1M context)", + "MiniMax-M2.7-highspeed": "MiniMax M2.7 High-Speed", + "MiniMax-M2.5": "MiniMax M2.5 (204K context)", + "MiniMax-M2.5-highspeed": "MiniMax M2.5 High-Speed (204K context)", + }, + temperature_min=0.0, + temperature_max=1.0, +) + +PROVIDER_PRESETS: Dict[str, ProviderPreset] = { + "openai": OPENAI_PRESET, + "minimax": MINIMAX_PRESET, +} + + +def get_provider_preset(name: str) -> ProviderPreset: + """Return a :class:`ProviderPreset` by name (case-insensitive). + + Raises ``ValueError`` for unknown providers. + """ + key = name.lower() + if key not in PROVIDER_PRESETS: + available = ", ".join(sorted(PROVIDER_PRESETS)) + raise ValueError( + f"Unknown provider '{name}'. Available providers: {available}" + ) + return PROVIDER_PRESETS[key] + + +def detect_provider() -> str: + """Auto-detect provider from available environment variables. + + Returns ``"minimax"`` if ``MINIMAX_API_KEY`` is set (and ``OPENAI_API_KEY`` + is not), otherwise ``"openai"``. + """ + if os.getenv("MINIMAX_API_KEY") and not os.getenv("OPENAI_API_KEY"): + return "minimax" + return "openai" + + +def clamp_temperature( + temperature: Optional[float], preset: ProviderPreset +) -> Optional[float]: + """Clamp *temperature* to the provider's valid range, or return ``None``.""" + if temperature is None: + return None + return max(preset.temperature_min, min(temperature, preset.temperature_max)) + + +def strip_think_tags(text: str) -> str: + """Remove ```` blocks from LLM output. + + Some MiniMax models emit reasoning traces wrapped in ```` tags that + should be stripped before returning results to the caller. + """ + return re.sub(r"[\s\S]*?", "", text).strip() + + +def create_provider_client( + provider: Optional[str] = None, + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, +) -> tuple[OpenAI, ProviderPreset]: + """Create an :class:`OpenAI` client configured for the given *provider*. + + Parameters + ---------- + provider: + Provider name (``"openai"`` or ``"minimax"``). When ``None`` the + provider is auto-detected from environment variables. + api_key: + Explicit API key. Falls back to the provider's ``api_key_env``. + base_url: + Override the provider's default base URL. + + Returns + ------- + tuple[OpenAI, ProviderPreset] + A configured client and the resolved preset. + """ + if provider is None: + provider = detect_provider() + + preset = get_provider_preset(provider) + resolved_key = api_key or os.getenv(preset.api_key_env, "") + resolved_url = base_url or preset.base_url + + if not resolved_key: + raise ValueError( + f"No API key found for provider '{preset.name}'. " + f"Set the {preset.api_key_env} environment variable or pass --api-key." + ) + + client = OpenAI(api_key=resolved_key, base_url=resolved_url) + return client, preset