From 0018938413830106e94a19a2e2408548c3cfb503 Mon Sep 17 00:00:00 2001 From: Ryan Yeske Date: Sat, 20 Jun 2026 11:37:29 -0700 Subject: [PATCH] Use river to download music --- api/models.go | 112 ++++++++++++++++++++ api/music_tracks.sql.go | 45 ++++++++ handlers/fun/music/page.go | 122 +++------------------- queries/music_tracks.sql | 7 ++ worker/worker.go | 2 + worker/youtubedownload/youtubedownload.go | 108 +++++++++++++++++++ 6 files changed, 290 insertions(+), 106 deletions(-) create mode 100644 worker/youtubedownload/youtubedownload.go diff --git a/api/models.go b/api/models.go index 08080fc..e473ae3 100644 --- a/api/models.go +++ b/api/models.go @@ -5,11 +5,62 @@ package api import ( + "database/sql/driver" + "fmt" + "github.com/jackc/pgx/v5/pgtype" "oj/avatar" "oj/gradient" ) +type RiverJobState string + +const ( + RiverJobStateAvailable RiverJobState = "available" + RiverJobStateCancelled RiverJobState = "cancelled" + RiverJobStateCompleted RiverJobState = "completed" + RiverJobStateDiscarded RiverJobState = "discarded" + RiverJobStatePending RiverJobState = "pending" + RiverJobStateRetryable RiverJobState = "retryable" + RiverJobStateRunning RiverJobState = "running" + RiverJobStateScheduled RiverJobState = "scheduled" +) + +func (e *RiverJobState) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = RiverJobState(s) + case string: + *e = RiverJobState(s) + default: + return fmt.Errorf("unsupported scan type for RiverJobState: %T", src) + } + return nil +} + +type NullRiverJobState struct { + RiverJobState RiverJobState + Valid bool // Valid is true if RiverJobState is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullRiverJobState) Scan(value interface{}) error { + if value == nil { + ns.RiverJobState, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.RiverJobState.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullRiverJobState) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.RiverJobState), nil +} + type Attempt struct { ID int64 CreatedAt pgtype.Timestamptz @@ -157,6 +208,67 @@ type Response struct { Text string } +type RiverClient struct { + ID string + CreatedAt pgtype.Timestamptz + Metadata []byte + PausedAt pgtype.Timestamptz + UpdatedAt pgtype.Timestamptz +} + +type RiverClientQueue struct { + RiverClientID string + Name string + CreatedAt pgtype.Timestamptz + MaxWorkers int64 + Metadata []byte + NumJobsCompleted int64 + NumJobsRunning int64 + UpdatedAt pgtype.Timestamptz +} + +type RiverJob struct { + ID int64 + State RiverJobState + Attempt int16 + MaxAttempts int16 + AttemptedAt pgtype.Timestamptz + CreatedAt pgtype.Timestamptz + FinalizedAt pgtype.Timestamptz + ScheduledAt pgtype.Timestamptz + Priority int16 + Args []byte + AttemptedBy []string + Errors [][]byte + Kind string + Metadata []byte + Queue string + Tags []string + UniqueKey []byte + UniqueStates pgtype.Bits +} + +type RiverLeader struct { + ElectedAt pgtype.Timestamptz + ExpiresAt pgtype.Timestamptz + LeaderID string + Name string +} + +type RiverMigration struct { + ID int64 + CreatedAt pgtype.Timestamptz + Version int64 +} + +type RiverQueue struct { + Name string + CreatedAt pgtype.Timestamptz + Metadata []byte + PausedAt pgtype.Timestamptz + UpdatedAt pgtype.Timestamptz +} + type Room struct { ID int64 CreatedAt pgtype.Timestamptz diff --git a/api/music_tracks.sql.go b/api/music_tracks.sql.go index 9f79554..ae82aa3 100644 --- a/api/music_tracks.sql.go +++ b/api/music_tracks.sql.go @@ -52,6 +52,27 @@ func (q *Queries) InsertMusicTrack(ctx context.Context, arg InsertMusicTrackPara return i, err } +const musicTrackByID = `-- name: MusicTrackByID :one +select id, created_at, user_id, url, title, uploader, filename, status, error from music_tracks where id = $1 +` + +func (q *Queries) MusicTrackByID(ctx context.Context, id int64) (MusicTrack, error) { + row := q.db.QueryRow(ctx, musicTrackByID, id) + var i MusicTrack + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UserID, + &i.Url, + &i.Title, + &i.Uploader, + &i.Filename, + &i.Status, + &i.Error, + ) + return i, err +} + const updateMusicTrack = `-- name: UpdateMusicTrack :exec update music_tracks set title = $1, uploader = $2, filename = $3, status = $4, error = $5 where user_id = $6 and filename = $7 @@ -80,6 +101,30 @@ func (q *Queries) UpdateMusicTrack(ctx context.Context, arg UpdateMusicTrackPara return err } +const updateMusicTrackByID = `-- name: UpdateMusicTrackByID :exec +update music_tracks set title = $1, uploader = $2, status = $3, error = $4 +where id = $5 +` + +type UpdateMusicTrackByIDParams struct { + Title pgtype.Text + Uploader pgtype.Text + Status string + Error pgtype.Text + ID int64 +} + +func (q *Queries) UpdateMusicTrackByID(ctx context.Context, arg UpdateMusicTrackByIDParams) error { + _, err := q.db.Exec(ctx, updateMusicTrackByID, + arg.Title, + arg.Uploader, + arg.Status, + arg.Error, + arg.ID, + ) + return err +} + const userMusicTracks = `-- name: UserMusicTracks :many select id, created_at, user_id, url, title, uploader, filename, status, error from music_tracks where user_id = $1 order by created_at desc ` diff --git a/handlers/fun/music/page.go b/handlers/fun/music/page.go index ef100ed..731c879 100644 --- a/handlers/fun/music/page.go +++ b/handlers/fun/music/page.go @@ -15,18 +15,17 @@ import ( "oj/api" "oj/handlers/layout" "oj/internal/middleware/auth" + "oj/worker" + "oj/worker/youtubedownload" g "maragu.dev/gomponents" h "maragu.dev/gomponents/html" "github.com/jackc/pgx/v5/pgtype" - "github.com/lrstanley/go-ytdlp" ) var ( - activeDownloads = map[string]string{} - activeMu sync.Mutex - initOnce sync.Once + initOnce sync.Once ) type service struct { @@ -75,7 +74,7 @@ func (s *service) Download(w http.ResponseWriter, r *http.Request) { id := fmt.Sprintf("%d", time.Now().UnixNano()) - _, err := s.Queries.InsertMusicTrack(ctx, api.InsertMusicTrackParams{ + track, err := s.Queries.InsertMusicTrack(ctx, api.InsertMusicTrackParams{ UserID: user.ID, Url: url, Filename: id + ".mp3", @@ -86,19 +85,15 @@ func (s *service) Download(w http.ResponseWriter, r *http.Request) { return } - activeMu.Lock() - activeDownloads[id] = url - activeMu.Unlock() - - err = s.runDownload(id, url, user.ID) + _, err = worker.RiverClient.Insert(ctx, youtubedownload.YoutubeDownloadArgs{MusicTrackID: track.ID}, nil) if err != nil { - _ = s.Queries.UpdateMusicTrack(ctx, api.UpdateMusicTrackParams{ - UserID: user.ID, - OldFilename: id + ".mp3", - NewFilename: id + ".mp3", - Status: "error", - Error: toPgText(err.Error()), + _ = s.Queries.UpdateMusicTrackByID(ctx, api.UpdateMusicTrackByIDParams{ + ID: track.ID, + Status: "error", + Error: toPgText(err.Error()), }) + http.Error(w, err.Error(), http.StatusInternalServerError) + return } w.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -112,16 +107,6 @@ func (s *service) Status(w http.ResponseWriter, r *http.Request) { return } - activeMu.Lock() - _, active := activeDownloads[id] - activeMu.Unlock() - - if active { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _ = downloadingRow(id, activeDownloads[id]).Render(w) - return - } - user := auth.FromContext(r.Context()) tracks, err := s.Queries.UserMusicTracks(r.Context(), user.ID) if err != nil { @@ -194,58 +179,6 @@ func slugify(title, uploader string) string { return slug } -func (s *service) runDownload(id, url string, userID int64) error { - ctx := context.TODO() - - cmd := os.Getenv("YTDLP_EXECUTABLE") - if cmd == "" { - return fmt.Errorf("YTDLP_EXECUTABLE not set") - } - - dl := ytdlp.New(). - SetExecutable(cmd). - FormatSort("res,ext:mp4:m4a"). - ExtractAudio(). - AudioFormat("mp3"). - Output(dlPath() + "/" + id + ".%(ext)s"). - NoProgress(). - PrintJSON() - - cookiesFile := os.Getenv("YTDLP_COOKIES_FILE") - if cookiesFile != "" { - dl.Cookies(cookiesFile) - } - - result, err := dl.Run(ctx, url) - if err != nil { - return err - } - - title := "" - uploader := "" - if info, e := result.GetExtractedInfo(); e == nil && len(info) > 0 { - if info[0].Title != nil { - title = *info[0].Title - } - if info[0].Uploader != nil { - uploader = *info[0].Uploader - } else if info[0].Channel != nil { - uploader = *info[0].Channel - } - } - - _ = s.Queries.UpdateMusicTrack(ctx, api.UpdateMusicTrackParams{ - UserID: userID, - OldFilename: id + ".mp3", - NewFilename: id + ".mp3", - Title: toPgText(title), - Uploader: toPgText(uploader), - Status: "done", - }) - - return nil -} - func (s *service) musicPage(userID int64) g.Node { return g.Group{ h.H1( @@ -311,38 +244,15 @@ function playTrack(id,btn){ func (s *service) downloadsList(userID int64) g.Node { tracks, err := s.Queries.UserMusicTracks(context.Background(), userID) if err != nil || len(tracks) == 0 { - activeMu.Lock() - hasActive := len(activeDownloads) > 0 - activeMu.Unlock() - - if !hasActive { - return h.P( - h.Class("nes-text is-disabled"), - g.Text("Nothing downloaded yet."), - ) - } - - var nodes []g.Node - activeMu.Lock() - for id, url := range activeDownloads { - nodes = append(nodes, downloadingRow(id, url)) - } - activeMu.Unlock() - return g.Group(nodes) + return h.P( + h.Class("nes-text is-disabled"), + g.Text("Nothing downloaded yet."), + ) } var nodes []g.Node for _, t := range tracks { - id := t.Filename[:len(t.Filename)-len(".mp3")] - activeMu.Lock() - _, active := activeDownloads[id] - activeMu.Unlock() - - if active && t.Status == "downloading" { - nodes = append(nodes, downloadingRow(id, t.Url)) - } else { - nodes = append(nodes, trackRow(t)) - } + nodes = append(nodes, trackRow(t)) } return g.Group(nodes) } diff --git a/queries/music_tracks.sql b/queries/music_tracks.sql index 2daf04c..be5ae62 100644 --- a/queries/music_tracks.sql +++ b/queries/music_tracks.sql @@ -1,6 +1,9 @@ -- name: UserMusicTracks :many select * from music_tracks where user_id = @user_id order by created_at desc; +-- name: MusicTrackByID :one +select * from music_tracks where id = @id; + -- name: InsertMusicTrack :one insert into music_tracks(user_id, url, title, uploader, filename, status, error) values(@user_id, @url, @title, @uploader, @filename, @status, @error) @@ -9,3 +12,7 @@ returning *; -- name: UpdateMusicTrack :exec update music_tracks set title = @title, uploader = @uploader, filename = @new_filename, status = @status, error = @error where user_id = @user_id and filename = @old_filename; + +-- name: UpdateMusicTrackByID :exec +update music_tracks set title = @title, uploader = @uploader, status = @status, error = @error +where id = @id; diff --git a/worker/worker.go b/worker/worker.go index 37c92e2..05d1ca0 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -8,6 +8,7 @@ import ( "oj/worker/notifydelivery" "oj/worker/notifyfriend" "oj/worker/notifykidfriend" + "oj/worker/youtubedownload" "time" "github.com/acaloiaro/neoq" @@ -36,6 +37,7 @@ func Start(ctx context.Context, queries *api.Queries, conn *pgxpool.Pool) error workers := river.NewWorkers() river.AddWorker(workers, &helloworld.Worker{}) + river.AddWorker(workers, youtubedownload.NewWorker(queries)) RiverClient, err = river.NewClient(riverpgxv5.New(conn), &river.Config{ Queues: map[string]river.QueueConfig{ diff --git a/worker/youtubedownload/youtubedownload.go b/worker/youtubedownload/youtubedownload.go new file mode 100644 index 0000000..0ef7cf4 --- /dev/null +++ b/worker/youtubedownload/youtubedownload.go @@ -0,0 +1,108 @@ +package youtubedownload + +import ( + "context" + "fmt" + "log" + "os" + + "oj/api" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/lrstanley/go-ytdlp" + "github.com/riverqueue/river" +) + +type YoutubeDownloadArgs struct { + MusicTrackID int64 +} + +func (YoutubeDownloadArgs) Kind() string { return "youtube_download" } + +type Worker struct { + river.WorkerDefaults[YoutubeDownloadArgs] + Queries *api.Queries +} + +func NewWorker(q *api.Queries) *Worker { + return &Worker{Queries: q} +} + +func dlPath() string { + p := os.Getenv("MUSIC_DOWNLOAD_PATH") + if p == "" { + return "./music-downloads" + } + return p +} + +func (w *Worker) Work(ctx context.Context, job *river.Job[YoutubeDownloadArgs]) error { + trackID := job.Args.MusicTrackID + + track, err := w.Queries.MusicTrackByID(ctx, trackID) + if err != nil { + return fmt.Errorf("fetching track %d: %w", trackID, err) + } + + cmd := os.Getenv("YTDLP_EXECUTABLE") + if cmd == "" { + return fmt.Errorf("YTDLP_EXECUTABLE not set") + } + + dl := ytdlp.New(). + SetExecutable(cmd). + FormatSort("res,ext:mp4:m4a"). + ExtractAudio(). + AudioFormat("mp3"). + Output(dlPath() + "/" + track.Filename). + NoProgress(). + PrintJSON() + + cookiesFile := os.Getenv("YTDLP_COOKIES_FILE") + if cookiesFile != "" { + dl.Cookies(cookiesFile) + } + + result, err := dl.Run(ctx, track.Url) + if err != nil { + _ = w.Queries.UpdateMusicTrackByID(ctx, api.UpdateMusicTrackByIDParams{ + ID: trackID, + Status: "error", + Error: toPgText(err.Error()), + }) + return fmt.Errorf("downloading: %w", err) + } + + title := "" + uploader := "" + if info, e := result.GetExtractedInfo(); e == nil && len(info) > 0 { + if info[0].Title != nil { + title = *info[0].Title + } + if info[0].Uploader != nil { + uploader = *info[0].Uploader + } else if info[0].Channel != nil { + uploader = *info[0].Channel + } + } + + err = w.Queries.UpdateMusicTrackByID(ctx, api.UpdateMusicTrackByIDParams{ + ID: trackID, + Title: toPgText(title), + Uploader: toPgText(uploader), + Status: "done", + }) + if err != nil { + return fmt.Errorf("updating track %d: %w", trackID, err) + } + + log.Printf("downloaded track %d: %s - %s", trackID, title, uploader) + return nil +} + +func toPgText(s string) pgtype.Text { + if s == "" { + return pgtype.Text{Valid: false} + } + return pgtype.Text{String: s, Valid: true} +}