Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions fleetspeak/src/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,23 @@ func (c *Client) Stop() {
done <- struct{}{}
log.Info("Messages have been drained.")
}

// FlushServices triggers all services that support it to flush their data.
func (c *Client) FlushServices(ctx context.Context) {
log.InfoContextf(ctx, "Client: FlushServices called")
c.sc.FlushServices(ctx)
}

// ForceResetAndFlush programmatically and forcefully tears down the existing
// network connection, establishes a fresh one, and blocks until the outbox is
// successfully drained or the context expires.
func (c *Client) ForceResetAndFlush(ctx context.Context) error {
log.InfoContextf(ctx, "ForceResetAndFlush triggered")
if c.com == nil {
return fmt.Errorf("communicator not set")
}
log.InfoContextf(ctx, "Calling CancelConnections")
c.com.CancelConnections()
log.InfoContextf(ctx, "CancelConnections finished, calling WaitForFlush")
return c.com.WaitForFlush(ctx)
}
14 changes: 14 additions & 0 deletions fleetspeak/src/client/comms/comms.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,27 @@ import (
fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
)

// darkWakeKeyType is the type for the DarkWakeKey context key.
type darkWakeKeyType struct{}

// DarkWakeKey is a context key used to indicate that an operation is part of a DarkWake flush.
var DarkWakeKey = darkWakeKeyType{}

// A Communicator is a component which allows a client to communicate with a
// Fleetspeak server.
type Communicator interface {
Setup(Context) error // Configure the communicator to work with Client.
Start() error // Tells the communicator to start sending and receiving messages.
Stop() // Tells the communicator to stop sending and receiving messages.

// CancelConnections forcefully cancels any underlying network connections.
CancelConnections()

// WaitForFlush blocks until the communicator has successfully established a
// fresh connection and flushed all pending messages, or until the provided
// context expires.
WaitForFlush(ctx context.Context) error

// GetFileIfModified attempts to retrieve a file from a server, if it
// has been modified since modSince. If it has not been modified, it
// returns nil. Otherwise, it returns a ReadCloser for the file's data
Expand Down
66 changes: 65 additions & 1 deletion fleetspeak/src/client/https/polling.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ type Communicator struct {
clientCertificateHeader string

certBytes []byte

wakeUp chan struct{}
pollComplete chan error
mu sync.Mutex
pollCancel context.CancelFunc
}

// Setup implements comms.Communicator.
Expand Down Expand Up @@ -108,6 +113,8 @@ func (c *Communicator) configure() error {
}
c.ctx, c.done = context.WithCancel(context.Background())
c.clientCertificateHeader = si.ClientCertificateHeader
c.wakeUp = make(chan struct{}, 1)
c.pollComplete = make(chan error, 1)
c.certBytes = certBytes
return nil
}
Expand All @@ -127,6 +134,39 @@ func (c *Communicator) Stop() {
c.wd.Stop()
}

// CancelConnections implements comms.Communicator.
func (c *Communicator) CancelConnections() {
log.Infof("CancelConnections called")
c.mu.Lock()
if c.pollCancel != nil {
c.pollCancel()
}
c.mu.Unlock()
c.hc.Transport.(*http.Transport).CloseIdleConnections()
// Drain pollComplete to ensure WaitForFlush waits for a new poll.
select {
case <-c.pollComplete:
default:
}
select {
case c.wakeUp <- struct{}{}:
default:
}
}

// WaitForFlush implements comms.Communicator.
func (c *Communicator) WaitForFlush(ctx context.Context) error {
log.InfoContextf(ctx, "WaitForFlush called")
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c.pollComplete:
return err
}
}
}

// processingLoop polls the server according to the configured policies while
// the communicator is active.
//
Expand All @@ -152,11 +192,19 @@ func (c *Communicator) processingLoop() {
// and updates the variables defined above. In case of failure it also sleeps
// for the MinFailureDelay.
poll := func() {
var err error
defer func() {
select {
case c.pollComplete <- err:
default:
}
}()
c.wd.Reset()
if c.cctx.CurrentID() != c.id {
c.configure()
}
active, err := c.poll(toSend)
var active bool
active, err = c.poll(toSend)
if err != nil {
log.Warningf("Failure during polling: %v", err)
for _, m := range toSend {
Expand All @@ -175,6 +223,8 @@ func (c *Communicator) processingLoop() {
case <-t.C:
case <-c.ctx.Done():
t.Stop()
case <-c.wakeUp:
t.Stop()
}
return
}
Expand Down Expand Up @@ -255,6 +305,9 @@ func (c *Communicator) processingLoop() {
return
case <-t.C:
poll()
case <-c.wakeUp:
t.Stop()
poll()
case m := <-c.cctx.Outbox():
t.Stop()
toSend = append(toSend, m)
Expand Down Expand Up @@ -339,6 +392,17 @@ func (c *Communicator) pollHost(host string, data []byte) (*fspb.ContactData, er
if sendErr != nil {
return nil, sendErr
}
var reqCtx context.Context
c.mu.Lock()
reqCtx, c.pollCancel = context.WithCancel(c.ctx)
c.mu.Unlock()
defer func() {
c.mu.Lock()
c.pollCancel()
c.pollCancel = nil
c.mu.Unlock()
}()
req = req.WithContext(reqCtx)
SetContentEncoding(req.Header, c.conf.GetCompression())
if c.clientCertificateHeader != "" {
bc := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.certBytes})
Expand Down
142 changes: 142 additions & 0 deletions fleetspeak/src/client/https/polling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,3 +733,145 @@ func TestCertificateRevoked(t *testing.T) {
tl.Close()
cl.Stop()
}

func TestCancelConnections(t *testing.T) {
var c Communicator
conf := config.Configuration{
Servers: []string{"localhost:1234"},
CommunicatorConfig: &clpb.CommunicatorConfig{
MaxPollDelaySeconds: 10,
MinFailureDelaySeconds: 10,
},
}
dialCalls := make(chan struct{}, 10)
c.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialCalls <- struct{}{}
<-ctx.Done()
return nil, ctx.Err()
}

cl, err := client.New(
conf,
client.Components{
Communicator: &c,
})
if err != nil {
t.Fatalf("unable to create client: %v", err)
}
defer cl.Stop()

// Wait for first dial attempt
select {
case <-dialCalls:
case <-time.After(5 * time.Second):
t.Fatal("Timed out waiting for first dial attempt")
}

// Call CancelConnections
c.CancelConnections()

// Wait for second dial attempt (should be immediate, not waiting 10s)
select {
case <-dialCalls:
case <-time.After(2 * time.Second):
t.Fatal("Timed out waiting for second dial attempt after CancelConnections")
}
}

func TestWaitForFlush(t *testing.T) {
// Create a local https server for the client to talk to.
pemCert, pemKey, err := common_util.ServerCert()
if err != nil {
t.Fatal(err)
}
cp, err := tls.X509KeyPair(pemCert, pemKey)
if err != nil {
t.Fatal(err)
}
ad, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
tl, err := net.ListenTCP("tcp", ad)
if err != nil {
t.Fatal(err)
}
addr := tl.Addr().String()

pollCount := int32(0)
mux := http.NewServeMux()
mux.HandleFunc("/message", func(res http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&pollCount, 1)
cd := fspb.ContactData{
SequencingNonce: 42,
}
buf, err := proto.Marshal(&cd)
if err != nil {
t.Fatalf("unable to marshal ContactData: %v", err)
}
res.Header().Set("Content-Type", "application/octet-stream")
res.WriteHeader(http.StatusOK)
res.Write(buf)
})

server := http.Server{
Addr: addr,
Handler: mux,
TLSConfig: &tls.Config{
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{cp},
NextProtos: []string{"h2"},
},
}
l := tls.NewListener(tl, server.TLSConfig)
go server.Serve(l)
defer tl.Close()

var c Communicator
conf := config.Configuration{
TrustedCerts: x509.NewCertPool(),
Servers: []string{addr},
CommunicatorConfig: &clpb.CommunicatorConfig{
MaxPollDelaySeconds: 10,
MinFailureDelaySeconds: 10,
},
}
if !conf.TrustedCerts.AppendCertsFromPEM(pemCert) {
t.Fatal("unable to add server cert to pool")
}

cl, err := client.New(
conf,
client.Components{
Communicator: &c,
})
if err != nil {
t.Fatalf("unable to create client: %v", err)
}
defer cl.Stop()

// Wait for first poll
for atomic.LoadInt32(&pollCount) == 0 {
time.Sleep(100 * time.Millisecond)
}

// Call WaitForFlush in a goroutine
flushDone := make(chan error, 1)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
flushDone <- c.WaitForFlush(ctx)
}()

// Trigger a poll via wakeUp (indirectly via CancelConnections)
c.CancelConnections()

select {
case err := <-flushDone:
if err != nil {
t.Errorf("WaitForFlush failed: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Timed out waiting for WaitForFlush")
}
}
Loading
Loading