Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions packages/commons/src/zeroshot_commons/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def load_config(
config = _load_raw_config(config_path)
merged_config = deep_merge(config, parse_env_variables())

database_url = os.environ.get("DATABASE_URL")
if database_url:
from .postgres_connection import PostgresConnectionConfig

pg = PostgresConnectionConfig.from_url(database_url)
merged_config = deep_merge(
merged_config,
{
"postgres": {
"host": pg.host,
"port": pg.port,
"username": pg.username,
"password": pg.password,
"database": pg.database,
}
},
)

if not config_key:
return merged_config

Expand Down
17 changes: 16 additions & 1 deletion packages/commons/src/zeroshot_commons/postgres_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
from urllib.parse import quote_plus
from urllib.parse import quote_plus, unquote, urlparse

from .application_config import ApplicationConfig
from .config_utils import load_config
Expand All @@ -25,6 +25,21 @@ class PostgresConnectionConfig:

POSTGRES_CONFIG_KEY = "postgres"

@classmethod
def from_url(cls, url: str) -> PostgresConnectionConfig:
parsed = urlparse(url)
if parsed.scheme not in ("postgresql", "postgres"):
raise ValueError(
f"Unsupported scheme '{parsed.scheme}', expected 'postgresql' or 'postgres'"
)
return cls(
host=parsed.hostname or "localhost",
port=parsed.port or 5432,
username=unquote(parsed.username or ""),
password=unquote(parsed.password or ""),
database=(parsed.path or "/").lstrip("/"),
)

@classmethod
def from_mapping(cls, data: Mapping[str, Any]) -> PostgresConnectionConfig:
return cls(
Expand Down
36 changes: 36 additions & 0 deletions packages/commons/tests/unit/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ def test_load_config_reads_json_and_applies_env_overrides(tmp_path: Path) -> Non
assert config["service"] == {"host": "base", "port": 8080}


def test_load_config_applies_database_url_override(tmp_path: Path) -> None:
package_root = tmp_path / "package"
main_dir = package_root / "src"
assets_dir = package_root / "assets"
main_dir.mkdir(parents=True)
assets_dir.mkdir()
(assets_dir / "config.json").write_text(
json.dumps(
{
"postgres": {
"host": "old-host",
"port": 5432,
"username": "old",
"password": "old",
"database": "old_db",
}
}
),
encoding="utf-8",
)

config = run_with_env(
lambda: load_config(
str(main_dir),
config_file_path="assets/config.json",
),
[("DATABASE_URL", "postgresql://newuser:newpass@newhost:6543/new_db")],
)

assert config["postgres"]["host"] == "newhost"
assert config["postgres"]["port"] == 6543
assert config["postgres"]["username"] == "newuser"
assert config["postgres"]["password"] == "newpass"
assert config["postgres"]["database"] == "new_db"


def test_load_config_reads_sub_config_and_application_config(tmp_path: Path) -> None:
package_root = tmp_path / "package"
main_dir = package_root / "src"
Expand Down
31 changes: 31 additions & 0 deletions packages/commons/tests/unit/test_postgres_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,37 @@ async def recovery() -> str:
assert recovered is True


def test_from_url_parses_standard_postgres_url() -> None:
config = PostgresConnectionConfig.from_url("postgresql://user:pass@localhost:5432/mydb")
assert config.host == "localhost"
assert config.port == 5432
assert config.username == "user"
assert config.password == "pass"
assert config.database == "mydb"


def test_from_url_handles_postgres_scheme() -> None:
config = PostgresConnectionConfig.from_url("postgres://u:p@host:1234/db")
assert config.host == "host"
assert config.port == 1234
assert config.username == "u"
assert config.password == "p"
assert config.database == "db"


def test_from_url_decodes_url_encoded_credentials() -> None:
config = PostgresConnectionConfig.from_url(
"postgresql://user%40domain:p%40ss%3Aword@host:5432/db"
)
assert config.username == "user@domain"
assert config.password == "p@ss:word"


def test_from_url_rejects_unsupported_scheme() -> None:
with pytest.raises(ValueError, match="Unsupported scheme"):
PostgresConnectionConfig.from_url("mysql://user:pass@host:3306/db")


def test_postgres_connection_config_generates_expected_urls() -> None:
config = PostgresConnectionConfig.from_mapping(
{
Expand Down
25 changes: 15 additions & 10 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.