Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
530 changes: 298 additions & 232 deletions backend/app/api/agents.py

Large diffs are not rendered by default.

1,087 changes: 507 additions & 580 deletions backend/app/api/auth.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backend/app/api/dingtalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ async def dingtalk_callback(
pass

# 2. Get DingTalk provider config
auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(tenant_id) if tenant_id else None)
auth_provider = await auth_provider_registry.get_provider("dingtalk", str(tenant_id) if tenant_id else None)
if not auth_provider:
return HTMLResponse("Auth failed: DingTalk provider not configured for this tenant")

Expand Down
2 changes: 1 addition & 1 deletion backend/app/api/google_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def _handle_google_sso_callback(
auth_provider = GoogleWorkspaceAuthProvider(provider=provider, config=provider.config or {})
else:
auth_provider = await auth_provider_registry.get_provider(
db, "google_workspace", str(tenant_id) if tenant_id else None
"google_workspace", str(tenant_id) if tenant_id else None
)
if not auth_provider:
return HTMLResponse("Auth failed: Google Workspace provider not configured for this tenant")
Expand Down
3 changes: 1 addition & 2 deletions backend/app/api/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ async def broadcast_notification(
from app.services.system_email_service import (
BroadcastEmailRecipient,
deliver_broadcast_emails,
run_background_email_job,
)

for user in users:
Expand All @@ -205,7 +204,7 @@ async def broadcast_notification(

await db.commit()
if email_recipients:
background_tasks.add_task(run_background_email_job, deliver_broadcast_emails, email_recipients)
background_tasks.add_task(deliver_broadcast_emails, email_recipients)
return {
"ok": True,
"users_notified": count_users,
Expand Down
1 change: 0 additions & 1 deletion backend/app/api/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ async def admin_update_user(
if "email" in update_data or "primary_mobile" in update_data:
from app.services.registration_service import registration_service
await registration_service.sync_org_member_contact_from_user(
db,
user,
sync_email="email" in update_data,
sync_phone="primary_mobile" in update_data,
Expand Down
4 changes: 2 additions & 2 deletions backend/app/api/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De

elif p.provider_type == "dingtalk":
from app.services.auth_registry import auth_provider_registry
auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(session.tenant_id) if session.tenant_id else None)
auth_provider = await auth_provider_registry.get_provider("dingtalk", str(session.tenant_id) if session.tenant_id else None)
if auth_provider:
redir = f"{public_base}/api/auth/dingtalk/callback"
# Use provider's standardized authorization URL
Expand All @@ -147,7 +147,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
sign_google_sso_state,
)
auth_provider = await auth_provider_registry.get_provider(
db, "google_workspace", str(session.tenant_id) if session.tenant_id else None
"google_workspace", str(session.tenant_id) if session.tenant_id else None
)
if auth_provider:
redir = await get_google_redirect_uri(db, p, request)
Expand Down
1 change: 0 additions & 1 deletion backend/app/api/wecom.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,6 @@ async def wecom_callback(
# 2. Extract user info and login/register via RegistrationService
try:
auth_provider = await auth_provider_registry.get_provider(
db,
"wecom",
str(tenant_id) if tenant_id else (str(provider.tenant_id) if provider.tenant_id else None),
)
Expand Down
19 changes: 19 additions & 0 deletions backend/app/dao/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from app.dao.identity_dao import identity_dao
from app.dao.identity_provider_dao import identity_provider_dao
from app.dao.invitation_code_dao import invitation_code_dao
from app.dao.org_member_dao import org_member_dao
from app.dao.participant_dao import participant_dao
from app.dao.system_setting_dao import system_setting_dao
from app.dao.tenant_dao import tenant_dao
from app.dao.user_dao import user_dao

__all__ = [
"identity_dao",
"identity_provider_dao",
"invitation_code_dao",
"org_member_dao",
"participant_dao",
"system_setting_dao",
"tenant_dao",
"user_dao",
]
80 changes: 80 additions & 0 deletions backend/app/dao/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import AsyncGenerator, Sequence
from contextlib import asynccontextmanager
from typing import Any, Generic, Type, TypeVar

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.database import Base, _session_ctx, async_session

ModelType = TypeVar("ModelType", bound=Base)


class BaseDAO(Generic[ModelType]):
"""Base class for data access objects, managing session context and basic CRUD."""

def __init__(self, model: Type[ModelType]):
self.model = model

@asynccontextmanager
async def session(self) -> AsyncGenerator[AsyncSession, None]:
"""Context manager yielding the active context session or a new one."""
context_session = _session_ctx.get()
if context_session is not None:
yield context_session
else:
async with async_session() as session:
yield session
Comment on lines +26 to +27

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve DAO writes outside explicit transactions

When no ContextVar transaction is active, this fallback opens a fresh AsyncSession and yields it, but the DAO write helpers only flush() and this context closes without committing. Several refactored callers still already have a request/session but are not wrapped in transaction() (for example BaseAuthProvider._create_new_user now calls find_or_create_identity(...) without passing its db), so first-time SSO/channel identity creation is rolled back before the subsequent User insert references it, leading to missing identities or FK failures. Bind the caller's session into the context or ensure standalone DAO write sessions commit/rollback.

Useful? React with 👍 / 👎.


async def get(self, id: Any) -> ModelType | None:
"""Fetch a single record by its primary key ID."""
async with self.session() as db:
if hasattr(db, "get"):
return await db.get(self.model, id)
# Fallback for custom mock DB clients in tests
stmt = select(self.model).where(self.model.id == id)
result = await db.execute(stmt)
return result.scalar_one_or_none()

async def is_empty(self) -> bool:
"""Check if the table is empty (no records)."""
async with self.session() as db:
stmt = select(self.model.id).limit(1)
result = await db.execute(stmt)
return result.scalar() is None

async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[ModelType]:
"""Fetch all records with offset and limit."""
async with self.session() as db:
stmt = select(self.model).offset(skip).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()

async def create(self, *, obj_in: dict[str, Any]) -> ModelType:
"""Create a new record."""
async with self.session() as db:
db_obj = self.model(**obj_in)
db.add(db_obj)
await db.flush()
return db_obj

async def update(self, *, db_obj: ModelType, obj_in: dict[str, Any]) -> ModelType:
"""Update an existing record."""
async with self.session() as db:
for field, value in obj_in.items():
if hasattr(db_obj, field):
setattr(db_obj, field, value)
db.add(db_obj)
await db.flush()
return db_obj

async def delete(self, *, id: Any) -> ModelType | None:
"""Delete a record by ID."""
async with self.session() as db:
obj = await self.get(id)
if obj:
if hasattr(db, "delete"):
await db.delete(obj)
await db.flush()
return obj

82 changes: 82 additions & 0 deletions backend/app/dao/identity_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import re
import uuid
from typing import Any

from sqlalchemy import select

from app.dao.base import BaseDAO
from app.models.user import Identity


class IdentityDAO(BaseDAO[Identity]):
"""DAO for Identity model handling authentication credentials."""

def __init__(self) -> None:
super().__init__(Identity)

async def get_by_login_identifier(self, identifier: str) -> Identity | None:
"""Find identity by email, phone, or username."""
async with self.session() as db:
query = select(Identity).where(
(Identity.email == identifier) | (Identity.phone == identifier) | (Identity.username == identifier)
)
result = await db.execute(query)
return result.scalar_one_or_none()

async def get_by_email(self, email: str) -> Identity | None:
"""Find identity by email address."""
async with self.session() as db:
query = select(Identity).where(Identity.email == email)
result = await db.execute(query)
return result.scalar_one_or_none()

async def get_by_username(self, username: str) -> Identity | None:
"""Find identity by username."""
async with self.session() as db:
query = select(Identity).where(Identity.username == username)
result = await db.execute(query)
return result.scalar_one_or_none()

async def get_by_phone(self, phone: str) -> Identity | None:
"""Find identity by normalized phone number."""
normalized = re.sub(r"[\s\-\+]", "", phone)
async with self.session() as db:
query = select(Identity).where(Identity.phone == normalized)
result = await db.execute(query)
return result.scalar_one_or_none()

async def is_username_taken(self, username: str) -> bool:
"""Return True if the username is already used by another identity."""
async with self.session() as db:
result = await db.execute(
select(Identity.id).where(Identity.username == username).limit(1)
)
return result.scalar_one_or_none() is not None

async def create_identity(
self,
*,
email: str | None = None,
phone: str | None = None,
username: str | None = None,
password_hash: str | None = None,
is_platform_admin: bool = False,
email_verified: bool = False,
) -> Identity:
"""Create and flush a new Identity row."""
normalized_phone = re.sub(r"[\s\-\+]", "", phone) if phone else None
async with self.session() as db:
identity = Identity(
email=email,
phone=normalized_phone,
username=username,
password_hash=password_hash,
is_platform_admin=is_platform_admin,
email_verified=email_verified,
)
db.add(identity)
await db.flush()
return identity


identity_dao = IdentityDAO()
61 changes: 61 additions & 0 deletions backend/app/dao/identity_provider_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""DAO for IdentityProvider model."""

from typing import Any

from sqlalchemy import select

from app.dao.base import BaseDAO
from app.models.identity import IdentityProvider


class IdentityProviderDAO(BaseDAO[IdentityProvider]):
"""DAO for IdentityProvider model."""

def __init__(self) -> None:
super().__init__(IdentityProvider)

async def get_by_type_and_tenant(
self,
provider_type: str,
tenant_id: Any | None,
) -> IdentityProvider | None:
"""Find an IdentityProvider by type scoped to a tenant (or global if None)."""
async with self.session() as db:
query = select(IdentityProvider).where(
IdentityProvider.provider_type == provider_type,
)
if tenant_id is None:
query = query.where(IdentityProvider.tenant_id.is_(None))
else:
query = query.where(IdentityProvider.tenant_id == tenant_id)
result = await db.execute(query)
return result.scalar_one_or_none()

async def get_or_create(
self,
provider_type: str,
tenant_id: Any | None,
*,
name: str | None = None,
sso_login_enabled: bool = False,
) -> IdentityProvider:
"""Get an existing IdentityProvider or create it if missing."""
provider = await self.get_by_type_and_tenant(provider_type, tenant_id)
if provider:
return provider

async with self.session() as db:
provider = IdentityProvider(
provider_type=provider_type,
name=name or provider_type.capitalize(),
is_active=True,
sso_login_enabled=sso_login_enabled,
config={},
tenant_id=tenant_id,
)
db.add(provider)
await db.flush()
return provider


identity_provider_dao = IdentityProviderDAO()
28 changes: 28 additions & 0 deletions backend/app/dao/invitation_code_dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""DAO for InvitationCode model."""

from sqlalchemy import select

from app.dao.base import BaseDAO
from app.models.invitation_code import InvitationCode


class InvitationCodeDAO(BaseDAO[InvitationCode]):
"""DAO for InvitationCode model."""

def __init__(self) -> None:
super().__init__(InvitationCode)

async def get_active_by_code(self, code: str) -> InvitationCode | None:
"""Find an active invitation code with a tenant association."""
async with self.session() as db:
result = await db.execute(
select(InvitationCode).where(
InvitationCode.code == code,
InvitationCode.is_active == True,
InvitationCode.tenant_id.is_not(None),
)
)
return result.scalar_one_or_none()


invitation_code_dao = InvitationCodeDAO()
Loading