Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/202.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
tcp, tls, unix, and udp outputs: Refactor to use a single "net" output.
```
5 changes: 1 addition & 4 deletions internal/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ import (
_ "github.com/elastic/stream/internal/output/gcs"
_ "github.com/elastic/stream/internal/output/kafka"
_ "github.com/elastic/stream/internal/output/lumberjack"
_ "github.com/elastic/stream/internal/output/tcp"
_ "github.com/elastic/stream/internal/output/tls"
_ "github.com/elastic/stream/internal/output/udp"
_ "github.com/elastic/stream/internal/output/unix"
_ "github.com/elastic/stream/internal/output/net"
_ "github.com/elastic/stream/internal/output/webhook"
)

Expand Down
129 changes: 129 additions & 0 deletions internal/output/net/net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to Elasticsearch B.V. under one or more agreements.
// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

// Package netout provides a unified network output supporting tcp, tls, udp, and unix protocols.
package netout

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"time"

"golang.org/x/time/rate"

"github.com/elastic/stream/internal/output"
)

const burst = 1024 * 1024

func init() {
output.Register("tcp", New)
output.Register("tls", New)
output.Register("udp", New)
output.Register("unix", New)
}

// Output holds options and the active connection.
type Output struct {
opts *output.Options
conn net.Conn
ctx context.Context
limit *rate.Limiter
}

// New creates a new network output for the protocol specified in opts.Protocol.
func New(opts *output.Options) (output.Output, error) {
o := &Output{opts: opts}
if opts.Protocol == "udp" {
o.limit = rate.NewLimiter(rate.Limit(opts.RateLimit), burst)
}
return o, nil
}

// DialContext connects to the address in opts using the protocol in opts.
func (o *Output) DialContext(ctx context.Context) error {
var (
conn net.Conn
err error
)

switch o.opts.Protocol {
case "tls":
d := tls.Dialer{
Config: &tls.Config{InsecureSkipVerify: o.opts.InsecureTLS}, //nolint:gosec
NetDialer: &net.Dialer{Timeout: time.Second},
}
conn, err = d.DialContext(ctx, "tcp", o.opts.Addr)
case "udp":
conn, err = net.Dial("udp", o.opts.Addr)
o.ctx = ctx
case "tcp", "unix":

d := net.Dialer{Timeout: time.Second}
conn, err = d.DialContext(ctx, o.opts.Protocol, o.opts.Addr)
default:
return fmt.Errorf("unknown protocol: %s", o.opts.Protocol)
}

if err != nil {
return err
}
o.conn = conn
return nil
}

// Close closes the connection. For stream-oriented protocols (tcp, tls, unix) it
// performs a graceful shutdown by signalling EOF and draining remaining data.
func (o *Output) Close() error {
if o.conn == nil {
return nil
}

if o.opts.Protocol == "udp" {
return o.conn.Close()
}

// Signal EOF to the remote end.
type closeWriter interface {
CloseWrite() error
}
if cw, ok := o.conn.(closeWriter); ok {
if err := cw.CloseWrite(); err != nil {
return err
}
}

// Drain to facilitate graceful close on the other side.
deadline := time.Now().Add(5 * time.Second)
if err := o.conn.SetReadDeadline(deadline); err != nil {
return err
}
buf := make([]byte, 1024)
for {
_, err := o.conn.Read(buf)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
}

return o.conn.Close()
}

// Write writes b to the connection. A newline is appended for stream-oriented
// protocols (tcp, tls, unix); UDP datagrams are written as-is.
func (o *Output) Write(b []byte) (int, error) {
if o.opts.Protocol == "udp" {
if err := o.limit.WaitN(o.ctx, len(b)); err != nil {
return 0, err
}
return o.conn.Write(b)
}
return o.conn.Write(append(b, '\n'))
}
208 changes: 208 additions & 0 deletions internal/output/net/net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// Licensed to Elasticsearch B.V. under one or more agreements.
// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

package netout

import (
"context"
"net"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/stream/internal/output"
)

// helpers

func newTCPListener(t *testing.T) net.Listener {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() { l.Close() })
return l
}

func newUnixListener(t *testing.T) net.Listener {
t.Helper()
path := filepath.Join(t.TempDir(), "test.sock")
l, err := net.Listen("unix", path)
require.NoError(t, err)
t.Cleanup(func() { l.Close() })
return l
}

// acceptAndDrain accepts one connection, drains it, and closes it.
func acceptAndDrain(l net.Listener) {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
for {
_, err := conn.Read(buf)
if err != nil {
break
}
}
}

// acceptAndCollect accepts one connection, reads one chunk, and sends it on ch.
func acceptAndCollect(l net.Listener, ch chan<- []byte) {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
ch <- buf[:n]
}

// TCP tests

func TestTCPDial(t *testing.T) {
l := newTCPListener(t)
go acceptAndDrain(l)

out, err := New(&output.Options{Protocol: "tcp", Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))
require.NoError(t, out.Close())
}

func TestTCPDialInvalid(t *testing.T) {
out, err := New(&output.Options{Protocol: "tcp", Addr: "127.0.0.1:1"})
require.NoError(t, err)
require.Error(t, out.DialContext(context.Background()))
}

func TestTCPWrite(t *testing.T) {
l := newTCPListener(t)
ch := make(chan []byte, 1)
go acceptAndCollect(l, ch)

out, err := New(&output.Options{Protocol: "tcp", Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

msg := []byte("hello tcp")
_, err = out.Write(msg)
require.NoError(t, err)
assert.Equal(t, append(msg, '\n'), <-ch)
}

func TestTCPCloseWithoutDial(t *testing.T) {
out, err := New(&output.Options{Protocol: "tcp", Addr: "127.0.0.1:1"})
require.NoError(t, err)
require.NoError(t, out.Close())
}

// Unix socket tests

func TestUnixDial(t *testing.T) {
l := newUnixListener(t)
go acceptAndDrain(l)

out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))
require.NoError(t, out.Close())
}

func TestUnixDialInvalidPath(t *testing.T) {
out, err := New(&output.Options{Protocol: "unix", Addr: "/nonexistent/path/test.sock"})
require.NoError(t, err)
require.Error(t, out.DialContext(context.Background()))
}

func TestUnixWrite(t *testing.T) {
l := newUnixListener(t)
ch := make(chan []byte, 1)
go acceptAndCollect(l, ch)

out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

msg := []byte("hello unix")
_, err = out.Write(msg)
require.NoError(t, err)
assert.Equal(t, append(msg, '\n'), <-ch)
}

func TestUnixWriteAppendsNewline(t *testing.T) {
l := newUnixListener(t)
ch := make(chan []byte, 1)
go acceptAndCollect(l, ch)

out, err := New(&output.Options{Protocol: "unix", Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

_, err = out.Write([]byte("no newline"))
require.NoError(t, err)
got := <-ch
assert.Equal(t, byte('\n'), got[len(got)-1])
}

func TestUnixCloseWithoutDial(t *testing.T) {
out, err := New(&output.Options{Protocol: "unix", Addr: "/tmp/ignored.sock"})
require.NoError(t, err)
require.NoError(t, out.Close())
}

// UDP tests

func TestUDPDial(t *testing.T) {
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer pc.Close()

out, err := New(&output.Options{Protocol: "udp", Addr: pc.LocalAddr().String(), RateLimit: 1024 * 1024})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))
require.NoError(t, out.Close())
}

func TestUDPWrite(t *testing.T) {
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
defer pc.Close()

out, err := New(&output.Options{Protocol: "udp", Addr: pc.LocalAddr().String(), RateLimit: 1024 * 1024})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

msg := []byte("hello udp")
_, err = out.Write(msg)
require.NoError(t, err)

buf := make([]byte, 4096)
n, _, err := pc.ReadFrom(buf)
require.NoError(t, err)
// UDP does not append a newline.
assert.Equal(t, msg, buf[:n])
}

// Unknown protocol test

func TestDialUnknownProtocol(t *testing.T) {
out, err := New(&output.Options{Protocol: "ftp", Addr: "127.0.0.1:21"})
require.NoError(t, err)
err = out.DialContext(context.Background())
require.ErrorContains(t, err, "unknown protocol")
}

// Registration tests

func TestRegistered(t *testing.T) {
available := output.Available()
for _, proto := range []string{"tcp", "tls", "udp", "unix"} {
assert.Contains(t, available, proto, "protocol %q should be registered", proto)
}
}
Loading
Loading