Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (s *Server) SendMessage(ctx context.Context, req *threadsv1.SendMessageRequ
if err != nil {
return nil, toStatusError(err)
}
s.recordMessageSent(ctx, result.Message)
s.recordMessageSent(ctx, result)
if err := s.notifier.PublishMessageCreated(ctx, threadID, result.Message.ID, result.Recipients); err != nil {
return nil, status.Errorf(codes.Internal, "notify recipients: %v", err)
}
Expand Down Expand Up @@ -741,38 +741,36 @@ func (s *Server) AckMessages(ctx context.Context, req *threadsv1.AckMessagesRequ
}

func (s *Server) recordThreadCreated(ctx context.Context, thread store.Thread) {
if thread.OrganizationID == nil {
panic("thread organization_id missing")
}
threadID := thread.ID
orgID := *thread.OrganizationID
createdAt := thread.CreatedAt
s.recordUsageAsync(ctx, "thread_created", func(recordCtx context.Context, orgID uuid.UUID) error {
s.recordUsageAsync(ctx, "thread_created", func(recordCtx context.Context) error {
return s.metering.RecordThreadCreated(recordCtx, orgID, threadID, createdAt)
})
}

func (s *Server) recordMessageSent(ctx context.Context, message store.Message) {
func (s *Server) recordMessageSent(ctx context.Context, result store.SendMessageResult) {
orgID := result.OrganizationID
message := result.Message
messageID := message.ID
threadID := message.ThreadID
createdAt := message.CreatedAt
s.recordUsageAsync(ctx, "message_sent", func(recordCtx context.Context, orgID uuid.UUID) error {
s.recordUsageAsync(ctx, "message_sent", func(recordCtx context.Context) error {
return s.metering.RecordMessageSent(recordCtx, orgID, threadID, messageID, createdAt)
})
}

func (s *Server) recordUsageAsync(ctx context.Context, label string, record func(context.Context, uuid.UUID) error) {
func (s *Server) recordUsageAsync(ctx context.Context, label string, record func(context.Context) error) {
if s.metering == nil {
return
}
orgID, ok, err := organizationIDFromContext(ctx)
if err != nil {
log.Printf("metering: %s: %v", label, err)
return
}
if !ok {
return
}
go func() {
recordCtx, cancel := context.WithTimeout(context.Background(), meteringTimeout)
recordCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), meteringTimeout)
defer cancel()
if err := record(recordCtx, orgID); err != nil {
if err := record(recordCtx); err != nil {
log.Printf("metering: %s: %v", label, err)
}
}()
Expand Down Expand Up @@ -1298,6 +1296,8 @@ func toStatusError(err error) error {
return status.Error(codes.FailedPrecondition, err.Error())
case errors.Is(err, store.ErrThreadDegraded):
return status.Error(codes.FailedPrecondition, err.Error())
case errors.Is(err, store.ErrThreadOrganizationMissing):
return status.Error(codes.FailedPrecondition, err.Error())
case errors.Is(err, store.ErrParticipantNotInThread):
return status.Error(codes.InvalidArgument, err.Error())
default:
Expand Down
190 changes: 190 additions & 0 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,41 @@ type stubAuthorizationService struct {
writeFn func(ctx context.Context, req *authorizationv1.WriteRequest, opts ...grpc.CallOption) (*authorizationv1.WriteResponse, error)
}

type stubNotifier struct {
t *testing.T
publishFn func(ctx context.Context, threadID, messageID uuid.UUID, recipients []uuid.UUID) error
}

func (s *stubNotifier) PublishMessageCreated(ctx context.Context, threadID, messageID uuid.UUID, recipients []uuid.UUID) error {
s.t.Helper()
if s.publishFn == nil {
return nil
}
return s.publishFn(ctx, threadID, messageID, recipients)
}

type stubMeteringRecorder struct {
t *testing.T
recordThreadCreatedFn func(ctx context.Context, orgID, threadID uuid.UUID, createdAt time.Time) error
recordMessageSentFn func(ctx context.Context, orgID, threadID, messageID uuid.UUID, createdAt time.Time) error
}

func (s *stubMeteringRecorder) RecordThreadCreated(ctx context.Context, orgID, threadID uuid.UUID, createdAt time.Time) error {
s.t.Helper()
if s.recordThreadCreatedFn == nil {
s.t.Fatalf("unexpected RecordThreadCreated call")
}
return s.recordThreadCreatedFn(ctx, orgID, threadID, createdAt)
}

func (s *stubMeteringRecorder) RecordMessageSent(ctx context.Context, orgID, threadID, messageID uuid.UUID, createdAt time.Time) error {
s.t.Helper()
if s.recordMessageSentFn == nil {
s.t.Fatalf("unexpected RecordMessageSent call")
}
return s.recordMessageSentFn(ctx, orgID, threadID, messageID, createdAt)
}

func (s *stubAuthorizationService) Check(ctx context.Context, req *authorizationv1.CheckRequest, opts ...grpc.CallOption) (*authorizationv1.CheckResponse, error) {
s.t.Helper()
if s.checkFn == nil {
Expand All @@ -214,6 +249,70 @@ func allowAuthStub(t *testing.T) *stubAuthorizationService {
}
}

func TestCreateThreadRecordsUsageWithCreatedThreadOrganization(t *testing.T) {
threadID := uuid.New()
organizationID := uuid.New()
identityID := uuid.New()
participantID := uuid.New()
now := time.Now().UTC()
recorded := make(chan struct{}, 1)

storeStub := &stubThreadStore{
t: t,
createThreadFn: func(ctx context.Context, orgID uuid.UUID, participants []store.ParticipantInput) (store.Thread, error) {
if orgID != organizationID {
t.Fatalf("expected organization %s, got %s", organizationID, orgID)
}
return store.Thread{
ID: threadID,
OrganizationID: &organizationID,
MessageCount: 0,
Status: store.ThreadStatusActive,
CreatedAt: now,
UpdatedAt: now,
Participants: []store.Participant{
{ID: identityID, JoinedAt: now, Passive: false},
{ID: participantID, JoinedAt: now, Passive: false},
},
}, nil
},
}
meteringStub := &stubMeteringRecorder{
t: t,
recordThreadCreatedFn: func(ctx context.Context, orgID, recordedThreadID uuid.UUID, createdAt time.Time) error {
if orgID != organizationID {
t.Fatalf("expected metering organization %s, got %s", organizationID, orgID)
}
if recordedThreadID != threadID {
t.Fatalf("expected metering thread %s, got %s", threadID, recordedThreadID)
}
if !createdAt.Equal(now) {
t.Fatalf("expected metering created_at %s, got %s", now, createdAt)
}
recorded <- struct{}{}
return nil
},
}

srv := New(storeStub, nil, allowAuthStub(t), &stubIdentityResolver{t: t}, nil, meteringStub)
ctx := metadata.NewIncomingContext(
context.Background(),
metadata.Pairs("x-identity-id", identityID.String(), "x-identity-type", "user"),
)
_, err := srv.CreateThread(ctx, &threadsv1.CreateThreadRequest{
OrganizationId: &[]string{organizationID.String()}[0],
ParticipantIds: []string{participantID.String()},
})
if err != nil {
t.Fatalf("CreateThread returned error: %v", err)
}
select {
case <-recorded:
case <-time.After(time.Second):
t.Fatal("expected thread usage to be recorded")
}
}

func TestCreateThreadParticipantIDsRequireIdentityResolver(t *testing.T) {
organizationID := uuid.New()
identityID := uuid.New()
Expand Down Expand Up @@ -1712,6 +1811,97 @@ func TestSendMessageAuthorizationDenied(t *testing.T) {
}
}

func TestSendMessageRecordsUsageWithThreadOrganization(t *testing.T) {
threadID := uuid.New()
messageID := uuid.New()
organizationID := uuid.New()
identityID := uuid.New()
now := time.Now().UTC()
recorded := make(chan struct{}, 1)

storeStub := &stubThreadStore{
t: t,
sendMessageFn: func(ctx context.Context, threadArg, senderArg uuid.UUID, body string, fileIDs []uuid.UUID) (store.SendMessageResult, error) {
if threadArg != threadID {
t.Fatalf("expected thread %s, got %s", threadID, threadArg)
}
if senderArg != identityID {
t.Fatalf("expected sender %s, got %s", identityID, senderArg)
}
return store.SendMessageResult{
Message: store.Message{
ID: messageID,
ThreadID: threadID,
SenderID: identityID,
Body: body,
CreatedAt: now,
},
OrganizationID: organizationID,
}, nil
},
}
meteringStub := &stubMeteringRecorder{
t: t,
recordMessageSentFn: func(ctx context.Context, orgID, recordedThreadID, recordedMessageID uuid.UUID, createdAt time.Time) error {
if orgID != organizationID {
t.Fatalf("expected metering organization %s, got %s", organizationID, orgID)
}
if recordedThreadID != threadID {
t.Fatalf("expected metering thread %s, got %s", threadID, recordedThreadID)
}
if recordedMessageID != messageID {
t.Fatalf("expected metering message %s, got %s", messageID, recordedMessageID)
}
if !createdAt.Equal(now) {
t.Fatalf("expected metering created_at %s, got %s", now, createdAt)
}
recorded <- struct{}{}
return nil
},
}

srv := New(storeStub, &stubNotifier{t: t}, allowAuthStub(t), nil, nil, meteringStub)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String()))
_, err := srv.SendMessage(ctx, &threadsv1.SendMessageRequest{ThreadId: threadID.String(), Body: "hi"})
if err != nil {
t.Fatalf("SendMessage returned error: %v", err)
}
select {
case <-recorded:
case <-time.After(time.Second):
t.Fatal("expected message usage to be recorded")
}
}

func TestSendMessageRejectsThreadWithoutOrganization(t *testing.T) {
threadID := uuid.New()
identityID := uuid.New()

storeStub := &stubThreadStore{
t: t,
sendMessageFn: func(ctx context.Context, threadArg, senderArg uuid.UUID, body string, fileIDs []uuid.UUID) (store.SendMessageResult, error) {
return store.SendMessageResult{}, store.ErrThreadOrganizationMissing
},
}

srv := New(storeStub, nil, allowAuthStub(t), nil, nil, nil)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-identity-id", identityID.String()))
_, err := srv.SendMessage(ctx, &threadsv1.SendMessageRequest{ThreadId: threadID.String(), Body: "hi"})
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.FailedPrecondition {
t.Fatalf("expected FailedPrecondition, got %s: %s", st.Code(), st.Message())
}
if st.Message() != store.ErrThreadOrganizationMissing.Error() {
t.Fatalf("expected message %q, got %q", store.ErrThreadOrganizationMissing.Error(), st.Message())
}
}

func TestSendMessageRejectsSenderMismatch(t *testing.T) {
threadID := uuid.New()
identityID := uuid.New()
Expand Down
12 changes: 9 additions & 3 deletions internal/store/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
)

type SendMessageResult struct {
Message Message
Recipients []uuid.UUID
Message Message
OrganizationID uuid.UUID
Recipients []uuid.UUID
}

func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, body string, fileIDs []uuid.UUID) (SendMessageResult, error) {
Expand All @@ -29,6 +30,10 @@ func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, b
if thread.Status == ThreadStatusDegraded {
return ErrThreadDegraded
}
if thread.OrganizationID == nil {
return ErrThreadOrganizationMissing
}
organizationID := *thread.OrganizationID
now := time.Now().UTC()
messageID := uuid.New()
fileIDArray := pgtype.FlatArray[string](uuidsToStrings(fileIDs))
Expand Down Expand Up @@ -62,7 +67,8 @@ func (s *Store) SendMessage(ctx context.Context, threadID, senderID uuid.UUID, b
FileIDs: fileIDs,
CreatedAt: now,
},
Recipients: recipients,
OrganizationID: organizationID,
Recipients: recipients,
}
return nil
})
Expand Down
9 changes: 5 additions & 4 deletions internal/store/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
)

var (
ErrThreadNotFound = errors.New("thread not found")
ErrThreadArchived = errors.New("thread is archived")
ErrThreadDegraded = errors.New("thread is degraded")
ErrParticipantNotInThread = errors.New("participant not in thread")
ErrThreadNotFound = errors.New("thread not found")
ErrThreadArchived = errors.New("thread is archived")
ErrThreadDegraded = errors.New("thread is degraded")
ErrThreadOrganizationMissing = errors.New("thread organization_id missing")
ErrParticipantNotInThread = errors.New("participant not in thread")
)

type ThreadStatus int16
Expand Down
Loading