Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions cmd/openwatch/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,22 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int {
}

credSvc := credential.NewService(pool)

// Per-host SSH connection memory shared by every path that talks to
// a managed host: the liveness privilege probe, OS discovery, OS
// intelligence collection, and the compliance scan all lead the dial
// with this host's last known-good auth method and record what
// authenticated. Spec system-connection-profile.
connStore := connprofile.NewStore(pool)

// Spec system-ssh-connectivity v1.2.0 C-09 / AC-18: thread the
// SecurityConfig reader so the privilege probe can retry sudo -n
// failures via sudo -S -k with the credential password — same
// gating as the collector + discovery firewall probe.
privProbe := sshprivilege.Probe(credSvc, sshprivilege.WithPolicyLoader(cfgStore))
// gating as the collector + discovery firewall probe. WithProfiles
// adds the per-host auth-method learning (system-connection-profile).
privProbe := sshprivilege.Probe(credSvc,
sshprivilege.WithPolicyLoader(cfgStore),
sshprivilege.WithProfiles(connStore))

liveSvc := liveness.NewService(pool, audit.Emit, bus).
WithConfigLoader(cfgStore.LoadConnectivity).
Expand All @@ -369,6 +380,10 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int {
discoSvc := discovery.NewService(pool, audit.Emit, bus).
WithHostLookup(discovery.PoolHostLookup{Pool: pool}).
WithCredentialService(credSvc).
// Profile-aware transport: lead the dial with the host's learned
// SSH auth method + record what authenticated (system-connection-profile).
WithSSHTransport(discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).
WithProfiles(connStore)).
// Spec system-ssh-connectivity v1.2.0 C-09 / AC-20: thread the
// SecurityConfig reader so the firewall probe can retry a
// sudo -n failure via sudo -S -k with the credential password
Expand All @@ -389,7 +404,8 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int {
WithCredentialService(credSvc).
WithHostLookup(collector.PoolHostLookup{Pool: pool}).
WithSSHTransport(collectorSSHAdapter{
inner: discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()),
inner: discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).
WithProfiles(connStore),
}).
// Spec system-ssh-connectivity v1.1.0 C-09: load the
// allow_credential_sudo_password knob at cycle start. When the
Expand Down Expand Up @@ -508,7 +524,7 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int {
vars, err := cfgStore.LoadScanVars(ctx)
return vars, err
},
Profiles: connprofile.NewStore(pool),
Profiles: connStore,
Policy: func(ctx context.Context) (bool, error) {
cfg, err := cfgStore.LoadSecurity(ctx)
return cfg.AllowCredentialSudoPassword, err
Expand Down
28 changes: 28 additions & 0 deletions internal/connprofile/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package connprofile

import (
"context"

"github.com/google/uuid"
)

type hostIDKey struct{}

// WithHostID stashes the host id on ctx so a lower layer that only receives
// host:port + credential (e.g. an SSH transport behind an interface that
// can't easily grow a hostID parameter) can still look up and record this
// host's connection profile.
//
// This is best-effort learning enrichment, not a required dial parameter:
// a connection with no host id on the context simply skips the profile
// lookup/record and dials in the default order.
func WithHostID(ctx context.Context, id uuid.UUID) context.Context {
return context.WithValue(ctx, hostIDKey{}, id)
}

// HostIDFrom returns the host id stashed by WithHostID. The bool is false
// when no (or a nil) id is present, so callers can guard the learning path.
func HostIDFrom(ctx context.Context) (uuid.UUID, bool) {
id, ok := ctx.Value(hostIDKey{}).(uuid.UUID)
return id, ok && id != uuid.Nil
}
4 changes: 4 additions & 0 deletions internal/intelligence/collector/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/Hanalyx/openwatch/internal/audit"
"github.com/Hanalyx/openwatch/internal/connprofile"
"github.com/Hanalyx/openwatch/internal/correlation"
"github.com/Hanalyx/openwatch/internal/credential"
"github.com/Hanalyx/openwatch/internal/eventbus"
Expand Down Expand Up @@ -232,6 +233,9 @@ func (s *Service) runCycleWithTransport(ctx context.Context, hf hostFacts) (Snap
if s.transport == nil {
return Snapshot{}, 0, errors.New("collector: ssh transport not wired")
}
// Carry the host id so a profile-aware transport can lead with this
// host's known-good SSH auth method and record what authenticated.
ctx = connprofile.WithHostID(ctx, hf.HostID)
sess, err := s.transport.Dial(ctx, hf.Addr, hf.Port, hf.Cred)
if err != nil {
return Snapshot{}, 0, fmt.Errorf("collector: dial: %w", err)
Expand Down
5 changes: 5 additions & 0 deletions internal/intelligence/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/Hanalyx/openwatch/internal/audit"
"github.com/Hanalyx/openwatch/internal/connprofile"
"github.com/Hanalyx/openwatch/internal/credential"
"github.com/Hanalyx/openwatch/internal/eventbus"
"github.com/Hanalyx/openwatch/internal/intelligence/probe"
Expand Down Expand Up @@ -266,6 +267,10 @@ func (s *Service) discoverWithTransport(ctx context.Context, hf hostFacts) (Syst
return SystemFacts{}, errors.New("discovery: ssh transport not wired")
}

// Carry the host id so a profile-aware transport can lead with this
// host's known-good SSH auth method and record what authenticated.
ctx = connprofile.WithHostID(ctx, hf.HostID)

sess, err := s.transport.Dial(ctx, hf.Addr, hf.Port, hf.Cred)
if err != nil {
return SystemFacts{}, fmt.Errorf("discovery: dial: %w", err)
Expand Down
70 changes: 62 additions & 8 deletions internal/intelligence/discovery/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,23 @@ import (
"errors"
"time"

"github.com/google/uuid"
"golang.org/x/crypto/ssh"

"github.com/Hanalyx/openwatch/internal/connprofile"
"github.com/Hanalyx/openwatch/internal/credential"
owssh "github.com/Hanalyx/openwatch/internal/ssh"
"golang.org/x/crypto/ssh"
)

// ConnProfileStore is the subset of connprofile the transport uses to lead
// with the host's known-good SSH auth method and record what authenticated.
// nil disables learning (dial in the default key-first order). The host id
// is read from the context via connprofile.HostIDFrom.
type ConnProfileStore interface {
Get(ctx context.Context, hostID uuid.UUID) (connprofile.Profile, error)
RecordSSHAuth(ctx context.Context, hostID uuid.UUID, m connprofile.SSHAuthMethod) error
}

// SSHTransport is the seam between the discovery service and the actual
// SSH path. Production uses sshTransport (wraps owssh.Dial + ssh.Session);
// tests use stubSSHTransport.
Expand Down Expand Up @@ -42,33 +54,75 @@ const DefaultProbeTimeout = 10 * time.Second
// and tests that want a real (not stubbed) transport can construct one
// without internal package boundaries.
type SSHTransportProd struct {
mode owssh.Mode
store owssh.KnownHostsStore
mode owssh.Mode
store owssh.KnownHostsStore
profiles ConnProfileStore
// dial is the seam over internal/ssh.Dial — overridden in tests to
// exercise the auth-learning path (PreferAuth in / ObservedAuth out)
// without standing up a real SSH server. Defaults to dialReal.
dial func(ctx context.Context, host string, port int, cred *credential.Credential, opts owssh.DialOptions) (SSHSession, error)
}

// NewSSHTransport returns a production SSHTransport with the given
// host-key policy. NewService() calls this with TOFU + an in-memory
// known-hosts store by default; cmd/openwatch can override with a
// strict + persistent store later.
func NewSSHTransport(mode owssh.Mode, store owssh.KnownHostsStore) *SSHTransportProd {
return &SSHTransportProd{mode: mode, store: store}
return &SSHTransportProd{mode: mode, store: store, dial: dialReal}
}

// dialReal is the production dial: open the real SSH client and wrap it.
func dialReal(ctx context.Context, host string, port int, cred *credential.Credential, opts owssh.DialOptions) (SSHSession, error) {
client, err := owssh.Dial(ctx, host, port, cred, opts)
if err != nil {
return nil, err
}
return &SSHClientSession{client: client}, nil
}

// WithProfiles enables per-host SSH auth-method learning: the transport
// leads the dial with the host's recorded method and records which method
// authenticated. The host id comes from connprofile.WithHostID on the ctx.
// nil (the default) keeps the historical key-first, no-learning behaviour.
func (t *SSHTransportProd) WithProfiles(p ConnProfileStore) *SSHTransportProd {
t.profiles = p
return t
}

// Dial opens one SSH client connection and returns it as an SSHSession
// that multiplexes ssh.Session per Run call.
// that multiplexes ssh.Session per Run call. When a profile store is wired
// and the ctx carries a host id, the dial leads with the host's recorded
// auth method and records the one that authenticated (a hint, not a lock:
// both methods are still offered, and a stale hint self-heals).
func (t *SSHTransportProd) Dial(ctx context.Context, host string, port int, cred *credential.Credential) (SSHSession, error) {
if cred == nil {
return nil, errors.New("discovery: dial requires a resolved credential")
}
client, err := owssh.Dial(ctx, host, port, cred, owssh.DialOptions{

hostID, learn := connprofile.HostIDFrom(ctx)
learn = learn && t.profiles != nil

opts := owssh.DialOptions{
Mode: t.mode,
Store: t.store,
Timeout: owssh.DefaultDialTimeout,
})
}
var observed string
if learn {
if p, err := t.profiles.Get(ctx, hostID); err == nil {
opts.PreferAuth = string(p.SSHAuthMethod) // "key"/"password" match the ssh.Prefer* tokens
}
opts.ObservedAuth = &observed
}

sess, err := t.dial(ctx, host, port, cred, opts)
if err != nil {
return nil, err
}
return &SSHClientSession{client: client}, nil
if learn && observed != "" {
_ = t.profiles.RecordSSHAuth(ctx, hostID, connprofile.SSHAuthMethod(observed))
}
return sess, nil
}

// SSHClientSession is the per-host live SSH client. Each Run opens a
Expand Down
146 changes: 146 additions & 0 deletions internal/intelligence/discovery/transport_learning_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// @spec system-connection-profile
//
// AC traceability (this file):
//
// AC-08 TestSSHTransportProd_AuthLearning
//
// The dial seam is stubbed so the auth-learning wiring (PreferAuth in,
// ObservedAuth out, RecordSSHAuth on success) is exercised without a real
// SSH server. The shared discovery prod transport is the one path both OS
// discovery and OS intelligence (collector) dial through.

package discovery

import (
"context"
"errors"
"testing"

"github.com/google/uuid"

"github.com/Hanalyx/openwatch/internal/connprofile"
"github.com/Hanalyx/openwatch/internal/credential"
owssh "github.com/Hanalyx/openwatch/internal/ssh"
)

// recordingProfiles is an in-memory connprofile store for the test.
type recordingProfiles struct {
prefer connprofile.SSHAuthMethod
getErr error
gotID uuid.UUID
recorded connprofile.SSHAuthMethod
}

func (p *recordingProfiles) Get(_ context.Context, hostID uuid.UUID) (connprofile.Profile, error) {
p.gotID = hostID
if p.getErr != nil {
return connprofile.Profile{}, p.getErr
}
return connprofile.Profile{SSHAuthMethod: p.prefer}, nil
}

func (p *recordingProfiles) RecordSSHAuth(_ context.Context, _ uuid.UUID, m connprofile.SSHAuthMethod) error {
p.recorded = m
return nil
}

func TestSSHTransportProd_AuthLearning(t *testing.T) {
cred := &credential.Credential{Username: "u", AuthMethod: credential.AuthBoth, Password: "p"}

// stubDial captures the PreferAuth handed down and simulates the
// observed method crypto/ssh would report on a successful handshake.
newStubDial := func(gotPrefer *string, simulateObserved string) func(context.Context, string, int, *credential.Credential, owssh.DialOptions) (SSHSession, error) {
return func(_ context.Context, _ string, _ int, _ *credential.Credential, opts owssh.DialOptions) (SSHSession, error) {
*gotPrefer = opts.PreferAuth
if opts.ObservedAuth != nil {
*opts.ObservedAuth = simulateObserved
}
return learnStubSession{}, nil
}
}

t.Run("system-connection-profile/AC-08", func(t *testing.T) {
hostID := uuid.Must(uuid.NewV7())
profiles := &recordingProfiles{prefer: connprofile.AuthPassword}
var gotPrefer string

tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles)
tr.dial = newStubDial(&gotPrefer, "password")

ctx := connprofile.WithHostID(context.Background(), hostID)
if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil {
t.Fatalf("dial: %v", err)
}
if gotPrefer != "password" {
t.Errorf("lead-with: want PreferAuth=password, got %q", gotPrefer)
}
if profiles.gotID != hostID {
t.Errorf("lookup: want Get(%s), got Get(%s)", hostID, profiles.gotID)
}
if profiles.recorded != connprofile.AuthPassword {
t.Errorf("record: want recorded=password, got %q", profiles.recorded)
}
})

t.Run("no host id on ctx: no learning", func(t *testing.T) {
profiles := &recordingProfiles{prefer: connprofile.AuthPassword}
var gotPrefer string

tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles)
tr.dial = newStubDial(&gotPrefer, "key")

if _, err := tr.Dial(context.Background(), "192.0.2.1", 22, cred); err != nil {
t.Fatalf("dial: %v", err)
}
if gotPrefer != "" {
t.Errorf("no host id: want empty PreferAuth, got %q", gotPrefer)
}
if profiles.recorded != "" {
t.Errorf("no host id: want no record, got %q", profiles.recorded)
}
})

t.Run("no store wired: no learning", func(t *testing.T) {
var gotPrefer string
tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore())
tr.dial = newStubDial(&gotPrefer, "key")

ctx := connprofile.WithHostID(context.Background(), uuid.Must(uuid.NewV7()))
if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil {
t.Fatalf("dial: %v", err)
}
if gotPrefer != "" {
t.Errorf("no store: want empty PreferAuth, got %q", gotPrefer)
}
})

t.Run("profile lookup error is non-fatal", func(t *testing.T) {
profiles := &recordingProfiles{getErr: errors.New("db down")}
var gotPrefer string

tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles)
tr.dial = newStubDial(&gotPrefer, "password")

ctx := connprofile.WithHostID(context.Background(), uuid.Must(uuid.NewV7()))
if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil {
t.Fatalf("dial: want success despite lookup error, got %v", err)
}
if gotPrefer != "" {
t.Errorf("lookup error: want default order (empty PreferAuth), got %q", gotPrefer)
}
// observed still recorded — learning continues even when the
// lead-with hint was unavailable.
if profiles.recorded != connprofile.AuthPassword {
t.Errorf("record: want recorded=password, got %q", profiles.recorded)
}
})
}

// learnStubSession is a no-op SSHSession for the dial-seam tests.
type learnStubSession struct{}

func (learnStubSession) Run(context.Context, string) ([]byte, int, error) { return nil, 0, nil }
func (learnStubSession) RunWithStdin(context.Context, string, []byte) ([]byte, int, error) {
return nil, 0, nil
}
func (learnStubSession) Close() error { return nil }
Loading
Loading