From 0f5fed6dc3bcf60e69d1ba1c497298ed5536ec34 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 14 May 2026 11:00:10 +0000 Subject: [PATCH 1/2] fix: record platform usage with thread org --- internal/server/server.go | 28 +++--- internal/server/server_test.go | 161 +++++++++++++++++++++++++++++++++ internal/store/messages.go | 8 +- 3 files changed, 179 insertions(+), 18 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 4d382b5..8f32201 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -398,7 +398,7 @@ func (s *Server) SendMessage(ctx context.Context, req *threadsv1.SendMessageRequ if err != nil { return nil, toStatusError(err) } - s.recordMessageSent(ctx, result.Message) + s.recordMessageSent(ctx, result) if err := s.notifier.PublishMessageCreated(ctx, threadID, result.Message.ID, result.Recipients); err != nil { return nil, status.Errorf(codes.Internal, "notify recipients: %v", err) } @@ -741,38 +741,36 @@ func (s *Server) AckMessages(ctx context.Context, req *threadsv1.AckMessagesRequ } func (s *Server) recordThreadCreated(ctx context.Context, thread store.Thread) { + if thread.OrganizationID == nil { + panic("thread organization_id missing") + } threadID := thread.ID + orgID := *thread.OrganizationID createdAt := thread.CreatedAt - s.recordUsageAsync(ctx, "thread_created", func(recordCtx context.Context, orgID uuid.UUID) error { + s.recordUsageAsync(ctx, "thread_created", func(recordCtx context.Context) error { return s.metering.RecordThreadCreated(recordCtx, orgID, threadID, createdAt) }) } -func (s *Server) recordMessageSent(ctx context.Context, message store.Message) { +func (s *Server) recordMessageSent(ctx context.Context, result store.SendMessageResult) { + orgID := result.OrganizationID + message := result.Message messageID := message.ID threadID := message.ThreadID createdAt := message.CreatedAt - s.recordUsageAsync(ctx, "message_sent", func(recordCtx context.Context, orgID uuid.UUID) error { + s.recordUsageAsync(ctx, "message_sent", func(recordCtx context.Context) error { return s.metering.RecordMessageSent(recordCtx, orgID, threadID, messageID, createdAt) }) } -func (s *Server) recordUsageAsync(ctx context.Context, label string, record func(context.Context, uuid.UUID) error) { +func (s *Server) recordUsageAsync(ctx context.Context, label string, record func(context.Context) error) { if s.metering == nil { return } - orgID, ok, err := organizationIDFromContext(ctx) - if err != nil { - log.Printf("metering: %s: %v", label, err) - return - } - if !ok { - return - } go func() { - recordCtx, cancel := context.WithTimeout(context.Background(), meteringTimeout) + recordCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), meteringTimeout) defer cancel() - if err := record(recordCtx, orgID); err != nil { + if err := record(recordCtx); err != nil { log.Printf("metering: %s: %v", label, err) } }() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index b723857..cea0630 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -188,6 +188,41 @@ type stubAuthorizationService struct { writeFn func(ctx context.Context, req *authorizationv1.WriteRequest, opts ...grpc.CallOption) (*authorizationv1.WriteResponse, error) } +type stubNotifier struct { + t *testing.T + publishFn func(ctx context.Context, threadID, messageID uuid.UUID, recipients []uuid.UUID) error +} + +func (s *stubNotifier) PublishMessageCreated(ctx context.Context, threadID, messageID uuid.UUID, recipients []uuid.UUID) error { + s.t.Helper() + if s.publishFn == nil { + return nil + } + return s.publishFn(ctx, threadID, messageID, recipients) +} + +type stubMeteringRecorder struct { + t *testing.T + recordThreadCreatedFn func(ctx context.Context, orgID, threadID uuid.UUID, createdAt time.Time) error + recordMessageSentFn func(ctx context.Context, orgID, threadID, messageID uuid.UUID, createdAt time.Time) error +} + +func (s *stubMeteringRecorder) RecordThreadCreated(ctx context.Context, orgID, threadID uuid.UUID, createdAt time.Time) error { + s.t.Helper() + if s.recordThreadCreatedFn == nil { + s.t.Fatalf("unexpected RecordThreadCreated call") + } + return s.recordThreadCreatedFn(ctx, orgID, threadID, createdAt) +} + +func (s *stubMeteringRecorder) RecordMessageSent(ctx context.Context, orgID, threadID, messageID uuid.UUID, createdAt time.Time) error { + s.t.Helper() + if s.recordMessageSentFn == nil { + s.t.Fatalf("unexpected RecordMessageSent call") + } + return s.recordMessageSentFn(ctx, orgID, threadID, messageID, createdAt) +} + func (s *stubAuthorizationService) Check(ctx context.Context, req *authorizationv1.CheckRequest, opts ...grpc.CallOption) (*authorizationv1.CheckResponse, error) { s.t.Helper() if s.checkFn == nil { @@ -214,6 +249,70 @@ func allowAuthStub(t *testing.T) *stubAuthorizationService { } } +func TestCreateThreadRecordsUsageWithCreatedThreadOrganization(t *testing.T) { + threadID := uuid.New() + organizationID := uuid.New() + identityID := uuid.New() + participantID := uuid.New() + now := time.Now().UTC() + recorded := make(chan struct{}, 1) + + storeStub := &stubThreadStore{ + t: t, + createThreadFn: func(ctx context.Context, orgID uuid.UUID, participants []store.ParticipantInput) (store.Thread, error) { + if orgID != organizationID { + t.Fatalf("expected organization %s, got %s", organizationID, orgID) + } + return store.Thread{ + ID: threadID, + OrganizationID: &organizationID, + MessageCount: 0, + Status: store.ThreadStatusActive, + CreatedAt: now, + UpdatedAt: now, + Participants: []store.Participant{ + {ID: identityID, JoinedAt: now, Passive: false}, + {ID: participantID, JoinedAt: now, Passive: false}, + }, + }, nil + }, + } + meteringStub := &stubMeteringRecorder{ + t: t, + recordThreadCreatedFn: func(ctx context.Context, orgID, recordedThreadID uuid.UUID, createdAt time.Time) error { + if orgID != organizationID { + t.Fatalf("expected metering organization %s, got %s", organizationID, orgID) + } + if recordedThreadID != threadID { + t.Fatalf("expected metering thread %s, got %s", threadID, recordedThreadID) + } + if !createdAt.Equal(now) { + t.Fatalf("expected metering created_at %s, got %s", now, createdAt) + } + recorded <- struct{}{} + return nil + }, + } + + srv := New(storeStub, nil, allowAuthStub(t), &stubIdentityResolver{t: t}, nil, meteringStub) + ctx := metadata.NewIncomingContext( + context.Background(), + metadata.Pairs("x-identity-id", identityID.String(), "x-identity-type", "user"), + ) + _, err := srv.CreateThread(ctx, &threadsv1.CreateThreadRequest{ + OrganizationId: &[]string{organizationID.String()}[0], + ParticipantIds: []string{participantID.String()}, + }) + if err != nil { + t.Fatalf("CreateThread returned error: %v", err) + } + select { + case <-recorded: + case <-time.After(time.Second): + t.Fatal("expected thread usage to be recorded") + } +} + func TestCreateThreadParticipantIDsRequireIdentityResolver(t *testing.T) { organizationID := uuid.New() identityID := uuid.New() @@ -1712,6 +1811,68 @@ func TestSendMessageAuthorizationDenied(t *testing.T) { } } +func TestSendMessageRecordsUsageWithThreadOrganization(t *testing.T) { + threadID := uuid.New() + messageID := uuid.New() + organizationID := uuid.New() + identityID := uuid.New() + now := time.Now().UTC() + recorded := make(chan struct{}, 1) + + storeStub := &stubThreadStore{ + t: t, + sendMessageFn: func(ctx context.Context, threadArg, senderArg uuid.UUID, body string, fileIDs []uuid.UUID) (store.SendMessageResult, error) { + if threadArg != threadID { + t.Fatalf("expected thread %s, got %s", threadID, threadArg) + } + if senderArg != identityID { + t.Fatalf("expected sender %s, got %s", identityID, senderArg) + } + return store.SendMessageResult{ + Message: store.Message{ + ID: messageID, + ThreadID: threadID, + SenderID: identityID, + Body: body, + CreatedAt: now, + }, + OrganizationID: organizationID, + }, nil + }, + } + meteringStub := &stubMeteringRecorder{ + t: t, + recordMessageSentFn: func(ctx context.Context, orgID, recordedThreadID, recordedMessageID uuid.UUID, createdAt time.Time) error { + if orgID != organizationID { + t.Fatalf("expected metering organization %s, got %s", organizationID, orgID) + } + if recordedThreadID != threadID { + t.Fatalf("expected metering thread %s, got %s", threadID, recordedThreadID) + } + if recordedMessageID != messageID { + t.Fatalf("expected metering message %s, got %s", messageID, recordedMessageID) + } + if !createdAt.Equal(now) { + t.Fatalf("expected metering created_at %s, got %s", now, createdAt) + } + recorded <- struct{}{} + return nil + }, + } + + srv := New(storeStub, &stubNotifier{t: t}, allowAuthStub(t), nil, nil, meteringStub) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) + _, err := srv.SendMessage(ctx, &threadsv1.SendMessageRequest{ThreadId: threadID.String(), Body: "hi"}) + if err != nil { + t.Fatalf("SendMessage returned error: %v", err) + } + select { + case <-recorded: + case <-time.After(time.Second): + t.Fatal("expected message usage to be recorded") + } +} + func TestSendMessageRejectsSenderMismatch(t *testing.T) { threadID := uuid.New() identityID := uuid.New() diff --git a/internal/store/messages.go b/internal/store/messages.go index d944f51..a1718a5 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -12,8 +12,9 @@ import ( ) type SendMessageResult struct { - Message Message - Recipients []uuid.UUID + Message Message + OrganizationID uuid.UUID + Recipients []uuid.UUID } func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, body string, fileIDs []uuid.UUID) (SendMessageResult, error) { @@ -62,7 +63,8 @@ func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, b FileIDs: fileIDs, CreatedAt: now, }, - Recipients: recipients, + OrganizationID: *thread.OrganizationID, + Recipients: recipients, } return nil }) From 38f620463aa94035636c67b57c1a0a978b562d0c Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 14 May 2026 11:08:25 +0000 Subject: [PATCH 2/2] fix: guard missing thread organization --- internal/server/server.go | 2 ++ internal/server/server_test.go | 29 +++++++++++++++++++++++++++++ internal/store/messages.go | 6 +++++- internal/store/types.go | 9 +++++---- 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 8f32201..7a9c271 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1296,6 +1296,8 @@ func toStatusError(err error) error { return status.Error(codes.FailedPrecondition, err.Error()) case errors.Is(err, store.ErrThreadDegraded): return status.Error(codes.FailedPrecondition, err.Error()) + case errors.Is(err, store.ErrThreadOrganizationMissing): + return status.Error(codes.FailedPrecondition, err.Error()) case errors.Is(err, store.ErrParticipantNotInThread): return status.Error(codes.InvalidArgument, err.Error()) default: diff --git a/internal/server/server_test.go b/internal/server/server_test.go index cea0630..3c03ef9 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1873,6 +1873,35 @@ func TestSendMessageRecordsUsageWithThreadOrganization(t *testing.T) { } } +func TestSendMessageRejectsThreadWithoutOrganization(t *testing.T) { + threadID := uuid.New() + identityID := uuid.New() + + storeStub := &stubThreadStore{ + t: t, + sendMessageFn: func(ctx context.Context, threadArg, senderArg uuid.UUID, body string, fileIDs []uuid.UUID) (store.SendMessageResult, error) { + return store.SendMessageResult{}, store.ErrThreadOrganizationMissing + }, + } + + srv := New(storeStub, nil, allowAuthStub(t), nil, nil, nil) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String())) + _, err := srv.SendMessage(ctx, &threadsv1.SendMessageRequest{ThreadId: threadID.String(), Body: "hi"}) + if err == nil { + t.Fatal("expected error") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.FailedPrecondition { + t.Fatalf("expected FailedPrecondition, got %s: %s", st.Code(), st.Message()) + } + if st.Message() != store.ErrThreadOrganizationMissing.Error() { + t.Fatalf("expected message %q, got %q", store.ErrThreadOrganizationMissing.Error(), st.Message()) + } +} + func TestSendMessageRejectsSenderMismatch(t *testing.T) { threadID := uuid.New() identityID := uuid.New() diff --git a/internal/store/messages.go b/internal/store/messages.go index a1718a5..18a587b 100644 --- a/internal/store/messages.go +++ b/internal/store/messages.go @@ -30,6 +30,10 @@ func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, b if thread.Status == ThreadStatusDegraded { return ErrThreadDegraded } + if thread.OrganizationID == nil { + return ErrThreadOrganizationMissing + } + organizationID := *thread.OrganizationID now := time.Now().UTC() messageID := uuid.New() fileIDArray := pgtype.FlatArray[string](uuidsToStrings(fileIDs)) @@ -63,7 +67,7 @@ func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, b FileIDs: fileIDs, CreatedAt: now, }, - OrganizationID: *thread.OrganizationID, + OrganizationID: organizationID, Recipients: recipients, } return nil diff --git a/internal/store/types.go b/internal/store/types.go index 3e1969f..6917a9e 100644 --- a/internal/store/types.go +++ b/internal/store/types.go @@ -8,10 +8,11 @@ import ( ) var ( - ErrThreadNotFound = errors.New("thread not found") - ErrThreadArchived = errors.New("thread is archived") - ErrThreadDegraded = errors.New("thread is degraded") - ErrParticipantNotInThread = errors.New("participant not in thread") + ErrThreadNotFound = errors.New("thread not found") + ErrThreadArchived = errors.New("thread is archived") + ErrThreadDegraded = errors.New("thread is degraded") + ErrThreadOrganizationMissing = errors.New("thread organization_id missing") + ErrParticipantNotInThread = errors.New("participant not in thread") ) type ThreadStatus int16