diff --git a/cmd/agent/main.go b/cmd/agent/main.go index f76f5b3..0a51fab 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -97,7 +97,7 @@ func main() { } peerFactory := func(id, addr string) (paxos.PeerClient, error) { - return server.NewPaxosClient(id, addr) + return server.NewPaxosClient(id, addr, ident) } acceptor := paxos.NewAcceptor(agentID, ident, store) @@ -115,7 +115,7 @@ func main() { if *peerAddr != "" { glog.Infof("Attempting to fetch peer info from: %s", *peerAddr) - joinClient, err = server.NewPaxosClient("temp-peer", *peerAddr) + joinClient, err = server.NewPaxosClient("temp-peer", *peerAddr, ident) if err == nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) req := &paxosv1.GetKVEntryRequest{Key: constants.PeersKey} @@ -309,14 +309,14 @@ func main() { defer cancel() go func() { - err := server.RunGRPCServer(ctx, grpcLis, paxosSrv) + err := server.RunGRPCServer(ctx, grpcLis, paxosSrv, ident) if err != nil { errChan <- err } }() go func() { - httpSrv := server.NewHTTPServer(*httpAddr, store, cell) + httpSrv := server.NewHTTPServer(*httpAddr, store, ident, cell) err := httpSrv.Run(httpLis) if err != nil { errChan <- err diff --git a/internal/identity/identity.go b/internal/identity/identity.go index e961d49..9554e62 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -8,6 +8,7 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/sha256" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -15,6 +16,7 @@ import ( "encoding/pem" "fmt" "math/big" + "strings" "time" "google.golang.org/protobuf/proto" @@ -53,6 +55,7 @@ func Generate(shortName string) (*Identity, error) { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, + IsCA: true, } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) @@ -73,15 +76,95 @@ func Generate(shortName string) (*Identity, error) { // AgentID derives a unique agent ID from the certificate's public key. func (i *Identity) AgentID() string { - pubBytes, err := x509.MarshalPKIXPublicKey(i.Certificate.PublicKey) + return AgentIDFromCertificate(i.Certificate) +} + +// AgentIDFromCertificate derives a unique agent ID from the certificate's public key. +func AgentIDFromCertificate(cert *x509.Certificate) string { + if cert == nil { + return "" + } + pubBytes, err := x509.MarshalPKIXPublicKey(cert.PublicKey) if err != nil { - // This should not happen with a valid certificate return "" } hash := sha256.Sum256(pubBytes) return hex.EncodeToString(hash[:]) } +// ServerTLSConfig returns a TLS configuration for a gRPC server. +func (i *Identity) ServerTLSConfig() (*tls.Config, error) { + if i == nil { + return nil, fmt.Errorf("identity is nil") + } + cert := tls.Certificate{ + Certificate: [][]byte{i.Certificate.Raw}, + PrivateKey: i.PrivateKey, + Leaf: i.Certificate, + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAnyClientCert, + // We use VerifyPeerCertificate for custom validation of self-signed certs. + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + _, err := verifySelfSigned(rawCerts) + return err + }, + }, nil +} + +// ClientTLSConfig returns a TLS configuration for a gRPC client. +func (i *Identity) ClientTLSConfig(expectedRemoteID string) (*tls.Config, error) { + if i == nil { + return nil, fmt.Errorf("identity is nil") + } + cert := tls.Certificate{ + Certificate: [][]byte{i.Certificate.Raw}, + PrivateKey: i.PrivateKey, + Leaf: i.Certificate, + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + cert, err := verifySelfSigned(rawCerts) + if err != nil { + return err + } + if expectedRemoteID != "" && !strings.HasPrefix(expectedRemoteID, "temp-") { + remoteID := AgentIDFromCertificate(cert) + if remoteID != expectedRemoteID { + return fmt.Errorf("remote agent ID mismatch: expected %s, got %s", expectedRemoteID, remoteID) + } + } + return nil + }, + }, nil +} + +func verifySelfSigned(rawCerts [][]byte) (*x509.Certificate, error) { + if len(rawCerts) == 0 { + return nil, fmt.Errorf("no certificates provided") + } + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + // Check expiry + now := time.Now() + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return nil, fmt.Errorf("certificate is expired or not yet valid") + } + // Check self-signature + if err := cert.CheckSignatureFrom(cert); err != nil { + return nil, fmt.Errorf("certificate is not self-signed: %w", err) + } + return cert, nil +} + // Sign creates a signature for the given data using the private key. func (i *Identity) Sign(data []byte) ([]byte, error) { hash := sha256.Sum256(data) diff --git a/internal/server/BUILD.bazel b/internal/server/BUILD.bazel index cc431d6..2a2e981 100644 --- a/internal/server/BUILD.bazel +++ b/internal/server/BUILD.bazel @@ -27,7 +27,7 @@ go_library( "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//channelz/grpc_channelz_v1", "@org_golang_google_grpc//channelz/service", - "@org_golang_google_grpc//credentials/insecure", + "@org_golang_google_grpc//credentials", "@rules_go//go/runfiles", ], ) diff --git a/internal/server/grpc_client.go b/internal/server/grpc_client.go index c54e2eb..60ebb4f 100644 --- a/internal/server/grpc_client.go +++ b/internal/server/grpc_client.go @@ -6,9 +6,10 @@ import ( "context" "fmt" + "github.com/filmil/synod/internal/identity" paxosv1 "github.com/filmil/synod/proto/paxos/v1" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" ) // PaxosClient wraps a gRPC connection to a remote Paxos agent. @@ -22,8 +23,12 @@ type PaxosClient struct { } // NewPaxosClient establishes a connection to the specified address. -func NewPaxosClient(agentID string, addr string) (*PaxosClient, error) { - conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) +func NewPaxosClient(agentID string, addr string, ident *identity.Identity) (*PaxosClient, error) { + tlsConfig, err := ident.ClientTLSConfig(agentID) + if err != nil { + return nil, fmt.Errorf("failed to create client TLS config: %w", err) + } + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } diff --git a/internal/server/grpc_server.go b/internal/server/grpc_server.go index 030a3b7..5bcfd5d 100644 --- a/internal/server/grpc_server.go +++ b/internal/server/grpc_server.go @@ -4,6 +4,7 @@ package server import ( "context" + "fmt" "net" "time" @@ -15,6 +16,7 @@ import ( "github.com/golang/glog" "google.golang.org/grpc" "google.golang.org/grpc/channelz/service" + "google.golang.org/grpc/credentials" ) // PaxosServer implements both the internal PaxosService and the client-facing UserService over gRPC. @@ -338,8 +340,15 @@ func timeoutInterceptor(ctx context.Context, req interface{}, info *grpc.UnarySe } // RunGRPCServer starts the gRPC server and registers the Paxos and User APIs. -func RunGRPCServer(ctx context.Context, lis net.Listener, srv *PaxosServer) error { - s := grpc.NewServer(grpc.UnaryInterceptor(timeoutInterceptor)) +func RunGRPCServer(ctx context.Context, lis net.Listener, srv *PaxosServer, ident *identity.Identity) error { + tlsConfig, err := ident.ServerTLSConfig() + if err != nil { + return fmt.Errorf("failed to create server TLS config: %w", err) + } + s := grpc.NewServer( + grpc.Creds(credentials.NewTLS(tlsConfig)), + grpc.UnaryInterceptor(timeoutInterceptor), + ) paxosv1.RegisterPaxosServiceServer(s, srv) paxosv1.RegisterUserServiceServer(s, srv) diff --git a/internal/server/grpc_status.go b/internal/server/grpc_status.go index 7399e6b..17b24cb 100644 --- a/internal/server/grpc_status.go +++ b/internal/server/grpc_status.go @@ -11,7 +11,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/channelz/grpc_channelz_v1" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" "html" ) @@ -49,7 +49,14 @@ func (s *HTTPServer) handleGRPC(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - conn, err := grpc.NewClient(selfInfo.GRPCAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + tlsConfig, err := s.ident.ClientTLSConfig(agentID) + if err != nil { + data.ErrorMsg = template.HTML(fmt.Sprintf("
Failed to create client TLS config: %v
", html.EscapeString(err.Error()))) + s.renderGRPCStatus(w, data) + return + } + + conn, err := grpc.NewClient(selfInfo.GRPCAddr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) if err != nil { data.ErrorMsg = template.HTML(fmt.Sprintf("
Failed to connect to local gRPC channelz: %v
", html.EscapeString(err.Error()))) s.renderGRPCStatus(w, data) diff --git a/internal/server/http_server.go b/internal/server/http_server.go index 8860abf..d9659c4 100644 --- a/internal/server/http_server.go +++ b/internal/server/http_server.go @@ -258,6 +258,7 @@ const ( // HTTPServer provides a web dashboard for inspecting agent state and issuing commands. type HTTPServer struct { store *state.Store + ident *identity.Identity cell *paxos.Cell addr string @@ -266,10 +267,11 @@ type HTTPServer struct { } // NewHTTPServer initializes a new HTTPServer. -func NewHTTPServer(addr string, store *state.Store, cell *paxos.Cell) *HTTPServer { +func NewHTTPServer(addr string, store *state.Store, ident *identity.Identity, cell *paxos.Cell) *HTTPServer { return &HTTPServer{ addr: addr, store: store, + ident: ident, cell: cell, ongoingRequests: []OngoingRequest{}, } diff --git a/internal/server/http_server_test.go b/internal/server/http_server_test.go index ffc6ceb..e46ec71 100644 --- a/internal/server/http_server_test.go +++ b/internal/server/http_server_test.go @@ -39,7 +39,7 @@ func setupTestServer(t *testing.T) (*HTTPServer, func()) { } cell := paxos.NewCell(agentID, store, ident, acceptor, factory, ":50101", "http://localhost:8081") - server := NewHTTPServer(":8081", store, cell) + server := NewHTTPServer(":8081", store, ident, cell) cleanup := func() { store.Close() diff --git a/test/integration/BUILD.bazel b/test/integration/BUILD.bazel index 41fa49f..d6f0662 100644 --- a/test/integration/BUILD.bazel +++ b/test/integration/BUILD.bazel @@ -12,6 +12,7 @@ go_test( deps = [ "//internal/backoff", "//internal/constants", + "//internal/identity", "//internal/paxos", "//internal/server", "//internal/state", diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index 3b0daff..b2f42ab 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "github.com/filmil/synod/internal/constants" + "github.com/filmil/synod/internal/identity" "io" "net/http" "os" @@ -85,7 +86,7 @@ func TestIntegration_5Agents(t *testing.T) { } // Join agent-0 - client, err := server.NewPaxosClient("temp-joiner", addr0) + client, err := server.NewPaxosClient("temp-joiner", addr0, agents[i].ident) if err != nil { t.Fatalf("failed to create join client: %v", err) } @@ -218,6 +219,7 @@ type agentInstance struct { httpURL string dir string store *state.Store + ident *identity.Identity cell *paxos.Cell srv *server.PaxosServer cancelFunc context.CancelFunc @@ -236,7 +238,7 @@ func newAgentInstance(t *testing.T, id, dir, addr string) *agentInstance { ident, _ := store.GetIdentity("") peerFactory := func(id, addr string) (paxos.PeerClient, error) { - return server.NewPaxosClient(id, addr) + return server.NewPaxosClient(id, addr, ident) } acceptor := paxos.NewAcceptor(actualID, ident, store) @@ -252,6 +254,7 @@ func newAgentInstance(t *testing.T, id, dir, addr string) *agentInstance { id: actualID, dir: dir, store: store, + ident: ident, cell: cell, srv: srv, } @@ -283,7 +286,7 @@ func (a *agentInstance) run() { a.cell.StartEndpointSyncLoop(ctx, 2*time.Second) go func() { - server.RunGRPCServer(ctx, grpcLis, a.srv) + server.RunGRPCServer(ctx, grpcLis, a.srv, a.ident) }() go func() { diff --git a/test/integration/lock_test.go b/test/integration/lock_test.go index 2bded4a..6dbb3cb 100644 --- a/test/integration/lock_test.go +++ b/test/integration/lock_test.go @@ -47,7 +47,7 @@ func TestIntegration_Locks(t *testing.T) { infoI := state.PeerInfo{ShortName: fmt.Sprintf("agent-%d", i), GRPCAddr: agents[i].grpcAddr, HTTPURL: agents[i].httpURL} agents[i].store.AddMember(agents[i].id, infoI) - client, _ := server.NewPaxosClient("temp-joiner", addr0) + client, _ := server.NewPaxosClient("temp-joiner", addr0, agents[i].ident) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) client.JoinCluster(ctx, &paxosv1.JoinClusterRequest{ AgentId: agents[i].id, diff --git a/test/integration/ping_test.go b/test/integration/ping_test.go index f255d5c..c08f763 100644 --- a/test/integration/ping_test.go +++ b/test/integration/ping_test.go @@ -57,7 +57,7 @@ func TestIntegration_PeerRemovalOnFailure(t *testing.T) { agents[i].store.AddMember(agents[i].id, infoI) // Join via agent-0 - client, _ := server.NewPaxosClient("temp-joiner", addr0) + client, _ := server.NewPaxosClient("temp-joiner", addr0, agents[i].ident) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) client.JoinCluster(ctx, &paxosv1.JoinClusterRequest{ AgentId: agents[i].id,