From 6a9561e1348671ad81adf67b8c2ec86c27335843 Mon Sep 17 00:00:00 2001 From: Luca Date: Wed, 3 Jun 2026 18:17:34 +0100 Subject: [PATCH] Allow to filter by params --- docs/reference/task_plugin.md | 8 ++--- fluid/db/crud.py | 6 ++++ fluid/scheduler/db.py | 51 ++++++++++++++++++++++++++----- tests/scheduler/test_db_plugin.py | 49 ++++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/docs/reference/task_plugin.md b/docs/reference/task_plugin.md index 477b437..797644e 100644 --- a/docs/reference/task_plugin.md +++ b/docs/reference/task_plugin.md @@ -31,13 +31,13 @@ called directly from within a task by passing `context.task_manager`: ```python from fluid.scheduler import TaskRun, task -from fluid.scheduler.db import get_db_plugin, HistoryQuery +from fluid.scheduler.db import get_db_plugin, TaskHistoryQuery @task() async def report(context: TaskRun) -> None: db_plugin = get_db_plugin(context.task_manager) - page = await db_plugin.get_history(HistoryQuery(task="my-task", limit=10)) + page = await db_plugin.get_history(TaskHistoryQuery(task="my-task", limit=10)) for run in page.data: print(run.id, run.state) ``` @@ -53,10 +53,10 @@ or the HTTP endpoints added by [with_task_history_router][fluid.scheduler.db.wit They can be imported from `fluid.scheduler.db`: ```python -from fluid.scheduler.db import HistoryQuery, TaskRunHistory, TaskRunHistoryPage +from fluid.scheduler.db import TaskHistoryQuery, TaskRunHistory, TaskRunHistoryPage ``` -::: fluid.scheduler.db.HistoryQuery +::: fluid.scheduler.db.TaskHistoryQuery ::: fluid.scheduler.db.TaskRunHistory diff --git a/fluid/db/crud.py b/fluid/db/crud.py index 5996bd4..5e646e8 100644 --- a/fluid/db/crud.py +++ b/fluid/db/crud.py @@ -3,6 +3,7 @@ from dateutil.parser import parse as parse_date from sqlalchemy import Column, Table, func, insert, select +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -300,6 +301,11 @@ def default_filter_column( ], ) -> Any: """Build a SQLAlchemy WHERE clause expression for a single column filter""" + if isinstance(column.type, JSONB) and isinstance(value, dict): + if op == "eq": + return column.contains(value) + return None + if multiple := isinstance(value, (list, tuple)): value = tuple(column_value_to_python(column, v) for v in value) else: diff --git a/fluid/scheduler/db.py b/fluid/scheduler/db.py index 4feba89..07db11f 100644 --- a/fluid/scheduler/db.py +++ b/fluid/scheduler/db.py @@ -1,11 +1,13 @@ from __future__ import annotations +import json from datetime import datetime from typing import Any, ClassVar import sqlalchemy as sa from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query -from pydantic import BaseModel, Field +from pydantic import BaseModel, BeforeValidator, Field, model_validator +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import NoResultFound from typing_extensions import Annotated, Doc @@ -20,8 +22,18 @@ from .plugin import TaskManagerPlugin +def _parse_json_str(v: Any) -> Any: + """Parse a JSON string into a dict for JSONB query parameters.""" + if isinstance(v, str): + return json.loads(v) + return v + + +JsonDict = Annotated[dict[str, Any], BeforeValidator(_parse_json_str)] + + class TaskDbPlugin(TaskManagerPlugin): - """A plugin to store task runs in a database. + """A plugin to store [TaskRun][fluid.scheduler.TaskRun] in a postgresql database. This plugin listens to task state changes and updates the database accordingly. It requires a CrudDB instance to perform database operations and allows @@ -95,7 +107,7 @@ def register(self, task_manager: TaskManager) -> None: async def get_history( self, q: Annotated[ - HistoryQuery, Doc("Query parameters for fetching task run history") + TaskHistoryQuery, Doc("Query parameters for fetching task run history") ], ) -> TaskRunHistoryPage: """Get task run history based on the provided query parameters.""" @@ -169,10 +181,15 @@ def task_meta(meta: sa.MetaData, table_name: str = "tasks") -> None: nullable=False, index=True, ), - sa.Column("queued", sa.DateTime(timezone=True), nullable=False), + sa.Column("queued", sa.DateTime(timezone=True), nullable=False, index=True), sa.Column("start", sa.DateTime(timezone=True)), sa.Column("end", sa.DateTime(timezone=True)), - sa.Column("params", sa.JSON), + sa.Column("params", JSONB), + sa.Index( + f"ix_{table_name}_params", + "params", + postgresql_using="gin", + ), ) @@ -228,32 +245,52 @@ class TaskRunHistoryPage(BaseModel): cursor: str = Field(..., description="Pagination cursor to fetch the next page") -class HistoryQuery(BaseModel): +class TaskHistoryQuery(BaseModel): """Query parameters for fetching task run history.""" task: Annotated[ str | None, Query(description="Filter by task name"), + Doc("Filter by task name when provided"), ] = None start: Annotated[ datetime | None, Query(description="Filter runs queued at or after this time"), + Doc("Filter runs queued at or after this time when provided"), ] = None end: Annotated[ datetime | None, Query(description="Filter runs queued at or before this time"), + Doc("Filter runs queued at or before this time when provided"), ] = None state: Annotated[ TaskState | None, Query(description="Filter by task state"), + Doc("Filter by task state when provided"), + ] = None + params: Annotated[ + dict[str, Any] | str | None, + Query(description="Filter by params using JSON containment"), + Doc("Filter by params using JSON containment when provided"), ] = None + + @model_validator(mode="before") + @classmethod + def _parse_params_str(cls, data: Any) -> Any: + if isinstance(data, dict) and "params" in data: + data = {**data} + data["params"] = _parse_json_str(data["params"]) + return data + limit: Annotated[ int | None, Query(description="Maximum number of results to return", ge=1), + Doc("Maximum number of results to return when provided"), ] = None cursor: Annotated[ str, Query(description="Pagination cursor from a previous response"), + Doc("Pagination cursor from a previous response when provided"), ] = "" _filter_map: ClassVar[dict[str, str]] = { @@ -278,7 +315,7 @@ def filters(self) -> dict: ) async def get_history( db_plugin: TaskDbPluginDep, - q: Annotated[HistoryQuery, Depends()], + q: Annotated[TaskHistoryQuery, Query()], ) -> TaskRunHistoryPage: return await db_plugin.get_history(q) diff --git a/tests/scheduler/test_db_plugin.py b/tests/scheduler/test_db_plugin.py index bd0c2e9..0b07832 100644 --- a/tests/scheduler/test_db_plugin.py +++ b/tests/scheduler/test_db_plugin.py @@ -1,4 +1,5 @@ import asyncio +import json from datetime import datetime, timezone from typing import Any, AsyncIterator, cast @@ -10,7 +11,12 @@ from examples import tasks from fluid.scheduler import TaskState from fluid.scheduler.consumer import TaskConsumer -from fluid.scheduler.db import TaskDbPlugin, get_db_plugin, with_task_history_router +from fluid.scheduler.db import ( + TaskDbPlugin, + TaskHistoryQuery, + get_db_plugin, + with_task_history_router, +) from fluid.scheduler.endpoints import get_task_manager, task_manager_fastapi from fluid.utils.http_client import HttpResponseError from tests.scheduler.tasks import TaskClient, redis_broker, start_fastapi @@ -268,3 +274,44 @@ async def test_get_history_filter_by_end( assert any(item["id"] == task_run.id for item in data) data_empty = await get_history(cli_db, end="2000-01-01T00:00:00Z") assert data_empty == [] + + +async def test_get_history_filter_by_params_programmatic( + task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin +) -> None: + task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=7.0, b=8.0) + assert task_run.state == TaskState.success + + await wait_for_task_run(db_plugin, task_run.id) + page = await db_plugin.get_history(TaskHistoryQuery(params={"a": 7.0})) + assert len(page.data) >= 1 + assert any(r.id == task_run.id for r in page.data) + assert all(7.0 == r.params.get("a") for r in page.data) + + # Negative: filter that shouldn't match + page_empty = await db_plugin.get_history(TaskHistoryQuery(params={"a": 999.0})) + assert not any(r.id == task_run.id for r in page_empty.data) + + +async def test_get_history_filter_by_params_http( + cli_db: TaskClient, task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin +) -> None: + task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=9.0, b=10.0) + assert task_run.state == TaskState.success + + await wait_for_task_run(db_plugin, task_run.id) + + data = await get_history(cli_db, params=json.dumps({"a": 9.0})) + assert any(item["id"] == task_run.id for item in data) + assert all(9.0 == item["params"].get("a") for item in data) + + +async def test_get_history_filter_by_params_http_negative( + cli_db: TaskClient, task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin +) -> None: + task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=11.0, b=12.0) + assert task_run.state == TaskState.success + + await wait_for_task_run(db_plugin, task_run.id) + data_empty = await get_history(cli_db, params=json.dumps({"a": 999.0})) + assert not any(item["id"] == task_run.id for item in data_empty)