Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type threadStore interface {
ListOrganizationThreads(ctx context.Context, organizationID uuid.UUID, filter store.OrganizationThreadFilter, sort store.OrganizationThreadSort, pageSize int32, cursor *store.OrganizationThreadCursor) (store.OrganizationThreadListResult, error)
ListMessages(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *store.MessageCursor, order store.MessageOrder) (store.MessageListResult, error)
ListUnackedMessages(ctx context.Context, participantID uuid.UUID, threadID *uuid.UUID, pageSize int32, cursor *store.MessageCursor) (store.MessageListResult, error)
GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error)
AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error)
}

Expand Down Expand Up @@ -620,6 +621,33 @@ func (s *Server) GetUnackedMessages(ctx context.Context, req *threadsv1.GetUnack
return resp, nil
}

func (s *Server) GetUnackedMessageCounts(ctx context.Context, req *threadsv1.GetUnackedMessageCountsRequest) (*threadsv1.GetUnackedMessageCountsResponse, error) {
participantID, err := parseUUID(req.GetParticipantId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "participant_id: %v", err)
}
identityID, err := identityIDFromContext(ctx)
if err != nil {
return nil, err
}
if identityID != participantID {
return nil, status.Error(codes.PermissionDenied, "permission denied")
}

counts, err := s.store.GetUnackedMessageCounts(ctx, participantID)
if err != nil {
return nil, toStatusError(err)
}
resp := &threadsv1.GetUnackedMessageCountsResponse{CountsByThreadId: make(map[string]int32, len(counts))}
for threadID, count := range counts {
if count <= 0 {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] COUNT(*) in the store query can never yield <= 0, so the if count <= 0 { continue } branch is redundant and slightly masks invariant violations. I’d either remove it (and just return the map) or treat unexpected values as an internal error.

continue
}
resp.CountsByThreadId[threadID.String()] = count
}
return resp, nil
}

func (s *Server) AckMessages(ctx context.Context, req *threadsv1.AckMessagesRequest) (*threadsv1.AckMessagesResponse, error) {
participantID, err := parseUUID(req.GetParticipantId())
if err != nil {
Expand Down
60 changes: 60 additions & 0 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type stubThreadStore struct {
listOrgThreadsFn func(ctx context.Context, organizationID uuid.UUID, filter store.OrganizationThreadFilter, sort store.OrganizationThreadSort, pageSize int32, cursor *store.OrganizationThreadCursor) (store.OrganizationThreadListResult, error)
listMessagesFn func(ctx context.Context, threadID uuid.UUID, pageSize int32, cursor *store.MessageCursor, order store.MessageOrder) (store.MessageListResult, error)
listUnackedFn func(ctx context.Context, participantID uuid.UUID, threadID *uuid.UUID, pageSize int32, cursor *store.MessageCursor) (store.MessageListResult, error)
unackedCountsFn func(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error)
ackMessagesFn func(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error)
}

Expand Down Expand Up @@ -117,6 +118,14 @@ func (s *stubThreadStore) ListUnackedMessages(ctx context.Context, participantID
return s.listUnackedFn(ctx, participantID, threadID, pageSize, cursor)
}

func (s *stubThreadStore) GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) {
if s.unackedCountsFn == nil {
s.unexpectedCall("GetUnackedMessageCounts")
return nil, nil
}
return s.unackedCountsFn(ctx, participantID)
}

func (s *stubThreadStore) AckMessages(ctx context.Context, participantID uuid.UUID, messageIDs []uuid.UUID) (int32, error) {
if s.ackMessagesFn == nil {
s.unexpectedCall("AckMessages")
Expand Down Expand Up @@ -2315,6 +2324,57 @@ func TestGetUnackedMessagesPermissionDenied(t *testing.T) {
}
}

func TestGetUnackedMessageCountsPermissionDenied(t *testing.T) {
participantID := uuid.New()
identityID := uuid.New()

srv := New(&stubThreadStore{t: t}, nil, nil, nil, nil, nil)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String()))
_, err := srv.GetUnackedMessageCounts(ctx, &threadsv1.GetUnackedMessageCountsRequest{ParticipantId: participantID.String()})
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("expected PermissionDenied, got %s: %s", st.Code(), st.Message())
}
}

func TestGetUnackedMessageCounts(t *testing.T) {
participantID := uuid.New()
threadID := uuid.New()
zeroThreadID := uuid.New()
storeCalled := false

storeStub := &stubThreadStore{
t: t,
unackedCountsFn: func(ctx context.Context, id uuid.UUID) (map[uuid.UUID]int32, error) {
storeCalled = true
if id != participantID {
t.Fatalf("expected participant id %s, got %s", participantID, id)
}
return map[uuid.UUID]int32{threadID: 3, zeroThreadID: 0}, nil
},
}

srv := New(storeStub, nil, nil, nil, nil, nil)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", participantID.String()))
resp, err := srv.GetUnackedMessageCounts(ctx, &threadsv1.GetUnackedMessageCountsRequest{ParticipantId: participantID.String()})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !storeCalled {
t.Fatal("expected store call")
}
expected := map[string]int32{threadID.String(): 3}
if !reflect.DeepEqual(resp.GetCountsByThreadId(), expected) {
t.Fatalf("expected counts %v, got %v", expected, resp.GetCountsByThreadId())
}
}

func TestAckMessagesPermissionDenied(t *testing.T) {
participantID := uuid.New()
identityID := uuid.New()
Expand Down
25 changes: 25 additions & 0 deletions internal/store/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,31 @@ func (s *Store) ListUnackedMessages(ctx context.Context, participantID uuid.UUID
return MessageListResult{Messages: messages, NextCursor: nextCursor}, nil
}

func (s *Store) GetUnackedMessageCounts(ctx context.Context, participantID uuid.UUID) (map[uuid.UUID]int32, error) {
rows, err := s.pool.Query(ctx, `SELECT thread_id, COUNT(*) FROM message_recipients WHERE participant_id = $1 AND acked_at IS NULL GROUP BY thread_id`, participantID)
if err != nil {
return nil, err
}
defer rows.Close()

counts := make(map[uuid.UUID]int32)
for rows.Next() {
var threadID uuid.UUID
var count int64
if err := rows.Scan(&threadID, &count); err != nil {
return nil, err
}
if count > int64(^uint32(0)>>1) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] int64(^uint32(0)>>1) is a pretty opaque way to express max int32. Prefer a named const (e.g. const maxInt32 = int64(math.MaxInt32)) so the intent is obvious.

return nil, fmt.Errorf("unacked count overflow: %d", count)
}
counts[threadID] = int32(count)
}
if err := rows.Err(); err != nil {
return nil, err
}
return counts, nil
}

func buildUnackedMessagesQuery(participantID uuid.UUID, threadID *uuid.UUID, cursor *MessageCursor, limit int32) (string, []any) {
query := strings.Builder{}
query.WriteString(`SELECT m.id, m.thread_id, m.sender_id, m.body, m.file_ids, m.created_at
Expand Down
Loading