diff --git a/.gitignore b/.gitignore index cabd656..6d36170 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,6 @@ examples/hello-world-*.yaml # Review and scratch files review_report.md + +# AI planning artifacts (local only) +docs/ diff --git a/claude.md b/CLAUDE.md similarity index 97% rename from claude.md rename to CLAUDE.md index 3b0850f..d49fa73 100644 --- a/claude.md +++ b/CLAUDE.md @@ -2,6 +2,10 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +## AI Planning Artifacts + +The `docs/` directory is used exclusively for AI-assisted planning artifacts (specs, plans, design docs). These files are local only and must never be committed. They are listed in `.gitignore`. + ## Project Overview kubectl-coco is a kubectl plugin that transforms regular Kubernetes manifests into Confidential Containers (CoCo) enabled manifests. It automates RuntimeClass configuration, secret conversion to sealed format, imagePullSecrets handling, initdata generation, and Trustee KBS deployment/management. diff --git a/cmd/dump_initdata.go b/cmd/dump_initdata.go deleted file mode 100644 index f63a41f..0000000 --- a/cmd/dump_initdata.go +++ /dev/null @@ -1,160 +0,0 @@ -// Package cmd provides the command-line interface for kubectl-coco. -package cmd - -import ( - "bytes" - "compress/gzip" - "encoding/base64" - "fmt" - "io" - "strings" - - "github.com/confidential-devhub/cococtl/pkg/config" - "github.com/confidential-devhub/cococtl/pkg/initdata" - "github.com/pelletier/go-toml/v2" - "github.com/spf13/cobra" -) - -// dumpInitdataCmd displays generated initdata for inspection and debugging. -var dumpInitdataCmd = &cobra.Command{ - Use: "dump-initdata", - Short: "Display generated initdata for inspection", - Long: `Display the generated initdata configuration for inspection and debugging. - -By default, this command shows the decoded contents of: - - aa.toml (Attestation Agent configuration) - - cdh.toml (Confidential Data Hub configuration) - - policy.rego (Kata agent policy) - -Use --raw to output the gzip+base64 encoded annotation value that would -be added to Kubernetes manifests. - -Examples: - # Show decoded initdata from default config - kubectl coco dump-initdata - - # Show decoded initdata from specific config file - kubectl coco dump-initdata --config /path/to/coco-config.toml - - # Show raw base64-encoded annotation value - kubectl coco dump-initdata --raw`, - RunE: runDumpInitdata, -} - -var ( - dumpInitdataConfigPath string - dumpInitdataRaw bool -) - -func init() { - rootCmd.AddCommand(dumpInitdataCmd) - - dumpInitdataCmd.Flags().StringVar(&dumpInitdataConfigPath, "config", "", "Path to CoCo config file (default: ~/.kube/coco-config.toml)") - dumpInitdataCmd.Flags().BoolVar(&dumpInitdataRaw, "raw", false, "Output gzip+base64 encoded annotation value instead of decoded content") -} - -// runDumpInitdata generates and displays initdata for inspection. -func runDumpInitdata(_ *cobra.Command, _ []string) error { - // Determine config path - configPath := dumpInitdataConfigPath - if configPath == "" { - var err error - configPath, err = config.GetConfigPath() - if err != nil { - return fmt.Errorf("failed to get default config path: %w", err) - } - } - - // Load configuration - cfg, err := config.Load(configPath) - if err != nil { - return fmt.Errorf("failed to load config from %s: %w", configPath, err) - } - - if cfg.TrusteeServer == "" { - return fmt.Errorf("trustee_server is empty, initdata cannot be generated") - } - - if err := cfg.Validate(); err != nil { - return fmt.Errorf("invalid configuration: %w", err) - } - - // Generate initdata (nil for imagePullSecrets - not needed for inspection) - encoded, err := initdata.Generate(cfg, nil) - if err != nil { - return fmt.Errorf("failed to generate initdata: %w", err) - } - - // Output based on --raw flag - if dumpInitdataRaw { - // Output raw base64-encoded value - fmt.Println("# This is the gzip+base64 encoded initdata annotation value") - fmt.Println("# Use this value for the io.katacontainers.config.hypervisor.cc_init_data annotation") - fmt.Println(encoded) - return nil - } - - // Decode and display human-readable content - decoded, err := decodeInitdata(encoded) - if err != nil { - return fmt.Errorf("failed to decode initdata: %w", err) - } - - // Print each section with headers - fmt.Println("=== aa.toml ===") - if aaToml, ok := decoded["aa.toml"]; ok { - fmt.Println(strings.TrimSpace(aaToml)) - } else { - fmt.Println("(not found)") - } - fmt.Println() - - fmt.Println("=== cdh.toml ===") - if cdhToml, ok := decoded["cdh.toml"]; ok { - fmt.Println(strings.TrimSpace(cdhToml)) - } else { - fmt.Println("(not found)") - } - fmt.Println() - - fmt.Println("=== policy.rego ===") - if policy, ok := decoded["policy.rego"]; ok { - fmt.Println(strings.TrimSpace(policy)) - } else { - fmt.Println("(not found)") - } - - return nil -} - -// decodeInitdata decodes a base64+gzip encoded initdata string and extracts the data map. -func decodeInitdata(encoded string) (map[string]string, error) { - // Decode base64 - gzipData, err := base64.StdEncoding.DecodeString(encoded) - if err != nil { - return nil, fmt.Errorf("failed to decode base64: %w", err) - } - - // Decompress gzip - gzipReader, err := gzip.NewReader(bytes.NewReader(gzipData)) - if err != nil { - return nil, fmt.Errorf("failed to create gzip reader: %w", err) - } - defer func() { - _ = gzipReader.Close() - }() - - tomlData, err := io.ReadAll(gzipReader) - if err != nil { - return nil, fmt.Errorf("failed to decompress gzip data: %w", err) - } - - // Parse TOML to extract data map - var initdataStruct initdata.InitData - - if err := toml.Unmarshal(tomlData, &initdataStruct); err != nil { - return nil, fmt.Errorf("failed to parse initdata TOML: %w", err) - } - - return initdataStruct.Data, nil -} diff --git a/cmd/dump_initdata_test.go b/cmd/dump_initdata_test.go deleted file mode 100644 index 2e04906..0000000 --- a/cmd/dump_initdata_test.go +++ /dev/null @@ -1,271 +0,0 @@ -package cmd - -import ( - "bytes" - "encoding/base64" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/confidential-devhub/cococtl/pkg/config" - "github.com/confidential-devhub/cococtl/pkg/initdata" -) - -// TestDumpInitdataWithValidConfig tests dump-initdata with a valid config file. -func TestDumpInitdataWithValidConfig(t *testing.T) { - // Create temporary config file with valid TOML - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "test-config.toml") - - validConfig := `trustee_server = "http://kbs-service.trustee-operator-system.svc.cluster.local:8080" -runtime_class = "kata-cc" -` - - if err := os.WriteFile(configPath, []byte(validConfig), 0600); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - // Set the config path flag - originalConfigPath := dumpInitdataConfigPath - originalRaw := dumpInitdataRaw - defer func() { - dumpInitdataConfigPath = originalConfigPath - dumpInitdataRaw = originalRaw - }() - - dumpInitdataConfigPath = configPath - dumpInitdataRaw = false - - // Capture stdout - oldStdout := os.Stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("Failed to create pipe: %v", err) - } - - defer func() { - _ = r.Close() - }() - - os.Stdout = w - - runErr := runDumpInitdata(nil, nil) - - // Restore stdout and read captured output - _ = w.Close() - os.Stdout = oldStdout - - var buf bytes.Buffer - _, _ = buf.ReadFrom(r) - output := buf.String() - - if runErr != nil { - t.Fatalf("runDumpInitdata failed: %v", runErr) - } - - // Verify output contains expected section headers - if !strings.Contains(output, "=== aa.toml ===") { - t.Error("Output should contain '=== aa.toml ===' section header") - } - if !strings.Contains(output, "=== cdh.toml ===") { - t.Error("Output should contain '=== cdh.toml ===' section header") - } - if !strings.Contains(output, "=== policy.rego ===") { - t.Error("Output should contain '=== policy.rego ===' section header") - } - - // Verify trustee server appears in the output - if !strings.Contains(output, "kbs-service.trustee-operator-system.svc.cluster.local") { - t.Error("Output should contain the configured trustee server URL") - } -} - -// TestDumpInitdataWithRawFlag tests dump-initdata with --raw flag. -func TestDumpInitdataWithRawFlag(t *testing.T) { - // Create temporary config file - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "test-config.toml") - - validConfig := `trustee_server = "http://kbs.test.svc:8080" -runtime_class = "kata-cc" -` - - if err := os.WriteFile(configPath, []byte(validConfig), 0600); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - // Set the config path flag and enable raw mode - originalConfigPath := dumpInitdataConfigPath - originalRaw := dumpInitdataRaw - defer func() { - dumpInitdataConfigPath = originalConfigPath - dumpInitdataRaw = originalRaw - }() - - dumpInitdataConfigPath = configPath - dumpInitdataRaw = true - - // Capture stdout - oldStdout := os.Stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("Failed to create pipe: %v", err) - } - - defer func() { - _ = r.Close() - }() - - os.Stdout = w - - runErr := runDumpInitdata(nil, nil) - - // Restore stdout and read captured output - _ = w.Close() - os.Stdout = oldStdout - - var buf bytes.Buffer - _, _ = buf.ReadFrom(r) - output := buf.String() - - if runErr != nil { - t.Fatalf("runDumpInitdata with --raw failed: %v", runErr) - } - - // Find the base64 line (skip comment lines) - lines := strings.Split(strings.TrimSpace(output), "\n") - var base64Line string - for _, line := range lines { - if !strings.HasPrefix(line, "#") && len(line) > 0 { - base64Line = line - break - } - } - - if base64Line == "" { - t.Fatal("Output should contain a base64-encoded string") - } - - // Verify it's valid base64 - _, err = base64.StdEncoding.DecodeString(base64Line) - if err != nil { - t.Errorf("Output is not valid base64: %v", err) - } - - // Verify comment header is present - if !strings.Contains(output, "gzip+base64 encoded initdata") { - t.Error("Output should contain explanatory comment about gzip+base64 encoding") - } -} - -// TestDumpInitdataMissingConfig tests dump-initdata with non-existent config file. -func TestDumpInitdataMissingConfig(t *testing.T) { - // Set a non-existent config path - originalConfigPath := dumpInitdataConfigPath - originalRaw := dumpInitdataRaw - defer func() { - dumpInitdataConfigPath = originalConfigPath - dumpInitdataRaw = originalRaw - }() - - dumpInitdataConfigPath = "/nonexistent/path/to/config.toml" - dumpInitdataRaw = false - - err := runDumpInitdata(nil, nil) - - if err == nil { - t.Fatal("Expected error for missing config file, got nil") - } - - // Verify error message mentions config loading failure - if !strings.Contains(err.Error(), "failed to load config") { - t.Errorf("Error message should mention 'failed to load config', got: %v", err) - } -} - -// TestDumpInitdataInvalidConfig tests dump-initdata with invalid TOML config. -func TestDumpInitdataInvalidConfig(t *testing.T) { - // Create temporary file with invalid TOML - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "invalid-config.toml") - - invalidConfig := `trustee_server = "http://test.local -runtime_class = [broken -` - - if err := os.WriteFile(configPath, []byte(invalidConfig), 0600); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - // Set the config path - originalConfigPath := dumpInitdataConfigPath - originalRaw := dumpInitdataRaw - defer func() { - dumpInitdataConfigPath = originalConfigPath - dumpInitdataRaw = originalRaw - }() - - dumpInitdataConfigPath = configPath - dumpInitdataRaw = false - - err := runDumpInitdata(nil, nil) - - if err == nil { - t.Fatal("Expected error for invalid TOML config, got nil") - } - - // Verify error message indicates config or parsing failure - if !strings.Contains(err.Error(), "failed to load config") && !strings.Contains(err.Error(), "parse") { - t.Errorf("Error message should mention config loading or parsing failure, got: %v", err) - } -} - -// TestDecodeInitdata tests the decodeInitdata helper function. -func TestDecodeInitdata(t *testing.T) { - // Create a simple test config - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "test-config.toml") - - validConfig := `trustee_server = "http://test-kbs.default.svc:8080" -runtime_class = "kata-cc" -` - - if err := os.WriteFile(configPath, []byte(validConfig), 0600); err != nil { - t.Fatalf("Failed to create test config: %v", err) - } - - // Load config using the real config package - cfg, err := config.Load(configPath) - if err != nil { - t.Fatalf("Failed to load test config: %v", err) - } - - // Generate initdata using the real initdata package - encoded, err := initdata.Generate(cfg, nil) - if err != nil { - t.Fatalf("Failed to generate initdata: %v", err) - } - - // Now test decoding - data, err := decodeInitdata(encoded) - if err != nil { - t.Fatalf("decodeInitdata failed: %v", err) - } - - // Verify we got the expected keys - if _, ok := data["aa.toml"]; !ok { - t.Error("Decoded data should contain 'aa.toml' key") - } - if _, ok := data["cdh.toml"]; !ok { - t.Error("Decoded data should contain 'cdh.toml' key") - } - if _, ok := data["policy.rego"]; !ok { - t.Error("Decoded data should contain 'policy.rego' key") - } - - // Verify aa.toml contains the trustee server URL - if !strings.Contains(data["aa.toml"], "test-kbs.default.svc:8080") { - t.Error("aa.toml should contain the configured trustee server URL") - } -} diff --git a/cmd/initdata/common.go b/cmd/initdata/common.go new file mode 100644 index 0000000..aeba3b0 --- /dev/null +++ b/cmd/initdata/common.go @@ -0,0 +1,216 @@ +// Package initdata provides the initdata subcommand group for cococtl. +package initdata + +import ( + "bytes" + "compress/gzip" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/confidential-devhub/cococtl/pkg/config" + "github.com/pelletier/go-toml/v2" +) + +func loadCerts(path string) ([]*x509.Certificate, error) { + // #nosec G304 -- path comes from the user-provided --cacert flag + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", path, err) + } + return parsePEMCerts(data) +} + +func loadCertsFromDir(dir string) ([]*x509.Certificate, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("failed to read directory %s: %w", dir, err) + } + var all []*x509.Certificate + for _, entry := range entries { + if entry.IsDir() { + continue + } + path := filepath.Join(dir, entry.Name()) + // #nosec G304 -- path is constructed from the user-provided --capath directory + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", path, err) + } + certs, err := parsePEMCerts(data) + if err != nil { + return nil, fmt.Errorf("failed to parse certs from %s: %w", path, err) + } + if len(certs) == 0 { + fmt.Fprintf(os.Stderr, "skipping %s: no PEM blocks found\n", entry.Name()) + continue + } + all = append(all, certs...) + } + return all, nil +} + +func parsePEMCerts(data []byte) ([]*x509.Certificate, error) { + var certs []*x509.Certificate + for len(data) > 0 { + block, rest := pem.Decode(data) + if block == nil { + break + } + data = rest + if block.Type != "CERTIFICATE" { + continue + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + certs = append(certs, cert) + } + return certs, nil +} + +func validateCACert(cert *x509.Certificate) error { + if !cert.IsCA { + return fmt.Errorf("certificate %q: IsCA is false", cert.Subject.CommonName) + } + if cert.KeyUsage&x509.KeyUsageCertSign == 0 { + return fmt.Errorf("certificate %q: missing KeyUsageCertSign", cert.Subject.CommonName) + } + return nil +} + +func validateCerts(certs []*x509.Certificate) error { + var errs []string + for _, cert := range certs { + if err := validateCACert(cert); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("cert validation failed:\n %s", strings.Join(errs, "\n ")) + } + return nil +} + +func certsToPEM(certs []*x509.Certificate) string { + var sb strings.Builder + for _, cert := range certs { + _ = pem.Encode(&sb, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + } + return sb.String() +} + +func loadConfig(path string) (*config.CocoConfig, error) { + if path == "" { + var err error + path, err = config.GetConfigPath() + if err != nil { + return nil, fmt.Errorf("failed to get default config path: %w", err) + } + } + cfg, err := config.Load(path) + if err != nil { + return nil, fmt.Errorf("failed to load config from %s: %w", path, err) + } + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + return cfg, nil +} + +func loadInitdataTOML(filePath string, r io.Reader) ([]byte, error) { + if filePath != "" { + // #nosec G304 -- path comes from the user-provided --file flag + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", filePath, err) + } + return data, nil + } + encoded, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("failed to read stdin: %w", err) + } + return decompressBlob(strings.TrimSpace(string(encoded))) +} + +func decompressBlob(encoded string) ([]byte, error) { + gzipData, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("failed to decode base64: %w", err) + } + gr, err := gzip.NewReader(bytes.NewReader(gzipData)) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer func() { _ = gr.Close() }() + data, err := io.ReadAll(gr) + if err != nil { + return nil, fmt.Errorf("failed to decompress: %w", err) + } + return data, nil +} + +func extractCertsFromInitdata(data map[string]string) ([]*x509.Certificate, error) { + var pemStrings []string + + if aaToml, ok := data["aa.toml"]; ok && aaToml != "" { + var aa map[string]interface{} + if err := toml.Unmarshal([]byte(aaToml), &aa); err != nil { + return nil, fmt.Errorf("failed to parse aa.toml: %w", err) + } + if tc, ok := aa["token_configs"].(map[string]interface{}); ok { + for _, v := range tc { + if entry, ok := v.(map[string]interface{}); ok { + if cert, ok := entry["cert"].(string); ok && cert != "" { + pemStrings = append(pemStrings, cert) + } + } + } + } + } + + if cdhToml, ok := data["cdh.toml"]; ok && cdhToml != "" { + var cdh map[string]interface{} + if err := toml.Unmarshal([]byte(cdhToml), &cdh); err != nil { + return nil, fmt.Errorf("failed to parse cdh.toml: %w", err) + } + if kbc, ok := cdh["kbc"].(map[string]interface{}); ok { + if cert, ok := kbc["kbs_cert"].(string); ok && cert != "" { + pemStrings = append(pemStrings, cert) + } + } + if img, ok := cdh["image"].(map[string]interface{}); ok { + if extra, ok := img["extra_root_certificates"].([]interface{}); ok { + for _, c := range extra { + if cert, ok := c.(string); ok && cert != "" { + pemStrings = append(pemStrings, cert) + } + } + } + } + } + + seen := make(map[string]bool) + var all []*x509.Certificate + for _, pemStr := range pemStrings { + certs, err := parsePEMCerts([]byte(pemStr)) + if err != nil { + return nil, err + } + for _, cert := range certs { + fp := base64.StdEncoding.EncodeToString(cert.Raw) + if !seen[fp] { + seen[fp] = true + all = append(all, cert) + } + } + } + return all, nil +} diff --git a/cmd/initdata/common_test.go b/cmd/initdata/common_test.go new file mode 100644 index 0000000..20cf958 --- /dev/null +++ b/cmd/initdata/common_test.go @@ -0,0 +1,318 @@ +package initdata + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "testing" + "time" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" +) + +func makeTestCACert(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("create cert: %v", err) + } + cert, _ := x509.ParseCertificate(der) + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + return cert, key, pemBytes +} + +func makeTestLeafCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, []byte) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "Test Leaf"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: false, + BasicConstraintsValid: true, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, caCert, &key.PublicKey, caKey) + if err != nil { + t.Fatalf("create cert: %v", err) + } + cert, _ := x509.ParseCertificate(der) + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + return cert, pemBytes +} + +func writeTempPEM(t *testing.T, dir, name string, pemData []byte) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, pemData, 0600); err != nil { + t.Fatalf("write temp PEM: %v", err) + } + return path +} + +func TestLoadCerts_SingleCACert(t *testing.T) { + _, _, pemBytes := makeTestCACert(t) + path := writeTempPEM(t, t.TempDir(), "ca.pem", pemBytes) + certs, err := loadCerts(path) + if err != nil { + t.Fatalf("loadCerts() error: %v", err) + } + if len(certs) != 1 { + t.Errorf("got %d certs, want 1", len(certs)) + } +} + +func TestLoadCerts_MultipleCerts(t *testing.T) { + _, key, pem1 := makeTestCACert(t) + block, _ := pem.Decode(pem1) + ca, _ := x509.ParseCertificate(block.Bytes) + _, pem2 := makeTestLeafCert(t, ca, key) + combined := append(pem1, pem2...) + path := writeTempPEM(t, t.TempDir(), "bundle.pem", combined) + certs, err := loadCerts(path) + if err != nil { + t.Fatalf("loadCerts() error: %v", err) + } + if len(certs) != 2 { + t.Errorf("got %d certs, want 2", len(certs)) + } +} + +func TestLoadCerts_NonExistentFile(t *testing.T) { + _, err := loadCerts("/nonexistent/file.pem") + if err == nil { + t.Fatal("expected error for non-existent file") + } +} + +func TestLoadCerts_NoPEMBlocks(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "notapem.txt") + _ = os.WriteFile(path, []byte("hello world"), 0600) + certs, err := loadCerts(path) + if err != nil { + t.Fatalf("loadCerts() error: %v", err) + } + if len(certs) != 0 { + t.Errorf("got %d certs, want 0", len(certs)) + } +} + +func TestLoadCertsFromDir_LoadsPEMFiles(t *testing.T) { + _, _, pem1 := makeTestCACert(t) + _, _, pem2 := makeTestCACert(t) + dir := t.TempDir() + writeTempPEM(t, dir, "a.pem", pem1) + writeTempPEM(t, dir, "b.pem", pem2) + certs, err := loadCertsFromDir(dir) + if err != nil { + t.Fatalf("loadCertsFromDir() error: %v", err) + } + if len(certs) != 2 { + t.Errorf("got %d certs, want 2", len(certs)) + } +} + +func TestLoadCertsFromDir_SkipsNonPEM(t *testing.T) { + _, _, pemBytes := makeTestCACert(t) + dir := t.TempDir() + writeTempPEM(t, dir, "ca.pem", pemBytes) + _ = os.WriteFile(filepath.Join(dir, "readme.txt"), []byte("not a cert"), 0600) + certs, err := loadCertsFromDir(dir) + if err != nil { + t.Fatalf("loadCertsFromDir() error: %v", err) + } + if len(certs) != 1 { + t.Errorf("got %d certs, want 1", len(certs)) + } +} + +func TestValidateCACert_ValidCA(t *testing.T) { + cert, _, _ := makeTestCACert(t) + if err := validateCACert(cert); err != nil { + t.Errorf("validateCACert() unexpected error: %v", err) + } +} + +func TestValidateCACert_LeafCert(t *testing.T) { + caCert, caKey, _ := makeTestCACert(t) + leaf, _ := makeTestLeafCert(t, caCert, caKey) + if err := validateCACert(leaf); err == nil { + t.Error("validateCACert() should reject leaf cert") + } +} + +func TestValidateCerts_AllValid(t *testing.T) { + cert1, _, _ := makeTestCACert(t) + cert2, _, _ := makeTestCACert(t) + if err := validateCerts([]*x509.Certificate{cert1, cert2}); err != nil { + t.Errorf("validateCerts() unexpected error: %v", err) + } +} + +func TestValidateCerts_OneInvalid(t *testing.T) { + caCert, caKey, _ := makeTestCACert(t) + leaf, _ := makeTestLeafCert(t, caCert, caKey) + err := validateCerts([]*x509.Certificate{caCert, leaf}) + if err == nil { + t.Fatal("validateCerts() should fail with one leaf cert") + } + if !strings.Contains(err.Error(), "Test Leaf") { + t.Errorf("error should name the failing cert, got: %v", err) + } +} + +func TestCertsToPEM_RoundTrip(t *testing.T) { + cert, _, _ := makeTestCACert(t) + pemStr := certsToPEM([]*x509.Certificate{cert}) + certs, err := loadCerts(writeTempPEM(t, t.TempDir(), "out.pem", []byte(pemStr))) + if err != nil { + t.Fatalf("round-trip parse error: %v", err) + } + if len(certs) != 1 { + t.Errorf("got %d certs after round-trip, want 1", len(certs)) + } + if !certs[0].Equal(cert) { + t.Error("cert not equal after round-trip") + } +} + +func TestLoadInitdataTOML_FromFile(t *testing.T) { + content := []byte(`version = "0.1.0"`) + dir := t.TempDir() + path := filepath.Join(dir, "initdata.toml") + _ = os.WriteFile(path, content, 0600) + got, err := loadInitdataTOML(path, nil) + if err != nil { + t.Fatalf("loadInitdataTOML() error: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("got %q, want %q", got, content) + } +} + +func TestLoadInitdataTOML_FromReader(t *testing.T) { + raw := []byte("version = \"0.1.0\"\nalgorithm = \"sha256\"\n") + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write(raw) + _ = gz.Close() + encoded := base64.StdEncoding.EncodeToString(buf.Bytes()) + reader := strings.NewReader(encoded) + got, err := loadInitdataTOML("", reader) + if err != nil { + t.Fatalf("loadInitdataTOML() error: %v", err) + } + if !bytes.Equal(got, raw) { + t.Errorf("got %q, want %q", got, raw) + } +} + +func TestExtractCertsFromInitdata_ExtractsCerts(t *testing.T) { + _, _, pemBytes := makeTestCACert(t) + aaToml := "[token_configs]\n[token_configs.kbs]\nurl = \"http://kbs.test:8080\"\ncert = \"\"\"\n" + + string(pemBytes) + "\n\"\"\"\n" + data := map[string]string{ + "aa.toml": aaToml, + "cdh.toml": "[kbc]\nname = \"cc_kbc\"\nurl = \"http://kbs.test:8080\"\n", + "policy.rego": "package agent_policy\n", + } + certs, err := extractCertsFromInitdata(data) + if err != nil { + t.Fatalf("extractCertsFromInitdata() error: %v", err) + } + if len(certs) != 1 { + t.Errorf("got %d certs, want 1", len(certs)) + } +} + +func TestExtractCertsFromInitdata_EmptyData(t *testing.T) { + data := map[string]string{ + "aa.toml": "[token_configs]\n[token_configs.kbs]\nurl = \"http://kbs.test:8080\"\n", + "cdh.toml": "[kbc]\nname = \"cc_kbc\"\nurl = \"http://kbs.test:8080\"\n", + "policy.rego": "package agent_policy\n", + } + certs, err := extractCertsFromInitdata(data) + if err != nil { + t.Fatalf("extractCertsFromInitdata() error: %v", err) + } + if len(certs) != 0 { + t.Errorf("got %d certs, want 0", len(certs)) + } +} + +func TestLoadConfig_ValidPath(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "cfg.toml") + _ = os.WriteFile(path, []byte("trustee_server = \"http://kbs.test.svc:8080\"\nruntime_class = \"kata-cc\"\n"), 0600) + cfg, err := loadConfig(path) + if err != nil { + t.Fatalf("loadConfig() error: %v", err) + } + if cfg.TrusteeServer == "" { + t.Error("expected TrusteeServer to be set") + } +} + +func TestLoadConfig_InvalidPath(t *testing.T) { + _, err := loadConfig("/nonexistent/config.toml") + if err == nil { + t.Fatal("expected error for non-existent config file") + } +} + +func TestLoadCertsFromDir_NonRecursive(t *testing.T) { + _, _, pemBytes := makeTestCACert(t) + dir := t.TempDir() + subDir := filepath.Join(dir, "sub") + _ = os.Mkdir(subDir, 0750) + // cert inside subdir should NOT be loaded + _ = os.WriteFile(filepath.Join(subDir, "ca.pem"), pemBytes, 0600) + certs, err := loadCertsFromDir(dir) + if err != nil { + t.Fatalf("loadCertsFromDir() error: %v", err) + } + if len(certs) != 0 { + t.Errorf("got %d certs, want 0 (subdirectory should be skipped)", len(certs)) + } +} + +func TestExtractCertsFromInitdata_MalformedTOML(t *testing.T) { + data := map[string]string{ + "aa.toml": "this is not [valid toml", + "cdh.toml": "[kbc]\nname = \"cc_kbc\"\nurl = \"http://kbs.test:8080\"\n", + "policy.rego": "package agent_policy\n", + } + _, err := extractCertsFromInitdata(data) + if err == nil { + t.Fatal("expected error for malformed aa.toml") + } +} + +// Verify pkginitdata import is used (compile check) +var _ = pkginitdata.InitDataVersion diff --git a/cmd/initdata/create.go b/cmd/initdata/create.go new file mode 100644 index 0000000..353777c --- /dev/null +++ b/cmd/initdata/create.go @@ -0,0 +1,100 @@ +package initdata + +import ( + "fmt" + "os" + "path/filepath" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" + "github.com/spf13/cobra" +) + +var createCmd = &cobra.Command{ + Use: "create", + Short: "Generate initdata TOML from CoCo config and save to disk", + Long: `Generate initdata TOML from coco-config.toml and save it to disk. + +Examples: + kubectl coco initdata create + kubectl coco initdata create --cacert /path/to/ca.crt + kubectl coco initdata create --capath /etc/ssl/certs --output /tmp/initdata.toml`, + RunE: runCreate, +} + +var ( + createConfigPath string + createCACert string + createCAPath string + createOutput string +) + +func init() { + createCmd.Flags().StringVar(&createConfigPath, "config", "", "Path to CoCo config file (default: ~/.kube/coco-config.toml)") + createCmd.Flags().StringVar(&createCACert, "cacert", "", "Path to CA cert PEM file") + createCmd.Flags().StringVar(&createCAPath, "capath", "", "Path to directory of CA cert PEM files") + createCmd.Flags().StringVar(&createOutput, "output", "", "Output file for raw TOML (default: ~/.kube/coco-initdata.toml)") + createCmd.MarkFlagsMutuallyExclusive("cacert", "capath") +} + +func runCreate(_ *cobra.Command, _ []string) error { + if createCACert != "" && createCAPath != "" { + return fmt.Errorf("--cacert and --capath are mutually exclusive") + } + + cfg, err := loadConfig(createConfigPath) + if err != nil { + return err + } + + var certPEM string + switch { + case createCACert != "": + certs, err := loadCerts(createCACert) + if err != nil { + return err + } + if len(certs) == 0 { + return fmt.Errorf("--cacert %s: no certificates found", createCACert) + } + if err := validateCerts(certs); err != nil { + return err + } + certPEM = certsToPEM(certs) + case createCAPath != "": + certs, err := loadCertsFromDir(createCAPath) + if err != nil { + return err + } + if len(certs) == 0 { + return fmt.Errorf("--capath %s: no certificates found", createCAPath) + } + if err := validateCerts(certs); err != nil { + return err + } + certPEM = certsToPEM(certs) + } + + raw, err := pkginitdata.GenerateRaw(cfg, certPEM, nil) + if err != nil { + return fmt.Errorf("failed to generate initdata: %w", err) + } + + outputPath := createOutput + if outputPath == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get home directory: %w", err) + } + outputPath = filepath.Join(home, ".kube", "coco-initdata.toml") + } + + if err := os.MkdirAll(filepath.Dir(outputPath), 0750); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + if err := os.WriteFile(outputPath, raw, 0600); err != nil { + return fmt.Errorf("failed to write initdata: %w", err) + } + + fmt.Printf("Initdata written to %s\n", outputPath) + return nil +} diff --git a/cmd/initdata/create_test.go b/cmd/initdata/create_test.go new file mode 100644 index 0000000..31351a3 --- /dev/null +++ b/cmd/initdata/create_test.go @@ -0,0 +1,201 @@ +package initdata + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/pelletier/go-toml/v2" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" +) + +func makeCACertPEMFile(t *testing.T, dir, name string) string { + t.Helper() + key, _ := rsa.GenerateKey(rand.Reader, 2048) + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "CA"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + } + der, _ := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + path := filepath.Join(dir, name) + _ = os.WriteFile(path, pemBytes, 0600) + return path +} + +func makeLeafCertPEMFile(t *testing.T, dir, name string) string { + t.Helper() + caKey, _ := rsa.GenerateKey(rand.Reader, 2048) + caTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "CA"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign, + } + caDER, _ := x509.CreateCertificate(rand.Reader, caTmpl, caTmpl, &caKey.PublicKey, caKey) + caCert, _ := x509.ParseCertificate(caDER) + leafKey, _ := rsa.GenerateKey(rand.Reader, 2048) + leafTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "Leaf"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: false, + BasicConstraintsValid: true, + } + der, _ := x509.CreateCertificate(rand.Reader, leafTmpl, caCert, &leafKey.PublicKey, caKey) + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + path := filepath.Join(dir, name) + _ = os.WriteFile(path, pemBytes, 0600) + return path +} + +func makeTestConfigFile(t *testing.T, dir string) string { + t.Helper() + path := filepath.Join(dir, "coco-config.toml") + _ = os.WriteFile(path, []byte("trustee_server = \"http://kbs.test.svc:8080\"\nruntime_class = \"kata-cc\"\n"), 0600) + return path +} + +func TestRunCreate_MutuallyExclusive(t *testing.T) { + createCACert = "/some/cert.pem" + createCAPath = "/some/dir" + defer func() { createCACert = ""; createCAPath = "" }() + err := runCreate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "mutually exclusive") { + t.Errorf("expected mutually exclusive error, got: %v", err) + } +} + +func TestRunCreate_NoCert_WritesFile(t *testing.T) { + dir := t.TempDir() + createConfigPath = makeTestConfigFile(t, dir) + createCACert = "" + createCAPath = "" + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createOutput = "" }() + + if err := runCreate(nil, nil); err != nil { + t.Fatalf("runCreate() error: %v", err) + } + data, err := os.ReadFile(createOutput) + if err != nil { + t.Fatalf("output file not written: %v", err) + } + var id pkginitdata.InitData + if err := toml.Unmarshal(data, &id); err != nil { + t.Fatalf("output is not valid TOML: %v", err) + } + if id.Version != pkginitdata.InitDataVersion { + t.Errorf("version = %q, want %q", id.Version, pkginitdata.InitDataVersion) + } +} + +func TestRunCreate_WithCACert_EmbedsCert(t *testing.T) { + dir := t.TempDir() + createConfigPath = makeTestConfigFile(t, dir) + createCACert = makeCACertPEMFile(t, dir, "ca.pem") + createCAPath = "" + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createCACert = ""; createOutput = "" }() + + if err := runCreate(nil, nil); err != nil { + t.Fatalf("runCreate() error: %v", err) + } + data, _ := os.ReadFile(createOutput) + if !strings.Contains(string(data), "CERTIFICATE") { + t.Error("output TOML should contain embedded certificate") + } +} + +func TestRunCreate_WithCAPath_LoadsDir(t *testing.T) { + dir := t.TempDir() + certDir := filepath.Join(dir, "certs") + _ = os.Mkdir(certDir, 0750) + makeCACertPEMFile(t, certDir, "ca1.pem") + makeCACertPEMFile(t, certDir, "ca2.pem") + createConfigPath = makeTestConfigFile(t, dir) + createCACert = "" + createCAPath = certDir + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createCAPath = ""; createOutput = "" }() + + if err := runCreate(nil, nil); err != nil { + t.Fatalf("runCreate() error: %v", err) + } + data, _ := os.ReadFile(createOutput) + if !strings.Contains(string(data), "CERTIFICATE") { + t.Error("output TOML should contain embedded certificates") + } +} + +func TestRunCreate_RejectsLeafCert(t *testing.T) { + dir := t.TempDir() + createConfigPath = makeTestConfigFile(t, dir) + createCACert = makeLeafCertPEMFile(t, dir, "leaf.pem") + createCAPath = "" + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createCACert = ""; createOutput = "" }() + + err := runCreate(nil, nil) + if err == nil { + t.Fatal("expected error for leaf cert, got nil") + } + if !strings.Contains(err.Error(), "cert validation failed") { + t.Errorf("expected cert validation error, got: %v", err) + } +} + +func TestRunCreate_EmptyCACert_Errors(t *testing.T) { + dir := t.TempDir() + // Write a file with no PEM blocks + emptyPath := filepath.Join(dir, "empty.pem") + _ = os.WriteFile(emptyPath, []byte("not a cert\n"), 0600) + createConfigPath = makeTestConfigFile(t, dir) + createCACert = emptyPath + createCAPath = "" + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createCACert = ""; createOutput = "" }() + + err := runCreate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "no certificates found") { + t.Errorf("expected no-certs error, got: %v", err) + } +} + +func TestRunCreate_OutputFileMode(t *testing.T) { + dir := t.TempDir() + createConfigPath = makeTestConfigFile(t, dir) + createCACert = "" + createCAPath = "" + createOutput = filepath.Join(dir, "initdata.toml") + defer func() { createConfigPath = ""; createOutput = "" }() + + if err := runCreate(nil, nil); err != nil { + t.Fatalf("runCreate() error: %v", err) + } + info, err := os.Stat(createOutput) + if err != nil { + t.Fatalf("stat output file: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("file mode = %o, want 0600", info.Mode().Perm()) + } +} diff --git a/cmd/initdata/dump.go b/cmd/initdata/dump.go new file mode 100644 index 0000000..3f925f8 --- /dev/null +++ b/cmd/initdata/dump.go @@ -0,0 +1,72 @@ +package initdata + +import ( + "bytes" + "compress/gzip" + "encoding/base64" + "fmt" + "os" + "path/filepath" + + "github.com/spf13/cobra" +) + +var dumpCmd = &cobra.Command{ + Use: "dump", + Short: "Display initdata from saved TOML file", + Long: `Display initdata from the saved raw TOML file. + +Default output is the base64+gzip encoded blob ready for use as the +io.katacontainers.config.hypervisor.cc_init_data annotation. + +Use --raw to output the plaintext TOML instead. + +Examples: + kubectl coco initdata dump + kubectl coco initdata dump --raw + kubectl coco initdata dump --file /path/to/initdata.toml`, + RunE: runDump, +} + +var ( + dumpFile string + dumpRaw bool +) + +func init() { + dumpCmd.Flags().StringVar(&dumpFile, "file", "", "Path to raw initdata TOML file (default: ~/.kube/coco-initdata.toml)") + dumpCmd.Flags().BoolVar(&dumpRaw, "raw", false, "Output plaintext TOML instead of encoded blob") +} + +func runDump(_ *cobra.Command, _ []string) error { + filePath := dumpFile + if filePath == "" { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get home directory: %w", err) + } + filePath = filepath.Join(home, ".kube", "coco-initdata.toml") + } + + // #nosec G304 -- path comes from --file flag or defaults to ~/.kube/coco-initdata.toml + data, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read %s: %w", filePath, err) + } + + if dumpRaw { + _, err = os.Stdout.Write(data) + return err + } + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write(data); err != nil { + return fmt.Errorf("failed to compress: %w", err) + } + if err := gz.Close(); err != nil { + return fmt.Errorf("failed to close gzip writer: %w", err) + } + fmt.Println(base64.StdEncoding.EncodeToString(buf.Bytes())) + return nil +} diff --git a/cmd/initdata/dump_test.go b/cmd/initdata/dump_test.go new file mode 100644 index 0000000..a1382ea --- /dev/null +++ b/cmd/initdata/dump_test.go @@ -0,0 +1,74 @@ +package initdata + +import ( + "bytes" + "encoding/base64" + "os" + "strings" + "testing" + + "github.com/pelletier/go-toml/v2" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" +) + +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + old := os.Stdout + os.Stdout = w + fn() + _ = w.Close() + os.Stdout = old + var buf bytes.Buffer + _, _ = buf.ReadFrom(r) + _ = r.Close() + return buf.String() +} + +func TestRunDump_Encoded(t *testing.T) { + dumpFile = "testdata/valid.toml" + dumpRaw = false + defer func() { dumpFile = ""; dumpRaw = false }() + + var runErr error + output := captureStdout(t, func() { runErr = runDump(nil, nil) }) + if runErr != nil { + t.Fatalf("runDump() error: %v", runErr) + } + trimmed := strings.TrimSpace(output) + if _, err := base64.StdEncoding.DecodeString(trimmed); err != nil { + t.Errorf("output is not valid base64: %v", err) + } +} + +func TestRunDump_Raw(t *testing.T) { + dumpFile = "testdata/valid.toml" + dumpRaw = true + defer func() { dumpFile = ""; dumpRaw = false }() + + var runErr error + output := captureStdout(t, func() { runErr = runDump(nil, nil) }) + if runErr != nil { + t.Fatalf("runDump --raw error: %v", runErr) + } + var id pkginitdata.InitData + if err := toml.Unmarshal([]byte(output), &id); err != nil { + t.Fatalf("--raw output is not valid TOML: %v", err) + } + if id.Version != pkginitdata.InitDataVersion { + t.Errorf("version = %q, want %q", id.Version, pkginitdata.InitDataVersion) + } +} + +func TestRunDump_MissingFile(t *testing.T) { + dumpFile = "/nonexistent/initdata.toml" + dumpRaw = false + defer func() { dumpFile = "" }() + if err := runDump(nil, nil); err == nil { + t.Fatal("expected error for missing file") + } +} diff --git a/cmd/initdata/initdata.go b/cmd/initdata/initdata.go new file mode 100644 index 0000000..3002813 --- /dev/null +++ b/cmd/initdata/initdata.go @@ -0,0 +1,21 @@ +package initdata + +import "github.com/spf13/cobra" + +// InitdataCmd is the root command for initdata operations. +var InitdataCmd = &cobra.Command{ + Use: "initdata", + Short: "Manage initdata for Confidential Containers", + Long: `Commands for creating, inspecting, and validating initdata. + +Available subcommands: + create Generate initdata TOML from CoCo config and save to disk + dump Display initdata as base64+gzip blob or plaintext TOML + validate Validate initdata structure and embedded certificates`, +} + +func init() { + InitdataCmd.AddCommand(createCmd) + InitdataCmd.AddCommand(dumpCmd) + InitdataCmd.AddCommand(validateCmd) +} diff --git a/cmd/initdata/testdata/invalid-algorithm.toml b/cmd/initdata/testdata/invalid-algorithm.toml new file mode 100644 index 0000000..4304464 --- /dev/null +++ b/cmd/initdata/testdata/invalid-algorithm.toml @@ -0,0 +1,25 @@ +version = "0.1.0" +algorithm = "md5" + +[data] +"aa.toml" = ''' +[token_configs] + +[token_configs.kbs] +url = "http://kbs.example.svc:8080" + +''' + +"cdh.toml" = ''' +[kbc] +name = "cc_kbc" +url = "http://kbs.example.svc:8080" + +''' + +"policy.rego" = ''' +package agent_policy + +default ExecProcessRequest := false + +''' diff --git a/cmd/initdata/testdata/invalid-missing-cdh.toml b/cmd/initdata/testdata/invalid-missing-cdh.toml new file mode 100644 index 0000000..949e673 --- /dev/null +++ b/cmd/initdata/testdata/invalid-missing-cdh.toml @@ -0,0 +1,11 @@ +version = "0.1.0" +algorithm = "sha256" + +[data] +"aa.toml" = ''' +[token_configs] + +[token_configs.kbs] +url = "http://kbs.example.svc:8080" + +''' diff --git a/cmd/initdata/testdata/invalid-version.toml b/cmd/initdata/testdata/invalid-version.toml new file mode 100644 index 0000000..4285ab9 --- /dev/null +++ b/cmd/initdata/testdata/invalid-version.toml @@ -0,0 +1,25 @@ +version = "9.9.9" +algorithm = "sha256" + +[data] +"aa.toml" = ''' +[token_configs] + +[token_configs.kbs] +url = "http://kbs.example.svc:8080" + +''' + +"cdh.toml" = ''' +[kbc] +name = "cc_kbc" +url = "http://kbs.example.svc:8080" + +''' + +"policy.rego" = ''' +package agent_policy + +default ExecProcessRequest := false + +''' diff --git a/cmd/initdata/testdata/valid-no-policy.toml b/cmd/initdata/testdata/valid-no-policy.toml new file mode 100644 index 0000000..4cdace5 --- /dev/null +++ b/cmd/initdata/testdata/valid-no-policy.toml @@ -0,0 +1,18 @@ +version = "0.1.0" +algorithm = "sha256" + +[data] +"aa.toml" = ''' +[token_configs] + +[token_configs.kbs] +url = "http://kbs.example.svc:8080" + +''' + +"cdh.toml" = ''' +[kbc] +name = "cc_kbc" +url = "http://kbs.example.svc:8080" + +''' diff --git a/cmd/initdata/testdata/valid.toml b/cmd/initdata/testdata/valid.toml new file mode 100644 index 0000000..0e916f8 --- /dev/null +++ b/cmd/initdata/testdata/valid.toml @@ -0,0 +1,26 @@ +version = "0.1.0" +algorithm = "sha256" + +[data] +"aa.toml" = ''' +[token_configs] + +[token_configs.kbs] +url = "http://kbs.example.svc:8080" + +''' + +"cdh.toml" = ''' +[kbc] +name = "cc_kbc" +url = "http://kbs.example.svc:8080" + +''' + +"policy.rego" = ''' +package agent_policy + +default CreateContainerRequest := true +default ExecProcessRequest := false + +''' diff --git a/cmd/initdata/validate.go b/cmd/initdata/validate.go new file mode 100644 index 0000000..f2c222d --- /dev/null +++ b/cmd/initdata/validate.go @@ -0,0 +1,78 @@ +package initdata + +import ( + "fmt" + "os" + "strings" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" + "github.com/pelletier/go-toml/v2" + "github.com/spf13/cobra" +) + +var validateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate initdata structure and embedded certificates", + Long: `Validate an initdata for structural correctness and certificate validity. + +Reads from --file (plaintext TOML) or stdin (base64+gzip encoded blob). + +Checks: + - TOML parses cleanly + - version == "0.1.0" and algorithm == "sha256" + - aa.toml and cdh.toml are present (policy.rego is optional) + - Embedded CA certs in aa.toml / cdh.toml pass validation + +Examples: + kubectl coco initdata validate --file ~/.kube/coco-initdata.toml + kubectl coco initdata dump | kubectl coco initdata validate`, + RunE: runValidate, +} + +var validateFile string + +func init() { + validateCmd.Flags().StringVar(&validateFile, "file", "", "Path to plaintext initdata TOML file (reads encoded blob from stdin if not set)") +} + +func runValidate(_ *cobra.Command, _ []string) error { + tomlBytes, err := loadInitdataTOML(validateFile, os.Stdin) + if err != nil { + return fmt.Errorf("failed to load initdata: %w", err) + } + + var id pkginitdata.InitData + if err := toml.Unmarshal(tomlBytes, &id); err != nil { + return fmt.Errorf("failed to parse TOML: %w", err) + } + + var failures []string + + if id.Version != pkginitdata.InitDataVersion { + failures = append(failures, fmt.Sprintf("version: got %q, want %q", id.Version, pkginitdata.InitDataVersion)) + } + if id.Algorithm != pkginitdata.InitDataAlgorithm { + failures = append(failures, fmt.Sprintf("algorithm: got %q, want %q", id.Algorithm, pkginitdata.InitDataAlgorithm)) + } + for _, key := range []string{"aa.toml", "cdh.toml"} { + if _, ok := id.Data[key]; !ok { + failures = append(failures, fmt.Sprintf("missing required data key: %s", key)) + } + } + + certs, err := extractCertsFromInitdata(id.Data) + if err != nil { + failures = append(failures, fmt.Sprintf("cert extraction failed: %v", err)) + } else if len(certs) > 0 { + if err := validateCerts(certs); err != nil { + failures = append(failures, err.Error()) + } + } + + if len(failures) > 0 { + return fmt.Errorf("validation failed:\n %s", strings.Join(failures, "\n ")) + } + + fmt.Println("Validation passed.") + return nil +} diff --git a/cmd/initdata/validate_test.go b/cmd/initdata/validate_test.go new file mode 100644 index 0000000..a816518 --- /dev/null +++ b/cmd/initdata/validate_test.go @@ -0,0 +1,141 @@ +package initdata + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "math/big" + "os" + "strings" + "testing" + "time" + + "github.com/pelletier/go-toml/v2" + + pkginitdata "github.com/confidential-devhub/cococtl/pkg/initdata" +) + +func encodeBlobFromFile(t *testing.T, path string) string { + t.Helper() + raw, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read fixture: %v", err) + } + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write(raw) + _ = gz.Close() + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + +func TestRunValidate_ValidFile(t *testing.T) { + validateFile = "testdata/valid.toml" + defer func() { validateFile = "" }() + if err := runValidate(nil, nil); err != nil { + t.Errorf("runValidate() unexpected error: %v", err) + } +} + +func TestRunValidate_ValidNoPolicyRego(t *testing.T) { + validateFile = "testdata/valid-no-policy.toml" + defer func() { validateFile = "" }() + if err := runValidate(nil, nil); err != nil { + t.Errorf("runValidate() should accept missing policy.rego: %v", err) + } +} + +func TestRunValidate_FromStdin(t *testing.T) { + encoded := encodeBlobFromFile(t, "testdata/valid.toml") + validateFile = "" + defer func() { validateFile = "" }() + + r, w, _ := os.Pipe() + _, _ = w.WriteString(encoded) + _ = w.Close() + oldStdin := os.Stdin + os.Stdin = r + defer func() { os.Stdin = oldStdin; _ = r.Close() }() + + if err := runValidate(nil, nil); err != nil { + t.Errorf("runValidate() from stdin error: %v", err) + } +} + +func TestRunValidate_WrongVersion(t *testing.T) { + validateFile = "testdata/invalid-version.toml" + defer func() { validateFile = "" }() + + err := runValidate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "version") { + t.Errorf("expected version error, got: %v", err) + } +} + +func TestRunValidate_WrongAlgorithm(t *testing.T) { + validateFile = "testdata/invalid-algorithm.toml" + defer func() { validateFile = "" }() + + err := runValidate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "algorithm") { + t.Errorf("expected algorithm error, got: %v", err) + } +} + +func TestRunValidate_MissingRequiredKey(t *testing.T) { + validateFile = "testdata/invalid-missing-cdh.toml" + defer func() { validateFile = "" }() + + err := runValidate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "cdh.toml") { + t.Errorf("expected missing cdh.toml error, got: %v", err) + } +} + +func TestRunValidate_InvalidEmbeddedCert(t *testing.T) { + caKey, _ := rsa.GenerateKey(rand.Reader, 2048) + caTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "CA"}, + NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), + IsCA: true, BasicConstraintsValid: true, KeyUsage: x509.KeyUsageCertSign, + } + caDER, _ := x509.CreateCertificate(rand.Reader, caTmpl, caTmpl, &caKey.PublicKey, caKey) + caCert, _ := x509.ParseCertificate(caDER) + leafKey, _ := rsa.GenerateKey(rand.Reader, 2048) + leafTmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), Subject: pkix.Name{CommonName: "Bad Leaf"}, + NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), + IsCA: false, BasicConstraintsValid: true, + } + leafDER, _ := x509.CreateCertificate(rand.Reader, leafTmpl, caCert, &leafKey.PublicKey, caKey) + leafPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: leafDER}) + + id := pkginitdata.InitData{ + Version: pkginitdata.InitDataVersion, + Algorithm: pkginitdata.InitDataAlgorithm, + Data: map[string]string{ + "aa.toml": "[token_configs]\n[token_configs.kbs]\nurl = \"http://kbs.test:8080\"\ncert = \"\"\"\n" + + string(leafPEM) + "\n\"\"\"\n", + "cdh.toml": "[kbc]\nname = \"cc_kbc\"\nurl = \"http://kbs.test:8080\"\n", + }, + } + raw, err := toml.Marshal(id) + if err != nil { + t.Fatalf("marshal: %v", err) + } + path := t.TempDir() + "/initdata.toml" + if err := os.WriteFile(path, raw, 0600); err != nil { + t.Fatalf("write fixture: %v", err) + } + validateFile = path + defer func() { validateFile = "" }() + + err = runValidate(nil, nil) + if err == nil || !strings.Contains(err.Error(), "cert validation failed") { + t.Errorf("expected cert validation error, got: %v", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index 49ecd00..766c4aa 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" + "github.com/confidential-devhub/cococtl/cmd/initdata" "github.com/confidential-devhub/cococtl/cmd/kbs" ) @@ -35,6 +36,7 @@ func Execute() error { func init() { cobra.OnInitialize() rootCmd.AddCommand(kbs.KbsCmd) + rootCmd.AddCommand(initdata.InitdataCmd) } // contextKey is the type for context keys used in cococtl diff --git a/pkg/initdata/initdata.go b/pkg/initdata/initdata.go index 0c516df..4baca3e 100644 --- a/pkg/initdata/initdata.go +++ b/pkg/initdata/initdata.go @@ -6,8 +6,10 @@ import ( "compress/gzip" "encoding/base64" "fmt" + "io" "os" "path/filepath" + "sort" "strings" "github.com/confidential-devhub/cococtl/pkg/config" @@ -35,45 +37,52 @@ type ImagePullSecretInfo struct { Key string } -// Generate creates initdata based on the CoCo configuration -// imagePullSecrets is optional - pass nil if no imagePullSecrets need to be added +// Generate creates initdata based on the CoCo configuration. func Generate(cfg *config.CocoConfig, imagePullSecrets []ImagePullSecretInfo) (string, error) { + raw, err := GenerateRaw(cfg, "", imagePullSecrets) + if err != nil { + return "", err + } + encoded, err := compressAndEncode(raw) + if err != nil { + return "", fmt.Errorf("failed to compress and encode initdata: %w", err) + } + return encoded, nil +} + +// GenerateRaw returns the raw initdata TOML bytes without gzip/base64 encoding. +// When certPEM is non-empty it is used directly instead of reading cfg.TrusteeCACert. +func GenerateRaw(cfg *config.CocoConfig, certPEM string, imagePullSecrets []ImagePullSecretInfo) ([]byte, error) { if cfg.TrusteeServer == "" { - return "", fmt.Errorf("trustee server URL is required for initdata generation") + return nil, fmt.Errorf("trustee server URL is required for initdata generation") } - // Read the CA cert once so that both aa.toml and cdh.toml are guaranteed - // to embed identical content (avoids a TOCTOU window from multiple reads). - var caCert string - if cfg.TrusteeCACert != "" { + caCert := certPEM + if caCert == "" && cfg.TrusteeCACert != "" { raw, err := os.ReadFile(cfg.TrusteeCACert) if err != nil { - return "", fmt.Errorf("failed to read CA cert from %q: %w", cfg.TrusteeCACert, err) + return nil, fmt.Errorf("failed to read CA cert from %q: %w", cfg.TrusteeCACert, err) } caCert = string(raw) } - // Generate aa.toml (Attestation Agent configuration) aaToml, err := generateAAToml(cfg, caCert) if err != nil { - return "", fmt.Errorf("failed to generate aa.toml: %w", err) + return nil, fmt.Errorf("failed to generate aa.toml: %w", err) } - // Generate cdh.toml (Confidential Data Hub configuration) cdhToml, err := generateCDHToml(cfg, caCert, imagePullSecrets) if err != nil { - return "", fmt.Errorf("failed to generate cdh.toml: %w", err) + return nil, fmt.Errorf("failed to generate cdh.toml: %w", err) } - // Get policy.rego var policy string if cfg.KataAgentPolicy != "" { policy, err = loadPolicyFile(cfg.KataAgentPolicy) if err != nil { - return "", fmt.Errorf("failed to load policy file: %w", err) + return nil, fmt.Errorf("failed to load policy file: %w", err) } } else { - // Use default restrictive policy (exec disabled, logs enabled) policy = getDefaultPolicy() } @@ -87,18 +96,53 @@ func Generate(cfg *config.CocoConfig, imagePullSecrets []ImagePullSecretInfo) (s }, } - tomlData, err := toml.Marshal(id) - if err != nil { - return "", fmt.Errorf("failed to marshal initdata: %w", err) + return marshalInitData(id) +} + +// marshalInitData serialises InitData to TOML using ''' literal multi-line strings +// for data values so the output is human-readable without escape sequences. +func marshalInitData(id InitData) ([]byte, error) { + var sb strings.Builder + fmt.Fprintf(&sb, "version = %q\n", id.Version) + fmt.Fprintf(&sb, "algorithm = %q\n", id.Algorithm) + sb.WriteString("\n[data]\n") + + keys := make([]string, 0, len(id.Data)) + for k := range id.Data { + keys = append(keys, k) } + sort.Strings(keys) - // Compress with gzip and encode to base64 - encoded, err := compressAndEncode(tomlData) - if err != nil { - return "", fmt.Errorf("failed to compress and encode initdata: %w", err) + for _, k := range keys { + v := id.Data[k] + if !strings.HasSuffix(v, "\n") { + v += "\n" + } + fmt.Fprintf(&sb, "\n%q = '''\n%s'''\n", k, v) } + return []byte(sb.String()), nil +} - return encoded, nil +// Decode decodes a base64+gzip encoded initdata string and returns the data map. +func Decode(encoded string) (map[string]string, error) { + gzipData, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("failed to decode base64: %w", err) + } + gzipReader, err := gzip.NewReader(bytes.NewReader(gzipData)) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer func() { _ = gzipReader.Close() }() + tomlData, err := io.ReadAll(gzipReader) + if err != nil { + return nil, fmt.Errorf("failed to decompress gzip data: %w", err) + } + var id InitData + if err := toml.Unmarshal(tomlData, &id); err != nil { + return nil, fmt.Errorf("failed to parse initdata TOML: %w", err) + } + return id.Data, nil } // generateAAToml creates the Attestation Agent configuration. @@ -152,9 +196,7 @@ func generateCDHToml(cfg *config.CocoConfig, caCert string, imagePullSecrets []I // Add authenticated registry credentials URI // Priority: imagePullSecrets (dynamic) > config.RegistryCredURI (static) if len(imagePullSecrets) > 0 { - // CDH spec only supports a single authenticated_registry_credentials_uri - // Use the first (and typically only) imagePullSecret - // Format: kbs:///namespace/secret-name/key + // CDH spec only supports one URI; use the first entry. ips := imagePullSecrets[0] uri := fmt.Sprintf("kbs:///%s/%s/%s", ips.Namespace, ips.SecretName, ips.Key) imageConfig["authenticated_registry_credentials_uri"] = uri @@ -187,34 +229,20 @@ func generateCDHToml(cfg *config.CocoConfig, caCert string, imagePullSecrets []I // loadPolicyFile reads a policy file from disk func loadPolicyFile(path string) (string, error) { - // Validate and sanitize the path to prevent directory traversal - // Source - https://stackoverflow.com/a/57534618 - // Posted by Kenny Grant, modified by community. See post 'Timeline' for change history - // Retrieved 2025-11-14, License - CC BY-SA 4.0 cleanPath := filepath.Clean(path) - - // For absolute paths, validate they don't escape the filesystem root - // For relative paths, ensure they're relative to current directory - if filepath.IsAbs(cleanPath) { - // Absolute paths are allowed for policy files - // but ensure path doesn't contain traversal attempts - if strings.Contains(path, "..") { - return "", fmt.Errorf("invalid policy path: contains directory traversal") - } - } else { - // For relative paths, ensure they resolve within current directory + if !filepath.IsAbs(cleanPath) { cwd, err := os.Getwd() if err != nil { return "", fmt.Errorf("failed to get current directory: %w", err) } absPath := filepath.Join(cwd, cleanPath) - if !strings.HasPrefix(absPath, cwd) { - return "", fmt.Errorf("invalid policy path: escapes current directory") + rel, err := filepath.Rel(cwd, absPath) + if err != nil || strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("policy path %q escapes working directory", path) } cleanPath = absPath } - - // #nosec G304 - Path is validated above + // #nosec G304 -- path is cleaned and validated; absolute paths are intentionally allowed data, err := os.ReadFile(cleanPath) if err != nil { return "", err diff --git a/pkg/initdata/initdata_test.go b/pkg/initdata/initdata_test.go new file mode 100644 index 0000000..baa7b3c --- /dev/null +++ b/pkg/initdata/initdata_test.go @@ -0,0 +1,125 @@ +package initdata + +import ( + "os" + "strings" + "testing" + + "github.com/confidential-devhub/cococtl/pkg/config" + "github.com/pelletier/go-toml/v2" +) + +func minimalCfg() *config.CocoConfig { + return &config.CocoConfig{ + TrusteeServer: "http://kbs.test.svc:8080", + RuntimeClass: "kata-cc", + } +} + +func TestGenerateRaw_ReturnsValidTOML(t *testing.T) { + raw, err := GenerateRaw(minimalCfg(), "", nil) + if err != nil { + t.Fatalf("GenerateRaw() error: %v", err) + } + var id InitData + if err := toml.Unmarshal(raw, &id); err != nil { + t.Fatalf("output is not valid TOML: %v", err) + } + if id.Version != InitDataVersion { + t.Errorf("version = %q, want %q", id.Version, InitDataVersion) + } + if id.Algorithm != InitDataAlgorithm { + t.Errorf("algorithm = %q, want %q", id.Algorithm, InitDataAlgorithm) + } + for _, key := range []string{"aa.toml", "cdh.toml", "policy.rego"} { + if _, ok := id.Data[key]; !ok { + t.Errorf("data[%q] missing", key) + } + } +} + +func TestGenerateRaw_WithCertPEM(t *testing.T) { + const fakePEM = "FAKECERT" + raw, err := GenerateRaw(minimalCfg(), fakePEM, nil) + if err != nil { + t.Fatalf("GenerateRaw() error: %v", err) + } + if !strings.Contains(string(raw), fakePEM) { + t.Error("cert PEM not found in raw output") + } +} + +func TestGenerateRaw_NoCert_Succeeds(t *testing.T) { + raw, err := GenerateRaw(minimalCfg(), "", nil) + if err != nil { + t.Fatalf("GenerateRaw() error: %v", err) + } + if len(raw) == 0 { + t.Error("expected non-empty output") + } +} + +func TestGenerateRaw_ReadsCertFromFile(t *testing.T) { + const fakePEM = "FILECERT" + f, err := os.CreateTemp(t.TempDir(), "ca-*.pem") + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteString(fakePEM); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + + cfg := minimalCfg() + cfg.TrusteeCACert = f.Name() + + raw, err := GenerateRaw(cfg, "", nil) + if err != nil { + t.Fatalf("GenerateRaw() error: %v", err) + } + if !strings.Contains(string(raw), fakePEM) { + t.Error("cert read from file not found in raw output") + } +} + +func TestGenerateRaw_RequiresTrusteeServer(t *testing.T) { + _, err := GenerateRaw(&config.CocoConfig{}, "", nil) + if err == nil { + t.Fatal("expected error for empty TrusteeServer") + } +} + +func TestGenerate_StillWorks(t *testing.T) { + encoded, err := Generate(minimalCfg(), nil) + if err != nil { + t.Fatalf("Generate() error: %v", err) + } + if encoded == "" { + t.Error("Generate() returned empty string") + } +} + +func TestDecode_RoundTrip(t *testing.T) { + encoded, err := Generate(minimalCfg(), nil) + if err != nil { + t.Fatalf("Generate() error: %v", err) + } + data, err := Decode(encoded) + if err != nil { + t.Fatalf("Decode() error: %v", err) + } + for _, key := range []string{"aa.toml", "cdh.toml", "policy.rego"} { + if _, ok := data[key]; !ok { + t.Errorf("Decode() missing key %q", key) + } + } +} + +func TestDecode_InvalidBase64(t *testing.T) { + _, err := Decode("not-valid-base64!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +}