diff --git a/upup/pkg/fi/http.go b/upup/pkg/fi/http.go index 89b458af92440..a5bbaac750fbd 100644 --- a/upup/pkg/fi/http.go +++ b/upup/pkg/fi/http.go @@ -23,13 +23,15 @@ import ( "net" "net/http" "os" - "path" + "path/filepath" "time" "k8s.io/klog/v2" "k8s.io/kops/util/pkg/hashing" ) +const downloadTimeout = 3 * time.Minute + // DownloadURL will download the file at the given url and store it as dest. // If hash is non-nil, it will also verify that it matches the hash of the downloaded file. func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, error) { @@ -44,7 +46,7 @@ func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, er } dirMode := os.FileMode(0o755) - err := downloadURLAlways(url, dest, dirMode) + err := downloadURLAlways(url, dest, dirMode, hash) if err != nil { return nil, err } @@ -67,63 +69,121 @@ func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, er return hash, nil } -func downloadURLAlways(url string, destPath string, dirMode os.FileMode) error { - err := os.MkdirAll(path.Dir(destPath), dirMode) +func downloadURLAlways(url string, destPath string, dirMode os.FileMode, hash *hashing.Hash) error { + dir := filepath.Dir(destPath) + err := os.MkdirAll(dir, dirMode) if err != nil { return fmt.Errorf("error creating directories for destination file %q: %v", destPath, err) } - output, err := os.Create(destPath) + output, err := os.CreateTemp(dir, "."+filepath.Base(destPath)+".tmp") + if err != nil { + return fmt.Errorf("error creating temporary file for download %q: %v", destPath, err) + } + tempPath := output.Name() + defer os.Remove(tempPath) + + _, err = DownloadURLToWriter(url, output, hash) + if closeErr := output.Close(); closeErr != nil && err == nil { + err = closeErr + } + if err != nil { + return err + } + if err := os.Chmod(tempPath, 0o644); err != nil { + return fmt.Errorf("error setting mode on downloaded file %q: %v", tempPath, err) + } + if err := os.Rename(tempPath, destPath); err != nil { + return fmt.Errorf("error moving downloaded file %q to %q: %v", tempPath, destPath, err) + } + return nil +} + +// DownloadURLToWriter streams the file at the given url to dest. +// If hash is non-nil, it will also verify that it matches the downloaded bytes. +func DownloadURLToWriter(url string, dest io.Writer, hash *hashing.Hash) (*hashing.Hash, error) { + responseBody, err := OpenURL(url) if err != nil { - return fmt.Errorf("error creating file for download %q: %v", destPath, err) + return nil, err } - defer output.Close() + defer responseBody.Close() klog.V(2).Infof("Downloading %q", url) - // Create a client with custom timeouts - // to avoid idle downloads to hang the program - httpClient := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 30 * time.Second, - IdleConnTimeout: 30 * time.Second, - }, + start := time.Now() + defer func() { + klog.V(2).Infof("Downloading %q took %q", url, time.Since(start)) + }() + + algorithm := hashing.HashAlgorithmSHA256 + if hash != nil { + algorithm = hash.Algorithm + } + hasher := algorithm.NewHasher() + writer := io.MultiWriter(dest, hasher) + + if _, err := io.Copy(writer, responseBody); err != nil { + return nil, fmt.Errorf("error downloading HTTP content from %q: %v", url, err) } - // this will stop slow downloads after 3 minutes - // and interrupt reading of the Response.Body - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) - defer cancel() + actual := &hashing.Hash{ + Algorithm: algorithm, + HashValue: hasher.Sum(nil), + } + if hash != nil && !actual.Equal(hash) { + return nil, fmt.Errorf("downloaded from %q but hash did not match expected %q", url, hash) + } + return actual, nil +} + +// OpenURL opens a hardened HTTP GET stream for url. +func OpenURL(url string) (io.ReadCloser, error) { + httpClient := newDownloadHTTPClient() + ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return fmt.Errorf("Cannot create request: %v", err) + cancel() + return nil, fmt.Errorf("cannot create request: %v", err) } response, err := httpClient.Do(req) if err != nil { - return fmt.Errorf("error doing HTTP fetch of %q: %v", url, err) + cancel() + return nil, fmt.Errorf("error doing HTTP fetch of %q: %v", url, err) } - defer response.Body.Close() - if response.StatusCode >= 400 { - return fmt.Errorf("error response from %q: HTTP %v", url, response.StatusCode) + if response.StatusCode < 200 || response.StatusCode > 299 { + response.Body.Close() + cancel() + return nil, fmt.Errorf("unexpected response from %q: HTTP %s", url, response.Status) } - start := time.Now() - defer func() { - klog.V(2).Infof("Copying %q to %q took %q", url, destPath, time.Since(start)) - }() + return &cancelOnCloseReadCloser{ReadCloser: response.Body, cancel: cancel}, nil +} - _, err = io.Copy(output, response.Body) - if err != nil { - return fmt.Errorf("error downloading HTTP content from %q: %v", url, err) +func newDownloadHTTPClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + IdleConnTimeout: 30 * time.Second, + }, } - return nil +} + +type cancelOnCloseReadCloser struct { + io.ReadCloser + cancel context.CancelFunc +} + +func (r *cancelOnCloseReadCloser) Close() error { + err := r.ReadCloser.Close() + r.cancel() + return err } diff --git a/upup/pkg/fi/http_test.go b/upup/pkg/fi/http_test.go new file mode 100644 index 0000000000000..893ad3b785a2c --- /dev/null +++ b/upup/pkg/fi/http_test.go @@ -0,0 +1,78 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fi + +import ( + "bytes" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "k8s.io/kops/util/pkg/hashing" +) + +func TestDownloadURLRejectsNon2xxAndPreservesDestination(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusFound) + _, _ = w.Write([]byte("redirect body")) + })) + defer server.Close() + + dest := filepath.Join(t.TempDir(), "download") + if err := os.WriteFile(dest, []byte("original"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + if _, err := DownloadURL(server.URL, dest, nil); err == nil { + t.Fatalf("DownloadURL() expected error") + } + + actual, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + if string(actual) != "original" { + t.Fatalf("download destination = %q, expected original contents", actual) + } +} + +func TestDownloadURLToWriterVerifiesHash(t *testing.T) { + body := []byte("payload") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(body) + })) + defer server.Close() + + expectedHash, err := hashing.HashAlgorithmSHA256.Hash(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + var output bytes.Buffer + actualHash, err := DownloadURLToWriter(server.URL, &output, expectedHash) + if err != nil { + t.Fatalf("DownloadURLToWriter() error = %v", err) + } + if !actualHash.Equal(expectedHash) { + t.Fatalf("DownloadURLToWriter() hash = %v, expected %v", actualHash, expectedHash) + } + if !bytes.Equal(output.Bytes(), body) { + t.Fatalf("DownloadURLToWriter() body = %q, expected %q", output.Bytes(), body) + } +} diff --git a/upup/pkg/fi/nodeup/nodetasks/archive.go b/upup/pkg/fi/nodeup/nodetasks/archive.go index bcd90f030681b..a554e32c0ee52 100644 --- a/upup/pkg/fi/nodeup/nodetasks/archive.go +++ b/upup/pkg/fi/nodeup/nodetasks/archive.go @@ -20,11 +20,9 @@ import ( "encoding/json" "fmt" "os" - "os/exec" "path" "path/filepath" "reflect" - "strconv" "strings" "k8s.io/klog/v2" @@ -159,16 +157,8 @@ func (_ *Archive) RenderLocal(t *local.LocalTarget, a, e, changes *Archive) erro if err := os.MkdirAll(targetDir, 0o755); err != nil { return fmt.Errorf("error creating directories %q: %v", targetDir, err) } - - args := []string{"tar", "xf", localFile, "-C", targetDir} - if e.StripComponents != 0 { - args = append(args, "--strip-components="+strconv.Itoa(e.StripComponents)) - } - - klog.Infof("running command %s", args) - cmd := exec.Command(args[0], args[1:]...) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("error installing archive %q: %v: %s", e.Name, err, string(output)) + if err := extractArchive(localFile, targetDir, e.StripComponents, ""); err != nil { + return fmt.Errorf("error installing archive %q: %v", e.Name, err) } } else { for src, dest := range e.MapFiles { @@ -177,13 +167,8 @@ func (_ *Archive) RenderLocal(t *local.LocalTarget, a, e, changes *Archive) erro if err := os.MkdirAll(targetDir, 0o755); err != nil { return fmt.Errorf("error creating directories %q: %v", targetDir, err) } - - args := []string{"tar", "xf", localFile, "-C", targetDir, "--wildcards", "--strip-components=" + strconv.Itoa(stripCount), src} - - klog.Infof("running command %s", args) - cmd := exec.Command(args[0], args[1:]...) - if output, err := cmd.CombinedOutput(); err != nil { - return fmt.Errorf("error installing archive %q: %v: %s", e.Name, err, string(output)) + if err := extractArchive(localFile, targetDir, stripCount, src); err != nil { + return fmt.Errorf("error installing archive %q: %v", e.Name, err) } } } diff --git a/upup/pkg/fi/nodeup/nodetasks/archive_extract.go b/upup/pkg/fi/nodeup/nodetasks/archive_extract.go new file mode 100644 index 0000000000000..259a2e7b6d8f5 --- /dev/null +++ b/upup/pkg/fi/nodeup/nodetasks/archive_extract.go @@ -0,0 +1,360 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package nodetasks + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +func extractArchive(archivePath, targetDir string, stripComponents int, pattern string) error { + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("error creating directories %q: %v", targetDir, err) + } + + targetRoot, err := filepath.Abs(targetDir) + if err != nil { + return fmt.Errorf("error resolving target directory %q: %v", targetDir, err) + } + targetRoot, err = filepath.EvalSymlinks(targetRoot) + if err != nil { + return fmt.Errorf("error resolving target directory %q: %v", targetDir, err) + } + + tarReader, closeArchive, err := openArchive(archivePath) + if err != nil { + return err + } + defer closeArchive() + + for { + header, err := tarReader.Next() + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("error reading archive %q: %v", archivePath, err) + } + if pattern != "" && !archivePathMatches(pattern, header.Name) { + continue + } + if err := extractArchiveEntry(tarReader, header, targetRoot, stripComponents); err != nil { + return fmt.Errorf("error extracting %q from %q: %v", header.Name, archivePath, err) + } + } +} + +func openArchive(archivePath string) (*tar.Reader, func() error, error) { + file, err := os.Open(archivePath) + if err != nil { + return nil, nil, fmt.Errorf("error opening archive %q: %v", archivePath, err) + } + + reader, gzipCloser, err := maybeGzipReader(file) + if err != nil { + file.Close() + return nil, nil, fmt.Errorf("error reading archive %q: %v", archivePath, err) + } + + closeArchive := func() error { + if gzipCloser != nil { + if err := gzipCloser.Close(); err != nil { + _ = file.Close() + return err + } + } + return file.Close() + } + return tar.NewReader(reader), closeArchive, nil +} + +func archivePathMatches(pattern, name string) bool { + if pattern == name { + return true + } + matches, err := path.Match(pattern, name) + return err == nil && matches +} + +func extractArchiveEntry(reader io.Reader, header *tar.Header, targetRoot string, stripComponents int) error { + targetPath, ok, err := archiveTargetPath(targetRoot, header.Name, stripComponents) + if err != nil || !ok { + return err + } + + switch header.Typeflag { + case tar.TypeDir: + return extractArchiveDirectory(targetRoot, targetPath, header) + case tar.TypeReg, tar.TypeRegA: + return extractArchiveRegularFile(reader, targetRoot, targetPath, header) + case tar.TypeSymlink: + return extractArchiveSymlink(targetRoot, targetPath, header) + case tar.TypeLink: + return extractArchiveHardlink(targetRoot, targetPath, header, stripComponents) + case tar.TypeXGlobalHeader, tar.TypeXHeader: + return nil + default: + return fmt.Errorf("unsupported archive entry type %q", header.Typeflag) + } +} + +func archiveTargetPath(targetRoot, name string, stripComponents int) (string, bool, error) { + relativePath, ok, err := archiveRelativePath(name, stripComponents) + if err != nil || !ok { + return "", ok, err + } + targetPath := filepath.Join(targetRoot, filepath.FromSlash(relativePath)) + if err := ensurePathInside(targetRoot, targetPath); err != nil { + return "", false, err + } + return targetPath, true, nil +} + +func archiveRelativePath(name string, stripComponents int) (string, bool, error) { + if stripComponents < 0 { + return "", false, fmt.Errorf("strip components cannot be negative") + } + if name == "" { + return "", false, nil + } + if path.IsAbs(name) { + return "", false, fmt.Errorf("archive path %q is absolute", name) + } + + var components []string + for _, component := range strings.Split(name, "/") { + switch component { + case "", ".": + continue + case "..": + return "", false, fmt.Errorf("archive path %q escapes target directory", name) + default: + components = append(components, component) + } + } + + if stripComponents > len(components) { + return "", false, nil + } + components = components[stripComponents:] + if len(components) == 0 { + return "", false, nil + } + return path.Join(components...), true, nil +} + +func extractArchiveDirectory(targetRoot, targetPath string, header *tar.Header) error { + if err := ensureParentReady(targetRoot, targetPath); err != nil { + return err + } + if err := ensureExistingSymlinkInside(targetRoot, targetPath); err != nil { + return err + } + mode := archiveFileMode(header, 0o755) + if err := os.MkdirAll(targetPath, mode); err != nil { + return fmt.Errorf("error creating directory %q: %v", targetPath, err) + } + if err := ensureExistingSymlinkInside(targetRoot, targetPath); err != nil { + return err + } + if err := os.Chmod(targetPath, mode); err != nil { + return fmt.Errorf("error setting mode on directory %q: %v", targetPath, err) + } + return setArchiveModTime(targetPath, header) +} + +func extractArchiveRegularFile(reader io.Reader, targetRoot, targetPath string, header *tar.Header) error { + if err := ensureParentReady(targetRoot, targetPath); err != nil { + return err + } + if err := removeExistingNonDirectory(targetPath); err != nil { + return err + } + + mode := archiveFileMode(header, 0o644) + file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("error creating file %q: %v", targetPath, err) + } + if _, err := io.Copy(file, reader); err != nil { + _ = file.Close() + return fmt.Errorf("error writing file %q: %v", targetPath, err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("error closing file %q: %v", targetPath, err) + } + if err := os.Chmod(targetPath, mode); err != nil { + return fmt.Errorf("error setting mode on file %q: %v", targetPath, err) + } + return setArchiveModTime(targetPath, header) +} + +func extractArchiveSymlink(targetRoot, targetPath string, header *tar.Header) error { + if header.Linkname == "" { + return fmt.Errorf("symlink %q has empty target", header.Name) + } + if err := ensureParentReady(targetRoot, targetPath); err != nil { + return err + } + if err := removeExistingNonDirectory(targetPath); err != nil { + return err + } + if err := os.Symlink(header.Linkname, targetPath); err != nil { + return fmt.Errorf("error creating symlink %q -> %q: %v", targetPath, header.Linkname, err) + } + return nil +} + +func extractArchiveHardlink(targetRoot, targetPath string, header *tar.Header, stripComponents int) error { + linkPath, ok, err := archiveTargetPath(targetRoot, header.Linkname, stripComponents) + if err != nil || !ok { + return err + } + if err := ensureExistingSymlinkInside(targetRoot, linkPath); err != nil { + return err + } + if err := ensureParentReady(targetRoot, targetPath); err != nil { + return err + } + if err := removeExistingNonDirectory(targetPath); err != nil { + return err + } + if err := os.Link(linkPath, targetPath); err != nil { + return fmt.Errorf("error creating hardlink %q -> %q: %v", targetPath, linkPath, err) + } + return setArchiveModTime(targetPath, header) +} + +func archiveFileMode(header *tar.Header, defaultMode os.FileMode) os.FileMode { + mode := os.FileMode(header.Mode) & os.ModePerm + if mode == 0 { + return defaultMode + } + return mode +} + +func setArchiveModTime(targetPath string, header *tar.Header) error { + if header.ModTime.IsZero() { + return nil + } + if err := os.Chtimes(targetPath, header.ModTime, header.ModTime); err != nil { + return fmt.Errorf("error setting timestamps on %q: %v", targetPath, err) + } + return nil +} + +func ensureParentReady(targetRoot, targetPath string) error { + if err := ensureParentPathSafe(targetRoot, targetPath); err != nil { + return err + } + parent := filepath.Dir(targetPath) + if err := os.MkdirAll(parent, 0o755); err != nil { + return fmt.Errorf("error creating directories %q: %v", parent, err) + } + return ensureParentPathSafe(targetRoot, targetPath) +} + +func ensureParentPathSafe(targetRoot, targetPath string) error { + parent := filepath.Dir(targetPath) + if err := ensurePathInside(targetRoot, parent); err != nil { + return err + } + + relative, err := filepath.Rel(targetRoot, parent) + if err != nil { + return fmt.Errorf("error checking path %q: %v", parent, err) + } + if relative == "." { + return nil + } + + current := targetRoot + for _, component := range strings.Split(relative, string(os.PathSeparator)) { + current = filepath.Join(current, component) + info, err := os.Lstat(current) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("error checking path %q: %v", current, err) + } + if info.Mode()&os.ModeSymlink != 0 { + if err := ensureExistingSymlinkInside(targetRoot, current); err != nil { + return err + } + continue + } + if !info.IsDir() { + return fmt.Errorf("archive parent path %q is not a directory", current) + } + } + return nil +} + +func ensureExistingSymlinkInside(targetRoot, targetPath string) error { + info, err := os.Lstat(targetPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("error checking path %q: %v", targetPath, err) + } + if info.Mode()&os.ModeSymlink == 0 { + return nil + } + resolved, err := filepath.EvalSymlinks(targetPath) + if err != nil { + return fmt.Errorf("error resolving symlink %q: %v", targetPath, err) + } + if err := ensurePathInside(targetRoot, resolved); err != nil { + return fmt.Errorf("archive path %q resolves outside target directory: %v", targetPath, err) + } + return nil +} + +func removeExistingNonDirectory(targetPath string) error { + info, err := os.Lstat(targetPath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return fmt.Errorf("error checking path %q: %v", targetPath, err) + } + if info.IsDir() { + return fmt.Errorf("cannot replace directory %q with archive entry", targetPath) + } + if err := os.Remove(targetPath); err != nil { + return fmt.Errorf("error removing existing path %q: %v", targetPath, err) + } + return nil +} + +func ensurePathInside(targetRoot, targetPath string) error { + relative, err := filepath.Rel(targetRoot, targetPath) + if err != nil { + return fmt.Errorf("error checking path %q: %v", targetPath, err) + } + if relative == ".." || strings.HasPrefix(relative, ".."+string(os.PathSeparator)) { + return fmt.Errorf("path %q is outside target directory %q", targetPath, targetRoot) + } + return nil +} diff --git a/upup/pkg/fi/nodeup/nodetasks/archive_test.go b/upup/pkg/fi/nodeup/nodetasks/archive_test.go index f1a4f758d884e..670040626c64e 100644 --- a/upup/pkg/fi/nodeup/nodetasks/archive_test.go +++ b/upup/pkg/fi/nodeup/nodetasks/archive_test.go @@ -17,6 +17,12 @@ limitations under the License. package nodetasks import ( + "archive/tar" + "bytes" + "compress/gzip" + "io" + "os" + "path/filepath" "testing" "k8s.io/kops/upup/pkg/fi" @@ -63,3 +69,170 @@ func TestArchiveDependencies(t *testing.T) { } } } + +func TestExtractArchive(t *testing.T) { + archivePath := writeTestArchive(t, true, []testTarEntry{ + { + name: "root/bin/tool", + mode: 0o755, + body: "hello", + }, + { + name: "root/etc/config", + mode: 0o644, + body: "config", + }, + }) + + targetDir := t.TempDir() + if err := extractArchive(archivePath, targetDir, 1, ""); err != nil { + t.Fatalf("extractArchive() error = %v", err) + } + + assertFileContents(t, filepath.Join(targetDir, "bin/tool"), "hello") + assertFileContents(t, filepath.Join(targetDir, "etc/config"), "config") +} + +func TestExtractArchiveMapFiles(t *testing.T) { + archivePath := writeTestArchive(t, false, []testTarEntry{ + { + name: "pkg/bin/tool", + mode: 0o755, + body: "hello", + }, + { + name: "pkg/lib/ignored", + mode: 0o644, + body: "ignored", + }, + }) + + targetDir := t.TempDir() + if err := extractArchive(archivePath, targetDir, 2, "pkg/bin/*"); err != nil { + t.Fatalf("extractArchive() error = %v", err) + } + + assertFileContents(t, filepath.Join(targetDir, "tool"), "hello") + if _, err := os.Stat(filepath.Join(targetDir, "ignored")); !os.IsNotExist(err) { + t.Fatalf("expected ignored file not to be extracted, stat error = %v", err) + } +} + +func TestExtractArchiveRejectsTraversal(t *testing.T) { + baseDir := t.TempDir() + archivePath := writeTestArchive(t, false, []testTarEntry{ + { + name: "../evil", + mode: 0o644, + body: "bad", + }, + }) + + err := extractArchive(archivePath, filepath.Join(baseDir, "target"), 0, "") + if err == nil { + t.Fatalf("extractArchive() expected error") + } + if _, err := os.Stat(filepath.Join(baseDir, "evil")); !os.IsNotExist(err) { + t.Fatalf("expected traversal target not to be created, stat error = %v", err) + } +} + +func TestExtractArchiveRejectsSymlinkEscape(t *testing.T) { + outsideDir := t.TempDir() + archivePath := writeTestArchive(t, false, []testTarEntry{ + { + name: "escape", + typeflag: tar.TypeSymlink, + linkname: outsideDir, + }, + { + name: "escape/pwned", + mode: 0o644, + body: "bad", + }, + }) + + err := extractArchive(archivePath, t.TempDir(), 0, "") + if err == nil { + t.Fatalf("extractArchive() expected error") + } + if _, err := os.Stat(filepath.Join(outsideDir, "pwned")); !os.IsNotExist(err) { + t.Fatalf("expected symlink escape target not to be created, stat error = %v", err) + } +} + +type testTarEntry struct { + name string + typeflag byte + linkname string + mode int64 + body string +} + +func writeTestArchive(t *testing.T, gzipArchive bool, entries []testTarEntry) string { + t.Helper() + + var buffer bytes.Buffer + var output io.Writer = &buffer + var gzipWriter *gzip.Writer + if gzipArchive { + gzipWriter = gzip.NewWriter(&buffer) + output = gzipWriter + } + writer := tar.NewWriter(output) + + for _, entry := range entries { + typeflag := entry.typeflag + if typeflag == 0 { + typeflag = tar.TypeReg + } + mode := entry.mode + if mode == 0 { + mode = 0o644 + } + header := &tar.Header{ + Name: entry.name, + Typeflag: typeflag, + Linkname: entry.linkname, + Mode: mode, + Size: int64(len(entry.body)), + } + if typeflag != tar.TypeReg { + header.Size = 0 + } + if err := writer.WriteHeader(header); err != nil { + t.Fatalf("WriteHeader() error = %v", err) + } + if header.Size != 0 { + if _, err := writer.Write([]byte(entry.body)); err != nil { + t.Fatalf("Write() error = %v", err) + } + } + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if gzipWriter != nil { + if err := gzipWriter.Close(); err != nil { + t.Fatalf("gzip Close() error = %v", err) + } + } + + archivePath := filepath.Join(t.TempDir(), "archive.tar") + if err := os.WriteFile(archivePath, buffer.Bytes(), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + return archivePath +} + +func assertFileContents(t *testing.T, path string, expected string) { + t.Helper() + actual, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile(%q) error = %v", path, err) + } + if string(actual) != expected { + t.Fatalf("ReadFile(%q) = %q, expected %q", path, actual, expected) + } +} diff --git a/upup/pkg/fi/nodeup/nodetasks/compression.go b/upup/pkg/fi/nodeup/nodetasks/compression.go new file mode 100644 index 0000000000000..b3f0c84c1152d --- /dev/null +++ b/upup/pkg/fi/nodeup/nodetasks/compression.go @@ -0,0 +1,42 @@ +/* +Copyright 2026 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package nodetasks + +import ( + "bufio" + "bytes" + "compress/gzip" + "io" +) + +var gzipMagic = []byte{0x1f, 0x8b, 0x08} + +func maybeGzipReader(r io.Reader) (io.Reader, io.Closer, error) { + buffered := bufio.NewReader(r) + header, err := buffered.Peek(len(gzipMagic)) + if err != nil && err != io.EOF { + return nil, nil, err + } + if len(header) == len(gzipMagic) && bytes.Equal(header, gzipMagic) { + gzipReader, err := gzip.NewReader(buffered) + if err != nil { + return nil, nil, err + } + return gzipReader, gzipReader, nil + } + return buffered, nil, nil +} diff --git a/upup/pkg/fi/nodeup/nodetasks/load_image.go b/upup/pkg/fi/nodeup/nodetasks/load_image.go index c1ab2cdc7d2db..29db59b04a59b 100644 --- a/upup/pkg/fi/nodeup/nodetasks/load_image.go +++ b/upup/pkg/fi/nodeup/nodetasks/load_image.go @@ -18,17 +18,14 @@ package nodetasks import ( "fmt" - "os" + "io" "os/exec" - "path" - "path/filepath" "strings" "k8s.io/klog/v2" "k8s.io/kops/pkg/backoff" "k8s.io/kops/upup/pkg/fi" "k8s.io/kops/upup/pkg/fi/nodeup/local" - "k8s.io/kops/upup/pkg/fi/utils" "k8s.io/kops/util/pkg/hashing" ) @@ -87,7 +84,7 @@ func (_ *LoadImageTask) CheckChanges(a, e, changes *LoadImageTask) error { return nil } -func (_ *LoadImageTask) RenderLocal(t *local.LocalTarget, a, e, changes *LoadImageTask) error { +func (_ *LoadImageTask) RenderLocal(_ *local.LocalTarget, a, e, changes *LoadImageTask) error { hash, err := hashing.FromString(e.Hash) if err != nil { return err @@ -98,15 +95,11 @@ func (_ *LoadImageTask) RenderLocal(t *local.LocalTarget, a, e, changes *LoadIma return fmt.Errorf("no sources specified: %v", err) } - // We assume the first url is the "main" url, and download to a local file based on that _name_, wherever we get it from primaryURL := urls[0] - key := path.Base(primaryURL) - localFile := filepath.Join(t.CacheDir, hash.String()+"_"+utils.SanitizeString(key)) - for _, url := range urls { - _, err = fi.DownloadURL(url, localFile, hash) + err = importContainerImage(url, hash) if err != nil { - klog.Warningf("error downloading url %q: %v", url, err) + klog.Warningf("error importing image from url %q: %v", url, err) continue } else { break @@ -114,43 +107,73 @@ func (_ *LoadImageTask) RenderLocal(t *local.LocalTarget, a, e, changes *LoadIma } if err != nil { // Hack to try to avoid failed downloads causing massive bandwidth bills - backoff.DoGlobalBackoff(fmt.Errorf("failed to download image %s: %v", primaryURL, err)) + backoff.DoGlobalBackoff(fmt.Errorf("failed to import image %s: %v", primaryURL, err)) return err } - // containerd can't import gzipped container images, if the image is gzipped extract it to tmp dir - // TODO: Improve the naive gzip format detection by checking the content type bytes "\x1F\x8B\x08" - var tarFile string - if strings.HasSuffix(localFile, "gz") { - tmpDir, err := os.MkdirTemp("", "loadimage") - if err != nil { - return fmt.Errorf("error creating temp dir: %v", err) - } - defer func() { - if err := os.RemoveAll(tmpDir); err != nil { - klog.Warningf("error deleting temp dir %q: %v", tmpDir, err) - } - }() - tarFile = path.Join(tmpDir, utils.SanitizeString(primaryURL)) - err = utils.UngzipFile(localFile, tarFile) - if err != nil { - return fmt.Errorf("error ungzipping container image: %v", err) - } - } else { - // Assume container image is tar file alerady - tarFile = localFile + return nil +} + +func importContainerImage(url string, expectedHash *hashing.Hash) error { + responseBody, err := fi.OpenURL(url) + if err != nil { + return err + } + defer responseBody.Close() + + imageReader, verifyHash, imageCloser, err := imageImportReader(responseBody, expectedHash) + if err != nil { + return err } - // Load the container image - args := []string{"ctr", "--namespace", "k8s.io", "images", "import", tarFile} + args := containerImageImportArgs() human := strings.Join(args, " ") klog.Infof("running command %s", human) cmd := exec.Command(args[0], args[1:]...) + cmd.Stdin = imageReader output, err := cmd.CombinedOutput() + if imageCloser != nil { + if closeErr := imageCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + } if err != nil { return fmt.Errorf("error loading docker image with '%s': %v: %s", human, err, string(output)) } + if err := verifyHash(); err != nil { + return err + } return nil } + +func containerImageImportArgs() []string { + return []string{"ctr", "--namespace", "k8s.io", "images", "import", "--no-unpack", "-"} +} + +func imageImportReader(r io.Reader, expectedHash *hashing.Hash) (io.Reader, func() error, io.Closer, error) { + algorithm := hashing.HashAlgorithmSHA256 + if expectedHash != nil { + algorithm = expectedHash.Algorithm + } + + hasher := algorithm.NewHasher() + hashedReader := io.TeeReader(r, hasher) + imageReader, closer, err := maybeGzipReader(hashedReader) + if err != nil { + return nil, nil, nil, fmt.Errorf("error reading container image stream: %v", err) + } + + verifyHash := func() error { + actualHash := &hashing.Hash{ + Algorithm: algorithm, + HashValue: hasher.Sum(nil), + } + if expectedHash != nil && !actualHash.Equal(expectedHash) { + return fmt.Errorf("downloaded container image but hash did not match expected %q", expectedHash) + } + return nil + } + return imageReader, verifyHash, closer, nil +} diff --git a/upup/pkg/fi/nodeup/nodetasks/loadimage_test.go b/upup/pkg/fi/nodeup/nodetasks/loadimage_test.go index 5d932e2ae1a94..ea87a9fc14cd3 100644 --- a/upup/pkg/fi/nodeup/nodetasks/loadimage_test.go +++ b/upup/pkg/fi/nodeup/nodetasks/loadimage_test.go @@ -17,10 +17,14 @@ limitations under the License. package nodetasks import ( + "bytes" + "compress/gzip" + "io" "reflect" "testing" "k8s.io/kops/upup/pkg/fi" + "k8s.io/kops/util/pkg/hashing" ) func TestLoadImageTask_Deps(t *testing.T) { @@ -38,3 +42,49 @@ func TestLoadImageTask_Deps(t *testing.T) { t.Fatalf("unexpected deps. expected=%v, actual=%v", expected, deps) } } + +func TestImageImportReaderUngzipsAndHashesDownload(t *testing.T) { + image := []byte("container image tar") + var compressed bytes.Buffer + gzipWriter := gzip.NewWriter(&compressed) + if _, err := gzipWriter.Write(image); err != nil { + t.Fatalf("Write() error = %v", err) + } + if err := gzipWriter.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + expectedHash, err := hashing.HashAlgorithmSHA256.Hash(bytes.NewReader(compressed.Bytes())) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + reader, verifyHash, closer, err := imageImportReader(bytes.NewReader(compressed.Bytes()), expectedHash) + if err != nil { + t.Fatalf("imageImportReader() error = %v", err) + } + + actual, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if closer != nil { + if err := closer.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + } + if !bytes.Equal(actual, image) { + t.Fatalf("imageImportReader() body = %q, expected %q", actual, image) + } + if err := verifyHash(); err != nil { + t.Fatalf("verifyHash() error = %v", err) + } +} + +func TestContainerImageImportArgsDoesNotUnpack(t *testing.T) { + args := containerImageImportArgs() + expected := []string{"ctr", "--namespace", "k8s.io", "images", "import", "--no-unpack", "-"} + if !reflect.DeepEqual(args, expected) { + t.Fatalf("containerImageImportArgs() = %v, expected %v", args, expected) + } +}