diff --git a/.flake8 b/.flake8 index 2f7531d6..2d2c1e88 100644 --- a/.flake8 +++ b/.flake8 @@ -19,7 +19,7 @@ extend-ignore = # # not required or shadowed by other plugins D I FI TC Q U101 S101 WPS118 WPS400 # black - WPS220 WPS317 WPS318 WPS348 WPS352 E501 C812 C815 C816 C819 E203 + WPS220 WPS317 WPS318 WPS326 WPS348 WPS352 E501 C812 C815 C816 C819 E203 # mypy (for __init__) WPS410 WPS412 # sqlalchemy needs `id` @@ -28,7 +28,7 @@ extend-ignore = # # weird PIE803 C101 FNE007 FNE008 N812 ANN101 ANN102 PT004 WPS110 WPS111 WPS114 WPS338 WPS407 WPS414 WPS440 VNE001 VNE002 CM001 # too many - WPS200 WPS201 WPS202 WPS203 WPS204 WPS210 WPS211 WPS212 WPS213 WPS214 WPS217 WPS218 WPS221 WPS224 WPS230 WPS231 WPS234 WPS235 WPS238 + WPS200 WPS201 WPS202 WPS203 WPS204 WPS210 WPS211 WPS212 WPS213 WPS214 WPS217 WPS218 WPS221 WPS222 WPS224 WPS230 WPS231 WPS234 WPS235 WPS238 # "vague" imports WPS347 @@ -38,6 +38,8 @@ extend-ignore = U100 # fails to understand enums WPS115 + # broken for longer statements (also shadowed by the formatter) + WPS361 # fails to understand overloading WPS428 # fails to understand pipe-unions diff --git a/alembic/versions/057_classroom_events_2_0.py b/alembic/versions/057_classroom_events_2_0.py new file mode 100644 index 00000000..2bce336f --- /dev/null +++ b/alembic/versions/057_classroom_events_2_0.py @@ -0,0 +1,240 @@ +"""classroom_events_2_0 + +Revision ID: 057 +Revises: 056 +Create Date: 2026-04-28 01:01:43.574455 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import BIT + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "057" +down_revision: Union[str, None] = "056" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("scheduler_events", schema="xi_back_2") + sa.Enum(name="eventkind").drop(bind=op.get_bind()) + op.create_table( + "events", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("description", sa.String(length=1000), nullable=True), + sa.Column("kind", sa.Enum("CLASSROOM", name="eventkind"), nullable=False), + sa.Column("classroom_id", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_events")), + schema="xi_back_2", + ) + op.create_table( + "repetition_modes", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("event_id", sa.Integer(), nullable=False), + sa.Column( + "kind", sa.Enum("DAILY", "WEEKLY", name="repetitionkind"), nullable=False + ), + sa.Column( + "starts_at", + postgresql.TIMESTAMP(timezone=True, precision=0), + nullable=False, + ), + sa.Column( + "ends_at", postgresql.TIMESTAMP(timezone=True, precision=0), nullable=False + ), + sa.Column("is_finite", sa.Boolean(), nullable=True), + sa.Column( + "weekly_starting_bitmask", + BIT(length=7), + nullable=True, + ), + sa.Column( + "weekly_combined_bitmask", + BIT(length=7), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["event_id"], + ["xi_back_2.events.id"], + name=op.f("fk_repetition_modes_event_id_events"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_repetition_modes")), + schema="xi_back_2", + ) + op.create_index( + "index_repetition_modes_kind_and_interval", + "repetition_modes", + ["kind", "starts_at", "ends_at"], + unique=False, + schema="xi_back_2", + ) + op.create_index( + op.f("ix_xi_back_2_repetition_modes_event_id"), + "repetition_modes", + ["event_id"], + unique=False, + schema="xi_back_2", + ) + op.create_table( + "event_instances", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "kind", + sa.Enum("SOLE", "REPEATED", name="eventinstancekind"), + nullable=False, + ), + sa.Column("event_id", sa.Integer(), nullable=False), + sa.Column("cancelled_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "starts_at", postgresql.TIMESTAMP(timezone=True, precision=0), nullable=True + ), + sa.Column( + "ends_at", postgresql.TIMESTAMP(timezone=True, precision=0), nullable=True + ), + sa.Column("repetition_mode_id", sa.Uuid(), nullable=True), + sa.Column("instance_index", sa.Integer(), nullable=True), + sa.Column( + "starts_at_override", + postgresql.TIMESTAMP(timezone=True, precision=0), + nullable=True, + ), + sa.Column( + "ends_at_override", + postgresql.TIMESTAMP(timezone=True, precision=0), + nullable=True, + ), + sa.Column("name_override", sa.String(length=100), nullable=True), + sa.Column("description_override", sa.String(length=1000), nullable=True), + sa.ForeignKeyConstraint( + ["event_id"], + ["xi_back_2.events.id"], + name=op.f("fk_event_instances_event_id_events"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["repetition_mode_id"], + ["xi_back_2.repetition_modes.id"], + name=op.f("fk_event_instances_repetition_mode_id_repetition_modes"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_event_instances")), + schema="xi_back_2", + ) + op.create_index( + "index_repeated_event_instance_interval_override", + "event_instances", + ["starts_at_override", "ends_at_override"], + unique=False, + schema="xi_back_2", + postgresql_where=sa.text( + "kind = 'REPEATED' AND starts_at_override IS NOT NULL AND ends_at_override IS NOT NULL" + ), + ) + op.create_index( + "index_repeated_event_instances_ids", + "event_instances", + ["repetition_mode_id", "instance_index"], + unique=False, + schema="xi_back_2", + postgresql_where=sa.text("kind = 'REPEATED'"), + ) + op.create_index( + "index_sole_event_instance_interval", + "event_instances", + ["starts_at", "ends_at"], + unique=False, + schema="xi_back_2", + postgresql_where=sa.text("kind = 'SOLE'"), + ) + op.create_index( + "unique_index_sole_event_instances_event_id", + "event_instances", + ["event_id"], + unique=True, + schema="xi_back_2", + postgresql_where=sa.text("kind = 'SOLE'"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + "unique_index_sole_event_instances_event_id", + table_name="event_instances", + schema="xi_back_2", + postgresql_where=sa.text("kind = 'SOLE'"), + ) + op.drop_index( + "index_sole_event_instance_interval", + table_name="event_instances", + schema="xi_back_2", + postgresql_where=sa.text("kind = 'SOLE'"), + ) + op.drop_index( + "index_repeated_event_instances_ids", + table_name="event_instances", + schema="xi_back_2", + postgresql_where=sa.text("kind = 'REPEATED'"), + ) + op.drop_index( + "index_repeated_event_instance_interval_override", + table_name="event_instances", + schema="xi_back_2", + postgresql_where=sa.text( + "kind = 'REPEATED' AND starts_at_override IS NOT NULL AND ends_at_override IS NOT NULL" + ), + ) + op.drop_table("event_instances", schema="xi_back_2") + op.drop_index( + op.f("ix_xi_back_2_repetition_modes_event_id"), + table_name="repetition_modes", + schema="xi_back_2", + ) + op.drop_index( + "index_repetition_modes_kind_and_interval", + table_name="repetition_modes", + schema="xi_back_2", + ) + op.drop_table("repetition_modes", schema="xi_back_2") + op.drop_table("events", schema="xi_back_2") + sa.Enum(name="eventkind").drop(bind=op.get_bind()) + op.create_table( + "scheduler_events", + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column( + "starts_at", + postgresql.TIMESTAMP(timezone=True), + autoincrement=False, + nullable=False, + ), + sa.Column( + "ends_at", + postgresql.TIMESTAMP(timezone=True), + autoincrement=False, + nullable=False, + ), + sa.Column("name", sa.VARCHAR(length=100), autoincrement=False, nullable=False), + sa.Column( + "description", sa.VARCHAR(length=1000), autoincrement=False, nullable=True + ), + sa.Column( + "kind", + postgresql.ENUM("CLASSROOM", name="eventkind"), + autoincrement=False, + nullable=False, + ), + sa.Column("classroom_id", sa.INTEGER(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_scheduler_events"), + schema="xi_back_2", + ) + # ### end Alembic commands ### diff --git a/app/classrooms/routes/classrooms_int.py b/app/classrooms/routes/classrooms_int.py index c72af413..71695c31 100644 --- a/app/classrooms/routes/classrooms_int.py +++ b/app/classrooms/routes/classrooms_int.py @@ -1,10 +1,18 @@ from collections.abc import Sequence -from typing import assert_never +from typing import Annotated, assert_never + +from fastapi import Path +from sqlalchemy import or_, select from app.classrooms.dependencies.classrooms_dep import ClassroomByID -from app.classrooms.models.classrooms_db import GroupClassroom, IndividualClassroom +from app.classrooms.models.classrooms_db import ( + Classroom, + GroupClassroom, + IndividualClassroom, +) from app.classrooms.models.enrollments_db import Enrollment from app.common.fastapi_ext import APIRouterExt +from app.common.sqlalchemy_ext import db router = APIRouterExt(tags=["classrooms internal"]) @@ -23,3 +31,39 @@ async def list_classroom_student_ids(classroom: ClassroomByID) -> Sequence[int]: ) case _: assert_never(classroom) + + +@router.get( + path="/tutors/{tutor_id}/classroom-ids/", + summary="List all classroom ids for a tutor by id", +) +async def list_tutor_classroom_ids( + tutor_id: Annotated[int, Path()], +) -> list[int]: + return await db.get_all_with_assumed_limit( + select(Classroom.id) + .filter_by(tutor_id=tutor_id) + .order_by(Classroom.created_at.desc()), + limit=100, + ) + + +@router.get( + path="/students/{student_id}/classroom-ids/", + summary="List all classroom ids for a student by id", +) +async def list_student_classroom_ids( + student_id: Annotated[int, Path()], +) -> list[int]: + return await db.get_all_with_assumed_limit( + select(Classroom.id) + .join(Enrollment, isouter=True) + .filter( + or_( + IndividualClassroom.student_id == student_id, + Enrollment.student_id == student_id, + ) + ) + .order_by(Classroom.created_at.desc()), + limit=100, + ) diff --git a/app/common/bridges/classrooms_bdg.py b/app/common/bridges/classrooms_bdg.py index 4c04bb20..8aad2740 100644 --- a/app/common/bridges/classrooms_bdg.py +++ b/app/common/bridges/classrooms_bdg.py @@ -18,3 +18,15 @@ async def list_classroom_student_ids(self, classroom_id: int) -> Response: return await self.client.get( f"/classrooms/{classroom_id}/students/", ) + + @validate_external_json_response(TypeAdapter(list[int])) + async def list_tutor_classroom_ids(self, tutor_id: int) -> Response: + return await self.client.get( + f"/tutors/{tutor_id}/classroom-ids/", + ) + + @validate_external_json_response(TypeAdapter(list[int])) + async def list_student_classroom_ids(self, student_id: int) -> Response: + return await self.client.get( + f"/students/{student_id}/classroom-ids/", + ) diff --git a/app/common/sqlalchemy_ext.py b/app/common/sqlalchemy_ext.py index fb69f413..42b5e969 100644 --- a/app/common/sqlalchemy_ext.py +++ b/app/common/sqlalchemy_ext.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import sys from collections.abc import Iterable, Sequence from contextvars import ContextVar @@ -54,6 +55,21 @@ async def get_count(self, stmt: Select[tuple[int]]) -> int: async def get_all(self, stmt: Select[Any] | ReturningInsert[Any]) -> Sequence[Any]: return (await self.session.execute(stmt)).scalars().all() + async def get_all_with_assumed_limit( + self, + stmt: Select[Any], + limit: int, + ) -> list[Any]: + result = list(await self.get_all(stmt.limit(limit))) + + if len(result) == limit: + logging.warning( + f"Reached the limit of {limit} in one query", + extra={"stmt": str(stmt)}, + ) + + return result + async def get_paginated( self, stmt: Select[Any], offset: int, limit: int ) -> Sequence[Any]: diff --git a/app/common/utils/bitwise.py b/app/common/utils/bitwise.py new file mode 100644 index 00000000..853bd950 --- /dev/null +++ b/app/common/utils/bitwise.py @@ -0,0 +1,6 @@ +def bitwise_cyclic_shift_left(value: int, size: int, rotations: int = 1) -> int: + return ((value << rotations) % (1 << size)) | (value >> (size - rotations)) + + +def bitwise_cyclic_shift_right(value: int, size: int, rotations: int = 1) -> int: + return ((1 << size) - 1) & (value >> rotations | value << (size - rotations)) diff --git a/app/main.py b/app/main.py index 3d694b72..7440418c 100644 --- a/app/main.py +++ b/app/main.py @@ -11,6 +11,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.staticfiles import StaticFiles +from starlette_exporter import PrometheusMiddleware, handle_metrics from tmexio import AsyncSocket, EventException, EventName, PydanticPackager from tmexio.documentation import OpenAPIBuilder @@ -143,6 +144,21 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]: lifespan=lifespan, ) +app.add_middleware( + PrometheusMiddleware, + group_paths=True, + app_name="xi.back-2", + prefix="fastapi", + skip_paths=["/metrics"], + labels={ + "instance_name": settings.instance_name, + # TODO: Add these only to some metrics to save on series amount + # `"x_user_id": from_header("X-User-ID")`, + # `"x_session_id": from_header("X-Session-ID")`, + }, +) +app.add_route("/metrics", handle_metrics) + app.mount("/static", StaticFiles(directory="static"), name="static") diff --git a/app/scheduler/config.py b/app/scheduler/config.py new file mode 100644 index 00000000..6409939f --- /dev/null +++ b/app/scheduler/config.py @@ -0,0 +1,7 @@ +from datetime import timedelta +from typing import Final + +MIN_EVENT_INSTANCE_DURATION: Final[timedelta] = timedelta() +MAX_EVENT_INSTANCE_DURATION: Final[timedelta] = timedelta(hours=12) +MAX_TIMEDELTA_TO_THE_PAST: Final[timedelta] = timedelta(days=-370) +MAX_TIMEDELTA_TO_THE_FUTURE: Final[timedelta] = timedelta(days=370) diff --git a/app/scheduler/dependencies/event_instances_dep.py b/app/scheduler/dependencies/event_instances_dep.py new file mode 100644 index 00000000..fb2a9aed --- /dev/null +++ b/app/scheduler/dependencies/event_instances_dep.py @@ -0,0 +1,77 @@ +from typing import Annotated +from uuid import UUID + +from fastapi import Depends, Path +from starlette import status + +from app.common.fastapi_ext import Responses, with_responses +from app.scheduler.models.event_instances_db import AnyEventInstance, EventInstance +from app.scheduler.models.events_db import ClassroomEvent + + +class EventInstanceResponses(Responses): + EVENT_INSTANCE_NOT_FOUND = status.HTTP_404_NOT_FOUND, "Event instance not found" + + +@with_responses(EventInstanceResponses) +async def get_event_instance_by_id( + event_instance_id: Annotated[UUID, Path()], +) -> AnyEventInstance: + event_instance = await EventInstance.find_first_by_id(event_instance_id) + if event_instance is None: + raise EventInstanceResponses.EVENT_INSTANCE_NOT_FOUND + if not isinstance(event_instance, AnyEventInstance): # pragma: no cover + raise TypeError("SQLAlchemy returned an unknown type of EventInstance") + return event_instance + + +EventInstanceByID = Annotated[AnyEventInstance, Depends(get_event_instance_by_id)] + + +class ClassroomEventInstanceResponses(Responses): + EVENT_INSTANCE_IS_NOT_IN_A_CLASSROOM = ( + status.HTTP_403_FORBIDDEN, + "Event instance is not in a classroom", + ) + + +@with_responses(ClassroomEventInstanceResponses) +async def get_classroom_event_by_instance_id( + event_instance: EventInstanceByID, +) -> ClassroomEvent: + if not isinstance(event_instance.event, ClassroomEvent): + raise ClassroomEventInstanceResponses.EVENT_INSTANCE_IS_NOT_IN_A_CLASSROOM + return event_instance.event + + +ClassroomEventByInstanceID = Annotated[ + ClassroomEvent, + Depends(get_classroom_event_by_instance_id), +] + + +class MyClassroomEventInstanceResponses(Responses): + CLASSROOM_EVENT_INSTANCE_ACCESS_DENIED = ( + status.HTTP_403_FORBIDDEN, + "Classroom event instance access denied", + ) + + +@with_responses(MyClassroomEventInstanceResponses) +async def get_my_classroom_event_instance_by_ids( + event_instance: EventInstanceByID, + classroom_event: ClassroomEventByInstanceID, + classroom_id: Annotated[int, Path()], +) -> AnyEventInstance: + if classroom_event.classroom_id != classroom_id: + raise MyClassroomEventInstanceResponses.CLASSROOM_EVENT_INSTANCE_ACCESS_DENIED + return event_instance + + +MyClassroomEventInstanceByIDs = Annotated[ + AnyEventInstance, + Depends(get_my_classroom_event_instance_by_ids), +] + + +EventInstanceIndex = Annotated[int, Path(ge=0)] diff --git a/app/scheduler/dependencies/events_dep.py b/app/scheduler/dependencies/events_dep.py index 81b99693..77145638 100644 --- a/app/scheduler/dependencies/events_dep.py +++ b/app/scheduler/dependencies/events_dep.py @@ -1,18 +1,31 @@ -from typing import Annotated, Self +from datetime import datetime, timedelta +from typing import Annotated, ClassVar, Self from fastapi import Query -from pydantic import AwareDatetime, BaseModel, model_validator +from pydantic import AwareDatetime, BaseModel, field_validator, model_validator class EventTimeFrameSchema(BaseModel): + min_period_duration: ClassVar[timedelta] = timedelta(days=1) + max_period_duration: ClassVar[timedelta] = timedelta(days=30) + happens_after: AwareDatetime happens_before: AwareDatetime + @classmethod + @field_validator("happens_after", "happens_before", mode="after") + def remove_microseconds_from_timestamps(cls, value: datetime) -> datetime: + # TODO (170) replace with a reusable AwareDatetimeNoMS type (use better naming) + return value.replace(microsecond=0) + @model_validator(mode="after") def validate_happens_after_and_happens_before(self) -> Self: - if self.happens_after >= self.happens_before: + period_duration = self.happens_before - self.happens_after + if period_duration < self.min_period_duration: + raise ValueError("happens_before must be later in time than happens_after") + if period_duration > self.max_period_duration: raise ValueError( - "parameter happens_before must be later in time than happens_after" + "happens_before is too far in the future from happens_after" ) return self diff --git a/app/scheduler/dependencies/repetition_modes_dep.py b/app/scheduler/dependencies/repetition_modes_dep.py new file mode 100644 index 00000000..3e873af0 --- /dev/null +++ b/app/scheduler/dependencies/repetition_modes_dep.py @@ -0,0 +1,72 @@ +from typing import Annotated +from uuid import UUID + +from fastapi import Depends, Path +from starlette import status + +from app.common.fastapi_ext import Responses, with_responses +from app.scheduler.models.events_db import ClassroomEvent +from app.scheduler.models.repetition_modes_db import RepetitionMode + + +class RepetitionModeResponses(Responses): + REPETITION_MODE_NOT_FOUND = status.HTTP_404_NOT_FOUND, "Repetition mode not found" + + +@with_responses(RepetitionModeResponses) +async def get_repetition_mode_by_id( + repetition_mode_id: Annotated[UUID, Path()], +) -> RepetitionMode: + repetition_mode = await RepetitionMode.find_first_by_id(repetition_mode_id) + if repetition_mode is None: + raise RepetitionModeResponses.REPETITION_MODE_NOT_FOUND + return repetition_mode + + +RepetitionModeByID = Annotated[RepetitionMode, Depends(get_repetition_mode_by_id)] + + +class ClassroomRepetitionModeResponses(Responses): + REPETITION_MODE_IS_NOT_IN_A_CLASSROOM = ( + status.HTTP_403_FORBIDDEN, + "Repetition mode is not in a classroom", + ) + + +@with_responses(ClassroomRepetitionModeResponses) +async def get_classroom_event_by_repetition_mode_id( + repetition_mode: RepetitionModeByID, +) -> ClassroomEvent: + if not isinstance(repetition_mode.event, ClassroomEvent): + raise ClassroomRepetitionModeResponses.REPETITION_MODE_IS_NOT_IN_A_CLASSROOM + return repetition_mode.event + + +ClassroomEventByRepetitionModeID = Annotated[ + ClassroomEvent, + Depends(get_classroom_event_by_repetition_mode_id), +] + + +class MyClassroomRepetitionModeResponses(Responses): + CLASSROOM_REPETITION_MODE_ACCESS_DENIED = ( + status.HTTP_403_FORBIDDEN, + "Classroom repetition mode access denied", + ) + + +@with_responses(MyClassroomRepetitionModeResponses) +async def get_my_classroom_repetition_mode_by_ids( + repetition_mode: RepetitionModeByID, + classroom_event: ClassroomEventByRepetitionModeID, + classroom_id: Annotated[int, Path()], +) -> RepetitionMode: + if classroom_event.classroom_id != classroom_id: + raise MyClassroomRepetitionModeResponses.CLASSROOM_REPETITION_MODE_ACCESS_DENIED + return repetition_mode + + +MyClassroomRepetitionModeByIDs = Annotated[ + RepetitionMode, + Depends(get_my_classroom_repetition_mode_by_ids), +] diff --git a/app/scheduler/main.py b/app/scheduler/main.py index 58398042..6ec072a4 100644 --- a/app/scheduler/main.py +++ b/app/scheduler/main.py @@ -7,8 +7,9 @@ from app.common.dependencies.mub_dep import MUBProtection from app.common.fastapi_ext import APIRouterExt from app.scheduler.routes import ( - classroom_events_student_rst, + classroom_event_instances_rst, classroom_events_tutor_rst, + classroom_schedules_rst, ) outside_router = APIRouterExt(prefix="/api/public/scheduler-service") @@ -18,7 +19,8 @@ prefix="/api/protected/scheduler-service", ) authorized_router.include_router(classroom_events_tutor_rst.router) -authorized_router.include_router(classroom_events_student_rst.router) +authorized_router.include_router(classroom_schedules_rst.router) +authorized_router.include_router(classroom_event_instances_rst.router) mub_router = APIRouterExt( dependencies=[MUBProtection], diff --git a/app/scheduler/models/event_instances_db.py b/app/scheduler/models/event_instances_db.py new file mode 100644 index 00000000..c895b2e5 --- /dev/null +++ b/app/scheduler/models/event_instances_db.py @@ -0,0 +1,278 @@ +from datetime import datetime, timedelta +from enum import StrEnum, auto +from typing import Annotated, Literal, Self +from uuid import UUID, uuid4 + +from pydantic import AwareDatetime, BaseModel, Field, computed_field +from pydantic_marshals.sqlalchemy import MappedModel +from sqlalchemy import DateTime, Enum, ForeignKey, Index, String, and_, delete +from sqlalchemy.dialects.postgresql import TIMESTAMP +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.common.config import Base +from app.common.sqlalchemy_ext import db +from app.scheduler.config import ( + MAX_EVENT_INSTANCE_DURATION, + MIN_EVENT_INSTANCE_DURATION, +) +from app.scheduler.models.events_db import Event +from app.scheduler.models.repetition_modes_db import RepetitionMode + + +class EventInstanceResponseSchemaKind(StrEnum): + SOLE = auto() + REPEATED_PERSISTED = auto() + REPEATED_VIRTUAL = auto() + + +class BaseEventInstanceResponseSchema(BaseModel): + event_id: int + classroom_id: int # TODO (170) ClassroomEvent-specific + + starts_at: AwareDatetime + ends_at: AwareDatetime + + name: str + description: str | None = None + + +class PersistedEventInstanceDataMixin(BaseModel): + id: UUID + cancelled_at: AwareDatetime | None = None + + # TODO "meta" + # TODO (170) could just add name & description from Event as proxies + + +class SoleEventInstanceResponseSchema( + BaseEventInstanceResponseSchema, + PersistedEventInstanceDataMixin, +): + kind: Literal[EventInstanceResponseSchemaKind.SOLE] = ( + EventInstanceResponseSchemaKind.SOLE + ) + + +class BaseRepeatedEventInstanceResponseSchema(BaseEventInstanceResponseSchema): + repetition_mode_id: UUID + instance_index: int + + +class PersistedRepeatedEventInstanceResponseSchema( + BaseRepeatedEventInstanceResponseSchema, + PersistedEventInstanceDataMixin, +): + kind: Literal[EventInstanceResponseSchemaKind.REPEATED_PERSISTED] = ( + EventInstanceResponseSchemaKind.REPEATED_PERSISTED + ) + + +class VirtualRepeatedEventInstanceResponseSchema( + BaseRepeatedEventInstanceResponseSchema, +): + kind: Literal[EventInstanceResponseSchemaKind.REPEATED_VIRTUAL] = ( + EventInstanceResponseSchemaKind.REPEATED_VIRTUAL + ) + + +EventInstanceResponseSchema = Annotated[ + SoleEventInstanceResponseSchema + | PersistedRepeatedEventInstanceResponseSchema + | VirtualRepeatedEventInstanceResponseSchema, + Field(discriminator="kind"), +] + + +class EventInstanceKind(StrEnum): + SOLE = auto() + REPEATED = auto() + + +class EventInstance(Base): + __tablename__: str | None = "event_instances" + + id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) + kind: Mapped[EventInstanceKind] = mapped_column(Enum(EventInstanceKind)) + + event_id: Mapped[int] = mapped_column( + # In RepeatedEventInstance this is denormalization, + # but it is useful for faster and more consistent queries + # Also `ForeignKey` can't generate two different constraints for subclasses + ForeignKey(Event.id, ondelete="CASCADE"), + use_existing_column=True, + ) + event: Mapped[Event] = relationship(lazy="joined") + + cancelled_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), + default=None, + ) + # TODO "meta" + + __mapper_args__ = { + "polymorphic_on": kind, + "polymorphic_abstract": True, + } + + StandaloneResponseSchema = MappedModel.create( + columns=[id, cancelled_at], + ) + + def reschedule(self, new_starts_at: datetime, new_ends_at: datetime) -> None: + # TODO (170) mb check if changed to not send a notification + raise NotImplementedError + + +class SoleEventInstance(EventInstance): + __tablename__ = None + + __mapper_args__ = { + "polymorphic_identity": EventInstanceKind.SOLE, + "polymorphic_load": "inline", + } + + starts_at: Mapped[datetime] = mapped_column( + TIMESTAMP(precision=0, timezone=True), + nullable=True, + ) + ends_at: Mapped[datetime] = mapped_column( + TIMESTAMP(precision=0, timezone=True), + nullable=True, + ) + + StandaloneResponseSchema = MappedModel.create( + bases=[EventInstance.StandaloneResponseSchema], + columns=[(starts_at, AwareDatetime), (ends_at, AwareDatetime)], + ) + + def reschedule(self, new_starts_at: datetime, new_ends_at: datetime) -> None: + self.starts_at = new_starts_at + self.ends_at = new_ends_at + + +# declared outside the class, because STI doesn't support indexes on child classes +Index( + "unique_index_sole_event_instances_event_id", + SoleEventInstance.event_id, + postgresql_where=EventInstance.kind == EventInstanceKind.SOLE, + unique=True, +) +Index( + "index_sole_event_instance_interval", + SoleEventInstance.starts_at, + SoleEventInstance.ends_at, + postgresql_where=EventInstance.kind == EventInstanceKind.SOLE, +) + + +class RepeatedEventInstance(EventInstance): + __tablename__ = None + + __mapper_args__ = { + "polymorphic_identity": EventInstanceKind.REPEATED, + "polymorphic_load": "inline", + } + + repetition_mode_id: Mapped[UUID] = mapped_column( + ForeignKey(RepetitionMode.id, ondelete="CASCADE"), + nullable=True, + ) + repetition_mode: Mapped[RepetitionMode] = relationship(lazy="joined") + instance_index: Mapped[int] = mapped_column(nullable=True) + + starts_at_override: Mapped[datetime | None] = mapped_column( + TIMESTAMP(precision=0, timezone=True), + default=None, + ) + ends_at_override: Mapped[datetime | None] = mapped_column( + TIMESTAMP(precision=0, timezone=True), + default=None, + ) + # TODO make sure that either both or neither are specified + + name_override: Mapped[str | None] = mapped_column( + String(100), + default=None, + ) + description_override: Mapped[str | None] = mapped_column( + String(1000), + default=None, + ) + + StandaloneResponseSchema = MappedModel.create( + bases=[EventInstance.StandaloneResponseSchema], + columns=[ + (starts_at_override, AwareDatetime | None), + (ends_at_override, AwareDatetime | None), + name_override, + description_override, + ], + ) + + @classmethod + async def find_by_repetition_mode_id_and_index( + cls, + repetition_mode_id: UUID, + instance_index: int, + ) -> Self | None: + return await cls.find_first_by_kwargs( + repetition_mode_id=repetition_mode_id, + instance_index=instance_index, + ) + + @classmethod + async def delete_all_after_index( + cls, + repetition_mode_id: UUID, + instance_index: int, + ) -> None: + await db.session.execute( + delete(cls).filter( + cls.repetition_mode_id == repetition_mode_id, + cls.instance_index > instance_index, + ) + ) + + def reschedule(self, new_starts_at: datetime, new_ends_at: datetime) -> None: + self.starts_at_override = new_starts_at + self.ends_at_override = new_ends_at + + +# declared outside the class, because STI doesn't support indexes on child classes +Index( + "index_repeated_event_instances_ids", + RepeatedEventInstance.repetition_mode_id, + RepeatedEventInstance.instance_index, + postgresql_where=EventInstance.kind == EventInstanceKind.REPEATED, +) +Index( + "index_repeated_event_instance_interval_override", + RepeatedEventInstance.starts_at_override, + RepeatedEventInstance.ends_at_override, + postgresql_where=and_( + EventInstance.kind == EventInstanceKind.REPEATED, + RepeatedEventInstance.starts_at_override.is_not(None), + RepeatedEventInstance.ends_at_override.is_not(None), + ), +) + + +AnyEventInstance = SoleEventInstance | RepeatedEventInstance + + +class EventInstanceTimeSlotInputSchema(BaseModel): + starts_at: AwareDatetime + duration_seconds: int = Field( + gt=MIN_EVENT_INSTANCE_DURATION.seconds, + le=MAX_EVENT_INSTANCE_DURATION.seconds, + exclude=True, + ) + + @computed_field + @property + def ends_at(self) -> datetime: + return self.starts_at + timedelta(seconds=self.duration_seconds) + + +class SoleEventInstanceInputSchema(EventInstanceTimeSlotInputSchema): + pass # TODO meta diff --git a/app/scheduler/models/events_db.py b/app/scheduler/models/events_db.py index 4de4a923..d13b901a 100644 --- a/app/scheduler/models/events_db.py +++ b/app/scheduler/models/events_db.py @@ -1,11 +1,10 @@ from collections.abc import Sequence -from datetime import datetime from enum import StrEnum, auto from typing import Annotated, Literal, Self -from pydantic import AwareDatetime, Field +from pydantic import Field from pydantic_marshals.sqlalchemy import MappedModel -from sqlalchemy import DateTime, Enum, String, and_, select +from sqlalchemy import Enum, String, select from sqlalchemy.orm import Mapped, mapped_column from app.common.config import Base @@ -17,34 +16,35 @@ class EventKind(StrEnum): class Event(Base): - __tablename__: str | None = "scheduler_events" + __tablename__: str | None = "events" id: Mapped[int] = mapped_column(primary_key=True) - starts_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) - ends_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) name: Mapped[str] = mapped_column(String(100)) description: Mapped[str | None] = mapped_column(String(1000), default=None) kind: Mapped[EventKind] = mapped_column(Enum(EventKind)) - NameType = Annotated[str, Field(min_length=1, max_length=100)] - DescriptionType = Annotated[str | None, Field(min_length=1, max_length=1000)] - __mapper_args__ = { "polymorphic_on": kind, "polymorphic_abstract": True, } + NameType = Annotated[str, Field(min_length=1, max_length=100)] + DescriptionType = Annotated[str | None, Field(min_length=1, max_length=1000)] + InputSchema = MappedModel.create( columns=[ - (starts_at, AwareDatetime), - (ends_at, AwareDatetime), (name, NameType), (description, DescriptionType), ], ) + PatchSchema = InputSchema.as_patch() ResponseSchema = InputSchema.extend(columns=[id]) + @classmethod + async def find_all_by_ids(cls, event_ids: list[int]) -> Sequence[Self]: + return await db.get_all(select(cls).filter(cls.id.in_(event_ids))) + class ClassroomEvent(Event): __tablename__ = None @@ -56,23 +56,8 @@ class ClassroomEvent(Event): classroom_id: Mapped[int] = mapped_column(nullable=True) - InputSchema = MappedModel.create(bases=[Event.InputSchema]) ResponseSchema = MappedModel.create( bases=[Event.ResponseSchema], columns=[classroom_id], extra_fields={"kind": (Literal[EventKind.CLASSROOM], EventKind.CLASSROOM)}, ) - - @classmethod - async def find_all_by_classroom_id_in_time_frame( - cls, - classroom_id: int, - happens_after: datetime, - happens_before: datetime, - ) -> Sequence[Self]: - return await db.get_all( - select(cls) - .filter_by(classroom_id=classroom_id) - .filter(and_(cls.starts_at < happens_before, cls.ends_at > happens_after)) - .order_by(cls.starts_at.desc()) - ) diff --git a/app/scheduler/models/repetition_modes_db.py b/app/scheduler/models/repetition_modes_db.py new file mode 100644 index 00000000..ef84cc4f --- /dev/null +++ b/app/scheduler/models/repetition_modes_db.py @@ -0,0 +1,543 @@ +from abc import abstractmethod +from collections.abc import Iterator +from datetime import datetime, timedelta, timezone +from enum import StrEnum, auto +from typing import Annotated, ClassVar, Literal, Self +from uuid import UUID, uuid4 + +from pydantic import ( + AwareDatetime, + BaseModel, + Field, + TypeAdapter, + computed_field, + model_validator, +) +from pydantic_marshals.sqlalchemy import MappedModel +from sqlalchemy import ( + Enum, + ForeignKey, + Index, + SQLColumnExpression, + delete, + or_, + select, +) +from sqlalchemy.dialects.postgresql import TIMESTAMP +from sqlalchemy.orm import ( + InstrumentedAttribute, + Mapped, + mapped_column, + relationship, +) + +from app.common.config import Base +from app.common.sqlalchemy_ext import db +from app.common.utils.datetime import datetime_utc_now +from app.scheduler.config import ( + MAX_EVENT_INSTANCE_DURATION, + MAX_TIMEDELTA_TO_THE_FUTURE, + MAX_TIMEDELTA_TO_THE_PAST, + MIN_EVENT_INSTANCE_DURATION, +) +from app.scheduler.models.events_db import Event +from app.scheduler.utils.bitmasks import ( + PSQLBitmask, + TimestampRelativeBitmask, + WeeklyBitmask, +) + + +class RepetitionKind(StrEnum): + DAILY = auto() + WEEKLY = auto() + + +class RepetitionMode(Base): + __tablename__: str | None = "repetition_modes" + + id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) + + event_id: Mapped[int] = mapped_column( + ForeignKey(Event.id, ondelete="CASCADE"), + index=True, + ) + event: Mapped[Event] = relationship(lazy="joined") + + kind: Mapped[RepetitionKind] = mapped_column(Enum(RepetitionKind)) + + starts_at: Mapped[datetime] = mapped_column(TIMESTAMP(precision=0, timezone=True)) + ends_at: Mapped[datetime] = mapped_column(TIMESTAMP(precision=0, timezone=True)) + is_finite: Mapped[bool] = mapped_column(default=False, nullable=True) + + __mapper_args__ = { + "polymorphic_on": kind, + "polymorphic_abstract": True, + "with_polymorphic": "*", # `polymorphic_load: inline` doesn't work in complex queries for some reason + } + + __table_args__ = ( + Index("index_repetition_modes_kind_and_interval", kind, starts_at, ends_at), + ) + + @property + def duration(self) -> timedelta: + return self.ends_at - self.starts_at + + @property + def duration_seconds(self) -> int: + return self.duration.seconds + + @property + def event_instance_duration(self) -> timedelta: + return timedelta(seconds=self.duration_seconds) + + @property + def active_period_days(self) -> int | None: + return self.duration.days if self.is_finite else None + + ResponseSchema = MappedModel.create( + columns=[id, event_id, starts_at], + properties=[duration_seconds, active_period_days], + ) + + @classmethod + async def delete_all_at_or_after_timestamp( + cls, + event_id: int, + timestamp: datetime, + ) -> None: + await db.session.execute( + delete(cls).filter( + cls.event_id == event_id, + cls.starts_at >= timestamp, + ) + ) + + @classmethod + async def find_last_bordering_on_a_timestamp( + cls, + event_id: int, + timestamp: datetime, + ) -> Self | None: + return await db.get_first( + select(cls).filter( + cls.event_id == event_id, + cls.starts_at < timestamp, + or_( + cls.ends_at > timestamp, + cls.is_finite.is_(False), + ), + ) + ) + + @classmethod + def iter_in_range_conditions( + cls, + happens_after: datetime, + happens_before: datetime, + ) -> Iterator[SQLColumnExpression[bool]]: + yield cls.kind == cls.__mapper__.polymorphic_identity + yield cls.starts_at <= happens_before + yield or_(cls.is_finite.is_(False), cls.ends_at > happens_after) + + def calculate_event_instance_starts_at_for_index( + self, + instance_index: int, + ) -> datetime: + raise NotImplementedError + + def calculate_closest_past_event_instance_starts_at_for_timestamp( + self, + timestamp: datetime, + ) -> datetime | None: + if timestamp < self.starts_at: + return None + + if self.is_finite and timestamp >= self.ends_at: + timestamp = self.ends_at - self.event_instance_duration + else: + timestamp = self.starts_at + (timestamp - self.starts_at) // timedelta( + days=1 + ) * timedelta(days=1) + + return timestamp + + def calculate_event_instance_index_for_starts_at( + self, + event_instance_starts_at: datetime, + ) -> int: + raise NotImplementedError + + def get_starts_at_bounds_in_range( + self, + happens_after: datetime, + happens_before: datetime, + ) -> tuple[datetime, datetime]: + if self.starts_at > happens_after - self.event_instance_duration: + starts_at_lower_bound = self.starts_at + else: + starts_at_lower_bound = datetime.combine( + happens_after.astimezone(timezone.utc).date(), + self.starts_at.time(), + self.starts_at.tzinfo, + ) + if starts_at_lower_bound + self.event_instance_duration <= happens_after: + # TODO use bitmask's unit instead of `days=1` + # or just implement "skipping" the first starts at + starts_at_lower_bound += timedelta(days=1) + + if self.is_finite and self.ends_at < happens_before: + starts_at_upper_bound = self.ends_at + else: + starts_at_upper_bound = happens_before + + return starts_at_lower_bound, starts_at_upper_bound + + def iter_event_instances_in_range( + self, + happens_after: datetime, + happens_before: datetime, + ) -> Iterator[tuple[int, datetime]]: + """This method assumes, that the repetition mode is inside the range (checked on query level)""" + raise NotImplementedError + + +class DailyRepetitionMode(RepetitionMode): + __tablename__ = None + __mapper_args__ = { + "polymorphic_identity": RepetitionKind.DAILY, + "polymorphic_load": "inline", + } + + ResponseSchema = MappedModel.create( + bases=[RepetitionMode.ResponseSchema], + extra_fields={"kind": (Literal[RepetitionKind.DAILY], RepetitionKind.DAILY)}, + ) + + def calculate_event_instance_starts_at_for_index( + self, + instance_index: int, + ) -> datetime: + return self.starts_at + timedelta(days=instance_index) + + def calculate_event_instance_index_for_starts_at( + self, + event_instance_starts_at: datetime, + ) -> int: + return (event_instance_starts_at - self.starts_at).days + + def iter_event_instances_in_range( + self, + happens_after: datetime, + happens_before: datetime, + ) -> Iterator[tuple[int, datetime]]: + current_starts_at, starts_at_upper_bound = self.get_starts_at_bounds_in_range( + happens_after=happens_after, + happens_before=happens_before, + ) + current_event_instance_index: int = ( + self.calculate_event_instance_index_for_starts_at( + event_instance_starts_at=current_starts_at + ) + ) + while current_starts_at < starts_at_upper_bound: + yield current_event_instance_index, current_starts_at + current_starts_at += timedelta(days=1) + current_event_instance_index += 1 + + +class BitMaskedRepeatingRepetitionMode(RepetitionMode): + __tablename__ = None + __mapper_args__ = { + "polymorphic_abstract": True, + } + + bitmask_type: ClassVar[type[TimestampRelativeBitmask]] + + @classmethod + def get_combined_bitmask_field(cls) -> InstrumentedAttribute[int]: + raise NotImplementedError + + @property + def starting_bitmask(self) -> TimestampRelativeBitmask: + raise NotImplementedError + + @classmethod + def iter_in_range_conditions( + cls, + happens_after: datetime, + happens_before: datetime, + ) -> Iterator[SQLColumnExpression[bool]]: + yield from super().iter_in_range_conditions( + happens_after=happens_after, + happens_before=happens_before, + ) + + if ( + happens_before - happens_after + < (cls.bitmask_type.size - 1) * cls.bitmask_type.unit_duration + ): + interval_bitmask = cls.bitmask_type.build_continuous( + start_timestamp=happens_after.astimezone(timezone.utc), + end_timestamp=happens_before.astimezone(timezone.utc), + ) + yield cls.get_combined_bitmask_field().bitwise_and( + interval_bitmask.value + ) != 0 + + def calculate_event_instance_starts_at_for_index( + self, + instance_index: int, + ) -> datetime: + offset_in_cycles = instance_index // self.starting_bitmask.value.bit_count() + + rotated_bitmask_value: int = self.starting_bitmask.rotate( + source_position=self.starting_bitmask.position_from_timestamp( + self.starts_at.astimezone(timezone.utc) + ), + target_position=-1, + ).value + + required_bit_count: int = ( + instance_index % self.starting_bitmask.value.bit_count() + ) + offset_in_units: int = 0 + while required_bit_count > 0: + if rotated_bitmask_value & 1: + required_bit_count -= 1 + rotated_bitmask_value >>= 1 + offset_in_units += 1 + + return ( + self.starts_at + + offset_in_cycles * self.starting_bitmask.get_cycle_duration() + + offset_in_units * self.starting_bitmask.unit_duration + ) + + def calculate_closest_past_event_instance_starts_at_for_timestamp( + self, + timestamp: datetime, + ) -> datetime | None: + result = super().calculate_closest_past_event_instance_starts_at_for_timestamp( + timestamp=timestamp + ) + + if result is None: + return None + + while not self.starting_bitmask.check_if_timestamp_matches(result): + result -= self.bitmask_type.unit_duration + if result < self.starts_at: + return None + + return result + + def calculate_event_instance_index_for_starts_at( + self, + event_instance_starts_at: datetime, + ) -> int: + repetition_mode_cycle_offset: int = ( + self.starting_bitmask.calculate_cycle_offset_for_timestamp( + timestamp=self.starts_at.astimezone(timezone.utc), + ) + ) + event_instance_cycle_offset: int = ( + self.starting_bitmask.calculate_cycle_offset_for_timestamp( + timestamp=event_instance_starts_at.astimezone(timezone.utc), + ) + ) + return ( + (event_instance_starts_at - self.starts_at) + // self.bitmask_type.get_cycle_duration() + * self.starting_bitmask.value.bit_count() + ) + ( + (event_instance_cycle_offset - repetition_mode_cycle_offset) + % self.starting_bitmask.value.bit_count() + ) + + def iter_event_instances_in_range( + self, + happens_after: datetime, + happens_before: datetime, + ) -> Iterator[tuple[int, datetime]]: + current_starts_at, starts_at_upper_bound = self.get_starts_at_bounds_in_range( + happens_after=happens_after, + happens_before=happens_before, + ) + + current_event_instance_index: int | None = None + while current_starts_at < starts_at_upper_bound: + if self.starting_bitmask.check_if_timestamp_matches(current_starts_at): + if current_event_instance_index is None: + current_event_instance_index = ( + self.calculate_event_instance_index_for_starts_at( + current_starts_at + ) + ) + yield current_event_instance_index, current_starts_at + current_event_instance_index += 1 + current_starts_at += self.starting_bitmask.unit_duration + + +class WeeklyRepetitionMode(BitMaskedRepeatingRepetitionMode): + __tablename__ = None + __mapper_args__ = { + "polymorphic_identity": RepetitionKind.WEEKLY, + "polymorphic_load": "inline", + } + + bitmask_type = WeeklyBitmask + bitmask_size = WeeklyBitmask.size + + weekly_starting_bitmask: Mapped[int] = mapped_column( + PSQLBitmask(bitmask_size), nullable=True + ) + weekly_combined_bitmask: Mapped[int] = mapped_column( + PSQLBitmask(bitmask_size), nullable=True + ) + + ResponseSchema = MappedModel.create( + bases=[BitMaskedRepeatingRepetitionMode.ResponseSchema], + columns=[weekly_starting_bitmask], + extra_fields={"kind": (Literal[RepetitionKind.WEEKLY], RepetitionKind.WEEKLY)}, + ) + + @classmethod + def get_combined_bitmask_field(cls) -> InstrumentedAttribute[int]: + return cls.weekly_combined_bitmask + + @property + def starting_bitmask(self) -> WeeklyBitmask: + return WeeklyBitmask(self.weekly_starting_bitmask) + + +ConcreteRepetitionModeClasses: tuple[type[RepetitionMode], ...] = ( + DailyRepetitionMode, + WeeklyRepetitionMode, +) + + +class BaseRepetitionModeInputSchema(BaseModel): + db_class: ClassVar[type[RepetitionMode]] + + starts_at: AwareDatetime + duration_seconds: int = Field( + gt=MIN_EVENT_INSTANCE_DURATION.seconds, + le=MAX_EVENT_INSTANCE_DURATION.seconds, + exclude=True, + ) + active_period_days: int | None = Field(None, gt=0, exclude=True) + + @model_validator(mode="after") + def validate_starts_at_range(self) -> Self: + timedelta_from_now_to_start: timedelta = self.starts_at - datetime_utc_now() + if timedelta_from_now_to_start < MAX_TIMEDELTA_TO_THE_PAST: + raise ValueError("start is too far in the past") + if timedelta_from_now_to_start > MAX_TIMEDELTA_TO_THE_FUTURE: + raise ValueError("start is too far in the future") + return self + + @model_validator(mode="after") + def validate_active_period_does_not_end_too_far_in_the_future(self) -> Self: + if self.active_period_days is None: + return self + active_period_ends_at: datetime = self.starts_at + timedelta( + days=self.active_period_days + ) + if active_period_ends_at - datetime_utc_now() <= MAX_TIMEDELTA_TO_THE_FUTURE: + return self + raise ValueError("active period's end is too far in the future") + + @property + def starts_at_utc(self) -> datetime: + return self.starts_at.astimezone(timezone.utc) + + @computed_field + @property + def ends_at(self) -> datetime: + return self.starts_at + timedelta( + seconds=self.duration_seconds, + days=self.active_period_days or 0, + ) + + @computed_field + @property + def is_finite(self) -> bool: + return self.active_period_days is not None + + +class DailyRepetitionModeInputSchema(BaseRepetitionModeInputSchema): + db_class = DailyRepetitionMode + + kind: Literal[RepetitionKind.DAILY] = RepetitionKind.DAILY + + +class BaseBitMaskedRepetitionModeInputSchema[BitmaskType: TimestampRelativeBitmask]( + BaseRepetitionModeInputSchema +): + @property + @abstractmethod + def bitmask(self) -> BitmaskType: + raise NotImplementedError + + @property + def starting_bitmask(self) -> BitmaskType: + return self.bitmask.replace_origin( + old_origin=self.starts_at, + new_origin=self.starts_at_utc, + ) + + @property + def ending_bitmask(self) -> BitmaskType: + return self.bitmask.replace_origin( + old_origin=self.starts_at, + new_origin=self.starts_at_utc + timedelta(seconds=self.duration_seconds), + ) + + @property + def combined_bitmask_value(self) -> int: + return self.starting_bitmask.value | self.ending_bitmask.value + + +class WeeklyOccurrenceModeInputSchema( + BaseBitMaskedRepetitionModeInputSchema[WeeklyBitmask] +): + db_class = WeeklyRepetitionMode + + kind: Literal[RepetitionKind.WEEKLY] = RepetitionKind.WEEKLY + + weekly_bitmask: int = Field( + gt=0, + lt=2**WeeklyRepetitionMode.bitmask_size - 1, + exclude=True, + ) + + @property + def bitmask(self) -> WeeklyBitmask: + return WeeklyBitmask(self.weekly_bitmask) + + @computed_field + @property + def weekly_starting_bitmask(self) -> int: + return self.starting_bitmask.value + + @computed_field + @property + def weekly_combined_bitmask(self) -> int: + return self.combined_bitmask_value + + +RepetitionModeInputSchema = Annotated[ + DailyRepetitionModeInputSchema | WeeklyOccurrenceModeInputSchema, + Field(discriminator="kind"), +] + +RepetitionModeResponseSchema = Annotated[ + DailyRepetitionMode.ResponseSchema | WeeklyRepetitionMode.ResponseSchema, + Field(discriminator="kind"), +] + +REPETITION_MODE_TYPE_ADAPTER: TypeAdapter[RepetitionModeResponseSchema] = TypeAdapter( + RepetitionModeResponseSchema +) diff --git a/app/scheduler/routes/classroom_event_instances_rst.py b/app/scheduler/routes/classroom_event_instances_rst.py new file mode 100644 index 00000000..0742def7 --- /dev/null +++ b/app/scheduler/routes/classroom_event_instances_rst.py @@ -0,0 +1,354 @@ +from typing import Annotated, Literal + +from pydantic import AwareDatetime, BaseModel, Field +from starlette import status + +from app.common.fastapi_ext import APIRouterExt, Responses +from app.common.utils.datetime import datetime_utc_now +from app.scheduler.dependencies.event_instances_dep import ( + ClassroomEventByInstanceID, + EventInstanceIndex, + MyClassroomEventInstanceByIDs, +) +from app.scheduler.dependencies.repetition_modes_dep import ( + ClassroomEventByRepetitionModeID, + MyClassroomRepetitionModeByIDs, +) +from app.scheduler.models.event_instances_db import ( + EventInstanceResponseSchemaKind, + EventInstanceTimeSlotInputSchema, + RepeatedEventInstance, + SoleEventInstance, +) +from app.scheduler.models.events_db import ClassroomEvent +from app.scheduler.models.repetition_modes_db import ( + REPETITION_MODE_TYPE_ADAPTER, + RepetitionModeResponseSchema, +) + +router = APIRouterExt(tags=["classroom event instances"]) + + +class VirtualRepeatedEventInstanceStandaloneResponseSchema(BaseModel): + starts_at: AwareDatetime + ends_at: AwareDatetime + + +class BaseEventInstanceDetailedResponseSchema(BaseModel): + event: ClassroomEvent.ResponseSchema + + +class SoleEventInstanceDetailedResponseSchema(BaseEventInstanceDetailedResponseSchema): + kind: Literal[EventInstanceResponseSchemaKind.SOLE] = ( + EventInstanceResponseSchemaKind.SOLE + ) + + persisted_event_instance: SoleEventInstance.StandaloneResponseSchema + + +class BaseRepeatedEventInstanceDetailedResponseSchema( + BaseEventInstanceDetailedResponseSchema +): + repetition_mode: RepetitionModeResponseSchema + instance_index: int + + virtual_event_instance: VirtualRepeatedEventInstanceStandaloneResponseSchema + + +class PersistedRepeatedEventInstanceDetailedResponseSchema( + BaseRepeatedEventInstanceDetailedResponseSchema +): + kind: Literal[EventInstanceResponseSchemaKind.REPEATED_PERSISTED] = ( + EventInstanceResponseSchemaKind.REPEATED_PERSISTED + ) + + persisted_event_instance: RepeatedEventInstance.StandaloneResponseSchema + + +class VirtualRepeatedEventInstanceDetailedResponseSchema( + BaseRepeatedEventInstanceDetailedResponseSchema +): + kind: Literal[EventInstanceResponseSchemaKind.REPEATED_VIRTUAL] = ( + EventInstanceResponseSchemaKind.REPEATED_VIRTUAL + ) + + +EventInstanceDetailedResponseSchema = Annotated[ + SoleEventInstanceDetailedResponseSchema + | PersistedRepeatedEventInstanceDetailedResponseSchema + | VirtualRepeatedEventInstanceDetailedResponseSchema, + Field(discriminator="kind"), +] + + +@router.get( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/event-instances/{event_instance_id}" + "/" + ), + summary="Retrieve detailed data for any classroom event instance by id", +) +@router.get( + path=( + "/roles/student/classrooms/{classroom_id}" + "/event-instances/{event_instance_id}" + "/" + ), + summary="Retrieve detailed data for any classroom event instance by id", +) +async def retrieve_detailed_classroom_event_instance( + classroom_event: ClassroomEventByInstanceID, + event_instance: MyClassroomEventInstanceByIDs, +) -> EventInstanceDetailedResponseSchema: + # TODO (170) move to _schedules_rst? XOR move common logic to "svc" + match event_instance: + case SoleEventInstance(): + return SoleEventInstanceDetailedResponseSchema( + event=ClassroomEvent.ResponseSchema.model_validate( + classroom_event, + from_attributes=True, + ), + persisted_event_instance=SoleEventInstance.StandaloneResponseSchema.model_validate( + event_instance, + from_attributes=True, + ), + ) + case RepeatedEventInstance(): + virtual_instance_starts_at = event_instance.repetition_mode.calculate_event_instance_starts_at_for_index( + instance_index=event_instance.instance_index, + ) + return PersistedRepeatedEventInstanceDetailedResponseSchema( + event=ClassroomEvent.ResponseSchema.model_validate( + classroom_event, from_attributes=True + ), + repetition_mode=REPETITION_MODE_TYPE_ADAPTER.validate_python( + event_instance.repetition_mode, + from_attributes=True, + ), + instance_index=event_instance.instance_index, + virtual_event_instance=VirtualRepeatedEventInstanceStandaloneResponseSchema( + starts_at=virtual_instance_starts_at, + ends_at=( + virtual_instance_starts_at + + event_instance.repetition_mode.event_instance_duration + ), + ), + persisted_event_instance=RepeatedEventInstance.StandaloneResponseSchema.model_validate( + event_instance, from_attributes=True + ), + ) + + +@router.get( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/repetition-modes/{repetition_mode_id}" + "/instances/{instance_index}" + "/" + ), + summary="Reschedule detailed data for any classroom event instance in a repetition mode by id and index", +) +@router.get( + path=( + "/roles/student/classrooms/{classroom_id}" + "/repetition-modes/{repetition_mode_id}" + "/instances/{instance_index}" + "/" + ), + summary="Reschedule detailed data for any classroom event instance in a repetition mode by id and index", +) +async def retrieve_detailed_repeated_classroom_event_instance( + classroom_event: ClassroomEventByRepetitionModeID, + repetition_mode: MyClassroomRepetitionModeByIDs, + instance_index: EventInstanceIndex, +) -> EventInstanceDetailedResponseSchema: + # TODO (170) DRY (aaaaaaaaaaaa) + # TODO (170) move to _schedules_rst? XOR move common logic to "svc" + event_instance = await RepeatedEventInstance.find_by_repetition_mode_id_and_index( + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + ) + + response_schema: type[ + VirtualRepeatedEventInstanceDetailedResponseSchema + | PersistedRepeatedEventInstanceDetailedResponseSchema + ] = ( + VirtualRepeatedEventInstanceDetailedResponseSchema + if event_instance is None + else PersistedRepeatedEventInstanceDetailedResponseSchema + ) + + virtual_instance_starts_at = ( # TODO (170) restructure better + repetition_mode.calculate_event_instance_starts_at_for_index( + instance_index=instance_index, + ) + ) + return response_schema( + event=ClassroomEvent.ResponseSchema.model_validate( + classroom_event, + from_attributes=True, + ), + repetition_mode=REPETITION_MODE_TYPE_ADAPTER.validate_python( + repetition_mode, + from_attributes=True, + ), + instance_index=instance_index, + virtual_event_instance=VirtualRepeatedEventInstanceStandaloneResponseSchema( + starts_at=virtual_instance_starts_at, + ends_at=( + virtual_instance_starts_at + repetition_mode.event_instance_duration + ), + ), + persisted_event_instance=( # type: ignore[call-arg] + None + if event_instance is None + else RepeatedEventInstance.StandaloneResponseSchema.model_validate( + event_instance, + from_attributes=True, + ) + ), + ) + + +@router.put( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/event-instances/{event_instance_id}" + "/time-slot/" + ), + status_code=status.HTTP_204_NO_CONTENT, # TODO (170) response schema + summary="Reschedule any classroom event instance by id", +) +async def reschedule_persisted_classroom_event_instance( + event_instance: MyClassroomEventInstanceByIDs, + data: EventInstanceTimeSlotInputSchema, +) -> None: + event_instance.reschedule( + new_starts_at=data.starts_at, + new_ends_at=data.ends_at, + ) + + +@router.put( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/repetition-modes/{repetition_mode_id}" + "/instances/{instance_index}" + "/time-slot/" + ), + status_code=status.HTTP_204_NO_CONTENT, # TODO (170) response schema + summary="Reschedule any classroom event instance in a repetition mode by id and index", +) +async def reschedule_repeated_classroom_event_instance( + repetition_mode: MyClassroomRepetitionModeByIDs, + instance_index: EventInstanceIndex, + data: EventInstanceTimeSlotInputSchema, +) -> None: + # TODO (170) DRY (repeated in cancel_repeated_classroom_event_instance) + event_instance = await RepeatedEventInstance.find_by_repetition_mode_id_and_index( + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + ) + if event_instance is None: + # TODO (170) generate the actual event instance and check it's not outside of the range + # TODO (170) check new time-slot is not equal to the generated one + await RepeatedEventInstance.create( + event_id=repetition_mode.event_id, + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + starts_at_override=data.starts_at, + ends_at_override=data.ends_at, + ) + # TODO(?) + # `event_instance = await create(...)` + # `event_instance.reschedule(...)` + else: + event_instance.reschedule( + new_starts_at=data.starts_at, + new_ends_at=data.ends_at, + ) + + +class EventInstanceCancellationResponses(Responses): + EVENT_INSTANCE_ALREADY_CANCELLED = ( + status.HTTP_409_CONFLICT, + "Event instance already cancelled", + ) + + +@router.post( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/event-instances/{event_instance_id}" + "/cancellation/" + ), + status_code=status.HTTP_201_CREATED, # TODO (170) mb a response schema + responses=EventInstanceCancellationResponses.responses(), + summary="Cancel any classroom event instance by id", +) +async def cancel_persisted_classroom_event_instance( + event_instance: MyClassroomEventInstanceByIDs, +) -> None: + if event_instance.cancelled_at is not None: + raise EventInstanceCancellationResponses.EVENT_INSTANCE_ALREADY_CANCELLED + event_instance.cancelled_at = datetime_utc_now() + + +@router.post( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/repetition-modes/{repetition_mode_id}" + "/instances/{instance_index}" + "/cancellation/" + ), + status_code=status.HTTP_201_CREATED, # TODO (170) mb a response schema + responses=EventInstanceCancellationResponses.responses(), + summary="Cancel any classroom event instance in a repetition mode by id and index", +) +async def cancel_repeated_classroom_event_instance( + repetition_mode: MyClassroomRepetitionModeByIDs, + instance_index: EventInstanceIndex, +) -> None: + event_instance = await RepeatedEventInstance.find_by_repetition_mode_id_and_index( + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + ) + if event_instance is None: + # TODO (170) generate the actual event instance and check it's not outside of the range + await RepeatedEventInstance.create( + event_id=repetition_mode.event_id, + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + cancelled_at=datetime_utc_now(), + ) + elif event_instance.cancelled_at is not None: + raise EventInstanceCancellationResponses.EVENT_INSTANCE_ALREADY_CANCELLED + else: + event_instance.cancelled_at = datetime_utc_now() + + +class EventInstanceUncancellationResponses(Responses): + EVENT_INSTANCE_IS_NOT_CANCELLED = ( + status.HTTP_409_CONFLICT, + "Event instance is not cancelled", + ) + + +@router.delete( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/event-instances/{event_instance_id}" + "/cancellation/" + ), + status_code=status.HTTP_204_NO_CONTENT, # TODO (170) mb a response schema + responses=EventInstanceUncancellationResponses.responses(), + summary="Uncancel any classroom event instance by id", +) +async def uncancel_persisted_classroom_event_instance( + event_instance: MyClassroomEventInstanceByIDs, +) -> None: + if event_instance.cancelled_at is None: + raise EventInstanceUncancellationResponses.EVENT_INSTANCE_IS_NOT_CANCELLED + + event_instance.cancelled_at = None diff --git a/app/scheduler/routes/classroom_events_student_rst.py b/app/scheduler/routes/classroom_events_student_rst.py deleted file mode 100644 index 89744d67..00000000 --- a/app/scheduler/routes/classroom_events_student_rst.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated - -from fastapi import Path - -from app.common.fastapi_ext import APIRouterExt -from app.scheduler.dependencies.classroom_events_dep import ( - MyClassroomEventByIDs, -) -from app.scheduler.dependencies.events_dep import EventTimeFrameQuery -from app.scheduler.models.events_db import ClassroomEvent - -router = APIRouterExt(tags=["student classroom events"]) - - -@router.get( - path="/roles/student/classrooms/{classroom_id}/events/", - response_model=list[ClassroomEvent.ResponseSchema], - summary="List paginated events in a classroom by id", -) -async def list_classroom_events( - classroom_id: Annotated[int, Path()], - time_frame: EventTimeFrameQuery, -) -> Sequence[ClassroomEvent]: - return await ClassroomEvent.find_all_by_classroom_id_in_time_frame( - classroom_id=classroom_id, - happens_after=time_frame.happens_after, - happens_before=time_frame.happens_before, - ) - - -@router.get( - path="/roles/student/classrooms/{classroom_id}/events/{event_id}/", - response_model=ClassroomEvent.ResponseSchema, - summary="Retrieve a classroom event by ids", -) -async def retrieve_classroom_event( - classroom_event: MyClassroomEventByIDs, -) -> ClassroomEvent: - return classroom_event diff --git a/app/scheduler/routes/classroom_events_tutor_rst.py b/app/scheduler/routes/classroom_events_tutor_rst.py index ef2229ea..9ac9854e 100644 --- a/app/scheduler/routes/classroom_events_tutor_rst.py +++ b/app/scheduler/routes/classroom_events_tutor_rst.py @@ -1,83 +1,212 @@ -from collections.abc import Sequence -from typing import Annotated, Self +from datetime import datetime +from enum import StrEnum, auto +from typing import Annotated, Literal, assert_never -from fastapi import Path -from pydantic import model_validator +from fastapi import Body, Path +from pydantic import AwareDatetime, BaseModel, Field from starlette import status from app.common.fastapi_ext import APIRouterExt -from app.scheduler.dependencies.classroom_events_dep import ( - MyClassroomEventByIDs, +from app.scheduler.dependencies.classroom_events_dep import MyClassroomEventByIDs +from app.scheduler.models.event_instances_db import ( + RepeatedEventInstance, + SoleEventInstance, + SoleEventInstanceInputSchema, ) -from app.scheduler.dependencies.events_dep import EventTimeFrameQuery from app.scheduler.models.events_db import ClassroomEvent +from app.scheduler.models.repetition_modes_db import ( + REPETITION_MODE_TYPE_ADAPTER, + RepetitionMode, + RepetitionModeInputSchema, + RepetitionModeResponseSchema, +) router = APIRouterExt(tags=["tutor classroom events"]) -@router.get( - path="/roles/tutor/classrooms/{classroom_id}/events/", - response_model=list[ClassroomEvent.ResponseSchema], - summary="List paginated events in a classroom by id", -) -async def list_classroom_events( - classroom_id: Annotated[int, Path()], - time_frame: EventTimeFrameQuery, -) -> Sequence[ClassroomEvent]: - return await ClassroomEvent.find_all_by_classroom_id_in_time_frame( - classroom_id=classroom_id, - happens_after=time_frame.happens_after, - happens_before=time_frame.happens_before, - ) +class EventSchemaKind(StrEnum): + SINGLE = auto() + REPEATING = auto() -class ClassroomEventInputSchema(ClassroomEvent.InputSchema): - @model_validator(mode="after") - def validate_event_start_and_end_time(self) -> Self: - if self.starts_at >= self.ends_at: - raise ValueError( - "the start time of an event cannot be greater than or equal to the end time" - ) - return self +class BaseEventInputSchema(BaseModel): + event: ClassroomEvent.InputSchema + + +class SingleEventInputSchema(BaseEventInputSchema): + kind: Literal[EventSchemaKind.SINGLE] = EventSchemaKind.SINGLE + sole_instance: SoleEventInstanceInputSchema + + +class RepeatingEventInputSchema(BaseEventInputSchema): + kind: Literal[EventSchemaKind.REPEATING] = EventSchemaKind.REPEATING + repetition_mode: RepetitionModeInputSchema + + +EventInputSchema = Annotated[ + SingleEventInputSchema | RepeatingEventInputSchema, + Field(discriminator="kind"), +] + + +class BaseEventResponseSchema(BaseModel): + event: ClassroomEvent.ResponseSchema + + +class SingleEventResponseSchema(BaseEventResponseSchema): + kind: Literal[EventSchemaKind.SINGLE] = EventSchemaKind.SINGLE + sole_instance: SoleEventInstance.StandaloneResponseSchema + + +class RepeatingEventResponseSchema(BaseEventResponseSchema): + kind: Literal[EventSchemaKind.REPEATING] = EventSchemaKind.REPEATING + repetition_mode: RepetitionModeResponseSchema + + +EventResponseSchema = Annotated[ + SingleEventResponseSchema | RepeatingEventResponseSchema, + Field(discriminator="kind"), +] @router.post( path="/roles/tutor/classrooms/{classroom_id}/events/", status_code=status.HTTP_201_CREATED, - response_model=ClassroomEvent.ResponseSchema, summary="Create a new event in a classroom by id", ) async def create_classroom_event( classroom_id: Annotated[int, Path()], - input_data: ClassroomEventInputSchema, -) -> ClassroomEvent: - return await ClassroomEvent.create( - **input_data.model_dump(), classroom_id=classroom_id + data: EventInputSchema, +) -> EventResponseSchema: + classroom_event = await ClassroomEvent.create( + **data.event.model_dump(), + classroom_id=classroom_id, ) + match data: + case SingleEventInputSchema(): + sole_instance = await SoleEventInstance.create( + **data.sole_instance.model_dump(), + event_id=classroom_event.id, + ) + return SingleEventResponseSchema( + event=ClassroomEvent.ResponseSchema.model_validate( + classroom_event, from_attributes=True + ), + sole_instance=SoleEventInstance.StandaloneResponseSchema.model_validate( + sole_instance, + from_attributes=True, + ), + ) + case RepeatingEventInputSchema(): + repetition_mode = await data.repetition_mode.db_class.create( + **data.repetition_mode.model_dump(), + event_id=classroom_event.id, + ) + return RepeatingEventResponseSchema( + event=ClassroomEvent.ResponseSchema.model_validate( + classroom_event, from_attributes=True + ), + repetition_mode=REPETITION_MODE_TYPE_ADAPTER.validate_python( + repetition_mode, + from_attributes=True, + ), + ) + case _: + assert_never(data) + -@router.get( +@router.patch( path="/roles/tutor/classrooms/{classroom_id}/events/{event_id}/", response_model=ClassroomEvent.ResponseSchema, - summary="Retrieve a classroom event by ids", + summary="Update a classroom event by ids", ) -async def retrieve_classroom_event( +async def patch_classroom_event( classroom_event: MyClassroomEventByIDs, + data: ClassroomEvent.PatchSchema, ) -> ClassroomEvent: + classroom_event.update(**data.model_dump(exclude_defaults=True)) return classroom_event -@router.put( - path="/roles/tutor/classrooms/{classroom_id}/events/{event_id}/", - response_model=ClassroomEvent.ResponseSchema, - summary="Update a classroom event by ids", +async def cancel_repetition_modes_after_timestamp( + classroom_event: ClassroomEvent, + timestamp: datetime, +) -> None: + await RepetitionMode.delete_all_at_or_after_timestamp( + event_id=classroom_event.id, + timestamp=timestamp, + ) + + border_repetition_mode = await RepetitionMode.find_last_bordering_on_a_timestamp( + event_id=classroom_event.id, + timestamp=timestamp, + ) + if border_repetition_mode is None: + return + + last_starts_at = border_repetition_mode.calculate_closest_past_event_instance_starts_at_for_timestamp( + timestamp=timestamp + ) + if last_starts_at is None: + await border_repetition_mode.delete() + return + + border_repetition_mode.is_finite = True + border_repetition_mode.ends_at = ( + last_starts_at + border_repetition_mode.event_instance_duration + ) + + last_instance_index: int = ( + border_repetition_mode.calculate_event_instance_index_for_starts_at( + event_instance_starts_at=last_starts_at + ) + ) + await RepeatedEventInstance.delete_all_after_index( + repetition_mode_id=border_repetition_mode.id, + instance_index=last_instance_index, + ) + + +@router.post( + path=( + "/roles/tutor/classrooms/{classroom_id}" + "/events/{event_id}/last-repetition-mode/" + ), + status_code=status.HTTP_201_CREATED, + response_model=RepetitionModeResponseSchema, + summary="Create a new repetition mode at the end for a classroom event by id", ) -async def put_classroom_event( +async def create_last_repetition_mode( classroom_event: MyClassroomEventByIDs, - put_data: ClassroomEventInputSchema, -) -> ClassroomEvent: - classroom_event.update(**put_data.model_dump()) - return classroom_event + data: RepetitionModeInputSchema, +) -> RepetitionMode: + # TODO (170) check if this is a single event + + await cancel_repetition_modes_after_timestamp( + classroom_event=classroom_event, + timestamp=data.starts_at, + ) + + return await data.db_class.create( + **data.model_dump(), + event_id=classroom_event.id, + ) + + +@router.post( + path="/roles/tutor/classrooms/{classroom_id}/events/{event_id}/cancellations/", + status_code=status.HTTP_204_NO_CONTENT, + summary="Cancel a repeating classroom event by id after some timestamp", +) +async def cancel_repeating_event_after_timestamp( + classroom_event: MyClassroomEventByIDs, + starts_at: Annotated[AwareDatetime, Body(embed=True)], +) -> None: + await cancel_repetition_modes_after_timestamp( + classroom_event=classroom_event, + timestamp=starts_at, + ) @router.delete( diff --git a/app/scheduler/routes/classroom_schedules_rst.py b/app/scheduler/routes/classroom_schedules_rst.py new file mode 100644 index 00000000..354acdf1 --- /dev/null +++ b/app/scheduler/routes/classroom_schedules_rst.py @@ -0,0 +1,454 @@ +from collections.abc import Iterator +from dataclasses import dataclass +from datetime import datetime +from typing import Annotated, assert_never, cast +from uuid import UUID + +from fastapi import Path +from pydantic import AwareDatetime +from sqlalchemy import and_, or_, select, tuple_ +from sqlalchemy.orm import raiseload + +from app.common.config_bdg import classrooms_bridge +from app.common.dependencies.authorization_dep import AuthorizationData +from app.common.fastapi_ext import APIRouterExt +from app.common.sqlalchemy_ext import db +from app.scheduler.dependencies.events_dep import ( + EventTimeFrameQuery, + EventTimeFrameSchema, +) +from app.scheduler.models.event_instances_db import ( + AnyEventInstance, + EventInstance, + EventInstanceKind, + EventInstanceResponseSchema, + PersistedRepeatedEventInstanceResponseSchema, + RepeatedEventInstance, + SoleEventInstance, + SoleEventInstanceResponseSchema, + VirtualRepeatedEventInstanceResponseSchema, +) +from app.scheduler.models.events_db import ClassroomEvent +from app.scheduler.models.repetition_modes_db import ( + ConcreteRepetitionModeClasses, + RepetitionMode, +) + +router = APIRouterExt(tags=["classroom schedules"]) + + +# TODO (170) naming: `_range`??? + + +async def get_repetition_modes_in_range( + classroom_ids: list[int], + happens_after: datetime, + happens_before: datetime, +) -> list[RepetitionMode]: + return await db.get_all_with_assumed_limit( + select(RepetitionMode) + .options(raiseload(RepetitionMode.event)) + .join(ClassroomEvent) + .filter( + ClassroomEvent.classroom_id.in_(classroom_ids), + or_( + *( + and_( + *klass.iter_in_range_conditions( + happens_after=happens_after, + happens_before=happens_before, + ) + ) + for klass in ConcreteRepetitionModeClasses + ) + ), + ), + limit=1000, + ) + + +@dataclass(frozen=True) +class VirtualRepeatedEventInstanceKeyData: + repetition_mode_id: UUID + instance_index: int + + +@dataclass(frozen=True) +class VirtualRepeatedEventInstanceValueData: + starts_at: AwareDatetime + ends_at: AwareDatetime + event_id: int + + +def iter_virtual_repeated_event_instances_in_range( + repetition_modes: list[RepetitionMode], + happens_after: datetime, + happens_before: datetime, +) -> Iterator[ + tuple[ + VirtualRepeatedEventInstanceKeyData, + VirtualRepeatedEventInstanceValueData, + ] +]: + for repetition_mode in repetition_modes: + event_instance_duration = repetition_mode.event_instance_duration + yield from ( + ( + VirtualRepeatedEventInstanceKeyData( + repetition_mode_id=repetition_mode.id, + instance_index=instance_index, + ), + VirtualRepeatedEventInstanceValueData( + starts_at=starts_at, + ends_at=starts_at + event_instance_duration, + event_id=repetition_mode.event_id, + ), + ) + for ( + instance_index, + starts_at, + ) in repetition_mode.iter_event_instances_in_range( + happens_after=happens_after, + happens_before=happens_before, + ) + ) + + +async def get_event_instances_in_range( + classroom_ids: list[int], + happens_after: datetime, + happens_before: datetime, + virtual_repeated_instance_keys: list[VirtualRepeatedEventInstanceKeyData], +) -> list[AnyEventInstance]: + filters_or = [ + and_( + RepeatedEventInstance.kind == EventInstanceKind.SOLE, + SoleEventInstance.starts_at <= happens_before, + SoleEventInstance.ends_at > happens_after, + ), + and_( + RepeatedEventInstance.kind == EventInstanceKind.REPEATED, + RepeatedEventInstance.starts_at_override.is_not(None), + RepeatedEventInstance.ends_at_override.is_not(None), + RepeatedEventInstance.starts_at_override <= happens_before, + RepeatedEventInstance.ends_at_override > happens_after, + ), + ] + if len(virtual_repeated_instance_keys) > 0: + filters_or.append( + and_( + RepeatedEventInstance.kind == EventInstanceKind.REPEATED, + tuple_( + RepeatedEventInstance.repetition_mode_id, + RepeatedEventInstance.instance_index, + ).in_( + [ + (key.repetition_mode_id, key.instance_index) + for key in virtual_repeated_instance_keys + ] + ), + ) + ) + + return cast( # no good way to type this in SQLAlchemy + list[AnyEventInstance], + await db.get_all_with_assumed_limit( + select(EventInstance) + .options( + raiseload(EventInstance.event), + # TODO enable `raiseload` for `RepeatedEventInstance.repetition_mode` + # Currently disabled for generating virtual event in `iter_persisted_repeated_event_instances` + ) + .join(ClassroomEvent) + .filter( + ClassroomEvent.classroom_id.in_(classroom_ids), + or_(*filters_or), + ), + limit=1000, + ), + ) + + +class ScheduleResponseSchemaAdapter: + def __init__( + self, + events_by_id: dict[int, ClassroomEvent], + sole_event_instances: list[SoleEventInstance], + persisted_repeated_event_instances: list[RepeatedEventInstance], + persisted_repeated_event_instance_keys: set[ + VirtualRepeatedEventInstanceKeyData + ], + virtual_repeated_instances_by_id: dict[ + VirtualRepeatedEventInstanceKeyData, + VirtualRepeatedEventInstanceValueData, + ], + ) -> None: + self.events_by_id = events_by_id + self.virtual_repeated_instances_by_id = virtual_repeated_instances_by_id + self.sole_event_instances = sole_event_instances + self.persisted_repeated_event_instances = persisted_repeated_event_instances + self.persisted_repeated_event_instance_keys = ( + persisted_repeated_event_instance_keys + ) + + def iter_sole_event_instances(self) -> Iterator[SoleEventInstanceResponseSchema]: + for sole_event_instance in self.sole_event_instances: + event = self.events_by_id[sole_event_instance.event_id] + yield SoleEventInstanceResponseSchema( + id=sole_event_instance.id, + event_id=event.id, + classroom_id=event.classroom_id, + cancelled_at=sole_event_instance.cancelled_at, + starts_at=sole_event_instance.starts_at, + ends_at=sole_event_instance.ends_at, + name=event.name, + description=event.description, + ) + + def iter_persisted_repeated_event_instances( + self, + ) -> Iterator[PersistedRepeatedEventInstanceResponseSchema]: + for ( + persisted_repeated_event_instance + ) in self.persisted_repeated_event_instances: + event = self.events_by_id[persisted_repeated_event_instance.event_id] + + virtual_event_instance_value = self.virtual_repeated_instances_by_id.get( + VirtualRepeatedEventInstanceKeyData( + repetition_mode_id=persisted_repeated_event_instance.repetition_mode_id, + instance_index=persisted_repeated_event_instance.instance_index, + ) + ) + if virtual_event_instance_value is None: + repetition_mode = persisted_repeated_event_instance.repetition_mode + starts_at = ( + repetition_mode.calculate_event_instance_starts_at_for_index( + instance_index=persisted_repeated_event_instance.instance_index, + ) + ) + virtual_event_instance_value = VirtualRepeatedEventInstanceValueData( + starts_at=starts_at, + ends_at=starts_at + repetition_mode.event_instance_duration, + event_id=event.id, + ) + + yield PersistedRepeatedEventInstanceResponseSchema( + id=persisted_repeated_event_instance.id, + event_id=event.id, + classroom_id=event.classroom_id, + repetition_mode_id=persisted_repeated_event_instance.repetition_mode_id, + instance_index=persisted_repeated_event_instance.instance_index, + cancelled_at=persisted_repeated_event_instance.cancelled_at, + starts_at=( + persisted_repeated_event_instance.starts_at_override + or virtual_event_instance_value.starts_at + ), + ends_at=( + persisted_repeated_event_instance.ends_at_override + or virtual_event_instance_value.ends_at + ), + name=persisted_repeated_event_instance.name_override or event.name, + description=( + persisted_repeated_event_instance.description_override + or event.description + ), + ) + + def iter_virtual_repeated_event_instances( + self, + ) -> Iterator[VirtualRepeatedEventInstanceResponseSchema]: + for ( + virtual_repeated_event_instance_key, + virtual_repeated_event_instance_value, + ) in self.virtual_repeated_instances_by_id.items(): + if ( + virtual_repeated_event_instance_key + in self.persisted_repeated_event_instance_keys + ): + continue + + event = self.events_by_id[virtual_repeated_event_instance_value.event_id] + yield VirtualRepeatedEventInstanceResponseSchema( + event_id=event.id, + classroom_id=event.classroom_id, + repetition_mode_id=virtual_repeated_event_instance_key.repetition_mode_id, + instance_index=virtual_repeated_event_instance_key.instance_index, + starts_at=virtual_repeated_event_instance_value.starts_at, + ends_at=virtual_repeated_event_instance_value.ends_at, + name=event.name, + description=event.description, + ) + + def iter_event_instances(self) -> Iterator[EventInstanceResponseSchema]: + yield from self.iter_sole_event_instances() + yield from self.iter_persisted_repeated_event_instances() + yield from self.iter_virtual_repeated_event_instances() + + def adapt(self) -> list[EventInstanceResponseSchema]: + return list(self.iter_event_instances()) + + +async def list_classroom_event_instances( + classroom_ids: list[int], + time_frame: EventTimeFrameSchema, +) -> list[EventInstanceResponseSchema]: + repetition_modes = await get_repetition_modes_in_range( + classroom_ids=classroom_ids, + happens_after=time_frame.happens_after, + happens_before=time_frame.happens_before, + ) + + virtual_repeated_instances_by_id: dict[ + VirtualRepeatedEventInstanceKeyData, + VirtualRepeatedEventInstanceValueData, + ] = dict( + iter_virtual_repeated_event_instances_in_range( + repetition_modes=repetition_modes, + happens_after=time_frame.happens_after, + happens_before=time_frame.happens_before, + ) + ) + + persisted_event_instances = await get_event_instances_in_range( + classroom_ids=classroom_ids, + happens_after=time_frame.happens_after, + happens_before=time_frame.happens_before, + virtual_repeated_instance_keys=list(virtual_repeated_instances_by_id.keys()), + ) + + sole_event_instances: list[SoleEventInstance] = [] + persisted_repeated_event_instances: list[RepeatedEventInstance] = [] + + for persisted_event_instance in persisted_event_instances: + match persisted_event_instance: + case SoleEventInstance(): + if ( + persisted_event_instance.cancelled_at is not None + or persisted_event_instance.starts_at > time_frame.happens_before + or persisted_event_instance.ends_at <= time_frame.happens_after + ): + continue + sole_event_instances.append(persisted_event_instance) + case RepeatedEventInstance(): + if ( + persisted_event_instance.cancelled_at is not None + or ( + persisted_event_instance.starts_at_override is not None + and persisted_event_instance.starts_at_override + > time_frame.happens_before + ) + or ( + persisted_event_instance.ends_at_override is not None + and persisted_event_instance.ends_at_override + <= time_frame.happens_after + ) + ): + virtual_repeated_instances_by_id.pop( + VirtualRepeatedEventInstanceKeyData( + persisted_event_instance.repetition_mode_id, + persisted_event_instance.instance_index, + ), + None, + ) + continue + persisted_repeated_event_instances.append(persisted_event_instance) + case _: + assert_never(persisted_event_instance) + + persisted_repeated_event_instance_keys: set[VirtualRepeatedEventInstanceKeyData] = { + VirtualRepeatedEventInstanceKeyData( + repetition_mode_id=event_instance.repetition_mode_id, + instance_index=event_instance.instance_index, + ) + for event_instance in persisted_repeated_event_instances + } + + repetition_mode_ids_used_in_event_instances: set[UUID] = { + key.repetition_mode_id + for key in ( + *virtual_repeated_instances_by_id.keys(), + *persisted_repeated_event_instance_keys, + ) + } + + event_ids: list[int] = list( + { + repetition_mode.event_id + for repetition_mode in repetition_modes + if repetition_mode.id in repetition_mode_ids_used_in_event_instances + } + | {event_instance.event_id for event_instance in sole_event_instances} + | { + event_instance.event_id + for event_instance in persisted_repeated_event_instances + } + ) + + events_by_id: dict[int, ClassroomEvent] + if len(event_ids) == 0: + events_by_id = {} + else: + events_by_id = { + classroom_event.id: classroom_event + for classroom_event in await ClassroomEvent.find_all_by_ids( + event_ids=event_ids + ) + } + + return ScheduleResponseSchemaAdapter( + events_by_id=events_by_id, + virtual_repeated_instances_by_id=virtual_repeated_instances_by_id, + sole_event_instances=sole_event_instances, + persisted_repeated_event_instances=persisted_repeated_event_instances, + persisted_repeated_event_instance_keys=persisted_repeated_event_instance_keys, + ).adapt() + + +@router.get( + path="/roles/tutor/classrooms/{classroom_id}/schedule/", + summary="Retrieve a schedule for all of the events in a classroom by id", +) +@router.get( + path="/roles/student/classrooms/{classroom_id}/schedule/", + summary="Retrieve a schedule for all of the events in a classroom by id", +) +async def retrieve_classroom_schedule( + classroom_id: Annotated[int, Path()], + time_frame: EventTimeFrameQuery, +) -> list[EventInstanceResponseSchema]: + return await list_classroom_event_instances( + classroom_ids=[classroom_id], + time_frame=time_frame, + ) + + +@router.get( + path="/roles/tutor/schedule/", + summary="Retrieve a schedule for all events for the current tutor", +) +async def retrieve_tutor_schedule( + auth_data: AuthorizationData, + time_frame: EventTimeFrameQuery, +) -> list[EventInstanceResponseSchema]: + return await list_classroom_event_instances( + classroom_ids=await classrooms_bridge.list_tutor_classroom_ids( + tutor_id=auth_data.user_id + ), + time_frame=time_frame, + ) + + +@router.get( + path="/roles/student/schedule/", + summary="Retrieve a schedule for all events for the current student", +) +async def retrieve_student_schedule( + auth_data: AuthorizationData, + time_frame: EventTimeFrameQuery, +) -> list[EventInstanceResponseSchema]: + return await list_classroom_event_instances( + classroom_ids=await classrooms_bridge.list_student_classroom_ids( + student_id=auth_data.user_id + ), + time_frame=time_frame, + ) diff --git a/app/scheduler/utils/__init__.py b/app/scheduler/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/scheduler/utils/bitmasks.py b/app/scheduler/utils/bitmasks.py new file mode 100644 index 00000000..cb90f631 --- /dev/null +++ b/app/scheduler/utils/bitmasks.py @@ -0,0 +1,107 @@ +from datetime import datetime, timedelta +from typing import Any, ClassVar, Self + +from sqlalchemy import Dialect, TypeDecorator +from sqlalchemy.dialects.postgresql import BIT + +from app.common.utils.bitwise import ( + bitwise_cyclic_shift_left, + bitwise_cyclic_shift_right, +) + + +class TimestampRelativeBitmask: + size: ClassVar[int] + unit_duration: ClassVar[timedelta] + + @classmethod + def get_cycle_duration(cls) -> timedelta: + return cls.size * cls.unit_duration + + def __init__(self, value: int) -> None: + self.value = value + + @classmethod + def position_from_timestamp(cls, timestamp: datetime) -> int: + raise NotImplementedError + + @classmethod + def build_continuous( + cls, start_timestamp: datetime, end_timestamp: datetime + ) -> Self: + start_position: int = cls.position_from_timestamp(start_timestamp) + end_position: int = cls.position_from_timestamp(end_timestamp) + + if start_position <= end_position: + bitmask_value = 0 + for bit_position in range(start_position, end_position + 1): + bitmask_value ^= 1 << bit_position + else: + bitmask_value = (1 << cls.size) - 1 + for bit_position in range(end_position + 1, start_position): + bitmask_value ^= 1 << bit_position + + return cls(value=bitmask_value) + + def check_if_timestamp_matches(self, timestamp: datetime) -> bool: + return bool(self.value & (1 << self.position_from_timestamp(timestamp))) + + def calculate_cycle_offset_for_timestamp(self, timestamp: datetime) -> int: + bitmask_position: int = self.position_from_timestamp(timestamp) + return (((1 << bitmask_position) - 1) & self.value).bit_count() + + def rotate(self, source_position: int, target_position: int) -> Self: + position_difference: int = (target_position - source_position) % self.size + if position_difference > self.size // 2: + position_difference -= self.size + + if position_difference < 0: + new_value = bitwise_cyclic_shift_right( + value=self.value, + size=self.size, + rotations=-position_difference, + ) + elif position_difference > 0: + new_value = bitwise_cyclic_shift_left( + value=self.value, + size=self.size, + rotations=position_difference, + ) + else: + new_value = self.value + + return type(self)(value=new_value) + + def replace_origin(self, old_origin: datetime, new_origin: datetime) -> Self: + return self.rotate( + source_position=self.position_from_timestamp(old_origin), + target_position=self.position_from_timestamp(new_origin), + ) + + +class WeeklyBitmask(TimestampRelativeBitmask): + size = 7 + unit_duration = timedelta(days=1) + + @classmethod + def position_from_timestamp(cls, timestamp: datetime) -> int: + return timestamp.weekday() + + +class PSQLBitmask(TypeDecorator[int]): + impl = BIT + cache_ok = True + + @property + def python_type(self) -> type[Any]: + return int + + def process_bind_param(self, value: int | None, dialect: Dialect) -> str | None: + if value is None: + return None + return bin(value).partition("b")[2].rjust(7, "0") + + def process_result_value(self, value: str | None, dialect: Dialect) -> int | None: + if value is None: + return None + return int(value, base=2) diff --git a/poetry.lock b/poetry.lock index fbc32382..262a45fa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2628,6 +2628,23 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "prometheus-client" +version = "0.25.0" +description = "Python client for the Prometheus monitoring system." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "prometheus_client-0.25.0-py3-none-any.whl", hash = "sha256:d5aec89e349a6ec230805d0df882f3807f74fd6c1a2fa86864e3c2279059fed1"}, + {file = "prometheus_client-0.25.0.tar.gz", hash = "sha256:5e373b75c31afb3c86f1a52fa1ad470c9aace18082d39ec0d2f918d11cc9ba28"}, +] + +[package.extras] +aiohttp = ["aiohttp"] +django = ["django"] +twisted = ["twisted"] + [[package]] name = "propcache" version = "0.2.0" @@ -3676,6 +3693,22 @@ anyio = ">=3.6.2,<5" [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] +[[package]] +name = "starlette-exporter" +version = "0.23.0" +description = "Prometheus metrics exporter for Starlette applications." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "starlette_exporter-0.23.0-py3-none-any.whl", hash = "sha256:ea1a27f2aae48122931e2384a361a03e00261efbb4a665ce1ae2e46f29123d5e"}, + {file = "starlette_exporter-0.23.0.tar.gz", hash = "sha256:f80998db2d4a3462808a9bce56950046b113d3fab6ec6c20cb6de4431d974969"}, +] + +[package.dependencies] +prometheus-client = ">=0.12" +starlette = ">=0.35" + [[package]] name = "stevedore" version = "5.4.1" @@ -4276,4 +4309,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.1" python-versions = "~=3.12,<3.15" -content-hash = "855178da1cdda866751889417cb42c114e1b9dba5d80c45ab2199bd7c1eb1a09" +content-hash = "e61c08fe15b7619ca64347fc008f1be5c46dc645931330304fffaf68868d694b" diff --git a/pyproject.toml b/pyproject.toml index cb00c9af..fdf4a0ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ itsdangerous = "^2.2.0" faststream = {extras = ["redis"], version = "^0.6.2"} sentry-sdk = {extras = ["asyncio", "fastapi", "sqlalchemy", "redis", "httpx"], version = "2.44.0"} pillow = "^12.2.0" +starlette-exporter = "^0.23.0" [tool.poetry.group.types.dependencies] types-passlib = "^1.7.7.13" @@ -161,5 +162,8 @@ omit = [ "app/common/tmexio_ext.py", "app/common/bridges/base_bdg.py", "app/common/abscract_models/ordered_lists_db.py", # TODO (33602197) + "app/scheduler/*", # TODO (170) + "app/common/bridges/classrooms_bdg.py", # TODO (170) + "app/common/utils/bitwise.py", # TODO (170) "app/setup_ci.py", ] diff --git a/tests/classrooms/functional/test_classrooms_int.py b/tests/classrooms/functional/test_classrooms_int.py index c5969a48..14b896f5 100644 --- a/tests/classrooms/functional/test_classrooms_int.py +++ b/tests/classrooms/functional/test_classrooms_int.py @@ -1,4 +1,5 @@ import pytest +from pydantic_marshals.contains import UnorderedLiteralCollection from starlette import status from starlette.testclient import TestClient @@ -47,3 +48,38 @@ async def test_listing_classroom_students_classroom_not_found( expected_code=status.HTTP_404_NOT_FOUND, expected_json={"detail": "Classroom not found"}, ) + + +# TODO maybe expand + + +async def test_listing_tutor_classroom_ids( + internal_client: TestClient, + tutor_user_id: int, + individual_classroom: IndividualClassroom, + group_classroom: GroupClassroom, +) -> None: + assert_response( + internal_client.get( + f"/internal/classroom-service/tutors/{tutor_user_id}/classroom-ids/", + ), + expected_json=UnorderedLiteralCollection( + [individual_classroom.id, group_classroom.id], + ), + ) + + +async def test_listing_student_classroom_ids( + internal_client: TestClient, + student_user_id: int, + individual_classroom: IndividualClassroom, + enrollment: Enrollment, +) -> None: + assert_response( + internal_client.get( + f"/internal/classroom-service/students/{student_user_id}/classroom-ids/", + ), + expected_json=UnorderedLiteralCollection( + [individual_classroom.id, enrollment.group_classroom_id], + ), + ) diff --git a/tests/scheduler/conftest.py b/tests/scheduler/conftest.py deleted file mode 100644 index ae8a49ac..00000000 --- a/tests/scheduler/conftest.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest -from faker import Faker -from starlette.testclient import TestClient - -from app.common.dependencies.authorization_dep import ProxyAuthData -from app.scheduler.models.events_db import ClassroomEvent -from tests.common.active_session import ActiveSession -from tests.common.types import AnyJSON -from tests.factories import ProxyAuthDataFactory -from tests.scheduler import factories - - -@pytest.fixture() -def tutor_auth_data() -> ProxyAuthData: - return ProxyAuthDataFactory.build() - - -@pytest.fixture() -def tutor_client(client: TestClient, tutor_auth_data: ProxyAuthData) -> TestClient: - return TestClient(client.app, headers=tutor_auth_data.as_headers) - - -@pytest.fixture() -def student_auth_data() -> ProxyAuthData: - return ProxyAuthDataFactory.build() - - -@pytest.fixture() -def student_client(client: TestClient, student_auth_data: ProxyAuthData) -> TestClient: - return TestClient(client.app, headers=student_auth_data.as_headers) - - -@pytest.fixture() -def classroom_id(faker: Faker) -> int: - return faker.random_int(1, 1000) - - -@pytest.fixture() -def other_classroom_id(faker: Faker, classroom_id: int) -> int: - return faker.random_int(classroom_id + 1, classroom_id + 1000) - - -@pytest.fixture() -async def classroom_event( - active_session: ActiveSession, classroom_id: int -) -> ClassroomEvent: - async with active_session(): - return await ClassroomEvent.create( - **factories.ClassroomEventInputFactory.build_python(), - classroom_id=classroom_id, - ) - - -@pytest.fixture() -def classroom_event_data(classroom_event: ClassroomEvent) -> AnyJSON: - return ClassroomEvent.ResponseSchema.model_validate( - classroom_event, from_attributes=True - ).model_dump(mode="json") - - -@pytest.fixture() -async def deleted_classroom_event_id( - active_session: ActiveSession, - classroom_event: ClassroomEvent, -) -> int: - async with active_session(): - await classroom_event.delete() - return classroom_event.id diff --git a/tests/scheduler/factories.py b/tests/scheduler/factories.py deleted file mode 100644 index 91c7d145..00000000 --- a/tests/scheduler/factories.py +++ /dev/null @@ -1,28 +0,0 @@ -from datetime import timezone - -from polyfactory import PostGenerated - -from app.scheduler.models.events_db import ClassroomEvent -from tests.common.polyfactory_ext import BaseModelFactory - - -class ClassroomEventInputFactory(BaseModelFactory[ClassroomEvent.InputSchema]): - __model__ = ClassroomEvent.InputSchema - - ends_at = PostGenerated( - lambda _, values: BaseModelFactory.__faker__.date_time_between( - start_date=values["starts_at"], end_date="+120m", tzinfo=timezone.utc - ) - ) - - -class ClassroomEventInvalidTimeFrameInputFactory( - BaseModelFactory[ClassroomEvent.InputSchema] -): - __model__ = ClassroomEvent.InputSchema - - ends_at = PostGenerated( - lambda _, values: BaseModelFactory.__faker__.date_time( - end_datetime=values["starts_at"], tzinfo=timezone.utc - ) - ) diff --git a/tests/scheduler/functional/test_classroom_events_list_rst.py b/tests/scheduler/functional/test_classroom_events_list_rst.py deleted file mode 100644 index 963070c1..00000000 --- a/tests/scheduler/functional/test_classroom_events_list_rst.py +++ /dev/null @@ -1,173 +0,0 @@ -from collections.abc import AsyncIterator -from datetime import datetime, timedelta, timezone -from typing import Literal, assert_never - -import pytest -from faker import Faker -from starlette import status -from starlette.testclient import TestClient - -from app.scheduler.models.events_db import ClassroomEvent -from tests.common.active_session import ActiveSession -from tests.common.assert_contains_ext import assert_response -from tests.scheduler.factories import ClassroomEventInputFactory - -pytestmark = pytest.mark.anyio - -CLASSROOM_EVENT_LIST_SIZE = 6 - - -@pytest.fixture() -async def classroom_events( - faker: Faker, - active_session: ActiveSession, - classroom_id: int, -) -> AsyncIterator[list[ClassroomEvent]]: - classroom_events: list[ClassroomEvent] = [] - start_datetime: datetime = faker.date_time_between(tzinfo=timezone.utc) - - async with active_session(): - for _ in range(CLASSROOM_EVENT_LIST_SIZE): - end_datetime: datetime = ( - start_datetime - + timedelta(minutes=10) - + faker.time_delta(end_datetime="+120m") - ) - classroom_events.append( - await ClassroomEvent.create( - **ClassroomEventInputFactory.build_python( - starts_at=start_datetime, - ends_at=end_datetime, - ), - classroom_id=classroom_id, - ) - ) - start_datetime = end_datetime + faker.time_delta(end_datetime="+360m") - - classroom_events.sort( - key=lambda classroom_event: classroom_event.starts_at, reverse=True - ) - - yield classroom_events - - async with active_session(): - for classroom_event in classroom_events: - await classroom_event.delete() - - -classroom_events_list_request_parametrization = pytest.mark.parametrize( - ("index_happens_before", "index_happens_after"), - [ - pytest.param(None, None, id="start_to_end"), - pytest.param(None, CLASSROOM_EVENT_LIST_SIZE // 2, id="start_to_middle"), - pytest.param(CLASSROOM_EVENT_LIST_SIZE // 2, None, id="middle_to_end"), - pytest.param(None, 0, id="before_the_start"), - pytest.param(-1, None, id="after_the_end"), - ], -) - - -classroom_events_role_parametrization = pytest.mark.parametrize( - "role", - [ - pytest.param("student", id="student"), - pytest.param("tutor", id="tutor"), - ], -) - - -@classroom_events_list_request_parametrization -@classroom_events_role_parametrization -async def test_tutor_classroom_events_listing( - faker: Faker, - authorized_client: TestClient, - classroom_id: int, - classroom_events: list[ClassroomEvent], - index_happens_before: int | None, - index_happens_after: int | None, - role: Literal["tutor", "student"], -) -> None: - happens_after: datetime = ( - faker.date_time_between( - end_date=classroom_events[0].ends_at, tzinfo=timezone.utc - ) - if index_happens_after is None - else classroom_events[index_happens_after].ends_at - ) - happens_before: datetime = ( - faker.date_time_between( - start_date=classroom_events[-1].starts_at, tzinfo=timezone.utc - ) - if index_happens_before is None - else classroom_events[index_happens_before].starts_at - ) - - assert_response( - authorized_client.get( - f"/api/protected/scheduler-service/roles/{role}/classrooms/{classroom_id}/events/", - params={ - "happens_after": happens_after.isoformat(), - "happens_before": happens_before.isoformat(), - }, - ), - expected_json=[ - ClassroomEvent.ResponseSchema.model_validate( - classroom_event, from_attributes=True - ) - for classroom_event in classroom_events - if classroom_event.starts_at < happens_before - and classroom_event.ends_at > happens_after - ], - ) - - -@pytest.mark.parametrize( - "happens_before_mode", - [ - pytest.param("equal_to_happens_after", id="before_is_equal_to_after"), - pytest.param("less_than_happens_after", id="before_is_less_than_after"), - ], -) -@classroom_events_role_parametrization -async def test_classroom_events_listing_happens_before_le_happens_after( - faker: Faker, - authorized_client: TestClient, - classroom_id: int, - classroom_events: list[ClassroomEvent], - role: Literal["tutor", "student"], - happens_before_mode: Literal["equal_to_happens_after", "less_than_happens_after"], -) -> None: - happens_after: datetime = faker.date_time_between( - tzinfo=timezone.utc, - ) - happens_before: datetime - match happens_before_mode: - case "equal_to_happens_after": - happens_before = happens_after - case "less_than_happens_after": - happens_before = faker.date_time( - end_datetime=happens_after, tzinfo=timezone.utc - ) - case _: - assert_never(happens_before_mode) - - assert_response( - authorized_client.get( - f"/api/protected/scheduler-service/roles/{role}" - f"/classrooms/{classroom_id}/events/", - params={ - "happens_after": happens_after.isoformat(), - "happens_before": happens_before.isoformat(), - }, - ), - expected_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - expected_json={ - "detail": [ - { - "type": "value_error", - "loc": ["query"], - "msg": "Value error, parameter happens_before must be later in time than happens_after", - }, - ] - }, - ) diff --git a/tests/scheduler/functional/test_classroom_events_student_rst.py b/tests/scheduler/functional/test_classroom_events_student_rst.py deleted file mode 100644 index 3967d9b5..00000000 --- a/tests/scheduler/functional/test_classroom_events_student_rst.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest -from starlette import status -from starlette.testclient import TestClient - -from app.scheduler.models.events_db import ClassroomEvent -from tests.common.assert_contains_ext import assert_response -from tests.common.types import AnyJSON - -pytestmark = pytest.mark.anyio - - -async def test_student_classroom_events_retrieving( - student_client: TestClient, - classroom_event: ClassroomEvent, - classroom_event_data: AnyJSON, -) -> None: - assert_response( - student_client.get( - "/api/protected/scheduler-service/roles/student" - f"/classrooms/{classroom_event.classroom_id}/events/{classroom_event.id}/", - ), - expected_json=classroom_event_data, - ) - - -async def test_student_classroom_event_requesting_access_denied( - student_client: TestClient, - other_classroom_id: int, - classroom_event: ClassroomEvent, -) -> None: - assert_response( - student_client.get( - "/api/protected/scheduler-service/roles/student" - f"/classrooms/{other_classroom_id}/events/{classroom_event.id}/", - ), - expected_code=status.HTTP_403_FORBIDDEN, - expected_json={"detail": "Classroom event access denied"}, - ) - - -async def test_student_classroom_event_requesting_not_finding( - student_client: TestClient, - classroom_id: int, - deleted_classroom_event_id: int, -) -> None: - assert_response( - student_client.get( - "/api/protected/scheduler-service/roles/student" - f"/classrooms/{classroom_id}/events/{deleted_classroom_event_id}/", - ), - expected_code=status.HTTP_404_NOT_FOUND, - expected_json={"detail": "Classroom event not found"}, - ) diff --git a/tests/scheduler/functional/test_classroom_events_tutor_rst.py b/tests/scheduler/functional/test_classroom_events_tutor_rst.py deleted file mode 100644 index 42036b2b..00000000 --- a/tests/scheduler/functional/test_classroom_events_tutor_rst.py +++ /dev/null @@ -1,186 +0,0 @@ -from typing import Any - -import pytest -from starlette import status -from starlette.testclient import TestClient - -from app.scheduler.models.events_db import ClassroomEvent, EventKind -from tests.common.active_session import ActiveSession -from tests.common.assert_contains_ext import assert_nodata_response, assert_response -from tests.common.polyfactory_ext import BaseModelFactory -from tests.common.types import AnyJSON -from tests.scheduler.factories import ( - ClassroomEventInputFactory, - ClassroomEventInvalidTimeFrameInputFactory, -) - -pytestmark = pytest.mark.anyio - - -async def test_tutor_classroom_event_creation( - active_session: ActiveSession, - tutor_client: TestClient, - classroom_id: int, -) -> None: - classroom_event_input_data = ClassroomEventInputFactory.build_json() - - classroom_event_id: int = assert_response( - tutor_client.post( - f"/api/protected/scheduler-service/roles/tutor/classrooms/{classroom_id}/events/", - json=classroom_event_input_data, - ), - expected_code=status.HTTP_201_CREATED, - expected_json={ - **classroom_event_input_data, - "id": int, - "classroom_id": classroom_id, - "kind": EventKind.CLASSROOM, - }, - ).json()["id"] - - async with active_session(): - classroom_event = await ClassroomEvent.find_first_by_id(classroom_event_id) - assert classroom_event is not None - await classroom_event.delete() - - -async def test_tutor_classroom_event_creation_invalid_time_frame( - tutor_client: TestClient, - classroom_id: int, -) -> None: - assert_response( - tutor_client.post( - f"/api/protected/scheduler-service/roles/tutor/classrooms/{classroom_id}/events/", - json=ClassroomEventInvalidTimeFrameInputFactory.build_json(), - ), - expected_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - expected_json={ - "detail": [ - { - "type": "value_error", - "loc": ["body"], - "msg": "Value error, the start time of an event cannot be greater than or equal to the end time", - } - ] - }, - ) - - -async def test_tutor_classroom_event_retrieving( - tutor_client: TestClient, - classroom_event: ClassroomEvent, - classroom_event_data: AnyJSON, -) -> None: - assert_response( - tutor_client.get( - "/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{classroom_event.classroom_id}/events/{classroom_event.id}/", - ), - expected_json=classroom_event_data, - ) - - -async def test_tutor_classroom_event_updating( - tutor_client: TestClient, - classroom_event: ClassroomEvent, - classroom_event_data: AnyJSON, -) -> None: - classroom_event_put_data = ClassroomEventInputFactory.build_json() - - assert_response( - tutor_client.put( - "/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{classroom_event.classroom_id}/events/{classroom_event.id}/", - json=classroom_event_put_data, - ), - expected_json={**classroom_event_data, **classroom_event_put_data}, - ) - - -async def test_tutor_classroom_event_updating_invalid_time_frame( - tutor_client: TestClient, - classroom_event: ClassroomEvent, -) -> None: - assert_response( - tutor_client.put( - "/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{classroom_event.classroom_id}/events/{classroom_event.id}/", - json=ClassroomEventInvalidTimeFrameInputFactory.build_json(), - ), - expected_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - expected_json={ - "detail": [ - { - "type": "value_error", - "loc": ["body"], - "msg": "Value error, the start time of an event cannot be greater than or equal to the end time", - } - ] - }, - ) - - -async def test_tutor_classroom_event_deleting( - active_session: ActiveSession, - tutor_client: TestClient, - classroom_event: ClassroomEvent, -) -> None: - assert_nodata_response( - tutor_client.delete( - "/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{classroom_event.classroom_id}/events/{classroom_event.id}/", - ) - ) - - async with active_session(): - assert await ClassroomEvent.find_first_by_id(classroom_event.id) is None - - -tutor_classroom_events_request_parametrization = pytest.mark.parametrize( - ("method", "body_factory"), - [ - pytest.param("GET", None, id="retrieve"), - pytest.param("PUT", ClassroomEventInputFactory, id="update"), - pytest.param("DELETE", None, id="delete"), - ], -) - - -@tutor_classroom_events_request_parametrization -async def test_tutor_classroom_event_requesting_access_denied( - tutor_client: TestClient, - other_classroom_id: int, - classroom_event: ClassroomEvent, - method: str, - body_factory: type[BaseModelFactory[Any]] | None, -) -> None: - assert_response( - tutor_client.request( - method=method, - url="/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{other_classroom_id}/events/{classroom_event.id}/", - json=body_factory and body_factory.build_json(), - ), - expected_code=status.HTTP_403_FORBIDDEN, - expected_json={"detail": "Classroom event access denied"}, - ) - - -@tutor_classroom_events_request_parametrization -async def test_tutor_classroom_event_requesting_not_finding( - tutor_client: TestClient, - classroom_id: int, - deleted_classroom_event_id: int, - method: str, - body_factory: type[BaseModelFactory[Any]] | None, -) -> None: - assert_response( - tutor_client.request( - method=method, - url="/api/protected/scheduler-service/roles/tutor" - f"/classrooms/{classroom_id}/events/{deleted_classroom_event_id}/", - json=body_factory and body_factory.build_json(), - ), - expected_code=status.HTTP_404_NOT_FOUND, - expected_json={"detail": "Classroom event not found"}, - )