Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/201.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
unix output: Added support for writing to unix sockets, behaves the same as tcp output.
```
1 change: 1 addition & 0 deletions internal/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
_ "github.com/elastic/stream/internal/output/tcp"
_ "github.com/elastic/stream/internal/output/tls"
_ "github.com/elastic/stream/internal/output/udp"
_ "github.com/elastic/stream/internal/output/unix"
_ "github.com/elastic/stream/internal/output/webhook"
)

Expand Down
83 changes: 83 additions & 0 deletions internal/output/unix/unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to Elasticsearch B.V. under one or more agreements.
// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

// Package unix provides an output for writing to unix sockets
package unix

import (
"context"
"errors"
"io"
"net"
"time"

"github.com/elastic/stream/internal/output"
)

func init() {
output.Register("unix", New)
}

// Output holds options and connection
type Output struct {
opts *output.Options
conn *net.UnixConn
}

// New creates a new unix output
func New(opts *output.Options) (output.Output, error) {
return &Output{opts: opts}, nil
}

// DialContext connects to the address in the Output struct using the supplied context
func (o *Output) DialContext(ctx context.Context) error {
d := net.Dialer{Timeout: time.Second}

conn, err := d.DialContext(ctx, "unix", o.opts.Addr)
if err != nil {
return err
}

o.conn = conn.(*net.UnixConn)
return nil
}

// Conn returns the connection
func (o *Output) Conn() net.Conn {
return o.conn
}

// Close gracefully closes the connection
func (o *Output) Close() error {
if o.conn != nil {
if err := o.conn.CloseWrite(); err != nil {
return err
}

// drain to facilitate graceful close on the other side
deadline := time.Now().Add(5 * time.Second)
if err := o.conn.SetReadDeadline(deadline); err != nil {
return err
}
buffer := make([]byte, 1024)
for {
_, err := o.conn.Read(buffer)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
}

return o.conn.Close()
}
return nil
}

// Write the supplied bytes to the connection and appends a newline
// character. The adding of the newline character is to behave the
// same as the tcp output.
func (o *Output) Write(b []byte) (int, error) {
return o.conn.Write(append(b, '\n')) //nolint:staticcheck // convention established in tcp output
}
154 changes: 154 additions & 0 deletions internal/output/unix/unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed to Elasticsearch B.V. under one or more agreements.
// Elasticsearch B.V. licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

package unix

import (
"context"
"net"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/stream/internal/output"
)

func newListener(t *testing.T) *net.UnixListener {
t.Helper()
path := filepath.Join(t.TempDir(), "test.sock")
l, err := net.Listen("unix", path)
require.NoError(t, err)
t.Cleanup(func() { l.Close() })
return l.(*net.UnixListener)
}

func TestDial(t *testing.T) {
l := newListener(t)

// Accept and drain so Close()'s graceful shutdown can complete.
go func() {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
for {
_, err := conn.Read(buf)
if err != nil {
break
}
}
}()

out, err := New(&output.Options{Addr: l.Addr().String()})
require.NoError(t, err)

err = out.DialContext(context.Background())
require.NoError(t, err)
require.NoError(t, out.Close())
}

func TestDialInvalidPath(t *testing.T) {
out, err := New(&output.Options{Addr: "/nonexistent/path/test.sock"})
require.NoError(t, err)

err = out.DialContext(context.Background())
require.Error(t, err)
}

func TestWrite(t *testing.T) {
l := newListener(t)

// Accept one connection and read all data from it.
received := make(chan []byte, 1)
go func() {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
received <- buf[:n]
}()

out, err := New(&output.Options{Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

msg := []byte("hello world")
_, err = out.Write(msg)
require.NoError(t, err)

got := <-received
assert.Equal(t, append(msg, '\n'), got)
}

func TestWriteAppendsNewline(t *testing.T) {
l := newListener(t)

received := make(chan []byte, 1)
go func() {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
received <- buf[:n]
}()

out, err := New(&output.Options{Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))

_, err = out.Write([]byte("no newline"))
require.NoError(t, err)

got := <-received
assert.Equal(t, byte('\n'), got[len(got)-1])
}

func TestClose(t *testing.T) {
l := newListener(t)

done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
if err != nil {
return
}
// Drain until EOF so the graceful close can complete.
buf := make([]byte, 4096)
for {
_, err := conn.Read(buf)
if err != nil {
break
}
}
conn.Close()
}()

out, err := New(&output.Options{Addr: l.Addr().String()})
require.NoError(t, err)
require.NoError(t, out.DialContext(context.Background()))
require.NoError(t, out.Close())
<-done
}

func TestCloseWithoutDial(t *testing.T) {
out, err := New(&output.Options{Addr: "/tmp/ignored.sock"})
require.NoError(t, err)
require.NoError(t, out.Close())
}

func TestRegistered(t *testing.T) {
outputs := output.Available()
assert.Contains(t, outputs, "unix")
}
Loading