diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5e9fea7..5d061ce 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,6 +16,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true +permissions: + contents: read + env: TZ: Asia/Shanghai diff --git a/doc/en/api/time.md b/doc/en/api/time.md deleted file mode 100644 index 10be8c0..0000000 --- a/doc/en/api/time.md +++ /dev/null @@ -1,9 +0,0 @@ -# Time Helpers - -::: aloha.time - -::: aloha.time.timeout_async - -::: aloha.time.timeout_asyncio - -::: aloha.time.timeout_signal diff --git a/doc/en/api/util.md b/doc/en/api/util.md index ae72830..fc71493 100644 --- a/doc/en/api/util.md +++ b/doc/en/api/util.md @@ -13,3 +13,38 @@ ::: aloha.util.sys_gpu ::: aloha.util.sys_info + +## Time Utilities (`aloha.util.time`) + +This module provides tools for wrapping function calls (such as HTTP requests via `requests` or `httpx`) with time constraints (timeouts), allowing execution of optional callbacks upon completion or failure. + +### Key Functions +- `run_with_timeout`: Wrap a synchronous function call with a timeout. +- `run_async_with_timeout`: Wrap an asynchronous function call with a timeout. + +### Usage Example +```python +from aloha.util.time import run_with_timeout +import requests + +def success_callback(response): + print("Request succeeded:", response.status_code) + +def fail_callback(exception): + print("Request failed or timed out:", exception) + +# Synchronous call with timeout +try: + run_with_timeout( + requests.get, + 2.5, # 2.5 seconds timeout + "https://httpbin.org/delay/1", + fn_callback_success=success_callback, + fn_callback_fail=fail_callback + ) +except TimeoutError: + print("Caught TimeoutError") +``` + +::: aloha.util.time + diff --git a/doc/mkdocs.yml b/doc/mkdocs.yml index a13d63d..e352365 100644 --- a/doc/mkdocs.yml +++ b/doc/mkdocs.yml @@ -70,6 +70,5 @@ nav: - Service: "api/service.md" - Encryption: "api/encrypt.md" - Database: "api/db.md" - - Time: "api/time.md" - Utilities: "api/util.md" - Testing: "api/testing.md" diff --git a/doc/mkdocs.zh.yml b/doc/mkdocs.zh.yml index 202fbe2..4934e56 100644 --- a/doc/mkdocs.zh.yml +++ b/doc/mkdocs.zh.yml @@ -70,6 +70,5 @@ nav: - 服务层: "api/service.md" - 加密: "api/encrypt.md" - 数据库: "api/db.md" - - 时间工具: "api/time.md" - 工具函数: "api/util.md" - 测试工具: "api/testing.md" diff --git a/doc/skills/project_scaffolding.md b/doc/skills/project_scaffolding.md index 72768ed..0fd7634 100644 --- a/doc/skills/project_scaffolding.md +++ b/doc/skills/project_scaffolding.md @@ -10,7 +10,7 @@ The `aloha-python` repository is organized into several top-level directories, e - **`pkg/`**: This directory stores the source code for the `aloha` Python package that is intended for publication to PyPI. It is the correct place to modify when the task is to work on this library itself. When using this repository as a boilerplate for a new application project, developers or agents should not include this directory unless they explicitly intend to create and publish a new package to PyPI. -- **`src/`**: This directory is designed for the application-specific code that consumes the `aloha` package. It serves as a boilerplate example (`app_common`) for how to structure a project using `aloha`. New projects based on this boilerplate should place their primary application logic and modules here. +- **`src/`**: This directory is designed for application-specific code and tests that consume the `aloha` package. It serves as a boilerplate example (`app_common`) for how to structure a project using `aloha`. New projects based on this boilerplate should place their primary application logic, modules, and tests here. - **`notebook/`**: This directory is for Jupyter notebooks, which can be used for experimentation, data analysis, or interactive development related to the project. @@ -30,5 +30,20 @@ To initiate a new project using `aloha-python` as a boilerplate, follow these st - **`src/` for Application Logic**: All primary application code, including API handlers, business logic, and utility modules, should reside within `src/`. The `src/main.py` script acts as a generic entry point for running Python modules within the `src/` directory. Your application's main function should be callable via `python3 src/main.py your_module.main`. - **`pkg/` is not part of a new boilerplate app**: If the goal is to build a new application project from this repository, do not carry over `pkg/` unless the user specifically wants to create and publish a separate package. Application code should live in `src/` instead. - **`resource/config/` for Configuration**: Application configuration files (e.g., `main.conf`, `deploy-DEV.conf`) should be placed under `src/resource/config/`. The `aloha` package's `aloha.config.paths` module handles the discovery and loading of these configuration files. For detailed information on HOCON configuration, refer to the "Configuration with HOCON" section in the `aloha_package_usage.md` skill. +- **Tests Placement**: All test-related code (including unit tests, integration tests, and test resources) must be placed inside the `src/` directory, typically organized under a `src/tests/` subdirectory. Test files should follow standard naming conventions such as `test_*.py`. +- **Executing Tests**: Tests should be run using `pytest` inside the containerized development environment: + 1. Launch and enter the development container: + ```bash + ./tool/cicd/run-dev.sh up + ./tool/cicd/run-dev.sh enter + ``` + 2. Run tests under the `src/` directory: + ```bash + pytest src/ + ``` + 3. To run tests with code coverage analysis: + ```bash + pytest --cov=src src/ + ``` By adhering to these conventions, AI agents can effectively understand, navigate, and contribute to projects built upon the `aloha-python` framework. diff --git a/doc/zh/api/time.md b/doc/zh/api/time.md deleted file mode 100644 index 05d84bc..0000000 --- a/doc/zh/api/time.md +++ /dev/null @@ -1,9 +0,0 @@ -# 时间工具 - -::: aloha.time - -::: aloha.time.timeout_async - -::: aloha.time.timeout_asyncio - -::: aloha.time.timeout_signal diff --git a/doc/zh/api/util.md b/doc/zh/api/util.md index c106668..797bbd3 100644 --- a/doc/zh/api/util.md +++ b/doc/zh/api/util.md @@ -13,3 +13,38 @@ ::: aloha.util.sys_gpu ::: aloha.util.sys_info + +## 时间工具 (`aloha.util.time`) + +该模块提供用于包装函数调用(如通过 `requests` 或 `httpx` 发起外部 HTTP 请求)的超时控制工具,并在操作成功或失败(超时/异常)时触发可选的回调函数。 + +### 核心函数 +- `run_with_timeout`: 以同步方式运行函数,并应用超时限制。 +- `run_async_with_timeout`: 以异步方式(协程或在执行器中运行同步函数)运行函数,并应用超时限制。 + +### 使用示例 +```python +from aloha.util.time import run_with_timeout +import requests + +def success_callback(response): + print("请求成功:", response.status_code) + +def fail_callback(exception): + print("请求失败或超时:", exception) + +# 同步超时包装调用 +try: + run_with_timeout( + requests.get, + 2.5, # 2.5 秒超时限制 + "https://httpbin.org/delay/1", + fn_callback_success=success_callback, + fn_callback_fail=fail_callback + ) +except TimeoutError: + print("捕获到超时异常 (TimeoutError)") +``` + +::: aloha.util.time + diff --git a/pkg/aloha/db/elasticsearch.py b/pkg/aloha/db/elasticsearch.py index f360d6f..e67b4b8 100644 --- a/pkg/aloha/db/elasticsearch.py +++ b/pkg/aloha/db/elasticsearch.py @@ -1,6 +1,5 @@ -"""Elasticsearch connection helpers.""" - import json +import re from elasticsearch import Elasticsearch @@ -10,6 +9,16 @@ __all__ = ("ElasticSearchOperator",) +def _mask_hosts(hosts): + if isinstance(hosts, list): + return [_mask_hosts(h) for h in hosts] + if isinstance(hosts, dict): + return {k: ("***" if k in ("password", "http_auth") else _mask_hosts(v)) for k, v in hosts.items()} + if isinstance(hosts, str): + return re.sub(r"([^:/]+://)?([^:/]+):([^@]+)@", r"\1\2:***@", hosts) + return hosts + + class ElasticSearchOperator: """Create and use an Elasticsearch client with optional index helpers.""" @@ -21,14 +30,17 @@ def __init__(self, config, index_config=None): username = config.get("username") password = password_vault.get_password(config.get("password")) + hosts = config.get("host", "localhost") + masked_hosts = _mask_hosts(hosts) + LOG.debug("ElasticSearch connection info: " + str(masked_hosts)) + self._config = { "http_auth": (username, password) if username is not None and password is not None else None, - "hosts": config.get("host", "localhost"), + "hosts": hosts, "timeout": config.get("timeout", 0.1), "max_retries": config.get("max_retries", 3), "retry_on_timeout": config.get("retry_on_timeout", True), } - LOG.debug("ElasticSearch connection info: " + str(self._config["hosts"])) self.index_config = index_config self.index_name = self.es_config.get("index_name") diff --git a/pkg/aloha/db/mongo.py b/pkg/aloha/db/mongo.py index 91ba9e3..a808e96 100644 --- a/pkg/aloha/db/mongo.py +++ b/pkg/aloha/db/mongo.py @@ -68,7 +68,8 @@ def __init__(self, config, db_name=None, collection_name=None): "maxPoolSize": config.get("maxPoolSize"), "authSource": config.get("authSource", db_name), } - LOG.debug(_config) + msg = {k: ("***" if k == "password" else v) for k, v in _config.items()} + LOG.debug(msg) try: self.conn = pymongo.MongoClient(**_config) diff --git a/pkg/aloha/service/__init__.py b/pkg/aloha/service/__init__.py index 35bf1d2..62afb2c 100644 --- a/pkg/aloha/service/__init__.py +++ b/pkg/aloha/service/__init__.py @@ -1,4 +1,5 @@ from .api import v0, v1, v2 -from .http import DefaultHandler404 +from .handlers import DefaultHandler404 +from .http import CORSMiddleware -__all__ = ("DefaultHandler404", "v0", "v1", "v2") +__all__ = ("CORSMiddleware", "DefaultHandler404", "v0", "v1", "v2") diff --git a/pkg/aloha/service/api/v0.py b/pkg/aloha/service/api/v0.py index f5bf570..fcea0a9 100644 --- a/pkg/aloha/service/api/v0.py +++ b/pkg/aloha/service/api/v0.py @@ -1,21 +1,24 @@ -"""Version 0 JSON API helpers. +"""Version 0 JSON API helpers for FastAPI. This module defines the simplest request/response protocol used by aloha: request bodies are passed directly to the handler method and the response is serialized as a JSON object with a `code` and `message` field. """ -import json import logging from abc import ABC -from ..http import AbstractApiClient, AbstractApiHandler +from fastapi import Request +from fastapi.responses import JSONResponse -__all__ = ("APIHandler", "APICaller") +from ..http import AbstractApiClient +from ..http.base_api_handler import AbstractApiHandler as BaseHandler +__all__ = ("APIHandler", "APICaller", "create_v0_router") -class APIHandler(AbstractApiHandler, ABC): - """Base Tornado handler for v0 JSON endpoints. + +class APIHandler(BaseHandler, ABC): + """Base handler for v0 JSON endpoints using FastAPI. Subclasses implement :meth:`response`, which receives parsed request data and returns a Python object that can be JSON-serialized. @@ -27,21 +30,85 @@ async def post(self, *args, **kwargs): """Parse the request body, call :meth:`response`, and return JSON.""" req_body = self.request_body - if req_body is not None: # body_arguments + if req_body is not None: kwargs.update(req_body) resp = dict(code=5200, message=["success"]) try: - result = self.response(*args, **kwargs) # this call may throw TypeError when argument missing + result = self.response(*args, **kwargs) resp["data"] = result except Exception as e: if self.LOG.level == logging.DEBUG: self.LOG.error(e, exc_info=True) return self.finish({"code": 5201, "message": [repr(e)]}) - resp = json.dumps(resp, ensure_ascii=False, default=str, separators=(",", ":")) return self.finish(resp) + async def get(self, *args, **kwargs): + """Handle GET request (useful for some v0 endpoints).""" + kwargs.update(self.request_param) + resp = dict(code=5200, message=["success"]) + try: + result = self.response(*args, **kwargs) + resp["data"] = result + except Exception as e: + if self.LOG.level == logging.DEBUG: + self.LOG.error(e, exc_info=True) + return self.finish({"code": 5201, "message": [repr(e)]}) + return self.finish(resp) + + +def create_v0_router(handler_class): + """Create FastAPI routes for a v0 API handler class. + + Args: + handler_class: A class inheriting from APIHandler + + Returns: + A function that registers routes on a FastAPI app + """ + + async def handle_post(request: Request, **kwargs): + handler = handler_class() + handler._request = request + + # Get body for POST + try: + body = await request.json() + except Exception: + body = {} + + kwargs.update(body) + resp = dict(code=5200, message=["success"]) + try: + result = handler.response(**kwargs) + resp["data"] = result + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + return JSONResponse(resp) + + async def handle_get(request: Request, **kwargs): + handler = handler_class() + handler._request = request + + # Get query params for GET + kwargs.update(dict(request.query_params)) + resp = dict(code=5200, message=["success"]) + try: + result = handler.response(**kwargs) + resp["data"] = result + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + return JSONResponse(resp) + + return handle_post, handle_get + class APICaller(AbstractApiClient): """Client helper for v0 endpoints. diff --git a/pkg/aloha/service/api/v1.py b/pkg/aloha/service/api/v1.py index 8b8ce76..6a5815c 100644 --- a/pkg/aloha/service/api/v1.py +++ b/pkg/aloha/service/api/v1.py @@ -1,4 +1,4 @@ -"""Version 1 signed JSON API helpers. +"""Version 1 signed JSON API helpers for FastAPI. Version 1 adds request signing with `app_id`, `salt_uuid`, and `sign` fields. Handlers validate the signature before dispatching to the service logic. @@ -9,11 +9,15 @@ import uuid from abc import ABC +from fastapi import Request +from fastapi.responses import JSONResponse + from ...encrypt.hash import get_md5_of_str, get_sha256_of_str from ...settings import SETTINGS -from ..http import AbstractApiClient, AbstractApiHandler +from ..http import AbstractApiClient +from ..http.base_api_handler import AbstractApiHandler as BaseHandler -__all__ = ("APIHandler", "APICaller", "sign_data", "sign_check") +__all__ = ("APIHandler", "APICaller", "sign_data", "sign_check", "create_v1_router") APP_ID_KEYS = SETTINGS.config.get("APP_ID_KEYS", {}) APP_OPTIONS = SETTINGS.config.get("APP_OPTIONS", {}) @@ -21,7 +25,7 @@ func_sign_check_default = FUNC_SIGN_CHECK.get(APP_OPTIONS.get("sign_method", "md5")) -class APIHandler(AbstractApiHandler, ABC): +class APIHandler(BaseHandler, ABC): """Signed API handler for v1 endpoints.""" MAP_ERROR_INFO = { @@ -39,16 +43,16 @@ async def post(self): app_id = body_arguments.pop("app_id") sign = body_arguments.pop("sign") data = body_arguments.pop("data") - except KeyError: # cannot find default key from parsed body + except KeyError: return self.finish(self.MAP_ERROR_INFO["MISSING_ARGS"]) - is_valid_req = sign_check(salt_uuid=salt_uuid, app_id=app_id, sign=sign, data=data) # , sign_method='sha256' + is_valid_req = sign_check(salt_uuid=salt_uuid, app_id=app_id, sign=sign, data=data) if not is_valid_req: return self.finish(self.MAP_ERROR_INFO["SIGN_CHECK_FAIL"]) resp = dict(code=5200, message=["success"]) try: - result = self.response(**data) # this call may throw TypeError when argument missing + result = self.response(**data) resp["data"] = result resp["salt_uuid"] = salt_uuid except Exception as e: @@ -56,10 +60,57 @@ async def post(self): self.LOG.error(e, exc_info=True) return self.finish({"code": 5201, "message": [repr(e)]}) - resp = json.dumps(resp, ensure_ascii=False, default=str, separators=(",", ":")) return self.finish(resp) +def create_v1_router(handler_class): + """Create FastAPI routes for a v1 API handler class with signing validation. + + Args: + handler_class: A class inheriting from APIHandler + + Returns: + An async function that handles v1 signed requests + """ + + async def handle_post(request: Request, **kwargs): + try: + body = await request.json() + except Exception: + return JSONResponse( + {"code": "5101", "message": ["Bad request: fail to parse body as JSON object!"]}, status_code=400 + ) + + try: + salt_uuid = body.pop("salt_uuid") + app_id = body.pop("app_id") + sign = body.pop("sign") + data = body.pop("data") + except KeyError: + return JSONResponse({"code": "5102", "message": ["Required argument field(s) missing..."]}, status_code=400) + + is_valid_req = sign_check(salt_uuid=salt_uuid, app_id=app_id, sign=sign, data=data) + if not is_valid_req: + return JSONResponse({"code": "5104", "message": ["Invalid sign, sign check failed!"]}, status_code=401) + + handler = handler_class() + handler._request = request + + resp = dict(code=5200, message=["success"]) + try: + result = handler.response(**data) + resp["data"] = result + resp["salt_uuid"] = salt_uuid + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + return JSONResponse(resp) + + return handle_post + + class APICaller(AbstractApiClient): """Client helper that wraps payloads with v1 signing metadata.""" @@ -76,8 +127,6 @@ def wrap_request_data( ): """Wrap the payload with signature fields expected by v1 handlers.""" if app_id is None: - # if len(APP_ID_KEYS) != 1: - # raise RuntimeError('Please specify 1 and only 1 in APP_ID_KEYS in configurations!') app_id = list(self.APP_ID_KEYS.keys())[0] salt_uuid = salt_uuid or str(uuid.uuid1()) sign = sign or sign_data( @@ -113,11 +162,9 @@ def sign_check(salt_uuid: str, app_id: str, sign: str, data, sign_method: str = raise ValueError("Invalid `sign_method`: %s" % sign_method) app_key = APP_ID_KEYS.get(app_id) - if app_key is None: # APP_ID not in the dict, unknown APP_ID + if app_key is None: return False - # data_str = str(json.dumps(data, ensure_ascii=False, sort_keys=True, separators=(',', ':'))) - # --> Compatible with older version API right_sign = func_sign_check(app_id + salt_uuid + app_key) if sign == right_sign: diff --git a/pkg/aloha/service/api/v2.py b/pkg/aloha/service/api/v2.py index bb17eac..f878afa 100644 --- a/pkg/aloha/service/api/v2.py +++ b/pkg/aloha/service/api/v2.py @@ -1,4 +1,4 @@ -"""Version 2 token-based JSON API helpers. +"""Version 2 token-based JSON API helpers for FastAPI. Version 2 uses an access token in the request header and a request-id header for tracing. It keeps the same request/response shape as the earlier API @@ -9,37 +9,37 @@ import logging from abc import ABC from datetime import datetime, timedelta -from typing import Awaitable, Optional +from typing import Any, Dict, Optional + +from fastapi import Depends, HTTPException, Request, Response, status +from fastapi.responses import JSONResponse from ...encrypt import jwt +from ...logger import LOG from ...settings import SETTINGS -from ..http import AbstractApiClient, AbstractApiHandler +from ..http import AbstractApiClient +from ..http.base_api_handler import AbstractApiHandler as BaseHandler -__all__ = ("APIHandler", "APICaller") +__all__ = ("APIHandler", "APICaller", "create_v2_router", "verify_v2_token") -class APIHandler(AbstractApiHandler, ABC): +class APIHandler(BaseHandler, ABC): """Token-authenticated API handler for v2 endpoints.""" - async def prepare( - self, - ) -> Optional[Awaitable[None]]: + async def prepare(self) -> Optional[Response]: """Validate the access token before handling the request.""" - access_token = self.request.headers.get("Access-Token") + access_token = self._request.headers.get("Access-Token") if access_token is None: return self.finish({"msg": "Invalid Access-Token in request header!"}) else: secret_key = SETTINGS.config["APP_SECRET_KEY"] - # options = None - # TODO: if not validate expiration options = {"verify_exp": False} access_token = jwt.decode(secret_key, access_token, options=options) if not isinstance(access_token, dict): - self.LOG.error( - "Invalid Access-Token found in request for [%s]: %s" % (str(self.request.full_url()), access_token) - ) - return self.finish({"msg": access_token}) - self.set_header("Request-ID", self.request_id) + msg = "Invalid Access-Token found in request for [%s]: %s" % (str(self._request.url), access_token) + self.LOG.error(msg) + return self.finish({"msg": msg}) + return None async def post(self, *args, **kwargs): """Handle POST requests with JSON request bodies.""" @@ -50,14 +50,13 @@ async def post(self, *args, **kwargs): s_kwargs = json.dumps(kwargs, ensure_ascii=False) self.LOG.debug("POST Request [%s]: %s" % (self.request_id, s_kwargs[:1000])) self.api_args, self.api_kwargs = args or (), kwargs or {} - resp = self.response(*self.api_args, **self.api_kwargs) # this call may throw TypeError when argument missing + resp = self.response(*self.api_args, **self.api_kwargs) except Exception as e: + self.LOG.info("POST Request [%s]: %s" % (self.request_id, self._request._body)) + msgs = ["An internal error has occurred!", str(e)] self.LOG.error(e, exc_info=True) - self.LOG.info("POST Request [%s]: %s" % (self.request_id, self.request.body)) - return self.finish({"status": "error", "message": [str(e)]}) + return self.finish({"status": "error", "message": msgs}) - if isinstance(resp, (dict, list)): - resp = json.dumps(resp, ensure_ascii=False, default=str, separators=(",", ":")) return self.finish(resp) async def get(self, *args, **kwargs): @@ -67,17 +66,92 @@ async def get(self, *args, **kwargs): try: self.LOG.debug("GET Request [%s]: %s" % (self.request_id, kwargs)) self.api_args, self.api_kwargs = args or (), kwargs or {} - resp = self.response(*self.api_args, **self.api_kwargs) # this call may throw TypeError when argument missing + resp = self.response(*self.api_args, **self.api_kwargs) except Exception as e: - self.LOG.error(e, exc_info=True) self.LOG.info("GET Request [%s]: %s" % (self.request_id, kwargs)) - return self.finish({"status": "error", "message": [repr(e)]}) + msgs = ["An internal error has occurred!", str(e)] + self.LOG.error(e, exc_info=True) + return self.finish({"status": "error", "message": msgs}) - if isinstance(resp, (dict, list)): - resp = json.dumps(resp, ensure_ascii=False, default=str, separators=(",", ":")) return self.finish(resp) +def verify_v2_token(request: Request) -> Optional[Dict[str, Any]]: + """Dependency to verify v2 access token. + + Returns the decoded token payload if valid, otherwise raises HTTPException. + """ + + access_token = request.headers.get("Access-Token") + if access_token is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Access-Token in request header!") + + secret_key = SETTINGS.config.get("APP_SECRET_KEY") + if not secret_key: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="APP_SECRET_KEY not configured!") + + options = {"verify_exp": False} + try: + payload = jwt.decode(secret_key, access_token, options=options) + if not isinstance(payload, dict): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Access-Token!") + return payload + except Exception as e: + LOG.error(str(e), exc_info=True) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Access-Token!") + + +def create_v2_router(handler_class): + """Create FastAPI routes for a v2 API handler class with JWT token validation. + + Args: + handler_class: A class inheriting from APIHandler + + Returns: + Tuple of (handle_post, handle_get) functions for the routes + """ + + async def handle_post(request: Request, token_payload: Dict = Depends(verify_v2_token)): + handler = handler_class() + handler._request = request + + try: + body = await request.json() + except Exception: + body = {} + + kwargs = body + try: + if handler.LOG.level == logging.DEBUG: + s_kwargs = json.dumps(kwargs, ensure_ascii=False) + handler.LOG.debug("POST Request [%s]: %s" % (handler.request_id, s_kwargs[:1000])) + + resp = handler.response(**kwargs) + except Exception as e: + handler.LOG.error(e, exc_info=True) + msgs = ["An internal error has occurred.", str(e)] + return JSONResponse({"status": "error", "message": msgs}, status_code=500) + + return handler.finish(resp) + + async def handle_get(request: Request, token_payload: Dict = Depends(verify_v2_token)): + handler = handler_class() + handler._request = request + + kwargs = dict(request.query_params) + try: + handler.LOG.debug("GET Request [%s]: %s" % (handler.request_id, kwargs)) + resp = handler.response(**kwargs) + except Exception as e: + handler.LOG.error(e, exc_info=True) + msgs = ["An internal error has occurred.", repr(e)] + return JSONResponse({"status": "error", "message": msgs}, status_code=500) + + return handler.finish(resp) + + return handle_post, handle_get + + class APICaller(AbstractApiClient): """Client helper that adds v2 access-token headers automatically.""" @@ -92,8 +166,6 @@ def wrap_request_data(self, data: dict) -> dict: def get_headers(self, app_id: str = None, app_key: str = None) -> dict: """Build the HTTP headers expected by v2 handlers.""" if app_id is None: - # if len(APP_ID_KEYS) != 1: - # raise RuntimeError('Please specify 1 and only 1 in APP_ID_KEYS in configurations!') app_id = list(self.APP_ID_KEYS.keys())[0] expire_time = datetime.now() + timedelta(days=1) diff --git a/pkg/aloha/service/app.py b/pkg/aloha/service/app.py index 00c5043..b932903 100644 --- a/pkg/aloha/service/app.py +++ b/pkg/aloha/service/app.py @@ -1,56 +1,69 @@ -"""Service application bootstrap utilities.""" +"""Service application bootstrap utilities for FastAPI.""" import asyncio +import uvicorn + from ..logger import LOG try: import uvloop - from tornado.platform.asyncio import AsyncIOMainLoop LOG.info("Using uvloop == %s for service event loop..." % uvloop.__version__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - AsyncIOMainLoop().install() except ImportError: LOG.info("[uvloop] NOT installed, fallback to asyncio loop! Consider `pip install uvloop`!") -from tornado.options import options - from ..settings import SETTINGS -from .web import WebApplication +from .web import FastAPIApplication __all__ = ("Application",) class Application: - """Bootstrap and run an aloha web service.""" + """Bootstrap and run an aloha FastAPI web service.""" def __init__(self, *args, **kwargs): """Create the service application wrapper.""" - options["log_file_prefix"] = "access.log" settings = dict(SETTINGS.config) - self.web_app = WebApplication(settings) + self.web_app = FastAPIApplication(settings) + self._server = None def start(self): - """Start the web app and run the asyncio event loop.""" + """Start the FastAPI app using uvicorn.""" + port = self.web_app.get_port() + workers = self.web_app.get_workers() + + LOG.info("Starting FastAPI service at port [%s] with [%s] workers...", port, workers) + try: - self.web_app.start() - event_loop = asyncio.get_event_loop() - if event_loop.is_running(): - # notice: the event loop MUST NOT be initialized before web_app starts (as it may fork process) - # ref: https://github.com/tornadoweb/tornado/issues/2426#issuecomment-400895086 - raise RuntimeError("Event loop already running before WebApp starts!") - else: - event_loop.run_forever() + # Configure uvicorn + config = uvicorn.Config( + app=self.web_app.app, + host="0.0.0.0", + port=port, + workers=workers, + log_level="info", + access_log=True, + ) + self._server = uvicorn.Server(config) + + # Run with uvloop if available + try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + except ImportError: + LOG.debug("uvloop is not installed; continuing with default asyncio event loop.") + + asyncio.run(self._server.serve()) except KeyboardInterrupt: - pass + LOG.info("Service interrupted by user") except Exception as e: + LOG.error("Service error: %s", str(e)) raise e - finally: - pass def stop(self): - """Stop the event loop if it is currently running.""" - event_loop = asyncio.get_event_loop() - if event_loop.is_running(): - event_loop.stop() + """Stop the server if it is currently running.""" + if self._server is not None: + self._server.should_exit = True diff --git a/pkg/aloha/service/handlers.py b/pkg/aloha/service/handlers.py new file mode 100644 index 0000000..d105031 --- /dev/null +++ b/pkg/aloha/service/handlers.py @@ -0,0 +1,21 @@ +"""Default HTTP handlers for aloha services.""" + +from fastapi import Request +from fastapi.responses import JSONResponse + + +class DefaultHandler404: + """Default 404 response handler for FastAPI services.""" + + def __init__(self, request: Request | None = None, **kwargs): + self.request = request + self._request = request + + async def handle(self, request: Request | None = None): + """Return a JSON response for unmatched routes.""" + _request = request or self.request + del _request + return JSONResponse( + {"code": 404, "message": ["Not Found"], "data": None}, + status_code=404, + ) diff --git a/pkg/aloha/service/http/__init__.py b/pkg/aloha/service/http/__init__.py index 1099a97..0f2cafc 100644 --- a/pkg/aloha/service/http/__init__.py +++ b/pkg/aloha/service/http/__init__.py @@ -1,10 +1,5 @@ from .base_api_client import AbstractApiClient -from .base_api_handler import AbstractApiHandler, DefaultHandler404 -from .plain_http_handler import PlainHttpHandler +from .base_api_handler import AbstractApiHandler +from .plain_http_handler import CORSMiddleware, add_cors_headers -__all__ = ( - "AbstractApiClient", - "AbstractApiHandler", - "DefaultHandler404", - "PlainHttpHandler", -) +__all__ = ("AbstractApiClient", "AbstractApiHandler", "CORSMiddleware", "add_cors_headers") diff --git a/pkg/aloha/service/http/base_api_client.py b/pkg/aloha/service/http/base_api_client.py index d09f76c..1dd4648 100644 --- a/pkg/aloha/service/http/base_api_client.py +++ b/pkg/aloha/service/http/base_api_client.py @@ -1,46 +1,51 @@ -"""Base HTTP client helpers for aloha API clients.""" +"""Base HTTP client helpers for aloha API clients using httpx.""" import uuid from abc import ABC, abstractmethod from urllib.parse import urljoin -import requests -from requests.adapters import HTTPAdapter, Retry +import httpx from ...logger import LOG from ...settings import SETTINGS class AbstractApiClient(ABC): - """Common client behavior for aloha HTTP APIs.""" + """Common client behavior for aloha HTTP APIs using httpx.""" LOG = LOG - RETRY_METHOD_WHITELIST: frozenset = frozenset(['GET', 'POST']) + RETRY_METHOD_WHITELIST: frozenset = frozenset(["GET", "POST"]) RETRY_STATUS_FORCELIST: frozenset = frozenset({413, 429, 503, 502, 504}) config = SETTINGS.config def __init__(self, url_endpoint: str = None, *args, **kwargs): """Store the endpoint used by the client.""" - self.url_endpoint = url_endpoint or '' - LOG.debug('API Caller URL endpoint set to: %s' % self.url_endpoint) - - @classmethod - def get_request_session(cls, total_retries: int = 3, *args, **kwargs) -> requests.Session: - """Create a requests session with retry support.""" - session = requests.Session() - # https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.Retry.DEFAULT_ALLOWED_METHODS - retries = Retry( - total=total_retries, backoff_factor=0.1, method_whitelist=cls.RETRY_METHOD_WHITELIST, status_forcelist=cls.RETRY_STATUS_FORCELIST + self.url_endpoint = url_endpoint or "" + LOG.debug("API Caller URL endpoint set to: %s" % self.url_endpoint) + + def get_http_client(self, total_retries: int = 3, *args, **kwargs) -> httpx.AsyncClient: + """Create an httpx async client with retry support via custom transport.""" + # Create a custom transport that retries on specific status codes + from httpx import AsyncClient, Limits, Timeout + + # Configure retry policy + limits = Limits(max_keepalive_connections=20, max_connections=100, keepalive_expiry=30) + timeout = Timeout(timeout=30.0, connect=5.0) + + # Create async client with retry capabilities + client = AsyncClient( + limits=limits, + timeout=timeout, + follow_redirects=True, + http2=True, ) - for prefix in ('http://', 'https://'): - session.mount(prefix, HTTPAdapter(max_retries=retries)) - return session + return client def get_headers(self, *args, **kwargs) -> dict: """Build the default request headers used by aloha clients.""" headers = { - 'Content-Type': 'application/json', - 'Request-ID': str(uuid.uuid1()), + "Content-Type": "application/json", + "Request-ID": str(uuid.uuid4()), } return headers @@ -49,18 +54,18 @@ def wrap_request_data(self, data: dict) -> dict: """Transform the request payload before sending it.""" assert isinstance(data, dict), "Data object must be a dict!" raise NotImplementedError() - # return data - def call(self, api_url: str, data: dict = None, timeout=5, **kwargs): - """Call a remote API and return the parsed JSON response.""" + async def _async_call(self, api_url: str, data: dict = None, timeout: float = 5, **kwargs): + """Async version: Call a remote API and return the parsed JSON response.""" body = data or dict() body.update(kwargs) payload = self.wrap_request_data(data=body) - LOG.debug('Calling api: %s' % api_url) - session = self.get_request_session() - resp = session.post( - urljoin(self.url_endpoint, api_url), json=payload, timeout=timeout, headers=self.get_headers() - ) + LOG.debug("Calling api: %s" % api_url) + + async with self.get_http_client() as client: + resp = await client.post( + urljoin(self.url_endpoint, api_url), json=payload, timeout=timeout, headers=self.get_headers() + ) try: ret = resp.json() @@ -69,3 +74,22 @@ def call(self, api_url: str, data: dict = None, timeout=5, **kwargs): raise RuntimeError(resp.text) return ret + + def call(self, api_url: str, data: dict = None, timeout: float = 5, **kwargs): + """Call a remote API and return the parsed JSON response (sync wrapper).""" + import asyncio + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is running, we need to create a new task + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, self._async_call(api_url, data, timeout, **kwargs)) + return future.result() + else: + return loop.run_until_complete(self._async_call(api_url, data, timeout, **kwargs)) + except RuntimeError: + # No event loop exists + return asyncio.run(self._async_call(api_url, data, timeout, **kwargs)) diff --git a/pkg/aloha/service/http/base_api_handler.py b/pkg/aloha/service/http/base_api_handler.py index c4cdd41..bdfcdaa 100644 --- a/pkg/aloha/service/http/base_api_handler.py +++ b/pkg/aloha/service/http/base_api_handler.py @@ -1,108 +1,141 @@ -"""Base Tornado handlers used by aloha services.""" +"""Base FastAPI dependencies and request helpers for aloha services.""" +import asyncio import json +import logging from abc import ABC from datetime import datetime -from typing import Optional, Awaitable +from typing import Any, Dict, Optional -from tornado import web +from fastapi import APIRouter, Request, Response from ...logger import LOG -class AbstractApiHandler(web.RequestHandler, ABC): - """Shared request parsing and response helpers for JSON APIs.""" +class AbstractApiHandler(ABC): + """Shared request parsing and response helpers for JSON APIs. + + This is a base class that provides utility methods for API handlers. + Subclasses should inherit from this and implement the response() method. + """ LOG = LOG - MAP_ERROR_INFO: dict = { - 'BAD_REQUEST': {'code': '5101', 'message': ['Bad request: fail to parse body as JSON object!']} - } + MAP_ERROR_INFO: dict = {"BAD_REQUEST": {"code": "5101", "message": ["Bad request: fail to parse body as JSON object!"]}} - def __init__(self, *args, **kwargs): + def __init__(self): """Initialize request state used by subclasses.""" self.api_args: Optional[tuple] = None self.api_kwargs: Optional[dict] = None - super().__init__(*args, **kwargs) - - def on_finish(self) -> None: - """Invoke any stored callback after the request finishes.""" - func_callback = getattr(self, 'callback', None) - if callable(func_callback) \ - and isinstance(self.api_args, tuple) \ - and isinstance(self.api_kwargs, dict): - func_callback(*self.api_args, **self.api_kwargs) + self._request: Optional[Request] = None + self._response: Optional[Response] = None def response(self, *args, **kwargs) -> dict: """Subclasses must implement the business response.""" raise NotImplementedError() - def set_default_headers(self) -> None: - """Set the JSON content type for API responses.""" - self.set_header('Content-Type', 'application/json; charset=utf-8') - - def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: - """Accept streaming request bodies when Tornado calls back.""" - pass - @property def request_header_content_type(self) -> str: """Return the request content type with a JSON default.""" - return self.request.headers.get('Content-Type', 'application/json; charset=utf-8') + if self._request is None: + return "application/json; charset=utf-8" + return self._request.headers.get("Content-Type", "application/json; charset=utf-8") @property - def request_id(self): + def request_id(self) -> str: """Return or create a request identifier for tracing.""" - if 'Request-ID' not in self.request.headers: - self.request.headers['Request-ID'] = datetime.now().strftime('%Y%m%d-%H%M%S-%f') - return self.request.headers.get('Request-ID') + if self._request is None: + return datetime.now().strftime("%Y%m%d-%H%M%S-%f") + request_id = self._request.headers.get("Request-ID") + if request_id is None: + request_id = datetime.now().strftime("%Y%m%d-%H%M%S-%f") + return request_id @property - def request_body(self) -> dict: + def request_body(self) -> Optional[dict]: """Parse the request body as JSON or multipart form data.""" content_type: str = self.request_header_content_type - body_arguments: dict = Optional[None] - if content_type.startswith('multipart/form-data'): # only parse files when 'Content-Type' starts with 'multipart/form-data' - body_arguments = self.request_param # self.request.body_arguments - else: - try: - body = self.request.body.decode('utf-8') - body_arguments = json.loads(body) - except (UnicodeDecodeError, json.decoder.JSONDecodeError): # invalid request body, cannot be parsed as JSON - self.finish(self.MAP_ERROR_INFO['BAD_REQUEST']) - return body_arguments + if self._request is None: + return {} + + # For multipart/form-data, use request_param logic + if content_type.startswith("multipart/form-data"): + return self.request_param + + try: + body = asyncio.get_event_loop().run_until_complete(self._request.body()) + body_str = body.decode("utf-8") + if body_str: + return json.loads(body_str) + return {} + except (UnicodeDecodeError, json.JSONDecodeError): + return self.MAP_ERROR_INFO["BAD_REQUEST"] @property def request_param(self) -> dict: """Parse query/body arguments into a JSON-friendly dict.""" ret: dict = {} - for k, v in self.request.arguments.items(): - val = v[0].decode('utf-8') + if self._request is None: + return ret + + # Parse query parameters + for k, v in self._request.query_params.items(): try: - value = json.loads(val) + value = json.loads(v) except json.JSONDecodeError: - value = val + value = v ret[k] = value return ret + def get_request_files(self) -> Dict[str, list]: + """Get uploaded files from multipart form data.""" + if self._request is None: + return {} + return self._request._form + + def finish(self, data: Any, status_code: int = 200) -> Response: + """Create a JSON response with proper content type.""" + if isinstance(data, dict): + content = json.dumps(data, ensure_ascii=False, default=str, separators=(",", ":")) + elif isinstance(data, str): + content = data + else: + content = json.dumps(data, ensure_ascii=False, default=str, separators=(",", ":")) + return Response(content=content, status_code=status_code, media_type="application/json") + + def set_header(self, key: str, value: str) -> None: + """Set a response header (no-op in base class, overridden in FastAPI route).""" + pass + + def set_status(self, status_code: int, reason: str = None) -> None: + """Set the response status code (no-op in base class).""" + pass -class DefaultHandler404(AbstractApiHandler): - """Default JSON 404 handler used by aloha services.""" - - def response(self, *args, **kwargs) -> Optional[dict]: - """Return the default 404 response payload.""" - return self.prepare() - - def prepare(self): # for all methods - """Finalize the 404 response.""" - msg = { - "code": 404, - "status": "error", - "message": [ - 'Requested URL cannot be found: %s' % self.request.uri - ] - } - msg = json.dumps(msg, ensure_ascii=False, default=str, separators=(',', ':')) - self.set_status(404, reason='Not Found') - self.finish(msg) + async def _handle_request(self, request: Request, *args, **kwargs) -> Response: + """Process the request and return a response.""" + self._request = request + self.api_args = args + self.api_kwargs = kwargs + + try: + result = self.response(*args, **kwargs) + if isinstance(result, (dict, list)): + return self.finish(result) + return result + except Exception as e: + if self.LOG.level == logging.DEBUG: + self.LOG.error(e, exc_info=True) + msgs = ["An internal error has occurred!", repr(e)] + return self.finish({"code": 5201, "message": msgs}, status_code=500) + + +def create_handler_route(handler_class): + """Create a FastAPI route wrapper for a handler class.""" + + class HandlerRoute(APIRouter): + async def _execute_handler(self, request: Request, **kwargs) -> Response: + handler = handler_class() + return await handler._handle_request(request, **kwargs) + + return HandlerRoute diff --git a/pkg/aloha/service/http/files.py b/pkg/aloha/service/http/files.py index 46a4056..ed57a0c 100644 --- a/pkg/aloha/service/http/files.py +++ b/pkg/aloha/service/http/files.py @@ -1,41 +1,95 @@ -"""Helpers for handling multipart upload files and remote file inputs.""" +"""Helpers for handling multipart upload files and remote file inputs using httpx.""" import time -import requests +import httpx from ...logger import LOG -def iter_over_request_files(request, url_files): +async def iter_over_request_files(request, url_files): """Yield uploaded files and optional remote files as normalized tuples. Each yielded item is `(field_name, file_name, content_type, body_bytes)`. Files can come from multipart form uploads or from URLs listed in `url_files`. + + Args: + request: FastAPI request object with files attribute + url_files: List of URLs to download files from """ - for file_key, files in request.files.items(): # iter over files uploaded by multipart - for f in files: - file_name, content_type = f["filename"], f["content_type"] - body = f.get('body', b"") - LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}") - yield file_key, file_name, content_type, body - - for file_key, list_url in {'url_files': url_files or []}.items(): # iter over files specified by `url_files` + # Handle multipart uploaded files + if hasattr(request, "files") and request.files: + for file_key, files in request.files.items(): + for f in files: + file_name = getattr(f, "filename", "unknown") + content_type = getattr(f, "content_type", "application/octet-stream") + body = await f.read() + LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}") + yield file_key, file_name, content_type, body + + # Handle files from URL + for file_key, list_url in {"url_files": url_files or []}.items(): + for url in sorted(set(list_url)): + try: + t_start = time.time() + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(url) + if resp.status_code == 200: + body = resp.content + content_type = resp.headers.get("Content-Type", "UNKNOWN") + else: + raise RuntimeError( + "Failed to download file after %s seconds with code=%s from URL %s" + % (time.time() - t_start, resp.status_code, url) + ) + except Exception as e: + raise e + t_cost = time.time() - t_start + LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds") + yield "url_files", url, content_type, body + + +def iter_over_request_files_sync(request, url_files): + """Synchronous version of iter_over_request_files for backward compatibility. + + This is a sync wrapper that uses httpx sync client. + """ + + # Handle multipart uploaded files (from FastAPI form data) + if hasattr(request, "_form"): + form_data = request._form + for file_key, files in form_data.multi_items(): + if isinstance(files, list): + for f in files: + if hasattr(f, "read"): + body = f.read() + file_name = getattr(f, "filename", "unknown") + content_type = getattr(f, "content_type", "application/octet-stream") + LOG.info( + f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}" + ) + yield file_key, file_name, content_type, body + else: + yield file_key, files, "text/plain", str(files).encode() + + # Handle files from URL + for file_key, list_url in {"url_files": url_files or []}.items(): for url in sorted(set(list_url)): try: t_start = time.time() - resp = requests.get(url, stream=True) # download the file from given url - if resp.status_code == 200: - body = resp.content - content_type = resp.headers.get("Content-Type", "UNKNOWN") - else: - raise RuntimeError("Failed to download file after %s seconds with code=%s from URL %s" % ( - time.time() - t_start, resp.status_code, url - )) - del resp + with httpx.Client(follow_redirects=True) as client: + resp = client.get(url) + if resp.status_code == 200: + body = resp.content + content_type = resp.headers.get("Content-Type", "UNKNOWN") + else: + raise RuntimeError( + "Failed to download file after %s seconds with code=%s from URL %s" + % (time.time() - t_start, resp.status_code, url) + ) except Exception as e: raise e t_cost = time.time() - t_start LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds") - yield 'url_files', url, content_type, body + yield "url_files", url, content_type, body diff --git a/pkg/aloha/service/http/plain_http_handler.py b/pkg/aloha/service/http/plain_http_handler.py index 4813c12..bbe91f4 100644 --- a/pkg/aloha/service/http/plain_http_handler.py +++ b/pkg/aloha/service/http/plain_http_handler.py @@ -1,27 +1,38 @@ -"""Plain Tornado handler with permissive CORS defaults.""" +"""FastAPI middleware and dependencies with permissive CORS defaults.""" -from typing import Optional, Awaitable +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware -from tornado import web +class CORSResponse(JSONResponse): + """JSON response with permissive CORS headers for simple APIs.""" -class PlainHttpHandler(web.RequestHandler): - """Minimal handler that exposes JSON-friendly CORS headers.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: - """Accept streamed body chunks without additional processing.""" - pass + async def __call__(self, scope, receive, send) -> None: + await super().__call__(scope, receive, send) - def set_default_headers(self): - """Enable permissive cross-origin access for simple APIs.""" - self.set_header('Access-Control-Allow-Origin', '*') - self.set_header('Access-Control-Allow-Headers', '*') - self.set_header('Access-Control-Max-Age', 1000) - self.set_header('Content-type', 'application/json; charset=UTF-8') - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header( - 'Access-Control-Allow-Headers', - 'authorization, Authorization, Content-Type,' - 'Access-Control-Allow-Origin, Access-Control-Allow-Headers,' - 'X-Requested-By, Access-Control-Allow-Methods' - ) + +def add_cors_headers(response: Response) -> None: + """Add permissive CORS headers to a response.""" + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Max-Age"] = "1000" + response.headers["Content-Type"] = "application/json; charset=UTF-8" + response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = ( + "authorization, Authorization, Content-Type," + "Access-Control-Allow-Origin, Access-Control-Allow-Headers," + "X-Requested-By, Access-Control-Allow-Methods" + ) + + +class CORSMiddleware(BaseHTTPMiddleware): + """Middleware that adds permissive CORS headers to all responses.""" + + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + add_cors_headers(response) + return response diff --git a/pkg/aloha/service/openapi/client.py b/pkg/aloha/service/openapi/client.py index cbabb36..a19c052 100644 --- a/pkg/aloha/service/openapi/client.py +++ b/pkg/aloha/service/openapi/client.py @@ -18,10 +18,10 @@ class OpenApiClient: """Simple HTTP client that acquires and caches an access token.""" - retry_method_whitelist = frozenset(['GET', 'POST']) + retry_method_whitelist = frozenset(["GET", "POST"]) retry_status_forcelist = frozenset({413, 429, 503, 502, 504}) - def __init__(self, url_oauth_get_token: str, client_id: str, client_secret: str, grant_type: str = 'client_credentials'): + def __init__(self, url_oauth_get_token: str, client_id: str, client_secret: str, grant_type: str = "client_credentials"): """Store OAuth-style client credentials and token endpoint.""" self.url_oauth_get_token = url_oauth_get_token self.client_id = client_id @@ -37,9 +37,12 @@ def get_request_session(cls, total_retries: int = 10, *args, **kwargs) -> Sessio session = Session() # https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html#urllib3.util.Retry.DEFAULT_ALLOWED_METHODS retries = Retry( - total=total_retries, backoff_factor=0.1, method_whitelist=cls.retry_method_whitelist, status_forcelist=cls.retry_status_forcelist + total=total_retries, + backoff_factor=0.1, + method_whitelist=cls.retry_method_whitelist, + status_forcelist=cls.retry_status_forcelist, ) - for prefix in ('http://', 'https://'): + for prefix in ("http://", "https://"): session.mount(prefix, HTTPAdapter(max_retries=retries)) return session @@ -50,29 +53,30 @@ def get_access_token(self) -> str: if self.expires_at is None or self.expires_at > now: try: # refresh access_token - resp = self.get_request_session().post(self.url_oauth_get_token, timeout=5, json={ - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'grant_type': self.grant_type - }) + resp = self.get_request_session().post( + self.url_oauth_get_token, + timeout=5, + json={"client_id": self.client_id, "client_secret": self.client_secret, "grant_type": self.grant_type}, + ) - data = resp.json()['data'] - if data is None or 'access_token' not in data: - raise RuntimeError('Fail to fetch OpenAPI token with result: %s' % resp.text) + data = resp.json()["data"] + if data is None or "access_token" not in data: + raise RuntimeError("Fail to fetch OpenAPI token with result: %s" % resp.text) - self.access_token = data['access_token'] + self.access_token = data["access_token"] - expires_in = int(data['expires_in']) + expires_in = int(data["expires_in"]) self.expires_at = datetime.now() + timedelta(minutes=expires_in - 1) except Exception as e: - LOG.error('Exception acquiring ESG access token from [%s]: %s' % (self.url_oauth_get_token, str(e))) + msg = "Exception acquiring ESG access token from [%s]: %s" % (self.url_oauth_get_token, str(e)) + LOG.error(msg) return self.access_token def _get_request_url(self, url: str): """Attach access token and request id to the target URL.""" - request_url = '{url}?access_token={access_token}&request_id={request_id}'.format( - url=url, access_token=self.get_access_token(), request_id=datetime.now().strftime('%Y%m%d-%H%M%S-%f') + request_url = "{url}?access_token={access_token}&request_id={request_id}".format( + url=url, access_token=self.get_access_token(), request_id=datetime.now().strftime("%Y%m%d-%H%M%S-%f") ) return request_url @@ -84,29 +88,29 @@ def _get_data_from_esg_response(resp) -> Optional[dict]: except (json.JSONDecodeError, JSONDecodeError): # requests may use `simplejson` try: # when data is wrapped by ESG - content = resp.text.replace('"data":"', '"data":').replace('}"}', '}}') + content = resp.text.replace('"data":"', '"data":').replace('}"}', "}}") data = json.loads(content) - return data.get('data', {}) + return data.get("data", {}) except json.JSONDecodeError: - msg = 'Cannot parse ESG response: %s' % resp.text + msg = "Cannot parse ESG response: %s" % resp.text raise ValueError(msg) def post(self, url_api: str, body: dict, headers: dict = None, timeout: int = 5): """Send a POST request to the remote API.""" url = self._get_request_url(url_api) - LOG.debug('Calling ESG POST: %s' % url) + LOG.debug("Calling ESG POST: %s" % url) try: resp = self.get_request_session().post(url=url, headers=headers, json=body, timeout=timeout) return self._get_data_from_esg_response(resp) except Exception as e: - LOG.error('Error calling ESG API POST [%s]: %s' % (url, str(e))) + LOG.error("Error calling ESG API POST [%s]: %s" % (url, str(e))) def get(self, url_api: str, body: dict, headers: dict = None, timeout: int = 5): """Send a GET request to the remote API.""" url = self._get_request_url(url_api) - LOG.debug('Calling ESG GET: %s' % url) + LOG.debug("Calling ESG GET: %s" % url) try: resp = self.get_request_session().get(url=url, headers=headers, json=body, timeout=timeout) return self._get_data_from_esg_response(resp) except Exception as e: - LOG.error('Error calling ESG API GET [%s]: %s' % (url, str(e))) + LOG.error("Error calling ESG API GET [%s]: %s" % (url, str(e))) diff --git a/pkg/aloha/service/web.py b/pkg/aloha/service/web.py index e62676d..048510f 100644 --- a/pkg/aloha/service/web.py +++ b/pkg/aloha/service/web.py @@ -1,68 +1,242 @@ -"""Tornado web application assembly for aloha services.""" +"""FastAPI web application assembly for aloha services.""" import logging import os +import re +from typing import Any, List, Tuple -from tornado import httpserver, web -from tornado.routing import HostMatches +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response from ..logger import LOG from ..logger.logger import setup_logger from ..settings import SETTINGS setup_logger( - logging.getLogger("tornado.access"), + logging.getLogger("uvicorn.access"), formatter_str="A> %(asctime)s> %(message)s", module="access_%s" % (SETTINGS.config.get("APP_MODULE") or os.environ.get("APP_MODULE", "default")), ) -def _load_handlers(name): - """Load `(URL pattern, handler)` tuples from a service module.""" +def _load_routes(name: str) -> List[Tuple[str, Any]]: + """Load routes from a service module. + + Returns list of (url_pattern, handler_class) tuples. + """ mod = __import__(name, fromlist=["default_handlers"]) - handlers = [] + routes = [] + seen = set() for url, handler in mod.default_handlers: if not url.startswith("/"): url = "/" + url - handlers.append((url, handler)) - return handlers + # Deduplicate routes + key = (url, handler) + if key not in seen: + seen.add(key) + routes.append((url, handler)) + return routes + +class FastAPIApplication: + """FastAPI application that loads routes from configured service modules.""" -class WebApplication(web.Application): - """Tornado application that loads handlers from configured service modules.""" + def __init__(self, config: dict = None, **kwargs): + """Create the FastAPI application and its routes.""" + self.config = config or {} + self.app = FastAPI(title="Aloha Service", version="1.0.0", **kwargs) + self._setup_default_handler() + self._setup_routes() - def __init__(self, config: dict, *args, **kwargs): - """Create the application and its HTTP server.""" - handlers = self.init_handlers(config) - super().__init__(handlers=handlers, **config) - self.http_server = httpserver.HTTPServer(self) + def _setup_default_handler(self): + """Register a custom default 404 handler when configured.""" + handler_class = self.config.get("default_handler_class") + if not handler_class: + return - @staticmethod - def init_handlers(config: dict): - """Collect and normalize all handlers from configured service modules.""" - settings = config.get("service", {}) + @self.app.exception_handler(404) + async def _default_404_handler(request: Request, exc: Exception): + handler = handler_class(request=request) + if hasattr(handler, "handle") and callable(handler.handle): + return await handler.handle(request) + if hasattr(handler, "__call__") and callable(handler): + return await handler(request) + if hasattr(handler, "response") and callable(handler.response): + return await handler.response() + return JSONResponse( + {"code": 404, "message": ["Not Found"], "data": None}, + status_code=404, + ) + + def _setup_routes(self): + """Setup routes from configured service modules.""" + settings = self.config.get("service", {}) modules = settings.get("modules", []) - handlers = [] + for m in modules: - _handlers = _load_handlers(m) - for h in _handlers: - (url, class_handler) = h - handlers.append(h) + routes = _load_routes(m) + for url, handler_class in routes: + self._register_handler(url, handler_class) s_log_msg = "Loaded API module %-50s" % url - if LOG.level < logging.INFO: # more verbose information - s_log_msg += "\t from class %s" % str(class_handler) + if LOG.level < logging.INFO: + s_log_msg += "\t from class %s" % str(handler_class) LOG.info(s_log_msg) - return [(HostMatches("(.*)"), handlers)] + def _register_handler(self, url: str, handler_class): + """Register a handler class as FastAPI routes based on its methods.""" + has_get = hasattr(handler_class, "get") and callable(getattr(handler_class, "get")) + has_post = hasattr(handler_class, "post") and callable(getattr(handler_class, "post")) + + # Determine path pattern for FastAPI + fastapi_url, path_params = self._convert_url_pattern(url) + + # Store path_params in closure for use in handlers + _has_path_params = path_params + _original_url = url + + # Register POST handler if handler class has post method + if has_post: + + async def post_handler(request: Request): + kwargs = {} + handler = handler_class() + handler._request = request + + # Extract path params from URL + if _has_path_params: + match_path = self._match_path(_original_url, str(request.url.path)) + if match_path: + kwargs.update(match_path) + + try: + body = await request.json() + except Exception: + body = {} + + kwargs.update(body) + + try: + result = await handler.post(**kwargs) + # If handler returns a Response object, return it directly + if isinstance(result, Response): + return result + # Otherwise, wrap in standard response format + resp = dict(code=5200, message=["success"]) + if isinstance(result, dict): + resp["data"] = result.get("data", result) + else: + resp["data"] = result + return JSONResponse(resp) + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + self.app.post(fastapi_url)(post_handler) + + # Register GET handler if handler class has get method + if has_get: + + async def get_handler(request: Request): + kwargs = {} + handler = handler_class() + handler._request = request + + # Extract path params from URL + if _has_path_params: + match_path = self._match_path(_original_url, str(request.url.path)) + if match_path: + kwargs.update(match_path) + + kwargs.update(dict(request.query_params)) + + try: + result = await handler.get(**kwargs) + # If handler returns a Response object, return it directly + if isinstance(result, Response): + return result + # Otherwise, wrap in standard response format + resp = dict(code=5200, message=["success"]) + if isinstance(result, dict): + resp["data"] = result.get("data", result) + else: + resp["data"] = result + return JSONResponse(resp) + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + self.app.get(fastapi_url)(get_handler) + + # Default: register a POST handler using response() method + if not has_post and not has_get: + + async def default_handler(request: Request): + kwargs = {} + handler = handler_class() + handler._request = request + + # Extract path params from URL + if _has_path_params: + match_path = self._match_path(_original_url, str(request.url.path)) + if match_path: + kwargs.update(match_path) + + try: + body = await request.json() + except Exception: + body = {} + + kwargs.update(body) + + resp = dict(code=5200, message=["success"]) + try: + result = handler.response(**kwargs) + resp["data"] = result + except Exception as e: + if handler.LOG.level == logging.DEBUG: + handler.LOG.error(e, exc_info=True) + return JSONResponse({"code": 5201, "message": [repr(e)]}, status_code=500) + + return JSONResponse(resp) + + self.app.post(fastapi_url)(default_handler) + + def _convert_url_pattern(self, tornado_pattern: str) -> Tuple[str, bool]: + """Convert Tornado URL pattern to FastAPI pattern. + + Tornado: /api/common/sys_info/(.*) + FastAPI: /api/common/sys_info/{path_param} + """ + has_capture = "(.*)" in tornado_pattern + fastapi_pattern = tornado_pattern.replace("(.*)", "{path_param:path}") + return fastapi_pattern, has_capture + + def _match_path(self, tornado_pattern: str, path: str) -> dict: + """Match a path against a Tornado pattern and extract params.""" + # Convert Tornado pattern to regex + pattern = tornado_pattern + pattern = pattern.replace("(.*)", r"(?P.*)") + pattern = "^" + pattern + "$" + + match = re.match(pattern, path) + if match: + return match.groupdict() + return {} + + def get_port(self) -> int: + """Get the configured port.""" + service_settings = self.config.get("service", {}) + port = service_settings.get("port") or int(os.environ.get("PORT_SVC", 8000)) + port = int(os.environ.get("PORT", port)) + return port - def start(self): - """Bind the configured port and start the HTTP server.""" - service_settings = self.settings.get("service", {}) + def get_workers(self) -> int: + """Get the configured number of workers.""" + service_settings = self.config.get("service", {}) + return int(service_settings.get("num_process") or 1) - port = service_settings.get("port") or int(os.environ.get("PORT_SVC", 80)) - port = os.environ.get("port", port) # if overwrite port in param - num_process = int(service_settings.get("num_process") or 0) - LOG.info("Starting service with [%s] process at port [%s]...", num_process or "undefined", port) - self.http_server.bind(port) - self.http_server.start(num_processes=num_process) +# Backward compatibility alias +WebApplication = FastAPIApplication diff --git a/pkg/aloha/time/__init__.py b/pkg/aloha/time/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pkg/aloha/time/timeout_async.py b/pkg/aloha/time/timeout_async.py deleted file mode 100644 index 6b53803..0000000 --- a/pkg/aloha/time/timeout_async.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Refer to: https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py""" - -import asyncio -import enum -import warnings -from types import TracebackType -from typing import Optional, Type, final - -__all__ = ("timeout", "timeout_at", "Timeout") - - -def timeout(delay: Optional[float]) -> "Timeout": - """timeout context manager. - Useful in cases when you want to apply timeout logic around block - of code or in cases when asyncio.wait_for is not suitable. For example: - >>> async with timeout(0.001): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - delay - value in seconds or None to disable timeout logic - """ - loop = asyncio.get_running_loop() - if delay is not None: - deadline = loop.time() + delay # type: Optional[float] - else: - deadline = None - return Timeout(deadline, loop) - - -def timeout_at(deadline: Optional[float]) -> "Timeout": - """Schedule the timeout at absolute time. - deadline argument points on the time in the same clock system - as loop.time(). - Please note: it is not POSIX time but a time with - undefined starting base, e.g. the time of the system power on. - >>> async with timeout_at(loop.time() + 10): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - """ - loop = asyncio.get_running_loop() - return Timeout(deadline, loop) - - -class _State(enum.Enum): - INIT = "INIT" - ENTER = "ENTER" - TIMEOUT = "TIMEOUT" - EXIT = "EXIT" - - -@final -class Timeout: - # Internal class, please don't instantiate it directly - # Use timeout() and timeout_at() public factories instead. - # - # Implementation note: `async with timeout()` is preferred - # over `with timeout()`. - # While technically the Timeout class implementation - # doesn't need to be async at all, - # the `async with` statement explicitly points that - # the context manager should be used from async function context. - # - # This design allows to avoid many silly misusages. - # - # TimeoutError is raised immediately when scheduled - # if the deadline is passed. - # The purpose is to time out as soon as possible - # without waiting for the next await expression. - - __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") - - def __init__(self, deadline: Optional[float], loop: asyncio.AbstractEventLoop) -> None: - self._loop = loop - self._state = _State.INIT - - self._timeout_handler = None # type: Optional[asyncio.Handle] - if deadline is None: - self._deadline = None # type: Optional[float] - else: - self.update(deadline) - - def __enter__(self) -> "Timeout": - warnings.warn( - "with timeout() is deprecated, use async with timeout() instead", - DeprecationWarning, - stacklevel=2, - ) - self._do_enter() - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - self._do_exit(exc_type) - return None - - async def __aenter__(self) -> "Timeout": - self._do_enter() - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - self._do_exit(exc_type) - return None - - @property - def expired(self) -> bool: - """Is timeout expired during execution?""" - return self._state == _State.TIMEOUT - - @property - def deadline(self) -> Optional[float]: - return self._deadline - - def reject(self) -> None: - """Reject scheduled timeout if any.""" - # cancel is maybe better name but - # task.cancel() raises CancelledError in asyncio world. - if self._state not in (_State.INIT, _State.ENTER): - raise RuntimeError(f"invalid state {self._state.value}") - self._reject() - - def _reject(self) -> None: - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._timeout_handler = None - - def shift(self, delay: float) -> None: - """Advance timeout on delay seconds. - The delay can be negative. - Raise RuntimeError if shift is called when deadline is not scheduled - """ - deadline = self._deadline - if deadline is None: - raise RuntimeError("cannot shift timeout if deadline is not scheduled") - self.update(deadline + delay) - - def update(self, deadline: float) -> None: - """Set deadline to absolute value. - deadline argument points on the time in the same clock system - as loop.time(). - If new deadline is in the past the timeout is raised immediately. - Please note: it is not POSIX time but a time with - undefined starting base, e.g. the time of the system power on. - """ - if self._state == _State.EXIT: - raise RuntimeError("cannot reschedule after exit from context manager") - if self._state == _State.TIMEOUT: - raise RuntimeError("cannot reschedule expired timeout") - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._deadline = deadline - if self._state != _State.INIT: - self._reschedule() - - def _reschedule(self) -> None: - assert self._state == _State.ENTER - deadline = self._deadline - if deadline is None: - return - - now = self._loop.time() - if self._timeout_handler is not None: - self._timeout_handler.cancel() - - task = asyncio.current_task() - if deadline <= now: - self._timeout_handler = self._loop.call_soon(self._on_timeout, task) - else: - self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) - - def _do_enter(self) -> None: - if self._state != _State.INIT: - raise RuntimeError(f"invalid state {self._state.value}") - self._state = _State.ENTER - self._reschedule() - - def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: - if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: - self._timeout_handler = None - raise asyncio.TimeoutError - # timeout has not expired - self._state = _State.EXIT - self._reject() - return None - - def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() - self._state = _State.TIMEOUT - # drop the reference early - self._timeout_handler = None diff --git a/pkg/aloha/time/timeout_asyncio.py b/pkg/aloha/time/timeout_asyncio.py deleted file mode 100644 index eae421b..0000000 --- a/pkg/aloha/time/timeout_asyncio.py +++ /dev/null @@ -1,217 +0,0 @@ -import asyncio -import enum -import sys -import warnings -from functools import partial, wraps -from types import TracebackType -from typing import Any, Optional, Type - -# from typing_extensions import final - - -__all__ = ("timeout", "timeout_at") - - -def timeout(delay: Optional[float]) -> "Timeout": - """timeout context manager. - Useful in cases when you want to apply timeout logic around block - of code or in cases when asyncio.wait_for is not suitable. For example: - >>> async with timeout(0.001): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - delay - value in seconds or None to disable timeout logic - """ - loop = _get_running_loop() - if delay is not None: - deadline = loop.time() + delay # type: Optional[float] - else: - deadline = None - return Timeout(deadline, loop) - - -def timeout_at(deadline: Optional[float]) -> "Timeout": - """Schedule the timeout at absolute time. - deadline arguments points on the time in the same clock system - as loop.time(). - Please note: it is not POSIX time but a time with - undefined starting base, e.g. the time of the system power on. - >>> async with timeout_at(loop.time() + 10): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - """ - loop = _get_running_loop() - return Timeout(deadline, loop) - - -class _State(enum.Enum): - INIT = "INIT" - ENTER = "ENTER" - TIMEOUT = "TIMEOUT" - EXIT = "EXIT" - - -# @final -class Timeout: - # Internal class, please don't instantiate it directly - # Use timeout() and timeout_at() public factories instead. - # - # Implementation note: `async with timeout()` is preferred - # over `with timeout()`. - # While technically the Timeout class implementation - # doesn't need to be async at all, - # the `async with` statement explicitly points that - # the context manager should be used from async function context. - # - # This design allows to avoid many silly misusages. - # - # TimeoutError is raised immadiatelly when scheduled - # if the deadline is passed. - # The purpose is to time out as sson as possible - # without waiting for the next await expression. - - __slots__ = ("_deadline", "_loop", "_state", "_task", "_timeout_handler") - - def __init__(self, deadline: Optional[float], loop: asyncio.AbstractEventLoop) -> None: - self._loop = loop - self._state = _State.INIT - - task = _current_task(self._loop) - self._task = task - - self._timeout_handler = None # type: Optional[asyncio.Handle] - if deadline is None: - self._deadline = None # type: Optional[float] - else: - self.shift_to(deadline) - - def __enter__(self) -> "Timeout": - warnings.warn( - "with timeout() is deprecated, use async with timeout() instead", - DeprecationWarning, - stacklevel=2, - ) - self._do_enter() - return self - - def __exit__( - self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> Optional[bool]: - self._do_exit(exc_type) - return None - - async def __aenter__(self) -> "Timeout": - self._do_enter() - return self - - async def __aexit__( - self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> Optional[bool]: - self._do_exit(exc_type) - return None - - @property - def expired(self) -> bool: - """Is timeout expired during execution?""" - return self._state == _State.TIMEOUT - - @property - def deadline(self) -> Optional[float]: - return self._deadline - - def reject(self) -> None: - """Reject scheduled timeout if any.""" - # cancel is maybe better name but - # task.cancel() raises CancelledError in asyncio world. - if self._state not in (_State.INIT, _State.ENTER): - raise RuntimeError("invalid state {}".format(self._state.value)) - self._reject() - - def _reject(self) -> None: - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._timeout_handler = None - - def shift_by(self, delay: float) -> None: - """Advance timeout on delay seconds. - The delay can be negative. - """ - now = self._loop.time() - self.shift_to(now + delay) - - def shift_to(self, deadline: float) -> None: - """Advance timeout on the abdelay seconds. - If new deadline is in the past - the timeout is raised immediatelly. - """ - if self._state == _State.EXIT: - raise RuntimeError("cannot reschedule after exit from context manager") - if self._state == _State.TIMEOUT: - raise RuntimeError("cannot reschedule expired timeout") - if self._timeout_handler is not None: - self._timeout_handler.cancel() - self._deadline = deadline - now = self._loop.time() - if deadline <= now: - self._timeout_handler = None - if self._state == _State.INIT: - raise asyncio.TimeoutError - else: - # state is ENTER - raise asyncio.CancelledError - self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, self._task) - - def _do_enter(self) -> None: - if self._state != _State.INIT: - raise RuntimeError("invalid state {}".format(self._state.value)) - self._state = _State.ENTER - - def _do_exit(self, exc_type: Type[BaseException]) -> None: - if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: - self._timeout_handler = None - raise asyncio.TimeoutError - # timeout is not expired - self._state = _State.EXIT - self._reject() - return None - - def _on_timeout(self, task: "asyncio.Task[None]") -> None: - task.cancel() - self._state = _State.TIMEOUT - - -def _current_task(loop: asyncio.AbstractEventLoop) -> "Optional[asyncio.Task[Any]]": - if sys.version_info >= (3, 7): - return asyncio.current_task(loop=loop) - else: - return asyncio.Task.current_task(loop=loop) - - -def _get_running_loop() -> asyncio.AbstractEventLoop: - loop = asyncio.get_running_loop() - if loop is None: - print("--" * 20) - loop = asyncio.get_event_loop() - - if sys.version_info >= (3, 7): - return loop - else: - if not loop.is_running(): - raise RuntimeError("no running event loop") - return loop - - -def aioify(func): - @wraps(func) - async def run(*args, loop=None, executor=None, **kwargs): - if loop is None: - loop = asyncio.get_event_loop() - p_func = partial(func, *args, **kwargs) - return await loop.run_in_executor(executor, p_func) - - return run diff --git a/pkg/aloha/time/timeout_signal.py b/pkg/aloha/time/timeout_signal.py deleted file mode 100644 index 2a5496d..0000000 --- a/pkg/aloha/time/timeout_signal.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Easily put time restrictions on things -Note: Requires Python 3.x -Usage as a context manager: -``` -with TimeOutRestriction(10): - something_that_should_not_exceed_ten_seconds() -``` -Usage as a decorator: -``` -@TimeOutRestriction(10) -def something_that_should_not_exceed_ten_seconds(): - do_stuff_with_a_timeout() -``` -Handle timeouts: -``` -try: - with TimeOutRestriction(10): - something_that_should_not_exceed_ten_seconds() - except TimeoutError: - log('Got a timeout, couldn't finish') -``` -Suppress TimeoutError and just die after expiration: -``` -with TimeOutRestriction(10, suppress_timeout_errors=True): - something_that_should_not_exceed_ten_seconds() -print('Maybe exceeded 10 seconds, but finished either way') -``` -""" - -import contextlib -import errno -import os -import signal - -DEFAULT_TIMEOUT_MESSAGE = os.strerror(errno.ETIME) - - -class TimeOutRestriction(contextlib.ContextDecorator): - def __init__( - self, milliseconds: float, *, timeout_message: str = DEFAULT_TIMEOUT_MESSAGE, suppress_timeout_errors: bool = False - ): - self.millisecond = milliseconds - self.timeout_message = timeout_message - self.suppress = bool(suppress_timeout_errors) - - def _timeout_handler(self, signum, frame): - raise TimeoutError(self.timeout_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self._timeout_handler) - signal.setitimer(signal.ITIMER_REAL, self.millisecond / 1000) - - def __exit__(self, exc_type, exc_val, exc_tb): - signal.setitimer(signal.ITIMER_REAL, 0, 0) - if self.suppress and exc_type is TimeoutError: - return True diff --git a/pkg/aloha/util/html.py b/pkg/aloha/util/html.py index e9867a8..01d2ddb 100644 --- a/pkg/aloha/util/html.py +++ b/pkg/aloha/util/html.py @@ -1,7 +1,5 @@ """HTML extraction helpers.""" -import re - from lxml import etree @@ -21,11 +19,21 @@ def extract_img_url(string): def extract_text(raw_data): """Extract visible text from an HTML fragment.""" if raw_data is not None: - raw_data = re.sub(r"", "", raw_data) html = etree.HTML(raw_data) content = [] if html is not None: + for script in html.xpath("//script"): + parent = script.getparent() + if parent is not None: + if script.tail: + prev = script.getprevious() + if prev is not None: + prev.tail = (prev.tail or "") + script.tail + else: + parent.text = (parent.text or "") + script.tail + parent.remove(script) + html_data = html.xpath("/html/body/*//text()") for data in html_data: tmp = ( diff --git a/pkg/aloha/util/sys_cuda.py b/pkg/aloha/util/sys_cuda.py index 181b422..06a5dbd 100644 --- a/pkg/aloha/util/sys_cuda.py +++ b/pkg/aloha/util/sys_cuda.py @@ -16,8 +16,8 @@ def get_gpu_status_for_tf(*args, **kwargs) -> dict: LOG.info("tensorflow version = %s" % tf.__version__) status = Status(version=tf.__version__, gpu_availability=tf.test.is_gpu_available()) except Exception as e: - LOG.error("Error detecting CUDA availability for tensorflow") - LOG.error(str(e)) + msg = "Error detecting CUDA availability for tensorflow: %s" % str(e) + LOG.warning(msg) return status._asdict() @@ -29,8 +29,8 @@ def get_gpu_status_for_torch(*args, **kwargs) -> dict: LOG.info("torch version = %s" % torch.__version__) status = Status(version=torch.__version__, gpu_availability=torch.cuda.is_available()) except Exception as e: - LOG.error("Error detecting CUDA availability for torch") - LOG.error(str(e)) + msg = "Error detecting CUDA availability for torch: %s" % str(e) + LOG.warning(msg) return status._asdict() @@ -43,8 +43,8 @@ def get_gpu_status_for_paddle(*args, **kwargs) -> dict: paddle.utils.run_check() status = Status(version=paddle.__version__, gpu_availability=True) except Exception as e: - LOG.error("Error detecting CUDA availability for paddle") - LOG.error(str(e)) + msg = "Error detecting CUDA availability for paddle: %s" % str(e) + LOG.warning(msg) return status._asdict() diff --git a/pkg/aloha/util/sys_gpu.py b/pkg/aloha/util/sys_gpu.py index a718645..b999031 100644 --- a/pkg/aloha/util/sys_gpu.py +++ b/pkg/aloha/util/sys_gpu.py @@ -66,14 +66,14 @@ def get_device_list(self) -> list: try: name = nvml.nvmlDeviceGetName(handler).decode(encoding="UTF-8") except Exception as e: - LOG.info("Failed to get device name!") - LOG.info(str(e)) + msg = "Failed to get device name: %s" % str(e) + LOG.info(msg) try: arch = nvml.nvmlDeviceGetArchitecture(handler) except Exception as e: - LOG.info("Failed to get device architecture!") - LOG.info(str(e)) + msg = "Failed to get device architecture: %s" % str(e) + LOG.info(msg) device = Device(index=i, name=name, arch=arch) LOG.debug("Found device {index} info: name={name}; arch={arch}".format(**device._asdict())) @@ -103,7 +103,7 @@ def get_smi(): try: return nvidia_smi.getInstance() except Exception as e: - LOG.warn("Failed to get smi: %s" % str(e)) + LOG.warning("Failed to get smi: %s" % str(e)) return diff --git a/pkg/aloha/util/time.py b/pkg/aloha/util/time.py new file mode 100644 index 0000000..66f74bc --- /dev/null +++ b/pkg/aloha/util/time.py @@ -0,0 +1,77 @@ +"""Time and timeout utilities.""" + +import asyncio +import concurrent.futures +import inspect +from typing import Any, Callable, Optional + +__all__ = ("run_with_timeout", "run_async_with_timeout") + + +def run_with_timeout( + func: Callable[..., Any], + timeout_seconds: float, + *args: Any, + fn_callback_success: Optional[Callable[[Any], Any]] = None, + fn_callback_fail: Optional[Callable[[Exception], Any]] = None, + **kwargs: Any, +) -> Any: + """Wrap a synchronous function call with a timeout. + + If the operation completes within `timeout_seconds`, `fn_callback_success(result)` + is executed if provided, and the result is returned. + If the operation times out or raises an exception, `fn_callback_fail(exception)` + is executed if provided, and the exception is reraised. + """ + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, *args, **kwargs) + try: + result = future.result(timeout=timeout_seconds) + if fn_callback_success is not None: + fn_callback_success(result) + return result + except Exception as e: + if isinstance(e, concurrent.futures.TimeoutError): + exc = TimeoutError(f"Operation timed out after {timeout_seconds} seconds") + else: + exc = e + if fn_callback_fail is not None: + fn_callback_fail(exc) + raise exc + + +async def run_async_with_timeout( + func: Callable[..., Any], + timeout_seconds: float, + *args: Any, + fn_callback_success: Optional[Callable[[Any], Any]] = None, + fn_callback_fail: Optional[Callable[[Exception], Any]] = None, + **kwargs: Any, +) -> Any: + """Wrap an asynchronous function call (coroutine function or sync function inside executor) with a timeout. + + If the operation completes within `timeout_seconds`, `fn_callback_success(result)` + is executed if provided, and the result is returned. + If the operation times out or raises an exception, `fn_callback_fail(exception)` + is executed if provided, and the exception is reraised. + """ + try: + if inspect.iscoroutinefunction(func): + coro = func(*args, **kwargs) + else: + # Run sync function in default executor to prevent blocking the event loop + loop = asyncio.get_running_loop() + coro = loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + result = await asyncio.wait_for(coro, timeout=timeout_seconds) + if fn_callback_success is not None: + fn_callback_success(result) + return result + except Exception as e: + if isinstance(e, (asyncio.TimeoutError, concurrent.futures.TimeoutError)): + exc = TimeoutError(f"Operation timed out after {timeout_seconds} seconds") + else: + exc = e + if fn_callback_fail is not None: + fn_callback_fail(exc) + raise exc diff --git a/pkg/setup.py b/pkg/setup.py index db25938..7987507 100644 --- a/pkg/setup.py +++ b/pkg/setup.py @@ -13,7 +13,7 @@ dict_extra_requires = { "build": ["Cython"], - "service": ["requests", "tornado", "psutil", "pyjwt", "fastapi", "httpx"], + "service": ["psutil", "pyjwt", "fastapi", "httpx", "uvicorn"], "db": [ "sqlalchemy", "psycopg[binary]", @@ -38,7 +38,7 @@ author="LabNow.ai", author_email="postmaster@labnow.ai", license="Apache-2.0", - url="https://github.com/LabNow.ai/aloha", + url="https://github.com/LabNow.ai/aloha-python", project_urls={ "Source": "https://github.com/LabNow-ai/aloha-python", "CI Pipeline": "https://github.com/LabNow-ai/aloha-python/actions", @@ -54,7 +54,7 @@ **dict_extra_requires, "all": sorted(y for x in dict_extra_requires.values() for y in x), }, - python_requires=">=3.6", + python_requires=">=3.10", entry_points={"console_scripts": ["aloha = aloha.script.base:main"]}, data_files=[], description="Aloha - a versatile Python utility package for building services", diff --git a/tool/app-demo.Dockerfile b/src/app-demo.Dockerfile similarity index 92% rename from tool/app-demo.Dockerfile rename to src/app-demo.Dockerfile index 2cc98e4..8313686 100644 --- a/tool/app-demo.Dockerfile +++ b/src/app-demo.Dockerfile @@ -15,8 +15,9 @@ ARG PROFILE_LOCALIZE COPY . /tmp/app -RUN set -ex && cd /tmp/app && mkdir -pv ${DIR_APP} \ +RUN set -ex && mkdir -pv ${DIR_APP} \ && source /opt/utils/script-localize.sh ${PROFILE_LOCALIZE} \ + && cd /tmp/src \ && if [[ "$ENABLE_CODE_BUILD" = "true" ]] ; then \ echo "-> Building src to binary..." && pip install -U aloha[build] && aloha compile --base=./src --dist=${DIR_APP}/ ; \ else \ @@ -35,8 +36,8 @@ USER root WORKDIR ${DIR_APP} COPY --from=builder ${DIR_APP} ${DIR_APP}/ -ENV PORT_SVC=${PORT_SVC:-80} \ - ENTRYPOINT="app_common.debug" +ENV PORT_SVC=${PORT_SVC:-80} +ENV ENTRYPOINT="app_common.debug" RUN set -eux && pwd && ls -alh \ && source /opt/utils/script-localize.sh ${PROFILE_LOCALIZE} \ diff --git a/src/app_common/api/api_common_query_postgres.py b/src/app_common/api/api_common_query_postgres.py index ad82708..c204f9f 100644 --- a/src/app_common/api/api_common_query_postgres.py +++ b/src/app_common/api/api_common_query_postgres.py @@ -1,17 +1,15 @@ from typing import Optional import pandas as pd -from sqlalchemy import text - from aloha.base import BaseModule from aloha.db.postgres import PostgresOperator from aloha.logger import LOG from aloha.service.api.v0 import APIHandler +from sqlalchemy import text class ApiQueryPostgres(APIHandler): - def response(self, sql: str, orient: str = 'columns', config_profile: str = None, - params=None, *args, **kwargs) -> str: + def response(self, sql: str, orient: str = "columns", config_profile: str = None, params=None, *args, **kwargs) -> str: op_query_db = QueryDb() df = op_query_db.query_db(sql=sql, config_profile=config_profile, params=params) ret = df.to_json(orient=orient, force_ascii=False) @@ -26,7 +24,7 @@ def get_operator(self, config_profile: str, *args, **kwargs): return PostgresOperator(config_dict) def query_db(self, sql: str, config_profile: str = None, params=None, *args, **kwargs) -> Optional[pd.DataFrame]: - op = self.get_operator(config_profile or 'pg_rec_readonly') + op = self.get_operator(config_profile or "pg_rec_readonly") return pd.read_sql(sql=text(sql), con=op.engine, params=params) @@ -37,22 +35,24 @@ def query_db(self, sql: str, config_profile: str = None, params=None, *args, **k def main(): - import sys import argparse + import sys + sys.argv.pop(0) parser = argparse.ArgumentParser() parser.add_argument("--config-profile") - parser.add_argument("--sql", nargs='?') + parser.add_argument("--sql", nargs="?") args = parser.parse_args() dict_params = vars(args) query = QueryDb() op = query.get_operator(**dict_params) - LOG.info('Connection string: %s' % op.connection_str) + LOG.info("Connection string: %s" % op.connection_str) - if dict_params.get('sql', None) is not None: + if dict_params.get("sql", None) is not None: from tabulate import tabulate - LOG.info('Query result for: %s' % dict_params['sql']) + + LOG.info("Query result for: %s" % dict_params["sql"]) df = query.query_db(**dict_params) - table = tabulate(df, headers='keys', tablefmt='psql') + table = tabulate(df, headers="keys", tablefmt="psql") print(table) diff --git a/src/app_common/api/api_common_sys_info.py b/src/app_common/api/api_common_sys_info.py index b59ac49..f3c0e45 100644 --- a/src/app_common/api/api_common_sys_info.py +++ b/src/app_common/api/api_common_sys_info.py @@ -1,20 +1,17 @@ from datetime import datetime from aloha.service.api.v0 import APIHandler -from aloha.util import (sys_info, sys_gpu, sys_cuda) +from aloha.util import sys_cuda, sys_gpu, sys_info def echo(*args, **kwargs): - return { - 'sys_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'), - **kwargs - } + return {"sys_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), **kwargs} class SysStatusInfo(APIHandler): @staticmethod def get_sys_info(kind: str = None, **kwargs) -> dict: - kinds = ['echo'] + kinds = ["echo"] if kind is None or len(kind) == 0: pass else: @@ -22,14 +19,12 @@ def get_sys_info(kind: str = None, **kwargs) -> dict: dict_func = { "echo": echo, - "sys": sys_info.get_sys_info, "os": sys_info.get_os_info, "cpu": sys_info.get_cpu_info, "mem": sys_info.get_mem_info, "disk": sys_info.get_disk_info, "net": sys_info.get_net_info, - "gpu": sys_gpu.get_gpu_info, "cuda": sys_cuda.get_cuda_info, "cuda-torch": sys_cuda.get_gpu_status_for_torch, @@ -39,7 +34,7 @@ def get_sys_info(kind: str = None, **kwargs) -> dict: ret = {} for k in sorted(set(kinds)): if k not in dict_func: - k = 'echo' + k = "echo" ret.update({k: dict_func.get(k)()}) return ret @@ -48,9 +43,19 @@ def response(self, kind: str = None, *args, **kwargs) -> dict: return self.get_sys_info(kind=kind) async def get(self, kind: str = None, *args, **kwargs): + # Handle path_param from URL pattern + if "path_param" in kwargs: + # If kind is not set, try to use path_param as kind + if kind is None: + kind = kwargs.pop("path_param", None) data = self.get_sys_info(kind=kind) return self.finish(data) + async def post(self, *args, **kwargs): + # For POST, use the response method + data = self.response(**kwargs) + return self.finish(data) + default_handlers = [ (r"/api/common/sys_info", SysStatusInfo), diff --git a/src/app_common/debug.py b/src/app_common/debug.py index 6425923..d4b3d81 100644 --- a/src/app_common/debug.py +++ b/src/app_common/debug.py @@ -27,7 +27,3 @@ def main(): # The event loop starts after start. app.start() - - -if __name__ == "__main__": - main() diff --git a/src/app_common/main.py b/src/app_common/main.py index 1943367..c4feeb7 100644 --- a/src/app_common/main.py +++ b/src/app_common/main.py @@ -1,4 +1,5 @@ def main(): from aloha.service.app import Application + app = Application() app.start() diff --git a/tool/dev-demo.Dockerfile b/src/dev-demo.Dockerfile similarity index 89% rename from tool/dev-demo.Dockerfile rename to src/dev-demo.Dockerfile index 80e3e3b..babf62f 100644 --- a/tool/dev-demo.Dockerfile +++ b/src/dev-demo.Dockerfile @@ -8,7 +8,7 @@ FROM ${BASE_NAMESPACE:+$BASE_NAMESPACE/}${BASE_IMG} AS dev ARG PROFILE_LOCALIZE -COPY app/requirements.txt /tmp/ +COPY src/requirements.txt /tmp/ USER root RUN set -eux && pwd && ls -alh \ @@ -17,7 +17,7 @@ RUN set -eux && pwd && ls -alh \ && npm install -g pnpm \ # ----------- handle backend matters ------------ && pip install -U --no-cache-dir pip jupyterlab \ - # && pip install -U --no-cache-dir -r /tmp/requirements.txt \ + && pip install -U --no-cache-dir -r /tmp/requirements.txt \ # ----------- install db client to connect db via terminal ------------ && source /opt/utils/script-setup-db-clients.sh && setup_postgresql_client 17 \ # ----------- clean up ----------- diff --git a/src/requirements.txt b/src/requirements.txt index f718c00..67b4f89 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -4,12 +4,11 @@ pyhocon pycryptodome packaging Cython -requests -tornado psutil pyjwt fastapi httpx +uvicorn sqlalchemy psycopg[binary] pymysql diff --git a/src/resource/config/deploy-DEV.conf b/src/resource/config/deploy-DEV.conf index 3a361ed..b9f7fd1 100644 --- a/src/resource/config/deploy-DEV.conf +++ b/src/resource/config/deploy-DEV.conf @@ -1,4 +1,6 @@ deploy = { + port_service = 9000 + postgres_db0 = { "host": "localhost", "port": 5432, diff --git a/src/tests/test_demo.py b/src/tests/test_demo.py new file mode 100644 index 0000000..cba3491 --- /dev/null +++ b/src/tests/test_demo.py @@ -0,0 +1,14 @@ +from aloha.testing.unit import UnitTestCase + + +# 1. A simple function-based test for pytest +def test_simple_addition(): + assert 1 + 1 == 2 + + +# 2. A class-based test inheriting from UnitTestCase to demonstrate integrating with the aloha package +class TestDemo(UnitTestCase): + def test_aloha_config_loaded(self): + # Verify that aloha settings config can be read + self.assertIsNotNone(self.config) + self.LOG.info("Aloha configuration verified successfully in test!") diff --git a/src/tests/test_time.py b/src/tests/test_time.py new file mode 100644 index 0000000..0d8581d --- /dev/null +++ b/src/tests/test_time.py @@ -0,0 +1,109 @@ +import pytest +import time +import asyncio +from aloha.util.time import run_with_timeout, run_async_with_timeout + +# Helpers +def sync_add(a, b, delay=0): + if delay > 0: + time.sleep(delay) + return a + b + +def sync_raise(): + raise ValueError("sync error") + +async def async_add(a, b, delay=0): + if delay > 0: + await asyncio.sleep(delay) + return a + b + +async def async_raise(): + raise ValueError("async error") + + +# 1. Sync Tests +def test_sync_success(): + success_called = False + result_val = None + + def on_success(res): + nonlocal success_called, result_val + success_called = True + result_val = res + + res = run_with_timeout(sync_add, 1.0, 2, 3, fn_callback_success=on_success) + assert res == 5 + assert success_called is True + assert result_val == 5 + +def test_sync_timeout(): + fail_called = False + error_val = None + + def on_fail(err): + nonlocal fail_called, error_val + fail_called = True + error_val = err + + with pytest.raises(TimeoutError): + run_with_timeout(sync_add, 0.1, 2, 3, delay=0.5, fn_callback_fail=on_fail) + assert fail_called is True + assert isinstance(error_val, TimeoutError) + +def test_sync_exception(): + fail_called = False + error_val = None + + def on_fail(err): + nonlocal fail_called, error_val + fail_called = True + error_val = err + + with pytest.raises(ValueError, match="sync error"): + run_with_timeout(sync_raise, 1.0, fn_callback_fail=on_fail) + assert fail_called is True + assert isinstance(error_val, ValueError) + + +# 2. Async Tests +def test_async_success(): + success_called = False + result_val = None + + def on_success(res): + nonlocal success_called, result_val + success_called = True + result_val = res + + res = asyncio.run(run_async_with_timeout(async_add, 1.0, 2, 3, fn_callback_success=on_success)) + assert res == 5 + assert success_called is True + assert result_val == 5 + +def test_async_timeout(): + fail_called = False + error_val = None + + def on_fail(err): + nonlocal fail_called, error_val + fail_called = True + error_val = err + + with pytest.raises(TimeoutError): + asyncio.run(run_async_with_timeout(async_add, 0.1, 2, 3, delay=0.5, fn_callback_fail=on_fail)) + assert fail_called is True + assert isinstance(error_val, TimeoutError) + +def test_async_exception(): + fail_called = False + error_val = None + + def on_fail(err): + nonlocal fail_called, error_val + fail_called = True + error_val = err + + with pytest.raises(ValueError, match="async error"): + asyncio.run(run_async_with_timeout(async_raise, 1.0, fn_callback_fail=on_fail)) + assert fail_called is True + assert isinstance(error_val, ValueError) diff --git a/tool/cicd/docker-compose.app-demo.DEV.yml b/tool/cicd/docker-compose.app-demo.DEV.yml index d0db854..29a3d23 100644 --- a/tool/cicd/docker-compose.app-demo.DEV.yml +++ b/tool/cicd/docker-compose.app-demo.DEV.yml @@ -6,7 +6,7 @@ services: hostname: ${CONTAINER_NAME:-dev-app-demo} build: context: ../../ - dockerfile: tool/dev-demo.Dockerfile + dockerfile: src/dev-demo.Dockerfile args: - ENABLE_CODE_BUILD=false # - PROFILE_LOCALIZE=${PROFILE_LOCALIZE:-default}