Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions backend/alembic/versions/f3a4b5c6d7e8_add_user_genres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""add genres and user_genres

Revision ID: f3a4b5c6d7e8
Revises: e2f3a4b5c6d7
Create Date: 2025-02-19 12:00:00.000000

"""
from alembic import op
import sqlalchemy as sa


revision = "f3a4b5c6d7e8"
down_revision = "e2f3a4b5c6d7"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"genres",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("name", sa.String(length=100), nullable=False),
sa.Column("normalized_name", sa.String(length=100), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.UniqueConstraint("normalized_name", name="uq_genres_normalized_name"),
)
op.create_index("ix_genres_normalized_name", "genres", ["normalized_name"])

op.create_table(
"user_genres",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column(
"user_id",
sa.String(length=36),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"genre_id",
sa.Integer(),
sa.ForeignKey("genres.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.UniqueConstraint(
"user_id", "genre_id", name="uq_user_genres_user_genre"
),
)
op.create_index("ix_user_genres_user_id", "user_genres", ["user_id"])
op.create_index("ix_user_genres_genre_id", "user_genres", ["genre_id"])

genres_table = sa.table(
"genres",
sa.column("name", sa.String),
sa.column("normalized_name", sa.String),
)

seed_genres = [
"ambient",
"art pop",
"dream pop",
"electronic",
"folk",
"funk",
"hip-hop",
"house",
"industrial",
"indie rock",
"jazz fusion",
"lo-fi",
"neo-soul",
"post-punk",
"progressive rock",
"psychedelic",
"R&B",
"shoegaze",
"synth-pop",
"trip-hop",
]
op.bulk_insert(
genres_table,
[
{"name": genre, "normalized_name": genre.strip().lower()}
for genre in seed_genres
],
)


def downgrade() -> None:
op.drop_index("ix_user_genres_genre_id", table_name="user_genres")
op.drop_index("ix_user_genres_user_id", table_name="user_genres")
op.drop_table("user_genres")
op.drop_index("ix_genres_normalized_name", table_name="genres")
op.drop_table("genres")
43 changes: 42 additions & 1 deletion backend/app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class User(Base):
suno_prompts: Mapped[list["SunoPrompt"]] = relationship(
back_populates="owner", cascade="all, delete-orphan"
)
user_genres: Mapped[list["UserGenre"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)


class ExternalAccount(Base):
Expand Down Expand Up @@ -123,6 +126,44 @@ class ExternalAccount(Base):
user: Mapped[User] = relationship(back_populates="external_accounts")


class Genre(Base):
__tablename__ = "genres"

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
normalized_name: Mapped[str] = mapped_column(
String(100), nullable=False, unique=True, index=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)

user_links: Mapped[list["UserGenre"]] = relationship(
back_populates="genre", cascade="all, delete-orphan"
)


class UserGenre(Base):
__tablename__ = "user_genres"
__table_args__ = (
UniqueConstraint("user_id", "genre_id", name="uq_user_genres_user_genre"),
)

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
genre_id: Mapped[int] = mapped_column(
ForeignKey("genres.id", ondelete="CASCADE"), nullable=False, index=True
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)

user: Mapped[User] = relationship(back_populates="user_genres")
genre: Mapped[Genre] = relationship(back_populates="user_links")


class SunoPrompt(Base):
"""A saved Suno prompt that users can favorite and reuse."""

Expand Down Expand Up @@ -174,4 +215,4 @@ class SunoPrompt(Base):
owner: Mapped[User] = relationship(back_populates="suno_prompts")


__all__ = ["Base", "User", "ExternalAccount", "SunoPrompt"]
__all__ = ["Base", "User", "ExternalAccount", "Genre", "UserGenre", "SunoPrompt"]
147 changes: 144 additions & 3 deletions backend/app/routes/spotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,73 @@
Fetches and processes user's Spotify data
"""

from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from app.db.models import Genre, UserGenre
from app.deps import get_current_user_id, get_db, get_spotify_client
from app.schemas.genres import (
GenreCatalogResponse,
GenreItem,
UserGenreAddRequest,
UserGenresResponse,
)
from app.schemas.spotify import SpotifyProfileResponse
from app.deps import get_spotify_client
from app.services.genre_catalog import ensure_genres_seeded, get_or_create_genres
from app.services.spotify_client import SpotifyClient
from app.services.taste_analyzer import (
compute_avg_popularity,
derive_mood_tags,
generate_summary,
)
from app.utils import fetch_and_parse_spotify_data

router = APIRouter()
MAX_USER_GENRES = 20


def _get_user_genre_names(db: Session, user_id: str) -> list[str]:
return (
db.scalars(
select(Genre.name)
.join(UserGenre, UserGenre.genre_id == Genre.id)
.where(UserGenre.user_id == user_id)
.order_by(UserGenre.id.asc())
)
.all()
)


def _seed_user_genres(
db: Session,
user_id: str,
genre_names: list[str],
) -> list[str]:
trimmed_names = genre_names[:MAX_USER_GENRES]
genres = get_or_create_genres(db, trimmed_names)
if not genres:
return []
try:
db.add_all(
[
UserGenre(user_id=user_id, genre_id=genre.id)
for genre in genres
]
)
db.commit()
except IntegrityError:
db.rollback()
return _get_user_genre_names(db, user_id)


@router.get("/profile", response_model=SpotifyProfileResponse)
async def get_profile(
client: SpotifyClient = Depends(get_spotify_client),
time_range: str = Query(default="medium_term", pattern="^(short_term|medium_term|long_term)$")
time_range: str = Query(default="medium_term", pattern="^(short_term|medium_term|long_term)$"),
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id),
):
"""
Get user's Spotify profile with taste analysis
Expand All @@ -30,10 +83,98 @@ async def get_profile(
top_artists, top_tracks, taste_profile = await fetch_and_parse_spotify_data(
client, time_range
)

# Ensure catalog exists and load user-managed genres
ensure_genres_seeded(db)
user_genres = _get_user_genre_names(db, user_id)
if not user_genres and taste_profile.top_genres:
user_genres = _seed_user_genres(db, user_id, taste_profile.top_genres)

if user_genres:
taste_profile.top_genres = user_genres
avg_popularity = compute_avg_popularity(top_artists)
taste_profile.mood_tags = derive_mood_tags(user_genres, avg_popularity)
taste_profile.summary_sentence = generate_summary(
user_genres, taste_profile.mood_tags, avg_popularity
)

return SpotifyProfileResponse(
top_artists=top_artists,
top_tracks=top_tracks,
taste_profile=taste_profile,
time_range=time_range
)


@router.get("/genres/catalog", response_model=GenreCatalogResponse)
def list_genre_catalog(
db: Session = Depends(get_db),
_user_id: str = Depends(get_current_user_id),
):
"""
List available genres for the top-genres picker.
"""
ensure_genres_seeded(db)
genres = db.scalars(select(Genre).order_by(Genre.name.asc())).all()
return GenreCatalogResponse(
genres=[GenreItem(id=genre.id, name=genre.name) for genre in genres]
)


@router.post("/genres", response_model=UserGenresResponse)
def add_user_genre(
body: UserGenreAddRequest,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id),
):
"""
Add a genre to the user's top genres list.
"""
genre = db.get(Genre, body.genre_id)
if not genre:
raise HTTPException(status_code=404, detail="Genre not found")

existing = db.scalar(
select(UserGenre)
.where(UserGenre.user_id == user_id)
.where(UserGenre.genre_id == body.genre_id)
)
if not existing:
current_count = db.scalar(
select(func.count(UserGenre.id)).where(UserGenre.user_id == user_id)
)
if (current_count or 0) >= MAX_USER_GENRES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You can only have up to 20 genres.",
)
db.add(UserGenre(user_id=user_id, genre_id=body.genre_id))
db.commit()

return UserGenresResponse(genres=_get_user_genre_names(db, user_id))


@router.delete("/genres/{genre_id}", response_model=UserGenresResponse)
def delete_user_genre(
genre_id: int,
db: Session = Depends(get_db),
user_id: str = Depends(get_current_user_id),
):
"""
Remove a genre from the user's top genres list.
"""
link = db.scalar(
select(UserGenre)
.where(UserGenre.user_id == user_id)
.where(UserGenre.genre_id == genre_id)
)
if link:
db.delete(link)
db.commit()
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Genre not linked to user",
)

return UserGenresResponse(genres=_get_user_genre_names(db, user_id))
22 changes: 22 additions & 0 deletions backend/app/schemas/genres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Pydantic schemas for user-managed genres.
"""

from pydantic import BaseModel, Field


class GenreItem(BaseModel):
id: int
name: str


class GenreCatalogResponse(BaseModel):
genres: list[GenreItem]


class UserGenreAddRequest(BaseModel):
genre_id: int = Field(..., description="ID of the genre to add")


class UserGenresResponse(BaseModel):
genres: list[str]
Loading