diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..812dad5 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,51 @@ +# AGENTS.md — OttO + +## Purpose +OttO is the core runtime/framework. Keep application-specific logic out of this repo. +Add tests and docs with minimal disruption to public APIs. + +## Non-goals +- Do NOT add/require external services (no real MQTT broker, DB, HTTP server) for unit tests. +- Do NOT access hardware, GPIO, serial, or OS-specific devices in tests. +- Avoid broad refactors. Prefer small, additive changes that enable testing. +- Always ignore files and directories that begin with an underscore '_'. + +## Go conventions +- Run `gofmt` on all changed files. +- Keep package boundaries clean; avoid circular deps. +- Prefer context-aware APIs for long-running work. +- Avoid goroutine leaks: every goroutine must have a deterministic shutdown path. + +## Testing (required) +Use **testify**: +- Use `github.com/stretchr/testify/require` for must-pass assertions. +- Use `github.com/stretchr/testify/assert` for non-fatal checks. +- Use `github.com/stretchr/testify/mock` only when a small fake isn’t practical. + +Test rules: +- Tests must be hermetic: no network, no filesystem writes outside `t.TempDir()`. +- Avoid `time.Sleep` for synchronization. Use channels, WaitGroups, or context cancellation. +- Any test that could block must use `context.WithTimeout` and fail fast. +- Prefer table-driven tests. +- Keep tests deterministic and non-flaky. + +Commands: +- Run: `go test ./...` +- When adding new packages or helpers, keep them internal: `internal/testutil` is allowed. + +## What to test first (priority) +1. Pure logic: parsing, validation, topic naming/handling, config processing. +2. Concurrency: cancellation, shutdown behavior, channel fan-in/out, race-prone areas. +3. Interface contracts: error propagation and edge cases. + +## Review checklist (before proposing changes) +- Does this change keep OttO independent of app/device-specific logic? +- Are new tests hermetic and deterministic? +- Any new goroutines? Where do they stop? +- Any sleeps/time-based flakiness introduced? Remove it. + +## Commit guidance (if committing) +Prefer small commits: +- `test: baseline` +- `testutil: add ` +- `docs: godoc for ` diff --git a/Makefile b/Makefile index 9ea7dc8..b63c350 100644 --- a/Makefile +++ b/Makefile @@ -15,18 +15,6 @@ fmt: vet: go vet ./... -# ottoctl: -# go build -o ${OTTOCTL_BINARY} -ldflags "-X github.com/rustyeddy/otto/cmd.version=${VERSION}" ./cmd/ottoctl/ottoctl - -# otto: -# go build -o ${OTTO_BINARY} -ldflags "-X github.com/rustyeddy/otto/cmd.version=${VERSION}" ./cmd/otto - -build: - $(MAKE) -C cmd - -run: build - ./otto - test: rm -f cover.out go test -benchmem -coverprofile=cover.out -cover ./... @@ -104,4 +92,4 @@ service-status: service-logs: sudo journalctl -u $(SERVICE_FILE) -f -.PHONY: all build cmd otto ottoctl clean ci fmt run test vet install install-service enable-service uninstall-service uninstall service-status service-logs $(SUBDIRS) +.PHONY: all build otto ottoctl clean ci fmt run test vet install install-service enable-service uninstall-service uninstall service-status service-logs $(SUBDIRS) diff --git a/client/client.go b/_client/client.go similarity index 100% rename from client/client.go rename to _client/client.go diff --git a/client/client_test.go b/_client/client_test.go similarity index 100% rename from client/client_test.go rename to _client/client_test.go diff --git a/cmd/Makefile b/_cmd/Makefile similarity index 100% rename from cmd/Makefile rename to _cmd/Makefile diff --git a/cmd/otto/Makefile b/_cmd/otto/Makefile similarity index 100% rename from cmd/otto/Makefile rename to _cmd/otto/Makefile diff --git a/cmd/otto/main.go b/_cmd/otto/main.go similarity index 100% rename from cmd/otto/main.go rename to _cmd/otto/main.go diff --git a/cmd/otto/otto b/_cmd/otto/otto similarity index 100% rename from cmd/otto/otto rename to _cmd/otto/otto diff --git a/cmd/ottoctl/cmd_cli.go b/_cmd/ottoctl/cmd_cli.go similarity index 100% rename from cmd/ottoctl/cmd_cli.go rename to _cmd/ottoctl/cmd_cli.go diff --git a/cmd/ottoctl/cmd_cli_test.go b/_cmd/ottoctl/cmd_cli_test.go similarity index 100% rename from cmd/ottoctl/cmd_cli_test.go rename to _cmd/ottoctl/cmd_cli_test.go diff --git a/cmd/ottoctl/cmd_log.go b/_cmd/ottoctl/cmd_log.go similarity index 100% rename from cmd/ottoctl/cmd_log.go rename to _cmd/ottoctl/cmd_log.go diff --git a/cmd/ottoctl/cmd_log_test.go b/_cmd/ottoctl/cmd_log_test.go similarity index 100% rename from cmd/ottoctl/cmd_log_test.go rename to _cmd/ottoctl/cmd_log_test.go diff --git a/cmd/ottoctl/cmd_root.go b/_cmd/ottoctl/cmd_root.go similarity index 100% rename from cmd/ottoctl/cmd_root.go rename to _cmd/ottoctl/cmd_root.go diff --git a/cmd/ottoctl/cmd_root_test.go b/_cmd/ottoctl/cmd_root_test.go similarity index 100% rename from cmd/ottoctl/cmd_root_test.go rename to _cmd/ottoctl/cmd_root_test.go diff --git a/cmd/ottoctl/cmd_shutdown.go b/_cmd/ottoctl/cmd_shutdown.go similarity index 100% rename from cmd/ottoctl/cmd_shutdown.go rename to _cmd/ottoctl/cmd_shutdown.go diff --git a/cmd/ottoctl/cmd_shutdown_test.go b/_cmd/ottoctl/cmd_shutdown_test.go similarity index 100% rename from cmd/ottoctl/cmd_shutdown_test.go rename to _cmd/ottoctl/cmd_shutdown_test.go diff --git a/cmd/ottoctl/cmd_stations.go b/_cmd/ottoctl/cmd_stations.go similarity index 100% rename from cmd/ottoctl/cmd_stations.go rename to _cmd/ottoctl/cmd_stations.go diff --git a/cmd/ottoctl/cmd_stations_test.go b/_cmd/ottoctl/cmd_stations_test.go similarity index 100% rename from cmd/ottoctl/cmd_stations_test.go rename to _cmd/ottoctl/cmd_stations_test.go diff --git a/cmd/ottoctl/cmd_stats.go b/_cmd/ottoctl/cmd_stats.go similarity index 100% rename from cmd/ottoctl/cmd_stats.go rename to _cmd/ottoctl/cmd_stats.go diff --git a/cmd/ottoctl/cmd_stats_test.go b/_cmd/ottoctl/cmd_stats_test.go similarity index 100% rename from cmd/ottoctl/cmd_stats_test.go rename to _cmd/ottoctl/cmd_stats_test.go diff --git a/cmd/ottoctl/cmd_test.go b/_cmd/ottoctl/cmd_test.go similarity index 100% rename from cmd/ottoctl/cmd_test.go rename to _cmd/ottoctl/cmd_test.go diff --git a/cmd/ottoctl/cmd_timers.go b/_cmd/ottoctl/cmd_timers.go similarity index 100% rename from cmd/ottoctl/cmd_timers.go rename to _cmd/ottoctl/cmd_timers.go diff --git a/cmd/ottoctl/cmd_timers_test.go b/_cmd/ottoctl/cmd_timers_test.go similarity index 100% rename from cmd/ottoctl/cmd_timers_test.go rename to _cmd/ottoctl/cmd_timers_test.go diff --git a/cmd/ottoctl/cmd_version.go b/_cmd/ottoctl/cmd_version.go similarity index 100% rename from cmd/ottoctl/cmd_version.go rename to _cmd/ottoctl/cmd_version.go diff --git a/cmd/ottoctl/cmd_version_test.go b/_cmd/ottoctl/cmd_version_test.go similarity index 100% rename from cmd/ottoctl/cmd_version_test.go rename to _cmd/ottoctl/cmd_version_test.go diff --git a/cmd/ottoctl/ottoctl/Makefile b/_cmd/ottoctl/ottoctl/Makefile similarity index 100% rename from cmd/ottoctl/ottoctl/Makefile rename to _cmd/ottoctl/ottoctl/Makefile diff --git a/cmd/ottoctl/ottoctl/main.go b/_cmd/ottoctl/ottoctl/main.go similarity index 100% rename from cmd/ottoctl/ottoctl/main.go rename to _cmd/ottoctl/ottoctl/main.go diff --git a/cmd/ottoctl/testdata/help.txt b/_cmd/ottoctl/testdata/help.txt similarity index 100% rename from cmd/ottoctl/testdata/help.txt rename to _cmd/ottoctl/testdata/help.txt diff --git a/cmd/ottoctl/testdata/stations.json b/_cmd/ottoctl/testdata/stations.json similarity index 100% rename from cmd/ottoctl/testdata/stations.json rename to _cmd/ottoctl/testdata/stations.json diff --git a/cmd/ottoctl/testdata/stats.json b/_cmd/ottoctl/testdata/stats.json similarity index 100% rename from cmd/ottoctl/testdata/stats.json rename to _cmd/ottoctl/testdata/stats.json diff --git a/cmd/ottoctl/testdata/timers.json b/_cmd/ottoctl/testdata/timers.json similarity index 100% rename from cmd/ottoctl/testdata/timers.json rename to _cmd/ottoctl/testdata/timers.json diff --git a/cmd/ottoctl/testdata/version.json b/_cmd/ottoctl/testdata/version.json similarity index 100% rename from cmd/ottoctl/testdata/version.json rename to _cmd/ottoctl/testdata/version.json diff --git a/data/benchmark_test.go b/_data/benchmark_test.go similarity index 100% rename from data/benchmark_test.go rename to _data/benchmark_test.go diff --git a/data/data.go b/_data/data.go similarity index 100% rename from data/data.go rename to _data/data.go diff --git a/data/data_test.go b/_data/data_test.go similarity index 100% rename from data/data_test.go rename to _data/data_test.go diff --git a/data/timeseries.go b/_data/timeseries.go similarity index 100% rename from data/timeseries.go rename to _data/timeseries.go diff --git a/data/timeseries_test.go b/_data/timeseries_test.go similarity index 100% rename from data/timeseries_test.go rename to _data/timeseries_test.go diff --git a/examples/demo/main.go b/_examples/demo/main.go similarity index 100% rename from examples/demo/main.go rename to _examples/demo/main.go diff --git a/examples/logging/main.go b/_examples/logging/main.go similarity index 98% rename from examples/logging/main.go rename to _examples/logging/main.go index b256db7..c467f97 100644 --- a/examples/logging/main.go +++ b/_examples/logging/main.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" - "github.com/rustyeddy/otto/utils" + "github.com/rustyeddy/otto/_utils" ) func main() { diff --git a/otto.go b/_otto.go similarity index 100% rename from otto.go rename to _otto.go diff --git a/otto_test.go b/_otto_test.go similarity index 100% rename from otto_test.go rename to _otto_test.go diff --git a/server/ping.go b/_server/ping.go similarity index 100% rename from server/ping.go rename to _server/ping.go diff --git a/server/ping_test.go b/_server/ping_test.go similarity index 100% rename from server/ping_test.go rename to _server/ping_test.go diff --git a/server/server.go b/_server/server.go similarity index 100% rename from server/server.go rename to _server/server.go diff --git a/server/server_test.go b/_server/server_test.go similarity index 100% rename from server/server_test.go rename to _server/server_test.go diff --git a/server/testdata/app/index.html b/_server/testdata/app/index.html similarity index 100% rename from server/testdata/app/index.html rename to _server/testdata/app/index.html diff --git a/server/ws.go b/_server/ws.go similarity index 100% rename from server/ws.go rename to _server/ws.go diff --git a/server/ws_test.go b/_server/ws_test.go similarity index 100% rename from server/ws_test.go rename to _server/ws_test.go diff --git a/station/device_manager.go b/_station/device_manager.go similarity index 100% rename from station/device_manager.go rename to _station/device_manager.go diff --git a/station/device_manager_test.go b/_station/device_manager_test.go similarity index 100% rename from station/device_manager_test.go rename to _station/device_manager_test.go diff --git a/station/managed_device.go b/_station/managed_device.go similarity index 100% rename from station/managed_device.go rename to _station/managed_device.go diff --git a/station/managed_device_test.go b/_station/managed_device_test.go similarity index 100% rename from station/managed_device_test.go rename to _station/managed_device_test.go diff --git a/station/station.go b/_station/station.go similarity index 100% rename from station/station.go rename to _station/station.go diff --git a/station/station_manager.go b/_station/station_manager.go similarity index 100% rename from station/station_manager.go rename to _station/station_manager.go diff --git a/station/station_manager_test.go b/_station/station_manager_test.go similarity index 100% rename from station/station_manager_test.go rename to _station/station_manager_test.go diff --git a/station/station_metrics.go b/_station/station_metrics.go similarity index 100% rename from station/station_metrics.go rename to _station/station_metrics.go diff --git a/station/station_test.go b/_station/station_test.go similarity index 100% rename from station/station_test.go rename to _station/station_test.go diff --git a/utils/log.go b/_utils/log.go similarity index 100% rename from utils/log.go rename to _utils/log.go diff --git a/utils/log_example_test.go b/_utils/log_example_test.go similarity index 100% rename from utils/log_example_test.go rename to _utils/log_example_test.go diff --git a/utils/log_test.go b/_utils/log_test.go similarity index 100% rename from utils/log_test.go rename to _utils/log_test.go diff --git a/utils/rand.go b/_utils/rand.go similarity index 100% rename from utils/rand.go rename to _utils/rand.go diff --git a/utils/rand_test.go b/_utils/rand_test.go similarity index 100% rename from utils/rand_test.go rename to _utils/rand_test.go diff --git a/utils/station_name.go b/_utils/station_name.go similarity index 100% rename from utils/station_name.go rename to _utils/station_name.go diff --git a/utils/station_name_test.go b/_utils/station_name_test.go similarity index 100% rename from utils/station_name_test.go rename to _utils/station_name_test.go diff --git a/utils/stats.go b/_utils/stats.go similarity index 100% rename from utils/stats.go rename to _utils/stats.go diff --git a/utils/stats_test.go b/_utils/stats_test.go similarity index 100% rename from utils/stats_test.go rename to _utils/stats_test.go diff --git a/utils/timers.go b/_utils/timers.go similarity index 100% rename from utils/timers.go rename to _utils/timers.go diff --git a/utils/timers_test.go b/_utils/timers_test.go similarity index 100% rename from utils/timers_test.go rename to _utils/timers_test.go diff --git a/messenger/codec/json_test.go b/messenger/codec/json_test.go new file mode 100644 index 0000000..73b5d47 --- /dev/null +++ b/messenger/codec/json_test.go @@ -0,0 +1,29 @@ +package codec + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONRoundTrip(t *testing.T) { + t.Parallel() + + c := JSON[int]{} + + raw, err := c.Marshal(42) + require.NoError(t, err) + + got, err := c.Unmarshal(raw) + require.NoError(t, err) + assert.Equal(t, 42, got) +} + +func TestJSONUnmarshalInvalid(t *testing.T) { + t.Parallel() + + c := JSON[int]{} + _, err := c.Unmarshal([]byte(`"not-an-int"`)) + require.Error(t, err) +} diff --git a/messenger/messenger.go b/messenger/messenger.go index c986041..e6bf975 100644 --- a/messenger/messenger.go +++ b/messenger/messenger.go @@ -12,6 +12,7 @@ type subSpec struct { handler func(Message) } +// Messenger manages desired MQTT subscriptions. type Messenger struct { MQTT MQTT @@ -20,6 +21,7 @@ type Messenger struct { unsubs map[string]func() error } +// New returns a Messenger for the provided MQTT client. func New(mqtt MQTT) *Messenger { return &Messenger{ MQTT: mqtt, @@ -28,14 +30,14 @@ func New(mqtt MQTT) *Messenger { } } -// Register a subscription you want to always be active. +// WantSub registers a subscription that should always be active. func (m *Messenger) WantSub(topic string, qos byte, handler func(Message)) { m.mu.Lock() defer m.mu.Unlock() m.subscriptions[topic] = subSpec{topic: topic, qos: qos, handler: handler} } -// Apply all desired subscriptions (call on first connect and on every reconnect). +// ResubscribeAll applies desired subscriptions on connect and reconnect. func (m *Messenger) ResubscribeAll(ctx context.Context) { slog.Info("MQTT connected; (re)subscribing", "count", len(m.subscriptions)) diff --git a/messenger/messenger_test.go b/messenger/messenger_test.go new file mode 100644 index 0000000..9e572ea --- /dev/null +++ b/messenger/messenger_test.go @@ -0,0 +1,109 @@ +package messenger + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeMQTT struct { + mu sync.Mutex + subs []subCall + subscribeCalls map[string]int + unsubCalls map[string]int +} + +type subCall struct { + topic string + qos byte + handler func(Message) +} + +func newFakeMQTT() *fakeMQTT { + return &fakeMQTT{ + subscribeCalls: make(map[string]int), + unsubCalls: make(map[string]int), + } +} + +func (f *fakeMQTT) Publish(ctx context.Context, topic string, payload []byte, retain bool, qos byte) error { + return nil +} + +func (f *fakeMQTT) Subscribe(ctx context.Context, topic string, qos byte, handler func(Message)) (func() error, error) { + f.mu.Lock() + f.subs = append(f.subs, subCall{topic: topic, qos: qos, handler: handler}) + f.subscribeCalls[topic]++ + f.mu.Unlock() + + return func() error { + f.mu.Lock() + f.unsubCalls[topic]++ + f.mu.Unlock() + return nil + }, nil +} + +func (f *fakeMQTT) SetWill(topic string, payload []byte, retain bool, qos byte) error { + return nil +} + +func (f *fakeMQTT) snapshot() (subs []subCall, subscribeCalls map[string]int, unsubCalls map[string]int) { + f.mu.Lock() + defer f.mu.Unlock() + + subs = append([]subCall(nil), f.subs...) + subscribeCalls = make(map[string]int, len(f.subscribeCalls)) + for k, v := range f.subscribeCalls { + subscribeCalls[k] = v + } + unsubCalls = make(map[string]int, len(f.unsubCalls)) + for k, v := range f.unsubCalls { + unsubCalls[k] = v + } + return subs, subscribeCalls, unsubCalls +} + +func TestMessengerResubscribeAllSubscribes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newFakeMQTT() + m := New(mqtt) + m.WantSub("otto/devices/lamp/set", 1, func(Message) {}) + m.WantSub("otto/devices/lamp/state", 0, func(Message) {}) + + m.ResubscribeAll(ctx) + + subs, calls, _ := mqtt.snapshot() + require.Len(t, subs, 2) + assert.Equal(t, 1, calls["otto/devices/lamp/set"]) + assert.Equal(t, 1, calls["otto/devices/lamp/state"]) +} + +func TestMessengerResubscribeAllUnsubscribesPrevious(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newFakeMQTT() + m := New(mqtt) + m.WantSub("otto/devices/lamp/set", 1, func(Message) {}) + m.WantSub("otto/devices/lamp/state", 0, func(Message) {}) + + m.ResubscribeAll(ctx) + m.ResubscribeAll(ctx) + + _, calls, unsubs := mqtt.snapshot() + assert.Equal(t, 2, calls["otto/devices/lamp/set"]) + assert.Equal(t, 2, calls["otto/devices/lamp/state"]) + assert.Equal(t, 1, unsubs["otto/devices/lamp/set"]) + assert.Equal(t, 1, unsubs["otto/devices/lamp/state"]) +} diff --git a/messenger/mqtt/paho_test.go b/messenger/mqtt/paho_test.go new file mode 100644 index 0000000..1bd2904 --- /dev/null +++ b/messenger/mqtt/paho_test.go @@ -0,0 +1,290 @@ +package mqtt + +import ( + "context" + "errors" + "testing" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/rustyeddy/otto/messenger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeToken struct { + waitTimeoutResult bool + err error + waitCalls int + waitTimeoutCalls int + done chan struct{} +} + +func newFakeToken(waitTimeoutResult bool, err error) *fakeToken { + ch := make(chan struct{}) + close(ch) + return &fakeToken{ + waitTimeoutResult: waitTimeoutResult, + err: err, + done: ch, + } +} + +func (t *fakeToken) Wait() bool { + t.waitCalls++ + return true +} + +func (t *fakeToken) WaitTimeout(time.Duration) bool { + t.waitTimeoutCalls++ + return t.waitTimeoutResult +} + +func (t *fakeToken) Done() <-chan struct{} { return t.done } +func (t *fakeToken) Error() error { return t.err } + +type fakeClient struct { + connectToken paho.Token + publishToken paho.Token + subscribeToken paho.Token + unsubscribeToken paho.Token + + published []publishArgs + subscriptions []subscriptionArgs + unsubscribed []string + connectedState bool +} + +type publishArgs struct { + topic string + qos byte + retain bool + payload interface{} +} + +type subscriptionArgs struct { + topic string + qos byte + handler paho.MessageHandler +} + +func (c *fakeClient) IsConnected() bool { return c.connectedState } +func (c *fakeClient) IsConnectionOpen() bool { return c.connectedState } +func (c *fakeClient) Connect() paho.Token { return c.connectToken } +func (c *fakeClient) Disconnect(uint) {} + +func (c *fakeClient) Publish(topic string, qos byte, retained bool, payload interface{}) paho.Token { + c.published = append(c.published, publishArgs{topic: topic, qos: qos, retain: retained, payload: payload}) + return c.publishToken +} + +func (c *fakeClient) Subscribe(topic string, qos byte, callback paho.MessageHandler) paho.Token { + c.subscriptions = append(c.subscriptions, subscriptionArgs{topic: topic, qos: qos, handler: callback}) + return c.subscribeToken +} + +func (c *fakeClient) SubscribeMultiple(map[string]byte, paho.MessageHandler) paho.Token { + return newFakeToken(true, nil) +} + +func (c *fakeClient) Unsubscribe(topics ...string) paho.Token { + c.unsubscribed = append(c.unsubscribed, topics...) + return c.unsubscribeToken +} + +func (c *fakeClient) AddRoute(string, paho.MessageHandler) {} +func (c *fakeClient) OptionsReader() paho.ClientOptionsReader { + return paho.NewOptionsReader(paho.NewClientOptions()) +} + +type fakeMessage struct { + topic string + payload []byte + retain bool + qos byte +} + +func (m *fakeMessage) Duplicate() bool { return false } +func (m *fakeMessage) Qos() byte { return m.qos } +func (m *fakeMessage) Retained() bool { return m.retain } +func (m *fakeMessage) Topic() string { return m.topic } +func (m *fakeMessage) MessageID() uint16 { + return 1 +} +func (m *fakeMessage) Payload() []byte { return m.payload } +func (m *fakeMessage) Ack() {} + +func TestNewUsesProvidedClientID(t *testing.T) { + t.Parallel() + + cfg := Config{ + Broker: "tcp://example:1883", + ClientID: "client-1", + Username: "user", + Password: "pass", + CleanSession: true, + } + + p := New(cfg) + + require.NotNil(t, p.opts) + assert.Equal(t, cfg.ClientID, p.opts.ClientID) + assert.Equal(t, cfg.Username, p.opts.Username) + assert.Equal(t, cfg.Password, p.opts.Password) + assert.Equal(t, cfg.CleanSession, p.opts.CleanSession) + require.Len(t, p.opts.Servers, 1) + assert.Equal(t, cfg.Broker, p.opts.Servers[0].String()) +} + +func TestNewGeneratesClientID(t *testing.T) { + t.Parallel() + + p := New(Config{Broker: "tcp://example:1883"}) + require.NotNil(t, p.opts) + assert.Len(t, p.opts.ClientID, len("otto-")+8) + assert.Contains(t, p.opts.ClientID, "otto-") +} + +func TestRandSuffix(t *testing.T) { + t.Parallel() + + suffix := randSuffix() + assert.Len(t, suffix, 8) + for _, r := range suffix { + assert.True(t, (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9')) + } +} + +func TestSetWillWithoutOptions(t *testing.T) { + t.Parallel() + + p := &Paho{} + err := p.SetWill("topic", []byte("payload"), true, 1) + require.Error(t, err) +} + +func TestConnectTimeout(t *testing.T) { + t.Parallel() + + p := &Paho{ + opts: paho.NewClientOptions(), + c: &fakeClient{ + connectToken: newFakeToken(false, nil), + }, + } + + err := p.Connect(context.Background()) + require.Error(t, err) +} + +func TestConnectReturnsTokenError(t *testing.T) { + t.Parallel() + + p := &Paho{ + opts: paho.NewClientOptions(), + c: &fakeClient{ + connectToken: newFakeToken(true, errors.New("connect failed")), + }, + } + + err := p.Connect(context.Background()) + require.Error(t, err) +} + +func TestPublishQoS0DoesNotWait(t *testing.T) { + t.Parallel() + + token := newFakeToken(true, nil) + client := &fakeClient{publishToken: token} + p := &Paho{c: client} + + err := p.Publish(context.Background(), "topic", []byte("payload"), false, 0) + require.NoError(t, err) + assert.Equal(t, 0, token.waitTimeoutCalls) + require.Len(t, client.published, 1) + assert.Equal(t, "topic", client.published[0].topic) +} + +func TestPublishQoS1Waits(t *testing.T) { + t.Parallel() + + token := newFakeToken(true, nil) + p := &Paho{c: &fakeClient{publishToken: token}} + + err := p.Publish(context.Background(), "topic", []byte("payload"), false, 1) + require.NoError(t, err) + assert.Equal(t, 1, token.waitTimeoutCalls) +} + +func TestPublishTimeout(t *testing.T) { + t.Parallel() + + token := newFakeToken(false, nil) + p := &Paho{c: &fakeClient{publishToken: token}} + + err := p.Publish(context.Background(), "topic", []byte("payload"), false, 1) + require.Error(t, err) +} + +func TestSubscribeSuccessAndUnsubscribe(t *testing.T) { + t.Parallel() + + subToken := newFakeToken(true, nil) + unsubToken := newFakeToken(true, nil) + client := &fakeClient{ + subscribeToken: subToken, + unsubscribeToken: unsubToken, + } + p := &Paho{c: client} + + got := make(chan messenger.Message, 1) + unsub, err := p.Subscribe(context.Background(), "topic", 1, func(m messenger.Message) { + got <- m + }) + require.NoError(t, err) + require.NotNil(t, unsub) + require.Len(t, client.subscriptions, 1) + + handler := client.subscriptions[0].handler + handler(client, &fakeMessage{ + topic: "topic", + payload: []byte("payload"), + retain: true, + qos: 1, + }) + + select { + case msg := <-got: + assert.Equal(t, "topic", msg.Topic) + assert.Equal(t, []byte("payload"), msg.Payload) + assert.True(t, msg.Retain) + assert.Equal(t, byte(1), msg.QoS) + default: + require.Fail(t, "expected handler to be called") + } + + err = unsub() + require.NoError(t, err) + assert.Equal(t, 1, unsubToken.waitTimeoutCalls) + assert.Equal(t, []string{"topic"}, client.unsubscribed) +} + +func TestSubscribeTimeout(t *testing.T) { + t.Parallel() + + client := &fakeClient{subscribeToken: newFakeToken(false, nil)} + p := &Paho{c: client} + + _, err := p.Subscribe(context.Background(), "topic", 1, func(messenger.Message) {}) + require.Error(t, err) +} + +func TestSubscribeTokenError(t *testing.T) { + t.Parallel() + + client := &fakeClient{subscribeToken: newFakeToken(true, errors.New("sub failed"))} + p := &Paho{c: client} + + _, err := p.Subscribe(context.Background(), "topic", 1, func(messenger.Message) {}) + require.Error(t, err) +} diff --git a/messenger/mqtt_client.go b/messenger/mqtt_client.go index 8131398..ca7c4de 100644 --- a/messenger/mqtt_client.go +++ b/messenger/mqtt_client.go @@ -2,6 +2,7 @@ package messenger import "context" +// Message is a decoded MQTT message delivered to a handler. type Message struct { Topic string Payload []byte @@ -9,6 +10,7 @@ type Message struct { QoS byte } +// MQTT abstracts the MQTT client operations used by the messenger. type MQTT interface { // Publish should be safe to call from multiple goroutines. Publish(ctx context.Context, topic string, payload []byte, retain bool, qos byte) error diff --git a/messenger/payloads.go b/messenger/payloads.go index 8ba9a5b..ed95fc2 100644 --- a/messenger/payloads.go +++ b/messenger/payloads.go @@ -2,11 +2,13 @@ package messenger import "time" +// StatusPayload is the JSON body for status topics. type StatusPayload struct { Status string `json:"status"` // "online"|"offline" Time time.Time `json:"time"` } +// MetaPayload is the JSON body for device metadata topics. type MetaPayload struct { Name string `json:"name"` Kind string `json:"kind"` diff --git a/messenger/payloads_test.go b/messenger/payloads_test.go new file mode 100644 index 0000000..be02f9a --- /dev/null +++ b/messenger/payloads_test.go @@ -0,0 +1,92 @@ +package messenger + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatusPayloadJSON(t *testing.T) { + t.Parallel() + + ts := time.Date(2024, time.January, 2, 3, 4, 5, 0, time.UTC) + payload := StatusPayload{ + Status: "online", + Time: ts, + } + + raw, err := json.Marshal(payload) + require.NoError(t, err) + + var got map[string]any + require.NoError(t, json.Unmarshal(raw, &got)) + + assert.Equal(t, "online", got["status"]) + assert.Equal(t, ts.Format(time.RFC3339Nano), got["time"]) +} + +func TestMetaPayloadJSONOmitEmpty(t *testing.T) { + t.Parallel() + + payload := MetaPayload{ + Name: "lamp", + Kind: "switch", + ValueType: "bool", + Access: "rw", + } + + raw, err := json.Marshal(payload) + require.NoError(t, err) + + var got map[string]any + require.NoError(t, json.Unmarshal(raw, &got)) + + assert.Equal(t, "lamp", got["name"]) + assert.Equal(t, "switch", got["kind"]) + assert.Equal(t, "bool", got["value_type"]) + assert.Equal(t, "rw", got["access"]) + assert.NotContains(t, got, "unit") + assert.NotContains(t, got, "min") + assert.NotContains(t, got, "max") + assert.NotContains(t, got, "tags") + assert.NotContains(t, got, "attrs") +} + +func TestMetaPayloadJSONFull(t *testing.T) { + t.Parallel() + + min := 1.5 + max := 9.5 + payload := MetaPayload{ + Name: "temp", + Kind: "sensor", + ValueType: "float", + Access: "ro", + Unit: "C", + Min: &min, + Max: &max, + Tags: []string{"indoor", "calibrated"}, + Attrs: map[string]string{ + "vendor": "acme", + }, + } + + raw, err := json.Marshal(payload) + require.NoError(t, err) + + var got map[string]any + require.NoError(t, json.Unmarshal(raw, &got)) + + assert.Equal(t, "temp", got["name"]) + assert.Equal(t, "sensor", got["kind"]) + assert.Equal(t, "float", got["value_type"]) + assert.Equal(t, "ro", got["access"]) + assert.Equal(t, "C", got["unit"]) + assert.Equal(t, min, got["min"]) + assert.Equal(t, max, got["max"]) + assert.Equal(t, []any{"indoor", "calibrated"}, got["tags"]) + assert.Equal(t, map[string]any{"vendor": "acme"}, got["attrs"]) +} diff --git a/messenger/registry.go b/messenger/registry.go index 9191baa..32dd9c5 100644 --- a/messenger/registry.go +++ b/messenger/registry.go @@ -10,12 +10,14 @@ import ( "github.com/rustyeddy/devices" ) +// Logger is the minimal logging interface used by Registry. type Logger interface { Info(msg string, args ...any) Warn(msg string, args ...any) Error(msg string, args ...any) } +// Registry wires devices to MQTT topics and keeps a small state cache. type Registry struct { MQTT MQTT Topics TopicScheme @@ -53,6 +55,7 @@ type Registry struct { stateAny map[string]any } +// NewRegistry builds a Registry with defaults set for QoS and retention. func NewRegistry(m MQTT, topics TopicScheme) *Registry { return &Registry{ MQTT: m, @@ -73,6 +76,7 @@ func NewRegistry(m MQTT, topics TopicScheme) *Registry { } } +// Add appends a device to the registry. func (r *Registry) Add(dev devices.Device) { r.mu.Lock() defer r.mu.Unlock() @@ -187,7 +191,7 @@ func (r *Registry) wireEvents(ctx context.Context, dev devices.Device) { } // Run starts device goroutines, wires events, and publishes status/meta. -// IMPORTANT: For reconnect-resubscribe to work, your MQTT adapter must call r.ResubscribeAll(ctx) on connect. +// For reconnect-resubscribe to work, your MQTT adapter must call ResubscribeAll on connect. func (r *Registry) Run(ctx context.Context) error { // Snapshot devices r.mu.RLock() @@ -260,7 +264,7 @@ func (r *Registry) StateRaw(name string) ([]byte, bool) { return b, ok } -// StateAny returns the last decoded state value (if known). +// StateAny returns the last decoded state value, if known. func (r *Registry) StateAny(name string) (any, bool) { r.stateMu.RLock() defer r.stateMu.RUnlock() diff --git a/messenger/registry_test.go b/messenger/registry_test.go new file mode 100644 index 0000000..b38fb68 --- /dev/null +++ b/messenger/registry_test.go @@ -0,0 +1,243 @@ +package messenger + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/rustyeddy/devices" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type registryMQTT struct { + mu sync.Mutex + publishes []publishCall + wills []publishCall + + subscribeCalls map[string]int + unsubCalls map[string]int +} + +type publishCall struct { + topic string + body []byte + retain bool + qos byte +} + +func newRegistryMQTT() *registryMQTT { + return ®istryMQTT{ + subscribeCalls: make(map[string]int), + unsubCalls: make(map[string]int), + } +} + +func (m *registryMQTT) Publish(ctx context.Context, topic string, payload []byte, retain bool, qos byte) error { + m.mu.Lock() + m.publishes = append(m.publishes, publishCall{topic: topic, body: payload, retain: retain, qos: qos}) + m.mu.Unlock() + return nil +} + +func (m *registryMQTT) Subscribe(ctx context.Context, topic string, qos byte, handler func(Message)) (func() error, error) { + m.mu.Lock() + m.subscribeCalls[topic]++ + m.mu.Unlock() + + return func() error { + m.mu.Lock() + m.unsubCalls[topic]++ + m.mu.Unlock() + return nil + }, nil +} + +func (m *registryMQTT) SetWill(topic string, payload []byte, retain bool, qos byte) error { + m.mu.Lock() + m.wills = append(m.wills, publishCall{topic: topic, body: payload, retain: retain, qos: qos}) + m.mu.Unlock() + return nil +} + +func (m *registryMQTT) snapshot() (publishes []publishCall, wills []publishCall, subs map[string]int, unsubs map[string]int) { + m.mu.Lock() + defer m.mu.Unlock() + + publishes = append([]publishCall(nil), m.publishes...) + wills = append([]publishCall(nil), m.wills...) + + subs = make(map[string]int, len(m.subscribeCalls)) + for k, v := range m.subscribeCalls { + subs[k] = v + } + unsubs = make(map[string]int, len(m.unsubCalls)) + for k, v := range m.unsubCalls { + unsubs[k] = v + } + return publishes, wills, subs, unsubs +} + +type fakeDevice struct { + name string + desc devices.Descriptor + events chan devices.Event + run func(ctx context.Context) error + descriptor bool +} + +func (d *fakeDevice) Name() string { return d.name } +func (d *fakeDevice) Run(ctx context.Context) error { + if d.run != nil { + return d.run(ctx) + } + <-ctx.Done() + return nil +} +func (d *fakeDevice) Events() <-chan devices.Event { return d.events } +func (d *fakeDevice) Descriptor() devices.Descriptor { + if d.descriptor { + return d.desc + } + return devices.Descriptor{} +} + +func TestRegistryResubscribeAll(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newRegistryMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + reg.WantSub("otto/devices/lamp/set", 1, func(Message) {}) + reg.WantSub("otto/devices/lamp/state", 0, func(Message) {}) + + reg.ResubscribeAll(ctx) + reg.ResubscribeAll(ctx) + + _, _, subs, unsubs := mqtt.snapshot() + assert.Equal(t, 2, subs["otto/devices/lamp/set"]) + assert.Equal(t, 2, subs["otto/devices/lamp/state"]) + assert.Equal(t, 1, unsubs["otto/devices/lamp/set"]) + assert.Equal(t, 1, unsubs["otto/devices/lamp/state"]) +} + +func TestRegistryRunReturnsError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newRegistryMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + + events := make(chan devices.Event) + close(events) + + dev := &fakeDevice{ + name: "lamp", + desc: devices.Descriptor{Name: "lamp", Kind: "switch", ValueType: "bool", Access: devices.ReadWrite}, + events: events, + descriptor: true, + run: func(context.Context) error { + return errors.New("boom") + }, + } + reg.Add(dev) + + err := reg.Run(ctx) + require.Error(t, err) + + publishes, wills, _, _ := mqtt.snapshot() + require.Len(t, wills, 1) + assert.Equal(t, "otto/devices/lamp/status", wills[0].topic) + + var status StatusPayload + var meta MetaPayload + var gotStatus bool + var gotMeta bool + + for _, call := range publishes { + switch call.topic { + case "otto/devices/lamp/status": + require.NoError(t, json.Unmarshal(call.body, &status)) + assert.Equal(t, "online", status.Status) + gotStatus = true + case "otto/devices/lamp/meta": + require.NoError(t, json.Unmarshal(call.body, &meta)) + assert.Equal(t, "lamp", meta.Name) + gotMeta = true + } + } + + assert.True(t, gotStatus) + assert.True(t, gotMeta) +} + +func TestRegistryRunGracefulShutdownPublishesOffline(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(waitCancel) + + mqtt := newRegistryMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + + started := make(chan struct{}) + events := make(chan devices.Event) + + dev := &fakeDevice{ + name: "sensor", + events: events, + run: func(ctx context.Context) error { + close(started) + <-ctx.Done() + return nil + }, + } + reg.Add(dev) + + done := make(chan error, 1) + go func() { done <- reg.Run(ctx) }() + + select { + case <-started: + case <-waitCtx.Done(): + require.Fail(t, "device did not start") + } + + cancel() + + select { + case err := <-done: + require.NoError(t, err) + case <-waitCtx.Done(): + require.Fail(t, "registry did not exit after cancel") + } + + publishes, _, _, _ := mqtt.snapshot() + var gotOnline bool + var gotOffline bool + for _, call := range publishes { + if call.topic != "otto/devices/sensor/status" { + continue + } + var status StatusPayload + require.NoError(t, json.Unmarshal(call.body, &status)) + if status.Status == "online" { + gotOnline = true + } + if status.Status == "offline" { + gotOffline = true + } + } + + assert.True(t, gotOnline) + assert.True(t, gotOffline) +} diff --git a/messenger/topics.go b/messenger/topics.go index ab831b5..19f28c9 100644 --- a/messenger/topics.go +++ b/messenger/topics.go @@ -2,16 +2,24 @@ package messenger import "path" -// I think we should add a station to the MQTT topic path - +// TopicScheme builds MQTT topic paths for devices. type TopicScheme struct { Prefix string // e.g. "otto" or "home" } func (s TopicScheme) base(name string) string { return path.Join(s.Prefix, "devices", name) } -func (s TopicScheme) State(name string) string { return path.Join(s.base(name), "state") } -func (s TopicScheme) Set(name string) string { return path.Join(s.base(name), "set") } -func (s TopicScheme) Event(name string) string { return path.Join(s.base(name), "event") } +// State returns the MQTT topic for a device's state. +func (s TopicScheme) State(name string) string { return path.Join(s.base(name), "state") } + +// Set returns the MQTT topic for a device's set command. +func (s TopicScheme) Set(name string) string { return path.Join(s.base(name), "set") } + +// Event returns the MQTT topic for a device's events. +func (s TopicScheme) Event(name string) string { return path.Join(s.base(name), "event") } + +// Status returns the MQTT topic for a device's status. func (s TopicScheme) Status(name string) string { return path.Join(s.base(name), "status") } -func (s TopicScheme) Meta(name string) string { return path.Join(s.base(name), "meta") } + +// Meta returns the MQTT topic for a device's metadata. +func (s TopicScheme) Meta(name string) string { return path.Join(s.base(name), "meta") } diff --git a/messenger/topics_test.go b/messenger/topics_test.go new file mode 100644 index 0000000..55502e3 --- /dev/null +++ b/messenger/topics_test.go @@ -0,0 +1,50 @@ +package messenger + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTopicSchemePaths(t *testing.T) { + t.Parallel() + + scheme := TopicScheme{Prefix: "otto"} + + tests := []struct { + name string + got string + expected string + }{ + {name: "state", got: scheme.State("lamp"), expected: "otto/devices/lamp/state"}, + {name: "set", got: scheme.Set("lamp"), expected: "otto/devices/lamp/set"}, + {name: "event", got: scheme.Event("lamp"), expected: "otto/devices/lamp/event"}, + {name: "status", got: scheme.Status("lamp"), expected: "otto/devices/lamp/status"}, + {name: "meta", got: scheme.Meta("lamp"), expected: "otto/devices/lamp/meta"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.expected, tc.got) + }) + } +} + +func TestStateAs(t *testing.T) { + t.Parallel() + + reg := NewRegistry(nil, TopicScheme{Prefix: "otto"}) + + reg.stateMu.Lock() + reg.stateAny["relay"] = true + reg.stateMu.Unlock() + + val, ok := StateAs[bool](reg, "relay") + require.True(t, ok) + assert.True(t, val) + + _, ok = StateAs[int](reg, "relay") + assert.False(t, ok) +} diff --git a/messenger/wire_typed_test.go b/messenger/wire_typed_test.go new file mode 100644 index 0000000..b33847d --- /dev/null +++ b/messenger/wire_typed_test.go @@ -0,0 +1,161 @@ +package messenger + +import ( + "context" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/rustyeddy/devices" + "github.com/rustyeddy/otto/messenger/codec" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type wireMQTT struct { + mu sync.Mutex + publishes []publishCall + publishCh chan publishCall +} + +func newWireMQTT() *wireMQTT { + return &wireMQTT{ + publishCh: make(chan publishCall, 8), + } +} + +func (m *wireMQTT) Publish(ctx context.Context, topic string, payload []byte, retain bool, qos byte) error { + call := publishCall{topic: topic, body: payload, retain: retain, qos: qos} + m.mu.Lock() + m.publishes = append(m.publishes, call) + m.mu.Unlock() + m.publishCh <- call + return nil +} + +func (m *wireMQTT) Subscribe(ctx context.Context, topic string, qos byte, handler func(Message)) (func() error, error) { + return func() error { return nil }, nil +} + +func (m *wireMQTT) SetWill(topic string, payload []byte, retain bool, qos byte) error { + return nil +} + +type fakeSource[T any] struct { + name string + out chan T + events chan devices.Event +} + +func (f *fakeSource[T]) Name() string { return f.name } +func (f *fakeSource[T]) Run(ctx context.Context) error { <-ctx.Done(); return nil } +func (f *fakeSource[T]) Events() <-chan devices.Event { return f.events } +func (f *fakeSource[T]) Out() <-chan T { return f.out } + +type fakeSink[T any] struct { + name string + in chan T + events chan devices.Event +} + +func (f *fakeSink[T]) Name() string { return f.name } +func (f *fakeSink[T]) Run(ctx context.Context) error { <-ctx.Done(); return nil } +func (f *fakeSink[T]) Events() <-chan devices.Event { return f.events } +func (f *fakeSink[T]) In() chan<- T { return f.in } + +func TestWireSourcePublishesAndCaches(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newWireMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + src := &fakeSource[int]{name: "temp", out: make(chan int, 1), events: make(chan devices.Event)} + + WireSource(ctx, reg, src, codec.JSON[int]{}) + + src.out <- 42 + + select { + case call := <-mqtt.publishCh: + assert.Equal(t, "otto/devices/temp/state", call.topic) + var got int + require.NoError(t, json.Unmarshal(call.body, &got)) + assert.Equal(t, 42, got) + case <-ctx.Done(): + require.Fail(t, "publish not received") + } + + raw, ok := reg.StateRaw("temp") + require.True(t, ok) + assert.NotEmpty(t, raw) + + val, ok := reg.StateAny("temp") + require.True(t, ok) + assert.Equal(t, 42, val) + + close(src.out) +} + +func TestWireSinkDeliversToDevice(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + mqtt := newWireMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + sink := &fakeSink[int]{name: "relay", in: make(chan int, 1), events: make(chan devices.Event)} + + WireSink(ctx, reg, sink, codec.JSON[int]{}) + + sub, ok := reg.subs["otto/devices/relay/set"] + require.True(t, ok) + + payload, err := json.Marshal(7) + require.NoError(t, err) + sub.handler(Message{Topic: "otto/devices/relay/set", Payload: payload}) + + select { + case got := <-sink.in: + assert.Equal(t, 7, got) + case <-ctx.Done(): + require.Fail(t, "set not delivered") + } +} + +func TestWireSinkContextCancelReturns(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancel() + + mqtt := newWireMQTT() + reg := NewRegistry(mqtt, TopicScheme{Prefix: "otto"}) + sink := &fakeSink[int]{name: "relay", in: make(chan int), events: make(chan devices.Event)} + + WireSink(ctx, reg, sink, codec.JSON[int]{}) + + sub, ok := reg.subs["otto/devices/relay/set"] + require.True(t, ok) + + payload, err := json.Marshal(9) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + sub.handler(Message{Topic: "otto/devices/relay/set", Payload: payload}) + close(done) + }() + + waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(waitCancel) + + select { + case <-done: + case <-waitCtx.Done(): + require.Fail(t, "handler did not return on canceled context") + } +} diff --git a/rules/follow.go b/rules/follow.go index c94b217..768e189 100644 --- a/rules/follow.go +++ b/rules/follow.go @@ -12,12 +12,15 @@ type Follow[T any] struct { Dst devices.Sink[T] } +// NewFollow returns a rule that copies source values into the sink. func NewFollow[T any](name string, src devices.Source[T], dst devices.Sink[T]) *Follow[T] { return &Follow[T]{name: name, Src: src, Dst: dst} } +// Name returns the rule name. func (f *Follow[T]) Name() string { return f.name } +// Run forwards source values into the sink until context cancellation. func (f *Follow[T]) Run(ctx context.Context) error { for { select { diff --git a/rules/rules.go b/rules/rules.go index 2635947..bd3bae4 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -5,17 +5,21 @@ import ( "sync" ) +// Rule is a named runnable unit of behavior. type Rule interface { Name() string Run(ctx context.Context) error } +// Runner executes a set of rules concurrently. type Runner struct { rules []Rule } +// NewRunner creates an empty Runner. func NewRunner() *Runner { return &Runner{} } +// Add registers a rule to run. func (r *Runner) Add(rule Rule) { r.rules = append(r.rules, rule) } // Run starts all rules and returns the first fatal error (or ctx cancellation). diff --git a/rules/rules_test.go b/rules/rules_test.go new file mode 100644 index 0000000..3f80996 --- /dev/null +++ b/rules/rules_test.go @@ -0,0 +1,133 @@ +package rules + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/rustyeddy/devices" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testRule struct { + name string + run func(ctx context.Context) error +} + +func (r testRule) Name() string { return r.name } +func (r testRule) Run(ctx context.Context) error { + return r.run(ctx) +} + +type fakeSource[T any] struct { + name string + out chan T + events chan devices.Event +} + +func (f *fakeSource[T]) Name() string { return f.name } +func (f *fakeSource[T]) Run(ctx context.Context) error { + <-ctx.Done() + return nil +} +func (f *fakeSource[T]) Events() <-chan devices.Event { return f.events } +func (f *fakeSource[T]) Out() <-chan T { return f.out } + +type fakeSink[T any] struct { + name string + in chan T + events chan devices.Event +} + +func (f *fakeSink[T]) Name() string { return f.name } +func (f *fakeSink[T]) Run(ctx context.Context) error { + <-ctx.Done() + return nil +} +func (f *fakeSink[T]) Events() <-chan devices.Event { return f.events } +func (f *fakeSink[T]) In() chan<- T { return f.in } + +func TestRunnerReturnsFirstError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + runner := NewRunner() + wantErr := errors.New("boom") + runner.Add(testRule{name: "err", run: func(context.Context) error { return wantErr }}) + runner.Add(testRule{name: "ok", run: func(context.Context) error { return nil }}) + + err := runner.Run(ctx) + require.ErrorIs(t, err, wantErr) +} + +func TestRunnerContextCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + waitCtx, waitCancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(waitCancel) + + started := make(chan struct{}) + runner := NewRunner() + runner.Add(testRule{name: "block", run: func(ctx context.Context) error { + close(started) + <-ctx.Done() + return nil + }}) + + done := make(chan error, 1) + go func() { done <- runner.Run(ctx) }() + + select { + case <-started: + case <-waitCtx.Done(): + require.Fail(t, "rule did not start before timeout") + } + + cancel() + + select { + case err := <-done: + assert.NoError(t, err) + case <-waitCtx.Done(): + require.Fail(t, "runner did not exit after cancel") + } +} + +func TestFollowForwards(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + src := &fakeSource[bool]{name: "button", out: make(chan bool, 1), events: make(chan devices.Event)} + dst := &fakeSink[bool]{name: "relay", in: make(chan bool, 1), events: make(chan devices.Event)} + + rule := NewFollow("follow", src, dst) + + done := make(chan error, 1) + go func() { done <- rule.Run(ctx) }() + + src.out <- true + + select { + case got := <-dst.in: + require.True(t, got) + case <-ctx.Done(): + require.Fail(t, "did not receive forwarded value") + } + + close(src.out) + + select { + case err := <-done: + assert.NoError(t, err) + case <-ctx.Done(): + require.Fail(t, "follow did not exit after source closed") + } +} diff --git a/rules/toggle_on_rising.go b/rules/toggle_on_rising.go index dc479d0..402c1ea 100644 --- a/rules/toggle_on_rising.go +++ b/rules/toggle_on_rising.go @@ -23,6 +23,7 @@ type ToggleOnRisingEdge struct { MinInterval time.Duration } +// NewToggleOnRisingEdge returns a rule that toggles a relay on rising edge presses. func NewToggleOnRisingEdge(name string, reg *messenger.Registry, btn devices.Source[bool], relay devices.Duplex[bool]) *ToggleOnRisingEdge { return &ToggleOnRisingEdge{ name: name, @@ -34,8 +35,10 @@ func NewToggleOnRisingEdge(name string, reg *messenger.Registry, btn devices.Sou } } +// Name returns the rule name. func (t *ToggleOnRisingEdge) Name() string { return t.name } +// Run listens for button presses and toggles the relay. func (t *ToggleOnRisingEdge) Run(ctx context.Context) error { var last time.Time