Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 97 additions & 37 deletions upup/pkg/fi/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ import (
"net"
"net/http"
"os"
"path"
"path/filepath"
"time"

"k8s.io/klog/v2"
"k8s.io/kops/util/pkg/hashing"
)

const downloadTimeout = 3 * time.Minute

// DownloadURL will download the file at the given url and store it as dest.
// If hash is non-nil, it will also verify that it matches the hash of the downloaded file.
func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, error) {
Expand All @@ -44,7 +46,7 @@ func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, er
}

dirMode := os.FileMode(0o755)
err := downloadURLAlways(url, dest, dirMode)
err := downloadURLAlways(url, dest, dirMode, hash)
if err != nil {
return nil, err
}
Expand All @@ -67,63 +69,121 @@ func DownloadURL(url string, dest string, hash *hashing.Hash) (*hashing.Hash, er
return hash, nil
}

func downloadURLAlways(url string, destPath string, dirMode os.FileMode) error {
err := os.MkdirAll(path.Dir(destPath), dirMode)
func downloadURLAlways(url string, destPath string, dirMode os.FileMode, hash *hashing.Hash) error {
dir := filepath.Dir(destPath)
err := os.MkdirAll(dir, dirMode)
if err != nil {
return fmt.Errorf("error creating directories for destination file %q: %v", destPath, err)
}

output, err := os.Create(destPath)
output, err := os.CreateTemp(dir, "."+filepath.Base(destPath)+".tmp")
if err != nil {
return fmt.Errorf("error creating temporary file for download %q: %v", destPath, err)
}
tempPath := output.Name()
defer os.Remove(tempPath)

_, err = DownloadURLToWriter(url, output, hash)
if closeErr := output.Close(); closeErr != nil && err == nil {
err = closeErr
}
if err != nil {
return err
}
if err := os.Chmod(tempPath, 0o644); err != nil {
return fmt.Errorf("error setting mode on downloaded file %q: %v", tempPath, err)
}
if err := os.Rename(tempPath, destPath); err != nil {
return fmt.Errorf("error moving downloaded file %q to %q: %v", tempPath, destPath, err)
}
return nil
}

// DownloadURLToWriter streams the file at the given url to dest.
// If hash is non-nil, it will also verify that it matches the downloaded bytes.
func DownloadURLToWriter(url string, dest io.Writer, hash *hashing.Hash) (*hashing.Hash, error) {
responseBody, err := OpenURL(url)
if err != nil {
return fmt.Errorf("error creating file for download %q: %v", destPath, err)
return nil, err
}
defer output.Close()
defer responseBody.Close()

klog.V(2).Infof("Downloading %q", url)

// Create a client with custom timeouts
// to avoid idle downloads to hang the program
httpClient := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
IdleConnTimeout: 30 * time.Second,
},
start := time.Now()
defer func() {
klog.V(2).Infof("Downloading %q took %q", url, time.Since(start))
}()

algorithm := hashing.HashAlgorithmSHA256
if hash != nil {
algorithm = hash.Algorithm
}
hasher := algorithm.NewHasher()
writer := io.MultiWriter(dest, hasher)

if _, err := io.Copy(writer, responseBody); err != nil {
return nil, fmt.Errorf("error downloading HTTP content from %q: %v", url, err)
}

// this will stop slow downloads after 3 minutes
// and interrupt reading of the Response.Body
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
actual := &hashing.Hash{
Algorithm: algorithm,
HashValue: hasher.Sum(nil),
}
if hash != nil && !actual.Equal(hash) {
return nil, fmt.Errorf("downloaded from %q but hash did not match expected %q", url, hash)
}
return actual, nil
}

// OpenURL opens a hardened HTTP GET stream for url.
func OpenURL(url string) (io.ReadCloser, error) {
httpClient := newDownloadHTTPClient()

ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("Cannot create request: %v", err)
cancel()
return nil, fmt.Errorf("cannot create request: %v", err)
}

response, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("error doing HTTP fetch of %q: %v", url, err)
cancel()
return nil, fmt.Errorf("error doing HTTP fetch of %q: %v", url, err)
}
defer response.Body.Close()

if response.StatusCode >= 400 {
return fmt.Errorf("error response from %q: HTTP %v", url, response.StatusCode)
if response.StatusCode < 200 || response.StatusCode > 299 {
response.Body.Close()
cancel()
return nil, fmt.Errorf("unexpected response from %q: HTTP %s", url, response.Status)
}

start := time.Now()
defer func() {
klog.V(2).Infof("Copying %q to %q took %q", url, destPath, time.Since(start))
}()
return &cancelOnCloseReadCloser{ReadCloser: response.Body, cancel: cancel}, nil
}

_, err = io.Copy(output, response.Body)
if err != nil {
return fmt.Errorf("error downloading HTTP content from %q: %v", url, err)
func newDownloadHTTPClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
IdleConnTimeout: 30 * time.Second,
},
}
return nil
}

type cancelOnCloseReadCloser struct {
io.ReadCloser
cancel context.CancelFunc
}

func (r *cancelOnCloseReadCloser) Close() error {
err := r.ReadCloser.Close()
r.cancel()
return err
}
78 changes: 78 additions & 0 deletions upup/pkg/fi/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
Copyright 2026 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package fi

import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

"k8s.io/kops/util/pkg/hashing"
)

func TestDownloadURLRejectsNon2xxAndPreservesDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound)
_, _ = w.Write([]byte("redirect body"))
}))
defer server.Close()

dest := filepath.Join(t.TempDir(), "download")
if err := os.WriteFile(dest, []byte("original"), 0o644); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}

if _, err := DownloadURL(server.URL, dest, nil); err == nil {
t.Fatalf("DownloadURL() expected error")
}

actual, err := os.ReadFile(dest)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if string(actual) != "original" {
t.Fatalf("download destination = %q, expected original contents", actual)
}
}

func TestDownloadURLToWriterVerifiesHash(t *testing.T) {
body := []byte("payload")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(body)
}))
defer server.Close()

expectedHash, err := hashing.HashAlgorithmSHA256.Hash(bytes.NewReader(body))
if err != nil {
t.Fatalf("Hash() error = %v", err)
}

var output bytes.Buffer
actualHash, err := DownloadURLToWriter(server.URL, &output, expectedHash)
if err != nil {
t.Fatalf("DownloadURLToWriter() error = %v", err)
}
if !actualHash.Equal(expectedHash) {
t.Fatalf("DownloadURLToWriter() hash = %v, expected %v", actualHash, expectedHash)
}
if !bytes.Equal(output.Bytes(), body) {
t.Fatalf("DownloadURLToWriter() body = %q, expected %q", output.Bytes(), body)
}
}
23 changes: 4 additions & 19 deletions upup/pkg/fi/nodeup/nodetasks/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path"
"path/filepath"
"reflect"
"strconv"
"strings"

"k8s.io/klog/v2"
Expand Down Expand Up @@ -159,16 +157,8 @@ func (_ *Archive) RenderLocal(t *local.LocalTarget, a, e, changes *Archive) erro
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("error creating directories %q: %v", targetDir, err)
}

args := []string{"tar", "xf", localFile, "-C", targetDir}
if e.StripComponents != 0 {
args = append(args, "--strip-components="+strconv.Itoa(e.StripComponents))
}

klog.Infof("running command %s", args)
cmd := exec.Command(args[0], args[1:]...)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error installing archive %q: %v: %s", e.Name, err, string(output))
if err := extractArchive(localFile, targetDir, e.StripComponents, ""); err != nil {
return fmt.Errorf("error installing archive %q: %v", e.Name, err)
}
} else {
for src, dest := range e.MapFiles {
Expand All @@ -177,13 +167,8 @@ func (_ *Archive) RenderLocal(t *local.LocalTarget, a, e, changes *Archive) erro
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("error creating directories %q: %v", targetDir, err)
}

args := []string{"tar", "xf", localFile, "-C", targetDir, "--wildcards", "--strip-components=" + strconv.Itoa(stripCount), src}

klog.Infof("running command %s", args)
cmd := exec.Command(args[0], args[1:]...)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error installing archive %q: %v: %s", e.Name, err, string(output))
if err := extractArchive(localFile, targetDir, stripCount, src); err != nil {
return fmt.Errorf("error installing archive %q: %v", e.Name, err)
}
}
}
Expand Down
Loading
Loading