Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_extract_json_from_response(self):
result = extract_json_from_response(case["input"])
self.assertEqual(result, case["expected"])

@unittest.skipIf(not os.getenv("OPENAI_API_KEY"), "OpenAI API key required")
def test_extract(self):
# provide an explicit client so we cover the new parameter
client = OpenAI()
Expand Down
65 changes: 64 additions & 1 deletion tests/test_scraper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import tempfile
import threading
from typing import cast
import unittest
import os
import socket
import sys
import zipfile
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from PIL import Image
import pandas as pd

Expand Down Expand Up @@ -32,7 +35,9 @@ def setUp(self):
self.files_directory = os.path.join(os.path.dirname(__file__), "files")
self.outputs_directory = "outputs"
# create a client we can re-use for ai_extraction scenarios
self.client = OpenAI() if OpenAI is not None else None
self.client = (
OpenAI() if OpenAI is not None and os.getenv("OPENAI_API_KEY") else None
)

def tearDown(self):
# clean up outputs
Expand Down Expand Up @@ -77,6 +82,64 @@ def test_scrape_directory_inclusion_exclusion(self):
text = cast(str, chunks[0].text)
self.assertIn("Y", text)

def test_scrape_url_rejects_file_scheme(self):
with self.assertRaisesRegex(ValueError, "Only http:// and https:// URLs"):
scraper.scrape_url("file:///tmp/secret.html")

def test_scrape_url_rejects_localhost_html(self):
with self.assertRaisesRegex(
ValueError, "Local and private-network URLs are blocked by default"
):
scraper.scrape_url("http://127.0.0.1:8000/internal.html")

def test_scrape_url_rejects_localhost_download(self):
with self.assertRaisesRegex(
ValueError, "Local and private-network URLs are blocked by default"
):
scraper.scrape_url("http://127.0.0.1:8000/secret.txt")

def test_scrape_url_allows_localhost_with_opt_in(self):
canary = "LOCALHOST_DOWNLOAD_ALLOWED"
request_log = []

class Handler(BaseHTTPRequestHandler):
def do_GET(self):
request_log.append(self.path)
body = canary.encode("utf-8")
self.send_response(200)
self.send_header("Content-Type", "text/plain; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)

def log_message(self, format, *args):
return

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
port = cast(int, sock.getsockname()[1])

server = ThreadingHTTPServer(("127.0.0.1", port), Handler)
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
try:
chunks = scraper.scrape_url(
f"http://127.0.0.1:{port}/secret.txt",
allow_local_urls=True,
)
finally:
server.shutdown()
thread.join(timeout=5)
server.server_close()

self.assertEqual(request_log, ["/secret.txt"])
self.assertEqual(len(chunks), 1)
self.assertIn(canary, cast(str, chunks[0].text))

def test_scrape_github_rejects_confusable_host(self):
with self.assertRaisesRegex(ValueError, "hostname 'evil.example'"):
scraper.scrape_github("https://github.com@evil.example/owner/repo")

def test_scrape_html(self):
filepath = os.path.join(self.files_directory, "example.html")
chunks = scraper.scrape_file(filepath, verbose=True)
Expand Down
55 changes: 19 additions & 36 deletions thepipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,28 @@

from openai import OpenAI

from .scraper import scrape_directory, scrape_file, scrape_url
from .core import DEFAULT_AI_MODEL, save_outputs
from .scraper import scrape_directory, scrape_file, scrape_url


# Argument parsing
def parse_arguments() -> argparse.Namespace: # noqa: D401 – imperative is fine here
"""
Parse CLI flags.
def parse_arguments() -> argparse.Namespace: # noqa: D401
"""Parse CLI flags."""

Returns
-------
argparse.Namespace
Parsed arguments.
"""
parser = argparse.ArgumentParser(
prog="thepipe",
description="Universal document/Web scraper with optional OpenAI extraction.",
)

# Required source (file, directory, or URL)
parser.add_argument(
"source",
help="File path, directory, or URL to scrape.",
)

# Optional flags
parser.add_argument(
"-i",
"--inclusion-pattern",
dest="inclusion_pattern",
default=None,
help="Regex pattern – only files whose *full path* matches are scraped "
"(applies to directory/zip scraping).",
help="Regex pattern - only files whose full path matches are scraped (applies to directory/zip scraping).",
)
parser.add_argument(
"-v",
Expand All @@ -51,15 +40,19 @@ def parse_arguments() -> argparse.Namespace: # noqa: D401 – imperative is fin
"--text-only",
dest="text_only",
action="store_true",
help="Suppress images – output only extracted text.",
help="Suppress images - output only extracted text.",
)
parser.add_argument(
"--allow-local-urls",
dest="allow_local_urls",
action="store_true",
help="Allow scraping localhost and private-network HTTP(S) URLs. Disabled by default for security.",
)

# OpenAI-related flags
parser.add_argument(
"--openai-api-key",
dest="openai_api_key",
default=os.getenv("OPENAI_API_KEY"),
help="OpenAI API key. If omitted, env variable OPENAI_API_KEY is used.",
help="OpenAI API key. If omitted, env variable OPENAI_API_KEY is used.",
)
parser.add_argument(
"--openai-base-url",
Expand All @@ -73,61 +66,53 @@ def parse_arguments() -> argparse.Namespace: # noqa: D401 – imperative is fin
default=DEFAULT_AI_MODEL,
help=f"Chat/VLM model to use (default: {DEFAULT_AI_MODEL}).",
)

# Legacy flag (will be removed in future versions)
parser.add_argument(
"--ai-extraction",
action="store_true",
help=argparse.SUPPRESS, # hidden but still accepted
help=argparse.SUPPRESS,
)

return parser.parse_args()


# OpenAI client factory
def create_openai_client(
*,
api_key: Optional[str],
base_url: str,
enable_vlm: bool,
) -> Optional[OpenAI]:
if api_key:
# Normal path – user gave an explicit key
return OpenAI(api_key=api_key, base_url=base_url)

if enable_vlm:
# Old flag: fall back to env vars
warnings.warn(
"--ai-extraction is deprecated; "
"please use --openai-api-key and --openai-model "
"(and optionally --openai-base-url) instead.",
"--ai-extraction is deprecated; please use --openai-api-key and "
"--openai-model (and optionally --openai-base-url) instead.",
DeprecationWarning,
stacklevel=2,
)
return OpenAI(base_url=base_url, api_key=os.getenv("OPENAI_API_KEY"))

# AI extraction disabled
return None


def main() -> None:
"""CLI entry point"""
args = parse_arguments()
"""CLI entry point."""

# Instantiate the OpenAI client if requested
args = parse_arguments()
openai_client = create_openai_client(
api_key=args.openai_api_key,
base_url=args.openai_base_url,
enable_vlm=args.ai_extraction,
)

# Delegate scraping based on source type
if args.source.startswith(("http://", "https://")):
chunks = scrape_url(
args.source,
verbose=args.verbose,
openai_client=openai_client,
model=args.openai_model,
allow_local_urls=args.allow_local_urls,
)
elif os.path.isdir(args.source):
chunks = scrape_directory(
Expand All @@ -146,7 +131,6 @@ def main() -> None:
else:
raise ValueError(f"Invalid source: {args.source}")

# Persist results
save_outputs(
chunks=chunks,
verbose=args.verbose,
Expand All @@ -155,9 +139,8 @@ def main() -> None:
)

if args.verbose:
print(f"Scraping complete. Outputs saved to 'thepipe_output/'.")
print("Scraping complete. Outputs saved to 'thepipe_output/'.")


# Entry-point shim
if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions thepipe/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def extract_from_url(
verbose: bool = False,
chunking_method: Callable[[List[Chunk]], List[Chunk]] = chunk_by_page,
openai_client: Optional[OpenAI] = None,
allow_local_urls: bool = False,
) -> Tuple[List[Dict], int]:
print(
f"[thepipe] Extract functions will be deprecated in future versions. See the README for more information"
Expand All @@ -227,6 +228,7 @@ def extract_from_url(
verbose=verbose,
chunking_method=chunking_method,
openai_client=openai_client,
allow_local_urls=allow_local_urls,
)
extracted_chunks, tokens_used = extract(
chunks=chunks,
Expand Down
Loading
Loading