From 2856204bb4836b022e8b85192321ca5f2f2bdf0a Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 12 Jun 2026 17:54:20 +0800 Subject: [PATCH 1/3] fix(email): support RFC-compliant non-ascii filename encoding for attachments --- backend/app/api/agents.py | 530 ++++++++++-------- backend/app/services/agent_tools.py | 1 + backend/app/services/email_service.py | 41 +- backend/app/services/llm/caller.py | 2 +- backend/app/services/tool_seeder.py | 2 +- deploy/docker-compose.yml | 2 +- frontend/src/i18n/en.json | 10 +- frontend/src/i18n/zh.json | 10 +- .../pages/agent-detail/AgentDetailPage.tsx | 80 +-- 9 files changed, 399 insertions(+), 279 deletions(-) diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index 169366574..82f2bec06 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -3,12 +3,11 @@ import hashlib import json import secrets -import time import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from loguru import logger from sqlalchemy import cast, func, select, String from sqlalchemy.ext.asyncio import AsyncSession @@ -17,7 +16,7 @@ from app.config import get_settings from app.core.permissions import build_visible_agents_query, check_agent_access, is_agent_creator from app.core.security import get_current_user -from app.database import get_db +from app.database import async_session, get_db from app.models.agent import Agent, AgentPermission, AgentTemplate from app.models.org import OrgMember from app.models.audit import ChatMessage @@ -26,6 +25,13 @@ from app.schemas.schemas import AgentCreate, AgentOut, AgentUpdate from app.services.storage import get_storage_backend from app.services.access_relationships import ensure_access_granted_platform_relationships +from app.services.quota_guard import check_agent_creation_quota, QuotaExceeded +from app.models.tenant import Tenant +from app.models.participant import Participant +from app.services.okr_agent_hook import hook_new_agent +from app.services.agent_manager import agent_manager +from app.models.skill import Skill +from app.services.resource_discovery import import_mcp_from_smithery router = APIRouter(prefix="/agents", tags=["agents"]) settings = get_settings() @@ -66,7 +72,9 @@ async def _archive_agent_task_history(db: AsyncSession, agent_id: uuid.UUID, arc } for task in tasks: - log_result = await db.execute(select(TaskLog).where(TaskLog.task_id == task.id).order_by(TaskLog.created_at.asc())) + log_result = await db.execute( + select(TaskLog).where(TaskLog.task_id == task.id).order_by(TaskLog.created_at.asc()) + ) logs = log_result.scalars().all() payload["tasks"].append( { @@ -110,6 +118,7 @@ async def _lazy_reset_token_counters(agent: Agent, db: AsyncSession) -> bool: Returns True if any counter was reset (caller should commit/flush). """ from datetime import datetime, timezone as tz + now = datetime.now(tz.utc) changed = False @@ -156,7 +165,8 @@ async def _build_unread_count_by_agent( ChatSession.is_group.is_(False), ChatSession.source_channel.notin_(["agent", "trigger"]), ChatMessage.role.in_(["assistant", "system", "tool_call"]), - ChatMessage.created_at > func.coalesce( + ChatMessage.created_at + > func.coalesce( ChatSession.last_read_at_by_user, datetime(1970, 1, 1, tzinfo=timezone.utc), ), @@ -179,6 +189,7 @@ async def list_templates( ): """List all available agent templates.""" from app.models.agent import AgentTemplate + result = await db.execute( select(AgentTemplate).order_by(AgentTemplate.is_builtin.desc(), AgentTemplate.created_at.asc()) ) @@ -208,6 +219,7 @@ async def _agent_to_out( ) -> AgentOut: """Serialize one agent with ``onboarded_for_me`` for the given viewer.""" from app.services.onboarding import is_onboarded + model = AgentOut.model_validate(agent) model.onboarded_for_me = await is_onboarded(db, agent.id, viewer_id) return model @@ -220,6 +232,7 @@ async def _agents_to_out( ) -> list[AgentOut]: """List variant that fetches all junction rows in one query.""" from app.services.onboarding import onboarded_agent_ids + onboarded = await onboarded_agent_ids(db, viewer_id, [a.id for a in agents]) out: list[AgentOut] = [] for a in agents: @@ -260,6 +273,7 @@ async def list_agents( await db.commit() unread_by_agent = await _build_unread_count_by_agent(db, agents, current_user) from app.services.onboarding import onboarded_agent_ids + onboarded = await onboarded_agent_ids(db, current_user.id, [a.id for a in agents]) out: list[AgentOut] = [] for a in agents: @@ -269,22 +283,158 @@ async def list_agents( return out +async def _background_agent_setup( + agent_id: uuid.UUID, + personality: str, + boundaries: str, + skill_ids: list[uuid.UUID], + template_skill_folder_names: list[str], + template_mcp_servers: list[str], +) -> None: + """Run all creation tasks asynchronously with small, short-lived transactions.""" + # 1. Initialize agent file system from template + try: + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if not agent: + logger.error(f"[background_agent_setup] Agent {agent_id} not found") + return + await agent_manager.initialize_agent_files( + db, + agent, + personality=personality, + boundaries=boundaries, + ) + await db.commit() + except Exception as e: + logger.exception(f"Error during agent file initialization for {agent_id}: {e}") + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if agent: + agent.status = "error" + await db.commit() + return + + # 2. Skill resolution (reads from DB) + skill_files_to_write = [] + try: + async with async_session() as db: + default_result = await db.execute(select(Skill).where(Skill.is_default)) + default_ids = {s.id for s in default_result.scalars().all()} + + template_skill_ids = set() + if template_skill_folder_names: + tpl_skills_r = await db.execute(select(Skill).where(Skill.folder_name.in_(template_skill_folder_names))) + template_skill_ids = {s.id for s in tpl_skills_r.scalars().all()} + + all_skill_ids = set(skill_ids) | default_ids | template_skill_ids + + if all_skill_ids: + skills_result = await db.execute( + select(Skill).where(Skill.id.in_(all_skill_ids)).options(selectinload(Skill.files)) + ) + skills = skills_result.scalars().all() + agent_prefix = agent_manager._agent_storage_prefix(agent_id) + for skill in skills: + for sf in skill.files: + skill_files_to_write.append( + (f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}", sf.content) + ) + except Exception as e: + logger.exception(f"Error resolving skills for agent {agent_id}: {e}") + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if agent: + agent.status = "error" + await db.commit() + return + + # 3. Skills Copying (I/O only, NO db connection held!) + if skill_files_to_write: + try: + import asyncio + + storage = get_storage_backend() + await asyncio.gather( + *[storage.write_text(key, content, encoding="utf-8") for key, content in skill_files_to_write] + ) + logger.info(f"[_skills_copy] background agent={agent_id} files={len(skill_files_to_write)} completed") + except Exception as e: + logger.exception(f"Error copying skills files for agent {agent_id}: {e}") + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if agent: + agent.status = "error" + await db.commit() + return + + # 4. Install template MCP servers + if template_mcp_servers: + for server_id in template_mcp_servers: + try: + result_msg = await import_mcp_from_smithery( + server_id=server_id, + agent_id=agent_id, + config={}, + ) + if result_msg.startswith("❌"): + logger.warning( + f"[create_agent] background MCP pre-install for '{server_id}' " + f"on agent {agent_id} reported error: {result_msg[:200]}" + ) + else: + logger.info( + f"[create_agent] background MCP pre-install '{server_id}' succeeded for agent {agent_id}" + ) + except Exception as e: + logger.warning( + f"[create_agent] background MCP pre-install for '{server_id}' on agent {agent_id} raised: {e}" + ) + + # 5. Start container and Hook OKR Agent + try: + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if not agent: + logger.error(f"[background_agent_setup] Agent {agent_id} not found before starting container") + return + + await agent_manager.start_container(db, agent) + + if agent.tenant_id: + await hook_new_agent(db, agent.id, agent.tenant_id) + + await db.commit() + except Exception as e: + logger.exception(f"Error starting container for agent {agent_id}: {e}") + async with async_session() as db: + agent_result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = agent_result.scalar_one_or_none() + if agent: + agent.status = "error" + await db.commit() + + @router.post("/", status_code=status.HTTP_201_CREATED) async def create_agent( data: AgentCreate, + background_tasks: BackgroundTasks, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Create a new digital employee (any authenticated user).""" # Check agent creation quota - from app.services.quota_guard import check_agent_creation_quota, QuotaExceeded try: await check_agent_creation_quota(current_user.id) except QuotaExceeded as e: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.message) # A TTL of 0 or less means the agent never expires. - from datetime import datetime, timedelta, timezone as tz ttl_hours = current_user.quota_agent_ttl_hours # Determine target tenant: normally user's tenant; admins can override via payload @@ -300,7 +450,6 @@ async def create_agent( default_heartbeat_interval = 240 # model default tenant_default_model_id = None if target_tenant_id: - from app.models.tenant import Tenant tenant_result = await db.execute(select(Tenant).where(Tenant.id == target_tenant_id)) tenant = tenant_result.scalar_one_or_none() if tenant: @@ -311,12 +460,15 @@ async def create_agent( default_webhook_rate = tenant.max_webhook_rate_ceiling or 5 tenant_default_model_id = tenant.default_model_id # Enforce heartbeat floor: new agents must respect company minimum - if tenant.min_heartbeat_interval_minutes and tenant.min_heartbeat_interval_minutes > default_heartbeat_interval: + if ( + tenant.min_heartbeat_interval_minutes + and tenant.min_heartbeat_interval_minutes > default_heartbeat_interval + ): default_heartbeat_interval = tenant.min_heartbeat_interval_minutes # If the caller didn't pick a model, fall back to the tenant's default. effective_primary_model_id = data.primary_model_id or tenant_default_model_id - expires_at = datetime.now(tz.utc) + timedelta(hours=ttl_hours) if ttl_hours and ttl_hours > 0 else None + expires_at = datetime.now(timezone.utc) + timedelta(hours=ttl_hours) if ttl_hours and ttl_hours > 0 else None agent = Agent( name=data.name, @@ -346,11 +498,14 @@ async def create_agent( await db.flush() # Auto-create Participant identity for the new agent - from app.models.participant import Participant - db.add(Participant( - type="agent", ref_id=agent.id, - display_name=agent.name, avatar_url=agent.avatar_url, - )) + db.add( + Participant( + type="agent", + ref_id=agent.id, + display_name=agent.name, + avatar_url=agent.avatar_url, + ) + ) await db.flush() # Set permissions @@ -366,10 +521,14 @@ async def create_agent( agent.company_access_level = access_level if data.permission_scope_ids: for scope_id in data.permission_scope_ids: - db.add(AgentPermission(agent_id=agent.id, scope_type="user", scope_id=scope_id, access_level=access_level)) + db.add( + AgentPermission(agent_id=agent.id, scope_type="user", scope_id=scope_id, access_level=access_level) + ) else: # "仅自己" — insert creator as the only permitted user - db.add(AgentPermission(agent_id=agent.id, scope_type="user", scope_id=current_user.id, access_level="manage")) + db.add( + AgentPermission(agent_id=agent.id, scope_type="user", scope_id=current_user.id, access_level="manage") + ) elif data.permission_scope_type == "custom": agent.access_mode = "custom" agent.company_access_level = access_level @@ -385,7 +544,6 @@ async def create_agent( agent.status = "idle" await db.commit() - from app.services.okr_agent_hook import hook_new_agent if agent.tenant_id: await hook_new_agent(db, agent.id, agent.tenant_id) await db.commit() @@ -395,169 +553,34 @@ async def create_agent( out["api_key"] = raw_key # Return once on creation return out - # Initialize agent file system from template - from app.services.agent_manager import agent_manager - await agent_manager.initialize_agent_files( - db, agent, - personality=data.personality, - boundaries=data.boundaries, - ) - from app.api.relationships import _regenerate_relationships_file - await _regenerate_relationships_file(db, agent.id) - - # Copy selected skills + mandatory default skills into agent workspace - from app.models.skill import Skill - from sqlalchemy.orm import selectinload - - # Always include global default skills (mcp-installer, skill-creator, - # complex-task-executor) - t_skills_copy_start = time.perf_counter() - t_default_query_start = time.perf_counter() - default_result = await db.execute(select(Skill).where(Skill.is_default)) - default_ids = {s.id for s in default_result.scalars().all()} - t_default_query = time.perf_counter() - t_default_query_start - - # Include the template's declared default skills (e.g. trading templates - # ship with `market-data` / `financial-calendar` in their meta.yaml). - # Without this, the SKILL.md never reaches `/skills//`, - # so the agent has no idea those MCP-backed skills exist and silently - # falls back to web search. - template_skill_ids: set = set() - t_template_query = 0.0 + # Resolve template settings + folder_names = [] + template_mcp_servers = [] if data.template_id: - t_template_query_start = time.perf_counter() - tpl_r = await db.execute( - select(AgentTemplate).where(AgentTemplate.id == data.template_id) - ) + tpl_r = await db.execute(select(AgentTemplate).where(AgentTemplate.id == data.template_id)) tpl = tpl_r.scalar_one_or_none() - folder_names = list((tpl.default_skills if tpl else None) or []) - if folder_names: - tpl_skills_r = await db.execute( - select(Skill).where(Skill.folder_name.in_(folder_names)) - ) - template_skill_ids = {s.id for s in tpl_skills_r.scalars().all()} - t_template_query = time.perf_counter() - t_template_query_start - - # Merge user-selected + global default + template-default skill IDs - all_skill_ids = set(data.skill_ids or []) | default_ids | template_skill_ids - - if all_skill_ids: - import asyncio - storage = get_storage_backend() - agent_prefix = agent_manager._agent_storage_prefix(agent.id) - - t_skill_fetch_start = time.perf_counter() - skills_result = await db.execute( - select(Skill).where(Skill.id.in_(all_skill_ids)).options(selectinload(Skill.files)) - ) - skills = skills_result.scalars().all() - t_skill_fetch = time.perf_counter() - t_skill_fetch_start - - file_specs = [ - (f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}", sf.content) - for skill in skills - for sf in skill.files - ] - - if file_specs: - t_upload_start = time.perf_counter() - await asyncio.gather(*[ - storage.write_text(key, content, encoding="utf-8") - for key, content in file_specs - ]) - logger.info( - f"[_skills_copy] agent={agent.id} skills={len(skills)} files={len(file_specs)} " - f"fetch={t_skill_fetch:.2f}s upload={time.perf_counter() - t_upload_start:.2f}s " - f"total={time.perf_counter() - t_skills_copy_start:.2f}s" - ) - else: - logger.info( - f"[_skills_copy] agent={agent.id} no files " - f"fetch={t_skill_fetch:.2f}s total={time.perf_counter() - t_skills_copy_start:.2f}s" - ) - - # Auto-install template-declared MCP servers using the system Smithery key. - # For trading agents, this means shibui/finance lands in the agent's tool - # list at creation time rather than relying on the agent to install it on - # first use via the MCP_INSTALLER skill (which depends on LLM compliance). - # Failures are logged and swallowed — agent creation must not fail because - # an external Smithery call did. - template_mcp_servers = list((tpl.default_mcp_servers if data.template_id and tpl else None) or []) - if template_mcp_servers: - # Commit the in-flight transaction first so the agent row exists in - # the database when import_mcp_from_smithery opens its own session - # to insert AgentTool rows. Without this commit the FK to agents.id - # is invisible to the parallel session and we get a FK violation. - await db.commit() - await db.refresh(agent) - - from app.services.resource_discovery import import_mcp_from_smithery - for server_id in template_mcp_servers: - try: - result_msg = await import_mcp_from_smithery( - server_id=server_id, - agent_id=agent.id, - config={}, # falls back to system Smithery key - ) - if result_msg.startswith("❌"): - logger.warning( - f"[create_agent] MCP pre-install for '{server_id}' " - f"on agent {agent.id} reported error: {result_msg[:200]}" - ) - else: - logger.info( - f"[create_agent] MCP pre-install '{server_id}' " - f"succeeded for agent {agent.id}" - ) - except Exception as e: - logger.warning( - f"[create_agent] MCP pre-install for '{server_id}' " - f"on agent {agent.id} raised: {e}" - ) + if tpl: + folder_names = list(tpl.default_skills or []) + template_mcp_servers = list(tpl.default_mcp_servers or []) - # Start container first (non-blocking if Docker available) - await agent_manager.start_container(db, agent) - await db.flush() + # Prepare return response before transaction is committed + out = await _agent_to_out(db, agent, current_user.id) - # Commit agent and basic setup before async operations - from app.services.okr_agent_hook import hook_new_agent - if agent.tenant_id: - await hook_new_agent(db, agent.id, agent.tenant_id) + # Commit initial state to DB so background task can read the agent row await db.commit() - await db.refresh(agent) - - # MCP import runs in background to avoid blocking the response - if template_mcp_servers: - import asyncio - from app.services.resource_discovery import import_mcp_from_smithery - - async def _background_mcp_import(agent_id: uuid.UUID, server_ids: list[str]): - for server_id in server_ids: - try: - result_msg = await import_mcp_from_smithery( - server_id=server_id, - agent_id=agent_id, - config={}, - ) - if result_msg.startswith("❌"): - logger.warning( - f"[create_agent] MCP pre-install for '{server_id}' " - f"on agent {agent_id} reported error: {result_msg[:200]}" - ) - else: - logger.info( - f"[create_agent] MCP pre-install '{server_id}' " - f"succeeded for agent {agent_id}" - ) - except Exception as e: - logger.warning( - f"[create_agent] MCP pre-install for '{server_id}' " - f"on agent {agent_id} raised: {e}" - ) - asyncio.create_task(_background_mcp_import(agent.id, template_mcp_servers)) + # Dispatch heavy setup to background task + background_tasks.add_task( + _background_agent_setup, + agent_id=agent.id, + personality=data.personality or "", + boundaries=data.boundaries or "", + skill_ids=list(data.skill_ids or []), + template_skill_folder_names=folder_names, + template_mcp_servers=template_mcp_servers, + ) - return await _agent_to_out(db, agent, current_user.id) + return out @router.get("/{agent_id}") @@ -582,10 +605,9 @@ async def get_agent( if agent.creator_id: from sqlalchemy.orm import selectinload from app.models.user import Identity # noqa: F401 + creator_result = await db.execute( - select(User) - .where(User.id == agent.creator_id) - .options(selectinload(User.identity)) + select(User).where(User.id == agent.creator_id).options(selectinload(User.identity)) ) creator = creator_result.scalar_one_or_none() out["creator_username"] = creator.username if creator else None @@ -594,6 +616,7 @@ async def get_agent( effective_tz = agent.timezone if not effective_tz and agent.tenant_id: from app.models.tenant import Tenant + t_result = await db.execute(select(Tenant).where(Tenant.id == agent.tenant_id)) tenant = t_result.scalar_one_or_none() if tenant: @@ -654,7 +677,13 @@ async def get_agent_permissions( if perm.scope_type == "user" and perm.scope_id } ordered_user_ids = [str(uid) for uid in display_user_ids] - ordered_user_ids.sort(key=lambda sid: (users_by_id.get(sid).display_name or users_by_id.get(sid).username or "") if users_by_id.get(sid) else "") + ordered_user_ids.sort( + key=lambda sid: ( + (users_by_id.get(sid).display_name or users_by_id.get(sid).username or "") + if users_by_id.get(sid) + else "" + ) + ) for perm in perms: if perm.scope_type != "user" or not perm.scope_id: continue @@ -720,6 +749,7 @@ async def update_agent_permissions( # Delete existing permissions from sqlalchemy import delete as sql_delete + await db.execute(sql_delete(AgentPermission).where(AgentPermission.agent_id == agent_id)) # Insert new permissions @@ -732,7 +762,14 @@ async def update_agent_permissions( agent.company_access_level = access_level # "Only me" means private to the agent creator, even when an org admin # is managing a company-visible agent created by someone else. - db.add(AgentPermission(agent_id=agent_id, scope_type="user", scope_id=agent.creator_id or current_user.id, access_level="manage")) + db.add( + AgentPermission( + agent_id=agent_id, + scope_type="user", + scope_id=agent.creator_id or current_user.id, + access_level="manage", + ) + ) elif scope_type == "custom": agent.access_mode = "custom" agent.company_access_level = access_level @@ -758,12 +795,14 @@ async def update_agent_permissions( uid = uuid.UUID(str(sid)) if uid not in seen_user_ids: seen_user_ids.add(uid) - db.add(AgentPermission( - agent_id=agent_id, - scope_type="user", - scope_id=uid, - access_level="manage" if uid in required_manager_ids else access_level, - )) + db.add( + AgentPermission( + agent_id=agent_id, + scope_type="user", + scope_id=uid, + access_level="manage" if uid in required_manager_ids else access_level, + ) + ) for uid in required_manager_ids: if uid not in seen_user_ids: db.add(AgentPermission(agent_id=agent_id, scope_type="user", scope_id=uid, access_level="manage")) @@ -776,6 +815,7 @@ async def update_agent_permissions( ) if relationships_changed: from app.api.relationships import _regenerate_relationships_file + await _regenerate_relationships_file(db, agent_id) await db.commit() @@ -806,10 +846,10 @@ async def get_agent_permission_candidates( if search: pattern = f"%{search}%" member_query = member_query.where( - OrgMember.name.ilike(pattern) | - OrgMember.email.ilike(pattern) | - OrgMember.name_translit_full.ilike(pattern) | - OrgMember.name_translit_initial.ilike(pattern) + OrgMember.name.ilike(pattern) + | OrgMember.email.ilike(pattern) + | OrgMember.name_translit_full.ilike(pattern) + | OrgMember.name_translit_initial.ilike(pattern) ) members_result = await db.execute(member_query.order_by(OrgMember.name.asc()).limit(50)) @@ -836,9 +876,7 @@ async def get_agent_permission_candidates( # No platform account yet — find-or-create one from OrgMember info # and link it back so future lookups hit Case 1. try: - u = await get_platform_user_by_org_member( - db, m, agent_tenant_id=agent.tenant_id - ) + u = await get_platform_user_by_org_member(db, m, agent_tenant_id=agent.tenant_id) except Exception: # If user creation fails for any reason, skip this member continue @@ -846,14 +884,16 @@ async def get_agent_permission_candidates( if u is None: continue - candidates.append({ - "id": str(u.id), # always a valid User.id - "name": m.name, - "username": u.username if u else None, - "email": m.email or (u.email if u else None), - "title": m.title or None, - "avatar_url": m.avatar_url or None, - }) + candidates.append( + { + "id": str(u.id), # always a valid User.id + "name": m.name, + "username": u.username if u else None, + "email": m.email or (u.email if u else None), + "title": m.title or None, + "avatar_url": m.avatar_url or None, + } + ) await db.commit() @@ -876,7 +916,9 @@ async def update_agent( is_admin = current_user.role in ("platform_admin", "org_admin") if not is_agent_creator(current_user, agent) and not is_admin: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only creator or admin can update agent settings") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Only creator or admin can update agent settings" + ) update_data = data.model_dump(exclude_unset=True) @@ -885,6 +927,7 @@ async def update_agent( if not is_admin: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only admin can modify agent expiry time") from datetime import datetime, timezone as tz + new_expires = update_data["expires_at"] # Allow any value: extend, shorten, or null (permanent). # Re-activate the agent if new expiry is in the future or cleared. @@ -897,21 +940,25 @@ async def update_agent( clamped_fields = [] # track fields adjusted by tenant floor if "heartbeat_interval_minutes" in update_data and current_user.tenant_id: from app.models.tenant import Tenant + t_result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id)) tenant = t_result.scalar_one_or_none() if tenant and update_data["heartbeat_interval_minutes"] < tenant.min_heartbeat_interval_minutes: update_data["heartbeat_interval_minutes"] = tenant.min_heartbeat_interval_minutes - clamped_fields.append({ - "field": "heartbeat_interval_minutes", - "requested": update_data["heartbeat_interval_minutes"], - "applied": tenant.min_heartbeat_interval_minutes, - "reason": "company_floor", - }) + clamped_fields.append( + { + "field": "heartbeat_interval_minutes", + "requested": update_data["heartbeat_interval_minutes"], + "applied": tenant.min_heartbeat_interval_minutes, + "reason": "company_floor", + } + ) # Enforce trigger limit floors from tenant trigger_fields = {"min_poll_interval_min", "webhook_rate_limit", "max_triggers"} if trigger_fields & set(update_data.keys()) and current_user.tenant_id: from app.models.tenant import Tenant + t_result = await db.execute(select(Tenant).where(Tenant.id == current_user.tenant_id)) tenant = t_result.scalar_one_or_none() if tenant: @@ -919,22 +966,26 @@ async def update_agent( original = update_data["min_poll_interval_min"] update_data["min_poll_interval_min"] = max(original, tenant.min_poll_interval_floor) if update_data["min_poll_interval_min"] != original: - clamped_fields.append({ - "field": "min_poll_interval_min", - "requested": original, - "applied": update_data["min_poll_interval_min"], - "reason": "company_floor", - }) + clamped_fields.append( + { + "field": "min_poll_interval_min", + "requested": original, + "applied": update_data["min_poll_interval_min"], + "reason": "company_floor", + } + ) if "webhook_rate_limit" in update_data: original = update_data["webhook_rate_limit"] update_data["webhook_rate_limit"] = min(original, tenant.max_webhook_rate_ceiling) if update_data["webhook_rate_limit"] != original: - clamped_fields.append({ - "field": "webhook_rate_limit", - "requested": original, - "applied": update_data["webhook_rate_limit"], - "reason": "company_ceiling", - }) + clamped_fields.append( + { + "field": "webhook_rate_limit", + "requested": original, + "applied": update_data["webhook_rate_limit"], + "reason": "company_ceiling", + } + ) for field, value in update_data.items(): setattr(agent, field, value) @@ -943,6 +994,7 @@ async def update_agent( # Sync Participant display_name / avatar if changed if "name" in update_data or "avatar_url" in update_data: from app.models.participant import Participant + p_r = await db.execute(select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id)) p = p_r.scalar_one_or_none() if p: @@ -967,7 +1019,11 @@ async def delete_agent( ): """Delete a digital employee (creator only).""" agent, _access = await check_agent_access(db, current_user, agent_id) - if not is_agent_creator(current_user, agent) and current_user.role not in ("super_admin", "org_admin", "platform_admin"): + if not is_agent_creator(current_user, agent) and current_user.role not in ( + "super_admin", + "org_admin", + "platform_admin", + ): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only creator or admin can delete agent") # System agents (OKR Agent, etc.) cannot be deleted — they are seeded by the @@ -980,6 +1036,7 @@ async def delete_agent( # Stop container and archive files (best effort) from app.services.agent_manager import agent_manager + archive_dir: Path | None = None try: await agent_manager.remove_container(agent) @@ -1082,6 +1139,7 @@ async def start_agent( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only manager can start agent") from app.services.agent_manager import agent_manager + await agent_manager.start_container(db, agent) await db.flush() return await _agent_to_out(db, agent, current_user.id) @@ -1099,6 +1157,7 @@ async def stop_agent( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only manager can stop agent") from app.services.agent_manager import agent_manager + await agent_manager.stop_container(agent) await db.flush() return await _agent_to_out(db, agent, current_user.id) @@ -1117,9 +1176,12 @@ async def list_agent_approvals( """List approval requests for a specific agent. Only creator or admin can view.""" agent, _access = await check_agent_access(db, current_user, agent_id) if not is_agent_creator(current_user, agent) and current_user.role not in ("platform_admin", "org_admin"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only agent creator or admin can view approvals") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Only agent creator or admin can view approvals" + ) from app.models.audit import ApprovalRequest + query = select(ApprovalRequest).where(ApprovalRequest.agent_id == agent_id) if status_filter: query = query.where(ApprovalRequest.status == status_filter) @@ -1154,6 +1216,7 @@ async def resolve_agent_approval( agent, _access = await check_agent_access(db, current_user, agent_id) from app.services.autonomy_service import autonomy_service + action = data.get("action", "reject") try: approval = await autonomy_service.resolve_approval(db, approval_id, current_user, action) @@ -1201,6 +1264,7 @@ async def list_gateway_messages( agent, _access = await check_agent_access(db, current_user, agent_id) from app.models.gateway_message import GatewayMessage + result = await db.execute( select(GatewayMessage) .where(GatewayMessage.agent_id == agent_id) @@ -1215,14 +1279,16 @@ async def list_gateway_messages( if m.sender_agent_id: r = await db.execute(select(Agent.name).where(Agent.id == m.sender_agent_id)) sender_name = r.scalar_one_or_none() - out.append({ - "id": str(m.id), - "sender_agent_name": sender_name, - "content": m.content, - "status": m.status, - "result": m.result, - "created_at": m.created_at.isoformat() if m.created_at else None, - "delivered_at": m.delivered_at.isoformat() if m.delivered_at else None, - "completed_at": m.completed_at.isoformat() if m.completed_at else None, - }) + out.append( + { + "id": str(m.id), + "sender_agent_name": sender_name, + "content": m.content, + "status": m.status, + "result": m.result, + "created_at": m.created_at.isoformat() if m.created_at else None, + "delivered_at": m.delivered_at.isoformat() if m.delivered_at else None, + "completed_at": m.completed_at.isoformat() if m.completed_at else None, + } + ) return out diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index c9c7c89b4..e0058e8b7 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -11766,6 +11766,7 @@ async def _handle_email_tool(tool_name: str, agent_id: uuid.UUID, ws: Path, argu cc=arguments.get("cc"), attachments=arguments.get("attachments"), workspace_path=ws, + agent_id=agent_id, ) elif tool_name == "read_emails": return await read_emails( diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py index 8401b149a..0f1bc4cd3 100644 --- a/backend/app/services/email_service.py +++ b/backend/app/services/email_service.py @@ -9,6 +9,8 @@ import smtplib import ssl import email as email_lib +import uuid +import re from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from email.mime.base import MIMEBase @@ -159,6 +161,7 @@ async def send_email( cc: Optional[str] = None, attachments: Optional[list[str]] = None, workspace_path: Optional[Path] = None, + agent_id: Optional[uuid.UUID] = None, ) -> str: """Send an email via SMTP. @@ -170,6 +173,7 @@ async def send_email( cc: CC recipients, comma-separated attachments: List of workspace-relative file paths to attach workspace_path: Agent workspace root for resolving attachment paths + agent_id: Optional UUID of the agent for retrieving files from storage """ cfg = resolve_config(config) addr = cfg["email_address"] @@ -191,14 +195,39 @@ async def send_email( # Attach files if attachments and workspace_path: + from app.services.storage import get_storage_backend, normalize_storage_key + storage = get_storage_backend() + for rel_path in attachments: - full_path = workspace_path / rel_path - if full_path.exists() and full_path.is_file(): - with open(full_path, "rb") as f: - part = MIMEBase("application", "octet-stream") - part.set_payload(f.read()) + clean_rel = rel_path.replace("\\", "/").strip().lstrip("/") + prefix = str(agent_id) if agent_id else workspace_path.name + storage_key = normalize_storage_key(f"{prefix}/{clean_rel}") + file_bytes = None + filename = Path(clean_rel).name + + # 1. Try to read from the storage backend (e.g. S3 or local storage) + try: + if await storage.exists(storage_key) and await storage.is_file(storage_key): + file_bytes = await storage.read_bytes(storage_key) + except Exception: + pass + + # 2. Fall back to local disk if not found in storage backend + if file_bytes is None: + full_path = workspace_path / rel_path + if full_path.exists() and full_path.is_file(): + try: + with open(full_path, "rb") as f: + file_bytes = f.read() + filename = full_path.name + except Exception: + pass + + if file_bytes is not None: + part = MIMEBase("application", "octet-stream") + part.set_payload(file_bytes) encoders.encode_base64(part) - part.add_header("Content-Disposition", f"attachment; filename={full_path.name}") + part.add_header("Content-Disposition", "attachment", filename=filename) msg.attach(part) try: diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index 867b153e5..2dc66faa8 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -309,7 +309,7 @@ async def _process_tool_call( fn = tc["function"] tool_name = fn["name"] raw_args = fn.get("arguments", "{}") - logger.info(f"[LLM] Calling tool: {tool_name}({json.dumps(raw_args, ensure_ascii=False)[:100]})") + logger.info(f"[LLM] Calling tool: {tool_name}({json.dumps(raw_args, ensure_ascii=False)})") try: args = json.loads(raw_args) if raw_args else {} diff --git a/backend/app/services/tool_seeder.py b/backend/app/services/tool_seeder.py index c1ddbd92d..ab90eb507 100644 --- a/backend/app/services/tool_seeder.py +++ b/backend/app/services/tool_seeder.py @@ -1385,7 +1385,7 @@ def _global_builtin_config(tool_data: dict) -> dict: "attachments": { "type": "array", "items": {"type": "string"}, - "description": "List of workspace-relative file paths to attach (optional)", + "description": "List of workspace-relative file paths to attach (optional). E.g. ['workspace/filename.ext']. Always specify this parameter if the user uploads a file or mentions sending/attaching a file.", }, }, "required": ["to", "subject", "body"], diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 613a253a1..b2a04275c 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -90,7 +90,7 @@ services: API_UPSTREAM: ${API_UPSTREAM:-backend:8000} MINIO_UPSTREAM: ${MINIO_UPSTREAM:-minio:9000} volumes: - - ./deploy/nginx/nginx.conf:/etc/nginx/templates/default.conf.template:ro + - ./nginx/nginx.conf:/etc/nginx/templates/default.conf.template:ro networks: - default depends_on: diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json index 9c93bc6e5..f7cefc213 100644 --- a/frontend/src/i18n/en.json +++ b/frontend/src/i18n/en.json @@ -327,7 +327,15 @@ "showMore": "Show {{count}} more...", "showLess": "Show less", "hideCompleted": "Hide completed", - "showCompleted": "Show {{count}} completed" + "showCompleted": "Show {{count}} completed", + "triggerOnce": "Once at {{time}}", + "triggerEveryMin": "Every {{min}} min", + "triggerEveryHour": "Every {{hour}}h", + "triggerPoll": "Poll: {{url}}", + "triggerOnMessage": "On message from {{sender}}", + "triggerUnknown": "unknown", + "today": "Today", + "viewAll": "View All" }, "soul": { "title": "Soul.md — Personality Definition", diff --git a/frontend/src/i18n/zh.json b/frontend/src/i18n/zh.json index 1c684d42d..964a97ee2 100644 --- a/frontend/src/i18n/zh.json +++ b/frontend/src/i18n/zh.json @@ -334,7 +334,15 @@ "showMore": "显示更多 {{count}} 项...", "showLess": "收起", "hideCompleted": "隐藏已完成", - "showCompleted": "显示 {{count}} 项已完成" + "showCompleted": "显示 {{count}} 项已完成", + "triggerOnce": "一次性:{{time}}", + "triggerEveryMin": "每 {{min}} 分钟", + "triggerEveryHour": "每 {{hour}} 小时", + "triggerPoll": "轮询:{{url}}", + "triggerOnMessage": "收到 {{sender}} 的消息时", + "triggerUnknown": "未知对象", + "today": "今天", + "viewAll": "查看全部" }, "soul": { "title": "Soul.md — 人格定义", diff --git a/frontend/src/pages/agent-detail/AgentDetailPage.tsx b/frontend/src/pages/agent-detail/AgentDetailPage.tsx index 3a2f2191e..5c08f8266 100644 --- a/frontend/src/pages/agent-detail/AgentDetailPage.tsx +++ b/frontend/src/pages/agent-detail/AgentDetailPage.tsx @@ -1964,7 +1964,7 @@ function RelationshipEditor({ agentId, readOnly = false }: { agentId: string; re export default function AgentDetailPage() { const { t, i18n } = useTranslation(); - const messageTimestampLocale = i18n.language?.startsWith('zh') ? 'zh-CN' : 'en-US'; + const tsLocale = i18n.language?.startsWith('zh') ? 'zh-CN' : 'en-US'; const dialog = useDialog(); const toast = useToast(); const { id } = useParams<{ id: string }>(); @@ -3549,9 +3549,9 @@ export default function AgentDetailPage() { const diffMs = now.getTime() - d.getTime(); const isToday = d.toDateString() === now.toDateString(); let timeStr = ''; - if (isToday) timeStr = d.toLocaleTimeString(messageTimestampLocale, { hour: '2-digit', minute: '2-digit' }); - else if (diffMs < 7 * 86400000) timeStr = d.toLocaleDateString(messageTimestampLocale, { weekday: 'short' }) + ' ' + d.toLocaleTimeString(messageTimestampLocale, { hour: '2-digit', minute: '2-digit' }); - else timeStr = d.toLocaleDateString(messageTimestampLocale, { month: 'short', day: 'numeric' }) + ' ' + d.toLocaleTimeString(messageTimestampLocale, { hour: '2-digit', minute: '2-digit' }); + if (isToday) timeStr = d.toLocaleTimeString(tsLocale, { hour: '2-digit', minute: '2-digit' }); + else if (diffMs < 7 * 86400000) timeStr = d.toLocaleDateString(tsLocale, { weekday: 'short' }) + ' ' + d.toLocaleTimeString(tsLocale, { hour: '2-digit', minute: '2-digit' }); + else timeStr = d.toLocaleDateString(tsLocale, { month: 'short', day: 'numeric' }) + ' ' + d.toLocaleTimeString(tsLocale, { hour: '2-digit', minute: '2-digit' }); return (
{timeStr} @@ -3730,14 +3730,23 @@ export default function AgentDetailPage() { attachedFiles.forEach(file => { filesDisplay += `[Attachment: ${file.name}] `; + const wsPath = file.path || ''; + const codePath = wsPath.replace(/^workspace\//, ''); + const fileLoc = wsPath ? `\nFile location: ${wsPath} (for read_file/read_document/send_email tools)\nIn execute_code, use relative path: "${codePath}" (working directory is workspace/)\n` : ''; + if (file.imageUrl && supportsVision) { filesPrompt += `[image_data:${file.imageUrl}]\n`; + if (fileLoc) { + filesPrompt += `[Image File Path Reference]${fileLoc}\n`; + } } else if (file.imageUrl) { - filesPrompt += t('common.file.imageUploaded', '[图片文件已上传: {{name}}...]', { name: file.name }) + '\n'; + filesPrompt += t('common.file.imageUploaded', '[图片文件已上传: {{name}}...]', { name: file.name }); + if (fileLoc) { + filesPrompt += `${fileLoc}\n`; + } else { + filesPrompt += '\n'; + } } else { - const wsPath = file.path || ''; - const codePath = wsPath.replace(/^workspace\//, ''); - const fileLoc = wsPath ? `\nFile location: ${wsPath} (for read_file/read_document tools)\nIn execute_code, use relative path: "${codePath}" (working directory is workspace/)\n` : ''; if (file.source === 'workspace_auto') { filesPrompt += `[Workspace reference: ${file.name}]${fileLoc}\nUse read_file or read_document if you need the file contents.\n\n`; } else { @@ -4162,7 +4171,7 @@ export default function AgentDetailPage() { const canManage = (agent as any).access_level === 'manage'; const formatAgentDate = (d?: string | null) => { if (!d) return '—'; - try { return new Date(d).toLocaleDateString(undefined, { year: 'numeric', month: 'short', day: 'numeric' }); } catch { return d; } + try { return new Date(d).toLocaleDateString(tsLocale, { year: 'numeric', month: 'short', day: 'numeric' }); } catch { return d; } }; const primaryModel = llmModels.find((m: any) => m.id === agent.primary_model_id); const showNoModelState = !llmModelsLoading && (agent as any).agent_type !== 'openclaw' && (enabledModelCount === 0 || !effectiveModelReady); @@ -4198,7 +4207,7 @@ export default function AgentDetailPage() { const expiryLabel = (agent as any).is_expired ? t('agent.settings.expiry.expired') : (agent as any).expires_at - ? new Date((agent as any).expires_at).toLocaleDateString(undefined, { year: 'numeric', month: 'short', day: 'numeric' }) + ? new Date((agent as any).expires_at).toLocaleDateString(tsLocale, { year: 'numeric', month: 'short', day: 'numeric' }) : t('agent.settings.expiry.neverExpires'); const renderAgentInfoCard = () => (
@@ -4328,8 +4337,8 @@ export default function AgentDetailPage() { const isZh = i18n.language?.startsWith('zh'); const formatTrigger = (trig: any) => { if (trig.type === 'cron' && trig.config?.expr) return `Cron ${trig.config.expr}`; - if (trig.type === 'interval' && trig.config?.minutes) return isZh ? `每 ${trig.config.minutes} 分钟` : `Every ${trig.config.minutes} min`; - if (trig.type === 'once' && trig.config?.at) return new Date(trig.config.at).toLocaleString(); + if (trig.type === 'interval' && trig.config?.minutes) return t('agent.aware.triggerEveryMin', { min: trig.config.minutes }); + if (trig.type === 'once' && trig.config?.at) return new Date(trig.config.at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }); return trig.name || trig.type; }; const triggerTitle = (trig: any) => String(trig.reason || trig.name || trig.type || '').trim(); @@ -4480,14 +4489,14 @@ export default function AgentDetailPage() { })(); const calendarRangeLabel = (() => { if (awareCalendarMode === 'day') { - return calendarAnchor.toLocaleDateString(undefined, { year: 'numeric', month: 'short', day: 'numeric', weekday: 'short' }); + return calendarAnchor.toLocaleDateString(tsLocale, { year: 'numeric', month: 'short', day: 'numeric', weekday: 'short' }); } if (awareCalendarMode === 'month') { - return calendarAnchor.toLocaleDateString(undefined, { year: 'numeric', month: 'long' }); + return calendarAnchor.toLocaleDateString(tsLocale, { year: 'numeric', month: 'long' }); } const first = calendarDays[0]; const last = calendarDays[calendarDays.length - 1]; - return `${first.toLocaleDateString(undefined, { month: 'short', day: 'numeric' })} - ${last.toLocaleDateString(undefined, { month: 'short', day: 'numeric', year: 'numeric' })}`; + return `${first.toLocaleDateString(tsLocale, { month: 'short', day: 'numeric' })} - ${last.toLocaleDateString(tsLocale, { month: 'short', day: 'numeric', year: 'numeric' })}`; })(); const shiftCalendar = (direction: -1 | 1) => { setAwareCalendarDate(prev => { @@ -4538,8 +4547,8 @@ export default function AgentDetailPage() { return (
- {day.toLocaleDateString(undefined, awareCalendarMode === 'month' ? { day: 'numeric' } : (awareCalendarMode === 'week' ? { weekday: 'short', day: 'numeric' } : { weekday: 'short', month: 'numeric', day: 'numeric' }))} - {isToday && awareCalendarMode === 'day' && {isZh ? '今天' : 'Today'}} + {day.toLocaleDateString(tsLocale, awareCalendarMode === 'month' ? { day: 'numeric' } : (awareCalendarMode === 'week' ? { weekday: 'short', day: 'numeric' } : { weekday: 'short', month: 'numeric', day: 'numeric' }))} + {isToday && awareCalendarMode === 'day' && {t('agent.aware.today')}}
{items.length === 0 ? (
-
@@ -4656,7 +4665,7 @@ export default function AgentDetailPage() {
{formatReflectionTitle(session.title, !!isZh)}
- {new Date(session.created_at).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })} + {new Date(session.created_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })} {session.message_count > 0 ? ` · ${session.message_count}` : ''}
@@ -5020,13 +5029,13 @@ export default function AgentDetailPage() {

{t('agent.activity.recent', 'Recent Activity')}

- +
{activityLogs.slice(0, 5).map((log: any, i: number) => (
- {new Date(log.created_at).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })} + {new Date(log.created_at).toLocaleTimeString(tsLocale, { hour: '2-digit', minute: '2-digit' })} {log.summary || log.action_type}
@@ -5080,19 +5089,18 @@ export default function AgentDetailPage() { } if (trig.type === 'once' && trig.config?.at) { try { - return isZh - ? `一次性:${new Date(trig.config.at).toLocaleString()}` - : `Once at ${new Date(trig.config.at).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })}`; - } catch { return isZh ? `一次性:${trig.config.at}` : `Once at ${trig.config.at}`; } + const timeStr = new Date(trig.config.at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }); + return t('agent.aware.triggerOnce', { time: timeStr }); + } catch { return t('agent.aware.triggerOnce', { time: trig.config.at }); } } if (trig.type === 'interval' && trig.config?.minutes) { const m = trig.config.minutes; - return isZh ? `每 ${m >= 60 ? `${m / 60} 小时` : `${m} 分钟`}` : (m >= 60 ? `Every ${m / 60}h` : `Every ${m} min`); + return m >= 60 ? t('agent.aware.triggerEveryHour', { hour: m / 60 }) : t('agent.aware.triggerEveryMin', { min: m }); } - if (trig.type === 'poll') return `${isZh ? '轮询' : 'Poll'}: ${trig.config?.url?.substring(0, 40) || 'URL'}`; + if (trig.type === 'poll') return t('agent.aware.triggerPoll', { url: trig.config?.url?.substring(0, 40) || 'URL' }); if (trig.type === 'on_message') { - const sender = trig.config?.from_agent_name || trig.config?.from_user_name || (isZh ? '未知对象' : 'unknown'); - return isZh ? `收到 ${sender} 的消息时` : `On message from ${sender}`; + const sender = trig.config?.from_agent_name || trig.config?.from_user_name || t('agent.aware.triggerUnknown'); + return t('agent.aware.triggerOnMessage', { sender }); } if (trig.type === 'webhook') { return `Webhook${trig.config?.token ? ` (${trig.config.token.substring(0, 6)}...)` : ''}`; @@ -5320,7 +5328,7 @@ export default function AgentDetailPage() { fontWeight: 500, }}>{log.action_type?.replace('trigger_', '')} - {new Date(log.created_at).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })} + {new Date(log.created_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })}
{log.summary}
@@ -5477,8 +5485,8 @@ export default function AgentDetailPage() { {formatReflectionTitle(session.title, !!isZh)}
- {new Date(session.created_at).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })} - {session.message_count > 0 && ` · ${session.message_count} msg`} + {new Date(session.created_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' })} + {session.message_count > 0 && ` · ${session.message_count}`}
{s.last_message_at - ? new Date(s.last_message_at).toLocaleString(i18n.language === 'zh' ? 'zh-CN' : 'en-US', { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }) - : new Date(s.created_at).toLocaleString(i18n.language === 'zh' ? 'zh-CN' : 'en-US', { month: 'short', day: 'numeric' })} + ? new Date(s.last_message_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }) + : new Date(s.created_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric' })} {s.message_count > 0 && {s.message_count}}
@@ -5947,7 +5955,7 @@ export default function AgentDetailPage() {
{s.username || ''} - {s.last_message_at ? new Date(s.last_message_at).toLocaleString(i18n.language === 'zh' ? 'zh-CN' : 'en-US', { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }) : ''}{s.message_count > 0 ? ` · ${s.message_count}` : ''} + {s.last_message_at ? new Date(s.last_message_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' }) : ''}{s.message_count > 0 ? ` · ${s.message_count}` : ''}
); @@ -6602,7 +6610,7 @@ export default function AgentDetailPage() { heartbeat: , plaza_post: , }; - const time = log.created_at ? new Date(log.created_at).toLocaleString('zh-CN', { + const time = log.created_at ? new Date(log.created_at).toLocaleString(tsLocale, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit', second: '2-digit', }) : ''; const isExpanded = expandedLogId === log.id; @@ -6772,7 +6780,7 @@ export default function AgentDetailPage() { {(agent as any).is_expired ? {t('agent.settings.expiry.expired')} : (agent as any).expires_at - ? <>{t('agent.settings.expiry.currentExpiry')} {new Date((agent as any).expires_at).toLocaleString(i18n.language === 'zh' ? 'zh-CN' : 'en-US')} + ? <>{t('agent.settings.expiry.currentExpiry')} {new Date((agent as any).expires_at).toLocaleString(tsLocale)} : {t('agent.settings.expiry.neverExpires')} } From 04c507935ed861393c6ee0dea064e6284367ed72 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 12 Jun 2026 23:10:40 +0800 Subject: [PATCH 2/3] refactor(db): introduce ContextVar DAO layer and shrink transaction boundaries in auth API --- backend/app/api/auth.py | 1087 ++++++++--------- backend/app/api/dingtalk.py | 2 +- backend/app/api/google_workspace.py | 2 +- backend/app/api/notification.py | 3 +- backend/app/api/organization.py | 1 - backend/app/api/sso.py | 4 +- backend/app/api/wecom.py | 1 - backend/app/dao/__init__.py | 19 + backend/app/dao/base.py | 80 ++ backend/app/dao/identity_dao.py | 82 ++ backend/app/dao/identity_provider_dao.py | 61 + backend/app/dao/invitation_code_dao.py | 28 + backend/app/dao/org_member_dao.py | 120 ++ backend/app/dao/participant_dao.py | 32 + backend/app/dao/system_setting_dao.py | 43 + backend/app/dao/tenant_dao.py | 43 + backend/app/dao/user_dao.py | 94 ++ backend/app/database.py | 39 + .../scripts/cleanup_duplicate_feishu_users.py | 2 +- backend/app/services/access_relationships.py | 2 +- backend/app/services/auth_provider.py | 5 +- backend/app/services/auth_registry.py | 43 +- backend/app/services/channel_user_service.py | 2 - backend/app/services/feishu_service.py | 3 +- .../app/services/password_reset_service.py | 12 +- backend/app/services/platform_service.py | 37 +- backend/app/services/registration_service.py | 596 ++++----- backend/app/services/system_email_service.py | 72 +- backend/tests/test_auth.py | 19 +- backend/tests/test_auth_provider.py | 7 +- .../test_password_reset_and_notifications.py | 130 +- backend/tests/test_sso_toggle.py | 32 +- deploy/docker-compose.yml | 1 - 33 files changed, 1581 insertions(+), 1123 deletions(-) create mode 100644 backend/app/dao/__init__.py create mode 100644 backend/app/dao/base.py create mode 100644 backend/app/dao/identity_dao.py create mode 100644 backend/app/dao/identity_provider_dao.py create mode 100644 backend/app/dao/invitation_code_dao.py create mode 100644 backend/app/dao/org_member_dao.py create mode 100644 backend/app/dao/participant_dao.py create mode 100644 backend/app/dao/system_setting_dao.py create mode 100644 backend/app/dao/tenant_dao.py create mode 100644 backend/app/dao/user_dao.py diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 311c08b59..e943dee70 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -2,56 +2,51 @@ import uuid from datetime import datetime, timezone -import uuid - from typing import Any from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request, status from loguru import logger -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.core.security import create_access_token, get_authenticated_user, get_current_user, hash_password_async, verify_password_async -from app.database import get_db -from app.models.user import Identity, User +from app.core.security import ( + create_access_token, + get_authenticated_user, + get_current_user, + hash_password_async, + verify_password_async, +) +from app.dao import identity_dao, system_setting_dao, tenant_dao, user_dao +from app.database import transaction +from app.models.user import User from app.schemas.schemas import ( ForgotPasswordRequest, - ResetPasswordRequest, IdentityBindRequest, + IdentityOut, IdentityUnbindRequest, + MultiTenantResponse, OAuthAuthorizeResponse, OAuthCallbackRequest, - TokenResponse, - UserLogin, - UserOut, - UserRegister, - UserUpdate, - VerifyEmailRequest, - ResendVerificationRequest, - NeedsVerificationResponse, RegisterInitRequest, RegisterInitResponse, - RegisterCompleteRequest, - RegisterCompleteResponse, + ResendVerificationRequest, + ResetPasswordRequest, SSORegisterRequest, TenantChoice, - MultiTenantResponse, - IdentityOut, TenantSwitchRequest, TenantSwitchResponse, + TokenResponse, + UserLogin, + UserOut, + UserRegister, + UserUpdate, + VerifyEmailRequest, ) -from sqlalchemy.orm import selectinload router = APIRouter(prefix="/auth", tags=["auth"]) @router.get("/registration-config") -async def get_registration_config(db: AsyncSession = Depends(get_db)): +async def get_registration_config(): """Public endpoint — returns registration requirements (no auth needed).""" - from app.models.system_settings import SystemSetting - result = await db.execute(select(SystemSetting).where(SystemSetting.key == "invitation_code_enabled")) - setting = result.scalar_one_or_none() - enabled = setting.value.get("enabled", False) if setting else False + enabled = await system_setting_dao.is_invitation_code_enabled() return {"invitation_code_required": enabled} @@ -59,24 +54,18 @@ async def get_registration_config(db: AsyncSession = Depends(get_db)): async def check_duplicate( email: str | None = Query(None, description="Email to check"), username: str | None = Query(None, description="Username to check"), - db: AsyncSession = Depends(get_db), ): """Check if email or username already exists.""" - from app.models.user import Identity, User result = {"email_exists": False, "username_exists": False, "conflicts": []} if email: # Check Identity email - existing = await db.execute( - select(Identity).where(Identity.email == email) - ) - if existing.scalar_one_or_none(): + if await identity_dao.get_by_email(email): result["email_exists"] = True result["conflicts"].append({"type": "email", "scope": "global", "message": "Email already registered"}) if username: - existing = await db.execute(select(Identity).where(Identity.username == username)) - if existing.scalar_one_or_none(): + if await identity_dao.get_by_username(username): result["username_exists"] = True result["conflicts"].append({"type": "username", "scope": "global", "message": "Username already taken"}) @@ -88,31 +77,28 @@ async def _send_verification_email_task( user: User, background_tasks: BackgroundTasks, settings: Any, - db: AsyncSession, ) -> None: """Helper to create verification token and add email task to background tasks.""" - # Check if email is configured — either via DB (platform settings UI) or env vars. - # We must check the DB config too, since most users configure SMTP via the UI. from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) + from app.services.email_verification_service import email_verification_service + + email_config = await resolve_email_config_async() if not email_config: logger.debug("No email config found (env or DB), skipping verification email") return - from app.services.email_verification_service import email_verification_service - try: - # Get identity for this user - res = await db.execute(select(Identity).where(Identity.id == user.identity_id)) - identity = res.scalar_one_or_none() + identity = await identity_dao.get(user.identity_id) if not identity: logger.warning(f"No identity found for user {user.id} ({user.email}). Cannot send verification.") return - raw_code, expires_at = await email_verification_service.create_email_verification_token(identity.id, identity.email) + raw_code, expires_at = await email_verification_service.create_email_verification_token( + identity.id, identity.email + ) expiry_minutes = int((expires_at - datetime.now(timezone.utc)).total_seconds() // 60) - + background_tasks.add_task( email_verification_service.send_verification_email, identity.email, @@ -129,7 +115,6 @@ async def _send_verification_email_task( async def register( data: UserRegister, background_tasks: BackgroundTasks, - db: AsyncSession = Depends(get_db), ): """Legacy registration endpoint - kept for backward compatibility. @@ -139,121 +124,121 @@ async def register( - /verify-email - Step 3: Verify email """ from app.config import get_settings + settings = get_settings() # Handle SSO registration if provider info provided if data.provider and data.provider_code: - return await _handle_sso_register(data, db) + return await _handle_sso_register(data) # Regular username/password registration - delegate to new flow - return await _handle_normal_register(data, background_tasks, db, settings) + return await _handle_normal_register(data, background_tasks, settings) + + @router.post("/register/init", response_model=RegisterInitResponse, status_code=status.HTTP_201_CREATED) async def register_init( data: RegisterInitRequest, background_tasks: BackgroundTasks, - db: AsyncSession = Depends(get_db), ): """Step 1: Initialize registration with account credentials. Creates/finds a global Identity and a tenant-scoped User. """ from app.config import get_settings - settings = get_settings() + from app.services.system_email_service import resolve_email_config_async from app.services.registration_service import registration_service - from app.models.user import Identity, User + settings = get_settings() logger.info(f"[REGISTER_INIT] Starting registration for email={data.email}") - # Resolve email config once - from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) - - # Check if this is the first user (platform admin setup) - Optimize with EXISTS - is_first_user = (await db.execute(select(Identity.id).limit(1))).scalar() is None - - # Find or Create Identity - identity = await registration_service.find_or_create_identity( - db, - email=data.email, - username=data.username, - password=data.password, - is_platform_admin=is_first_user, - email_config=email_config, - ) - # Defense-in-depth: verify the returned identity actually belongs to the - # submitted email. Under normal circumstances this should never trigger - # (find_or_create_identity no longer uses username as a lookup key), but - # this guard protects against future regressions. - if identity.email and identity.email != data.email: - logger.warning( - "[REGISTER_INIT] Identity email mismatch: submitted=%s returned=%s — rejecting", - data.email, - identity.email, - ) - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="Username already taken. Please choose a different username.", - ) + # 1. Resolve email config outside transaction + email_config = await resolve_email_config_async() - # If identity existed, verify password - if identity.password_hash and not await verify_password_async(data.password, identity.password_hash): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Email already registered. Incorrect password." - ) + # 2. Compute hash first (without DB connection checked out) + password_hash = None + if data.password: + password_hash = await hash_password_async(data.password) - # For first user: auto-create/get default tenant - tenant_uuid = None - if is_first_user: - from app.models.tenant import Tenant - default = await db.execute(select(Tenant).where(Tenant.slug == "default")) - tenant = default.scalar_one_or_none() - if not tenant: - tenant = Tenant(name="Default", slug="default", im_provider="web_only") - db.add(tenant) - await db.flush() - tenant_uuid = tenant.id - - # Create User (tenant-scoped) - # Check if user already exists in this tenant (if tenant_uuid is set) - if tenant_uuid: - existing_user_res = await db.execute( - select(User).where(User.identity_id == identity.id, User.tenant_id == tenant_uuid) - ) - user = existing_user_res.scalar_one_or_none() - else: - # Check for a "tenant-less" user (pending company setup) - existing_user_res = await db.execute( - select(User).where(User.identity_id == identity.id, User.tenant_id == None) - ) - user = existing_user_res.scalar_one_or_none() + # 3. Check if this is the first user (platform admin setup) + is_first_user = await identity_dao.is_empty() - if not user: - user = await registration_service.create_user_with_identity( - db, - identity=identity, - display_name=data.display_name or data.username, - role="platform_admin" if is_first_user else "member", - tenant_id=tenant_uuid, + # 4. Check duplicate/existing identity first (outside transaction) + identity = await identity_dao.get_by_email(data.email) + if identity: + # Defense-in-depth: verify the returned identity actually belongs to the submitted email. + if identity.email and identity.email != data.email: + logger.warning( + f"[REGISTER_INIT] Identity email mismatch: submitted={data.email} returned={identity.email} — rejecting" + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Username already taken. Please choose a different username.", + ) + + # Verify password outside transaction + if identity.password_hash and not await verify_password_async(data.password, identity.password_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Email already registered. Incorrect password." + ) + + async with transaction() as session: + # Find or Create Identity inside transaction (handles concurrent creation safely) + identity = await registration_service.find_or_create_identity( + email=data.email, + username=data.username, + password=data.password, + is_platform_admin=is_first_user, + email_config=email_config, + password_hash=password_hash, ) - # Set initial status - user.is_active = is_first_user # Active immediately if first user - user.email_verified = identity.email_verified - await db.flush() - # Generate token + # For first user: auto-create/get default tenant + tenant_uuid = None + if is_first_user: + tenant = await tenant_dao.get_by_slug("default") + if not tenant: + tenant = await tenant_dao.create( + obj_in={ + "name": "Default", + "slug": "default", + "im_provider": "web_only", + } + ) + tenant_uuid = tenant.id + + # Create User (tenant-scoped) + if tenant_uuid: + user = await user_dao.get_by_identity_and_tenant(identity.id, tenant_uuid) + else: + user = await user_dao.get_by_identity_and_tenant(identity.id, None) + + if not user: + user = await registration_service.create_user_with_identity( + identity=identity, + display_name=data.display_name or data.username, + role="platform_admin" if is_first_user else "member", + tenant_id=tenant_uuid, + ) + # Set initial status + user.is_active = is_first_user # Active immediately if first user + user.email_verified = identity.email_verified + await session.flush() + + # 5. Generate token outside transaction token = create_access_token(str(user.id), user.role) - # Send verification email if not verified + # 6. Send verification email if not verified (outside transaction) if not identity.email_verified: - await _send_verification_email_task(user, background_tasks, settings, db) + await _send_verification_email_task(user, background_tasks, settings) return RegisterInitResponse( user_id=user.id, email=identity.email, access_token=token, user=UserOut.model_validate(user), - message="Registration initiated. Please verify your email." if not identity.email_verified else "Registration successful.", + message="Registration initiated. Please verify your email." + if not identity.email_verified + else "Registration successful.", needs_company_setup=user.tenant_id is None, target_tenant_id=data.target_tenant_id, ) @@ -262,7 +247,6 @@ async def register_init( @router.post("/register/sso", response_model=TokenResponse) async def register_sso( data: SSORegisterRequest, - db: AsyncSession = Depends(get_db), ): """SSO registration - completely separate from normal registration flow. @@ -273,29 +257,30 @@ async def register_sso( logger.info(f"[REGISTER_SSO] Starting SSO registration: provider={data.provider}") - # Get provider - auth_provider = await auth_provider_registry.get_provider(db, data.provider) + # Move provider lookup outside transaction + auth_provider = await auth_provider_registry.get_provider(data.provider) if not auth_provider: raise HTTPException(status_code=400, detail=f"Provider '{data.provider}' not supported") - # Perform SSO registration - user, is_new, error = await registration_service.register_with_sso( - db, data.provider, data.code, auth_provider - ) + async with transaction() as session: + # Perform SSO registration + user, is_new, error = await registration_service.register_with_sso( + data.provider, data.code, auth_provider + ) - if error: - raise HTTPException(status_code=400, detail=error) + if error: + raise HTTPException(status_code=400, detail=error) - # If no tenant, check for email domain match - if not user.tenant_id and user.email: - tenant, _ = await registration_service.get_tenant_for_registration( - db, email=user.email, invitation_code=data.invitation_code - ) - if tenant: - user.tenant_id = tenant.id - await db.flush() + # If no tenant, check for email domain match + if not user.tenant_id and user.email: + tenant, _ = await registration_service.get_tenant_for_registration( + email=user.email, invitation_code=data.invitation_code + ) + if tenant: + user.tenant_id = tenant.id + await session.flush() - # Generate token + # Move token generation outside transaction token = create_access_token(str(user.id), user.role) logger.info(f"[REGISTER_SSO] SSO successful: user_id={user.id}, is_new={is_new}") @@ -307,203 +292,187 @@ async def register_sso( ) -async def _handle_normal_register(data: UserRegister, background_tasks: BackgroundTasks, db: AsyncSession, settings): +async def _handle_normal_register(data: UserRegister, background_tasks: BackgroundTasks, settings): """Legacy normal registration handler.""" logger.info(f"[REGISTER_LEGACY] email={data.email}") from app.services.registration_service import registration_service - from sqlalchemy import func - - # Resolve email config once from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) - # Check if first user - Optimize with EXISTS - is_first_user = (await db.execute(select(User.id).limit(1))).scalar() is None + # 1. Compute hash first (without DB connection checked out) + password_hash = None + if data.password: + password_hash = await hash_password_async(data.password) - # Resolve tenant - tenant_uuid = None - if is_first_user: - from app.models.tenant import Tenant - default = await db.execute(select(Tenant).where(Tenant.slug == "default")) - tenant = default.scalar_one_or_none() - if not tenant: - tenant = Tenant(name="Default", slug="default", im_provider="web_only") - db.add(tenant) - await db.flush() - tenant_uuid = tenant.id - role = "platform_admin" - else: - tenant, _ = await registration_service.get_tenant_for_registration( - db, email=data.email, invitation_code=data.invitation_code - ) - if tenant: - tenant_uuid = tenant.id - role = "member" + # 2. Resolve email config once outside transaction + email_config = await resolve_email_config_async() - # 1. Check for existing Identity/Tenant-User - from app.services.registration_service import registration_service - - # Check if this email is already registered globally - identity_query = select(Identity).where(Identity.email == data.email) - ident_res = await db.execute(identity_query) - identity = ident_res.scalar_one_or_none() + # 3. Check if first user outside transaction + is_first_user = await user_dao.is_empty() + # 4. Check if this email is already registered globally outside transaction + identity = await identity_dao.get_by_email(data.email) if identity: raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="Email already registered, please login directly." + status_code=status.HTTP_409_CONFLICT, detail="Email already registered, please login directly." ) - - # 2. Uniqueness Check (Already handled by Identity lookup above, but let's be explicit for Phone if needed) - # conflicts = await registration_service.check_duplicate_identity(db, email=data.email) - # ... - - # 3. Resolve or create Identity - # If it's the first user, we auto-verify (trusted admin) - identity = await registration_service.find_or_create_identity( - db, - email=data.email, - username=data.username, - password=data.password, - is_platform_admin=is_first_user, - email_config=email_config, - ) - # Defense-in-depth: verify the returned identity actually belongs to the - # submitted email. Should be unreachable after the username-lookup fix, but - # acts as a safety net against future regressions. - if identity.email and identity.email != data.email: - logger.warning( - "[REGISTER_LEGACY] Identity email mismatch: submitted=%s returned=%s — rejecting", - data.email, - identity.email, + async with transaction() as session: + # Resolve tenant + tenant_uuid = None + if is_first_user: + tenant = await tenant_dao.get_by_slug("default") + if not tenant: + tenant = await tenant_dao.create( + obj_in={ + "name": "Default", + "slug": "default", + "im_provider": "web_only", + } + ) + tenant_uuid = tenant.id + role = "platform_admin" + else: + tenant, _ = await registration_service.get_tenant_for_registration( + email=data.email, invitation_code=data.invitation_code + ) + if tenant: + tenant_uuid = tenant.id + role = "member" + + # Resolve or create Identity inside transaction + identity = await registration_service.find_or_create_identity( + email=data.email, + username=data.username, + password=data.password, + is_platform_admin=is_first_user, + email_config=email_config, + password_hash=password_hash, ) - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="Username already taken. Please choose a different username.", + + # Defense-in-depth: verify the returned identity actually belongs to the submitted email. + if identity.email and identity.email != data.email: + logger.warning( + f"[REGISTER_LEGACY] Identity email mismatch: submitted={data.email} returned={identity.email} — rejecting" + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Username already taken. Please choose a different username.", + ) + + if is_first_user: + identity.email_verified = True + identity.is_active = True + await session.flush() + + # Create Tenant User + user = await registration_service.create_user_with_identity( + identity=identity, + display_name=data.display_name or data.username, + role=role, + tenant_id=tenant_uuid, + registration_source="web", + email_config=email_config, ) - - if is_first_user: - identity.email_verified = True - identity.is_active = True - await db.flush() - - # 4. Create Tenant User (Handles OrgMember binding and Participant creation) - user = await registration_service.create_user_with_identity( - db, - identity=identity, - display_name=data.display_name or data.username, - role=role, - tenant_id=tenant_uuid, - registration_source="web", - email_config=email_config, - ) - # Seed default agents for first user + # 5. Seed default agents for first user outside main registration transaction block if is_first_user: - await db.commit() try: from app.services.agent_seeder import seed_default_agents await seed_default_agents() except Exception as e: logger.warning(f"Failed to seed default agents: {e}") - # Send verification email only when the identity still needs it. If the - # platform has no system email configured, registration_service auto-verifies - # the identity so local/self-hosted installs are not blocked. + # 6. Send verification email only when the identity still needs it (outside transaction) if not identity.email_verified: - await _send_verification_email_task(user, background_tasks, settings, db) + await _send_verification_email_task(user, background_tasks, settings) - return RegisterInitResponse( + # 7. Generate access token and build response payload outside transaction + token = create_access_token(str(user.id), user.role) + response_data = RegisterInitResponse( user_id=user.id, email=user.email, - access_token=create_access_token(str(user.id), user.role), + access_token=token, user=UserOut.model_validate(user), - message="Registration successful. Please verify your email." if not identity.email_verified else "Registration successful.", + message="Registration successful. Please verify your email." + if not identity.email_verified + else "Registration successful.", needs_company_setup=user.tenant_id is None, ) + return response_data + -async def _handle_sso_register(data: UserRegister, db: AsyncSession): +async def _handle_sso_register(data: UserRegister): """Legacy SSO registration handler - delegates to new SSO endpoint logic.""" # Redirect to new SSO flow - sso_data = SSORegisterRequest( - provider=data.provider, - code=data.provider_code, - invitation_code=data.invitation_code - ) - return await register_sso(sso_data, db) - + sso_data = SSORegisterRequest(provider=data.provider, code=data.provider_code, invitation_code=data.invitation_code) + return await register_sso(sso_data) @router.post("/login", response_model=Any) -async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSession = Depends(get_db)): +async def login(data: UserLogin, background_tasks: BackgroundTasks): """Login with email/phone/username and password. Supports multi-tenant selection.""" - from app.models.tenant import Tenant - from app.models.user import Identity, User - # 1. Query Identity - query = select(Identity).where( - (Identity.email == data.login_identifier) | - (Identity.phone == data.login_identifier) | - (Identity.username == data.login_identifier) - ) - result = await db.execute(query) - identity = result.scalar_one_or_none() + identity = await identity_dao.get_by_login_identifier(data.login_identifier) - if not identity or not identity.password_hash or not await verify_password_async(data.password, identity.password_hash): - logger.warning(f"[LOGIN] Invalid credentials for {data.login_identifier} identity_id={identity.id if identity else 'None'}") + if ( + not identity + or not identity.password_hash + or not await verify_password_async(data.password, identity.password_hash) + ): + logger.warning( + f"[LOGIN] Invalid credentials for {data.login_identifier} identity_id={identity.id if identity else 'None'}" + ) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") # 2. Check Global Activity & Verification if not identity.is_active: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Your account has been disabled.") + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Your account has been disabled.") if not identity.email_verified: from app.config import get_settings - from sqlalchemy import update from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) + email_config = await resolve_email_config_async() + if not email_config: - identity.email_verified = True - identity.is_active = True - await db.execute( - update(User) - .where(User.identity_id == identity.id) - .values(is_active=True) - ) - await db.flush() + # SMTP missing: auto-verify users under a transaction + async with transaction(): + tx_identity = await identity_dao.get(identity.id) + if tx_identity: + tx_identity.email_verified = True + tx_identity.is_active = True + identity.email_verified = True + identity.is_active = True + users = await user_dao.get_by_identity_id(tx_identity.id) + for u in users: + u.is_active = True else: # Find any user record (just for the task) - user_res = await db.execute(select(User).where(User.identity_id == identity.id).limit(1)) - user = user_res.scalar_one_or_none() - + user = await user_dao.get_representative_user_for_identity(identity.id) + # Trigger email delivery in background if user: - await _send_verification_email_task(user, background_tasks, get_settings(), db) - + await _send_verification_email_task(user, background_tasks, get_settings()) + # Consistent with identity-first flow: Return 403 Forbidden with verification intent raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail={ "needs_verification": True, "email": identity.email, - "message": "Please verify your email to continue." - } + "message": "Please verify your email to continue.", + }, ) # 3. Find all User records (tenants) - result = await db.execute(select(User).where(User.identity_id == identity.id).options(selectinload(User.identity))) - valid_users = list(result.scalars().all()) + valid_users = await user_dao.get_by_identity_id(identity.id, include_identity=True) if not valid_users: - # User has an identity but no tenant records? Should they create one? - # Create a "tenant-less" user if needed, or redirect to company setup - # For now, if no users, they need company setup. - # But wait, register_init should have created one. - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No organization associated with this account.") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="No organization associated with this account." + ) # 4. Handle Tenant Selection if not data.tenant_id: @@ -512,20 +481,20 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes tenant_ids = [u.tenant_id for u in valid_users if u.tenant_id] tenants_map = {} if tenant_ids: - tenants_result = await db.execute( - select(Tenant).where(Tenant.id.in_(tenant_ids)) - ) - tenants_map = {str(t.id): t for t in tenants_result.scalars().all()} + tenants_result = await tenant_dao.get_by_ids(tenant_ids) + tenants_map = {str(t.id): t for t in tenants_result} tenant_choices = [] for u in valid_users: tenant = tenants_map.get(str(u.tenant_id)) if u.tenant_id else None - tenant_choices.append(TenantChoice( - tenant_id=u.tenant_id, - tenant_name=tenant.name if tenant else "Create or Join Organization", - tenant_slug=tenant.slug if tenant else "", - logo_url=tenant.logo_url if tenant else None, - )) + tenant_choices.append( + TenantChoice( + tenant_id=u.tenant_id, + tenant_name=tenant.name if tenant else "Create or Join Organization", + tenant_slug=tenant.slug if tenant else "", + logo_url=tenant.logo_url if tenant else None, + ) + ) return MultiTenantResponse( requires_tenant_selection=True, @@ -537,22 +506,17 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes user = valid_users[0] else: # Specific tenant requested (Dedicated Link flow) - # Search for the user record in that tenant user = next((u for u in valid_users if u.tenant_id == data.tenant_id), None) - + # Cross-tenant access check if not user: - # Even platform admins must have a valid record in the targeted tenant - # when logging in via a dedicated tenant URL / tenant_id. - raise HTTPException( + raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="This account does not belong to the selected organization.", ) - if user.tenant_id: - t_result = await db.execute(select(Tenant).where(Tenant.id == user.tenant_id)) - tenant = t_result.scalar_one_or_none() + tenant = await tenant_dao.get(user.tenant_id) if tenant and not tenant.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -570,26 +534,24 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes @router.get("/email-hint") -async def get_email_hint(username: str, db: AsyncSession = Depends(get_db)): +async def get_email_hint(username: str): """Return a hinted email address for a given username.""" - from app.models.user import Identity - result = await db.execute(select(Identity).where(Identity.username == username)) - identity = result.scalar_one_or_none() - + identity = await identity_dao.get_by_username(username) + if not identity or not identity.email: raise HTTPException(status_code=404, detail="Account not found.") - + email = identity.email parts = email.split("@") if len(parts) == 2: name, domain = parts - + # Obfuscate name if len(name) <= 2: obs_name = name[0] + "***" else: obs_name = name[:2] + "***" + name[-1] - + # Obfuscate domain domain_parts = domain.split(".") if len(domain_parts) >= 2: @@ -604,7 +566,7 @@ async def get_email_hint(username: str, db: AsyncSession = Depends(get_db)): hint = f"{obs_name}@{domain}" else: hint = email[:3] + "***" - + return {"hint": hint} @@ -612,16 +574,16 @@ async def get_email_hint(username: str, db: AsyncSession = Depends(get_db)): async def forgot_password( data: ForgotPasswordRequest, background_tasks: BackgroundTasks, - db: AsyncSession = Depends(get_db), ): """Request a password reset link for a global Identity.""" from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) + + email_config = await resolve_email_config_async() if not email_config: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Password reset is currently unavailable (no mail server configured)." + detail="Password reset is currently unavailable (no mail server configured).", ) generic_response = { @@ -630,22 +592,18 @@ async def forgot_password( } # Find Identity by email - identity_query = select(Identity).where(Identity.email == data.email) - identity_result = await db.execute(identity_query) - identity = identity_result.scalar_one_or_none() - + identity = await identity_dao.get_by_email(data.email) + if not identity or not identity.is_active: return generic_response try: from app.services.password_reset_service import build_password_reset_url, create_password_reset_token - from app.services.system_email_service import ( - send_password_reset_email, - ) + from app.services.system_email_service import send_password_reset_email raw_token, expires_at = await create_password_reset_token(identity.id) - reset_url = await build_password_reset_url(db, raw_token) + reset_url = await build_password_reset_url(raw_token) expiry_minutes = int((expires_at - datetime.now(timezone.utc)).total_seconds() // 60) background_tasks.add_task( send_password_reset_email, @@ -661,26 +619,27 @@ async def forgot_password( @router.post("/reset-password") -async def reset_password(data: ResetPasswordRequest, db: AsyncSession = Depends(get_db)): +async def reset_password(data: ResetPasswordRequest): """Reset a password using a valid single-use token.""" from app.services.password_reset_service import consume_password_reset_token + # Consume token outside transaction token_data = await consume_password_reset_token(data.token) if not token_data: raise HTTPException(status_code=400, detail="Invalid or expired reset token") identity_id = token_data["identity_id"] - result = await db.execute(select(Identity).where(Identity.id == identity_id)) - identity = result.scalar_one_or_none() - - if not identity or not identity.is_active: - raise HTTPException(status_code=400, detail="Invalid or expired reset token") + # Hash new password outside transaction (CPU intensive) new_hash = await hash_password_async(data.new_password) - identity.password_hash = new_hash - await db.flush() - await db.commit() + # Perform DB update in a brief transaction (single select and update) + async with transaction(): + identity = await identity_dao.get(identity_id) + if not identity or not identity.is_active: + raise HTTPException(status_code=400, detail="Invalid or expired reset token") + identity.password_hash = new_hash + return {"ok": True} @@ -696,80 +655,67 @@ async def get_me(current_user: User = Depends(get_authenticated_user)): async def update_me( data: UserUpdate, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), ): """Update current user profile.""" update_data = data.model_dump(exclude_unset=True) - # Validate username uniqueness if changing - if "username" in update_data and update_data["username"] != current_user.username: - existing = await db.execute( - select(User) - .join(Identity, User.identity_id == Identity.id) - .where(Identity.username == update_data["username"]) - ) - if existing.scalars().first(): - raise HTTPException(status_code=409, detail="Username already taken") - - # Validate email uniqueness within tenant if changing - if "email" in update_data and update_data["email"] != current_user.email: - existing = await db.execute( - select(User) - .join(Identity, User.identity_id == Identity.id) - .where( - Identity.email == update_data["email"], - User.tenant_id == current_user.tenant_id, - User.id != current_user.id, + async with transaction() as session: + # Fetch current user in the transaction session + user = await user_dao.get_with_identity(current_user.id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Validate username uniqueness if changing + if "username" in update_data and update_data["username"] != user.identity.username: + existing = await user_dao.get_by_identity_username(update_data["username"]) + if existing: + raise HTTPException(status_code=409, detail="Username already taken") + + # Validate email uniqueness within tenant if changing + if "email" in update_data and update_data["email"] != user.identity.email: + existing = await user_dao.get_by_email_and_tenant( + email=update_data["email"], + tenant_id=user.tenant_id, + exclude_user_id=user.id, ) - ) - if existing.scalar_one_or_none(): - raise HTTPException(status_code=409, detail="Email already registered") - - # Validate mobile uniqueness within tenant if changing - if "primary_mobile" in update_data and update_data["primary_mobile"] != current_user.primary_mobile: - existing = await db.execute( - select(User) - .join(Identity, User.identity_id == Identity.id) - .where( - Identity.phone == update_data["primary_mobile"], - User.tenant_id == current_user.tenant_id, - User.id != current_user.id, + if existing: + raise HTTPException(status_code=409, detail="Email already registered") + + # Validate mobile uniqueness within tenant if changing + if "primary_mobile" in update_data and update_data["primary_mobile"] != user.identity.phone: + existing = await user_dao.get_by_phone_and_tenant( + phone=update_data["primary_mobile"], + tenant_id=user.tenant_id, + exclude_user_id=user.id, ) - ) - if existing.scalar_one_or_none(): - raise HTTPException(status_code=409, detail="Mobile already registered") - - for field, value in update_data.items(): - setattr(current_user, field, value) - await db.commit() - await db.refresh(current_user) - - # Sync email/phone to OrgMember if changed - if "email" in update_data or "primary_mobile" in update_data: - from app.services.registration_service import registration_service - await registration_service.sync_org_member_contact_from_user( - db, - current_user, - sync_email="email" in update_data, - sync_phone="primary_mobile" in update_data, - ) + if existing: + raise HTTPException(status_code=409, detail="Mobile already registered") + + for field, value in update_data.items(): + setattr(user, field, value) - return UserOut.model_validate(current_user) + await session.flush() + + # Sync email/phone to OrgMember if changed + if "email" in update_data or "primary_mobile" in update_data: + from app.services.registration_service import registration_service + + await registration_service.sync_org_member_contact_from_user( + user, + sync_email="email" in update_data, + sync_phone="primary_mobile" in update_data, + ) + + return UserOut.model_validate(user) @router.get("/my-tenants", response_model=list[TenantChoice]) async def get_my_tenants( current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), ): """Get all tenants associated with the current user's identity.""" - from app.models.tenant import Tenant - # 1. Get all user records for this identity - result = await db.execute( - select(User).where(User.identity_id == current_user.identity_id) - ) - users = result.scalars().all() + users = await user_dao.get_by_identity_id(current_user.identity_id) # 2. Extract tenant IDs tenant_ids = [u.tenant_id for u in users if u.tenant_id] @@ -777,10 +723,7 @@ async def get_my_tenants( return [] # 3. Get tenant details - result = await db.execute( - select(Tenant).where(Tenant.id.in_(tenant_ids)) - ) - tenants = result.scalars().all() + tenants = await tenant_dao.get_by_ids(tenant_ids) return [ TenantChoice( @@ -788,7 +731,8 @@ async def get_my_tenants( tenant_name=t.name, tenant_slug=t.slug, logo_url=t.logo_url, - ) for t in tenants + ) + for t in tenants ] @@ -797,75 +741,52 @@ async def switch_tenant( data: TenantSwitchRequest, request: Request, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), ): """Switch to a different tenant and return a new token and redirect URL.""" - from app.models.tenant import Tenant - from app.models.system_settings import SystemSetting - # 1. Verify membership - result = await db.execute( - select(User).where( - User.identity_id == current_user.identity_id, - User.tenant_id == data.tenant_id - ) - ) - target_user = result.scalar_one_or_none() + target_user = await user_dao.get_by_identity_and_tenant(current_user.identity_id, data.tenant_id) if not target_user: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You do not have access to this organization." + status_code=status.HTTP_403_FORBIDDEN, detail="You do not have access to this organization." ) # 2. Get tenant details - result = await db.execute(select(Tenant).where(Tenant.id == data.tenant_id)) - tenant = result.scalar_one_or_none() + tenant = await tenant_dao.get(data.tenant_id) if not tenant or not tenant.is_active: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="This organization is currently unavailable." + status_code=status.HTTP_403_FORBIDDEN, detail="This organization is currently unavailable." ) # 3. Generate new token token = create_access_token(str(target_user.id), target_user.role) # 4. Determine redirect URL - # Determine redirect URL (Priority: sso_domain > ENV > Request > Fallback) from app.services.platform_service import platform_service - from app.models.system_settings import SystemSetting - - # Check if custom domain SSO redirect is enabled globally - setting_result = await db.execute( - select(SystemSetting).where(SystemSetting.key == "sso_custom_domain_redirect_enabled") - ) - setting_s = setting_result.scalar_one_or_none() - sso_redirect_enabled = setting_s.value.get("enabled", True) if setting_s else True + + sso_redirect_enabled = await system_setting_dao.is_sso_custom_domain_redirect_enabled() if not sso_redirect_enabled: redirect_url = None else: - redirect_url = await platform_service.get_tenant_sso_base_url(db, tenant, request) - + async with tenant_dao.session() as session: + redirect_url = await platform_service.get_tenant_sso_base_url( + session, tenant, request, sso_redirect_enabled=sso_redirect_enabled + ) # Include token in redirect URL for cross-domain switching if needed if redirect_url: separator = "&" if "?" in redirect_url else "?" redirect_url = f"{redirect_url}{separator}token={token}" - return TenantSwitchResponse( - access_token=token, - redirect_url=redirect_url, - message="Switching organization..." - ) + return TenantSwitchResponse(access_token=token, redirect_url=redirect_url, message="Switching organization...") @router.put("/me/password") async def change_password( data: dict, current_user: User = Depends(get_authenticated_user), - db: AsyncSession = Depends(get_db), ): """Change current user's password. Updates the global identity password.""" old_password = data.get("old_password", "") @@ -877,19 +798,30 @@ async def change_password( if len(new_password) < 6: raise HTTPException(status_code=400, detail="New password must be at least 6 characters") - # Access identity through current_user (TenantUser) - res = await db.execute(select(User).where(User.id == current_user.id).options(selectinload(User.identity))) - user = res.scalar_one() + # Look up user & identity outside transaction + user = await user_dao.get_with_identity(current_user.id) + if not user: + raise HTTPException(status_code=404, detail="User not found") identity = user.identity - if not identity or not identity.password_hash or not await verify_password_async(old_password, identity.password_hash): + # Verify old password outside transaction (CPU intensive) + if ( + not identity + or not identity.password_hash + or not await verify_password_async(old_password, identity.password_hash) + ): raise HTTPException(status_code=400, detail="Current password is incorrect") + # Compute new hash outside transaction (CPU intensive) new_hash = await hash_password_async(new_password) - identity.password_hash = new_hash - - await db.flush() - await db.commit() + + # Perform DB update in a brief transaction + async with transaction(): + tx_identity = await identity_dao.get(identity.id) + if not tx_identity: + raise HTTPException(status_code=404, detail="Identity not found") + tx_identity.password_hash = new_hash + return {"ok": True} @@ -898,14 +830,16 @@ async def change_password( @router.get("/providers") async def list_providers( - db: AsyncSession = Depends(get_db), tenant_id: uuid.UUID | None = Query(None, description="Optional tenant ID"), ): """List all available identity providers.""" from app.services.auth_registry import auth_provider_registry - providers = await auth_provider_registry.list_providers(db, str(tenant_id) if tenant_id else None) - return [{"id": str(p.id), "provider_type": p.provider_type, "name": p.name, "is_active": p.is_active} for p in providers] + providers = await auth_provider_registry.list_providers(str(tenant_id) if tenant_id else None) + return [ + {"id": str(p.id), "provider_type": p.provider_type, "name": p.name, "is_active": p.is_active} + for p in providers + ] # Redis keys for OAuth two-step tenant selection @@ -922,12 +856,15 @@ async def _cache_oauth_pending( """Store OAuth intermediate data in Redis for the two-step tenant-selection flow.""" import json from app.core.events import get_redis + r = await get_redis() - payload = json.dumps({ - "provider_type": provider_type, - "user_info": user_info_dict, - "token_data": token_data, - }) + payload = json.dumps( + { + "provider_type": provider_type, + "user_info": user_info_dict, + "token_data": token_data, + } + ) await r.set(f"{_OAUTH_PENDING_PREFIX}{pending_token}", payload, ex=_OAUTH_PENDING_TTL) @@ -935,6 +872,7 @@ async def _get_oauth_pending(pending_token: str) -> dict | None: """Retrieve (and delete) cached OAuth data from Redis. Returns None if expired/missing.""" import json from app.core.events import get_redis + r = await get_redis() raw = await r.get(f"{_OAUTH_PENDING_PREFIX}{pending_token}") if not raw: @@ -949,14 +887,12 @@ async def authorize( provider: str, redirect_uri: str = Query(..., description="OAuth callback URI"), state: str = Query("", description="CSRF state parameter"), - db: AsyncSession = Depends(get_db), ): """Start OAuth authorization flow for a provider.""" from app.services.auth_registry import auth_provider_registry - from app.services.sso_service import sso_service # Get provider - auth_provider = await auth_provider_registry.get_provider(db, provider) + auth_provider = await auth_provider_registry.get_provider(provider) if not auth_provider: raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported") @@ -976,7 +912,6 @@ async def authorize( async def oauth_callback( provider: str, data: OAuthCallbackRequest, - db: AsyncSession = Depends(get_db), ): """Handle OAuth callback — supports a two-step flow for multi-tenant selection. @@ -987,7 +922,6 @@ async def oauth_callback( call find_or_create_user with the chosen tenant_id, return TokenResponse. """ import uuid as _uuid - from app.models.tenant import Tenant from app.services.auth_registry import auth_provider_registry # ── Step 2: User has selected a tenant ─────────────────────────────────── @@ -999,7 +933,7 @@ async def oauth_callback( detail="OAuth session expired or invalid. Please sign in again.", ) - auth_provider = await auth_provider_registry.get_provider(db, pending["provider_type"]) + auth_provider = await auth_provider_registry.get_provider(pending["provider_type"]) if not auth_provider: raise HTTPException( status_code=404, @@ -1007,13 +941,15 @@ async def oauth_callback( ) from app.services.auth_provider import ExternalUserInfo + user_info = ExternalUserInfo(**pending["user_info"]) - user, _ = await auth_provider.find_or_create_user(db, user_info, tenant_id=data.tenant_id) - if not user: - raise HTTPException(status_code=500, detail="Failed to create user") - if not user.is_active: - raise HTTPException(status_code=403, detail="Account is disabled") + async with transaction() as session: + user, _ = await auth_provider.find_or_create_user(session, user_info, tenant_id=data.tenant_id) + if not user: + raise HTTPException(status_code=500, detail="Failed to create user") + if not user.is_active: + raise HTTPException(status_code=403, detail="Account is disabled") jwt_token = create_access_token(str(user.id), user.role) return TokenResponse( @@ -1026,77 +962,82 @@ async def oauth_callback( if not data.code: raise HTTPException(status_code=400, detail="Missing authorization code") - auth_provider = await auth_provider_registry.get_provider(db, provider) + auth_provider = await auth_provider_registry.get_provider(provider) if not auth_provider: raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported") try: + # Perform external network requests outside transaction token_data = await auth_provider.exchange_code_for_token(data.code, data.redirect_uri) access_token = token_data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="Failed to get access token from provider") user_info = await auth_provider.get_user_info(access_token) - user, is_new = await auth_provider.find_or_create_user(db, user_info) - - if not user: - raise HTTPException(status_code=500, detail="Failed to create user") - if not user.is_active: - raise HTTPException(status_code=403, detail="Account is disabled") - except HTTPException: raise except Exception as e: logger.error(f"OAuth callback failed for {provider}: {e}") raise HTTPException(status_code=500, detail="OAuth authentication failed") - # Check if this identity has multiple tenant memberships - if user.identity_id: - all_users_result = await db.execute( - select(User).where(User.identity_id == user.identity_id) - ) - all_users = list(all_users_result.scalars().all()) - tenant_users = [u for u in all_users if u.tenant_id is not None] - - if len(tenant_users) > 1: - # Cache the full user_info in Redis so Step 2 can reconstruct it - pending_token = _uuid.uuid4().hex - await _cache_oauth_pending( - pending_token, - provider, - { - "provider_type": user_info.provider_type, - "provider_union_id": user_info.provider_union_id, - "provider_user_id": user_info.provider_user_id, - "name": user_info.name, - "email": user_info.email, - "avatar_url": user_info.avatar_url, - "mobile": user_info.mobile, - "raw_data": user_info.raw_data, - }, - token_data, - ) + tenant_users = [] + tenants_map = {} - tenant_ids = [u.tenant_id for u in tenant_users] - tenants_result = await db.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))) - tenants_map = {str(t.id): t for t in tenants_result.scalars().all()} + async with transaction() as session: + user, is_new = await auth_provider.find_or_create_user(session, user_info) - tenant_choices = [ - TenantChoice( - tenant_id=u.tenant_id, - tenant_name=tenants_map[str(u.tenant_id)].name if str(u.tenant_id) in tenants_map else "Unknown", - tenant_slug=tenants_map[str(u.tenant_id)].slug if str(u.tenant_id) in tenants_map else "", - logo_url=tenants_map[str(u.tenant_id)].logo_url if str(u.tenant_id) in tenants_map else None, - ) - for u in tenant_users - ] + if not user: + raise HTTPException(status_code=500, detail="Failed to create user") + if not user.is_active: + raise HTTPException(status_code=403, detail="Account is disabled") - return MultiTenantResponse( - requires_tenant_selection=True, - login_identifier=user_info.email or "", - tenants=tenant_choices, - pending_token=pending_token, + # Check if this identity has multiple tenant memberships + if user.identity_id: + all_users = await user_dao.get_by_identity_id(user.identity_id) + tenant_users = [u for u in all_users if u.tenant_id is not None] + + if len(tenant_users) > 1: + tenant_ids = [u.tenant_id for u in tenant_users] + tenants_result = await tenant_dao.get_by_ids(tenant_ids) + tenants_map = {str(t.id): t for t in tenants_result} + + if len(tenant_users) > 1: + # Cache the full user_info in Redis so Step 2 can reconstruct it (outside transaction) + pending_token = _uuid.uuid4().hex + await _cache_oauth_pending( + pending_token, + provider, + { + "provider_type": user_info.provider_type, + "provider_union_id": user_info.provider_union_id, + "provider_user_id": user_info.provider_user_id, + "name": user_info.name, + "email": user_info.email, + "avatar_url": user_info.avatar_url, + "mobile": user_info.mobile, + "raw_data": user_info.raw_data, + }, + token_data, + ) + + tenant_choices = [ + TenantChoice( + tenant_id=u.tenant_id, + tenant_name=tenants_map[str(u.tenant_id)].name + if str(u.tenant_id) in tenants_map + else "Unknown", + tenant_slug=tenants_map[str(u.tenant_id)].slug if str(u.tenant_id) in tenants_map else "", + logo_url=tenants_map[str(u.tenant_id)].logo_url if str(u.tenant_id) in tenants_map else None, ) + for u in tenant_users + ] + + return MultiTenantResponse( + requires_tenant_selection=True, + login_identifier=user_info.email or "", + tenants=tenant_choices, + pending_token=pending_token, + ) # Single tenant (or new user with no tenant yet) — issue token directly jwt_token = create_access_token(str(user.id), user.role) @@ -1112,49 +1053,49 @@ async def bind_identity( provider: str, data: IdentityBindRequest, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), ): """Bind an external identity to the current user.""" from app.services.auth_registry import auth_provider_registry from app.services.sso_service import sso_service - # Get provider - auth_provider = await auth_provider_registry.get_provider(db, provider) + # Get provider outside transaction + auth_provider = await auth_provider_registry.get_provider(provider) if not auth_provider: raise HTTPException(status_code=404, detail=f"Provider '{provider}' not supported") try: - # Exchange code for token + # Exchange code for token (network call) outside transaction token_data = await auth_provider.exchange_code_for_token(data.code) access_token = token_data.get("access_token") if not access_token: raise HTTPException(status_code=400, detail="Failed to get access token from provider") - # Get user info + # Get user info (network call) outside transaction user_info = await auth_provider.get_user_info(access_token) - # Check if identity is already linked to another user - lookup_provider_user_id = user_info.provider_user_id - existing_user = await sso_service.check_duplicate_identity( - db, - provider, - lookup_provider_user_id, - identity_data=user_info.raw_data, - ) - if existing_user and existing_user.id != current_user.id: - raise HTTPException( - status_code=409, - detail="This identity is already linked to another account", + async with transaction() as session: + # Check if identity is already linked to another user + lookup_provider_user_id = user_info.provider_user_id + existing_user = await sso_service.check_duplicate_identity( + session, + provider, + lookup_provider_user_id, + identity_data=user_info.raw_data, ) + if existing_user and existing_user.id != current_user.id: + raise HTTPException( + status_code=409, + detail="This identity is already linked to another account", + ) - # Link identity to current user - await sso_service.link_identity( - db, - str(current_user.id), - provider, - lookup_provider_user_id, - user_info.raw_data, - ) + # Link identity to current user + await sso_service.link_identity( + session, + str(current_user.id), + provider, + lookup_provider_user_id, + user_info.raw_data, + ) except HTTPException: raise @@ -1162,7 +1103,8 @@ async def bind_identity( logger.error(f"Identity bind failed for {provider}: {e}") raise HTTPException(status_code=500, detail="Failed to bind identity") - return UserOut.model_validate(current_user) + user = await user_dao.get(current_user.id) + return UserOut.model_validate(user) @router.post("/{provider}/unbind", response_model=UserOut) @@ -1170,30 +1112,31 @@ async def unbind_identity( provider: str, data: IdentityUnbindRequest, current_user: User = Depends(get_current_user), - db: AsyncSession = Depends(get_db), ): """Unlink an external identity from the current user.""" from app.services.sso_service import sso_service - # Unlink identity - success = await sso_service.unlink_identity(db, str(current_user.id), provider) - if not success: - raise HTTPException(status_code=404, detail=f"No linked identity found for provider '{provider}'") + async with transaction() as session: + success = await sso_service.unlink_identity(session, str(current_user.id), provider) + if not success: + raise HTTPException(status_code=404, detail=f"No linked identity found for provider '{provider}'") - return UserOut.model_validate(current_user) + user = await user_dao.get(current_user.id) + return UserOut.model_validate(user) # ─── Email Verification Endpoints ────────────────────────────────────── @router.post("/verify-email") -async def verify_email(data: VerifyEmailRequest, db: AsyncSession = Depends(get_db)): +async def verify_email(data: VerifyEmailRequest): """Verify email address using a token from the verification email. On success, returns user info and access token to allow immediate login. """ from app.services.email_verification_service import email_verification_service + # Consume verification token outside transaction (Redis operation) token_data = await email_verification_service.consume_email_verification_token(data.token) if not token_data: raise HTTPException(status_code=400, detail="Invalid or expired verification token") @@ -1201,41 +1144,29 @@ async def verify_email(data: VerifyEmailRequest, db: AsyncSession = Depends(get_ identity_id = token_data.get("identity_id") if not identity_id: raise HTTPException(status_code=400, detail="Token does not contain identity information") - - # 1. Update Identity - identity_result = await db.execute(select(Identity).where(Identity.id == identity_id)) - identity = identity_result.scalar_one_or_none() - if not identity: - raise HTTPException(status_code=400, detail="Identity not found") - - identity.email_verified = True - identity.is_active = True - - # 2. Activate all linked User accounts - # email_verified is a proxy to Identity, so only update physical is_active column - from sqlalchemy import update - await db.execute( - update(User) - .where(User.identity_id == identity.id) - .values(is_active=True) - ) - - await db.flush() - await db.commit() - - # Refresh after commit to avoid MissingGreenlet during Pydantic validation - await db.refresh(identity) - - # 3. Find a representative user for the token (for immediate login) - user_result = await db.execute( - select(User) - .where(User.identity_id == identity.id) - .order_by(User.created_at.desc()) - .limit(1) - ) - user = user_result.scalar_one_or_none() - # 4. Generate token and return full response for Auto Login (TokenResponse) + async with transaction() as session: + # 1. Update Identity + identity = await identity_dao.get(identity_id) + if not identity: + raise HTTPException(status_code=400, detail="Identity not found") + + identity.email_verified = True + identity.is_active = True + + # 2. Activate all linked User accounts + users = await user_dao.get_by_identity_id(identity.id) + for u in users: + u.is_active = True + + await session.flush() + # Refresh inside transaction to ensure we have the committed model state + await session.refresh(identity) + + # 3. Find a representative user outside transaction (read-only) + user = await user_dao.get_representative_user_for_identity(identity.id) + + # 4. Generate token and return full response outside transaction effective_id = str(user.id) if user else str(identity.id) effective_role = user.role if user else "user" token = create_access_token(effective_id, effective_role) @@ -1252,7 +1183,6 @@ async def verify_email(data: VerifyEmailRequest, db: AsyncSession = Depends(get_ async def resend_verification( data: ResendVerificationRequest, background_tasks: BackgroundTasks, - db: AsyncSession = Depends(get_db), ): """Resend email verification link.""" from app.config import get_settings @@ -1265,26 +1195,23 @@ async def resend_verification( } settings = get_settings() - # Check if email is configured (DB-only, no env fallback) - email_config = await resolve_email_config_async(db) + # Check if email is configured (DB-only, no env fallback) outside transaction (read-only) + email_config = await resolve_email_config_async() if not email_config: return generic_response - # Find Identity by email - id_result = await db.execute(select(Identity).where(Identity.email == data.email)) - identity = id_result.scalar_one_or_none() + # Find Identity by email (read-only) + identity = await identity_dao.get_by_email(data.email) # Don't reveal if user exists or already verified if not identity or identity.email_verified: return generic_response # Pick a representative user context (e.g. latest one) - u_result = await db.execute( - select(User).where(User.identity_id == identity.id).order_by(User.created_at.desc()).limit(1) - ) - user = u_result.scalar_one_or_none() - + user = await user_dao.get_representative_user_for_identity(identity.id) + if user: - await _send_verification_email_task(user, background_tasks, settings, db) + # Queue email task outside transaction + await _send_verification_email_task(user, background_tasks, settings) return generic_response diff --git a/backend/app/api/dingtalk.py b/backend/app/api/dingtalk.py index c7c0ac9e7..68646f79b 100644 --- a/backend/app/api/dingtalk.py +++ b/backend/app/api/dingtalk.py @@ -449,7 +449,7 @@ async def dingtalk_callback( pass # 2. Get DingTalk provider config - auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(tenant_id) if tenant_id else None) + auth_provider = await auth_provider_registry.get_provider("dingtalk", str(tenant_id) if tenant_id else None) if not auth_provider: return HTMLResponse("Auth failed: DingTalk provider not configured for this tenant") diff --git a/backend/app/api/google_workspace.py b/backend/app/api/google_workspace.py index bcceca43c..56972b97b 100644 --- a/backend/app/api/google_workspace.py +++ b/backend/app/api/google_workspace.py @@ -79,7 +79,7 @@ async def _handle_google_sso_callback( auth_provider = GoogleWorkspaceAuthProvider(provider=provider, config=provider.config or {}) else: auth_provider = await auth_provider_registry.get_provider( - db, "google_workspace", str(tenant_id) if tenant_id else None + "google_workspace", str(tenant_id) if tenant_id else None ) if not auth_provider: return HTMLResponse("Auth failed: Google Workspace provider not configured for this tenant") diff --git a/backend/app/api/notification.py b/backend/app/api/notification.py index b726a2c7e..cebfd0b8b 100644 --- a/backend/app/api/notification.py +++ b/backend/app/api/notification.py @@ -183,7 +183,6 @@ async def broadcast_notification( from app.services.system_email_service import ( BroadcastEmailRecipient, deliver_broadcast_emails, - run_background_email_job, ) for user in users: @@ -205,7 +204,7 @@ async def broadcast_notification( await db.commit() if email_recipients: - background_tasks.add_task(run_background_email_job, deliver_broadcast_emails, email_recipients) + background_tasks.add_task(deliver_broadcast_emails, email_recipients) return { "ok": True, "users_notified": count_users, diff --git a/backend/app/api/organization.py b/backend/app/api/organization.py index 578e0a5f8..5d9e428ba 100644 --- a/backend/app/api/organization.py +++ b/backend/app/api/organization.py @@ -97,7 +97,6 @@ async def admin_update_user( if "email" in update_data or "primary_mobile" in update_data: from app.services.registration_service import registration_service await registration_service.sync_org_member_contact_from_user( - db, user, sync_email="email" in update_data, sync_phone="primary_mobile" in update_data, diff --git a/backend/app/api/sso.py b/backend/app/api/sso.py index 6257ef1ae..07048a320 100644 --- a/backend/app/api/sso.py +++ b/backend/app/api/sso.py @@ -125,7 +125,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De elif p.provider_type == "dingtalk": from app.services.auth_registry import auth_provider_registry - auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(session.tenant_id) if session.tenant_id else None) + auth_provider = await auth_provider_registry.get_provider("dingtalk", str(session.tenant_id) if session.tenant_id else None) if auth_provider: redir = f"{public_base}/api/auth/dingtalk/callback" # Use provider's standardized authorization URL @@ -147,7 +147,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De sign_google_sso_state, ) auth_provider = await auth_provider_registry.get_provider( - db, "google_workspace", str(session.tenant_id) if session.tenant_id else None + "google_workspace", str(session.tenant_id) if session.tenant_id else None ) if auth_provider: redir = await get_google_redirect_uri(db, p, request) diff --git a/backend/app/api/wecom.py b/backend/app/api/wecom.py index ed15ed477..26a2d7dd8 100644 --- a/backend/app/api/wecom.py +++ b/backend/app/api/wecom.py @@ -710,7 +710,6 @@ async def wecom_callback( # 2. Extract user info and login/register via RegistrationService try: auth_provider = await auth_provider_registry.get_provider( - db, "wecom", str(tenant_id) if tenant_id else (str(provider.tenant_id) if provider.tenant_id else None), ) diff --git a/backend/app/dao/__init__.py b/backend/app/dao/__init__.py new file mode 100644 index 000000000..d1d5f5102 --- /dev/null +++ b/backend/app/dao/__init__.py @@ -0,0 +1,19 @@ +from app.dao.identity_dao import identity_dao +from app.dao.identity_provider_dao import identity_provider_dao +from app.dao.invitation_code_dao import invitation_code_dao +from app.dao.org_member_dao import org_member_dao +from app.dao.participant_dao import participant_dao +from app.dao.system_setting_dao import system_setting_dao +from app.dao.tenant_dao import tenant_dao +from app.dao.user_dao import user_dao + +__all__ = [ + "identity_dao", + "identity_provider_dao", + "invitation_code_dao", + "org_member_dao", + "participant_dao", + "system_setting_dao", + "tenant_dao", + "user_dao", +] diff --git a/backend/app/dao/base.py b/backend/app/dao/base.py new file mode 100644 index 000000000..5ca484d38 --- /dev/null +++ b/backend/app/dao/base.py @@ -0,0 +1,80 @@ +from collections.abc import AsyncGenerator, Sequence +from contextlib import asynccontextmanager +from typing import Any, Generic, Type, TypeVar + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import Base, _session_ctx, async_session + +ModelType = TypeVar("ModelType", bound=Base) + + +class BaseDAO(Generic[ModelType]): + """Base class for data access objects, managing session context and basic CRUD.""" + + def __init__(self, model: Type[ModelType]): + self.model = model + + @asynccontextmanager + async def session(self) -> AsyncGenerator[AsyncSession, None]: + """Context manager yielding the active context session or a new one.""" + context_session = _session_ctx.get() + if context_session is not None: + yield context_session + else: + async with async_session() as session: + yield session + + async def get(self, id: Any) -> ModelType | None: + """Fetch a single record by its primary key ID.""" + async with self.session() as db: + if hasattr(db, "get"): + return await db.get(self.model, id) + # Fallback for custom mock DB clients in tests + stmt = select(self.model).where(self.model.id == id) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + async def is_empty(self) -> bool: + """Check if the table is empty (no records).""" + async with self.session() as db: + stmt = select(self.model.id).limit(1) + result = await db.execute(stmt) + return result.scalar() is None + + async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[ModelType]: + """Fetch all records with offset and limit.""" + async with self.session() as db: + stmt = select(self.model).offset(skip).limit(limit) + result = await db.execute(stmt) + return result.scalars().all() + + async def create(self, *, obj_in: dict[str, Any]) -> ModelType: + """Create a new record.""" + async with self.session() as db: + db_obj = self.model(**obj_in) + db.add(db_obj) + await db.flush() + return db_obj + + async def update(self, *, db_obj: ModelType, obj_in: dict[str, Any]) -> ModelType: + """Update an existing record.""" + async with self.session() as db: + for field, value in obj_in.items(): + if hasattr(db_obj, field): + setattr(db_obj, field, value) + db.add(db_obj) + await db.flush() + return db_obj + + async def delete(self, *, id: Any) -> ModelType | None: + """Delete a record by ID.""" + async with self.session() as db: + obj = await self.get(id) + if obj: + if hasattr(db, "delete"): + await db.delete(obj) + await db.flush() + return obj + diff --git a/backend/app/dao/identity_dao.py b/backend/app/dao/identity_dao.py new file mode 100644 index 000000000..fa97df27c --- /dev/null +++ b/backend/app/dao/identity_dao.py @@ -0,0 +1,82 @@ +import re +import uuid +from typing import Any + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.user import Identity + + +class IdentityDAO(BaseDAO[Identity]): + """DAO for Identity model handling authentication credentials.""" + + def __init__(self) -> None: + super().__init__(Identity) + + async def get_by_login_identifier(self, identifier: str) -> Identity | None: + """Find identity by email, phone, or username.""" + async with self.session() as db: + query = select(Identity).where( + (Identity.email == identifier) | (Identity.phone == identifier) | (Identity.username == identifier) + ) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_email(self, email: str) -> Identity | None: + """Find identity by email address.""" + async with self.session() as db: + query = select(Identity).where(Identity.email == email) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_username(self, username: str) -> Identity | None: + """Find identity by username.""" + async with self.session() as db: + query = select(Identity).where(Identity.username == username) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_phone(self, phone: str) -> Identity | None: + """Find identity by normalized phone number.""" + normalized = re.sub(r"[\s\-\+]", "", phone) + async with self.session() as db: + query = select(Identity).where(Identity.phone == normalized) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def is_username_taken(self, username: str) -> bool: + """Return True if the username is already used by another identity.""" + async with self.session() as db: + result = await db.execute( + select(Identity.id).where(Identity.username == username).limit(1) + ) + return result.scalar_one_or_none() is not None + + async def create_identity( + self, + *, + email: str | None = None, + phone: str | None = None, + username: str | None = None, + password_hash: str | None = None, + is_platform_admin: bool = False, + email_verified: bool = False, + ) -> Identity: + """Create and flush a new Identity row.""" + normalized_phone = re.sub(r"[\s\-\+]", "", phone) if phone else None + async with self.session() as db: + identity = Identity( + email=email, + phone=normalized_phone, + username=username, + password_hash=password_hash, + is_platform_admin=is_platform_admin, + email_verified=email_verified, + ) + db.add(identity) + await db.flush() + return identity + + +identity_dao = IdentityDAO() diff --git a/backend/app/dao/identity_provider_dao.py b/backend/app/dao/identity_provider_dao.py new file mode 100644 index 000000000..8a834bd1c --- /dev/null +++ b/backend/app/dao/identity_provider_dao.py @@ -0,0 +1,61 @@ +"""DAO for IdentityProvider model.""" + +from typing import Any + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.identity import IdentityProvider + + +class IdentityProviderDAO(BaseDAO[IdentityProvider]): + """DAO for IdentityProvider model.""" + + def __init__(self) -> None: + super().__init__(IdentityProvider) + + async def get_by_type_and_tenant( + self, + provider_type: str, + tenant_id: Any | None, + ) -> IdentityProvider | None: + """Find an IdentityProvider by type scoped to a tenant (or global if None).""" + async with self.session() as db: + query = select(IdentityProvider).where( + IdentityProvider.provider_type == provider_type, + ) + if tenant_id is None: + query = query.where(IdentityProvider.tenant_id.is_(None)) + else: + query = query.where(IdentityProvider.tenant_id == tenant_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_or_create( + self, + provider_type: str, + tenant_id: Any | None, + *, + name: str | None = None, + sso_login_enabled: bool = False, + ) -> IdentityProvider: + """Get an existing IdentityProvider or create it if missing.""" + provider = await self.get_by_type_and_tenant(provider_type, tenant_id) + if provider: + return provider + + async with self.session() as db: + provider = IdentityProvider( + provider_type=provider_type, + name=name or provider_type.capitalize(), + is_active=True, + sso_login_enabled=sso_login_enabled, + config={}, + tenant_id=tenant_id, + ) + db.add(provider) + await db.flush() + return provider + + +identity_provider_dao = IdentityProviderDAO() diff --git a/backend/app/dao/invitation_code_dao.py b/backend/app/dao/invitation_code_dao.py new file mode 100644 index 000000000..fa20ea634 --- /dev/null +++ b/backend/app/dao/invitation_code_dao.py @@ -0,0 +1,28 @@ +"""DAO for InvitationCode model.""" + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.invitation_code import InvitationCode + + +class InvitationCodeDAO(BaseDAO[InvitationCode]): + """DAO for InvitationCode model.""" + + def __init__(self) -> None: + super().__init__(InvitationCode) + + async def get_active_by_code(self, code: str) -> InvitationCode | None: + """Find an active invitation code with a tenant association.""" + async with self.session() as db: + result = await db.execute( + select(InvitationCode).where( + InvitationCode.code == code, + InvitationCode.is_active == True, + InvitationCode.tenant_id.is_not(None), + ) + ) + return result.scalar_one_or_none() + + +invitation_code_dao = InvitationCodeDAO() diff --git a/backend/app/dao/org_member_dao.py b/backend/app/dao/org_member_dao.py new file mode 100644 index 000000000..4ab466580 --- /dev/null +++ b/backend/app/dao/org_member_dao.py @@ -0,0 +1,120 @@ +"""DAO for OrgMember model.""" + +from typing import Any, Sequence + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.org import OrgMember + + +class OrgMemberDAO(BaseDAO[OrgMember]): + """DAO for OrgMember model.""" + + def __init__(self) -> None: + super().__init__(OrgMember) + + async def find_unbound_by_email( + self, + email: str, + tenant_id: Any, + ) -> OrgMember | None: + """Find an OrgMember without a linked user that matches by email.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.email == email, + OrgMember.tenant_id == tenant_id, + OrgMember.user_id == None, + ).limit(1) + ) + return result.scalar_one_or_none() + + async def find_unbound_by_phone( + self, + phone: str, + tenant_id: Any, + ) -> OrgMember | None: + """Find an OrgMember without a linked user that matches by phone.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.phone == phone, + OrgMember.tenant_id == tenant_id, + OrgMember.user_id == None, + ).limit(1) + ) + return result.scalar_one_or_none() + + async def get_by_user_and_provider( + self, + user_id: Any, + tenant_id: Any, + provider_id: Any, + ) -> OrgMember | None: + """Find the OrgMember record for a user under a specific provider.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.user_id == user_id, + OrgMember.tenant_id == tenant_id, + OrgMember.provider_id == provider_id, + ).limit(1) + ) + return result.scalar_one_or_none() + + async def find_unbound_by_email_and_provider( + self, + email: str, + tenant_id: Any, + provider_id: Any, + ) -> OrgMember | None: + """Find an unlinked OrgMember by email under a specific provider.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.email == email, + OrgMember.tenant_id == tenant_id, + OrgMember.provider_id == provider_id, + OrgMember.user_id == None, + ).limit(1) + ) + return result.scalar_one_or_none() + + async def find_unbound_by_phone_and_provider( + self, + phone: str, + tenant_id: Any, + provider_id: Any, + ) -> OrgMember | None: + """Find an unlinked OrgMember by phone under a specific provider.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.phone == phone, + OrgMember.tenant_id == tenant_id, + OrgMember.provider_id == provider_id, + OrgMember.user_id == None, + ).limit(1) + ) + return result.scalar_one_or_none() + + async def get_by_user_and_tenant_and_provider( + self, + user_id: Any, + tenant_id: Any, + provider_id: Any, + ) -> Sequence[OrgMember]: + """Get all OrgMember records for a user+tenant+provider combination.""" + async with self.session() as db: + result = await db.execute( + select(OrgMember).where( + OrgMember.user_id == user_id, + OrgMember.tenant_id == tenant_id, + OrgMember.provider_id == provider_id, + ) + ) + return result.scalars().all() + + +org_member_dao = OrgMemberDAO() diff --git a/backend/app/dao/participant_dao.py b/backend/app/dao/participant_dao.py new file mode 100644 index 000000000..f8435c9d0 --- /dev/null +++ b/backend/app/dao/participant_dao.py @@ -0,0 +1,32 @@ +"""DAO for Participant model.""" + +from app.dao.base import BaseDAO +from app.models.participant import Participant + + +class ParticipantDAO(BaseDAO[Participant]): + """DAO for Participant model.""" + + def __init__(self) -> None: + super().__init__(Participant) + + async def create_for_user( + self, + user_id, + display_name: str | None = None, + avatar_url: str | None = None, + ) -> Participant: + """Create a Participant record linked to a User.""" + async with self.session() as db: + participant = Participant( + type="user", + ref_id=user_id, + display_name=display_name, + avatar_url=avatar_url, + ) + db.add(participant) + await db.flush() + return participant + + +participant_dao = ParticipantDAO() diff --git a/backend/app/dao/system_setting_dao.py b/backend/app/dao/system_setting_dao.py new file mode 100644 index 000000000..ad10b460f --- /dev/null +++ b/backend/app/dao/system_setting_dao.py @@ -0,0 +1,43 @@ +"""DAO for the system_settings key-value table.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.system_settings import SystemSetting + + +class SystemSettingDAO(BaseDAO[SystemSetting]): + """Typed access layer for platform-level system settings.""" + + def __init__(self) -> None: + super().__init__(SystemSetting) + + async def get_by_key(self, key: str) -> SystemSetting | None: + """Fetch a single SystemSetting row by its primary key.""" + async with self.session() as db: + result = await db.execute(select(SystemSetting).where(SystemSetting.key == key)) + return result.scalar_one_or_none() + + async def get_value(self, key: str, default: Any = None) -> Any: + """Return the JSON value for a key, or *default* when the row is absent.""" + setting = await self.get_by_key(key) + if setting is None: + return default + return setting.value + + async def is_invitation_code_enabled(self) -> bool: + """Return whether invitation-code enforcement is active.""" + value = await self.get_value("invitation_code_enabled", {}) + return bool(value.get("enabled", False)) + + async def is_sso_custom_domain_redirect_enabled(self) -> bool: + """Return whether cross-domain SSO redirect is globally enabled.""" + value = await self.get_value("sso_custom_domain_redirect_enabled", {}) + return bool(value.get("enabled", True)) + + +system_setting_dao = SystemSettingDAO() diff --git a/backend/app/dao/tenant_dao.py b/backend/app/dao/tenant_dao.py new file mode 100644 index 000000000..586af192d --- /dev/null +++ b/backend/app/dao/tenant_dao.py @@ -0,0 +1,43 @@ +from typing import Any, Sequence + +from sqlalchemy import select + +from app.dao.base import BaseDAO +from app.models.tenant import Tenant + + +class TenantDAO(BaseDAO[Tenant]): + """DAO for Tenant model handling organization-scoped records.""" + + def __init__(self) -> None: + super().__init__(Tenant) + + async def get_by_slug(self, slug: str) -> Tenant | None: + """Find a tenant by its unique slug identifier.""" + async with self.session() as db: + query = select(Tenant).where(Tenant.slug == slug) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_ids(self, ids: Sequence[Any]) -> Sequence[Tenant]: + """Find multiple tenants by a list of their IDs.""" + if not ids: + return [] + async with self.session() as db: + query = select(Tenant).where(Tenant.id.in_(ids)) + result = await db.execute(query) + return result.scalars().all() + + async def get_by_sso_domain(self, domain: str) -> Tenant | None: + """Find an active tenant matching the given SSO email domain.""" + async with self.session() as db: + result = await db.execute( + select(Tenant).where( + Tenant.sso_domain == domain.lower(), + Tenant.is_active == True, + ) + ) + return result.scalar_one_or_none() + + +tenant_dao = TenantDAO() diff --git a/backend/app/dao/user_dao.py b/backend/app/dao/user_dao.py new file mode 100644 index 000000000..428e98542 --- /dev/null +++ b/backend/app/dao/user_dao.py @@ -0,0 +1,94 @@ +from typing import Any, Sequence + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from app.dao.base import BaseDAO +from app.models.user import Identity, User + + +class UserDAO(BaseDAO[User]): + """DAO for User model handling tenant-scoped user records.""" + + def __init__(self) -> None: + super().__init__(User) + + async def get_by_identity_and_tenant(self, identity_id: Any, tenant_id: Any | None) -> User | None: + """Find a user in a specific tenant (or tenant-less) by identity ID.""" + async with self.session() as db: + query = select(User).where(User.identity_id == identity_id) + if tenant_id is not None: + query = query.where(User.tenant_id == tenant_id) + else: + query = query.where(User.tenant_id.is_(None)) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_identity_id(self, identity_id: Any, include_identity: bool = False) -> Sequence[User]: + """Find all users associated with an identity ID.""" + async with self.session() as db: + query = select(User).where(User.identity_id == identity_id) + if include_identity: + query = query.options(selectinload(User.identity)) + result = await db.execute(query) + return result.scalars().all() + + async def get_by_identity_username(self, username: str) -> User | None: + """Find user by identity username.""" + async with self.session() as db: + query = select(User).join(Identity, User.identity_id == Identity.id).where(Identity.username == username) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_email_and_tenant( + self, email: str, tenant_id: Any | None, exclude_user_id: Any | None = None + ) -> User | None: + """Find user by identity email in a specific tenant, optionally excluding a user ID.""" + async with self.session() as db: + query = ( + select(User) + .join(Identity, User.identity_id == Identity.id) + .where( + Identity.email == email, + User.tenant_id == tenant_id, + ) + ) + if exclude_user_id is not None: + query = query.where(User.id != exclude_user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_by_phone_and_tenant( + self, phone: str, tenant_id: Any | None, exclude_user_id: Any | None = None + ) -> User | None: + """Find user by identity phone in a specific tenant, optionally excluding a user ID.""" + async with self.session() as db: + query = ( + select(User) + .join(Identity, User.identity_id == Identity.id) + .where( + Identity.phone == phone, + User.tenant_id == tenant_id, + ) + ) + if exclude_user_id is not None: + query = query.where(User.id != exclude_user_id) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_with_identity(self, user_id: Any) -> User | None: + """Fetch user by ID with identity preloaded.""" + async with self.session() as db: + query = select(User).where(User.id == user_id).options(selectinload(User.identity)) + result = await db.execute(query) + return result.scalar_one_or_none() + + async def get_representative_user_for_identity(self, identity_id: Any) -> User | None: + """Find a representative user (e.g. latest created) associated with an identity ID.""" + async with self.session() as db: + query = select(User).where(User.identity_id == identity_id).order_by(User.created_at.desc()).limit(1) + result = await db.execute(query) + return result.scalar_one_or_none() + + +user_dao = UserDAO() diff --git a/backend/app/database.py b/backend/app/database.py index e9bd7afd4..8c9ad630c 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,6 +1,8 @@ """Database connection and session management.""" from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from contextvars import ContextVar from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase @@ -34,3 +36,40 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]: except Exception: await session.rollback() raise + + +_session_ctx: ContextVar[AsyncSession | None] = ContextVar("db_session_ctx", default=None) + + +@asynccontextmanager +async def transaction(session: AsyncSession | None = None) -> AsyncGenerator[AsyncSession, None]: + """Provide a transactional boundary using contextvars.""" + if session is not None: + token = _session_ctx.set(session) + try: + yield session + if hasattr(session, "commit"): + await session.commit() + except Exception: + if hasattr(session, "rollback"): + await session.rollback() + raise + finally: + _session_ctx.reset(token) + return + + existing_session = _session_ctx.get() + if existing_session is not None: + yield existing_session + return + + async with async_session() as session: + token = _session_ctx.set(session) + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + _session_ctx.reset(token) diff --git a/backend/app/scripts/cleanup_duplicate_feishu_users.py b/backend/app/scripts/cleanup_duplicate_feishu_users.py index 0e8a8addd..9300195db 100644 --- a/backend/app/scripts/cleanup_duplicate_feishu_users.py +++ b/backend/app/scripts/cleanup_duplicate_feishu_users.py @@ -33,7 +33,7 @@ async def main(): async with async_session() as db: # ── Step 0: Load org sync app credentials ── - provider = await auth_provider_registry.get_provider(db, "feishu") + provider = await auth_provider_registry.get_provider("feishu") if not provider: logger.warning("No feishu identity provider configured. Cannot resolve user_ids. Skipping backfill.") logger.info("You can still run Sync Now from the UI after configuring feishu identity provider.") diff --git a/backend/app/services/access_relationships.py b/backend/app/services/access_relationships.py index da578a38c..1a6d5c810 100644 --- a/backend/app/services/access_relationships.py +++ b/backend/app/services/access_relationships.py @@ -61,7 +61,7 @@ async def ensure_access_granted_platform_relationships( changed = False for user in users_result.scalars().all(): - member = await registration_service.ensure_web_org_member(db, user) + member = await registration_service.ensure_web_org_member(user) if not member or member.status != "active": continue db.add( diff --git a/backend/app/services/auth_provider.py b/backend/app/services/auth_provider.py index c070dc2c1..9cae3f519 100644 --- a/backend/app/services/auth_provider.py +++ b/backend/app/services/auth_provider.py @@ -138,7 +138,7 @@ async def find_or_create_user( # Update user info and ensure identity is loaded if not user.identity_id: from app.services.registration_service import registration_service - identity = await registration_service.find_or_create_identity(db, email=user_info.email, phone=user_info.mobile) + identity = await registration_service.find_or_create_identity(email=user_info.email, phone=user_info.mobile) user.identity_id = identity.id await self._update_existing_user(db, user, user_info) @@ -159,7 +159,7 @@ async def find_or_create_user( # SSO users should also appear as Web members for tenant-side user management. from app.services.registration_service import registration_service - await registration_service.ensure_web_org_member(db, user) + await registration_service.ensure_web_org_member(user) return user, is_new @@ -219,7 +219,6 @@ async def _create_new_user( effective_id = user_info.provider_user_id or user_info.provider_union_id or "unknown" identity = await registration_service.find_or_create_identity( - db, email=user_info.email, phone=user_info.mobile, username=user_info.email.split("@")[0] if user_info.email else None, diff --git a/backend/app/services/auth_registry.py b/backend/app/services/auth_registry.py index 0e5db7230..f61da5bbf 100644 --- a/backend/app/services/auth_registry.py +++ b/backend/app/services/auth_registry.py @@ -8,6 +8,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.dao import identity_provider_dao from app.models.identity import IdentityProvider from app.services.auth_provider import ( PROVIDER_CLASSES, @@ -34,12 +35,11 @@ def __init__(self): self._cache: dict[str, BaseAuthProvider] = {} async def get_provider( - self, db: AsyncSession, provider_type: str, tenant_id: str | None = None + self, provider_type: str, tenant_id: str | None = None ) -> BaseAuthProvider | None: """Get or create an authentication provider instance. Args: - db: Database session provider_type: The type of provider (feishu, dingtalk, etc.) tenant_id: Optional tenant ID for tenant-specific providers @@ -52,12 +52,13 @@ async def get_provider( return self._cache[cache_key] # Try to get provider config from database - provider_model = await get_preferred_identity_provider( - db, - provider_type, - tenant_id, - is_active=True, - ) + async with identity_provider_dao.session() as db: + provider_model = await get_preferred_identity_provider( + db, + provider_type, + tenant_id, + is_active=True, + ) # Create provider instance provider = self._create_provider(provider_type, provider_model) @@ -86,28 +87,28 @@ def _create_provider( return provider_class(provider=provider_model, config=config) async def list_providers( - self, db: AsyncSession, tenant_id: str | None = None + self, tenant_id: str | None = None ) -> list[IdentityProvider]: """List all available identity providers. Args: - db: Database session tenant_id: Optional tenant ID to filter by Returns: List of IdentityProvider records """ - query = select(IdentityProvider).where(IdentityProvider.is_active == True) - - if tenant_id: - # Only include tenant-specific ones - query = query.where(IdentityProvider.tenant_id == tenant_id) - else: - # Public OAuth login should only expose global providers. - query = query.where(IdentityProvider.tenant_id.is_(None)) - - result = await db.execute(query) - return list(result.scalars().all()) + async with identity_provider_dao.session() as db: + query = select(IdentityProvider).where(IdentityProvider.is_active == True) + + if tenant_id: + # Only include tenant-specific ones + query = query.where(IdentityProvider.tenant_id == tenant_id) + else: + # Public OAuth login should only expose global providers. + query = query.where(IdentityProvider.tenant_id.is_(None)) + + result = await db.execute(query) + return list(result.scalars().all()) async def create_provider( self, diff --git a/backend/app/services/channel_user_service.py b/backend/app/services/channel_user_service.py index bd7e88948..002a89a45 100644 --- a/backend/app/services/channel_user_service.py +++ b/backend/app/services/channel_user_service.py @@ -441,7 +441,6 @@ async def _create_channel_user( # Step 1: Find or create global Identity using unified registration service from app.services.registration_service import registration_service identity = await registration_service.find_or_create_identity( - db, email=email, phone=extra_info.get("mobile"), username=username, @@ -557,7 +556,6 @@ async def get_platform_user_by_org_member( from app.services.registration_service import registration_service # Use unified find_or_create_identity with dual lookup (email/phone) identity = await registration_service.find_or_create_identity( - db, email=email, phone=org_member.phone, username=username, diff --git a/backend/app/services/feishu_service.py b/backend/app/services/feishu_service.py index 5e4aa0928..86d556625 100644 --- a/backend/app/services/feishu_service.py +++ b/backend/app/services/feishu_service.py @@ -296,9 +296,8 @@ async def login_or_register(self, db: AsyncSession, feishu_user: dict, tenant_id from app.services.registration_service import registration_service # No phone available in this specific Feishu login block, but it handles email/username matching identity = await registration_service.find_or_create_identity( - db, email=email, - phone=user_info.get("mobile"), + phone=feishu_user.get("mobile"), username=username, password=open_id, ) diff --git a/backend/app/services/password_reset_service.py b/backend/app/services/password_reset_service.py index 7416993c9..4db859f5f 100644 --- a/backend/app/services/password_reset_service.py +++ b/backend/app/services/password_reset_service.py @@ -7,12 +7,8 @@ import uuid from datetime import datetime, timedelta, timezone -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - from app.config import get_settings from app.core.events import get_redis -from app.models.system_settings import SystemSetting # Key prefixes for Redis TOKEN_PREFIX = "pwd_reset:token:" @@ -53,15 +49,15 @@ async def create_password_reset_token(identity_id: uuid.UUID) -> tuple[str, date return raw_token, expires_at -async def get_public_base_url(db: AsyncSession) -> str: +async def get_public_base_url() -> str: """Resolve the public base URL used for user-facing links.""" from app.services.platform_service import platform_service - return await platform_service.get_public_base_url(db) + return await platform_service.get_public_base_url() -async def build_password_reset_url(db: AsyncSession, raw_token: str) -> str: +async def build_password_reset_url(raw_token: str) -> str: """Build the user-facing reset URL.""" - base_url = await get_public_base_url(db) + base_url = await get_public_base_url() return f"{base_url}/reset-password?token={raw_token}" diff --git a/backend/app/services/platform_service.py b/backend/app/services/platform_service.py index fa93d7ac2..22e2fa109 100644 --- a/backend/app/services/platform_service.py +++ b/backend/app/services/platform_service.py @@ -3,9 +3,7 @@ import os import re from fastapi import Request -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.models.system_settings import SystemSetting class PlatformService: """Service to handle platform-wide settings and URL resolution.""" @@ -40,15 +38,20 @@ async def get_public_base_url(self, db: AsyncSession | None = None, request: Req return "https://try.clawith.ai" - async def get_tenant_sso_base_url(self, db: AsyncSession, tenant, request: Request | None = None) -> str: - """Generate the SSO base URL for a tenant based on IP/Domain logic.""" - # Check if custom domain SSO redirect is enabled globally - setting_result = await db.execute( - select(SystemSetting).where(SystemSetting.key == "sso_custom_domain_redirect_enabled") - ) - setting_s = setting_result.scalar_one_or_none() - sso_redirect_enabled = setting_s.value.get("enabled", True) if setting_s else True + async def get_tenant_sso_base_url( + self, + db: AsyncSession, + tenant, + request: Request | None = None, + *, + sso_redirect_enabled: bool = True, + ) -> str: + """Generate the SSO base URL for a tenant based on IP/Domain logic. + ``sso_redirect_enabled`` should be pre-resolved by the caller via + ``system_setting_dao.is_sso_custom_domain_redirect_enabled()`` so this + method never issues an extra DB round-trip for the setting. + """ if sso_redirect_enabled and tenant.sso_domain: return tenant.sso_domain.rstrip("/") @@ -56,21 +59,21 @@ async def get_tenant_sso_base_url(self, db: AsyncSession, tenant, request: Reque return await self.get_public_base_url(db, request) base_url = await self.get_public_base_url(db, request) - + # Parse protocol and host # Example: http://1.2.3.4:8000 or http://clawith.ai parts = base_url.split("://") if len(parts) < 2: return base_url - + protocol = parts[0] host_port = parts[1] - + # Split host and port host_parts = host_port.split(":") host = host_parts[0] port = f":{host_parts[1]}" if len(host_parts) > 1 else "" - + if self.is_ip_address(host): # IP: No subdomain, just base URL return base_url @@ -79,15 +82,15 @@ async def get_tenant_sso_base_url(self, db: AsyncSession, tenant, request: Reque # Special case for localhost: keep it as is or handle it if host == "localhost": return f"{protocol}://{host}{port}" - - # Generic logic: if host has a subdomain (e.g. try.clawith.ai), + + # Generic logic: if host has a subdomain (e.g. try.clawith.ai), # we strip the first component to form a base for tenant subdomains. h_parts = host.split(".") if len(h_parts) > 2: target_host = ".".join(h_parts[1:]) else: target_host = host - + return f"{protocol}://{tenant.slug}.{target_host}{port}" diff --git a/backend/app/services/registration_service.py b/backend/app/services/registration_service.py index 601b7a95a..a999d5fac 100644 --- a/backend/app/services/registration_service.py +++ b/backend/app/services/registration_service.py @@ -8,27 +8,34 @@ import re import uuid -from datetime import datetime from typing import Any -from sqlalchemy import select, or_, and_ -from sqlalchemy.ext.asyncio import AsyncSession - from app.config import get_settings from app.core.security import hash_password_async +from app.dao import ( + identity_dao, + identity_provider_dao, + invitation_code_dao, + org_member_dao, + participant_dao, + tenant_dao, + user_dao, +) from app.models.identity import IdentityProvider from app.models.tenant import Tenant from app.models.user import User, Identity from app.services.sso_service import sso_service +from app.services.system_email_service import resolve_email_config_async from loguru import logger class RegistrationService: """Service for handling user registration flows.""" + # ── Identity provider ──────────────────────────────────────────────────── + async def ensure_identity_provider( self, - db: AsyncSession, provider_type: str, tenant_id: uuid.UUID | None, *, @@ -36,178 +43,120 @@ async def ensure_identity_provider( sso_login_enabled: bool = False, ) -> IdentityProvider: """Get or create an identity provider record for a tenant.""" - query = select(IdentityProvider).where( - IdentityProvider.provider_type == provider_type, - ) - if tenant_id is None: - query = query.where(IdentityProvider.tenant_id.is_(None)) - else: - query = query.where(IdentityProvider.tenant_id == tenant_id) - - result = await db.execute(query) - provider = result.scalar_one_or_none() - if provider: - return provider - - provider = IdentityProvider( - provider_type=provider_type, - name=name or provider_type.capitalize(), - is_active=True, + return await identity_provider_dao.get_or_create( + provider_type, + tenant_id, + name=name, sso_login_enabled=sso_login_enabled, - config={}, - tenant_id=tenant_id, ) - db.add(provider) - await db.flush() - return provider - - async def detect_tenant_by_email(self, db: AsyncSession, email: str) -> Tenant | None: - """Detect tenant based on email domain. - Args: - db: Database session - email: User email address + # ── Tenant detection ───────────────────────────────────────────────────── - Returns: - Tenant if found by domain match, None otherwise - """ + async def detect_tenant_by_email(self, email: str) -> Tenant | None: + """Detect tenant based on email domain.""" if not email or "@" not in email: return None - domain = email.split("@")[1].lower() + return await tenant_dao.get_by_sso_domain(domain) - # Try to find tenant by custom domain - Exact match to use index - result = await db.execute( - select(Tenant).where( - Tenant.sso_domain == domain, - Tenant.is_active == True, - ) - ) - return result.scalar_one_or_none() + # ── Duplicate check ────────────────────────────────────────────────────── async def check_duplicate_identity( self, - db: AsyncSession, email: str | None = None, mobile: str | None = None, ) -> dict[str, Any]: - """Check for existing identities or tenant-users that might conflict. - - Args: - db: Database session - email: Email address - mobile: Mobile phone - username: Username - tenant_id: Optional tenant to scope the search (for tenant-user conflicts) + """Check for existing identities that might conflict. Returns: - Dict with conflict information + Dict with ``has_conflict`` bool and ``conflicts`` list. """ conflicts = [] - # 1. Check Global Identity Conflicts - if email: - ident_result = await db.execute(select(Identity).where(Identity.email == email)) - if ident_result.scalar_one_or_none(): - conflicts.append({ - "type": "email", - "scope": "global", - "message": "Email already registered", - }) - + if email and await identity_dao.get_by_email(email): + conflicts.append({ + "type": "email", + "scope": "global", + "message": "Email already registered", + }) + if mobile: - normalized_mobile = re.sub(r"[\s\-\+]", "", mobile) - ident_result = await db.execute(select(Identity).where(Identity.phone == normalized_mobile)) - if ident_result.scalar_one_or_none(): + normalized = re.sub(r"[\s\-\+]", "", mobile) + if await identity_dao.get_by_phone(normalized): conflicts.append({ "type": "mobile", "scope": "global", "message": "Mobile already registered", }) - return { - "has_conflict": len(conflicts) > 0, - "conflicts": conflicts, - } + return {"has_conflict": len(conflicts) > 0, "conflicts": conflicts} + + # ── Identity find / create ─────────────────────────────────────────────── async def find_or_create_identity( self, - db: AsyncSession, email: str | None = None, phone: str | None = None, username: str | None = None, password: str | None = None, is_platform_admin: bool = False, email_config: Any = None, + password_hash: str | None = None, ) -> Identity: """Find an existing identity or create a new one. Security note: only email and phone are authoritative identity claims. """ - identity = None + identity: Identity | None = None # Match by email (primary ownership claim) if email: - res = await db.execute(select(Identity).where(Identity.email == email)) - identity = res.scalar_one_or_none() + identity = await identity_dao.get_by_email(email) # Match by phone (secondary ownership claim) if not identity and phone: - normalized_phone = re.sub(r"[\s\-\+]", "", phone) - res = await db.execute(select(Identity).where(Identity.phone == normalized_phone)) - identity = res.scalar_one_or_none() + identity = await identity_dao.get_by_phone(phone) if identity: # Auto-verify if SMTP is not configured if not email_config: - from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) - - if not email_config: - if not identity.email_verified: - identity.email_verified = True - db.add(identity) + email_config = await resolve_email_config_async() + if not email_config and not identity.email_verified: + await identity_dao.update(db_obj=identity, obj_in={"email_verified": True}) return identity - # Check if SMTP is configured for auto-verification + # Determine verified status if not email_config: - from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) - - is_verified = not email_config # Auto-verify only if no SMTP configured + email_config = await resolve_email_config_async() + is_verified = not email_config # Auto-verify only when no SMTP configured - # Resolve a safe username + # Resolve a safe, unique username final_username = username - if username: - # Use EXISTS for faster lookup - existing_res = await db.execute( - select(Identity.id).where(Identity.username == username).limit(1) + if username and await identity_dao.is_username_taken(username): + final_username = f"{username}_{uuid.uuid4().hex[:6]}" + logger.info( + "Username '%s' already taken; assigned '%s' to new identity", + username, + final_username, ) - if existing_res.scalar_one_or_none(): - final_username = f"{username}_{uuid.uuid4().hex[:6]}" - logger.info( - "Username '%s' already taken; assigned '%s' to new identity", - username, - final_username, - ) - # Create new identity - normalized_phone = re.sub(r"[\s\-\+]", "", phone) if phone else None - identity = Identity( + # Hash password if not pre-hashed + if not password_hash and password: + password_hash = await hash_password_async(password) + + return await identity_dao.create_identity( email=email, - phone=normalized_phone, + phone=phone, username=final_username, - password_hash=await hash_password_async(password) if password else None, + password_hash=password_hash, is_platform_admin=is_platform_admin, email_verified=is_verified, ) - db.add(identity) - await db.flush() - return identity + + # ── User create ────────────────────────────────────────────────────────── async def create_user_with_identity( self, - db: AsyncSession, identity: Identity, display_name: str | None = None, role: str = "member", @@ -218,159 +167,116 @@ async def create_user_with_identity( """Create a new tenant-specific user linked to an identity.""" name = display_name or identity.username or "User" - # Check if SMTP is configured for auto-activation if not email_config: - from app.services.system_email_service import resolve_email_config_async - email_config = await resolve_email_config_async(db) - + email_config = await resolve_email_config_async() + is_active = identity.email_verified if not email_config: - is_active = True # Auto-activate if no SMTP configured - - # Create tenant-user record - user = User( - identity_id=identity.id, - tenant_id=tenant_id, - display_name=name, - role=role, - registration_source=registration_source, - is_active=is_active or identity.is_platform_admin, - ) + is_active = True # Auto-activate when no SMTP configured - db.add(user) - await db.flush() + user = await user_dao.create(obj_in={ + "identity_id": identity.id, + "tenant_id": tenant_id, + "display_name": name, + "role": role, + "registration_source": registration_source, + "is_active": is_active or identity.is_platform_admin, + }) # Link to OrgMember if exists - await self.bind_org_member(db, user) + await self.bind_org_member(user) # Create Participant record - from app.models.participant import Participant - db.add(Participant( - type="user", - ref_id=user.id, + await participant_dao.create_for_user( + user.id, display_name=user.display_name, avatar_url=user.avatar_url, - )) + ) - await db.flush() return user + # ── SSO flows ──────────────────────────────────────────────────────────── + async def handle_sso_registration( self, - db: AsyncSession, provider_type: str, provider_user_id: str, user_info: dict, existing_user: User | None = None, ) -> tuple[User, bool]: - """Handle SSO-based registration flow. - - If existing_user is provided, links the identity to that user. - Otherwise, creates a new user or returns existing one. - - Args: - db: Database session - provider_type: Provider type (feishu, dingtalk, etc.) - provider_user_id: User ID in external system - user_info: User info from provider - existing_user: Optional existing user to link to - - Returns: - Tuple of (user, is_new) - """ - # Try to detect tenant from email + """Handle SSO-based registration flow.""" email = user_info.get("email", "") - tenant = None tenant_id = None if email: - tenant = await self.detect_tenant_by_email(db, email) + tenant = await self.detect_tenant_by_email(email) tenant_id = tenant.id if tenant else None - # Check if identity already exists - lookup_provider_user_id = user_info.get("union_id") or user_info.get("unionId") or provider_user_id - existing = await sso_service.resolve_user_identity( - db, - lookup_provider_user_id, - provider_type, - tenant_id=tenant_id, - identity_data=user_info, + lookup_provider_user_id = ( + user_info.get("union_id") or user_info.get("unionId") or provider_user_id ) - - if existing: - # Identity already linked - return existing, False - - if existing_user: - # Link to existing user - await sso_service.link_identity( + async with identity_dao.session() as db: + existing = await sso_service.resolve_user_identity( db, - str(existing_user.id), - provider_type, lookup_provider_user_id, - user_info, - tenant_id=str(existing_user.tenant_id) if existing_user.tenant_id else tenant_id, + provider_type, + tenant_id=tenant_id, + identity_data=user_info, ) - return existing_user, False - - # (moved up) - pass + if existing: + return existing, False - # Step 2: Ensure Identity exists - # Generate username from email or provider ID (fallback to open_id) - effective_id = provider_user_id or user_info.get("open_id") or user_info.get("union_id") or uuid.uuid4().hex[:8] + if existing_user: + await sso_service.link_identity( + db, + str(existing_user.id), + provider_type, + lookup_provider_user_id, + user_info, + tenant_id=str(existing_user.tenant_id) if existing_user.tenant_id else tenant_id, + ) + return existing_user, False + + # Create new Identity + User + effective_id = ( + provider_user_id + or user_info.get("open_id") + or user_info.get("union_id") + or uuid.uuid4().hex[:8] + ) username = email.split("@")[0] if email else f"{provider_type}_{effective_id[:8]}" identity = await self.find_or_create_identity( - db, email=email, phone=user_info.get("mobile") or user_info.get("phone"), username=username, - password=effective_id, # Placeholder for SSO users + password=effective_id, ) - - # Step 3: Create User linked to Identity user = await self.create_user_with_identity( - db, identity=identity, display_name=user_info.get("name", username), registration_source=provider_type, tenant_id=tenant_id, ) - return user, True async def register_with_sso( self, - db: AsyncSession, provider_type: str, code: str, auth_provider, ) -> tuple[User, bool, str | None]: - """Register or login user via SSO. - - Args: - db: Database session - provider_type: Provider type - code: OAuth authorization code - auth_provider: Auth provider instance - - Returns: - Tuple of (user, is_new, error_message) - """ + """Register or login user via SSO.""" try: - # Exchange code for token token_data = await auth_provider.exchange_code_for_token(code) access_token = token_data.get("access_token") if not access_token: return None, False, "Failed to get access token from provider" - # Get user info from app.services.auth_provider import ExternalUserInfo user_info_obj = await auth_provider.get_user_info(access_token) - # Convert to dict user_info = { "name": user_info_obj.name, "email": user_info_obj.email, @@ -379,249 +285,183 @@ async def register_with_sso( "raw_data": user_info_obj.raw_data, } - # Try to detect tenant from email email_addr = user_info_obj.email tenant_id = None if email_addr: - tenant = await self.detect_tenant_by_email(db, email_addr) + tenant = await self.detect_tenant_by_email(email_addr) tenant_id = tenant.id if tenant else None - # Try to find existing user by identity - lookup_provider_user_id = user_info_obj.provider_union_id or user_info_obj.provider_user_id - existing_user = await sso_service.resolve_user_identity( - db, - lookup_provider_user_id, - provider_type, - tenant_id=tenant_id, - identity_data=user_info, + lookup_provider_user_id = ( + user_info_obj.provider_union_id or user_info_obj.provider_user_id ) + async with identity_dao.session() as db: + existing_user = await sso_service.resolve_user_identity( + db, + lookup_provider_user_id, + provider_type, + tenant_id=tenant_id, + identity_data=user_info, + ) + if existing_user: + return existing_user, False, None - if existing_user: - # Update last login - return existing_user, False, None - - # Also try matching by email - if user_info_obj.email: - existing_by_email = await sso_service.match_user_by_email(db, user_info_obj.email, tenant_id=tenant_id) - if existing_by_email: - # Link identity to existing user - await sso_service.link_identity( - db, - str(existing_by_email.id), - provider_type, - lookup_provider_user_id, - user_info, - tenant_id=str(existing_by_email.tenant_id) if existing_by_email.tenant_id else tenant_id, + if user_info_obj.email: + existing_by_email = await sso_service.match_user_by_email( + db, user_info_obj.email, tenant_id=tenant_id ) - return existing_by_email, False, None + if existing_by_email: + await sso_service.link_identity( + db, + str(existing_by_email.id), + provider_type, + lookup_provider_user_id, + user_info, + tenant_id=( + str(existing_by_email.tenant_id) + if existing_by_email.tenant_id + else tenant_id + ), + ) + return existing_by_email, False, None - # Create new user user, is_new = await self.handle_sso_registration( - db, provider_type, lookup_provider_user_id, user_info, ) - # Bind to OrgMember via email/phone if possible - await self.bind_org_member(db, user) - + await self.bind_org_member(user) return user, is_new, None - except Exception as e: + except Exception: logger.exception("SSO registration failed for %s provider", provider_type) - return None, False, f"SSO registration failed: {str(e)}" + return None, False, f"SSO registration failed" + + # ── Tenant for registration ────────────────────────────────────────────── async def get_tenant_for_registration( - self, db: AsyncSession, email: str | None = None, invitation_code: str | None = None + self, + email: str | None = None, + invitation_code: str | None = None, ) -> tuple[Tenant | None, str]: - """Determine tenant for new user registration. - - Args: - db: Database session - email: User email (for domain matching) - invitation_code: Invitation code (for tenant association) - - Returns: - Tuple of (tenant, error_message) - """ - # First check invitation code + """Determine tenant for new user registration.""" if invitation_code: - from app.models.invitation_code import InvitationCode - result = await db.execute( - select(InvitationCode).where( - InvitationCode.code == invitation_code, - InvitationCode.is_active == True, - InvitationCode.tenant_id.is_not(None), - ) - ) - inv = result.scalar_one_or_none() + inv = await invitation_code_dao.get_active_by_code(invitation_code) if inv and inv.used_count < inv.max_uses: - # Get tenant from invitation - tenant_result = await db.execute(select(Tenant).where(Tenant.id == inv.tenant_id)) - tenant = tenant_result.scalar_one_or_none() - if tenant and tenant.is_active: - return tenant, None + t = await tenant_dao.get(inv.tenant_id) + if t and t.is_active: + return t, None return None, "Invitation code tenant is inactive" - # Try email domain matching if email: - tenant = await self.detect_tenant_by_email(db, email) + tenant = await self.detect_tenant_by_email(email) if tenant: return tenant, None - # No tenant association - user will need to create/join return None, None - async def bind_org_member(self, db: AsyncSession, user: User) -> None: - """Find and bind OrgMember to User based on email/phone and tenant_id. - - This establishes the link between a platform user and their entry in the - synchronized organizational structure. - """ + # ── OrgMember binding ──────────────────────────────────────────────────── + + async def bind_org_member(self, user: User) -> None: + """Find and bind OrgMember to User based on email/phone and tenant_id.""" if not user.tenant_id: return - from app.models.org import OrgMember - member = await self._find_unbound_org_member_by_contact(db, user) + member = await self._find_unbound_org_member_by_contact(user) if member: member.user_id = user.id if user.email and member.email != user.email: member.email = user.email elif not user.email and member.email: user.email = member.email - if user.primary_mobile and member.phone != user.primary_mobile: member.phone = user.primary_mobile elif not user.primary_mobile and member.phone: user.primary_mobile = member.phone - await db.flush() + + async with org_member_dao.session() as db: + await db.flush() from app.services.okr_agent_hook import hook_new_org_member - await hook_new_org_member(db, member.id, user.tenant_id) + async with org_member_dao.session() as db: + await hook_new_org_member(db, member.id, user.tenant_id) - await self.ensure_web_org_member(db, user) - - async def _find_unbound_org_member_by_contact( - self, - db: AsyncSession, - user: User, - ): - from app.models.org import OrgMember + await self.ensure_web_org_member(user) + async def _find_unbound_org_member_by_contact(self, user: User): if user.email: - result = await db.execute( - select(OrgMember).where( - OrgMember.email == user.email, - OrgMember.tenant_id == user.tenant_id, - OrgMember.user_id == None, - ).limit(1) - ) - member = result.scalar_one_or_none() + member = await org_member_dao.find_unbound_by_email(user.email, user.tenant_id) if member: return member - if user.primary_mobile: - result = await db.execute( - select(OrgMember).where( - OrgMember.phone == user.primary_mobile, - OrgMember.tenant_id == user.tenant_id, - OrgMember.user_id == None, - ).limit(1) - ) - member = result.scalar_one_or_none() - if member: - return member - + return await org_member_dao.find_unbound_by_phone(user.primary_mobile, user.tenant_id) return None - async def ensure_web_org_member(self, db: AsyncSession, user: User): + async def ensure_web_org_member(self, user: User): """Ensure the user has a dedicated platform OrgMember record in their tenant.""" if not user.tenant_id: return None from app.models.org import OrgMember - web_provider = await self.ensure_identity_provider( - db, - "web", - user.tenant_id, - name="Platform", - ) + web_provider = await self.ensure_identity_provider("web", user.tenant_id, name="Platform") if web_provider.name == "Web": web_provider.name = "Platform" - result = await db.execute( - select(OrgMember).where( - OrgMember.user_id == user.id, - OrgMember.tenant_id == user.tenant_id, - OrgMember.provider_id == web_provider.id, - ).limit(1) + # Look up existing OrgMember + member = await org_member_dao.get_by_user_and_provider( + user.id, user.tenant_id, web_provider.id ) - member = result.scalar_one_or_none() - if not member and user.email: - result = await db.execute( - select(OrgMember).where( - OrgMember.email == user.email, - OrgMember.tenant_id == user.tenant_id, - OrgMember.provider_id == web_provider.id, - OrgMember.user_id == None, - ).limit(1) + member = await org_member_dao.find_unbound_by_email_and_provider( + user.email, user.tenant_id, web_provider.id ) - member = result.scalar_one_or_none() - if not member and user.primary_mobile: - result = await db.execute( - select(OrgMember).where( - OrgMember.phone == user.primary_mobile, - OrgMember.tenant_id == user.tenant_id, - OrgMember.provider_id == web_provider.id, - OrgMember.user_id == None, - ).limit(1) + member = await org_member_dao.find_unbound_by_phone_and_provider( + user.primary_mobile, user.tenant_id, web_provider.id ) - member = result.scalar_one_or_none() created = False linked_existing = False - if member: - linked_existing = member.user_id is None - member.user_id = user.id - else: - member = OrgMember( - name=user.display_name or "User", - email=user.email, - phone=user.primary_mobile, - provider_id=web_provider.id, - title="Platform User", - tenant_id=user.tenant_id, - user_id=user.id, - status="active", - ) - db.add(member) - created = True + async with org_member_dao.session() as db: + if member: + linked_existing = member.user_id is None + member.user_id = user.id + else: + member = OrgMember( + name=user.display_name or "User", + email=user.email, + phone=user.primary_mobile, + provider_id=web_provider.id, + title="Platform User", + tenant_id=user.tenant_id, + user_id=user.id, + status="active", + ) + db.add(member) + created = True - desired_name = user.display_name or member.name or "User" - if desired_name and member.name != desired_name: - member.name = desired_name - if member.email != user.email: - member.email = user.email - if member.phone != user.primary_mobile: - member.phone = user.primary_mobile - if member.title in (None, "", "Web User"): - member.title = "Platform User" + desired_name = user.display_name or member.name or "User" + if desired_name and member.name != desired_name: + member.name = desired_name + if member.email != user.email: + member.email = user.email + if member.phone != user.primary_mobile: + member.phone = user.primary_mobile + if member.title in (None, "", "Web User"): + member.title = "Platform User" - await db.flush() + await db.flush() if created or linked_existing: from app.services.okr_agent_hook import hook_new_org_member - await hook_new_org_member(db, member.id, user.tenant_id) + async with org_member_dao.session() as db: + await hook_new_org_member(db, member.id, user.tenant_id) return member async def sync_org_member_contact_from_user( self, - db: AsyncSession, user: User, *, sync_email: bool = False, @@ -631,29 +471,23 @@ async def sync_org_member_contact_from_user( if not user.tenant_id or not (sync_email or sync_phone): return - from app.models.org import OrgMember - web_provider = await self.ensure_identity_provider(db, "web", user.tenant_id, name="Platform") + web_provider = await self.ensure_identity_provider("web", user.tenant_id, name="Platform") if web_provider.name == "Web": web_provider.name = "Platform" - result = await db.execute( - select(OrgMember).where( - OrgMember.user_id == user.id, - OrgMember.tenant_id == user.tenant_id, - OrgMember.provider_id == web_provider.id, - ) + members = await org_member_dao.get_by_user_and_tenant_and_provider( + user.id, user.tenant_id, web_provider.id ) - members = result.scalars().all() if not members: return - for member in members: - if sync_email and member.email != user.email: - member.email = user.email - if sync_phone and member.phone != user.primary_mobile: - member.phone = user.primary_mobile - - await db.flush() + async with org_member_dao.session() as db: + for member in members: + if sync_email and member.email != user.email: + member.email = user.email + if sync_phone and member.phone != user.primary_mobile: + member.phone = user.primary_mobile + await db.flush() # Global registration service diff --git a/backend/app/services/system_email_service.py b/backend/app/services/system_email_service.py index db50d5fcd..14c90a2b3 100644 --- a/backend/app/services/system_email_service.py +++ b/backend/app/services/system_email_service.py @@ -51,19 +51,18 @@ class BroadcastEmailRecipient: -async def resolve_email_config_async(db, *, include_disabled: bool = False) -> SystemEmailConfig | None: - """Resolve email configuration by searching in order: - 1. Platform-level settings in DB ('system_email_platform') +async def resolve_email_config_async(db=None, *, include_disabled: bool = False) -> SystemEmailConfig | None: + """Resolve email configuration from the 'system_email_platform' system setting. + + ``db`` is accepted for call-site compatibility but ignored — the lookup + goes through ``system_setting_dao`` which manages its own session. """ - from sqlalchemy import select - from app.models.system_settings import SystemSetting + from app.dao import system_setting_dao - # 1. Try platform-level config in DB + # Try platform-level config in DB try: - result = await db.execute(select(SystemSetting).where(SystemSetting.key == "system_email_platform")) - setting = result.scalar_one_or_none() - if setting and setting.value: - v = setting.value + v = await system_setting_dao.get_value("system_email_platform", {}) + if v: if v.get("SYSTEM_EMAIL_ENABLED") is False and not include_disabled: return None if v.get("SYSTEM_EMAIL_FROM_ADDRESS") and v.get("SYSTEM_SMTP_HOST"): @@ -72,7 +71,8 @@ async def resolve_email_config_async(db, *, include_disabled: bool = False) -> S from_name=str(v.get("SYSTEM_EMAIL_FROM_NAME", "Clawith")).strip() or "Clawith", smtp_host=str(v.get("SYSTEM_SMTP_HOST", "")).strip(), smtp_port=int(v.get("SYSTEM_SMTP_PORT", 465)), - smtp_username=str(v.get("SYSTEM_SMTP_USERNAME", "")).strip() or str(v.get("SYSTEM_EMAIL_FROM_ADDRESS", "")).strip(), + smtp_username=str(v.get("SYSTEM_SMTP_USERNAME", "")).strip() + or str(v.get("SYSTEM_EMAIL_FROM_ADDRESS", "")).strip(), smtp_password=str(v.get("SYSTEM_SMTP_PASSWORD", "")), smtp_ssl=bool(v.get("SYSTEM_SMTP_SSL", True)), smtp_timeout_seconds=max(1, int(v.get("SYSTEM_SMTP_TIMEOUT_SECONDS", 15))), @@ -90,14 +90,9 @@ async def send_system_email(to: str, subject: str, body: str, db=None) -> None: to: Recipient email address subject: Email subject body: Email body text - db: Optional database session + db: Ignored; kept for call-site compatibility """ - if not db: - from app.database import async_session - async with async_session() as session: - config = await resolve_email_config_async(session) - else: - config = await resolve_email_config_async(db) + config = await resolve_email_config_async() if not config: logger.warning(f"System email not configured, skipped sending to {to}") @@ -233,36 +228,23 @@ async def deliver_broadcast_emails(recipients: Iterable[BroadcastEmailRecipient] } -async def get_email_templates(db=None) -> dict[str, dict[str, str]]: +async def get_email_templates() -> dict[str, dict[str, str]]: """Load email templates from DB, falling back to defaults. Returns: A dict mapping scenario_key -> {"subject": str, "body": str} """ - from sqlalchemy import select - from app.models.system_settings import SystemSetting - templates = dict(DEFAULT_EMAIL_TEMPLATES) # start with defaults + return await _load_templates_from_db(templates) - if not db: - from app.database import async_session - async with async_session() as session: - return await _load_templates_from_db(session, templates) - return await _load_templates_from_db(db, templates) - -async def _load_templates_from_db(db, templates: dict) -> dict: +async def _load_templates_from_db(templates: dict) -> dict: """Internal helper: overlay DB-saved templates on top of defaults.""" - from sqlalchemy import select - from app.models.system_settings import SystemSetting + from app.dao import system_setting_dao try: - result = await db.execute( - select(SystemSetting).where(SystemSetting.key == "email_templates") - ) - setting = result.scalar_one_or_none() - if setting and setting.value: - saved = setting.value + saved = await system_setting_dao.get_value("email_templates", {}) + if saved: for key in templates: if key in saved and isinstance(saved[key], dict): # Only override subject/body if present and non-empty @@ -287,19 +269,19 @@ def _render_template(template_str: str, variables: dict[str, str]) -> str: async def render_email_template( scenario_key: str, variables: dict[str, str], - db=None, + db=None, # kept for call-site compat, ignored ) -> tuple[str, str]: """Render an email template for a given scenario. Args: scenario_key: One of the known scenario keys (e.g. 'email_verification') variables: Dict of variable_name -> value to substitute - db: Optional database session + db: Ignored; kept for backward-compatibility Returns: (subject, body) tuple with variables substituted """ - templates = await get_email_templates(db=db) + templates = await get_email_templates() template = templates.get(scenario_key, DEFAULT_EMAIL_TEMPLATES.get(scenario_key, {})) subject = _render_template(template.get("subject", ""), variables) @@ -312,15 +294,9 @@ async def send_test_email(to: str, db=None) -> None: Args: to: Recipient email address - db: Optional database session for resolving config + db: Ignored; kept for call-site compatibility """ - config = None - if not db: - from app.database import async_session - async with async_session() as session: - config = await resolve_email_config_async(session, include_disabled=True) - else: - config = await resolve_email_config_async(db, include_disabled=True) + config = await resolve_email_config_async(include_disabled=True) if not config: raise RuntimeError("System email SMTP settings are not configured.") diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 39eb3eb3b..b805801f3 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -9,6 +9,15 @@ from app.api import auth as auth_api from app.core.security import hash_password +from app.database import _session_ctx + + +async def run_with_db(db, func, *args, **kwargs): + token = _session_ctx.set(db) + try: + return await func(*args, **kwargs) + finally: + _session_ctx.reset(token) # --------------------------------------------------------------------------- @@ -111,7 +120,7 @@ async def test_login_invalid_credentials_no_identity(): bg = AsyncMock() with pytest.raises(HTTPException) as exc: - await auth_api.login(data, bg, db) + await run_with_db(db, auth_api.login, data, bg) assert exc.value.status_code == 401 @@ -124,7 +133,7 @@ async def test_login_invalid_credentials_wrong_password(): bg = AsyncMock() with pytest.raises(HTTPException) as exc: - await auth_api.login(data, bg, db) + await run_with_db(db, auth_api.login, data, bg) assert exc.value.status_code == 401 @@ -137,7 +146,7 @@ async def test_login_disabled_account(): bg = AsyncMock() with pytest.raises(HTTPException) as exc: - await auth_api.login(data, bg, db) + await run_with_db(db, auth_api.login, data, bg) assert exc.value.status_code == 403 assert "disabled" in str(exc.value.detail).lower() @@ -157,7 +166,7 @@ async def test_login_unverified_email(): with patch("app.services.system_email_service.resolve_email_config_async", new_callable=AsyncMock, return_value={"host": "localhost"}): with patch.object(auth_api, "_send_verification_email_task", new_callable=AsyncMock): with pytest.raises(HTTPException) as exc: - await auth_api.login(data, bg, db) + await run_with_db(db, auth_api.login, data, bg) assert exc.value.status_code == 403 assert exc.value.detail["needs_verification"] is True @@ -222,7 +231,7 @@ def __init__(self, access_token, **kwargs): with patch("app.api.auth.UserOut") as MockUserOut: MockUserOut.model_validate.return_value = {"id": str(user.id)} with patch.object(auth_api, "create_access_token", return_value="jwt-token"): - result = await auth_api.oauth_callback("google", data, RecordingDB()) + result = await run_with_db(RecordingDB(), auth_api.oauth_callback, "google", data) provider.exchange_code_for_token.assert_awaited_once_with("oauth-code", "https://example.com/oauth/callback/google") assert result.access_token == "jwt-token" diff --git a/backend/tests/test_auth_provider.py b/backend/tests/test_auth_provider.py index 1bbf25f85..dc99a0dc7 100644 --- a/backend/tests/test_auth_provider.py +++ b/backend/tests/test_auth_provider.py @@ -135,7 +135,12 @@ async def test_auth_registry_uses_preferred_provider_when_duplicates_exist(): db = _DummyDB([[provider]]) registry = AuthProviderRegistry() - result = await registry.get_provider(db, "google_workspace", str(tenant_id)) + from app.database import _session_ctx + token = _session_ctx.set(db) + try: + result = await registry.get_provider("google_workspace", str(tenant_id)) + finally: + _session_ctx.reset(token) assert result is not None assert result.provider is provider diff --git a/backend/tests/test_password_reset_and_notifications.py b/backend/tests/test_password_reset_and_notifications.py index 03ddec433..7b353aa52 100644 --- a/backend/tests/test_password_reset_and_notifications.py +++ b/backend/tests/test_password_reset_and_notifications.py @@ -9,10 +9,16 @@ from app.api import auth as auth_api from app.api.notification import BroadcastRequest, broadcast_notification -from app.core.security import verify_password +from app.core.security import verify_password, hash_password from app.models.user import User from app.schemas.schemas import ForgotPasswordRequest, ResetPasswordRequest from app.services import password_reset_service, system_email_service +from app.database import _session_ctx, transaction + + +async def run_with_db(db, func, *args, **kwargs): + async with transaction(db): + return await func(*args, **kwargs) class DummyScalars: @@ -35,6 +41,38 @@ def scalars(self): return DummyScalars(self._values) +class MockPipeline: + def __init__(self, redis): + self.redis = redis + self.commands = [] + + def setex(self, key, ttl, value): + self.commands.append(("setex", key, ttl, value)) + return self + + def delete(self, key): + self.commands.append(("delete", key)) + return self + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + pass + + async def execute(self): + for cmd in self.commands: + if cmd[0] == "setex": + _, key, ttl, value = cmd + self.redis.setex_calls.append((key, ttl, value)) + self.redis._data[key] = value + elif cmd[0] == "delete": + _, key = cmd + self.redis.deleted.append(key) + self.redis._data.pop(key, None) + self.commands.clear() + + class MockRedis: def __init__(self, initial_data=None): self._data = initial_data or {} @@ -53,16 +91,7 @@ async def setex(self, key, ttl, value): self._data[key] = value def pipeline(self, transaction=True): - return self - - async def __aenter__(self): - return self - - async def __aexit__(self, *_): - pass - - async def execute(self): - pass + return MockPipeline(self) class RecordingDB: @@ -86,6 +115,7 @@ async def flush(self): self.flushed = True async def commit(self): + self.flushed = True self.committed = True @@ -111,12 +141,12 @@ async def test_create_password_reset_token_invalidates_older_tokens(monkeypatch) "get_settings", lambda: SimpleNamespace(PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=15, PUBLIC_BASE_URL=""), ) - mock_redis = MockRedis(initial_data={"pwd_reset:user:user-id-123": "old-token-hash"}) + user_id = uuid.uuid4() + mock_redis = MockRedis(initial_data={f"pwd_reset:user:{user_id}": "old-token-hash"}) async def fake_get_redis(): return mock_redis monkeypatch.setattr(password_reset_service, "get_redis", fake_get_redis) db = RecordingDB() - user_id = uuid.uuid4() raw_token, expires_at = await password_reset_service.create_password_reset_token(user_id) @@ -132,14 +162,9 @@ async def fake_get_redis(): return mock_redis @pytest.mark.asyncio async def test_build_password_reset_url_uses_env_public_base_url(monkeypatch): - monkeypatch.setattr( - password_reset_service, - "get_settings", - lambda: SimpleNamespace(PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=30, PUBLIC_BASE_URL="https://app.example.com/"), - ) - db = RecordingDB([DummyResult(None)]) + monkeypatch.setenv("PUBLIC_BASE_URL", "https://app.example.com/") - url = await password_reset_service.build_password_reset_url(db, "abc123") + url = await password_reset_service.build_password_reset_url("abc123") assert url == "https://app.example.com/reset-password?token=abc123" @@ -162,21 +187,42 @@ async def fake_get_redis(): return mock_redis result = await password_reset_service.consume_password_reset_token(raw_token) assert result is not None - assert result["user_id"] == user_id + assert result["identity_id"] == user_id # Should be deleted after consumption assert f"pwd_reset:token:{token_hash}" in mock_redis.deleted assert f"pwd_reset:user:{user_id}" in mock_redis.deleted @pytest.mark.asyncio -async def test_forgot_password_returns_generic_response_for_unknown_email(): - db = RecordingDB([DummyResult(None)]) +async def test_forgot_password_returns_generic_response_for_unknown_email(monkeypatch): + async def fake_resolve_email_config_async(): + return system_email_service.SystemEmailConfig( + from_address="bot@example.com", + from_name="Clawith", + smtp_host="smtp.example.com", + smtp_port=465, + smtp_username="bot@example.com", + smtp_password="secret", + smtp_ssl=True, + smtp_timeout_seconds=15, + ) + monkeypatch.setattr( + "app.services.system_email_service.resolve_email_config_async", + fake_resolve_email_config_async, + ) background_tasks = BackgroundTasks() + # Patch identity_dao.get_by_email to return None + from app.dao import identity_dao + + async def fake_get_by_email(email): + return None + + monkeypatch.setattr(identity_dao, "get_by_email", fake_get_by_email) + response = await auth_api.forgot_password( ForgotPasswordRequest(email="missing@example.com"), background_tasks, - db, ) assert response == { @@ -191,8 +237,23 @@ async def test_forgot_password_returns_generic_response_for_unknown_email(): @pytest.mark.asyncio async def test_forgot_password_queues_background_email(monkeypatch): + async def fake_resolve_email_config_async(): + return system_email_service.SystemEmailConfig( + from_address="bot@example.com", + from_name="Clawith", + smtp_host="smtp.example.com", + smtp_port=465, + smtp_username="bot@example.com", + smtp_password="secret", + smtp_ssl=True, + smtp_timeout_seconds=15, + ) + monkeypatch.setattr( + "app.services.system_email_service.resolve_email_config_async", + fake_resolve_email_config_async, + ) + user = make_user() - db = RecordingDB([DummyResult(user)]) background_tasks = BackgroundTasks() async def fake_create_password_reset_token(*_args, **_kwargs): @@ -204,11 +265,17 @@ async def fake_build_password_reset_url(*_args, **_kwargs): monkeypatch.setattr(password_reset_service, "create_password_reset_token", fake_create_password_reset_token) monkeypatch.setattr(password_reset_service, "build_password_reset_url", fake_build_password_reset_url) + # Patch identity_dao.get_by_email to return our fake user + from app.dao import identity_dao + + async def fake_get_by_email(email): + return user - response = await auth_api.forgot_password(ForgotPasswordRequest(email=user.email), background_tasks, db) + monkeypatch.setattr(identity_dao, "get_by_email", fake_get_by_email) + + response = await auth_api.forgot_password(ForgotPasswordRequest(email=user.email), background_tasks) assert response["ok"] is True - assert db.committed is True assert len(background_tasks.tasks) == 1 @@ -260,17 +327,18 @@ def sendmail(self, from_address: str, to_addresses: list[str], message: str): @pytest.mark.asyncio async def test_reset_password_updates_user(monkeypatch): - user = make_user(password_hash=auth_api.hash_password("old-password")) + user = make_user(password_hash=hash_password("old-password")) db = RecordingDB([DummyResult(user)]) async def fake_consume_password_reset_token(*_args, **_kwargs): - return {"user_id": user.id} + return {"identity_id": user.id} monkeypatch.setattr(password_reset_service, "consume_password_reset_token", fake_consume_password_reset_token) - response = await auth_api.reset_password( - ResetPasswordRequest(token="t" * 20, new_password="new-password"), + response = await run_with_db( db, + auth_api.reset_password, + ResetPasswordRequest(token="t" * 20, new_password="new-password"), ) assert response == {"ok": True} diff --git a/backend/tests/test_sso_toggle.py b/backend/tests/test_sso_toggle.py index 0a8522120..6228c0faa 100644 --- a/backend/tests/test_sso_toggle.py +++ b/backend/tests/test_sso_toggle.py @@ -7,6 +7,15 @@ from app.api import tenants as tenants_api from app.services.platform_service import platform_service from tests.test_auth import RecordingDB, DummyResult +from app.database import _session_ctx + + +async def run_with_db(db, func, *args, **kwargs): + token = _session_ctx.set(db) + try: + return await func(*args, **kwargs) + finally: + _session_ctx.reset(token) @pytest.mark.asyncio async def test_get_platform_settings_sso_toggle_default(): @@ -68,23 +77,20 @@ async def test_resolve_tenant_by_domain_sso_toggle(): @pytest.mark.asyncio async def test_get_tenant_sso_base_url_toggle(): - """Verify that get_tenant_sso_base_url respects the sso_custom_domain_redirect_enabled toggle.""" + """Verify that get_tenant_sso_base_url respects the sso_redirect_enabled kwarg.""" tenant = SimpleNamespace(slug="acme", sso_domain="https://acme.com") - + # 1. Enabled: returns the custom sso_domain - db_enabled = RecordingDB(responses=[ - DummyResult(), # sso_custom_domain_redirect_enabled -> None (default True) - ]) - url = await platform_service.get_tenant_sso_base_url(db=db_enabled, tenant=tenant) + url = await platform_service.get_tenant_sso_base_url( + db=None, tenant=tenant, sso_redirect_enabled=True + ) assert url == "https://acme.com" # 2. Disabled: falls back to public base URL - setting_disabled = SimpleNamespace(key="sso_custom_domain_redirect_enabled", value={"enabled": False}) - db_disabled = RecordingDB(responses=[ - DummyResult(values=[setting_disabled]), # sso_custom_domain_redirect_enabled -> False - ]) with patch.object(platform_service, "get_public_base_url", return_value="https://try.clawith.ai"): - url = await platform_service.get_tenant_sso_base_url(db=db_disabled, tenant=tenant) + url = await platform_service.get_tenant_sso_base_url( + db=None, tenant=tenant, sso_redirect_enabled=False + ) assert url == "https://try.clawith.ai" @@ -110,7 +116,7 @@ async def test_switch_tenant_sso_toggle(): DummyResult(), # platform_service setting check (default True) ]) with patch("app.api.auth.create_access_token", return_value="jwt-token"): - res = await auth_api.switch_tenant(data, request, current_user, db_enabled) + res = await run_with_db(db_enabled, auth_api.switch_tenant, data, request, current_user) assert res.access_token == "jwt-token" assert res.redirect_url is not None assert "https://acme.com" in res.redirect_url @@ -123,6 +129,6 @@ async def test_switch_tenant_sso_toggle(): DummyResult(values=[setting_disabled]), # auth_api setting check (disabled) ]) with patch("app.api.auth.create_access_token", return_value="jwt-token"): - res = await auth_api.switch_tenant(data, request, current_user, db_disabled) + res = await run_with_db(db_disabled, auth_api.switch_tenant, data, request, current_user) assert res.access_token == "jwt-token" assert res.redirect_url is None diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index b2a04275c..0c2f98464 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -50,7 +50,6 @@ services: CORS_ORIGINS: '["*"]' FEISHU_APP_ID: ${FEISHU_APP_ID:-} FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} - DOCKER_NETWORK: clawith_yaojin_network SS_CONFIG_FILE: /data/ss-nodes.json # Public base URL for constructing OAuth callback URLs and email links. # Required when deployed behind a reverse proxy (e.g. Nginx, Cloudflare). From 5a72456380320d99b690538dd243827241ea2cdf Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 12 Jun 2026 23:11:24 +0800 Subject: [PATCH 3/3] update --- docker-compose.yml | 105 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..b6c2b2984 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,105 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + networks: + - default + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U clawith" ] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + networks: + - default + volumes: + - redisdata:/data + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] + interval: 5s + timeout: 5s + retries: 5 + + backend: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: all + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + SS_CONFIG_FILE: /data/ss-nodes.json + # Public base URL for constructing OAuth callback URLs and email links. + # Required when deployed behind a reverse proxy (e.g. Nginx, Cloudflare). + # If not set, the server infers the URL from the incoming request host. + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + # Password reset token lifetime in minutes (default: 30) + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + privileged: true + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + - apparmor=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + frontend: + build: ./frontend + restart: unless-stopped + ports: + - "${FRONTEND_PORT:-3008}:3000" + environment: + VITE_API_URL: http://localhost:8000 + API_UPSTREAM: ${API_UPSTREAM:-backend:8000} + MINIO_UPSTREAM: ${MINIO_UPSTREAM:-minio:9000} + volumes: + - ./frontend/nginx.conf.template:/etc/nginx/templates/default.conf.template:ro + networks: + - default + depends_on: + - backend + +volumes: + pgdata: + redisdata: + + +networks: + default: + name: clawith_network