diff --git a/.gitignore b/.gitignore index 88c2c59..964beba 100755 --- a/.gitignore +++ b/.gitignore @@ -38,5 +38,8 @@ repos/ # OAuth token cache for mcp-remote bridges (contains live access/refresh tokens) .mcp-auth/ +# OAuth token cache for REST providers' authorization_code flow (live tokens) +.rest-auth/ + # Legacy Playwright MCP output directory (replaced by ./files in docker-compose) .playwright-mcp/ diff --git a/README.md b/README.md index 7c31a0b..3a8ddcf 100755 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ Click **+ New Provider** and choose a provider type: | **Python code** | Write `async def` functions; the UI lists the ones it finds as you type. Each becomes a tool entry. | | **Package** | Enter any command that launches a stdio MCP server (`npx`, `uvx`, `python -m`, or an installed binary). When you click **Next**, mcpproxy auto-introspects the command and pre-populates the tool list; if introspection fails you can still proceed and add tools by hand. | | **Repository** | Provide a git URL and a list of build commands. mcpproxy clones the repo, runs the build commands, then introspects the resulting stdio MCP server. The URL and build commands are persisted in YAML so the repo can be re-cloned and re-built automatically on every container restart. | +| **REST / OAuth API** | Point at a REST API: a base URL plus an OpenAPI spec (imported into tools automatically) or hand-entered endpoints, with optional OAuth. Each endpoint becomes an MCP tool. See [REST / OAuth providers](#rest--oauth-providers). | After the provider step, the wizard shows a **Secrets** step: any `secrets.env` entries in the provider are listed, and you can fill in their values to save them directly to `.env`. @@ -143,6 +144,106 @@ ready when needed. > **After editing and saving** a provider's command or setup steps, click **Restart MCP Server** > (the yellow bar that appears after saving) to apply the changes. +## REST / OAuth providers + +A **REST provider** wraps an HTTP/REST API directly — no Python and no separate MCP +server needed. A provider YAML with a `rest:` block declares a base URL, an `auth:` +block, and a set of endpoints; each endpoint becomes an MCP tool. mcpproxy builds the +HTTP request (path/query/body), attaches authentication, and returns the JSON response. + +Create one through the **+ New Provider → REST / OAuth API** wizard. You can **import an +OpenAPI spec** (URL or file — OpenAPI 3.x or Swagger 2.0) to generate the endpoints and tools automatically, or +enter endpoints by hand. OpenAPI specs are expanded into concrete endpoints when the +provider is created, so startup stays fast and offline. + +After creation, the editor lets you **edit everything inline** — the base URL, the auth +block, default headers (sent on every request), and the endpoint list (method, path, and +which params go in the path / query / body). Adding or removing an endpoint keeps its +paired tool in sync (endpoints map 1:1 to tools by name), and **⟳ Sync params to tool +schema** regenerates a tool's input schema from its endpoint's params. + +Large responses are **truncated** to a bounded preview (with a `truncated` flag) so a +single call can't flood the model's context — tune or disable via `MCPPROXY_REST_MAX_BYTES`. + +### Authentication + +The `auth.type` field selects how requests are authenticated. Secrets are referenced by +**environment-variable name** (the `*_env` fields) and filled in via the Secrets UI / `.env` — +never written into the YAML. + +| `auth.type` | Fields | Behaviour | +|---|---|---| +| `none` | — | No authentication. | +| `bearer` | `token_env` | Sends `Authorization: Bearer `. | +| `api_key` | `value_env`, plus either `header` (default `X-Api-Key`) or `in: query` + `name` | Sends the secret in a custom header, or as a query parameter when `in: query`. | +| `client_credentials` | `token_url`, `client_id_env`, `client_secret_env`, `scopes` | OAuth2 client-credentials. Token is fetched, cached, and auto-refreshed on expiry/401. | +| `authorization_code` | `authorize_url`, `token_url`, `client_id_env`, `client_secret_env` (optional for PKCE), `scopes` | Interactive OAuth2 + PKCE. Click **🔐 Authorize** in the editor to complete the browser flow; tokens are cached and refreshed automatically. | + +For `authorization_code`, register the redirect URI **`/oauth/callback`** +(default `http://localhost:8889/oauth/callback`) with your OAuth provider. Tokens are cached +under `MCPPROXY_REST_AUTH_DIR` (default `/app/.rest-auth`, gitignored). + +### Example + +```yaml +rest: + base_url: https://api.example.com/v1 + headers: + Accept: application/json + auth: + type: client_credentials + token_url: https://auth.example.com/oauth/token + client_id_env: EXAMPLE_CLIENT_ID + client_secret_env: EXAMPLE_CLIENT_SECRET + scopes: [read, write] + endpoints: + - name: get_user + method: GET + path: /users/{user_id} + path_params: [user_id] + query_params: [include] + body_params: [] + - name: create_item + method: POST + path: /items + path_params: [] + query_params: [] + body_params: [title, body] + +requirements: [httpx] + +tools: + - name: get_user + description: Fetch a user by id. + input_schema: + type: object + properties: + user_id: {type: string} + include: {type: string} + required: [user_id] + - name: create_item + description: Create an item. + input_schema: + type: object + properties: + title: {type: string} + body: {type: string} + required: [title] +``` + +Each tool's `name` maps 1:1 to an endpoint's `name`. REST providers depend on `httpx` +(installed by default). + +At startup, OAuth-backed REST providers are **warmed**: `client_credentials` tokens are +fetched and cached, and `authorization_code` providers that have no usable token surface +their **🔐 Authorize** link in the banner immediately, rather than only after the first +failed tool call. (Disable with `MCPPROXY_WARM_REMOTE=0`.) + +Config knobs: `MCPPROXY_REST_AUTH_DIR`, `MCPPROXY_OAUTH_REDIRECT_BASE`, +`MCPPROXY_REST_TIMEOUT` (per-request HTTP timeout), `MCPPROXY_REST_MAX_BYTES` (max +response size before truncation; 0 disables), and `MCPPROXY_OAUTH_FLOW_TTL` (seconds an +in-flight authorization attempt stays valid; default 600). + ## Secrets Each tool provider YAML declares its required environment variables under `secrets.env`: diff --git a/config.py b/config.py index e27a737..95baaab 100644 --- a/config.py +++ b/config.py @@ -19,6 +19,20 @@ # Override with MCPPROXY_REPOS_DIR. REPOS_DIR = Path(os.environ.get("MCPPROXY_REPOS_DIR", "/app/repos")) +# Directory where REST providers cache OAuth tokens (authorization_code flow). +# One JSON file per provider (e.g. /app/.rest-auth/.json) holding the +# access/refresh tokens and expiry. Gitignored. Override with +# MCPPROXY_REST_AUTH_DIR (run_local.sh points it at ./.rest-auth for local runs). +REST_AUTH_DIR = Path(os.environ.get("MCPPROXY_REST_AUTH_DIR", "/app/.rest-auth")) + +# Public base URL the OAuth provider redirects back to after the user authorizes +# a REST provider's authorization_code flow. The callback route is served by the +# UI app at "/oauth/callback", so this must match a redirect URI registered +# with the OAuth provider. Override with MCPPROXY_OAUTH_REDIRECT_BASE. +OAUTH_REDIRECT_BASE = os.environ.get( + "MCPPROXY_OAUTH_REDIRECT_BASE", "http://localhost:8889" +).rstrip("/") + UI_HOST = os.environ.get("MCP_UI_HOST", "0.0.0.0") UI_PORT = int(os.environ.get("MCP_UI_PORT", "8889")) diff --git a/docker-compose.yml b/docker-compose.yml index f69a1ff..e03f533 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,6 +51,8 @@ services: # Where mcp-remote caches OAuth tokens (access + refresh). Persisted via # the mcpproxy-mcp-auth volume so you authorize once and refresh silently. MCP_REMOTE_CONFIG_DIR: "/app/.mcp-auth" + # Where REST providers cache OAuth tokens for the authorization_code flow. + MCPPROXY_REST_AUTH_DIR: "/app/.rest-auth" volumes: - mcpproxy-tools:/app/tools - mcpproxy-files:/app/files @@ -59,6 +61,7 @@ services: - mcpproxy-npm:/root/.npm - mcpproxy-uv-tools:/root/.local/share/uv - mcpproxy-mcp-auth:/app/.mcp-auth + - mcpproxy-rest-auth:/app/.rest-auth - ./.env:/app/.env volumes: @@ -69,3 +72,4 @@ volumes: mcpproxy-npm: mcpproxy-uv-tools: mcpproxy-mcp-auth: + mcpproxy-rest-auth: diff --git a/frontend/app.py b/frontend/app.py index 3ce3e8f..5118778 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -21,6 +21,7 @@ import ast import asyncio +import html import fcntl import json import os @@ -43,7 +44,7 @@ from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse -from config import CONFIG_DIR, ENV_FILE, REPOS_DIR +from config import CONFIG_DIR, ENV_FILE, FILES_DIR, REPOS_DIR # --------------------------------------------------------------------------- @@ -115,9 +116,21 @@ def _extract_secret_env_keys(spec: dict[str, Any]) -> list[str]: for key in (spec.get("repository") or {}).get("env_keys") or []: if key and key not in keys: keys.append(key) + # REST providers reference auth secrets by env-var name (``*_env`` keys) in + # the auth block, so surface those for the Secrets UI / missing-secrets badge. + for key in _rest_auth_env_keys(spec): + if key and key not in keys: + keys.append(key) return keys +def _rest_auth_env_keys(spec: dict[str, Any]) -> list[str]: + """Return the env-var names referenced by a REST provider's auth block.""" + auth = (spec.get("rest") or {}).get("auth") or {} + candidates = ("token_env", "value_env", "client_id_env", "client_secret_env") + return [auth[k] for k in candidates if auth.get(k)] + + _ENV_EXAMPLE_CANDIDATES = (".env.example", ".env.sample", ".env.template") @@ -184,6 +197,11 @@ def _get_repository_spec(spec: dict[str, Any]) -> dict[str, Any] | None: return spec.get("repository") or None +def _get_rest_spec(spec: dict[str, Any]) -> dict[str, Any] | None: + """Return the rest sub-dict (rest:), or None for non-REST providers.""" + return spec.get("rest") or None + + def _safe_provider_dirname(name: str) -> str: """Normalize a provider name into a safe single-segment directory name.""" safe = re.sub(r"[^a-zA-Z0-9_-]", "-", name or "").strip("-") @@ -229,7 +247,24 @@ def _provider_to_structured(name: str, spec: dict[str, Any]) -> dict[str, Any]: pkg_sub = _get_package_spec(spec) repo_sub = _get_repository_spec(spec) - if repo_sub is not None: + rest_sub = _get_rest_spec(spec) + rest_out: dict[str, Any] = {} + if rest_sub is not None: + ptype = "rest" + command = "" + repo_url = "" + repo_ref = "" + build_commands = [] + workdir = "" + repo_env_keys = [] + rest_out = { + "base_url": (rest_sub.get("base_url") or "").strip(), + "headers": dict(rest_sub.get("headers") or {}), + "auth": dict(rest_sub.get("auth") or {"type": "none"}), + "openapi": (rest_sub.get("openapi") or "").strip(), + "endpoints": list(rest_sub.get("endpoints") or []), + } + elif repo_sub is not None: ptype = "repository" command = (pkg_sub.get("command") if pkg_sub else "") or "" command = command.strip() @@ -268,6 +303,7 @@ def _provider_to_structured(name: str, spec: dict[str, Any]) -> dict[str, Any]: "build_commands": build_commands, "repo_env_keys": repo_env_keys, "workdir": workdir, + "rest": rest_out, "tools": tools_out, } @@ -282,7 +318,25 @@ def _structured_to_yaml(provider: dict[str, Any]) -> str: ptype = provider.get("type", "code") - if ptype == "package": + if ptype == "rest": + rest_in = provider.get("rest") or {} + rest_block: dict[str, Any] = { + "base_url": (rest_in.get("base_url") or "").strip(), + } + headers = {k: v for k, v in (rest_in.get("headers") or {}).items() if k} + if headers: + rest_block["headers"] = headers + auth = dict(rest_in.get("auth") or {"type": "none"}) + auth.setdefault("type", "none") + rest_block["auth"] = auth + openapi = (rest_in.get("openapi") or "").strip() + if openapi: + rest_block["openapi"] = openapi + endpoints = [e for e in (rest_in.get("endpoints") or []) if e.get("name")] + if endpoints: + rest_block["endpoints"] = endpoints + spec["rest"] = rest_block + elif ptype == "package": spec["package"] = {"command": (provider.get("command") or "").strip()} elif ptype == "repository": spec["package"] = {"command": (provider.get("command") or "").strip()} @@ -354,11 +408,52 @@ def _structured_to_yaml(provider: dict[str, Any]) -> str: # Validation # --------------------------------------------------------------------------- +_REST_AUTH_TYPES = {"none", "bearer", "api_key", "client_credentials", "authorization_code"} + + +def _validate_rest(provider: dict[str, Any]) -> list[str]: + """Return validation errors for a REST provider's ``rest`` block.""" + errors: list[str] = [] + rest = provider.get("rest") or {} + if not (rest.get("base_url") or "").strip(): + errors.append("base_url is required for REST providers") + + auth = rest.get("auth") or {} + atype = (auth.get("type") or "none").strip() + if atype not in _REST_AUTH_TYPES: + errors.append(f"auth.type must be one of {sorted(_REST_AUTH_TYPES)}") + if atype == "bearer" and not (auth.get("token_env") or "").strip(): + errors.append("auth.token_env is required for bearer auth") + if atype == "api_key" and not (auth.get("value_env") or "").strip(): + errors.append("auth.value_env is required for api_key auth") + if atype == "client_credentials": + for key in ("token_url", "client_id_env", "client_secret_env"): + if not (auth.get(key) or "").strip(): + errors.append(f"auth.{key} is required for client_credentials auth") + if atype == "authorization_code": + for key in ("authorize_url", "token_url", "client_id_env"): + if not (auth.get(key) or "").strip(): + errors.append(f"auth.{key} is required for authorization_code auth") + + openapi = (rest.get("openapi") or "").strip() + endpoints = rest.get("endpoints") or [] + if not openapi and not endpoints: + errors.append("REST providers need either an openapi source or at least one endpoint") + for i, ep in enumerate(endpoints): + if not (ep.get("method") or "").strip(): + errors.append(f"rest.endpoints[{i}]: method is required") + if not (ep.get("path") or "").strip(): + errors.append(f"rest.endpoints[{i}]: path is required") + return errors + + def _validate_provider(provider: dict[str, Any]) -> dict[str, Any]: errors: list[str] = [] ptype = provider.get("type", "code") - if ptype == "package": + if ptype == "rest": + errors.extend(_validate_rest(provider)) + elif ptype == "package": if not (provider.get("command") or "").strip(): errors.append("command is required for package providers") elif ptype == "repository": @@ -438,6 +533,26 @@ def _extract_functions(code: str) -> dict[str, Any]: # FastAPI app factory # --------------------------------------------------------------------------- +def _safe_local_openapi_path(source: str) -> str: + """Resolve a local OpenAPI file path, restricted to FILES_DIR. + + Stops the introspection endpoint from being used to read arbitrary files + (e.g. ``/app/.env``). Returns the resolved absolute path, or raises + ``ValueError`` if the path escapes the files directory or does not exist. + """ + base = FILES_DIR.resolve() + candidate = Path(source) + candidate = candidate.resolve() if candidate.is_absolute() else (base / candidate).resolve() + if candidate != base and base not in candidate.parents: + raise ValueError( + f"Local OpenAPI files must live under the files directory ({base}). " + "Use an http(s) URL, or place the spec in that directory." + ) + if not candidate.is_file(): + raise ValueError(f"OpenAPI file not found in the files directory: {source}") + return str(candidate) + + def create_app(config_dir: Path | None = None, env_file: Path | None = None) -> "FastAPI": _config_dir = config_dir or CONFIG_DIR _env_file = env_file or ENV_FILE @@ -462,6 +577,7 @@ async def list_tools() -> list[dict]: validation = _validate_provider(structured) is_package = bool(_get_package_spec(spec)) is_repository = bool(_get_repository_spec(spec)) + is_rest = bool(_get_rest_spec(spec)) out.append({ "name": path.stem, "file": path.name, @@ -470,6 +586,7 @@ async def list_tools() -> list[dict]: "provider_type": structured["type"], "is_package": is_package, "is_repository": is_repository, + "is_rest": is_rest, "secret_keys": secret_keys, "missing_secrets": missing_secrets, "validation_errors": validation["errors"], @@ -596,11 +713,105 @@ async def pending_auth(command: str = "") -> dict: bridge refreshes silently and this stays empty. With no `command`, returns every pending URL keyed by spawn command. + REST providers' authorization_code flows publish their URLs the same way + (keyed by provider name) and are merged into the ``pending`` map. """ from process_runner import pending_auth_urls + from rest_provider import pending_rest_auth if command: return {"ok": True, "auth_url": pending_auth_urls.get(command.strip())} - return {"ok": True, "pending": dict(pending_auth_urls)} + merged = {**pending_auth_urls, **pending_rest_auth} + return {"ok": True, "pending": merged, "rest_pending": dict(pending_rest_auth)} + + # ── REST / OpenAPI ─────────────────────────────────────────────────────── + + @app.post("/api/introspect-openapi") + async def introspect_openapi_endpoint(request: Request) -> dict: + """Parse an OpenAPI 3.x / Swagger 2.0 spec (URL or file) into endpoints + tools. + + Body: { openapi: }. Returns ``{ok, endpoints, tools}`` (or + ``{ok: False, error}``). The wizard calls this to expand an OpenAPI source + into concrete endpoints before saving, so the server never fetches the spec + at registration time. + + Local file sources are restricted to the files directory (FILES_DIR) so the + endpoint can't read arbitrary files (e.g. ``.env``). Parsing performs + blocking network/file I/O, so it runs in a worker thread to avoid blocking + the UI event loop. + """ + body = await request.json() + source = (body.get("openapi") or "").strip() + if not source: + raise HTTPException(400, "openapi (URL or file path) is required") + if not (source.startswith("http://") or source.startswith("https://")): + try: + source = _safe_local_openapi_path(source) + except ValueError as exc: + return {"ok": False, "error": str(exc), "endpoints": [], "tools": []} + try: + from rest_provider import introspect_openapi + endpoints, tools = await asyncio.to_thread(introspect_openapi, source) + return {"ok": True, "endpoints": endpoints, "tools": tools} + except Exception as exc: + traceback.print_exc() + return {"ok": False, "error": str(exc), "endpoints": [], "tools": []} + + @app.post("/api/rest-authorize") + async def rest_authorize(request: Request) -> dict: + """Begin an authorization_code flow for a saved REST provider. + + Body: { name: }. Reads the provider's auth block from its + YAML, builds the PKCE authorize URL, publishes it to ``pending_rest_auth``, + and returns ``{ok, auth_url, redirect_uri}`` for the UI to open. + """ + body = await request.json() + name = (body.get("name") or "").strip() + _guard_name(name) + path = _config_dir / f"{name}.yaml" + if not path.exists(): + raise HTTPException(404, f"Provider '{name}' not found") + spec = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + auth = (spec.get("rest") or {}).get("auth") or {} + if (auth.get("type") or "").strip() != "authorization_code": + raise HTTPException(400, "Provider does not use authorization_code auth") + try: + from rest_provider import AuthCodeTokenStore, oauth_redirect_uri + store = AuthCodeTokenStore(name, auth) + auth_url = store.begin_authorization() + return {"ok": True, "auth_url": auth_url, "redirect_uri": oauth_redirect_uri()} + except Exception as exc: + traceback.print_exc() + return {"ok": False, "error": str(exc)} + + @app.get("/oauth/callback") + async def oauth_callback( + code: str = "", state: str = "", error: str = "" + ) -> "HTMLResponse": + """OAuth redirect target for REST providers' authorization_code flow. + + Exchanges ``code`` for tokens (using the PKCE verifier registered under + ``state``) and persists them, then renders a small close-the-tab page. + """ + # Escape every interpolated value: these come from the redirect query + # string / upstream errors and must not be able to inject HTML or script. + if error: + return HTMLResponse( + f"

Authorization failed

{html.escape(error)}

", status_code=400 + ) + if not code or not state: + return HTMLResponse("

Missing code or state

", status_code=400) + try: + from rest_provider import AuthCodeTokenStore + await AuthCodeTokenStore.complete_authorization(state, code) + return HTMLResponse( + "

Authorization complete

" + "

You may close this tab and return to mcpproxy.

" + ) + except Exception as exc: + traceback.print_exc() + return HTMLResponse( + f"

Authorization error

{html.escape(str(exc))}

", status_code=400 + ) # ── Repository clone-and-build ─────────────────────────────────────────── @@ -1200,6 +1411,48 @@ async def index():
+ + + +
+
+
+
🔌
+
REST / OAuth API
+ Point at a REST API (base URL + endpoints, or an OpenAPI spec) with optional OAuth — each endpoint becomes an MCP tool +
+
+
@@ -1406,6 +1668,100 @@ async def index():
+ +
+
+ + +
+
+ + +
Requests are sent to <base URL><endpoint path>.
+
+
+ + +
+ +
+ + +
+ +
+ + +
+
Parses the spec into endpoints + tools when you click Introspect or Next.
+
+ +
+
+
+
@@ -1485,9 +1841,11 @@ async def index(): let secretsModal = null, wizModal = null, termModal = null; let webTerminalEnabled = false; let term = null, termFit = null, termSock = null; // xterm.js terminal state -let wzType = null; // 'code' | 'package' | 'repository' | 'remote' +let wzType = null; // 'code' | 'package' | 'repository' | 'remote' | 'rest' let wzStep = 'type'; let wzIntrospectedTools = []; // tools returned by introspect +let wzRestEndpoints = []; // REST wizard: concrete endpoint specs +let wzRestEndpointTools = {}; // REST wizard: endpoint name → tool spec (from OpenAPI) let wzRepoCtx = null; // repository-wizard state carried across steps // {name, command, repo_url, repo_ref, // build_commands, workdir, env_keys, tools, @@ -1668,6 +2026,16 @@ async def index(): // Failure (or absence of input) is silent — the editor falls back to free-form text. async function discoverFunctions() { if (!currentProvider) return; + // REST providers derive their tool names from the configured endpoints; no + // subprocess introspection or code analysis is involved. + if (currentProvider.type === 'rest') { + knownFunctions = ((currentProvider.rest || {}).endpoints || []).map(e => e.name).filter(Boolean); + knownFnStatus = 'ok'; + knownFnMessage = `Found ${knownFunctions.length} endpoint${knownFunctions.length === 1 ? '' : 's'}`; + _renderKnownFnStatus(); + _refreshToolDropdowns(); + return; + } const isRepo = currentProvider.type === 'repository'; const isPkg = currentProvider.type === 'package' || isRepo; knownFnStatus = 'busy'; @@ -1719,7 +2087,7 @@ async def index(): // reset focus / scroll position). Each dropdown's inner option list is replaced. function _refreshToolDropdowns() { if (!currentProvider) return; - const isPkg = currentProvider.type === 'package' || currentProvider.type === 'repository'; + const isPkg = currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'; const field = isPkg ? 'name' : 'function'; (currentProvider.tools || []).forEach((t, i) => { const sel = document.getElementById(`fn-pick-${i}`); @@ -1763,14 +2131,18 @@ async def index(): function renderProvider(p) { const isRepo = p.type === 'repository'; + const isRest = p.type === 'rest'; const isPkg = p.type === 'package' || isRepo; // repo also uses package.command const isCode = p.type === 'code'; - const label = isRepo ? ' (repository)' : isPkg ? ' (package)' : ' (code)'; + // REST tools (like package tools) are selected by endpoint name, not function. + const nameDriven = isPkg || isRest; + const label = isRepo ? ' (repository)' : isRest ? ' (rest)' : isPkg ? ' (package)' : ' (code)'; document.getElementById('editor-title').textContent = p.name + label; document.getElementById('f-documentation').value = p.documentation || ''; document.getElementById('package-box').style.display = isPkg ? '' : 'none'; document.getElementById('repository-box').style.display = isRepo ? '' : 'none'; + document.getElementById('rest-box').style.display = isRest ? '' : 'none'; document.getElementById('code-box').style.display = isCode ? '' : 'none'; if (isPkg) { @@ -1786,6 +2158,9 @@ async def index(): renderBuildCommands(p.build_commands || []); renderEnvKeys(p.repo_env_keys || []); } + if (isRest) { + renderRestEditor(p); + } if (isCode) { codeEditor.setValue(p.code || ''); setTimeout(() => codeEditor.refresh(), 50); @@ -1793,7 +2168,247 @@ async def index(): renderRequirements(p.requirements || []); renderSetupCommands(p.setup_commands || []); - renderTools(p.tools || [], isPkg); + renderTools(p.tools || [], nameDriven); +} + +function updateRestBaseUrl(val) { + ensureProvider(); + if (!currentProvider.rest) currentProvider.rest = {}; + currentProvider.rest.base_url = val; +} + +// ── REST editor: auth + endpoints (inline editing) ─────────────────────────── + +function renderRestEditor(p) { + const rest = p.rest || (p.rest = {}); + rest.auth = rest.auth || {type: 'none'}; + rest.endpoints = rest.endpoints || []; + rest.headers = rest.headers || {}; + // Edit headers as an ordered array of {key,value}; serialized back to an object + // in collectProvider so the YAML stays a plain mapping. + rest.headerRows = Object.entries(rest.headers).map(([key, value]) => ({key, value})); + document.getElementById('f-rest-base-url').value = rest.base_url || ''; + document.getElementById('f-rest-auth-type').value = rest.auth.type || 'none'; + document.getElementById('rest-authorize-btn').style.display = + (rest.auth.type === 'authorization_code') ? '' : 'none'; + document.getElementById('rest-auth-status').textContent = ''; + renderRestAuthFields(rest.auth); + renderRestHeaders(rest.headerRows); + renderRestEndpoints(rest.endpoints); +} + +function renderRestHeaders(rows) { + const c = document.getElementById('rest-headers-container'); + if (!rows.length) { c.innerHTML = '
(none)
'; return; } + c.innerHTML = rows.map((h, i) => ` +
+ + + +
`).join(''); +} + +function addRestHeader() { + ensureProvider(); + const rest = currentProvider.rest || (currentProvider.rest = {}); + rest.headerRows = rest.headerRows || []; + rest.headerRows.push({key: '', value: ''}); + renderRestHeaders(rest.headerRows); +} + +function removeRestHeader(i) { + ensureProvider(); + currentProvider.rest.headerRows.splice(i, 1); + renderRestHeaders(currentProvider.rest.headerRows); +} + +function updateRestHeader(i, which, val) { + ensureProvider(); + currentProvider.rest.headerRows[i][which] = val; +} + +function updateRestAuthType(val) { + ensureProvider(); + const rest = currentProvider.rest || (currentProvider.rest = {}); + rest.auth = rest.auth || {}; + rest.auth.type = val; + document.getElementById('rest-authorize-btn').style.display = + (val === 'authorization_code') ? '' : 'none'; + renderRestAuthFields(rest.auth); +} + +function updateRestAuthField(key, val) { + ensureProvider(); + const auth = currentProvider.rest.auth; + if (key === 'scopes') auth.scopes = val.split(/\s+/).filter(Boolean); + else auth[key] = val.trim(); +} + +function _restAuthRow(label, key, value, placeholder) { + return `
+ + +
`; +} + +function renderRestAuthFields(auth) { + const c = document.getElementById('f-rest-auth-fields'); + const t = auth.type || 'none'; + let html = ''; + if (t === 'bearer') { + html = _restAuthRow('Token env var', 'token_env', auth.token_env, 'EXAMPLE_TOKEN'); + } else if (t === 'api_key') { + const loc = auth.in === 'query' ? 'query' : 'header'; + html = `
+
`; + if (loc === 'query') + html += _restAuthRow('Query param name', 'name', auth.name, 'api_key'); + else + html += _restAuthRow('Header name', 'header', auth.header, 'X-Api-Key'); + html += _restAuthRow('Value env var', 'value_env', auth.value_env, 'EXAMPLE_API_KEY'); + } else if (t === 'client_credentials' || t === 'authorization_code') { + if (t === 'authorization_code') + html += _restAuthRow('Authorize URL', 'authorize_url', auth.authorize_url, 'https://auth.example.com/oauth/authorize'); + html += _restAuthRow('Token URL', 'token_url', auth.token_url, 'https://auth.example.com/oauth/token'); + html += _restAuthRow('Client ID env var', 'client_id_env', auth.client_id_env, 'EXAMPLE_CLIENT_ID'); + const secretLabel = 'Client secret env var' + (t === 'authorization_code' ? ' (optional — PKCE)' : ''); + html += _restAuthRow(secretLabel, 'client_secret_env', auth.client_secret_env, 'EXAMPLE_CLIENT_SECRET'); + html += _restAuthRow('Scopes (space-separated)', 'scopes', (auth.scopes || []).join(' '), 'read write'); + if (t === 'authorization_code') + html += `
Redirect URI to register with your OAuth provider: ${esc((window.location.origin || 'http://localhost:8889') + '/oauth/callback')}
`; + } + c.innerHTML = html; +} + +function renderRestEndpoints(endpoints) { + const c = document.getElementById('rest-endpoints-container'); + if (!endpoints.length) { + c.innerHTML = '
No endpoints yet — click + Add endpoint.
'; + return; + } + const methods = ['GET', 'POST', 'PUT', 'PATCH', 'DELETE']; + c.innerHTML = endpoints.map((ep, i) => ` +
+
+
+
+ +
+
+
+
+
+
+
+
+
+
+
`).join(''); +} + +function _uniqueEndpointName() { + const used = new Set((currentProvider.rest.endpoints || []).map(e => e.name)); + let n = (currentProvider.rest.endpoints || []).length + 1; + let name = `endpoint_${n}`; + while (used.has(name)) { n++; name = `endpoint_${n}`; } + return name; +} + +function addRestEndpoint() { + ensureProvider(); + const rest = currentProvider.rest || (currentProvider.rest = {}); + rest.endpoints = rest.endpoints || []; + const name = _uniqueEndpointName(); + rest.endpoints.push({name, method: 'GET', path: '/', path_params: [], query_params: [], body_params: []}); + // Pair a tool with the same name so the 1:1 invariant holds and it shows in Tools. + currentProvider.tools = currentProvider.tools || []; + if (!currentProvider.tools.some(t => t.name === name)) { + currentProvider.tools.push({name, function: '', description: name, documentation: '', enabled: true, parameters: [], secrets: []}); + } + renderRestEndpoints(rest.endpoints); + renderTools(currentProvider.tools, true); + discoverFunctions().catch(() => {}); +} + +function removeRestEndpoint(i) { + ensureProvider(); + const ep = currentProvider.rest.endpoints[i]; + currentProvider.rest.endpoints.splice(i, 1); + if (ep && ep.name) { + currentProvider.tools = (currentProvider.tools || []).filter(t => t.name !== ep.name); + } + renderRestEndpoints(currentProvider.rest.endpoints); + renderTools(currentProvider.tools, true); + discoverFunctions().catch(() => {}); +} + +function updateRestEndpoint(i, field, val) { + ensureProvider(); + const ep = currentProvider.rest.endpoints[i]; + if (field === 'name') { + const oldName = ep.name; + const newName = val.trim(); + ep.name = newName; + // Keep the paired tool's name in sync so the endpoint↔tool link is preserved. + (currentProvider.tools || []).forEach(t => { if (t.name === oldName) t.name = newName; }); + renderTools(currentProvider.tools, true); + discoverFunctions().catch(() => {}); + } else { + ep[field] = val.trim(); + } +} + +function updateRestEndpointParams(i, field, val) { + ensureProvider(); + currentProvider.rest.endpoints[i][field] = val.split(',').map(s => s.trim()).filter(Boolean); +} + +// Regenerate the matching tool's parameters from an endpoint's param routing. +// Preserves any existing param's type/description; path params default to required. +function syncRestEndpointToTool(i) { + ensureProvider(); + const ep = currentProvider.rest.endpoints[i]; + const tool = (currentProvider.tools || []).find(t => t.name === ep.name); + if (!tool) { toast('No matching tool for this endpoint', false); return; } + const names = [...(ep.path_params || []), ...(ep.query_params || []), ...(ep.body_params || [])]; + const pathSet = new Set(ep.path_params || []); + const existing = {}; + (tool.parameters || []).forEach(p => { existing[p.name] = p; }); + tool.parameters = names.map(n => existing[n] || {name: n, type: 'string', description: '', required: pathSet.has(n), default: null}); + renderTools(currentProvider.tools, true); + toast(`Synced ${names.length} param(s) to ${ep.name}`); +} + +async function authorizeRestProvider() { + if (!currentName) return; + const status = document.getElementById('rest-auth-status'); + status.className = 'fn-status busy'; + status.textContent = 'Starting authorization…'; + try { + const r = await api('POST', '/api/rest-authorize', {name: currentName}); + if (!r.ok) throw new Error(r.error || 'authorization failed'); + window.open(r.auth_url, '_blank', 'noopener'); + status.className = 'fn-status ok'; + status.innerHTML = `Opened the authorization page. After approving, tokens are cached automatically. ` + + `Re-open`; + } catch(e) { + status.className = 'fn-status error'; + status.textContent = e.message || 'authorization failed'; + } } // Build commands list (repository providers) @@ -2156,6 +2771,15 @@ async def index(): p.repo_env_keys = (currentProvider.repo_env_keys || []).filter(k => k.trim()); } else if (p.type === 'package') { p.command = document.getElementById('f-command').value.trim(); + } else if (p.type === 'rest') { + // auth + endpoints are carried in currentProvider.rest; base URL comes from + // the field, and the header rows are serialized back into a plain object. + p.rest = currentProvider.rest || {}; + p.rest.base_url = document.getElementById('f-rest-base-url').value.trim(); + const headers = {}; + (p.rest.headerRows || []).forEach(h => { if (h.key && h.key.trim()) headers[h.key.trim()] = h.value; }); + p.rest.headers = headers; + delete p.rest.headerRows; } else { p.code = codeEditor.getValue(); } @@ -2182,7 +2806,7 @@ async def index(): function addTool() { ensureProvider(); - const isPkg = currentProvider.type === 'package' || currentProvider.type === 'repository'; + const isPkg = currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'; currentProvider.tools.push({ name: '', function: '', description: '', documentation: '', enabled: true, parameters: [], secrets: [], @@ -2193,31 +2817,31 @@ async def index(): function removeTool(i) { ensureProvider(); currentProvider.tools.splice(i, 1); - renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository'); + renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'); } function addParam(ti) { ensureProvider(); currentProvider.tools[ti].parameters.push({name:'',type:'string',description:'',required:false,default:null}); - renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository'); + renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'); } function removeParam(ti, pi) { ensureProvider(); currentProvider.tools[ti].parameters.splice(pi, 1); - renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository'); + renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'); } function addSecret(ti) { ensureProvider(); currentProvider.tools[ti].secrets.push({arg:'',env:''}); - renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository'); + renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'); } function removeSecret(ti, si) { ensureProvider(); currentProvider.tools[ti].secrets.splice(si, 1); - renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository'); + renderTools(currentProvider.tools, currentProvider.type === 'package' || currentProvider.type === 'repository' || currentProvider.type === 'rest'); } // ───────────────────────────────────────────────────────────────────────────── @@ -2376,7 +3000,7 @@ async def index(): // ───────────────────────────────────────────────────────────────────────────── // Wizard // ───────────────────────────────────────────────────────────────────────────── -const WZ_STEPS = ['type','remote','package','repository','code','secrets']; +const WZ_STEPS = ['type','remote','package','repository','rest','code','secrets']; function wzShowStep(step) { WZ_STEPS.forEach(s => { @@ -2414,10 +3038,26 @@ async def index(): document.getElementById('wz-remote-name').value = ''; document.getElementById('wz-remote-url').value = ''; document.getElementById('wz-remote-result').innerHTML = ''; + wzRestReset(); wzShowStep('type'); wizModal.show(); } +function wzRestReset() { + wzRestEndpoints = []; + wzRestEndpointTools = {}; + ['wz-rest-name','wz-rest-base-url','wz-rest-token-env','wz-rest-header','wz-rest-value-env', + 'wz-rest-authorize-url','wz-rest-token-url','wz-rest-client-id-env','wz-rest-client-secret-env', + 'wz-rest-scopes','wz-rest-openapi'].forEach(id => { const el = document.getElementById(id); if (el) el.value = ''; }); + const at = document.getElementById('wz-rest-auth-type'); if (at) at.value = 'none'; + const ec = document.getElementById('wz-rest-endpoints-container'); if (ec) ec.innerHTML = ''; + const res = document.getElementById('wz-rest-result'); if (res) res.innerHTML = ''; + wzRestAuthChanged(); + wzRestTab('openapi'); + const ru = document.getElementById('wz-rest-redirect-uri'); + if (ru) ru.textContent = (window.location.origin || 'http://localhost:8889') + '/oauth/callback'; +} + function wzAddRepoBuild() { _wzListAdd('wz-repo-builds-container', 'npm install'); } // One-click Node/TypeScript defaults — covers the common case (e.g. the @@ -2443,6 +3083,123 @@ async def index(): setTimeout(() => wzShowStep(type), 120); } +// ── REST wizard helpers ────────────────────────────────────────────────────── + +function wzRestAuthChanged() { + const type = document.getElementById('wz-rest-auth-type').value; + const wrap = document.getElementById('wz-rest-auth-fields'); + wrap.style.display = type === 'none' ? 'none' : ''; + document.querySelectorAll('#wz-rest-auth-fields .wz-rest-auth').forEach(el => el.style.display = 'none'); + if (type === 'bearer') document.querySelector('.wz-rest-auth-bearer').style.display = ''; + else if (type === 'api_key') document.querySelector('.wz-rest-auth-api_key').style.display = ''; + else if (type === 'client_credentials' || type === 'authorization_code') { + document.querySelector('.wz-rest-auth-oauth').style.display = ''; + const authCodeOnly = type === 'authorization_code'; + document.querySelectorAll('.wz-rest-auth-authcode-only').forEach(el => el.style.display = authCodeOnly ? '' : 'none'); + document.getElementById('wz-rest-secret-optional').textContent = authCodeOnly ? 'optional (PKCE)' : 'required'; + } +} + +function wzRestTab(which) { + const openapi = which === 'openapi'; + document.getElementById('wz-rest-tab-openapi').classList.toggle('active', openapi); + document.getElementById('wz-rest-tab-manual').classList.toggle('active', !openapi); + document.getElementById('wz-rest-openapi-pane').style.display = openapi ? '' : 'none'; + document.getElementById('wz-rest-manual-pane').style.display = openapi ? 'none' : ''; +} + +function wzRestCollectAuth() { + const type = document.getElementById('wz-rest-auth-type').value; + const auth = { type }; + const g = id => (document.getElementById(id).value || '').trim(); + if (type === 'bearer') auth.token_env = g('wz-rest-token-env'); + else if (type === 'api_key') { auth.header = g('wz-rest-header') || 'X-Api-Key'; auth.value_env = g('wz-rest-value-env'); } + else if (type === 'client_credentials' || type === 'authorization_code') { + auth.token_url = g('wz-rest-token-url'); + auth.client_id_env = g('wz-rest-client-id-env'); + const sec = g('wz-rest-client-secret-env'); if (sec) auth.client_secret_env = sec; + const scopes = g('wz-rest-scopes'); if (scopes) auth.scopes = scopes.split(/\s+/).filter(Boolean); + if (type === 'authorization_code') auth.authorize_url = g('wz-rest-authorize-url'); + } + return auth; +} + +function wzRestValidateAuth(auth) { + if (auth.type === 'bearer' && !auth.token_env) return 'Bearer auth needs a token env var.'; + if (auth.type === 'api_key' && !auth.value_env) return 'API-key auth needs a value env var.'; + if (auth.type === 'client_credentials') { + if (!auth.token_url || !auth.client_id_env || !auth.client_secret_env) + return 'Client-credentials needs token URL, client ID env, and client secret env.'; + } + if (auth.type === 'authorization_code') { + if (!auth.authorize_url || !auth.token_url || !auth.client_id_env) + return 'Authorization-code needs authorize URL, token URL, and client ID env.'; + } + return ''; +} + +async function wzRestIntrospect() { + const source = document.getElementById('wz-rest-openapi').value.trim(); + const resEl = document.getElementById('wz-rest-result'); + if (!source) { resEl.innerHTML = 'Enter an OpenAPI URL or file path first.'; return; } + resEl.innerHTML = 'Parsing OpenAPI spec…'; + try { + const r = await api('POST', '/api/introspect-openapi', {openapi: source}); + if (!r.ok) throw new Error(r.error || 'introspection failed'); + wzRestEndpoints = r.endpoints || []; + wzRestEndpointTools = {}; + (r.tools || []).forEach(t => { wzRestEndpointTools[t.name] = t; }); + resEl.innerHTML = `
✓ Found ${wzRestEndpoints.length} endpoint(s)
`; + } catch(e) { + wzRestEndpoints = []; wzRestEndpointTools = {}; + resEl.innerHTML = `
✗ ${esc(e.message)}
`; + } +} + +function wzRestAddEndpoint() { + const c = document.getElementById('wz-rest-endpoints-container'); + const idx = c.children.length; + const div = document.createElement('div'); + div.className = 'border rounded p-2 mt-1'; + div.innerHTML = ` +
+
+
+ +
+
+
+
+
+
+
+
+
`; + c.appendChild(div); +} + +function wzRestCollectManualEndpoints() { + const eps = []; + const tools = {}; + document.querySelectorAll('#wz-rest-endpoints-container > .border').forEach(div => { + const name = (div.querySelector('.wz-ep-name').value || '').trim(); + const path = (div.querySelector('.wz-ep-path').value || '').trim(); + if (!name || !path) return; + const csv = sel => (div.querySelector(sel).value || '').split(',').map(s => s.trim()).filter(Boolean); + const path_params = csv('.wz-ep-pathp'); + const query_params = csv('.wz-ep-queryp'); + const body_params = csv('.wz-ep-bodyp'); + eps.push({name, method: div.querySelector('.wz-ep-method').value, path, path_params, query_params, body_params}); + const props = {}; const required = []; + [...path_params, ...query_params, ...body_params].forEach(pn => { props[pn] = {type:'string'}; }); + path_params.forEach(pn => required.push(pn)); + tools[name] = {name, description: name, input_schema: {type:'object', properties: props, required}}; + }); + return {eps, tools}; +} + // Wizard requirement/setup-command list helpers function _wzListAdd(containerId, placeholder) { const c = document.getElementById(containerId); @@ -2614,6 +3371,54 @@ async def index(): return; } + if (wzStep === 'rest') { + const name = document.getElementById('wz-rest-name').value.trim(); + const baseUrl = document.getElementById('wz-rest-base-url').value.trim(); + if (!name) { errEl.textContent = 'Provider name is required.'; return; } + if (!baseUrl) { errEl.textContent = 'Base URL is required.'; return; } + const openapi = document.getElementById('wz-rest-openapi').value.trim(); + const nextBtn = document.getElementById('wz-next-btn'); + const manualActive = document.getElementById('wz-rest-manual-pane').style.display !== 'none'; + if (manualActive) { + const m = wzRestCollectManualEndpoints(); + wzRestEndpoints = m.eps; wzRestEndpointTools = m.tools; + } else if (openapi && !wzRestEndpoints.length) { + // OpenAPI source given but not yet introspected — do it now. + nextBtn.disabled = true; const t = nextBtn.textContent; nextBtn.textContent = '⏳ Introspecting…'; + try { await wzRestIntrospect(); } + finally { nextBtn.disabled = false; nextBtn.textContent = t; } + } + if (!wzRestEndpoints.length) { + errEl.textContent = 'Add at least one endpoint, or import an OpenAPI spec.'; return; + } + const auth = wzRestCollectAuth(); + const authErr = wzRestValidateAuth(auth); + if (authErr) { errEl.textContent = authErr; return; } + const tools = wzRestEndpoints.map(ep => { + const t = wzRestEndpointTools[ep.name]; + return { + name: ep.name, function: '', + description: (t && t.description) || ep.name, + documentation: '', enabled: true, + parameters: _schemaToParams((t && (t.input_schema || t.inputSchema)) || {}), + secrets: [], + }; + }); + const provider = { + name, type: 'rest', command: '', code: '', documentation: '', + requirements: ['httpx'], setup_commands: [], + rest: { base_url: baseUrl, headers: {}, auth, openapi: '', endpoints: wzRestEndpoints }, + tools, + }; + try { + const r = await api('POST', '/api/tools', {name, provider}); + currentName = name; currentProvider = provider; + loadList(); + await wzGoSecrets(r.secret_keys || []); + } catch(e) { errEl.textContent = e.message; } + return; + } + if (wzStep === 'code') { const name = document.getElementById('wz-code-name').value.trim(); const code = document.getElementById('wz-code-input').value; @@ -2649,7 +3454,7 @@ async def index(): } function wzBack() { - const map = {remote:'type', package:'type', repository:'type', code:'type', secrets: wzType||'type'}; + const map = {remote:'type', package:'type', repository:'type', rest:'type', code:'type', secrets: wzType||'type'}; wzShowStep(map[wzStep] || 'type'); } diff --git a/requirements.txt b/requirements.txt index 7458a5f..abad7c9 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ pyyaml personalcapital2 fastapi uvicorn[standard] +httpx diff --git a/rest_provider.py b/rest_provider.py new file mode 100644 index 0000000..af7b515 --- /dev/null +++ b/rest_provider.py @@ -0,0 +1,731 @@ +""" +rest_provider.py — Wrap an arbitrary REST API as MCP tools. + +A provider YAML with a ``rest:`` block declares a base URL, an ``auth:`` block, +and either an ``openapi:`` source (expanded into endpoints at create time by the +frontend) or an explicit list of ``endpoints:``. Each entry in the provider's +``tools:`` list maps 1:1 to an endpoint by ``name``. + +This module supplies: + + * ``_make_rest_handler`` — an async handler (the analogue of + ``server._make_process_handler``) that builds and issues the HTTP request and + returns parsed JSON, suitable for ``server.register_tool``. + * ``OAuthTokenManager`` — client_credentials token cache (fetch/cache/refresh). + * ``AuthCodeTokenStore`` — authorization_code + PKCE token store (on-disk cache, + interactive browser flow, refresh-token rotation). + * ``resolve_rest_auth`` — turn an ``auth:`` block into a resolver that mutates + outgoing request headers. + * ``introspect_openapi`` — parse an OpenAPI 3.0 document into endpoints + tools. + +Secrets (tokens, client id/secret) are referenced by environment-variable name in +the YAML (``*_env`` keys) and read from ``os.environ`` here, so they ride the +existing ``.env`` / Secrets-UI mechanism without ever being written to YAML. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import json +import os +import secrets as _secrets +import time +import traceback +from pathlib import Path +from typing import Any, Callable +from urllib.parse import urlencode + +import httpx + +from config import OAUTH_REDIRECT_BASE, REST_AUTH_DIR + +# Authorization URLs a REST provider is currently waiting on, keyed by provider +# name. The UI polls this (alongside ``process_runner.pending_auth_urls``) so an +# interactive authorization_code flow surfaces a clickable "Authorize" link. +pending_rest_auth: dict[str, str] = {} + +# Seconds of slack subtracted from a token's lifetime so we refresh slightly +# before the real expiry rather than racing it. +_EXPIRY_SKEW = 30.0 + +# How long an in-flight authorization_code attempt (state + PKCE verifier) stays +# valid before it is pruned. The user has this long to complete the browser flow. +_FLOW_TTL = float(os.environ.get("MCPPROXY_OAUTH_FLOW_TTL", "600")) + +# Default timeout (seconds) for every outbound HTTP request. +HTTP_TIMEOUT = float(os.environ.get("MCPPROXY_REST_TIMEOUT", "30")) + +# Maximum size (bytes) of a response body returned to the caller. Responses +# larger than this are truncated to a bounded preview so a single REST call can't +# flood the model's context with megabytes of JSON. Set to 0 to disable. +MAX_RESPONSE_BYTES = int(os.environ.get("MCPPROXY_REST_MAX_BYTES", "100000")) + + +class NeedsAuthorization(Exception): + """Raised when an authorization_code provider has no usable token. + + Carries the authorization URL the user must visit (also published into + ``pending_rest_auth``) so the caller can surface it. + """ + + def __init__(self, provider: str, auth_url: str) -> None: + self.provider = provider + self.auth_url = auth_url + super().__init__( + f"Authorization required for REST provider '{provider}'. " + f"Visit: {auth_url}" + ) + + +# --------------------------------------------------------------------------- +# Secret resolution +# --------------------------------------------------------------------------- + +def _require_env(env_name: str) -> str: + value = os.environ.get(env_name) + if not value: + raise RuntimeError(f"Missing required secret environment variable: {env_name}") + return value + + +def _loop_lock(holder: Any) -> asyncio.Lock: + """Return an ``asyncio.Lock`` bound to the *current* running loop. + + The token managers/stores are cached in module-level state and may be touched + from more than one event loop over a process's lifetime (the startup warm-up + thread first, then the MCP server loop). An ``asyncio.Lock`` binds to the + loop of its first ``await``, so a single shared lock would later raise + "bound to a different loop". Recreating the lock when the running loop + changes keeps each loop's use correctly serialized (uses are single-loop at + runtime) without the cross-loop crash. + """ + loop = asyncio.get_running_loop() + lock = getattr(holder, "_lock", None) + if lock is None or getattr(holder, "_lock_loop", None) is not loop: + lock = asyncio.Lock() + holder._lock = lock + holder._lock_loop = loop + return lock + + +# --------------------------------------------------------------------------- +# OAuth2 client_credentials +# --------------------------------------------------------------------------- + +class OAuthTokenManager: + """Fetch/cache/refresh an OAuth2 ``client_credentials`` access token.""" + + def __init__( + self, + token_url: str, + client_id_env: str, + client_secret_env: str, + scopes: list[str] | None = None, + extra: dict[str, str] | None = None, + ) -> None: + self.token_url = token_url + self.client_id_env = client_id_env + self.client_secret_env = client_secret_env + self.scopes = list(scopes or []) + self.extra = dict(extra or {}) + self._access_token: str | None = None + self._expires_at: float = 0.0 + self._lock: asyncio.Lock | None = None + self._lock_loop: Any = None + + def _is_expired(self) -> bool: + return (not self._access_token) or (time.time() >= self._expires_at - _EXPIRY_SKEW) + + async def get_token(self, *, force_refresh: bool = False) -> str: + async with _loop_lock(self): + if force_refresh or self._is_expired(): + await self._fetch() + assert self._access_token is not None + return self._access_token + + async def _fetch(self) -> None: + data = { + "grant_type": "client_credentials", + "client_id": _require_env(self.client_id_env), + "client_secret": _require_env(self.client_secret_env), + } + if self.scopes: + data["scope"] = " ".join(self.scopes) + data.update(self.extra) + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + resp = await client.post(self.token_url, data=data) + resp.raise_for_status() + payload = resp.json() + token = payload.get("access_token") + if not token: + raise RuntimeError( + f"Token endpoint {self.token_url} returned no access_token" + ) + self._access_token = token + expires_in = float(payload.get("expires_in", 3600)) + self._expires_at = time.time() + expires_in + + +# One manager per (token_url, client_id_env, scopes) so all endpoints of a +# provider share a single cached token (parallels process_runner._sessions). +_token_managers: dict[tuple, OAuthTokenManager] = {} + + +def get_token_manager(auth: dict[str, Any]) -> OAuthTokenManager: + key = ( + auth.get("token_url", ""), + auth.get("client_id_env", ""), + tuple(auth.get("scopes") or ()), + ) + mgr = _token_managers.get(key) + if mgr is None: + mgr = OAuthTokenManager( + token_url=auth.get("token_url", ""), + client_id_env=auth.get("client_id_env", ""), + client_secret_env=auth.get("client_secret_env", ""), + scopes=list(auth.get("scopes") or []), + extra=dict(auth.get("extra") or {}), + ) + _token_managers[key] = mgr + return mgr + + +# --------------------------------------------------------------------------- +# OAuth2 authorization_code + PKCE +# --------------------------------------------------------------------------- + +def _b64url(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + + +def oauth_redirect_uri() -> str: + """The redirect URI the OAuth provider must call back; user registers it.""" + return f"{OAUTH_REDIRECT_BASE}/oauth/callback" + + +class AuthCodeTokenStore: + """On-disk cache + interactive flow for an authorization_code provider. + + One instance per provider name. Tokens persist under + ``REST_AUTH_DIR/.json`` so they survive restarts. + """ + + # In-flight authorization attempts keyed by the OAuth ``state`` value, shared + # across all instances (the callback route only has the state to go on). + _pending_flows: dict[str, dict[str, Any]] = {} + + def __init__(self, provider: str, auth: dict[str, Any]) -> None: + self.provider = provider + self.auth = auth + self._lock: asyncio.Lock | None = None + self._lock_loop: Any = None + + # ── persistence ───────────────────────────────────────────────────────── + + def _cache_path(self) -> Path: + return REST_AUTH_DIR / f"{self.provider}.json" + + def _load(self) -> dict[str, Any]: + path = self._cache_path() + if not path.exists(): + return {} + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + traceback.print_exc() + return {} + + def _save(self, data: dict[str, Any]) -> None: + path = self._cache_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data), encoding="utf-8") + + # ── token access ──────────────────────────────────────────────────────── + + async def get_token(self, *, force_refresh: bool = False) -> str: + async with _loop_lock(self): + data = self._load() + access = data.get("access_token") + expires_at = float(data.get("expires_at", 0)) + fresh = access and time.time() < expires_at - _EXPIRY_SKEW + if fresh and not force_refresh: + return access + refresh_token = data.get("refresh_token") + if refresh_token: + try: + return await self._refresh(refresh_token) + except Exception: + traceback.print_exc() + # No token, or refresh failed → user must (re)authorize. + auth_url = self.begin_authorization() + raise NeedsAuthorization(self.provider, auth_url) + + async def _refresh(self, refresh_token: str) -> str: + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": _require_env(self.auth["client_id_env"]), + } + secret_env = self.auth.get("client_secret_env") + if secret_env: + data["client_secret"] = _require_env(secret_env) + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + resp = await client.post(self.auth["token_url"], data=data) + resp.raise_for_status() + payload = resp.json() + return self._persist_token_response(payload, prior_refresh=refresh_token) + + def _persist_token_response( + self, payload: dict[str, Any], prior_refresh: str | None = None + ) -> str: + access = payload.get("access_token") + if not access: + raise RuntimeError("Token endpoint returned no access_token") + expires_in = float(payload.get("expires_in", 3600)) + record = { + "access_token": access, + "refresh_token": payload.get("refresh_token") or prior_refresh, + "expires_at": time.time() + expires_in, + } + self._save(record) + return access + + # ── interactive authorization ───────────────────────────────────────────── + + def begin_authorization(self) -> str: + """Build the authorize URL (with PKCE), register the in-flight flow, and + publish the URL into ``pending_rest_auth``. Returns the URL. + """ + code_verifier = _b64url(_secrets.token_bytes(48)) + code_challenge = _b64url(hashlib.sha256(code_verifier.encode("ascii")).digest()) + state = _b64url(_secrets.token_bytes(24)) + redirect_uri = oauth_redirect_uri() + params = { + "response_type": "code", + "client_id": _require_env(self.auth["client_id_env"]), + "redirect_uri": redirect_uri, + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + scopes = self.auth.get("scopes") or [] + if scopes: + params["scope"] = " ".join(scopes) + auth_url = f"{self.auth['authorize_url']}?{urlencode(params)}" + self._prune_flows() + AuthCodeTokenStore._pending_flows[state] = { + "provider": self.provider, + "auth": self.auth, + "code_verifier": code_verifier, + "redirect_uri": redirect_uri, + "created": time.time(), + } + pending_rest_auth[self.provider] = auth_url + print( + f"[mcpproxy] authorization required for REST provider " + f"'{self.provider}' — visit: {auth_url}", + flush=True, + ) + return auth_url + + @classmethod + def _prune_flows(cls) -> None: + """Drop in-flight authorization attempts older than ``_FLOW_TTL``.""" + cutoff = time.time() - _FLOW_TTL + stale = [s for s, f in cls._pending_flows.items() if f.get("created", 0) < cutoff] + for state in stale: + cls._pending_flows.pop(state, None) + + @classmethod + async def complete_authorization(cls, state: str, code: str) -> str: + """Exchange ``code`` for tokens using the flow registered under ``state``.""" + cls._prune_flows() + flow = cls._pending_flows.pop(state, None) + if flow is None: + raise RuntimeError("Unknown or expired authorization state") + auth = flow["auth"] + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": flow["redirect_uri"], + "client_id": _require_env(auth["client_id_env"]), + "code_verifier": flow["code_verifier"], + } + secret_env = auth.get("client_secret_env") + if secret_env: + data["client_secret"] = _require_env(secret_env) + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + resp = await client.post(auth["token_url"], data=data) + resp.raise_for_status() + payload = resp.json() + store = cls(flow["provider"], auth) + access = store._persist_token_response(payload) + pending_rest_auth.pop(flow["provider"], None) + return access + + +# --------------------------------------------------------------------------- +# Auth resolver +# --------------------------------------------------------------------------- + +class _AuthResolver: + """Applies a provider's auth to outgoing request headers.""" + + def __init__(self, provider_name: str, auth: dict[str, Any]) -> None: + self.provider_name = provider_name + self.auth = auth or {} + self.type = (self.auth.get("type") or "none").strip() + self.supports_retry = self.type in ("client_credentials", "authorization_code") + self._auth_code_store: AuthCodeTokenStore | None = None + if self.type == "authorization_code": + self._auth_code_store = AuthCodeTokenStore(provider_name, self.auth) + + def apply_query(self, params: dict[str, Any]) -> None: + """Add auth that travels in the query string (api_key with ``in: query``).""" + if self.type == "api_key" and self.auth.get("in") == "query": + name = self.auth.get("name") or self.auth.get("header") or "api_key" + params[name] = _require_env(self.auth["value_env"]) + + async def apply(self, headers: dict[str, str], *, force_refresh: bool = False) -> None: + if self.type == "none": + return + if self.type == "bearer": + headers["Authorization"] = f"Bearer {_require_env(self.auth['token_env'])}" + elif self.type == "api_key": + if self.auth.get("in") == "query": + return # handled by apply_query + header_name = self.auth.get("header", "X-Api-Key") + prefix = self.auth.get("prefix", "") + value = _require_env(self.auth["value_env"]) + headers[header_name] = f"{prefix}{value}" if prefix else value + elif self.type == "client_credentials": + token = await get_token_manager(self.auth).get_token(force_refresh=force_refresh) + headers["Authorization"] = f"Bearer {token}" + elif self.type == "authorization_code": + assert self._auth_code_store is not None + token = await self._auth_code_store.get_token(force_refresh=force_refresh) + headers["Authorization"] = f"Bearer {token}" + else: + raise RuntimeError(f"Unsupported auth type: {self.type!r}") + + +def resolve_rest_auth(provider_name: str, rest_config: dict[str, Any]) -> _AuthResolver: + return _AuthResolver(provider_name, rest_config.get("auth") or {}) + + +# --------------------------------------------------------------------------- +# Request handler +# --------------------------------------------------------------------------- + +def _split_kwargs( + endpoint_spec: dict[str, Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Partition kwargs into (path_params, query_params, body) per the endpoint.""" + path_names = set(endpoint_spec.get("path_params") or []) + query_names = set(endpoint_spec.get("query_params") or []) + body_names = set(endpoint_spec.get("body_params") or []) + path: dict[str, Any] = {} + query: dict[str, Any] = {} + body: dict[str, Any] = {} + for key, value in kwargs.items(): + if value is None: + continue + if key in path_names: + path[key] = value + elif key in query_names: + query[key] = value + elif key in body_names: + body[key] = value + else: + # Unclassified args: assume query for GET/DELETE, body otherwise. + method = (endpoint_spec.get("method") or "GET").upper() + if method in ("GET", "DELETE", "HEAD"): + query[key] = value + else: + body[key] = value + return path, query, body + + +def _cap_response(resp: httpx.Response, tool_name: str) -> Any: + """Parse the response, truncating bodies larger than ``MAX_RESPONSE_BYTES``. + + Oversized bodies are returned as a bounded text preview with a ``truncated`` + flag rather than handed back whole, so one call can't flood the model's + context. Set MCPPROXY_REST_MAX_BYTES=0 to disable. + """ + text = resp.text + if MAX_RESPONSE_BYTES and len(text) > MAX_RESPONSE_BYTES: + return { + "ok": True, + "status": resp.status_code, + "truncated": True, + "total_bytes": len(text), + "preview": text[:MAX_RESPONSE_BYTES], + "note": ( + f"Response was {len(text)} bytes, truncated to {MAX_RESPONSE_BYTES}. " + "Narrow the request (e.g. query params / pagination) for the full result." + ), + } + try: + return resp.json() + except (json.JSONDecodeError, ValueError): + return {"ok": True, "status": resp.status_code, "text": text} + + +def _make_rest_handler( + endpoint_spec: dict[str, Any], + rest_config: dict[str, Any], + provider_name: str, +) -> Callable[..., Any]: + """Return an async handler that calls one REST endpoint. + + Signature matches what ``server.register_tool`` expects: + ``async handler(context=..., **kwargs)``. + """ + base_url = (rest_config.get("base_url") or "").rstrip("/") + default_headers = dict(rest_config.get("headers") or {}) + method = (endpoint_spec.get("method") or "GET").upper() + path_template = endpoint_spec.get("path") or "/" + tool_name = endpoint_spec.get("name", "") + resolver = resolve_rest_auth(provider_name, rest_config) + + async def rest_handler(context: dict[str, Any], **kwargs: Any) -> Any: + try: + path_params, query, body = _split_kwargs(endpoint_spec, kwargs) + resolver.apply_query(query) # api_key-in-query auth + path = path_template.format(**path_params) + url = f"{base_url}{path}" + + async def _do(force_refresh: bool) -> httpx.Response: + headers = dict(default_headers) + await resolver.apply(headers, force_refresh=force_refresh) + async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: + return await client.request( + method, + url, + params=query or None, + json=body or None, + headers=headers, + ) + + resp = await _do(force_refresh=False) + if resp.status_code == 401 and resolver.supports_retry: + resp = await _do(force_refresh=True) + + if resp.status_code >= 400: + return { + "ok": False, + "error": f"HTTP {resp.status_code}: {resp.text[:500]}", + "status": resp.status_code, + "tool": tool_name, + } + + return _cap_response(resp, tool_name) + except NeedsAuthorization as exc: + return {"ok": False, "error": str(exc), "auth_url": exc.auth_url, "tool": tool_name} + except Exception as exc: # noqa: BLE001 + traceback.print_exc() + return {"ok": False, "error": str(exc), "tool": tool_name} + + rest_handler.__name__ = tool_name + return rest_handler + + +# --------------------------------------------------------------------------- +# OpenAPI introspection +# --------------------------------------------------------------------------- + +_HTTP_METHODS = ("get", "put", "post", "delete", "patch", "head", "options") + +_JSON_DEFAULT_TYPE = "string" + + +def _resolve_ref(doc: dict[str, Any], node: Any) -> Any: + """Resolve a single local ``$ref`` (one level) within ``doc``.""" + if isinstance(node, dict) and "$ref" in node: + ref = node["$ref"] + if ref.startswith("#/"): + target: Any = doc + for part in ref[2:].split("/"): + if not isinstance(target, dict): + return {} + target = target.get(part, {}) + return target + return node + + +def _param_schema_type(schema: dict[str, Any]) -> str: + t = schema.get("type") + if isinstance(t, str): + return t + return _JSON_DEFAULT_TYPE + + +def _object_props(doc: dict[str, Any], schema: Any) -> tuple[dict[str, Any], set[str]]: + """Resolve an object schema into (properties, required), merging ``allOf``. + + Handles both OpenAPI 3.x (``#/components/...``) and Swagger 2.0 + (``#/definitions/...``) refs via ``_resolve_ref``. + """ + schema = _resolve_ref(doc, schema or {}) + if not isinstance(schema, dict): + return {}, set() + props: dict[str, Any] = {} + required: set[str] = set() + for sub in schema.get("allOf") or []: + sub_props, sub_required = _object_props(doc, sub) + props.update(sub_props) + required |= sub_required + for pname, pschema in (schema.get("properties") or {}).items(): + props[pname] = _resolve_ref(doc, pschema) + required |= set(schema.get("required") or []) + return props, required + + +def introspect_openapi( + source: str, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Parse an OpenAPI 3.x / Swagger 2.0 document into (endpoints, tools). + + ``source`` is a URL (fetched via httpx) or a local file path. Returns a list + of endpoint specs (method/path/param classification) and a parallel list of + tool specs (name/description/input_schema) ready to drop into the provider. + The API's base URL is configured separately on the provider, so it is not + derived here. + """ + raw = _load_openapi_source(source) + doc = _parse_openapi_text(raw) + endpoints: list[dict[str, Any]] = [] + tools: list[dict[str, Any]] = [] + used_names: set[str] = set() + + paths = doc.get("paths") or {} + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + shared_params = path_item.get("parameters") or [] + for method in _HTTP_METHODS: + op = path_item.get(method) + if not isinstance(op, dict): + continue + name = _operation_name(op, method, path, used_names) + used_names.add(name) + params = list(shared_params) + list(op.get("parameters") or []) + endpoint, tool = _build_endpoint_and_tool(doc, name, method, path, op, params) + endpoints.append(endpoint) + tools.append(tool) + + return endpoints, tools + + +def _load_openapi_source(source: str) -> str: + if source.startswith("http://") or source.startswith("https://"): + resp = httpx.get(source, timeout=HTTP_TIMEOUT, follow_redirects=True) + resp.raise_for_status() + return resp.text + return Path(source).read_text(encoding="utf-8") + + +def _parse_openapi_text(text: str) -> dict[str, Any]: + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + import yaml # local import; pyyaml is always installed + + return yaml.safe_load(text) or {} + + +def _operation_name(op: dict[str, Any], method: str, path: str, used: set[str]) -> str: + name = (op.get("operationId") or "").strip() + if not name: + # Derive from method + path: POST /users/{id}/items → post_users_id_items + slug = path.strip("/").replace("/", "_").replace("{", "").replace("}", "") + slug = "".join(c if (c.isalnum() or c == "_") else "_" for c in slug) + name = f"{method}_{slug}".strip("_") or method + # Sanitize to a tool-safe identifier. + name = "".join(c if (c.isalnum() or c in "_-") else "_" for c in name) + candidate = name + n = 2 + while candidate in used: + candidate = f"{name}_{n}" + n += 1 + return candidate + + +def _build_endpoint_and_tool( + doc: dict[str, Any], + name: str, + method: str, + path: str, + op: dict[str, Any], + params: list[Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + path_params: list[str] = [] + query_params: list[str] = [] + body_params: list[str] = [] + properties: dict[str, Any] = {} + required: list[str] = [] + + def _add_body_props(body_props: dict[str, Any], body_required: set[str]) -> None: + for bname, bschema in body_props.items(): + bschema = _resolve_ref(doc, bschema) + properties[bname] = { + "type": _param_schema_type(bschema), + "description": bschema.get("description", ""), + } + body_params.append(bname) + if bname in body_required: + required.append(bname) + + for raw_param in params: + param = _resolve_ref(doc, raw_param) + if not isinstance(param, dict): + continue + location = param.get("in") + # Swagger 2.0 body parameter: its schema's properties become body params. + if location == "body": + bp, br = _object_props(doc, param.get("schema") or {}) + _add_body_props(bp, br) + continue + pname = param.get("name") + if not pname: + continue + # Type lives on param.schema (OpenAPI 3.x) or directly on param (Swagger 2.0). + schema = _resolve_ref(doc, param.get("schema") or {}) + ptype = _param_schema_type(schema) if schema else _param_schema_type(param) + properties[pname] = {"type": ptype, "description": param.get("description", "")} + if param.get("required") or location == "path": + required.append(pname) + if location == "path": + path_params.append(pname) + elif location in ("query", "formData"): + (body_params if location == "formData" else query_params).append(pname) + + # OpenAPI 3.x requestBody → body params (application/json schema, allOf merged). + request_body = _resolve_ref(doc, op.get("requestBody") or {}) + if isinstance(request_body, dict) and request_body.get("content"): + json_media = (request_body.get("content") or {}).get("application/json") or {} + body_props, body_required = _object_props(doc, json_media.get("schema") or {}) + _add_body_props(body_props, body_required) + + endpoint = { + "name": name, + "method": method.upper(), + "path": path, + "path_params": path_params, + "query_params": query_params, + "body_params": body_params, + } + description = (op.get("summary") or op.get("description") or name).strip() + tool = { + "name": name, + "description": description or name, + "input_schema": { + "type": "object", + "properties": properties, + "required": required, + }, + } + return endpoint, tool diff --git a/run_local.sh b/run_local.sh index 707409e..d175f60 100755 --- a/run_local.sh +++ b/run_local.sh @@ -344,6 +344,8 @@ export MCP_ENV_FILE="$ENV_FILE" # default to /app/files and /app/repos (see Dockerfile + docker-compose.yml). export MCPPROXY_FILES_DIR="${MCPPROXY_FILES_DIR:-$ROOT_DIR/files}" export MCPPROXY_REPOS_DIR="${MCPPROXY_REPOS_DIR:-$ROOT_DIR/repos}" +# OAuth token cache for REST providers' authorization_code flow (gitignored). +export MCPPROXY_REST_AUTH_DIR="${MCPPROXY_REST_AUTH_DIR:-$ROOT_DIR/.rest-auth}" unset MCP_REPOS_DIR # no longer used # ───────────────────────────────────────────────────────────────────────────── diff --git a/server.py b/server.py index bcee447..e8e387e 100755 --- a/server.py +++ b/server.py @@ -281,6 +281,11 @@ def _get_package_command(spec: dict[str, Any]) -> str | None: return None +def _get_rest_config(spec: dict[str, Any]) -> dict[str, Any] | None: + """Return the ``rest:`` sub-dict for REST providers, or None otherwise.""" + return spec.get("rest") or None + + def _make_process_handler( command: str, tool_name: str, @@ -417,6 +422,7 @@ def register_provider(spec: dict[str, Any]) -> None: source_path = spec.get("_config_path", "") provider_name = Path(source_path).stem if source_path != "" else "" try: + rest_config = _get_rest_config(spec) command = _get_package_command(spec) # Repository providers piggy-back on the package code path; the only # difference is that their subprocess is spawned with cwd= @@ -424,7 +430,35 @@ def register_provider(spec: dict[str, Any]) -> None: cwd = repository_workdir(provider_name, spec) env_keys = list((spec.get("repository") or {}).get("env_keys") or []) - if command is not None: + if rest_config is not None: + # ── REST provider ───────────────────────────────────────────────── + # Each tool maps 1:1 to an endpoint (matched by name). Endpoints are + # concrete by this point (OpenAPI specs are expanded into endpoints at + # create time by the frontend), so registration is network-free. + from rest_provider import _make_rest_handler + + endpoints = { + e.get("name"): e for e in (rest_config.get("endpoints") or []) + } + for tool_spec in spec.get("tools", []): + tool_name = tool_spec.get("name", "") + if not tool_is_enabled(tool_spec): + print(f"Skipping disabled tool: {advertised_tool_name(provider_name, tool_name)}") + continue + endpoint_spec = endpoints.get(tool_name) + if endpoint_spec is None: + raise ValueError( + f"REST tool '{tool_name}' in {source_path} has no matching " + f"endpoint (rest.endpoints[].name must equal the tool name)" + ) + handler = _make_rest_handler(endpoint_spec, rest_config, provider_name) + register_tool( + tool_spec, + handler, + advertised_name=advertised_tool_name(provider_name, tool_name), + ) + + elif command is not None: # ── package provider (npx / uvx / python -m / any binary) ────────── for tool_spec in spec.get("tools", []): tool_name = tool_spec.get("name", "") @@ -699,6 +733,55 @@ def _warm_remote_enabled() -> bool: ) +def _rest_oauth_providers() -> list[tuple[str, dict[str, Any]]]: + """Return (provider_name, rest_config) for every OAuth-backed REST provider.""" + out: list[tuple[str, dict[str, Any]]] = [] + for spec in load_provider_specs(CONFIG_DIR): + rest_config = _get_rest_config(spec) + if not rest_config: + continue + auth_type = ((rest_config.get("auth") or {}).get("type") or "none").strip() + if auth_type in ("client_credentials", "authorization_code"): + name = Path(spec.get("_config_path", "")).stem or "rest" + out.append((name, rest_config)) + return out + + +def _warm_rest_providers() -> None: + """Warm OAuth tokens for REST providers once at startup. + + For ``client_credentials`` this fetches and caches the token (validating the + client id/secret early). For ``authorization_code`` it checks the on-disk + cache and, when no usable token exists, surfaces the authorization URL via + ``pending_rest_auth`` (so the UI banner shows it before the first failed tool + call) instead of raising. Disable with MCPPROXY_WARM_REMOTE=0. + """ + providers = _rest_oauth_providers() + if not providers: + return + import asyncio + + from rest_provider import NeedsAuthorization, resolve_rest_auth + + async def _warm_all() -> None: + for name, rest_config in providers: + print(f"[mcpproxy] warming REST OAuth provider: {name}") + try: + resolver = resolve_rest_auth(name, rest_config) + await resolver.apply({}) # fetch/refresh the token (or publish auth URL) + print(f"[mcpproxy] token ready for REST provider: {name}") + except NeedsAuthorization as exc: + print(f"[mcpproxy] REST provider '{name}' needs authorization: {exc.auth_url}") + except Exception as exc: # noqa: BLE001 — best-effort warm-up + print(f"[mcpproxy] warm-up for REST provider '{name}' did not complete: {exc}") + + try: + asyncio.run(_warm_all()) + except Exception as exc: # noqa: BLE001 + print(f"_warm_rest_providers error: {exc}") + traceback.print_exc() + + # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- @@ -723,6 +806,10 @@ def _run_ui() -> None: target=_warm_remote_providers, daemon=True, name="remote-warmup" ) warm_thread.start() + rest_warm_thread = threading.Thread( + target=_warm_rest_providers, daemon=True, name="rest-warmup" + ) + rest_warm_thread.start() mcp.run(transport="streamable-http", host=MCP_HOST, port=MCP_PORT) except Exception as exc: print(f"main error: {exc}") diff --git a/tests/test_frontend.py b/tests/test_frontend.py index 24adf79..a8af607 100644 --- a/tests/test_frontend.py +++ b/tests/test_frontend.py @@ -1,4 +1,5 @@ """Unit tests for the HTTP frontend (frontend/app.py).""" +import json from pathlib import Path from unittest.mock import AsyncMock, patch @@ -15,6 +16,7 @@ _read_env_file, _structured_to_yaml, _validate_provider, + _validate_rest, _write_env_file, _write_workdir_env_file, create_app, @@ -100,6 +102,42 @@ def client(app): }], } +REST_PROVIDER = { + "name": "weather", + "type": "rest", + "documentation": "", + "command": "", + "code": "", + "requirements": ["httpx"], + "setup_commands": [], + "rest": { + "base_url": "https://api.example.com/v1", + "headers": {"Accept": "application/json"}, + "auth": { + "type": "authorization_code", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + "client_id_env": "WEATHER_CLIENT_ID", + "client_secret_env": "WEATHER_CLIENT_SECRET", + "scopes": ["read"], + }, + "openapi": "", + "endpoints": [ + {"name": "get_forecast", "method": "GET", "path": "/forecast/{city}", + "path_params": ["city"], "query_params": ["units"], "body_params": []}, + ], + }, + "tools": [{ + "name": "get_forecast", "function": "", "description": "Get the forecast", + "documentation": "", "enabled": True, + "parameters": [ + {"name": "city", "type": "string", "description": "City", "required": True, "default": None}, + {"name": "units", "type": "string", "description": "Units", "required": False, "default": None}, + ], + "secrets": [], + }], +} + # --------------------------------------------------------------------------- # GET /api/tools @@ -1140,3 +1178,337 @@ def test_disabled_gate_closes_with_message(self, client, monkeypatch): with client.websocket_connect("/ws/terminal") as ws: msg = ws.receive_text() assert "disabled" in msg.lower() + + +# --------------------------------------------------------------------------- +# REST providers +# --------------------------------------------------------------------------- + +class TestRestSpecConversion: + def test_structured_to_yaml_emits_rest_block(self): + out = _structured_to_yaml(REST_PROVIDER) + spec = yaml.safe_load(out) + assert spec["rest"]["base_url"] == "https://api.example.com/v1" + assert spec["rest"]["auth"]["type"] == "authorization_code" + assert spec["rest"]["endpoints"][0]["name"] == "get_forecast" + assert "package" not in spec and "code" not in spec + + def test_rest_tool_has_no_function_field(self): + spec = yaml.safe_load(_structured_to_yaml(REST_PROVIDER)) + assert "function" not in spec["tools"][0] + + def test_provider_to_structured_round_trips_rest(self): + spec = yaml.safe_load(_structured_to_yaml(REST_PROVIDER)) + structured = _provider_to_structured("weather", spec) + assert structured["type"] == "rest" + assert structured["rest"]["base_url"] == "https://api.example.com/v1" + assert structured["rest"]["auth"]["client_id_env"] == "WEATHER_CLIENT_ID" + assert structured["rest"]["endpoints"][0]["path"] == "/forecast/{city}" + + def test_default_headers_and_query_api_key_round_trip(self, app, tools_dir): + """Editor-set default headers and an api_key-in-query auth survive save.""" + provider = { + **REST_PROVIDER, + "rest": { + **REST_PROVIDER["rest"], + "headers": {"Accept": "application/json", "X-Trace": "1"}, + "auth": {"type": "api_key", "in": "query", "name": "apikey", "value_env": "DEMO_KEY"}, + }, + } + r = TestClient(app).post("/api/tools", json={"name": "weather2", "provider": provider}) + assert r.status_code == 200, r.text + spec = yaml.safe_load((tools_dir / "weather2.yaml").read_text()) + assert spec["rest"]["headers"] == {"Accept": "application/json", "X-Trace": "1"} + assert spec["rest"]["auth"] == {"type": "api_key", "in": "query", "name": "apikey", "value_env": "DEMO_KEY"} + # api_key value env surfaces as a secret key + assert "DEMO_KEY" in r.json()["secret_keys"] + # and it round-trips back into the structured editor form + structured = _provider_to_structured("weather2", spec) + assert structured["rest"]["headers"]["X-Trace"] == "1" + assert structured["rest"]["auth"]["in"] == "query" + + def test_editor_update_preserves_edited_endpoints(self, app, tools_dir): + """Simulate the inline editor saving a REST provider with an added + endpoint + renamed tool — auth and endpoints must survive the PUT.""" + (tools_dir / "weather.yaml").write_text(_structured_to_yaml(REST_PROVIDER)) + edited = {**REST_PROVIDER} + edited["rest"] = { + **REST_PROVIDER["rest"], + "base_url": "https://api.example.com/v2", + "endpoints": REST_PROVIDER["rest"]["endpoints"] + [ + {"name": "list_alerts", "method": "GET", "path": "/alerts", + "path_params": [], "query_params": ["region"], "body_params": []}, + ], + } + edited["tools"] = REST_PROVIDER["tools"] + [ + {"name": "list_alerts", "function": "", "description": "List alerts", + "documentation": "", "enabled": True, "parameters": [], "secrets": []}, + ] + r = TestClient(app).put("/api/tools/weather", json={"provider": edited}) + assert r.status_code == 200 + spec = yaml.safe_load((tools_dir / "weather.yaml").read_text()) + assert spec["rest"]["base_url"] == "https://api.example.com/v2" + names = {e["name"] for e in spec["rest"]["endpoints"]} + assert names == {"get_forecast", "list_alerts"} + assert spec["rest"]["auth"]["type"] == "authorization_code" + + +class TestValidateRest: + def test_valid_rest_provider_ok(self): + assert _validate_provider(REST_PROVIDER)["ok"] is True + + def test_missing_base_url_fails(self): + bad = {**REST_PROVIDER, "rest": {**REST_PROVIDER["rest"], "base_url": ""}} + result = _validate_provider(bad) + assert result["ok"] is False + assert any("base_url" in e for e in result["errors"]) + + def test_client_credentials_requires_token_url(self): + provider = { + "type": "rest", + "rest": {"base_url": "https://x", "auth": {"type": "client_credentials"}, + "endpoints": [{"name": "t", "method": "GET", "path": "/"}]}, + "tools": [{"name": "t", "description": "d"}], + } + errors = _validate_rest(provider) + assert any("token_url" in e for e in errors) + assert any("client_id_env" in e for e in errors) + + def test_authorization_code_requires_authorize_url(self): + provider = { + "type": "rest", + "rest": {"base_url": "https://x", "auth": {"type": "authorization_code"}, + "endpoints": [{"name": "t", "method": "GET", "path": "/"}]}, + "tools": [{"name": "t", "description": "d"}], + } + errors = _validate_rest(provider) + assert any("authorize_url" in e for e in errors) + + def test_requires_openapi_or_endpoints(self): + provider = { + "type": "rest", + "rest": {"base_url": "https://x", "auth": {"type": "none"}, + "openapi": "", "endpoints": []}, + "tools": [{"name": "t", "description": "d"}], + } + errors = _validate_rest(provider) + assert any("openapi" in e or "endpoint" in e for e in errors) + + def test_unknown_auth_type_fails(self): + provider = { + "type": "rest", + "rest": {"base_url": "https://x", "auth": {"type": "wat"}, + "endpoints": [{"name": "t", "method": "GET", "path": "/"}]}, + "tools": [{"name": "t", "description": "d"}], + } + errors = _validate_rest(provider) + assert any("auth.type" in e for e in errors) + + +class TestExtractSecretEnvKeysRest: + def test_rest_auth_env_keys_extracted(self): + spec = yaml.safe_load(_structured_to_yaml(REST_PROVIDER)) + keys = _extract_secret_env_keys(spec) + assert "WEATHER_CLIENT_ID" in keys + assert "WEATHER_CLIENT_SECRET" in keys + + +class TestListToolsRest: + def test_lists_rest_provider_is_rest_true(self, app, tools_dir): + (tools_dir / "weather.yaml").write_text(_structured_to_yaml(REST_PROVIDER)) + data = TestClient(app).get("/api/tools").json() + assert data[0]["is_rest"] is True + assert data[0]["provider_type"] == "rest" + + +class TestIntrospectOpenAPIEndpoint: + def test_returns_endpoints_and_tools(self, client): + fake = ( + [{"name": "op", "method": "GET", "path": "/x", + "path_params": [], "query_params": [], "body_params": []}], + [{"name": "op", "description": "d", "input_schema": {"type": "object", "properties": {}, "required": []}}], + ) + with patch("rest_provider.introspect_openapi", return_value=fake): + r = client.post("/api/introspect-openapi", json={"openapi": "https://x/openapi.json"}) + body = r.json() + assert body["ok"] is True + assert body["endpoints"][0]["name"] == "op" + assert body["tools"][0]["name"] == "op" + + def test_error_returns_ok_false(self, client): + with patch("rest_provider.introspect_openapi", side_effect=RuntimeError("boom")): + r = client.post("/api/introspect-openapi", json={"openapi": "https://x"}) + body = r.json() + assert body["ok"] is False and "boom" in body["error"] + + def test_missing_source_is_400(self, client): + r = client.post("/api/introspect-openapi", json={}) + assert r.status_code == 400 + + def test_local_path_outside_files_dir_rejected(self, client, tmp_path, monkeypatch): + import frontend.app as app_module + monkeypatch.setattr(app_module, "FILES_DIR", tmp_path / "files") + (tmp_path / "files").mkdir() + # An absolute path outside the files dir (would otherwise be a file read). + r = client.post("/api/introspect-openapi", json={"openapi": "/etc/hostname"}) + body = r.json() + assert body["ok"] is False + assert "files directory" in body["error"] + + def test_local_path_traversal_rejected(self, client, tmp_path, monkeypatch): + import frontend.app as app_module + files = tmp_path / "files" + files.mkdir() + monkeypatch.setattr(app_module, "FILES_DIR", files) + (tmp_path / "secret.json").write_text("{}") + r = client.post("/api/introspect-openapi", json={"openapi": "../secret.json"}) + assert r.json()["ok"] is False + + def test_local_path_inside_files_dir_allowed(self, client, tmp_path, monkeypatch): + import frontend.app as app_module + files = tmp_path / "files" + files.mkdir() + monkeypatch.setattr(app_module, "FILES_DIR", files) + (files / "spec.json").write_text(json.dumps({ + "openapi": "3.0.0", + "paths": {"/ping": {"get": {"operationId": "ping"}}}, + })) + r = client.post("/api/introspect-openapi", json={"openapi": "spec.json"}) + body = r.json() + assert body["ok"] is True + assert body["endpoints"][0]["name"] == "ping" + + +class TestRestAuthorizeAndCallback: + def test_rest_authorize_begins_flow(self, app, tools_dir, monkeypatch): + monkeypatch.setenv("WEATHER_CLIENT_ID", "cid") + (tools_dir / "weather.yaml").write_text(_structured_to_yaml(REST_PROVIDER)) + r = TestClient(app).post("/api/rest-authorize", json={"name": "weather"}) + body = r.json() + assert body["ok"] is True + assert body["auth_url"].startswith("https://auth.example.com/authorize?") + assert "/oauth/callback" in body["redirect_uri"] + + def test_rest_authorize_rejects_non_auth_code(self, app, tools_dir): + provider = {**REST_PROVIDER, "rest": {**REST_PROVIDER["rest"], "auth": {"type": "none"}}} + (tools_dir / "weather.yaml").write_text(_structured_to_yaml(provider)) + r = TestClient(app).post("/api/rest-authorize", json={"name": "weather"}) + assert r.status_code == 400 + + def test_callback_missing_code_is_400(self, client): + r = client.get("/oauth/callback") + assert r.status_code == 400 + + def test_callback_escapes_error_param(self, client): + r = client.get("/oauth/callback", params={"error": ""}) + assert r.status_code == 400 + assert "" not in r.text + assert "<script>" in r.text + + def test_callback_completes_authorization(self, client): + with patch("rest_provider.AuthCodeTokenStore.complete_authorization", + new=AsyncMock(return_value="tok")): + r = client.get("/oauth/callback?code=c&state=s") + assert r.status_code == 200 + assert "complete" in r.text.lower() + + +class TestRestWizardFlowIntegration: + """Drive the exact backend API sequence the REST wizard JS performs: + introspect OpenAPI → assemble provider → POST /api/tools → GET it back. + + Uses the real OpenAPI parser (not mocked), exercising the full path a user + walks through the wizard, then asserts a valid, reloadable provider results. + """ + + OPENAPI = { + "openapi": "3.0.0", + "info": {"title": "Demo", "version": "1.0"}, + "paths": { + "/users/{user_id}": { + "get": { + "operationId": "get_user", + "summary": "Fetch a user", + "parameters": [ + {"name": "user_id", "in": "path", "required": True, "schema": {"type": "string"}}, + {"name": "expand", "in": "query", "schema": {"type": "string"}}, + ], + } + }, + "/users": { + "post": { + "operationId": "create_user", + "requestBody": { + "required": True, + "content": {"application/json": {"schema": { + "type": "object", "required": ["name"], + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + }}}, + }, + } + }, + }, + } + + def test_full_wizard_sequence(self, app, tools_dir, tmp_path, monkeypatch): + import frontend.app as app_module + files = tmp_path / "files" + files.mkdir() + monkeypatch.setattr(app_module, "FILES_DIR", files) + client = TestClient(app) + spec_file = files / "openapi.json" + spec_file.write_text(json.dumps(self.OPENAPI)) + + # 1. Wizard step: introspect the OpenAPI spec (real parser, file in FILES_DIR). + r = client.post("/api/introspect-openapi", json={"openapi": "openapi.json"}) + body = r.json() + assert body["ok"] is True + endpoints = body["endpoints"] + tools_from_spec = {t["name"]: t for t in body["tools"]} + assert {e["name"] for e in endpoints} == {"get_user", "create_user"} + + # 2. Wizard assembles the provider exactly like wzNext() does. + provider = { + "name": "demo", "type": "rest", "command": "", "code": "", + "documentation": "", "requirements": ["httpx"], "setup_commands": [], + "rest": { + "base_url": "https://api.demo.test/v1", "headers": {}, + "auth": { + "type": "client_credentials", + "token_url": "https://auth.demo.test/token", + "client_id_env": "DEMO_ID", "client_secret_env": "DEMO_SECRET", + "scopes": ["read"], + }, + "openapi": "", "endpoints": endpoints, + }, + "tools": [{ + "name": e["name"], "function": "", + "description": tools_from_spec[e["name"]]["description"], + "documentation": "", "enabled": True, + "parameters": [ + {"name": pn, "type": pdef.get("type", "string"), + "description": pdef.get("description", ""), + "required": pn in tools_from_spec[e["name"]]["input_schema"].get("required", []), + "default": None} + for pn, pdef in tools_from_spec[e["name"]]["input_schema"]["properties"].items() + ], + "secrets": [], + } for e in endpoints], + } + + # 3. Create it. + r = client.post("/api/tools", json={"name": "demo", "provider": provider}) + assert r.status_code == 200, r.text + + # 4. Read it back as the editor would, and verify the on-disk YAML. + got = client.get("/api/tools/demo").json() + assert got["type"] == "rest" + assert got["rest"]["base_url"] == "https://api.demo.test/v1" + assert {e["name"] for e in got["rest"]["endpoints"]} == {"get_user", "create_user"} + + spec = yaml.safe_load((tools_dir / "demo.yaml").read_text()) + create = next(e for e in spec["rest"]["endpoints"] if e["name"] == "create_user") + assert create["method"] == "POST" + assert set(create["body_params"]) == {"name", "age"} + # Secret env keys surface for the wizard's Secrets step. + assert set(r.json()["secret_keys"]) >= {"DEMO_ID", "DEMO_SECRET"} diff --git a/tests/test_rest_provider.py b/tests/test_rest_provider.py new file mode 100644 index 0000000..38dc2ff --- /dev/null +++ b/tests/test_rest_provider.py @@ -0,0 +1,640 @@ +"""Unit tests for rest_provider — handlers, OAuth managers, OpenAPI introspection. + +HTTP is faked by patching ``rest_provider.httpx`` with a recording stub, so no +network is touched. +""" +import asyncio +import json +from pathlib import Path + +import pytest + +import rest_provider +from rest_provider import ( + AuthCodeTokenStore, + NeedsAuthorization, + OAuthTokenManager, + _make_rest_handler, + _split_kwargs, + introspect_openapi, + resolve_rest_auth, +) + + +# --------------------------------------------------------------------------- +# httpx fakes +# --------------------------------------------------------------------------- + +class FakeResponse: + def __init__(self, status_code=200, json_data=None, text=None): + self.status_code = status_code + self._json = json_data + self.text = text if text is not None else json.dumps(json_data or {}) + + def json(self): + if self._json is None: + raise json.JSONDecodeError("no json", "", 0) + return self._json + + def raise_for_status(self): + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + +class FakeAsyncClient: + """Records calls and returns queued responses. Shared via a factory.""" + + def __init__(self, recorder, **kwargs): + self._recorder = recorder + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + return False + + async def request(self, method, url, params=None, json=None, headers=None): + self._recorder["calls"].append( + {"method": method, "url": url, "params": params, "json": json, "headers": headers} + ) + return self._recorder["responses"].pop(0) + + async def post(self, url, data=None, **kwargs): + self._recorder["calls"].append({"method": "POST", "url": url, "data": data}) + return self._recorder["responses"].pop(0) + + +@pytest.fixture() +def http_recorder(monkeypatch): + """Patch rest_provider.httpx.AsyncClient with a recording fake.""" + recorder = {"calls": [], "responses": []} + + def factory(**kwargs): + return FakeAsyncClient(recorder, **kwargs) + + monkeypatch.setattr(rest_provider.httpx, "AsyncClient", factory) + return recorder + + +@pytest.fixture(autouse=True) +def _clear_token_state(): + rest_provider._token_managers.clear() + rest_provider.pending_rest_auth.clear() + AuthCodeTokenStore._pending_flows.clear() + yield + rest_provider._token_managers.clear() + rest_provider.pending_rest_auth.clear() + AuthCodeTokenStore._pending_flows.clear() + + +# --------------------------------------------------------------------------- +# _split_kwargs +# --------------------------------------------------------------------------- + +class TestSplitKwargs: + def test_classifies_by_metadata(self): + ep = { + "method": "POST", + "path_params": ["id"], + "query_params": ["q"], + "body_params": ["title"], + } + path, query, body = _split_kwargs(ep, {"id": "1", "q": "x", "title": "t"}) + assert path == {"id": "1"} + assert query == {"q": "x"} + assert body == {"title": "t"} + + def test_drops_none_values(self): + ep = {"method": "GET", "query_params": ["q"]} + _, query, _ = _split_kwargs(ep, {"q": None}) + assert query == {} + + def test_unclassified_get_goes_to_query(self): + ep = {"method": "GET"} + _, query, body = _split_kwargs(ep, {"extra": "v"}) + assert query == {"extra": "v"} and body == {} + + def test_unclassified_post_goes_to_body(self): + ep = {"method": "POST"} + _, query, body = _split_kwargs(ep, {"extra": "v"}) + assert body == {"extra": "v"} and query == {} + + +# --------------------------------------------------------------------------- +# OAuthTokenManager (client_credentials) +# --------------------------------------------------------------------------- + +CC_AUTH = { + "type": "client_credentials", + "token_url": "https://auth/token", + "client_id_env": "CC_ID", + "client_secret_env": "CC_SECRET", + "scopes": ["read"], +} + + +class TestOAuthTokenManager: + def _mgr(self): + return OAuthTokenManager("https://auth/token", "CC_ID", "CC_SECRET", ["read"]) + + def test_fetches_token_on_first_call(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "abc", "expires_in": 3600})) + token = asyncio.run(self._mgr().get_token()) + assert token == "abc" + assert http_recorder["calls"][0]["data"]["grant_type"] == "client_credentials" + assert http_recorder["calls"][0]["data"]["scope"] == "read" + + def test_caches_token_until_expiry(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "abc", "expires_in": 3600})) + mgr = self._mgr() + + async def go(): + t1 = await mgr.get_token() + t2 = await mgr.get_token() + return t1, t2 + + t1, t2 = asyncio.run(go()) + assert t1 == t2 == "abc" + assert len(http_recorder["calls"]) == 1 # only one fetch + + def test_refreshes_after_expiry(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "a", "expires_in": 0})) + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "b", "expires_in": 3600})) + mgr = self._mgr() + + async def go(): + return await mgr.get_token(), await mgr.get_token() + + t1, t2 = asyncio.run(go()) + assert t1 == "a" and t2 == "b" + assert len(http_recorder["calls"]) == 2 + + def test_force_refresh_bypasses_cache(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "a", "expires_in": 3600})) + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "b", "expires_in": 3600})) + mgr = self._mgr() + + async def go(): + return await mgr.get_token(), await mgr.get_token(force_refresh=True) + + t1, t2 = asyncio.run(go()) + assert t1 == "a" and t2 == "b" + + def test_concurrent_calls_fetch_once(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "abc", "expires_in": 3600})) + mgr = self._mgr() + + async def go(): + return await asyncio.gather(mgr.get_token(), mgr.get_token(), mgr.get_token()) + + tokens = asyncio.run(go()) + assert tokens == ["abc", "abc", "abc"] + assert len(http_recorder["calls"]) == 1 + + def test_missing_secret_raises(self, http_recorder, monkeypatch): + monkeypatch.delenv("CC_ID", raising=False) + with pytest.raises(RuntimeError, match="Missing required secret"): + asyncio.run(self._mgr().get_token()) + + def test_get_token_manager_shares_instance(self): + a = rest_provider.get_token_manager(CC_AUTH) + b = rest_provider.get_token_manager(CC_AUTH) + assert a is b + + def test_reused_across_event_loops(self, http_recorder, monkeypatch): + """A cached manager warmed in one loop (startup thread) must still work + from a different loop (MCP server) — the lock must not bind to one loop.""" + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "a", "expires_in": 3600})) + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "b", "expires_in": 3600})) + mgr = self._mgr() + # loop A (e.g. warm-up thread), then loop B (e.g. MCP server) — no RuntimeError. + t1 = asyncio.run(mgr.get_token()) + t2 = asyncio.run(mgr.get_token(force_refresh=True)) + assert t1 == "a" and t2 == "b" + + +# --------------------------------------------------------------------------- +# AuthCodeTokenStore (authorization_code + PKCE) +# --------------------------------------------------------------------------- + +AC_AUTH = { + "type": "authorization_code", + "authorize_url": "https://auth/authorize", + "token_url": "https://auth/token", + "client_id_env": "AC_ID", + "client_secret_env": "AC_SECRET", + "scopes": ["read", "write"], +} + + +@pytest.fixture() +def rest_auth_dir(tmp_path, monkeypatch): + monkeypatch.setattr(rest_provider, "REST_AUTH_DIR", tmp_path / "rest-auth") + return tmp_path / "rest-auth" + + +class TestAuthCodeTokenStore: + def test_begin_authorization_builds_pkce_url_and_publishes(self, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + store = AuthCodeTokenStore("prov", AC_AUTH) + url = store.begin_authorization() + assert url.startswith("https://auth/authorize?") + assert "code_challenge=" in url and "code_challenge_method=S256" in url + assert "response_type=code" in url + assert rest_provider.pending_rest_auth["prov"] == url + assert len(AuthCodeTokenStore._pending_flows) == 1 + + def test_complete_authorization_persists_tokens(self, http_recorder, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + monkeypatch.setenv("AC_SECRET", "secret") + store = AuthCodeTokenStore("prov", AC_AUTH) + store.begin_authorization() + state = next(iter(AuthCodeTokenStore._pending_flows)) + http_recorder["responses"].append( + FakeResponse(json_data={"access_token": "tok", "refresh_token": "ref", "expires_in": 3600}) + ) + access = asyncio.run(AuthCodeTokenStore.complete_authorization(state, "thecode")) + assert access == "tok" + # token persisted to disk and pending cleared + assert (rest_auth_dir / "prov.json").exists() + assert "prov" not in rest_provider.pending_rest_auth + data = json.loads((rest_auth_dir / "prov.json").read_text()) + assert data["refresh_token"] == "ref" + # the exchange POSTed the PKCE verifier + auth code + exch = http_recorder["calls"][-1] + assert exch["data"]["grant_type"] == "authorization_code" + assert exch["data"]["code"] == "thecode" + assert "code_verifier" in exch["data"] + + def test_complete_authorization_unknown_state_raises(self, rest_auth_dir): + with pytest.raises(RuntimeError, match="Unknown or expired"): + asyncio.run(AuthCodeTokenStore.complete_authorization("nope", "code")) + + def test_stale_pending_flow_is_pruned(self, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + monkeypatch.setattr(rest_provider, "_FLOW_TTL", 100) + store = AuthCodeTokenStore("prov", AC_AUTH) + store.begin_authorization() + state = next(iter(AuthCodeTokenStore._pending_flows)) + # Age the flow past the TTL; the next begin prunes it. + AuthCodeTokenStore._pending_flows[state]["created"] -= 200 + AuthCodeTokenStore("prov2", AC_AUTH).begin_authorization() + assert state not in AuthCodeTokenStore._pending_flows + + def test_fresh_pending_flow_survives_prune(self, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + store = AuthCodeTokenStore("prov", AC_AUTH) + store.begin_authorization() + first = next(iter(AuthCodeTokenStore._pending_flows)) + AuthCodeTokenStore("prov2", AC_AUTH).begin_authorization() + assert first in AuthCodeTokenStore._pending_flows # not stale → kept + + def test_get_token_returns_cached(self, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + rest_auth_dir.mkdir(parents=True) + (rest_auth_dir / "prov.json").write_text( + json.dumps({"access_token": "cached", "refresh_token": "r", "expires_at": 9_999_999_999}) + ) + store = AuthCodeTokenStore("prov", AC_AUTH) + assert asyncio.run(store.get_token()) == "cached" + + def test_get_token_refreshes_when_expired(self, http_recorder, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + monkeypatch.setenv("AC_SECRET", "secret") + rest_auth_dir.mkdir(parents=True) + (rest_auth_dir / "prov.json").write_text( + json.dumps({"access_token": "old", "refresh_token": "r", "expires_at": 0}) + ) + http_recorder["responses"].append( + FakeResponse(json_data={"access_token": "new", "refresh_token": "r2", "expires_in": 3600}) + ) + store = AuthCodeTokenStore("prov", AC_AUTH) + assert asyncio.run(store.get_token()) == "new" + assert http_recorder["calls"][-1]["data"]["grant_type"] == "refresh_token" + + def test_get_token_no_cache_raises_needs_authorization(self, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + store = AuthCodeTokenStore("prov", AC_AUTH) + with pytest.raises(NeedsAuthorization): + asyncio.run(store.get_token()) + assert "prov" in rest_provider.pending_rest_auth + + +# --------------------------------------------------------------------------- +# resolve_rest_auth +# --------------------------------------------------------------------------- + +class TestResolveRestAuth: + def _apply(self, auth): + resolver = resolve_rest_auth("prov", {"auth": auth}) + headers: dict = {} + asyncio.run(resolver.apply(headers)) + return headers, resolver + + def test_none_adds_no_header(self): + headers, resolver = self._apply({"type": "none"}) + assert headers == {} + assert resolver.supports_retry is False + + def test_bearer_sets_authorization_from_env(self, monkeypatch): + monkeypatch.setenv("MY_TOKEN", "xyz") + headers, _ = self._apply({"type": "bearer", "token_env": "MY_TOKEN"}) + assert headers["Authorization"] == "Bearer xyz" + + def test_api_key_sets_custom_header_from_env(self, monkeypatch): + monkeypatch.setenv("MY_KEY", "k1") + headers, _ = self._apply({"type": "api_key", "header": "X-Api-Key", "value_env": "MY_KEY"}) + assert headers["X-Api-Key"] == "k1" + + def test_client_credentials_sets_bearer(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "cc", "expires_in": 3600})) + headers, resolver = self._apply(CC_AUTH) + assert headers["Authorization"] == "Bearer cc" + assert resolver.supports_retry is True + + def test_api_key_in_query_uses_apply_query_not_header(self, monkeypatch): + monkeypatch.setenv("MY_KEY", "qk") + resolver = resolve_rest_auth("prov", {"auth": { + "type": "api_key", "in": "query", "name": "apikey", "value_env": "MY_KEY"}}) + headers: dict = {} + asyncio.run(resolver.apply(headers)) + assert headers == {} # nothing in headers + params: dict = {} + resolver.apply_query(params) + assert params == {"apikey": "qk"} + + def test_apply_query_noop_for_header_api_key(self, monkeypatch): + monkeypatch.setenv("MY_KEY", "k") + resolver = resolve_rest_auth("prov", {"auth": { + "type": "api_key", "header": "X-Api-Key", "value_env": "MY_KEY"}}) + params: dict = {} + resolver.apply_query(params) + assert params == {} + + +# --------------------------------------------------------------------------- +# _make_rest_handler +# --------------------------------------------------------------------------- + +REST_CONFIG = { + "base_url": "https://api.example.com/v1", + "headers": {"Accept": "application/json"}, + "auth": {"type": "none"}, +} + + +class TestMakeRestHandler: + def _call(self, endpoint, rest_config, kwargs): + handler = _make_rest_handler(endpoint, rest_config, "prov") + return asyncio.run(handler(context={}, **kwargs)) + + def test_builds_url_with_path_params(self, http_recorder): + http_recorder["responses"].append(FakeResponse(json_data={"id": "7"})) + ep = {"name": "get_user", "method": "GET", "path": "/users/{user_id}", + "path_params": ["user_id"], "query_params": [], "body_params": []} + result = self._call(ep, REST_CONFIG, {"user_id": "7"}) + assert result == {"id": "7"} + assert http_recorder["calls"][0]["url"] == "https://api.example.com/v1/users/7" + + def test_sends_query_and_body_separately(self, http_recorder): + http_recorder["responses"].append(FakeResponse(json_data={"ok": True})) + ep = {"name": "create", "method": "POST", "path": "/items", + "path_params": [], "query_params": ["dry"], "body_params": ["title"]} + self._call(ep, REST_CONFIG, {"dry": "1", "title": "hi"}) + call = http_recorder["calls"][0] + assert call["params"] == {"dry": "1"} + assert call["json"] == {"title": "hi"} + + def test_merges_default_headers(self, http_recorder, monkeypatch): + monkeypatch.setenv("T", "tk") + http_recorder["responses"].append(FakeResponse(json_data={})) + cfg = {**REST_CONFIG, "auth": {"type": "bearer", "token_env": "T"}} + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + self._call(ep, cfg, {}) + headers = http_recorder["calls"][0]["headers"] + assert headers["Accept"] == "application/json" + assert headers["Authorization"] == "Bearer tk" + + def test_returns_parsed_json(self, http_recorder): + http_recorder["responses"].append(FakeResponse(json_data={"v": 1})) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + assert self._call(ep, REST_CONFIG, {}) == {"v": 1} + + def test_non_json_returns_text(self, http_recorder): + http_recorder["responses"].append(FakeResponse(status_code=200, json_data=None, text="hello")) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, REST_CONFIG, {}) + assert result == {"ok": True, "status": 200, "text": "hello"} + + def test_http_error_returns_error_dict(self, http_recorder): + http_recorder["responses"].append(FakeResponse(status_code=404, json_data=None, text="missing")) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, REST_CONFIG, {}) + assert result["ok"] is False and result["status"] == 404 and result["tool"] == "g" + + def test_401_triggers_refresh_and_retry_once(self, http_recorder, monkeypatch): + monkeypatch.setenv("CC_ID", "id") + monkeypatch.setenv("CC_SECRET", "secret") + # token fetch, then a 401, then token refresh, then a 200 + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "t1", "expires_in": 3600})) + http_recorder["responses"].append(FakeResponse(status_code=401, json_data=None, text="unauth")) + http_recorder["responses"].append(FakeResponse(json_data={"access_token": "t2", "expires_in": 3600})) + http_recorder["responses"].append(FakeResponse(json_data={"ok": True})) + cfg = {**REST_CONFIG, "auth": CC_AUTH} + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, cfg, {}) + assert result == {"ok": True} + # second request used the refreshed token + request_calls = [c for c in http_recorder["calls"] if c.get("url", "").endswith("/x")] + assert request_calls[-1]["headers"]["Authorization"] == "Bearer t2" + + def test_needs_authorization_surfaced_in_result(self, http_recorder, monkeypatch, rest_auth_dir): + monkeypatch.setenv("AC_ID", "id") + cfg = {**REST_CONFIG, "auth": AC_AUTH} + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, cfg, {}) + assert result["ok"] is False and "auth_url" in result + + def test_api_key_query_added_to_request(self, http_recorder, monkeypatch): + monkeypatch.setenv("QK", "secretkey") + http_recorder["responses"].append(FakeResponse(json_data={"ok": True})) + cfg = {**REST_CONFIG, "auth": {"type": "api_key", "in": "query", "name": "key", "value_env": "QK"}} + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": ["q"], "body_params": []} + self._call(ep, cfg, {"q": "term"}) + assert http_recorder["calls"][0]["params"] == {"q": "term", "key": "secretkey"} + + def test_large_response_is_truncated(self, http_recorder, monkeypatch): + monkeypatch.setattr(rest_provider, "MAX_RESPONSE_BYTES", 50) + big = "x" * 500 + http_recorder["responses"].append(FakeResponse(status_code=200, json_data=None, text=big)) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, REST_CONFIG, {}) + assert result["truncated"] is True + assert result["total_bytes"] == 500 + assert len(result["preview"]) == 50 + + def test_small_response_not_truncated(self, http_recorder, monkeypatch): + monkeypatch.setattr(rest_provider, "MAX_RESPONSE_BYTES", 100000) + http_recorder["responses"].append(FakeResponse(json_data={"v": 1})) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + assert self._call(ep, REST_CONFIG, {}) == {"v": 1} + + def test_truncation_disabled_with_zero(self, http_recorder, monkeypatch): + monkeypatch.setattr(rest_provider, "MAX_RESPONSE_BYTES", 0) + http_recorder["responses"].append(FakeResponse(json_data={"data": "y" * 500})) + ep = {"name": "g", "method": "GET", "path": "/x", "path_params": [], "query_params": [], "body_params": []} + result = self._call(ep, REST_CONFIG, {}) + assert result == {"data": "y" * 500} + + +# --------------------------------------------------------------------------- +# introspect_openapi +# --------------------------------------------------------------------------- + +OPENAPI_DOC = { + "openapi": "3.0.0", + "paths": { + "/users/{user_id}": { + "get": { + "operationId": "get_user", + "summary": "Fetch a user", + "parameters": [ + {"name": "user_id", "in": "path", "required": True, "schema": {"type": "string"}}, + {"name": "include", "in": "query", "schema": {"type": "string"}}, + ], + } + }, + "/items": { + "post": { + "operationId": "create_item", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["title"], + "properties": { + "title": {"type": "string"}, + "count": {"type": "integer"}, + }, + } + } + }, + }, + } + }, + }, +} + + +class TestIntrospectOpenAPI: + def _introspect(self, tmp_path): + path = tmp_path / "openapi.json" + path.write_text(json.dumps(OPENAPI_DOC)) + return introspect_openapi(str(path)) + + def test_parses_paths_into_endpoints(self, tmp_path): + endpoints, tools = self._introspect(tmp_path) + names = {e["name"] for e in endpoints} + assert names == {"get_user", "create_item"} + assert len(tools) == 2 + + def test_operationid_becomes_tool_name(self, tmp_path): + _, tools = self._introspect(tmp_path) + assert {t["name"] for t in tools} == {"get_user", "create_item"} + + def test_param_classification(self, tmp_path): + endpoints, _ = self._introspect(tmp_path) + get_user = next(e for e in endpoints if e["name"] == "get_user") + assert get_user["path_params"] == ["user_id"] + assert get_user["query_params"] == ["include"] + create = next(e for e in endpoints if e["name"] == "create_item") + assert set(create["body_params"]) == {"title", "count"} + assert create["method"] == "POST" + + def test_builds_input_schema_with_required(self, tmp_path): + _, tools = self._introspect(tmp_path) + get_user = next(t for t in tools if t["name"] == "get_user") + assert "user_id" in get_user["input_schema"]["properties"] + assert get_user["input_schema"]["required"] == ["user_id"] + create = next(t for t in tools if t["name"] == "create_item") + assert "title" in create["input_schema"]["required"] + + def test_derives_name_when_no_operation_id(self, tmp_path): + doc = {"openapi": "3.0.0", "paths": {"/a/b": {"get": {}}}} + path = tmp_path / "o.json" + path.write_text(json.dumps(doc)) + endpoints, _ = introspect_openapi(str(path)) + assert endpoints[0]["name"] == "get_a_b" + + def test_resolves_local_ref_for_param(self, tmp_path): + doc = { + "openapi": "3.0.0", + "components": {"parameters": {"Id": {"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}}}, + "paths": {"/x/{id}": {"get": {"operationId": "getx", "parameters": [{"$ref": "#/components/parameters/Id"}]}}}, + } + path = tmp_path / "o.json" + path.write_text(json.dumps(doc)) + endpoints, _ = introspect_openapi(str(path)) + assert endpoints[0]["path_params"] == ["id"] + + def test_swagger_2_body_and_query_params(self, tmp_path): + doc = { + "swagger": "2.0", + "definitions": {"NewItem": { + "type": "object", "required": ["title"], + "properties": {"title": {"type": "string"}, "qty": {"type": "integer"}}}}, + "paths": {"/items/{id}": {"post": { + "operationId": "create_item", + "parameters": [ + {"name": "id", "in": "path", "required": True, "type": "string"}, + {"name": "verbose", "in": "query", "type": "boolean"}, + {"name": "body", "in": "body", "schema": {"$ref": "#/definitions/NewItem"}}, + ], + }}}, + } + path = tmp_path / "v2.json" + path.write_text(json.dumps(doc)) + endpoints, tools = introspect_openapi(str(path)) + ep = endpoints[0] + assert ep["path_params"] == ["id"] + assert ep["query_params"] == ["verbose"] + assert set(ep["body_params"]) == {"title", "qty"} + schema = tools[0]["input_schema"] + assert schema["properties"]["verbose"]["type"] == "boolean" + assert "title" in schema["required"] + + def test_allof_merged_in_request_body(self, tmp_path): + doc = { + "openapi": "3.0.0", + "components": {"schemas": { + "Base": {"type": "object", "required": ["a"], "properties": {"a": {"type": "string"}}}}}, + "paths": {"/x": {"post": { + "operationId": "mk", + "requestBody": {"content": {"application/json": {"schema": { + "allOf": [ + {"$ref": "#/components/schemas/Base"}, + {"type": "object", "properties": {"b": {"type": "integer"}}}, + ]}}}}, + }}}, + } + path = tmp_path / "allof.json" + path.write_text(json.dumps(doc)) + endpoints, tools = introspect_openapi(str(path)) + assert set(endpoints[0]["body_params"]) == {"a", "b"} + assert "a" in tools[0]["input_schema"]["required"] diff --git a/tests/test_server.py b/tests/test_server.py index 3b7356c..4ff46fd 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,6 +21,8 @@ SUBPROCESS_KEYS, _build_typed_signature, _get_package_command, + _get_rest_config, + _rest_oauth_providers, advertised_tool_name, build_runtime_context, exec_provider_code, @@ -687,6 +689,121 @@ def test_disabled_code_tool_skipped_without_loading_handler(self, tmp_path: Path assert names == ["p__alive"] +# --------------------------------------------------------------------------- +# register_provider — REST providers +# --------------------------------------------------------------------------- + +class TestRegisterProviderRest: + def _capture_registered(self, spec): + names: list[str] = [] + + def fake_decorator(**kwargs): + names.append(kwargs.get("name")) + return lambda fn: fn + + with patch("server.mcp") as mock_mcp: + mock_mcp.tool.side_effect = fake_decorator + register_provider(spec) + return names + + def _rest_spec(self, tmp_path, enabled=True): + return { + "_config_path": str(tmp_path / "weather.yaml"), + "rest": { + "base_url": "https://api.example.com", + "auth": {"type": "none"}, + "endpoints": [ + {"name": "get_forecast", "method": "GET", "path": "/forecast", + "path_params": [], "query_params": ["city"], "body_params": []}, + ], + }, + "tools": [{ + "name": "get_forecast", + "description": "Get the forecast", + "enabled": enabled, + "input_schema": {"type": "object", + "properties": {"city": {"type": "string"}}, "required": ["city"]}, + }], + } + + def test_get_rest_config_helper(self): + assert _get_rest_config({"rest": {"base_url": "x"}}) == {"base_url": "x"} + assert _get_rest_config({"package": {"command": "x"}}) is None + assert _get_rest_config({}) is None + + def test_rest_branch_detected_and_prefixed(self, tmp_path: Path): + names = self._capture_registered(self._rest_spec(tmp_path)) + assert names == ["weather__get_forecast"] + + def test_rest_tool_registered_into_tool_registry(self, tmp_path: Path): + tool_registry.clear() + try: + self._capture_registered(self._rest_spec(tmp_path)) + entry = tool_registry.get("weather__get_forecast") + assert entry is not None + assert entry["spec"]["name"] == "get_forecast" + finally: + tool_registry.clear() + + def test_disabled_rest_tool_skipped(self, tmp_path: Path): + names = self._capture_registered(self._rest_spec(tmp_path, enabled=False)) + assert names == [] + + def test_rest_tool_missing_endpoint_raises(self, tmp_path: Path): + spec = self._rest_spec(tmp_path) + spec["rest"]["endpoints"] = [] # no endpoint matching the tool + with pytest.raises(ValueError, match="no matching"): + self._capture_registered(spec) + + def test_rest_checked_before_package(self, tmp_path: Path): + # A spec with both rest and (nonsense) package should take the rest path. + spec = self._rest_spec(tmp_path) + names = self._capture_registered(spec) + assert names == ["weather__get_forecast"] + + +class TestWarmRestProviders: + """_rest_oauth_providers — discovery of OAuth-backed REST providers.""" + + def _write(self, config_dir: Path, name: str, auth_type: str): + body = f""" +rest: + base_url: https://api.example.com + auth: + type: {auth_type} + token_url: https://auth/token + authorize_url: https://auth/authorize + client_id_env: X_ID + client_secret_env: X_SECRET + endpoints: + - {{name: t, method: GET, path: /, path_params: [], query_params: [], body_params: []}} +tools: + - name: t + description: d + input_schema: {{type: object, properties: {{}}, required: []}} +""" + (config_dir / f"{name}.yaml").write_text(body) + + def test_discovers_only_oauth_rest_providers(self, tmp_path: Path, monkeypatch): + import server + self._write(tmp_path, "cc", "client_credentials") + self._write(tmp_path, "ac", "authorization_code") + self._write(tmp_path, "plain", "none") + # a non-rest provider must be ignored + (tmp_path / "pkg.yaml").write_text( + "package: {command: echo hi}\ntools: []\n" + ) + monkeypatch.setattr(server, "CONFIG_DIR", tmp_path) + found = {name for name, _ in _rest_oauth_providers()} + assert found == {"cc", "ac"} + + def test_empty_when_no_oauth_rest(self, tmp_path: Path, monkeypatch): + import server + self._write(tmp_path, "plain", "bearer") + monkeypatch.setattr(server, "CONFIG_DIR", tmp_path) + assert _rest_oauth_providers() == [] + + # --------------------------------------------------------------------------- # tool_registry module # ---------------------------------------------------------------------------