From d67f90cb79108154ad7e8d4f8578d40c95ef1b2e Mon Sep 17 00:00:00 2001 From: Copybara Service Date: Mon, 11 May 2026 10:58:37 -0700 Subject: [PATCH] Add connections and buffers flush support. PiperOrigin-RevId: 913782444 --- fleetspeak/src/client/client.go | 20 +++ fleetspeak/src/client/comms/comms.go | 14 ++ fleetspeak/src/client/https/polling.go | 66 +++++++- fleetspeak/src/client/https/polling_test.go | 142 ++++++++++++++++++ fleetspeak/src/client/https/streaming.go | 88 ++++++++++- fleetspeak/src/client/https/streaming_test.go | 137 ++++++++++++++++- fleetspeak/src/client/service/service.go | 6 + fleetspeak/src/client/services.go | 14 ++ 8 files changed, 475 insertions(+), 12 deletions(-) diff --git a/fleetspeak/src/client/client.go b/fleetspeak/src/client/client.go index 3f4f0fa0..d681ae3e 100644 --- a/fleetspeak/src/client/client.go +++ b/fleetspeak/src/client/client.go @@ -310,3 +310,23 @@ func (c *Client) Stop() { done <- struct{}{} log.Info("Messages have been drained.") } + +// FlushServices triggers all services that support it to flush their data. +func (c *Client) FlushServices(ctx context.Context) { + log.InfoContextf(ctx, "Client: FlushServices called") + c.sc.FlushServices(ctx) +} + +// ForceResetAndFlush programmatically and forcefully tears down the existing +// network connection, establishes a fresh one, and blocks until the outbox is +// successfully drained or the context expires. +func (c *Client) ForceResetAndFlush(ctx context.Context) error { + log.InfoContextf(ctx, "ForceResetAndFlush triggered") + if c.com == nil { + return fmt.Errorf("communicator not set") + } + log.InfoContextf(ctx, "Calling CancelConnections") + c.com.CancelConnections() + log.InfoContextf(ctx, "CancelConnections finished, calling WaitForFlush") + return c.com.WaitForFlush(ctx) +} diff --git a/fleetspeak/src/client/comms/comms.go b/fleetspeak/src/client/comms/comms.go index f429923c..2dd61be9 100644 --- a/fleetspeak/src/client/comms/comms.go +++ b/fleetspeak/src/client/comms/comms.go @@ -32,6 +32,12 @@ import ( fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" ) +// darkWakeKeyType is the type for the DarkWakeKey context key. +type darkWakeKeyType struct{} + +// DarkWakeKey is a context key used to indicate that an operation is part of a DarkWake flush. +var DarkWakeKey = darkWakeKeyType{} + // A Communicator is a component which allows a client to communicate with a // Fleetspeak server. type Communicator interface { @@ -39,6 +45,14 @@ type Communicator interface { Start() error // Tells the communicator to start sending and receiving messages. Stop() // Tells the communicator to stop sending and receiving messages. + // CancelConnections forcefully cancels any underlying network connections. + CancelConnections() + + // WaitForFlush blocks until the communicator has successfully established a + // fresh connection and flushed all pending messages, or until the provided + // context expires. + WaitForFlush(ctx context.Context) error + // GetFileIfModified attempts to retrieve a file from a server, if it // has been modified since modSince. If it has not been modified, it // returns nil. Otherwise, it returns a ReadCloser for the file's data diff --git a/fleetspeak/src/client/https/polling.go b/fleetspeak/src/client/https/polling.go index 85459be6..af3f68ac 100644 --- a/fleetspeak/src/client/https/polling.go +++ b/fleetspeak/src/client/https/polling.go @@ -58,6 +58,11 @@ type Communicator struct { clientCertificateHeader string certBytes []byte + + wakeUp chan struct{} + pollComplete chan error + mu sync.Mutex + pollCancel context.CancelFunc } // Setup implements comms.Communicator. @@ -108,6 +113,8 @@ func (c *Communicator) configure() error { } c.ctx, c.done = context.WithCancel(context.Background()) c.clientCertificateHeader = si.ClientCertificateHeader + c.wakeUp = make(chan struct{}, 1) + c.pollComplete = make(chan error, 1) c.certBytes = certBytes return nil } @@ -127,6 +134,39 @@ func (c *Communicator) Stop() { c.wd.Stop() } +// CancelConnections implements comms.Communicator. +func (c *Communicator) CancelConnections() { + log.Infof("CancelConnections called") + c.mu.Lock() + if c.pollCancel != nil { + c.pollCancel() + } + c.mu.Unlock() + c.hc.Transport.(*http.Transport).CloseIdleConnections() + // Drain pollComplete to ensure WaitForFlush waits for a new poll. + select { + case <-c.pollComplete: + default: + } + select { + case c.wakeUp <- struct{}{}: + default: + } +} + +// WaitForFlush implements comms.Communicator. +func (c *Communicator) WaitForFlush(ctx context.Context) error { + log.InfoContextf(ctx, "WaitForFlush called") + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-c.pollComplete: + return err + } + } +} + // processingLoop polls the server according to the configured policies while // the communicator is active. // @@ -152,11 +192,19 @@ func (c *Communicator) processingLoop() { // and updates the variables defined above. In case of failure it also sleeps // for the MinFailureDelay. poll := func() { + var err error + defer func() { + select { + case c.pollComplete <- err: + default: + } + }() c.wd.Reset() if c.cctx.CurrentID() != c.id { c.configure() } - active, err := c.poll(toSend) + var active bool + active, err = c.poll(toSend) if err != nil { log.Warningf("Failure during polling: %v", err) for _, m := range toSend { @@ -175,6 +223,8 @@ func (c *Communicator) processingLoop() { case <-t.C: case <-c.ctx.Done(): t.Stop() + case <-c.wakeUp: + t.Stop() } return } @@ -255,6 +305,9 @@ func (c *Communicator) processingLoop() { return case <-t.C: poll() + case <-c.wakeUp: + t.Stop() + poll() case m := <-c.cctx.Outbox(): t.Stop() toSend = append(toSend, m) @@ -339,6 +392,17 @@ func (c *Communicator) pollHost(host string, data []byte) (*fspb.ContactData, er if sendErr != nil { return nil, sendErr } + var reqCtx context.Context + c.mu.Lock() + reqCtx, c.pollCancel = context.WithCancel(c.ctx) + c.mu.Unlock() + defer func() { + c.mu.Lock() + c.pollCancel() + c.pollCancel = nil + c.mu.Unlock() + }() + req = req.WithContext(reqCtx) SetContentEncoding(req.Header, c.conf.GetCompression()) if c.clientCertificateHeader != "" { bc := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.certBytes}) diff --git a/fleetspeak/src/client/https/polling_test.go b/fleetspeak/src/client/https/polling_test.go index 084a689c..f7465588 100644 --- a/fleetspeak/src/client/https/polling_test.go +++ b/fleetspeak/src/client/https/polling_test.go @@ -733,3 +733,145 @@ func TestCertificateRevoked(t *testing.T) { tl.Close() cl.Stop() } + +func TestCancelConnections(t *testing.T) { + var c Communicator + conf := config.Configuration{ + Servers: []string{"localhost:1234"}, + CommunicatorConfig: &clpb.CommunicatorConfig{ + MaxPollDelaySeconds: 10, + MinFailureDelaySeconds: 10, + }, + } + dialCalls := make(chan struct{}, 10) + c.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialCalls <- struct{}{} + <-ctx.Done() + return nil, ctx.Err() + } + + cl, err := client.New( + conf, + client.Components{ + Communicator: &c, + }) + if err != nil { + t.Fatalf("unable to create client: %v", err) + } + defer cl.Stop() + + // Wait for first dial attempt + select { + case <-dialCalls: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for first dial attempt") + } + + // Call CancelConnections + c.CancelConnections() + + // Wait for second dial attempt (should be immediate, not waiting 10s) + select { + case <-dialCalls: + case <-time.After(2 * time.Second): + t.Fatal("Timed out waiting for second dial attempt after CancelConnections") + } +} + +func TestWaitForFlush(t *testing.T) { + // Create a local https server for the client to talk to. + pemCert, pemKey, err := common_util.ServerCert() + if err != nil { + t.Fatal(err) + } + cp, err := tls.X509KeyPair(pemCert, pemKey) + if err != nil { + t.Fatal(err) + } + ad, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + tl, err := net.ListenTCP("tcp", ad) + if err != nil { + t.Fatal(err) + } + addr := tl.Addr().String() + + pollCount := int32(0) + mux := http.NewServeMux() + mux.HandleFunc("/message", func(res http.ResponseWriter, req *http.Request) { + atomic.AddInt32(&pollCount, 1) + cd := fspb.ContactData{ + SequencingNonce: 42, + } + buf, err := proto.Marshal(&cd) + if err != nil { + t.Fatalf("unable to marshal ContactData: %v", err) + } + res.Header().Set("Content-Type", "application/octet-stream") + res.WriteHeader(http.StatusOK) + res.Write(buf) + }) + + server := http.Server{ + Addr: addr, + Handler: mux, + TLSConfig: &tls.Config{ + ClientAuth: tls.RequireAnyClientCert, + Certificates: []tls.Certificate{cp}, + NextProtos: []string{"h2"}, + }, + } + l := tls.NewListener(tl, server.TLSConfig) + go server.Serve(l) + defer tl.Close() + + var c Communicator + conf := config.Configuration{ + TrustedCerts: x509.NewCertPool(), + Servers: []string{addr}, + CommunicatorConfig: &clpb.CommunicatorConfig{ + MaxPollDelaySeconds: 10, + MinFailureDelaySeconds: 10, + }, + } + if !conf.TrustedCerts.AppendCertsFromPEM(pemCert) { + t.Fatal("unable to add server cert to pool") + } + + cl, err := client.New( + conf, + client.Components{ + Communicator: &c, + }) + if err != nil { + t.Fatalf("unable to create client: %v", err) + } + defer cl.Stop() + + // Wait for first poll + for atomic.LoadInt32(&pollCount) == 0 { + time.Sleep(100 * time.Millisecond) + } + + // Call WaitForFlush in a goroutine + flushDone := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + flushDone <- c.WaitForFlush(ctx) + }() + + // Trigger a poll via wakeUp (indirectly via CancelConnections) + c.CancelConnections() + + select { + case err := <-flushDone: + if err != nil { + t.Errorf("WaitForFlush failed: %v", err) + } + case <-time.After(6 * time.Second): + t.Fatal("Timed out waiting for WaitForFlush") + } +} diff --git a/fleetspeak/src/client/https/streaming.go b/fleetspeak/src/client/https/streaming.go index 89b0db6c..d1e68ca9 100644 --- a/fleetspeak/src/client/https/streaming.go +++ b/fleetspeak/src/client/https/streaming.go @@ -16,6 +16,7 @@ package https import ( "bufio" + "context" "crypto/tls" "encoding/binary" @@ -65,12 +66,17 @@ type StreamingCommunicator struct { clientCertificateHeader string certBytes []byte + + wakeUp chan struct{} + mu sync.Mutex + curCon *connection } // Setup implements comms.Communicator. func (c *StreamingCommunicator) Setup(cl comms.Context) error { c.cctx = cl c.ctx, c.fin = context.WithCancel(context.Background()) + c.wakeUp = make(chan struct{}, 1) c.conf = c.cctx.CommunicatorConfig() if c.conf == nil { return errors.New("no communicator_config in client configuration") @@ -99,6 +105,59 @@ func (c *StreamingCommunicator) Stop() { c.wd.Stop() } +// CancelConnections implements comms.Communicator. +func (c *StreamingCommunicator) CancelConnections() { + log.Infof("CancelConnections called") + c.mu.Lock() + if c.curCon != nil { + c.curCon.stop() + } + c.mu.Unlock() + c.hc.Transport.(*http.Transport).CloseIdleConnections() + select { + case c.wakeUp <- struct{}{}: + default: + } +} + +// WaitForFlush implements comms.Communicator. +func (c *StreamingCommunicator) WaitForFlush(ctx context.Context) error { + log.InfoContextf(ctx, "WaitForFlush called") + + c.mu.Lock() + for c.curCon == nil || c.curCon.ctx.Err() != nil { + log.InfoContextf(ctx, "WaitForFlush: No active connection or connection is dying, waiting for reconnect.") + c.mu.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + // poll again + } + c.mu.Lock() + } + con := c.curCon + c.mu.Unlock() + + if con != nil { + con.pendingLock.Lock() + l := len(con.pending) + con.pendingLock.Unlock() + if l == 0 { + log.InfoContextf(ctx, "WaitForFlush: Connection active but no pending messages.") + return nil + } + log.InfoContextf(ctx, "WaitForFlush: Waiting for %d pending messages to flush.", l) + select { + case <-ctx.Done(): + return ctx.Err() + case <-con.pendingEmpty: + return nil + } + } + return nil +} + func (c *StreamingCommunicator) GetFileIfModified(ctx context.Context, service, name string, modSince time.Time) (io.ReadCloser, time.Time, error) { c.hostLock.RLock() hosts := append([]string(nil), c.hosts...) @@ -189,11 +248,21 @@ func (c *StreamingCommunicator) connectLoop() { case <-c.ctx.Done(): t.Stop() return + case <-c.wakeUp: + t.Stop() } continue } + c.mu.Lock() + c.curCon = con + c.mu.Unlock() + log.Info("Signal new connection established.") log.V(2).Infof("--%p: started", con) con.working.Wait() + c.mu.Lock() + c.curCon = nil + c.mu.Unlock() + log.Info("Connection cleared.") lastContact = time.Now() for _, l := range con.pending { for _, m := range l { @@ -231,11 +300,12 @@ func (c *connection) readContact(body *bufio.Reader) (cd *fspb.ContactData, err // connection to the given host. ctx only regulates this initial connection. func (c *StreamingCommunicator) connect(ctx context.Context, host string, maxLife time.Duration) (*connection, error) { ret := connection{ - cctx: c.cctx, - pending: make(map[int][]comms.MessageInfo), - serverDone: make(chan struct{}), - writingDone: make(chan struct{}), - host: host, + cctx: c.cctx, + pending: make(map[int][]comms.MessageInfo), + serverDone: make(chan struct{}), + writingDone: make(chan struct{}), + host: host, + pendingEmpty: make(chan struct{}, 1), } ret.ctx, ret.stop = context.WithTimeout(c.ctx, maxLife) @@ -396,6 +466,8 @@ type connection struct { // Closure which can be called to terminate the connection. stop func() + pendingEmpty chan struct{} + // Count of processed messages (per service), as of the last message // sent to the server. Used to update the number of messages the server // can send us. @@ -586,6 +658,12 @@ func (c *connection) readLoop(body *bufio.Reader, closer io.Closer) { delete(c.pending, cnt) l := len(c.pending) c.pendingLock.Unlock() + if l == 0 { + select { + case c.pendingEmpty <- struct{}{}: + default: + } + } cnt++ for _, m := range toAck { m.Ack() diff --git a/fleetspeak/src/client/https/streaming_test.go b/fleetspeak/src/client/https/streaming_test.go index 2e12bd6e..552c69ec 100644 --- a/fleetspeak/src/client/https/streaming_test.go +++ b/fleetspeak/src/client/https/streaming_test.go @@ -60,9 +60,10 @@ func TestStreamingCreate(t *testing.T) { type streamingTestServer struct { // These should be populated by the creator. - t *testing.T - received chan<- *fspb.ContactData - toSend <-chan *fspb.ContactData + t *testing.T + received chan<- *fspb.ContactData + toSend <-chan *fspb.ContactData + allowMultipleRequests bool // These are populated by Start. pemCert, pemKey []byte @@ -100,9 +101,12 @@ func (s *streamingTestServer) Start() { if err != nil { s.t.Errorf("unable to make ClientID in test server: %v", err) } - if reqs := atomic.AddInt32(&s.rc, 1); reqs != 1 { - s.t.Errorf("Only expected 1 request, but this is request %d", reqs) - http.Error(res, "only expected 1 request", http.StatusBadRequest) + reqs := atomic.AddInt32(&s.rc, 1) + if !s.allowMultipleRequests { + if reqs != 1 { + s.t.Errorf("Only expected 1 request, but this is request %d", reqs) + http.Error(res, "only expected 1 request", http.StatusBadRequest) + } } body := bufio.NewReader(req.Body) b := make([]byte, 4) @@ -629,3 +633,124 @@ func TestStreamingCommunicatorBulkFast(t *testing.T) { } close(toSend) } + +func TestStreamingCancelConnections(t *testing.T) { + var c StreamingCommunicator + conf := config.Configuration{ + Servers: []string{"localhost:1234"}, + CommunicatorConfig: &clpb.CommunicatorConfig{ + MinFailureDelaySeconds: 10, + }, + } + dialCalls := make(chan struct{}, 10) + c.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialCalls <- struct{}{} + <-ctx.Done() + return nil, ctx.Err() + } + + cl, err := client.New( + conf, + client.Components{ + Communicator: &c, + }) + if err != nil { + t.Fatalf("unable to create client: %v", err) + } + defer cl.Stop() + + // Wait for first dial attempt + select { + case <-dialCalls: + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for first dial attempt") + } + + // Call CancelConnections + c.CancelConnections() + + // Wait for second dial attempt (should be immediate, not waiting 10s) + select { + case <-dialCalls: + case <-time.After(2 * time.Second): + t.Fatal("Timed out waiting for second dial attempt after CancelConnections") + } +} + +func TestStreamingWaitForFlushAfterCancel(t *testing.T) { + received := make(chan *fspb.ContactData, 5) + toSend := make(chan *fspb.ContactData, 5) + + server := streamingTestServer{ + t: t, + received: received, + toSend: toSend, + allowMultipleRequests: true, + } + server.Start() + defer server.Stop() + + var c StreamingCommunicator + conf := config.Configuration{ + Servers: []string{server.Addr()}, + TrustedCerts: x509.NewCertPool(), + FixedServices: []*fspb.ClientServiceConfig{{Name: "NOOPService", Factory: "NOOP"}}, + CommunicatorConfig: &clpb.CommunicatorConfig{ + MaxPollDelaySeconds: 2, + MaxBufferDelaySeconds: 1, + MinFailureDelaySeconds: 1, + Compression: fspb.CompressionAlgorithm_COMPRESSION_NONE, + }, + } + if !conf.TrustedCerts.AppendCertsFromPEM(server.pemCert) { + t.Fatal("unable to add server cert to pool") + } + cl, err := client.New( + conf, + client.Components{ + ServiceFactories: map[string]service.Factory{"NOOP": service.NOOPFactory}, + Communicator: &c}) + if err != nil { + t.Fatalf("unable to create client: %v", err) + } + defer cl.Stop() + + // Wait for initial connection to be established. + acks := make(chan int, 1) + if err := cl.ProcessMessage(context.Background(), + service.AckMessage{ + M: &fspb.Message{ + Destination: &fspb.Address{ServiceName: "DummyService"}}, + Ack: func() { acks <- 0 }, + }); err != nil { + t.Fatalf("unable to hand message to client: %v", err) + } + + <-received + <-acks + + // Connection is idle. + + // Call CancelConnections to break it. + c.CancelConnections() + + // Call WaitForFlush in a goroutine. + flushDone := make(chan error, 1) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + flushDone <- c.WaitForFlush(ctx) + }() + + select { + case err := <-flushDone: + if err != nil { + t.Errorf("WaitForFlush failed: %v", err) + } + if atomic.LoadInt32(&server.rc) < 2 { + t.Errorf("WaitForFlush returned before reconnect. Expected at least 2 requests, got %d", server.rc) + } + case <-time.After(10 * time.Second): + t.Fatal("Timed out waiting for WaitForFlush to complete") + } +} diff --git a/fleetspeak/src/client/service/service.go b/fleetspeak/src/client/service/service.go index c102816b..00963cdd 100644 --- a/fleetspeak/src/client/service/service.go +++ b/fleetspeak/src/client/service/service.go @@ -59,6 +59,12 @@ type Service interface { Stop() error } +// A Flusher is a service that can be requested to flush its buffers or +// process data immediately, e.g., during a darkwake. +type Flusher interface { + Flush(ctx context.Context) error +} + // LocalInfo stores summary information about the local client. type LocalInfo struct { ClientID common.ClientID // The ClientID of the local client. diff --git a/fleetspeak/src/client/services.go b/fleetspeak/src/client/services.go index 15f5113c..3ea14fba 100644 --- a/fleetspeak/src/client/services.go +++ b/fleetspeak/src/client/services.go @@ -210,6 +210,20 @@ func (c *serviceConfiguration) Stop() { c.services = make(map[string]*serviceData) } +// FlushServices calls Flush on all services that implement the Flusher interface. +func (c *serviceConfiguration) FlushServices(ctx context.Context) { + c.lock.RLock() + defer c.lock.RUnlock() + for _, sd := range c.services { + if flusher, ok := sd.service.(service.Flusher); ok { + log.Infof("Triggering flush for service %s", sd.name) + if err := flusher.Flush(ctx); err != nil { + log.Errorf("Error flushing service %s: %v", sd.name, err) + } + } + } +} + // A serviceData contains the data we have about a configured service, wrapping // a Service interface and mediating communication between it and the rest of // the Fleetspeak client.