From f199e1c7820126ea93c6fdd086b81a9fccd1ac0a Mon Sep 17 00:00:00 2001 From: Remylus Losius Date: Tue, 16 Jun 2026 11:38:24 -0400 Subject: [PATCH] feat(ssh): wire per-host auth-method learning into discovery/intelligence/liveness The connection-profile memo (PR #566) only led the dial with the host's known-good SSH auth method on the compliance-scan path. The other three paths that talk to a managed host -- OS discovery, OS intelligence (collector), and the liveness privilege probe -- still dialed key-first every cycle, re-offering an unauthorized public key to password-only hosts (a failed publickey attempt that counts against MaxAuthTries and can trip fail2ban) on a loop. Extend the shared connprofile store into those three paths: - connprofile.WithHostID / HostIDFrom: context helpers so a transport that only receives host:port+cred can still look up + record the host's profile, without churning the SSHTransport.Dial signature (and its test stubs across discovery + collector). - discovery.SSHTransportProd gains WithProfiles + a dial seam: when a store is wired and the ctx carries a host id, it sets PreferAuth from the recorded method and records ObservedAuth after a successful dial. This one transport is what BOTH discovery and the collector dial through (via collectorSSHAdapter), so both learn at once. - discovery.go / collector.go wrap the ctx with the host id at the dial site (both hostFacts already carry HostID). - sshprivilege.Probe gains WithProfiles: it leads the dial with the recorded method (reordering buildAuthMethods for AuthBoth) and records which method authenticated via a local single-factor observer. - cmd/openwatch wires one shared connprofile.NewStore(pool) across all four paths (the scan now reuses it too). Learning stays best-effort: a missing host id, absent profile row, or store error dials in the default order and never fails the connection (hint, not a lock -- a stale hint self-heals on the next dial). Scope: the SSH auth-method dimension. sudo-mode (NOPASSWD vs password) learning for these three paths stays a follow-up -- they already probe sudo mode correctly each cycle; only the scan learns both today. Spec system-connection-profile -> v1.1.0: C-06, AC-08 (discovery/ collector transport), AC-09 (liveness probe). --- cmd/openwatch/main.go | 24 ++- internal/connprofile/context.go | 28 ++++ internal/intelligence/collector/collector.go | 4 + internal/intelligence/discovery/discovery.go | 5 + internal/intelligence/discovery/transport.go | 70 ++++++++- .../discovery/transport_learning_test.go | 146 ++++++++++++++++++ internal/sshprivilege/privilege.go | 139 ++++++++++++++--- internal/sshprivilege/privilege_test.go | 97 ++++++++++-- specs/system/connection-profile.spec.yaml | 25 ++- 9 files changed, 491 insertions(+), 47 deletions(-) create mode 100644 internal/connprofile/context.go create mode 100644 internal/intelligence/discovery/transport_learning_test.go diff --git a/cmd/openwatch/main.go b/cmd/openwatch/main.go index 5ab69ed9..15e676a4 100644 --- a/cmd/openwatch/main.go +++ b/cmd/openwatch/main.go @@ -348,11 +348,22 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int { } credSvc := credential.NewService(pool) + + // Per-host SSH connection memory shared by every path that talks to + // a managed host: the liveness privilege probe, OS discovery, OS + // intelligence collection, and the compliance scan all lead the dial + // with this host's last known-good auth method and record what + // authenticated. Spec system-connection-profile. + connStore := connprofile.NewStore(pool) + // Spec system-ssh-connectivity v1.2.0 C-09 / AC-18: thread the // SecurityConfig reader so the privilege probe can retry sudo -n // failures via sudo -S -k with the credential password — same - // gating as the collector + discovery firewall probe. - privProbe := sshprivilege.Probe(credSvc, sshprivilege.WithPolicyLoader(cfgStore)) + // gating as the collector + discovery firewall probe. WithProfiles + // adds the per-host auth-method learning (system-connection-profile). + privProbe := sshprivilege.Probe(credSvc, + sshprivilege.WithPolicyLoader(cfgStore), + sshprivilege.WithProfiles(connStore)) liveSvc := liveness.NewService(pool, audit.Emit, bus). WithConfigLoader(cfgStore.LoadConnectivity). @@ -369,6 +380,10 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int { discoSvc := discovery.NewService(pool, audit.Emit, bus). WithHostLookup(discovery.PoolHostLookup{Pool: pool}). WithCredentialService(credSvc). + // Profile-aware transport: lead the dial with the host's learned + // SSH auth method + record what authenticated (system-connection-profile). + WithSSHTransport(discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()). + WithProfiles(connStore)). // Spec system-ssh-connectivity v1.2.0 C-09 / AC-20: thread the // SecurityConfig reader so the firewall probe can retry a // sudo -n failure via sudo -S -k with the credential password @@ -389,7 +404,8 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int { WithCredentialService(credSvc). WithHostLookup(collector.PoolHostLookup{Pool: pool}). WithSSHTransport(collectorSSHAdapter{ - inner: discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()), + inner: discovery.NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()). + WithProfiles(connStore), }). // Spec system-ssh-connectivity v1.1.0 C-09: load the // allow_credential_sudo_password knob at cycle start. When the @@ -508,7 +524,7 @@ func cmdServe(cfg *config.Config, _ []string, stdout, stderr *os.File) int { vars, err := cfgStore.LoadScanVars(ctx) return vars, err }, - Profiles: connprofile.NewStore(pool), + Profiles: connStore, Policy: func(ctx context.Context) (bool, error) { cfg, err := cfgStore.LoadSecurity(ctx) return cfg.AllowCredentialSudoPassword, err diff --git a/internal/connprofile/context.go b/internal/connprofile/context.go new file mode 100644 index 00000000..61a7126a --- /dev/null +++ b/internal/connprofile/context.go @@ -0,0 +1,28 @@ +package connprofile + +import ( + "context" + + "github.com/google/uuid" +) + +type hostIDKey struct{} + +// WithHostID stashes the host id on ctx so a lower layer that only receives +// host:port + credential (e.g. an SSH transport behind an interface that +// can't easily grow a hostID parameter) can still look up and record this +// host's connection profile. +// +// This is best-effort learning enrichment, not a required dial parameter: +// a connection with no host id on the context simply skips the profile +// lookup/record and dials in the default order. +func WithHostID(ctx context.Context, id uuid.UUID) context.Context { + return context.WithValue(ctx, hostIDKey{}, id) +} + +// HostIDFrom returns the host id stashed by WithHostID. The bool is false +// when no (or a nil) id is present, so callers can guard the learning path. +func HostIDFrom(ctx context.Context) (uuid.UUID, bool) { + id, ok := ctx.Value(hostIDKey{}).(uuid.UUID) + return id, ok && id != uuid.Nil +} diff --git a/internal/intelligence/collector/collector.go b/internal/intelligence/collector/collector.go index 1ff95675..fa71781c 100644 --- a/internal/intelligence/collector/collector.go +++ b/internal/intelligence/collector/collector.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Hanalyx/openwatch/internal/audit" + "github.com/Hanalyx/openwatch/internal/connprofile" "github.com/Hanalyx/openwatch/internal/correlation" "github.com/Hanalyx/openwatch/internal/credential" "github.com/Hanalyx/openwatch/internal/eventbus" @@ -232,6 +233,9 @@ func (s *Service) runCycleWithTransport(ctx context.Context, hf hostFacts) (Snap if s.transport == nil { return Snapshot{}, 0, errors.New("collector: ssh transport not wired") } + // Carry the host id so a profile-aware transport can lead with this + // host's known-good SSH auth method and record what authenticated. + ctx = connprofile.WithHostID(ctx, hf.HostID) sess, err := s.transport.Dial(ctx, hf.Addr, hf.Port, hf.Cred) if err != nil { return Snapshot{}, 0, fmt.Errorf("collector: dial: %w", err) diff --git a/internal/intelligence/discovery/discovery.go b/internal/intelligence/discovery/discovery.go index d4014744..19e8202a 100644 --- a/internal/intelligence/discovery/discovery.go +++ b/internal/intelligence/discovery/discovery.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Hanalyx/openwatch/internal/audit" + "github.com/Hanalyx/openwatch/internal/connprofile" "github.com/Hanalyx/openwatch/internal/credential" "github.com/Hanalyx/openwatch/internal/eventbus" "github.com/Hanalyx/openwatch/internal/intelligence/probe" @@ -266,6 +267,10 @@ func (s *Service) discoverWithTransport(ctx context.Context, hf hostFacts) (Syst return SystemFacts{}, errors.New("discovery: ssh transport not wired") } + // Carry the host id so a profile-aware transport can lead with this + // host's known-good SSH auth method and record what authenticated. + ctx = connprofile.WithHostID(ctx, hf.HostID) + sess, err := s.transport.Dial(ctx, hf.Addr, hf.Port, hf.Cred) if err != nil { return SystemFacts{}, fmt.Errorf("discovery: dial: %w", err) diff --git a/internal/intelligence/discovery/transport.go b/internal/intelligence/discovery/transport.go index 57f2190b..e2f358a4 100644 --- a/internal/intelligence/discovery/transport.go +++ b/internal/intelligence/discovery/transport.go @@ -5,11 +5,23 @@ import ( "errors" "time" + "github.com/google/uuid" + "golang.org/x/crypto/ssh" + + "github.com/Hanalyx/openwatch/internal/connprofile" "github.com/Hanalyx/openwatch/internal/credential" owssh "github.com/Hanalyx/openwatch/internal/ssh" - "golang.org/x/crypto/ssh" ) +// ConnProfileStore is the subset of connprofile the transport uses to lead +// with the host's known-good SSH auth method and record what authenticated. +// nil disables learning (dial in the default key-first order). The host id +// is read from the context via connprofile.HostIDFrom. +type ConnProfileStore interface { + Get(ctx context.Context, hostID uuid.UUID) (connprofile.Profile, error) + RecordSSHAuth(ctx context.Context, hostID uuid.UUID, m connprofile.SSHAuthMethod) error +} + // SSHTransport is the seam between the discovery service and the actual // SSH path. Production uses sshTransport (wraps owssh.Dial + ssh.Session); // tests use stubSSHTransport. @@ -42,8 +54,13 @@ const DefaultProbeTimeout = 10 * time.Second // and tests that want a real (not stubbed) transport can construct one // without internal package boundaries. type SSHTransportProd struct { - mode owssh.Mode - store owssh.KnownHostsStore + mode owssh.Mode + store owssh.KnownHostsStore + profiles ConnProfileStore + // dial is the seam over internal/ssh.Dial — overridden in tests to + // exercise the auth-learning path (PreferAuth in / ObservedAuth out) + // without standing up a real SSH server. Defaults to dialReal. + dial func(ctx context.Context, host string, port int, cred *credential.Credential, opts owssh.DialOptions) (SSHSession, error) } // NewSSHTransport returns a production SSHTransport with the given @@ -51,24 +68,61 @@ type SSHTransportProd struct { // known-hosts store by default; cmd/openwatch can override with a // strict + persistent store later. func NewSSHTransport(mode owssh.Mode, store owssh.KnownHostsStore) *SSHTransportProd { - return &SSHTransportProd{mode: mode, store: store} + return &SSHTransportProd{mode: mode, store: store, dial: dialReal} +} + +// dialReal is the production dial: open the real SSH client and wrap it. +func dialReal(ctx context.Context, host string, port int, cred *credential.Credential, opts owssh.DialOptions) (SSHSession, error) { + client, err := owssh.Dial(ctx, host, port, cred, opts) + if err != nil { + return nil, err + } + return &SSHClientSession{client: client}, nil +} + +// WithProfiles enables per-host SSH auth-method learning: the transport +// leads the dial with the host's recorded method and records which method +// authenticated. The host id comes from connprofile.WithHostID on the ctx. +// nil (the default) keeps the historical key-first, no-learning behaviour. +func (t *SSHTransportProd) WithProfiles(p ConnProfileStore) *SSHTransportProd { + t.profiles = p + return t } // Dial opens one SSH client connection and returns it as an SSHSession -// that multiplexes ssh.Session per Run call. +// that multiplexes ssh.Session per Run call. When a profile store is wired +// and the ctx carries a host id, the dial leads with the host's recorded +// auth method and records the one that authenticated (a hint, not a lock: +// both methods are still offered, and a stale hint self-heals). func (t *SSHTransportProd) Dial(ctx context.Context, host string, port int, cred *credential.Credential) (SSHSession, error) { if cred == nil { return nil, errors.New("discovery: dial requires a resolved credential") } - client, err := owssh.Dial(ctx, host, port, cred, owssh.DialOptions{ + + hostID, learn := connprofile.HostIDFrom(ctx) + learn = learn && t.profiles != nil + + opts := owssh.DialOptions{ Mode: t.mode, Store: t.store, Timeout: owssh.DefaultDialTimeout, - }) + } + var observed string + if learn { + if p, err := t.profiles.Get(ctx, hostID); err == nil { + opts.PreferAuth = string(p.SSHAuthMethod) // "key"/"password" match the ssh.Prefer* tokens + } + opts.ObservedAuth = &observed + } + + sess, err := t.dial(ctx, host, port, cred, opts) if err != nil { return nil, err } - return &SSHClientSession{client: client}, nil + if learn && observed != "" { + _ = t.profiles.RecordSSHAuth(ctx, hostID, connprofile.SSHAuthMethod(observed)) + } + return sess, nil } // SSHClientSession is the per-host live SSH client. Each Run opens a diff --git a/internal/intelligence/discovery/transport_learning_test.go b/internal/intelligence/discovery/transport_learning_test.go new file mode 100644 index 00000000..cc169a15 --- /dev/null +++ b/internal/intelligence/discovery/transport_learning_test.go @@ -0,0 +1,146 @@ +// @spec system-connection-profile +// +// AC traceability (this file): +// +// AC-08 TestSSHTransportProd_AuthLearning +// +// The dial seam is stubbed so the auth-learning wiring (PreferAuth in, +// ObservedAuth out, RecordSSHAuth on success) is exercised without a real +// SSH server. The shared discovery prod transport is the one path both OS +// discovery and OS intelligence (collector) dial through. + +package discovery + +import ( + "context" + "errors" + "testing" + + "github.com/google/uuid" + + "github.com/Hanalyx/openwatch/internal/connprofile" + "github.com/Hanalyx/openwatch/internal/credential" + owssh "github.com/Hanalyx/openwatch/internal/ssh" +) + +// recordingProfiles is an in-memory connprofile store for the test. +type recordingProfiles struct { + prefer connprofile.SSHAuthMethod + getErr error + gotID uuid.UUID + recorded connprofile.SSHAuthMethod +} + +func (p *recordingProfiles) Get(_ context.Context, hostID uuid.UUID) (connprofile.Profile, error) { + p.gotID = hostID + if p.getErr != nil { + return connprofile.Profile{}, p.getErr + } + return connprofile.Profile{SSHAuthMethod: p.prefer}, nil +} + +func (p *recordingProfiles) RecordSSHAuth(_ context.Context, _ uuid.UUID, m connprofile.SSHAuthMethod) error { + p.recorded = m + return nil +} + +func TestSSHTransportProd_AuthLearning(t *testing.T) { + cred := &credential.Credential{Username: "u", AuthMethod: credential.AuthBoth, Password: "p"} + + // stubDial captures the PreferAuth handed down and simulates the + // observed method crypto/ssh would report on a successful handshake. + newStubDial := func(gotPrefer *string, simulateObserved string) func(context.Context, string, int, *credential.Credential, owssh.DialOptions) (SSHSession, error) { + return func(_ context.Context, _ string, _ int, _ *credential.Credential, opts owssh.DialOptions) (SSHSession, error) { + *gotPrefer = opts.PreferAuth + if opts.ObservedAuth != nil { + *opts.ObservedAuth = simulateObserved + } + return learnStubSession{}, nil + } + } + + t.Run("system-connection-profile/AC-08", func(t *testing.T) { + hostID := uuid.Must(uuid.NewV7()) + profiles := &recordingProfiles{prefer: connprofile.AuthPassword} + var gotPrefer string + + tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles) + tr.dial = newStubDial(&gotPrefer, "password") + + ctx := connprofile.WithHostID(context.Background(), hostID) + if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil { + t.Fatalf("dial: %v", err) + } + if gotPrefer != "password" { + t.Errorf("lead-with: want PreferAuth=password, got %q", gotPrefer) + } + if profiles.gotID != hostID { + t.Errorf("lookup: want Get(%s), got Get(%s)", hostID, profiles.gotID) + } + if profiles.recorded != connprofile.AuthPassword { + t.Errorf("record: want recorded=password, got %q", profiles.recorded) + } + }) + + t.Run("no host id on ctx: no learning", func(t *testing.T) { + profiles := &recordingProfiles{prefer: connprofile.AuthPassword} + var gotPrefer string + + tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles) + tr.dial = newStubDial(&gotPrefer, "key") + + if _, err := tr.Dial(context.Background(), "192.0.2.1", 22, cred); err != nil { + t.Fatalf("dial: %v", err) + } + if gotPrefer != "" { + t.Errorf("no host id: want empty PreferAuth, got %q", gotPrefer) + } + if profiles.recorded != "" { + t.Errorf("no host id: want no record, got %q", profiles.recorded) + } + }) + + t.Run("no store wired: no learning", func(t *testing.T) { + var gotPrefer string + tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()) + tr.dial = newStubDial(&gotPrefer, "key") + + ctx := connprofile.WithHostID(context.Background(), uuid.Must(uuid.NewV7())) + if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil { + t.Fatalf("dial: %v", err) + } + if gotPrefer != "" { + t.Errorf("no store: want empty PreferAuth, got %q", gotPrefer) + } + }) + + t.Run("profile lookup error is non-fatal", func(t *testing.T) { + profiles := &recordingProfiles{getErr: errors.New("db down")} + var gotPrefer string + + tr := NewSSHTransport(owssh.ModeTOFU, owssh.NewMemoryStore()).WithProfiles(profiles) + tr.dial = newStubDial(&gotPrefer, "password") + + ctx := connprofile.WithHostID(context.Background(), uuid.Must(uuid.NewV7())) + if _, err := tr.Dial(ctx, "192.0.2.1", 22, cred); err != nil { + t.Fatalf("dial: want success despite lookup error, got %v", err) + } + if gotPrefer != "" { + t.Errorf("lookup error: want default order (empty PreferAuth), got %q", gotPrefer) + } + // observed still recorded — learning continues even when the + // lead-with hint was unavailable. + if profiles.recorded != connprofile.AuthPassword { + t.Errorf("record: want recorded=password, got %q", profiles.recorded) + } + }) +} + +// learnStubSession is a no-op SSHSession for the dial-seam tests. +type learnStubSession struct{} + +func (learnStubSession) Run(context.Context, string) ([]byte, int, error) { return nil, 0, nil } +func (learnStubSession) RunWithStdin(context.Context, string, []byte) ([]byte, int, error) { + return nil, 0, nil +} +func (learnStubSession) Close() error { return nil } diff --git a/internal/sshprivilege/privilege.go b/internal/sshprivilege/privilege.go index ed828637..191af50a 100644 --- a/internal/sshprivilege/privilege.go +++ b/internal/sshprivilege/privilege.go @@ -29,10 +29,12 @@ import ( "fmt" "io" "strings" + "sync" "time" "golang.org/x/crypto/ssh" + "github.com/Hanalyx/openwatch/internal/connprofile" "github.com/Hanalyx/openwatch/internal/credential" "github.com/Hanalyx/openwatch/internal/liveness" "github.com/Hanalyx/openwatch/internal/systemconfig" @@ -72,15 +74,29 @@ type SessionExecutor interface { // Dialer opens an SSH session against a host. Production uses // realDialer (crypto/ssh.Dial); tests inject a stub. +// +// prefer is the host's learned SSH auth method (connprofile.AuthUnknown +// when none): the dialer leads with it but still offers the other method. +// The returned method is the one that authenticated, for the caller to +// record. Both are best-effort learning, never a hard requirement. type Dialer interface { - Dial(ctx context.Context, cred *credential.Credential, addr string, timeout time.Duration) (SessionExecutor, error) + Dial(ctx context.Context, cred *credential.Credential, addr string, timeout time.Duration, prefer connprofile.SSHAuthMethod) (SessionExecutor, connprofile.SSHAuthMethod, error) +} + +// ConnProfileStore is the subset of connprofile the probe uses to lead the +// dial with the host's known-good SSH auth method and record what +// authenticated. nil (the default) disables learning. +type ConnProfileStore interface { + Get(ctx context.Context, hostID uuid.UUID) (connprofile.Profile, error) + RecordSSHAuth(ctx context.Context, hostID uuid.UUID, m connprofile.SSHAuthMethod) error } // probeConfig accumulates the optional dependencies a Probe needs. // Constructed by Probe(...) and the With* options. type probeConfig struct { - dialer Dialer - policy PolicyLoader + dialer Dialer + policy PolicyLoader + profiles ConnProfileStore } // ProbeOption configures the probe at construction time. Use @@ -102,6 +118,14 @@ func WithPolicyLoader(p PolicyLoader) ProbeOption { return func(c *probeConfig) { c.policy = p } } +// WithProfiles enables per-host SSH auth-method learning: the probe leads +// the dial with the host's recorded method and records which method +// authenticated. nil (the default) keeps the historical key-first, +// no-learning order. See system-connection-profile. +func WithProfiles(p ConnProfileStore) ProbeOption { + return func(c *probeConfig) { c.profiles = p } +} + // Probe builds a liveness.PrivilegeProbeFunc backed by the given // resolver. The returned function: // @@ -138,12 +162,27 @@ func Probe(resolver Resolver, opts ...ProbeOption) liveness.PrivilegeProbeFunc { return true, false, fmt.Errorf("resolve credential: %w", rerr) } - exec, derr := cfg.dialer.Dial(ctx, cred, addr, timeout) + // Learning: lead the dial with the host's recorded auth method + // (if a profile store is wired and a row exists), then record the + // method that actually authenticated. Both are best-effort: a + // lookup miss just dials in the default order. + var prefer connprofile.SSHAuthMethod + if cfg.profiles != nil { + if p, gerr := cfg.profiles.Get(ctx, id); gerr == nil { + prefer = p.SSHAuthMethod + } + } + + exec, observed, derr := cfg.dialer.Dial(ctx, cred, addr, timeout, prefer) if derr != nil { return true, false, fmt.Errorf("ssh dial: %w", derr) } defer func() { _ = exec.Close() }() + if cfg.profiles != nil && observed != "" { + _ = cfg.profiles.RecordSSHAuth(ctx, id, observed) + } + // Layer 1: sudo -n true. The 80% case where NOPASSWD is set. out, code, runErr := exec.Run(ctx, "sudo -n true") if runErr == nil && code == 0 { @@ -193,10 +232,11 @@ func canFallback(ctx context.Context, loader PolicyLoader, cred *credential.Cred type realDialer struct{} -func (realDialer) Dial(_ context.Context, cred *credential.Credential, addr string, timeout time.Duration) (SessionExecutor, error) { - methods, merr := buildAuthMethods(cred) +func (realDialer) Dial(_ context.Context, cred *credential.Credential, addr string, timeout time.Duration, prefer connprofile.SSHAuthMethod) (SessionExecutor, connprofile.SSHAuthMethod, error) { + obs := &authObserver{} + methods, merr := buildAuthMethods(cred, prefer, obs) if merr != nil { - return nil, merr + return nil, "", merr } cfg := &ssh.ClientConfig{ User: cred.Username, @@ -213,9 +253,34 @@ func (realDialer) Dial(_ context.Context, cred *credential.Credential, addr stri client, derr := ssh.Dial("tcp", addr, cfg) if derr != nil { - return nil, derr + return nil, "", derr } - return &realSession{client: client}, nil + // obs.Last() is the method that authenticated (single-factor: the + // client stops at the first accepted method). "" when nothing fired. + return &realSession{client: client}, obs.Last(), nil +} + +// authObserver records which auth-method callback last fired during a +// handshake. For OpenWatch's single-factor model the client stops at the +// first accepted method, so the last-attempted method after a SUCCESSFUL +// handshake is the one that authenticated. Mirrors internal/ssh's observer +// (kept local so this package stays decoupled from the scan dial path, +// which pins host keys — see the package doc). +type authObserver struct { + mu sync.Mutex + last connprofile.SSHAuthMethod +} + +func (o *authObserver) note(m connprofile.SSHAuthMethod) { + o.mu.Lock() + o.last = m + o.mu.Unlock() +} + +func (o *authObserver) Last() connprofile.SSHAuthMethod { + o.mu.Lock() + defer o.mu.Unlock() + return o.last } // buildAuthMethods translates the resolved credential into the ssh @@ -226,37 +291,73 @@ func (realDialer) Dial(_ context.Context, cred *credential.Credential, addr stri // for AuthBoth was the dialer bug behind the post-v1.2.0 regression // where the probe never made it past handshake on password-fallback // hosts. -func buildAuthMethods(cred *credential.Credential) ([]ssh.AuthMethod, error) { +// +// prefer reorders the AuthBoth list to lead with the host's learned +// method; both methods stay offered (a hint, not a lock). Each method is +// wrapped so obs records which one authenticated. +func buildAuthMethods(cred *credential.Credential, prefer connprofile.SSHAuthMethod, obs *authObserver) ([]ssh.AuthMethod, error) { + mkKey := func() (ssh.AuthMethod, error) { + signer, perr := parseSigner(cred) + if perr != nil { + return nil, perr + } + return ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { + obs.note(connprofile.AuthKey) + return []ssh.Signer{signer}, nil + }), nil + } + mkPassword := func() ssh.AuthMethod { + pw := cred.Password + return ssh.PasswordCallback(func() (string, error) { + obs.note(connprofile.AuthPassword) + return pw, nil + }) + } + switch cred.AuthMethod { case credential.AuthSSHKey: - signer, perr := parseSigner(cred) + keyM, perr := mkKey() if perr != nil { return nil, perr } - return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil + return []ssh.AuthMethod{keyM}, nil case credential.AuthPassword: - return []ssh.AuthMethod{ssh.Password(cred.Password)}, nil + return []ssh.AuthMethod{mkPassword()}, nil case credential.AuthBoth: - var methods []ssh.AuthMethod + var keyM, pwM ssh.AuthMethod if cred.PrivateKey != "" { - signer, perr := parseSigner(cred) + k, perr := mkKey() if perr != nil { return nil, perr } - methods = append(methods, ssh.PublicKeys(signer)) + keyM = k } if cred.Password != "" { - methods = append(methods, ssh.Password(cred.Password)) + pwM = mkPassword() } - if len(methods) == 0 { + if keyM == nil && pwM == nil { return nil, fmt.Errorf("auth method 'both' but credential carries neither key nor password") } - return methods, nil + if prefer == connprofile.AuthPassword { + return compactMethods(pwM, keyM), nil + } + return compactMethods(keyM, pwM), nil default: return nil, fmt.Errorf("unknown auth method %q", cred.AuthMethod) } } +// compactMethods drops nil entries, preserving order. +func compactMethods(ms ...ssh.AuthMethod) []ssh.AuthMethod { + out := make([]ssh.AuthMethod, 0, len(ms)) + for _, m := range ms { + if m != nil { + out = append(out, m) + } + } + return out +} + type realSession struct { client *ssh.Client } diff --git a/internal/sshprivilege/privilege_test.go b/internal/sshprivilege/privilege_test.go index ec64f50b..487667a9 100644 --- a/internal/sshprivilege/privilege_test.go +++ b/internal/sshprivilege/privilege_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/Hanalyx/openwatch/internal/connprofile" "github.com/Hanalyx/openwatch/internal/credential" "github.com/Hanalyx/openwatch/internal/liveness" "github.com/Hanalyx/openwatch/internal/systemconfig" @@ -122,16 +123,22 @@ func (p stubPolicy) LoadSecurity(_ context.Context) (systemconfig.SecurityConfig } // stubDialer returns a programmed stubExec without touching real SSH. +// gotPrefer records the auth-method hint the probe passed, so tests can +// assert the learning lead-with behaviour; observed is the method the +// stub reports as having authenticated. type stubDialer struct { - exec *stubExec - dialErr error + exec *stubExec + dialErr error + observed connprofile.SSHAuthMethod + gotPrefer connprofile.SSHAuthMethod } -func (d *stubDialer) Dial(_ context.Context, _ *credential.Credential, _ string, _ time.Duration) (SessionExecutor, error) { +func (d *stubDialer) Dial(_ context.Context, _ *credential.Credential, _ string, _ time.Duration, prefer connprofile.SSHAuthMethod) (SessionExecutor, connprofile.SSHAuthMethod, error) { + d.gotPrefer = prefer if d.dialErr != nil { - return nil, d.dialErr + return nil, "", d.dialErr } - return d.exec, nil + return d.exec, d.observed, nil } func validCred() *credential.Credential { @@ -279,7 +286,7 @@ func TestBuildAuthMethods(t *testing.T) { AuthMethod: credential.AuthSSHKey, PrivateKey: keyPEM, } - methods, err := buildAuthMethods(cred) + methods, err := buildAuthMethods(cred, "", &authObserver{}) if err != nil { t.Fatalf("err: %v", err) } @@ -294,7 +301,7 @@ func TestBuildAuthMethods(t *testing.T) { AuthMethod: credential.AuthPassword, Password: "p", } - methods, err := buildAuthMethods(cred) + methods, err := buildAuthMethods(cred, "", &authObserver{}) if err != nil { t.Fatalf("err: %v", err) } @@ -310,7 +317,7 @@ func TestBuildAuthMethods(t *testing.T) { PrivateKey: keyPEM, Password: "p", } - methods, err := buildAuthMethods(cred) + methods, err := buildAuthMethods(cred, "", &authObserver{}) if err != nil { t.Fatalf("err: %v", err) } @@ -325,7 +332,7 @@ func TestBuildAuthMethods(t *testing.T) { AuthMethod: credential.AuthBoth, PrivateKey: keyPEM, } - methods, err := buildAuthMethods(cred) + methods, err := buildAuthMethods(cred, "", &authObserver{}) if err != nil { t.Fatalf("err: %v", err) } @@ -340,7 +347,7 @@ func TestBuildAuthMethods(t *testing.T) { AuthMethod: credential.AuthBoth, Password: "p", } - methods, err := buildAuthMethods(cred) + methods, err := buildAuthMethods(cred, "", &authObserver{}) if err != nil { t.Fatalf("err: %v", err) } @@ -354,7 +361,7 @@ func TestBuildAuthMethods(t *testing.T) { Username: "u", AuthMethod: credential.AuthBoth, } - _, err := buildAuthMethods(cred) + _, err := buildAuthMethods(cred, "", &authObserver{}) if err == nil { t.Errorf("err: want non-nil for empty AuthBoth credential") } @@ -365,7 +372,7 @@ func TestBuildAuthMethods(t *testing.T) { Username: "u", AuthMethod: credential.AuthMethod("bogus"), } - _, err := buildAuthMethods(cred) + _, err := buildAuthMethods(cred, "", &authObserver{}) if err == nil { t.Errorf("err: want non-nil for unknown auth method") } @@ -398,3 +405,69 @@ func TestPrivilegeProbe_NoFallbackOnSudoNSuccess(t *testing.T) { } }) } + +// stubProfiles is an in-memory connprofile store for the learning tests. +type stubProfiles struct { + mu sync.Mutex + prefer connprofile.SSHAuthMethod + recorded connprofile.SSHAuthMethod + getErr error +} + +func (s *stubProfiles) Get(_ context.Context, _ uuid.UUID) (connprofile.Profile, error) { + if s.getErr != nil { + return connprofile.Profile{}, s.getErr + } + return connprofile.Profile{SSHAuthMethod: s.prefer}, nil +} + +func (s *stubProfiles) RecordSSHAuth(_ context.Context, _ uuid.UUID, m connprofile.SSHAuthMethod) error { + s.mu.Lock() + s.recorded = m + s.mu.Unlock() + return nil +} + +// @spec system-connection-profile +// @ac AC-09 +// AC-09 (liveness half): when a profile store is wired, the probe leads +// the dial with the host's recorded auth method and records the method +// that authenticated. Without a store, no learning occurs. +func TestPrivilegeProbe_AuthLearning(t *testing.T) { + t.Run("system-connection-profile/AC-09", func(t *testing.T) { + hostID := liveness.HostID(uuid.Must(uuid.NewV7()).String()) + exec := &stubExec{outcomes: map[string]execResult{"sudo -n true": {code: 0}}} + + profiles := &stubProfiles{prefer: connprofile.AuthPassword} + dialer := &stubDialer{exec: exec, observed: connprofile.AuthPassword} + + probe := Probe( + stubResolver{cred: validCred()}, + WithDialer(dialer), + WithProfiles(profiles), + ) + if _, ok, err := probe(context.Background(), hostID, "192.0.2.1:22", 2*time.Second); !ok { + t.Fatalf("ok: want true, got false; err=%v", err) + } + if dialer.gotPrefer != connprofile.AuthPassword { + t.Errorf("lead-with: want dial prefer=password, got %q", dialer.gotPrefer) + } + if profiles.recorded != connprofile.AuthPassword { + t.Errorf("record: want recorded=password, got %q", profiles.recorded) + } + }) + + t.Run("no store: no learning", func(t *testing.T) { + hostID := liveness.HostID(uuid.Must(uuid.NewV7()).String()) + exec := &stubExec{outcomes: map[string]execResult{"sudo -n true": {code: 0}}} + dialer := &stubDialer{exec: exec, observed: connprofile.AuthKey} + + probe := Probe(stubResolver{cred: validCred()}, WithDialer(dialer)) + if _, ok, _ := probe(context.Background(), hostID, "192.0.2.1:22", 2*time.Second); !ok { + t.Fatalf("ok: want true") + } + if dialer.gotPrefer != "" { + t.Errorf("no-store: want empty prefer, got %q", dialer.gotPrefer) + } + }) +} diff --git a/specs/system/connection-profile.spec.yaml b/specs/system/connection-profile.spec.yaml index 67de607a..22efeddf 100644 --- a/specs/system/connection-profile.spec.yaml +++ b/specs/system/connection-profile.spec.yaml @@ -1,7 +1,7 @@ spec: id: system-connection-profile title: Per-host SSH connection learning (auth method + sudo mode) - version: "1.0.0" + version: "1.1.0" status: approved tier: 2 @@ -37,11 +37,16 @@ spec: - Dial-layer auth-method ordering (PreferAuth) + observation (ObservedAuth) - Compliance-scan sudo -S support with per-connection sudo-mode probe + learning - Default-on sudo -S password fallback (kill-switch retained) + - Auth-method learning wired into the OS discovery, OS intelligence + (collector), and liveness privilege-probe paths (v1.1.0) — each + leads the dial with the host's recorded method and records what + authenticated via the shared connprofile store excludes: - Settings UI/API to toggle the kill-switch (DB-only for now) - - Extending auth/sudo learning to the discovery, intelligence, and - liveness paths (the substrate + dial mechanism land here; wiring - those paths is a follow-up) + - sudo-mode (NOPASSWD vs password) learning for the discovery, + intelligence, and liveness paths — those still probe sudo mode each + cycle; only the SSH auth-method dimension is wired for them so far + (the compliance scan learns both) constraints: - id: C-01 @@ -64,6 +69,10 @@ spec: description: The compliance scan MUST gate its sudo -S use of the credential password on the SAME two conditions the collector/liveness/discovery paths enforce — the AllowCredentialSudoPassword kill-switch is on AND the credential auth method is password or both. When either fails, the scan MUST NOT attempt sudo -S (it degrades to sudo -n). This gates only the sudo use of the password; SSH password authentication is independent. type: security enforcement: error + - id: C-06 + description: The OS discovery, OS intelligence (collector), and liveness privilege-probe paths MUST lead the SSH dial with the host's recorded auth method and record the method that authenticated, using the shared connprofile store. This learning MUST be best-effort — a missing host id, an absent profile row, or a store error MUST dial in the default order and MUST NOT fail the connection. + type: technical + enforcement: error acceptance_criteria: - id: AC-01 @@ -94,3 +103,11 @@ spec: description: The scan resolves a sudo password only when the kill-switch is on AND the credential auth method is password or both; a key-only credential or a kill-switch set false yields no sudo password (sudoPasswordFor returns empty), so the scan never attempts sudo -S in those cases. priority: critical references_constraints: [C-05] + - id: AC-08 + description: The discovery/collector SSH transport, when a profile store is wired and the ctx carries a host id (connprofile.WithHostID), sets DialOptions.PreferAuth from the host's recorded SSH auth method and records the observed method via RecordSSHAuth after a successful dial. With no host id on the ctx OR no store wired, it dials in the default order and records nothing. + priority: high + references_constraints: [C-06] + - id: AC-09 + description: The liveness privilege probe, when WithProfiles is set, leads the dial with the host's recorded SSH auth method and records the method that authenticated; with no profile store wired it passes no preference and records nothing. A profile lookup error is non-fatal (dials in the default order). + priority: high + references_constraints: [C-06]