diff --git a/internal/server/server.go b/internal/server/server.go index 8650386..230b84a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -54,6 +54,7 @@ type threadStore interface { ListOrganizationThreads(ctx context.Context, organizationID uuid.UUID, filter store.OrganizationThreadFilter, sort store.OrganizationThreadSort, pageSize int32, cursor *store.OrganizationThreadCursor) (store.OrganizationThreadListResult, error) ListMessages(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *store.MessageCursor, order store.MessageOrder) (store.MessageListResult, error) ListUnackedMessages(ctx context.Context, participantID uuid.UUID, threadID *uuid.UUID, pageSize int32, cursor *store.MessageCursor) (store.MessageListResult, error) + GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) } @@ -620,6 +621,33 @@ func (s *Server) GetUnackedMessages(ctx context.Context, req *threadsv1.GetUnack return resp, nil } +func (s *Server) GetUnackedMessageCounts(ctx context.Context, req *threadsv1.GetUnackedMessageCountsRequest) (*threadsv1.GetUnackedMessageCountsResponse, error) { + participantID, err := parseUUID(req.GetParticipantId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err) + } + identityID, err := identityIDFromContext(ctx) + if err != nil { + return nil, err + } + if identityID != participantID { + return nil, status.Error(codes.PermissionDenied, "permission denied") + } + + counts, err := s.store.GetUnackedMessageCounts(ctx, participantID) + if err != nil { + return nil, toStatusError(err) + } + resp := &threadsv1.GetUnackedMessageCountsResponse{CountsByThreadId: make(map[string]int32, len(counts))} + for threadID, count := range counts { + if count <= 0 { + continue + } + resp.CountsByThreadId[threadID.String()] = count + } + return resp, nil +} + func (s *Server) AckMessages(ctx context.Context, req *threadsv1.AckMessagesRequest) (*threadsv1.AckMessagesResponse, error) { participantID, err := parseUUID(req.GetParticipantId()) if err != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index e24da35..086bf24 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -32,6 +32,7 @@ type stubThreadStore struct { listOrgThreadsFn func(ctx context.Context, organizationID uuid.UUID, filter store.OrganizationThreadFilter, sort store.OrganizationThreadSort, pageSize int32, cursor *store.OrganizationThreadCursor) (store.OrganizationThreadListResult, error) listMessagesFn func(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *store.MessageCursor, order store.MessageOrder) (store.MessageListResult, error) listUnackedFn func(ctx context.Context, participantID uuid.UUID, threadID *uuid.UUID, pageSize int32, cursor *store.MessageCursor) (store.MessageListResult, error) + unackedCountsFn func(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) ackMessagesFn func(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) } @@ -117,6 +118,14 @@ func (s *stubThreadStore) ListUnackedMessages(ctx context.Context, participantID return s.listUnackedFn(ctx, participantID, threadID, pageSize, cursor) } +func (s *stubThreadStore) GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) { + if s.unackedCountsFn == nil { + s.unexpectedCall("GetUnackedMessageCounts") + return nil, nil + } + return s.unackedCountsFn(ctx, participantID) +} + func (s *stubThreadStore) AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) { if s.ackMessagesFn == nil { s.unexpectedCall("AckMessages") @@ -2315,6 +2324,57 @@ func TestGetUnackedMessagesPermissionDenied(t *testing.T) { } } +func TestGetUnackedMessageCountsPermissionDenied(t *testing.T) { + participantID := uuid.New() + identityID := uuid.New() + + srv := New(&stubThreadStore{t: t}, nil, nil, nil, nil, nil) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) + _, err := srv.GetUnackedMessageCounts(ctx, &threadsv1.GetUnackedMessageCountsRequest{ParticipantId: participantID.String()}) + if err == nil { + t.Fatal("expected error") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.PermissionDenied { + t.Fatalf("expected PermissionDenied, got %s: %s", st.Code(), st.Message()) + } +} + +func TestGetUnackedMessageCounts(t *testing.T) { + participantID := uuid.New() + threadID := uuid.New() + zeroThreadID := uuid.New() + storeCalled := false + + storeStub := &stubThreadStore{ + t: t, + unackedCountsFn: func(ctx context.Context, id uuid.UUID) (map[uuid.UUID]int32, error) { + storeCalled = true + if id != participantID { + t.Fatalf("expected participant id %s, got %s", participantID, id) + } + return map[uuid.UUID]int32{threadID: 3, zeroThreadID: 0}, nil + }, + } + + srv := New(storeStub, nil, nil, nil, nil, nil) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", participantID.String())) + resp, err := srv.GetUnackedMessageCounts(ctx, &threadsv1.GetUnackedMessageCountsRequest{ParticipantId: participantID.String()}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !storeCalled { + t.Fatal("expected store call") + } + expected := map[string]int32{threadID.String(): 3} + if !reflect.DeepEqual(resp.GetCountsByThreadId(), expected) { + t.Fatalf("expected counts %v, got %v", expected, resp.GetCountsByThreadId()) + } +} + func TestAckMessagesPermissionDenied(t *testing.T) { participantID := uuid.New() identityID := uuid.New() diff --git a/internal/store/messages.go b/internal/store/messages.go index 13d2c04..d944f51 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -206,6 +206,31 @@ func (s *Store) ListUnackedMessages(ctx context.Context, participantID uuid.UUID return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil } +func (s *Store) GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) { + rows, err := s.pool.Query(ctx, `SELECT thread_id, COUNT(*) FROM message_recipients WHERE participant_id = $1 AND acked_at IS NULL GROUP BY thread_id`, participantID) + if err != nil { + return nil, err + } + defer rows.Close() + + counts := make(map[uuid.UUID]int32) + for rows.Next() { + var threadID uuid.UUID + var count int64 + if err := rows.Scan(&threadID, &count); err != nil { + return nil, err + } + if count > int64(^uint32(0)>>1) { + return nil, fmt.Errorf("unacked count overflow: %d", count) + } + counts[threadID] = int32(count) + } + if err := rows.Err(); err != nil { + return nil, err + } + return counts, nil +} + func buildUnackedMessagesQuery(participantID uuid.UUID, threadID *uuid.UUID, cursor *MessageCursor, limit int32) (string, []any) { query := strings.Builder{} query.WriteString(`SELECT m.id, m.thread_id, m.sender_id, m.body, m.file_ids, m.created_at