diff --git a/.changelog/202.txt b/.changelog/202.txt new file mode 100644 index 0000000..58db719 --- /dev/null +++ b/.changelog/202.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +tcp, tls, unix, and udp outputs: Refactor to use a single "net" output. +``` diff --git a/internal/command/root.go b/internal/command/root.go index 386ed06..8116c0d 100644 --- a/internal/command/root.go +++ b/internal/command/root.go @@ -32,10 +32,7 @@ import ( _ "github.com/elastic/stream/internal/output/gcs" _ "github.com/elastic/stream/internal/output/kafka" _ "github.com/elastic/stream/internal/output/lumberjack" - _ "github.com/elastic/stream/internal/output/tcp" - _ "github.com/elastic/stream/internal/output/tls" - _ "github.com/elastic/stream/internal/output/udp" - _ "github.com/elastic/stream/internal/output/unix" + _ "github.com/elastic/stream/internal/output/net" _ "github.com/elastic/stream/internal/output/webhook" ) diff --git a/internal/output/net/net.go b/internal/output/net/net.go new file mode 100644 index 0000000..7e300b4 --- /dev/null +++ b/internal/output/net/net.go @@ -0,0 +1,129 @@ +// Licensed to Elasticsearch B.V. under one or more agreements. +// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +// Package netout provides a unified network output supporting tcp, tls, udp, and unix protocols. +package netout + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "time" + + "golang.org/x/time/rate" + + "github.com/elastic/stream/internal/output" +) + +const burst = 1024 * 1024 + +func init() { + output.Register("tcp", New) + output.Register("tls", New) + output.Register("udp", New) + output.Register("unix", New) +} + +// Output holds options and the active connection. +type Output struct { + opts *output.Options + conn net.Conn + ctx context.Context + limit *rate.Limiter +} + +// New creates a new network output for the protocol specified in opts.Protocol. +func New(opts *output.Options) (output.Output, error) { + o := &Output{opts: opts} + if opts.Protocol == "udp" { + o.limit = rate.NewLimiter(rate.Limit(opts.RateLimit), burst) + } + return o, nil +} + +// DialContext connects to the address in opts using the protocol in opts. +func (o *Output) DialContext(ctx context.Context) error { + var ( + conn net.Conn + err error + ) + + switch o.opts.Protocol { + case "tls": + d := tls.Dialer{ + Config: &tls.Config{InsecureSkipVerify: o.opts.InsecureTLS}, //nolint:gosec + NetDialer: &net.Dialer{Timeout: time.Second}, + } + conn, err = d.DialContext(ctx, "tcp", o.opts.Addr) + case "udp": + conn, err = net.Dial("udp", o.opts.Addr) + o.ctx = ctx + case "tcp", "unix": + + d := net.Dialer{Timeout: time.Second} + conn, err = d.DialContext(ctx, o.opts.Protocol, o.opts.Addr) + default: + return fmt.Errorf("unknown protocol: %s", o.opts.Protocol) + } + + if err != nil { + return err + } + o.conn = conn + return nil +} + +// Close closes the connection. For stream-oriented protocols (tcp, tls, unix) it +// performs a graceful shutdown by signalling EOF and draining remaining data. +func (o *Output) Close() error { + if o.conn == nil { + return nil + } + + if o.opts.Protocol == "udp" { + return o.conn.Close() + } + + // Signal EOF to the remote end. + type closeWriter interface { + CloseWrite() error + } + if cw, ok := o.conn.(closeWriter); ok { + if err := cw.CloseWrite(); err != nil { + return err + } + } + + // Drain to facilitate graceful close on the other side. + deadline := time.Now().Add(5 * time.Second) + if err := o.conn.SetReadDeadline(deadline); err != nil { + return err + } + buf := make([]byte, 1024) + for { + _, err := o.conn.Read(buf) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + } + + return o.conn.Close() +} + +// Write writes b to the connection. A newline is appended for stream-oriented +// protocols (tcp, tls, unix); UDP datagrams are written as-is. +func (o *Output) Write(b []byte) (int, error) { + if o.opts.Protocol == "udp" { + if err := o.limit.WaitN(o.ctx, len(b)); err != nil { + return 0, err + } + return o.conn.Write(b) + } + return o.conn.Write(append(b, '\n')) +} diff --git a/internal/output/net/net_test.go b/internal/output/net/net_test.go new file mode 100644 index 0000000..a989922 --- /dev/null +++ b/internal/output/net/net_test.go @@ -0,0 +1,208 @@ +// Licensed to Elasticsearch B.V. under one or more agreements. +// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information. + +package netout + +import ( + "context" + "net" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/stream/internal/output" +) + +// helpers + +func newTCPListener(t *testing.T) net.Listener { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + return l +} + +func newUnixListener(t *testing.T) net.Listener { + t.Helper() + path := filepath.Join(t.TempDir(), "test.sock") + l, err := net.Listen("unix", path) + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + return l +} + +// acceptAndDrain accepts one connection, drains it, and closes it. +func acceptAndDrain(l net.Listener) { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + for { + _, err := conn.Read(buf) + if err != nil { + break + } + } +} + +// acceptAndCollect accepts one connection, reads one chunk, and sends it on ch. +func acceptAndCollect(l net.Listener, ch chan<- []byte) { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + n, _ := conn.Read(buf) + ch <- buf[:n] +} + +// TCP tests + +func TestTCPDial(t *testing.T) { + l := newTCPListener(t) + go acceptAndDrain(l) + + out, err := New(&output.Options{Protocol: "tcp", Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + require.NoError(t, out.Close()) +} + +func TestTCPDialInvalid(t *testing.T) { + out, err := New(&output.Options{Protocol: "tcp", Addr: "127.0.0.1:1"}) + require.NoError(t, err) + require.Error(t, out.DialContext(context.Background())) +} + +func TestTCPWrite(t *testing.T) { + l := newTCPListener(t) + ch := make(chan []byte, 1) + go acceptAndCollect(l, ch) + + out, err := New(&output.Options{Protocol: "tcp", Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + msg := []byte("hello tcp") + _, err = out.Write(msg) + require.NoError(t, err) + assert.Equal(t, append(msg, '\n'), <-ch) +} + +func TestTCPCloseWithoutDial(t *testing.T) { + out, err := New(&output.Options{Protocol: "tcp", Addr: "127.0.0.1:1"}) + require.NoError(t, err) + require.NoError(t, out.Close()) +} + +// Unix socket tests + +func TestUnixDial(t *testing.T) { + l := newUnixListener(t) + go acceptAndDrain(l) + + out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + require.NoError(t, out.Close()) +} + +func TestUnixDialInvalidPath(t *testing.T) { + out, err := New(&output.Options{Protocol: "unix", Addr: "/nonexistent/path/test.sock"}) + require.NoError(t, err) + require.Error(t, out.DialContext(context.Background())) +} + +func TestUnixWrite(t *testing.T) { + l := newUnixListener(t) + ch := make(chan []byte, 1) + go acceptAndCollect(l, ch) + + out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + msg := []byte("hello unix") + _, err = out.Write(msg) + require.NoError(t, err) + assert.Equal(t, append(msg, '\n'), <-ch) +} + +func TestUnixWriteAppendsNewline(t *testing.T) { + l := newUnixListener(t) + ch := make(chan []byte, 1) + go acceptAndCollect(l, ch) + + out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + _, err = out.Write([]byte("no newline")) + require.NoError(t, err) + got := <-ch + assert.Equal(t, byte('\n'), got[len(got)-1]) +} + +func TestUnixCloseWithoutDial(t *testing.T) { + out, err := New(&output.Options{Protocol: "unix", Addr: "/tmp/ignored.sock"}) + require.NoError(t, err) + require.NoError(t, out.Close()) +} + +// UDP tests + +func TestUDPDial(t *testing.T) { + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer pc.Close() + + out, err := New(&output.Options{Protocol: "udp", Addr: pc.LocalAddr().String(), RateLimit: 1024 * 1024}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + require.NoError(t, out.Close()) +} + +func TestUDPWrite(t *testing.T) { + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + defer pc.Close() + + out, err := New(&output.Options{Protocol: "udp", Addr: pc.LocalAddr().String(), RateLimit: 1024 * 1024}) + require.NoError(t, err) + require.NoError(t, out.DialContext(context.Background())) + + msg := []byte("hello udp") + _, err = out.Write(msg) + require.NoError(t, err) + + buf := make([]byte, 4096) + n, _, err := pc.ReadFrom(buf) + require.NoError(t, err) + // UDP does not append a newline. + assert.Equal(t, msg, buf[:n]) +} + +// Unknown protocol test + +func TestDialUnknownProtocol(t *testing.T) { + out, err := New(&output.Options{Protocol: "ftp", Addr: "127.0.0.1:21"}) + require.NoError(t, err) + err = out.DialContext(context.Background()) + require.ErrorContains(t, err, "unknown protocol") +} + +// Registration tests + +func TestRegistered(t *testing.T) { + available := output.Available() + for _, proto := range []string{"tcp", "tls", "udp", "unix"} { + assert.Contains(t, available, proto, "protocol %q should be registered", proto) + } +} diff --git a/internal/output/tcp/tcp.go b/internal/output/tcp/tcp.go deleted file mode 100644 index 4228310..0000000 --- a/internal/output/tcp/tcp.go +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more agreements. -// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -package tcp - -import ( - "context" - "errors" - "io" - "net" - "time" - - "github.com/elastic/stream/internal/output" -) - -func init() { - output.Register("tcp", New) -} - -type Output struct { - opts *output.Options - conn *net.TCPConn -} - -func New(opts *output.Options) (output.Output, error) { - return &Output{opts: opts}, nil -} - -func (o *Output) DialContext(ctx context.Context) error { - d := net.Dialer{Timeout: time.Second} - - conn, err := d.DialContext(ctx, "tcp", o.opts.Addr) - if err != nil { - return err - } - - o.conn = conn.(*net.TCPConn) - return nil -} - -func (o *Output) Conn() net.Conn { - return o.conn -} - -func (o *Output) Close() error { - if o.conn != nil { - if err := o.conn.CloseWrite(); err != nil { - return err - } - - // drain to facilitate graceful close on the other side - deadline := time.Now().Add(5 * time.Second) - if err := o.conn.SetReadDeadline(deadline); err != nil { - return err - } - buffer := make([]byte, 1024) - for { - _, err := o.conn.Read(buffer) - if errors.Is(err, io.EOF) { - break - } else if err != nil { - return err - } - } - - return o.conn.Close() - } - return nil -} - -func (o *Output) Write(b []byte) (int, error) { - return o.conn.Write(append(b, '\n')) -} diff --git a/internal/output/tls/tls.go b/internal/output/tls/tls.go deleted file mode 100644 index f146ade..0000000 --- a/internal/output/tls/tls.go +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more agreements. -// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -package tcp - -import ( - "context" - "crypto/tls" - "errors" - "io" - "net" - "time" - - "github.com/elastic/stream/internal/output" -) - -func init() { - output.Register("tls", New) -} - -type Output struct { - opts *output.Options - conn *tls.Conn -} - -func New(opts *output.Options) (output.Output, error) { - return &Output{opts: opts}, nil -} - -func (o *Output) DialContext(ctx context.Context) error { - d := tls.Dialer{ - Config: &tls.Config{ - InsecureSkipVerify: o.opts.InsecureTLS, - }, - NetDialer: &net.Dialer{Timeout: time.Second}, - } - - conn, err := d.DialContext(ctx, "tcp", o.opts.Addr) - if err != nil { - return err - } - - o.conn = conn.(*tls.Conn) - return nil -} - -func (o *Output) Close() error { - if o.conn != nil { - if err := o.conn.CloseWrite(); err != nil { - return err - } - - // drain to facilitate graceful close on the other side - deadline := time.Now().Add(5 * time.Second) - if err := o.conn.SetReadDeadline(deadline); err != nil { - return err - } - buffer := make([]byte, 1024) - for { - _, err := o.conn.Read(buffer) - if errors.Is(err, io.EOF) { - break - } else if err != nil { - return err - } - } - - return o.conn.Close() - } - return nil -} - -func (o *Output) Write(b []byte) (int, error) { - return o.conn.Write(append(b, '\n')) -} diff --git a/internal/output/udp/udp.go b/internal/output/udp/udp.go deleted file mode 100644 index 2752ec1..0000000 --- a/internal/output/udp/udp.go +++ /dev/null @@ -1,61 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more agreements. -// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -package tcp - -import ( - "context" - "net" - - "golang.org/x/time/rate" - - "github.com/elastic/stream/internal/output" -) - -const burst = 1024 * 1024 - -func init() { - output.Register("udp", New) -} - -type Output struct { - opts *output.Options - conn *net.UDPConn - ctx context.Context - limit *rate.Limiter -} - -func New(opts *output.Options) (output.Output, error) { - return &Output{ - opts: opts, - limit: rate.NewLimiter(rate.Limit(opts.RateLimit), burst), - }, nil -} - -func (o *Output) DialContext(ctx context.Context) error { - udpAddr, err := net.ResolveUDPAddr("udp", o.opts.Addr) - if err != nil { - return err - } - - conn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - return err - } - - o.conn = conn - o.ctx = ctx - return nil -} - -func (o *Output) Close() error { - return o.conn.Close() -} - -func (o *Output) Write(b []byte) (int, error) { - if err := o.limit.WaitN(o.ctx, len(b)); err != nil { - return 0, err - } - return o.conn.Write(b) -} diff --git a/internal/output/unix/unix.go b/internal/output/unix/unix.go deleted file mode 100644 index 6986f3c..0000000 --- a/internal/output/unix/unix.go +++ /dev/null @@ -1,83 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more agreements. -// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -// Package unix provides an output for writing to unix sockets -package unix - -import ( - "context" - "errors" - "io" - "net" - "time" - - "github.com/elastic/stream/internal/output" -) - -func init() { - output.Register("unix", New) -} - -// Output holds options and connection -type Output struct { - opts *output.Options - conn *net.UnixConn -} - -// New creates a new unix output -func New(opts *output.Options) (output.Output, error) { - return &Output{opts: opts}, nil -} - -// DialContext connects to the address in the Output struct using the supplied context -func (o *Output) DialContext(ctx context.Context) error { - d := net.Dialer{Timeout: time.Second} - - conn, err := d.DialContext(ctx, "unix", o.opts.Addr) - if err != nil { - return err - } - - o.conn = conn.(*net.UnixConn) - return nil -} - -// Conn returns the connection -func (o *Output) Conn() net.Conn { - return o.conn -} - -// Close gracefully closes the connection -func (o *Output) Close() error { - if o.conn != nil { - if err := o.conn.CloseWrite(); err != nil { - return err - } - - // drain to facilitate graceful close on the other side - deadline := time.Now().Add(5 * time.Second) - if err := o.conn.SetReadDeadline(deadline); err != nil { - return err - } - buffer := make([]byte, 1024) - for { - _, err := o.conn.Read(buffer) - if errors.Is(err, io.EOF) { - break - } else if err != nil { - return err - } - } - - return o.conn.Close() - } - return nil -} - -// Write the supplied bytes to the connection and appends a newline -// character. The adding of the newline character is to behave the -// same as the tcp output. -func (o *Output) Write(b []byte) (int, error) { - return o.conn.Write(append(b, '\n')) //nolint:staticcheck // convention established in tcp output -} diff --git a/internal/output/unix/unix_test.go b/internal/output/unix/unix_test.go deleted file mode 100644 index fcbb8c5..0000000 --- a/internal/output/unix/unix_test.go +++ /dev/null @@ -1,154 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more agreements. -// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License. -// See the LICENSE file in the project root for more information. - -package unix - -import ( - "context" - "net" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/elastic/stream/internal/output" -) - -func newListener(t *testing.T) *net.UnixListener { - t.Helper() - path := filepath.Join(t.TempDir(), "test.sock") - l, err := net.Listen("unix", path) - require.NoError(t, err) - t.Cleanup(func() { l.Close() }) - return l.(*net.UnixListener) -} - -func TestDial(t *testing.T) { - l := newListener(t) - - // Accept and drain so Close()'s graceful shutdown can complete. - go func() { - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - buf := make([]byte, 4096) - for { - _, err := conn.Read(buf) - if err != nil { - break - } - } - }() - - out, err := New(&output.Options{Addr: l.Addr().String()}) - require.NoError(t, err) - - err = out.DialContext(context.Background()) - require.NoError(t, err) - require.NoError(t, out.Close()) -} - -func TestDialInvalidPath(t *testing.T) { - out, err := New(&output.Options{Addr: "/nonexistent/path/test.sock"}) - require.NoError(t, err) - - err = out.DialContext(context.Background()) - require.Error(t, err) -} - -func TestWrite(t *testing.T) { - l := newListener(t) - - // Accept one connection and read all data from it. - received := make(chan []byte, 1) - go func() { - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - buf := make([]byte, 4096) - n, _ := conn.Read(buf) - received <- buf[:n] - }() - - out, err := New(&output.Options{Addr: l.Addr().String()}) - require.NoError(t, err) - require.NoError(t, out.DialContext(context.Background())) - - msg := []byte("hello world") - _, err = out.Write(msg) - require.NoError(t, err) - - got := <-received - assert.Equal(t, append(msg, '\n'), got) -} - -func TestWriteAppendsNewline(t *testing.T) { - l := newListener(t) - - received := make(chan []byte, 1) - go func() { - conn, err := l.Accept() - if err != nil { - return - } - defer conn.Close() - buf := make([]byte, 4096) - n, _ := conn.Read(buf) - received <- buf[:n] - }() - - out, err := New(&output.Options{Addr: l.Addr().String()}) - require.NoError(t, err) - require.NoError(t, out.DialContext(context.Background())) - - _, err = out.Write([]byte("no newline")) - require.NoError(t, err) - - got := <-received - assert.Equal(t, byte('\n'), got[len(got)-1]) -} - -func TestClose(t *testing.T) { - l := newListener(t) - - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := l.Accept() - if err != nil { - return - } - // Drain until EOF so the graceful close can complete. - buf := make([]byte, 4096) - for { - _, err := conn.Read(buf) - if err != nil { - break - } - } - conn.Close() - }() - - out, err := New(&output.Options{Addr: l.Addr().String()}) - require.NoError(t, err) - require.NoError(t, out.DialContext(context.Background())) - require.NoError(t, out.Close()) - <-done -} - -func TestCloseWithoutDial(t *testing.T) { - out, err := New(&output.Options{Addr: "/tmp/ignored.sock"}) - require.NoError(t, err) - require.NoError(t, out.Close()) -} - -func TestRegistered(t *testing.T) { - outputs := output.Available() - assert.Contains(t, outputs, "unix") -}