Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 101 additions & 5 deletions src/authsome/server/credential_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Self
Expand Down Expand Up @@ -40,6 +41,7 @@
CredentialMissingError,
InvalidProviderSchemaError,
OperationNotAllowedError,
ProviderNotFoundError,
RefreshFailedError,
TokenExpiredError,
UnsupportedFlowError,
Expand Down Expand Up @@ -172,6 +174,28 @@ async def register_provider(self, definition: ProviderDefinition, *, force: bool
)
logger.info("Registered provider: {}", definition.name)

async def update_provider(self, provider: str, definition: ProviderDefinition) -> None:
"""Update an existing custom provider definition."""
self._require_admin("update", "update requires an admin principal", provider)
if provider != definition.name:
raise InvalidProviderSchemaError(
f"Provider name '{definition.name}' must match route provider '{provider}'",
provider=provider,
)
if not await self.is_custom_provider(provider):
raise ProviderNotFoundError(provider)
self._validate_provider(definition)
await self._providers.save_custom(definition, force=True)
audit.emit_event(
"provider.updated",
provider=definition.name,
identity=self._identity,
principal_id=self._principal_id,
status="success",
auth_type=definition.auth_type.value if definition.auth_type else None,
)
logger.info("Updated provider: {}", definition.name)

def _require_admin(self, operation: str, message: str, provider: str) -> None:
"""Allow an operation only for admin principals."""
if self._principal_role == PrincipalRole.ADMIN:
Expand All @@ -180,20 +204,85 @@ def _require_admin(self, operation: str, message: str, provider: str) -> None:

def _validate_provider(self, definition: ProviderDefinition) -> None:
validate_provider_definition(definition)
self._validate_api_targets(definition.api_urls(), "api_url", definition.name)
self._validate_optional_url(definition.docs_url, "docs_url", definition.name)
if definition.oauth:
for field_name in ("authorization_url", "token_url"):
for field_name in (
"authorization_url",
"token_url",
"revocation_url",
"device_authorization_url",
"base_url",
):
url = getattr(definition.oauth, field_name, None)
if url:
self._validate_url(url, field_name, definition.name)
self._validate_optional_url(url, field_name, definition.name, allow_base_url_template=True)
if definition.registration:
self._validate_optional_url(
definition.registration.registration_endpoint,
"registration.registration_endpoint",
definition.name,
)
if definition.browser:
self._validate_optional_url(definition.browser.entry_url, "browser.entry_url", definition.name)
self._validate_optional_url(definition.browser.validate_url, "browser.validate_url", definition.name)

@staticmethod
def _validate_url(url: str, field_name: str, provider_name: str) -> None:
if "{base_url}" in url:
def _validate_optional_url(
url: str | None,
field_name: str,
provider_name: str,
*,
allow_base_url_template: bool = False,
) -> None:
if not url:
return
CredentialService._validate_url(url, field_name, provider_name, allow_base_url_template=allow_base_url_template)

@staticmethod
def _validate_url(
url: str,
field_name: str,
provider_name: str,
*,
allow_base_url_template: bool = False,
) -> None:
if allow_base_url_template and "{base_url}" in url:
return
if "{base_url}" in url:
raise InvalidProviderSchemaError(
f"Invalid URL for '{field_name}': {url}",
provider=provider_name,
)
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise InvalidProviderSchemaError(f"Invalid URL for '{field_name}': {url}", provider=provider_name)

@staticmethod
def _validate_api_targets(targets: tuple[str, ...], field_name: str, provider_name: str) -> None:
for target in targets:
cleaned = target.strip()
if not cleaned or any(char.isspace() for char in cleaned):
raise InvalidProviderSchemaError(
f"Invalid API target for '{field_name}': {target}",
provider=provider_name,
)
if cleaned.startswith("regex:"):
try:
re.compile(cleaned.removeprefix("regex:"))
except re.error as exc:
raise InvalidProviderSchemaError(
f"Invalid API target for '{field_name}': {target}",
provider=provider_name,
) from exc
continue
parsed = urlparse(cleaned if "://" in cleaned else f"https://{cleaned}")
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise InvalidProviderSchemaError(
f"Invalid API target for '{field_name}': {target}",
provider=provider_name,
)

# ── Connection operations ─────────────────────────────────────────────

async def list_connections(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -904,6 +993,13 @@ async def remove(self, provider: str) -> None:
await self.revoke(provider)
if await self.is_custom_provider(provider):
await self._providers.delete_custom(provider)
audit.emit_event(
"provider.deleted",
provider=provider,
identity=self._identity,
principal_id=self._principal_id,
status="success",
)
logger.info("Removed local provider definition: {}", provider)
else:
logger.info("Revoked bundled provider: {} (definition kept)", provider)
Expand Down
8 changes: 0 additions & 8 deletions src/authsome/server/routes/_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,6 @@ async def get_protected_auth_service(
return _build_service(request, ownership)


async def get_admin_auth_service(
auth: CredentialService = Depends(get_protected_auth_service),
) -> CredentialService:
if auth.principal_role != PrincipalRole.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
return auth


async def get_daemon_or_browser_auth_service(request: Request) -> CredentialService:
"""Resolve auth from PoP headers or an existing browser dashboard session."""
if request.headers.get("Authorization"):
Expand Down
39 changes: 34 additions & 5 deletions src/authsome/server/routes/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from authsome.server.credential_service import CredentialService
from authsome.server.routes._deps import (
build_auth_service,
get_admin_auth_service,
get_daemon_or_browser_auth_service,
get_protected_auth_service,
get_server_base_url,
Expand Down Expand Up @@ -168,12 +167,14 @@ async def update_provider_configuration(


@router.post("")
async def register_provider(body: dict, auth: CredentialService = Depends(get_admin_auth_service)):
async def register_provider(body: dict, auth: CredentialService = Depends(get_daemon_or_browser_auth_service)):
if auth.principal_role != PrincipalRole.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
definition_payload = body.get("definition", body)
definition = ProviderDefinition.model_validate(definition_payload)
await auth.register_provider(definition, force=bool(body.get("force", False)))
capture_event(
auth.require_identity(),
_actor(auth),
"provider registered",
{
"provider": definition.name,
Expand All @@ -184,11 +185,39 @@ async def register_provider(body: dict, auth: CredentialService = Depends(get_ad
return {"status": "ok", "provider": definition.name}


@router.put("/{provider}")
async def update_provider(
provider: str,
body: dict,
auth: CredentialService = Depends(get_daemon_or_browser_auth_service),
):
if auth.principal_role != PrincipalRole.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
definition_payload = body.get("definition", body)
definition = ProviderDefinition.model_validate(definition_payload)
await auth.update_provider(provider, definition)
capture_event(
_actor(auth),
"provider updated",
{
"provider": definition.name,
"auth_type": definition.auth_type.value if definition.auth_type else None,
"principal_id": auth.principal_id,
},
)
return {"status": "ok", "provider": definition.name}


@router.delete("/{provider}")
async def delete_provider(provider: str, auth: CredentialService = Depends(get_admin_auth_service)):
async def delete_provider(
provider: str,
auth: CredentialService = Depends(get_daemon_or_browser_auth_service),
):
if auth.principal_role != PrincipalRole.ADMIN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
await auth.remove(provider)
capture_event(
auth.require_identity(),
_actor(auth),
"provider deleted",
{
"provider": provider,
Expand Down
116 changes: 116 additions & 0 deletions tests/server/test_provider_operation_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,119 @@ def test_first_principal_admin_can_register_provider(monkeypatch, tmp_path: Path

assert response.status_code == status.HTTP_200_OK
assert response.json()["status"] == "ok"


def test_browser_session_admin_can_register_provider(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path))
payload = {
"definition": {
"name": "custom-api",
"display_name": "Custom API",
"auth_type": "api_key",
"flow": "api_key",
"api_key": {"header_name": "Authorization"},
}
}

with create_server_test_client() as client:
registered = client.post(
"/api/auth/register",
data={"email": "admin@example.com", "password": "password-1", "next": "/providers"},
follow_redirects=False,
)
response = client.post("/api/providers", json=payload)

assert registered.status_code == status.HTTP_303_SEE_OTHER
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"status": "ok", "provider": "custom-api"}


def test_admin_can_update_custom_provider(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path))
create_payload = {
"definition": {
"name": "custom-api",
"display_name": "Custom API",
"auth_type": "api_key",
"flow": "api_key",
"api_url": "api.example.com",
"api_key": {"header_name": "Authorization", "header_prefix": "Bearer"},
}
}
update_payload = {
"definition": {
"name": "custom-api",
"display_name": "Updated API",
"auth_type": "api_key",
"flow": "api_key",
"api_url": "https://api.example.com/v2",
"api_key": {"header_name": "x-api-key", "header_prefix": ""},
}
}
create_body = json.dumps(create_payload, separators=(",", ":"), sort_keys=True).encode("utf-8")
update_body = json.dumps(update_payload, separators=(",", ":"), sort_keys=True).encode("utf-8")

with create_server_test_client() as client:
_register_identity(client, tmp_path, "steady-wisely-boldly-0042")
created = client.post(
"/api/providers",
content=create_body,
headers={
**_auth_header(tmp_path, "POST", "/api/providers", body=create_body),
"Content-Type": "application/json",
},
)
response = client.put(
"/api/providers/custom-api",
content=update_body,
headers={
**_auth_header(tmp_path, "PUT", "/api/providers/custom-api", body=update_body),
"Content-Type": "application/json",
},
)
fetched = client.get(
"/api/providers/custom-api",
headers=_auth_header(tmp_path, "GET", "/api/providers/custom-api"),
)

assert created.status_code == status.HTTP_200_OK
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"status": "ok", "provider": "custom-api"}
assert fetched.json()["display_name"] == "Updated API"
assert fetched.json()["api_key"]["header_name"] == "x-api-key"


def test_provider_registration_rejects_invalid_url_fields(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path))
payload = {
"definition": {
"name": "custom-api",
"display_name": "Custom API",
"auth_type": "api_key",
"flow": "api_key",
"api_url": "https://api.example.com",
"docs_url": "not a url",
"api_key": {"header_name": "Authorization"},
}
}
body = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8")

with create_server_test_client() as client:
_register_identity(client, tmp_path, "steady-wisely-boldly-0042")
response = client.post(
"/api/providers",
content=body,
headers={
**_auth_header(tmp_path, "POST", "/api/providers", body=body),
"Content-Type": "application/json",
},
)
fetched = client.get(
"/api/providers/custom-api",
headers=_auth_header(tmp_path, "GET", "/api/providers/custom-api"),
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["error"] == "InvalidProviderSchemaError"
assert "docs_url" in response.json()["message"]
assert fetched.status_code == status.HTTP_404_NOT_FOUND
10 changes: 8 additions & 2 deletions ui/src/app/(authenticated)/providers/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import { ProvidersView } from "@/components/dashboard/provider-views";
import { fetchDashboard } from "@/lib/authsome-api";

export default function ProvidersPage() {
const { data } = useSWR("authsome-dashboard", fetchDashboard);
const { data, mutate } = useSWR("authsome-dashboard", fetchDashboard);
if (!data) return null;
return <ProvidersView providers={data.providers} />;
return (
<ProvidersView
isAdmin={data.account.isAdmin}
onRefresh={() => void mutate()}
providers={data.providers}
/>
);
}
4 changes: 3 additions & 1 deletion ui/src/components/authsome-dashboard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ function ActiveView({
onRefresh: () => void;
view: View;
}) {
if (view === "providers") return <ProvidersView providers={data.providers} />;
if (view === "providers") {
return <ProvidersView isAdmin={data.account.isAdmin} onRefresh={onRefresh} providers={data.providers} />;
}
if (view === "connections") {
return (
<ConnectionsView
Expand Down
Loading
Loading