Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions sdk/py/osa/cli/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Credential storage for OSA CLI.

Stores and retrieves authentication tokens keyed by server URL.
Credentials file: ~/.config/osa/credentials.json (0600 permissions).
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any

_DEFAULT_PATH = Path.home() / ".config" / "osa" / "credentials.json"


def _normalize_server(server: str) -> str:
"""Normalize server URL by stripping trailing slashes."""
return server.rstrip("/")


def _read_file(path: Path) -> dict[str, Any]:
"""Read the credentials file, returning empty dict if missing or invalid."""
if not path.exists():
return {}
try:
return json.loads(path.read_text())
except (json.JSONDecodeError, OSError):
return {}


def _write_file(path: Path, data: dict[str, Any]) -> None:
"""Write data to credentials file with 0600 permissions.

Uses os.open with O_CREAT to create the file with 0600 from the start,
avoiding a TOCTOU window where the file is briefly world-readable.
"""
path.parent.mkdir(parents=True, exist_ok=True)
content = json.dumps(data, indent=2).encode()
fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)


def write_credentials(
server: str,
*,
access_token: str,
Comment thread
rorybyrne marked this conversation as resolved.
refresh_token: str,
path: Path = _DEFAULT_PATH,
) -> None:
"""Store credentials for a server URL.

Creates the file if it doesn't exist. Overwrites existing entry
for the same server. Preserves entries for other servers.
"""
server = _normalize_server(server)
data = _read_file(path)
data[server] = {
"access_token": access_token,
"refresh_token": refresh_token,
}
_write_file(path, data)


def read_credentials(
server: str,
*,
path: Path = _DEFAULT_PATH,
) -> dict[str, str] | None:
"""Read credentials for a server URL.

Returns dict with access_token and refresh_token, or None if not found.
"""
server = _normalize_server(server)
data = _read_file(path)
entry = data.get(server)
if entry and "access_token" in entry:
return entry
return None


def remove_credentials(
server: str,
*,
path: Path = _DEFAULT_PATH,
) -> bool:
"""Remove credentials for a server URL.

Returns True if credentials were removed, False if not found.
"""
server = _normalize_server(server)
data = _read_file(path)
if server not in data:
return False
del data[server]
_write_file(path, data)
return True


def refresh_access_token(
server: str,
*,
path: Path = _DEFAULT_PATH,
) -> str | None:
"""Attempt to refresh the access token using the stored refresh token.

On success, updates stored credentials and returns the new access token.
On failure, returns None.
"""
import httpx

creds = read_credentials(server, path=path)
if creds is None or "refresh_token" not in creds:
return None

url = f"{_normalize_server(server)}/api/v1/auth/refresh"
try:
resp = httpx.post(
url,
json={"refresh_token": creds["refresh_token"]},
timeout=10.0,
)
if resp.status_code != 200:
return None

data = resp.json()
write_credentials(
server,
access_token=data["access_token"],
refresh_token=data["refresh_token"],
path=path,
)
return data["access_token"]
except (httpx.HTTPError, ValueError, KeyError):
return None


def resolve_token(
server: str,
*,
path: Path = _DEFAULT_PATH,
) -> str | None:
"""Resolve an access token for a server URL.

Resolution chain:
1. OSA_TOKEN environment variable (for CI/CD)
2. Stored credentials — refresh first so we return a fresh access token
3. None (not authenticated)
"""
env_token = os.environ.get("OSA_TOKEN")
if env_token:
return env_token

creds = read_credentials(server, path=path)
if creds:
# Attempt refresh to get a fresh access token.
# The refresh endpoint is cheap and idempotent.
refreshed = refresh_access_token(server, path=path)
if refreshed is not None:
return refreshed
# Refresh failed — return stored token (server will reject if expired)
return creds["access_token"]

return None
Comment thread
rorybyrne marked this conversation as resolved.
71 changes: 71 additions & 0 deletions sdk/py/osa/cli/link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Per-directory project linking for OSA CLI.

Stores server URL in .osa/config.json so commands don't need --server every time.
Resolution chain: --server flag → OSA_SERVER env → .osa/config.json → error.
"""

from __future__ import annotations

import json
import os
import sys
from pathlib import Path


def write_link(server: str, *, project_dir: Path | None = None) -> Path:
"""Write .osa/config.json in project_dir (default: cwd).

Returns path to the config file.
"""
project_dir = project_dir or Path.cwd()
server = server.rstrip("/")

config_dir = project_dir / ".osa"
config_dir.mkdir(parents=True, exist_ok=True)

config_path = config_dir / "config.json"
config_path.write_text(json.dumps({"server": server}, indent=2) + "\n")

return config_path


def read_link(*, project_dir: Path | None = None) -> str | None:
"""Read server URL from .osa/config.json.

Returns the server URL or None if not found/invalid.
"""
project_dir = project_dir or Path.cwd()
config_path = project_dir / ".osa" / "config.json"

if not config_path.exists():
return None

try:
data = json.loads(config_path.read_text())
server = data.get("server")
if isinstance(server, str) and server:
return server
return None
except (json.JSONDecodeError, OSError):
return None


def resolve_server(*, flag: str | None = None, project_dir: Path | None = None) -> str:
"""Resolve server URL: --server flag → OSA_SERVER env → .osa/config.json → error."""
if flag:
return flag.rstrip("/")

env = os.environ.get("OSA_SERVER")
if env:
return env.rstrip("/")

linked = read_link(project_dir=project_dir)
if linked:
return linked

print(
"Error: No server specified. Use --server <url>, set OSA_SERVER, "
"or run `osa link --server <url>` in your project directory.",
file=sys.stderr,
)
sys.exit(1)
154 changes: 154 additions & 0 deletions sdk/py/osa/cli/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""OSA CLI login command — device flow authentication."""

from __future__ import annotations

import logging
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any

import httpx

from osa.cli.credentials import _DEFAULT_PATH, write_credentials

logger = logging.getLogger(__name__)

DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"


def _poll_for_token(
*,
client: httpx.Client,
server: str,
device_code: str,
interval: int,
expires_in: int,
) -> dict[str, Any] | None:
"""Poll the device token endpoint until authorized, expired, or timed out.

Returns token dict on success, None on expiry/timeout.
"""
url = f"{server.rstrip('/')}/api/v1/auth/device/token"
payload = {
"device_code": device_code,
"grant_type": DEVICE_CODE_GRANT_TYPE,
}

start = time.monotonic()
backoff = interval

while True:
elapsed = time.monotonic() - start
if elapsed >= expires_in:
return None

try:
resp = client.post(url, json=payload)
except httpx.HTTPError:
# Transient network error — backoff and retry
time.sleep(min(backoff, 30))
backoff = min(backoff * 2, 30)
continue

if resp.status_code == 200:
return resp.json()

if resp.status_code >= 500:
# Server error — backoff and retry
time.sleep(min(backoff, 30))
backoff = min(backoff * 2, 30)
continue

# 400-level: check error code
data = resp.json()
error = data.get("error", "")
Comment on lines +55 to +66

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resp.json() unguarded — crashes login on non-JSON responses

Both the 200 path (line 56) and the 400 path (line 65) call resp.json() without exception handling. If a WAF, reverse-proxy, or the server itself returns a non-JSON body (e.g., an HTML error page) for either response code, Python raises json.JSONDecodeError (a ValueError), which is not caught anywhere in the call stack. The polling loop exits with an unhandled traceback instead of a graceful failure message.

Suggested change
if resp.status_code == 200:
return resp.json()
if resp.status_code >= 500:
# Server error — backoff and retry
time.sleep(min(backoff, 30))
backoff = min(backoff * 2, 30)
continue
# 400-level: check error code
data = resp.json()
error = data.get("error", "")
if resp.status_code == 200:
try:
return resp.json()
except ValueError:
logger.error("Device token endpoint returned non-JSON 200 response")
return None
if resp.status_code >= 500:
# Server error — backoff and retry
time.sleep(min(backoff, 30))
backoff = min(backoff * 2, 30)
continue
# 400-level: check error code
try:
data = resp.json()
except ValueError:
logger.error("Device token endpoint returned non-JSON %d response", resp.status_code)
return None


if error == "authorization_pending":
time.sleep(interval)
backoff = interval # Reset backoff on normal pending
continue

if error == "slow_down":
interval = interval + 5 # RFC 8628: increase interval
backoff = interval # Sync backoff with new interval
time.sleep(interval)
continue
Comment thread
rorybyrne marked this conversation as resolved.

if error == "expired_token":
return None

# Unknown error
logger.error(
"Device token error: %s — %s", error, data.get("error_description", "")
)
return None


def login(
server: str,
*,
cred_path: Path = _DEFAULT_PATH,
) -> bool:
"""Run the device flow login.

Returns True on success, False on failure.
"""
server = server.rstrip("/")

with httpx.Client(timeout=30.0) as client:
# Step 1: Initiate device authorization
try:
resp = client.post(f"{server}/api/v1/auth/device")
resp.raise_for_status()
except httpx.HTTPError as e:
print(f"Error: Could not reach server at {server}", file=sys.stderr)
logger.debug("Initiation failed: %s", e)
return False

data = resp.json()
device_code = data["device_code"]
user_code = data["user_code"]
verification_uri = data["verification_uri"]
expires_in = data["expires_in"]
Comment on lines +109 to +114

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unguarded resp.json() and bare key access on device initiation response

resp.json() can raise json.JSONDecodeError (a ValueError) if the 200 response body is not valid JSON (e.g., an HTML error page from a WAF or reverse proxy). Additionally, the subsequent direct key accesses (data["device_code"], data["user_code"], etc.) will raise KeyError if any field is absent.

The same concern applies to the write step at result["access_token"] / result["refresh_token"] ~line 142, where result comes from the poll endpoint.

Both _poll_for_token (noted in a prior review comment at line 55) and login() share this pattern. Consider wrapping in a try/except and printing a helpful diagnostic message on failure:

try:
    data = resp.json()
    device_code = data["device_code"]
    user_code = data["user_code"]
    verification_uri = data["verification_uri"]
    expires_in = data["expires_in"]
    interval = data["interval"]
except (ValueError, KeyError) as e:
    print(f"Error: unexpected response from server: {e}", file=sys.stderr)
    return False

interval = data["interval"]

# Step 2: Display code and URL
print(f"Open this URL in your browser: {verification_uri}")
print(f"Enter code: {user_code}")
print()

# Try to open browser
try:
webbrowser.open(verification_uri)
except Exception:
pass # Non-critical — user can open manually

# Step 3: Poll for token
print("Waiting for authorization...", end=" ", flush=True)
result = _poll_for_token(
client=client,
server=server,
device_code=device_code,
interval=interval,
expires_in=expires_in,
)

if result is None:
print("failed")
print("Device code expired. Please try again.", file=sys.stderr)
return False

print("done")

# Step 4: Store credentials
write_credentials(
server,
access_token=result["access_token"],
refresh_token=result["refresh_token"],
path=cred_path,
)

print("Token stored.")
return True
Loading
Loading