diff --git a/.changelog/201.txt b/.changelog/201.txt new file mode 100644 index 0000000..8b0e125 --- /dev/null +++ b/.changelog/201.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +unix output: Added support for writing to unix sockets, behaves the same as tcp output. +``` diff --git a/internal/command/root.go b/internal/command/root.go index cba4dfe..386ed06 100644 --- a/internal/command/root.go +++ b/internal/command/root.go @@ -35,6 +35,7 @@ import ( _ "github.com/elastic/stream/internal/output/tcp" _ "github.com/elastic/stream/internal/output/tls" _ "github.com/elastic/stream/internal/output/udp" + _ "github.com/elastic/stream/internal/output/unix" _ "github.com/elastic/stream/internal/output/webhook" ) diff --git a/internal/output/unix/unix.go b/internal/output/unix/unix.go new file mode 100644 index 0000000..6986f3c --- /dev/null +++ b/internal/output/unix/unix.go @@ -0,0 +1,83 @@ +// Licensed to Elasticsearch B.V. under one or more agreements. +// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +// Package unix provides an output for writing to unix sockets +package unix + +import ( + "context" + "errors" + "io" + "net" + "time" + + "github.com/elastic/stream/internal/output" +) + +func init() { + output.Register("unix", New) +} + +// Output holds options and connection +type Output struct { + opts *output.Options + conn *net.UnixConn +} + +// New creates a new unix output +func New(opts *output.Options) (output.Output, error) { + return &Output{opts: opts}, nil +} + +// DialContext connects to the address in the Output struct using the supplied context +func (o *Output) DialContext(ctx context.Context) error { + d := net.Dialer{Timeout: time.Second} + + conn, err := d.DialContext(ctx, "unix", o.opts.Addr) + if err != nil { + return err + } + + o.conn = conn.(*net.UnixConn) + return nil +} + +// Conn returns the connection +func (o *Output) Conn() net.Conn { + return o.conn +} + +// Close gracefully closes the connection +func (o *Output) Close() error { + if o.conn != nil { + if err := o.conn.CloseWrite(); err != nil { + return err + } + + // drain to facilitate graceful close on the other side + deadline := time.Now().Add(5 * time.Second) + if err := o.conn.SetReadDeadline(deadline); err != nil { + return err + } + buffer := make([]byte, 1024) + for { + _, err := o.conn.Read(buffer) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + } + + return o.conn.Close() + } + return nil +} + +// Write the supplied bytes to the connection and appends a newline +// character. The adding of the newline character is to behave the +// same as the tcp output. +func (o *Output) Write(b []byte) (int, error) { + return o.conn.Write(append(b, '\n')) //nolint:staticcheck // convention established in tcp output +} diff --git a/internal/output/unix/unix_test.go b/internal/output/unix/unix_test.go new file mode 100644 index 0000000..fcbb8c5 --- /dev/null +++ b/internal/output/unix/unix_test.go @@ -0,0 +1,154 @@ +// Licensed to Elasticsearch B.V. under one or more agreements. +// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +package unix + +import ( + "context" + "net" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/stream/internal/output" +) + +func newListener(t *testing.T) *net.UnixListener { + t.Helper() + path := filepath.Join(t.TempDir(), "test.sock") + l, err := net.Listen("unix", path) + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + return l.(*net.UnixListener) +} + +func TestDial(t *testing.T) { + l := newListener(t) + + // Accept and drain so Close()'s graceful shutdown can complete. + go func() { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + for { + _, err := conn.Read(buf) + if err != nil { + break + } + } + }() + + out, err := New(&output.Options{Addr: l.Addr().String()}) + require.NoError(t, err) + + err = out.DialContext(context.Background()) + require.NoError(t, err) + require.NoError(t, out.Close()) +} + +func TestDialInvalidPath(t *testing.T) { + out, err := New(&output.Options{Addr: "/nonexistent/path/test.sock"}) + require.NoError(t, err) + + err = out.DialContext(context.Background()) + require.Error(t, err) +} + +func TestWrite(t *testing.T) { + l := newListener(t) + + // Accept one connection and read all data from it. + received := make(chan []byte, 1) + go func() { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + n, _ := conn.Read(buf) + received <- buf[:n] + }() + + out, err := New(&output.Options{Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + msg := []byte("hello world") + _, err = out.Write(msg) + require.NoError(t, err) + + got := <-received + assert.Equal(t, append(msg, '\n'), got) +} + +func TestWriteAppendsNewline(t *testing.T) { + l := newListener(t) + + received := make(chan []byte, 1) + go func() { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + n, _ := conn.Read(buf) + received <- buf[:n] + }() + + out, err := New(&output.Options{Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + _, err = out.Write([]byte("no newline")) + require.NoError(t, err) + + got := <-received + assert.Equal(t, byte('\n'), got[len(got)-1]) +} + +func TestClose(t *testing.T) { + l := newListener(t) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := l.Accept() + if err != nil { + return + } + // Drain until EOF so the graceful close can complete. + buf := make([]byte, 4096) + for { + _, err := conn.Read(buf) + if err != nil { + break + } + } + conn.Close() + }() + + out, err := New(&output.Options{Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + require.NoError(t, out.Close()) + <-done +} + +func TestCloseWithoutDial(t *testing.T) { + out, err := New(&output.Options{Addr: "/tmp/ignored.sock"}) + require.NoError(t, err) + require.NoError(t, out.Close()) +} + +func TestRegistered(t *testing.T) { + outputs := output.Available() + assert.Contains(t, outputs, "unix") +}