CVE-2022-30322
Description
go-getter up to 1.5.11 and 2.0.2 allowed asymmetric resource exhaustion when go-getter processed malicious HTTP responses. Fixed in 1.6.1 and 2.1.0.
AI Insight
LLM-synthesized narrative grounded in this CVE's description and references.
go-getter before 1.6.1 and 2.1.0 allowed asymmetric resource exhaustion via malicious HTTP responses, enabling denial of service.
Vulnerability
HashiCorp's go-getter library, used by Terraform and Nomad for downloading files or directories from various sources, contains a denial-of-service vulnerability in versions up to 1.5.11 (v1 branch) and 2.0.2 (v2 branch). The 'asymmetric resource exhaustion' occurs when go-getter processes crafted HTTP responses that trigger excessive resource consumption on the client side [1][2][4]. The vulnerable code path is reachable whenever an application uses go-getter to fetch resources from attacker-controlled or malicious URLs.
Exploitation
An attacker must be able to serve a malicious HTTP response to a user of an application that leverages go-getter, such as when Terraform resolves a module source or Nomad downloads a binary [1][4]. The attacker does not require authentication; they can position a crafted HTTP endpoint accessible to the victim. When go-getter processes the attacker's malicious response, the asymmetry in resource usage causes the client to exhaust CPU or memory while the server expends minimal effort [2]. No user interaction beyond the victim initiating a legitimate download is required.
Impact
Successful exploitation leads to a denial-of-service condition on the client system. The attacker causes disproportionate resource consumption (CPU, memory, or both) without needing corresponding investment on their own side [2][4]. This can crash the client application or degrade system performance, impacting availability. No data confidentiality or integrity is directly compromised.
Mitigation
All users should upgrade to go-getter 1.6.1 or 2.1.0, which fix the asymmetric resource exhaustion issue [2][3][4]. Fixes were committed in pull requests #359 (v1) and #361 (v2) and released around May 2022. There is no known workaround; the vulnerable versions are no longer maintained. The vulnerability is not listed on CISA's Known Exploited Vulnerabilities (KEV) catalog.
- GitHub - hashicorp/go-getter: Package for downloading things from a string URL using a variety of protocols.
- Multiple fixes for go-getter by eastebry · Pull Request #359 · hashicorp/go-getter
- Multiple fixes for go-getter (#359) · hashicorp/go-getter@a2ebce9
- Multiple fixes for go-getter v2 (#361) · hashicorp/go-getter@38e9738
AI Insight generated on May 21, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/hashicorp/go-getterGo | < 1.6.1 | 1.6.1 |
github.com/hashicorp/go-getterGo | >= 2.0.0, < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/v2Go | < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/s3/v2Go | < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/gcs/v2Go | < 2.1.0 | 2.1.0 |
Affected products
5- go-getter/go-getterdescription
- ghsa-coords4 versionspkg:golang/github.com/hashicorp/go-getterpkg:golang/github.com/hashicorp/go-getter/gcs/v2pkg:golang/github.com/hashicorp/go-getter/s3/v2pkg:golang/github.com/hashicorp/go-getter/v2
< 1.6.1+ 3 more
- (no CPE)range: < 1.6.1
- (no CPE)range: < 2.1.0
- (no CPE)range: < 2.1.0
- (no CPE)range: < 2.1.0
Patches
238e97387488fMultiple fixes for go-getter v2 (#361)
21 files changed · +1470 −164
.circleci/config.yml+9 −9 modified@@ -49,11 +49,11 @@ commands: jobs: linux-tests: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string - environment: + environment: <<: *ENVIRONMENT parallelism: 4 steps: @@ -104,7 +104,7 @@ jobs: path: *TEST_RESULTS_PATH windows-tests: - executor: + executor: name: win/default shell: bash --login -eo pipefail environment: @@ -115,12 +115,12 @@ jobs: type: string gotestsum-version: type: string - steps: + steps: - run: git config --global core.autocrlf false - checkout - attach_workspace: at: . - - run: + - run: name: Setup (remove pre-installed go) command: | rm -rf "c:\Go" @@ -131,16 +131,16 @@ jobs: - win-golang-<< parameters.go-version >>-cache-v1 - win-gomod-cache-{{ checksum "go.mod" }}-v1 - - run: + - run: name: Install go version << parameters.go-version >> - command: | + command: | if [ ! -d "c:\go" ]; then echo "Cache not found, installing new version of go" curl --fail --location https://dl.google.com/go/go<< parameters.go-version >>.windows-amd64.zip --output go.zip unzip go.zip -d "/c" fi - - run: + - run: command: go mod download - save_cache: @@ -176,7 +176,7 @@ jobs: go-smb-test: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string
client.go+35 −2 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -14,6 +15,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -27,6 +31,10 @@ type Client struct { // Getters is the list of protocols supported by this client. If this // is nil, then the default Getters variable will be used. Getters []Getter + + // Disable symlinks is used to prevent copying or writing files through symlinks for Get requests. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool } // GetResult is the result of a Client.Get @@ -41,15 +49,36 @@ func (c *Client) Get(ctx context.Context, req *Request) (*GetResult, error) { return nil, err } + // Pass along the configured Getter client in the context for usage with the X-Terraform-Get feature. + ctx = NewContextWithClient(ctx, c) + // Store this locally since there are cases we swap this if req.GetMode == ModeInvalid { req.GetMode = ModeAny } + // Client setting takes precedence for all requests + if c.DisableSymlinks { + req.DisableSymlinks = true + } + // If there is a subdir component, then we download the root separately // and then copy over the proper subdir. req.Src, req.subDir = SourceDirSubdir(req.Src) + if req.subDir != "" { + // Check if the subdirectory is attempting to traverse upwards, outside of + // the cloned repository path. + req.subDir = filepath.Clean(req.subDir) + if containsDotDot(req.subDir) { + return nil, fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + + // Prevent absolute paths, remove a leading path separator from the subdirectory + if req.subDir[0] == os.PathSeparator { + req.subDir = req.subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return nil, err @@ -123,7 +152,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * // Determine if we have an archive type archiveV := q.Get("archive") if archiveV != "" { - // Delete the paramter since it is a magic parameter we don't + // Delete the parameter since it is a magic parameter we don't // want to pass on to the Getter q.Del("archive") req.u.RawQuery = q.Encode() @@ -199,6 +228,10 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * filename = v } + if containsDotDot(filename) { + return nil, &getError{true, fmt.Errorf("filename query parameter contain path traversal")} + } + req.Dst = filepath.Join(req.Dst, filename) } } @@ -284,7 +317,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * return nil, &getError{true, err} } - err = copyDir(ctx, req.realDst, subDir, false, req.umask()) + err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) if err != nil { return nil, &getError{false, err} }
client_option.go+21 −0 modified@@ -1,5 +1,26 @@ package getter +import ( + "context" +) + +type clientContextKey int + +const clientContextValue clientContextKey = 0 + +func NewContextWithClient(ctx context.Context, client *Client) context.Context { + return context.WithValue(ctx, clientContextValue, client) +} + +func ClientFromContext(ctx context.Context) *Client { + // ctx.Value returns nil if ctx has no value for the key; + client, ok := ctx.Value(clientContextValue).(*Client) + if !ok { + return nil + } + return client +} + // configure configures a client with options. func (c *Client) configure() error { // Default decompressor values
cmd/go-getter/main.go+6 −1 modified@@ -16,6 +16,7 @@ import ( func main() { modeRaw := flag.String("mode", "any", "get mode (any, file, dir)") progress := flag.Bool("progress", false, "display terminal progress") + noSymlinks := flag.Bool("disable-symlinks", false, "prevent copying or writing files through symlinks") flag.Parse() args := flag.Args() if len(args) < 2 { @@ -54,12 +55,16 @@ func main() { if *progress { req.ProgressListener = defaultProgressBar } - wg := sync.WaitGroup{} wg.Add(1) client := getter.DefaultClient + // Disable symlinks for all client requests + if *noSymlinks { + client.DisableSymlinks = true + } + getters := getter.Getters getters = append(getters, new(gcs.Getter)) getters = append(getters, new(s3.Getter))
copy_dir.go+8 −2 modified@@ -11,7 +11,7 @@ import ( // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { src, err := filepath.EvalSymlinks(src) if err != nil { return err @@ -34,6 +34,12 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } } + if disableSymlinks { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + } + // The "path" has the src prefixed to it. We need to join our // destination with the path without the src on it. dstPath := filepath.Join(dst, path[len(src):]) @@ -54,7 +60,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err }
detect_test.go+6 −5 modified@@ -6,11 +6,12 @@ import ( ) func TestDetect(t *testing.T) { - gitGetter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + gitGetter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } cases := []struct { Input string
gcs/get_gcs.go+27 −1 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "cloud.google.com/go/storage" "github.com/hashicorp/go-getter/v2" @@ -15,10 +16,21 @@ import ( // Getter is a Getter implementation that will download a module from // a GCS bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(u) if err != nil { @@ -54,6 +66,13 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { } func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil { @@ -111,6 +130,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil {
get_file_copy.go+14 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,19 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err
get_git.go+23 −12 modified@@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { Detectors []Detector + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -71,10 +76,16 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { req.u.RawQuery = q.Encode() } + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -121,7 +132,7 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(req.Dst, ref); err != nil { + if err := g.checkout(ctx, req.Dst, ref); err != nil { return err } } @@ -163,8 +174,8 @@ func (g *GitGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -192,18 +203,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -226,9 +237,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -278,13 +289,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err }
get_git_test.go+125 −17 modified@@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -342,12 +344,13 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + ctx := context.Background() + if err := checkGitVersion(ctx, "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(ctx, "1.9"); err != nil { t.Fatal(err) } } @@ -411,11 +414,12 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { GetMode: ModeDir, } - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } client := &Client{ Getters: []Getter{getter}, @@ -623,11 +627,12 @@ func TestGitGetter_GitHubDetector(t *testing.T) { } pwd := "/pwd" - f := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + f := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for i, tc := range cases { req := &Request{ @@ -704,11 +709,12 @@ func TestGitGetter_Detector(t *testing.T) { } pwd := "/pwd" - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for _, tc := range cases { t.Run(tc.Input, func(t *testing.T) { @@ -730,6 +736,108 @@ func TestGitGetter_Detector(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink("/etc/passwd", path); err != nil { + t.Fatal(err) + } + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + DisableSymlinks: true, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } + } +} + +func TestGitGetter_subdirectory_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T
get.go+7 −6 modified@@ -76,12 +76,13 @@ func init() { // The order of the Getters in the list may affect the result // depending if the Request.Src is detected as valid by multiple getters Getters = []Getter{ - &GitGetter{[]Detector{ - new(GitHubDetector), - new(GitDetector), - new(BitBucketDetector), - new(GitLabDetector), - }, + &GitGetter{ + Detectors: []Detector{ + new(GitHubDetector), + new(GitDetector), + new(BitBucketDetector), + new(GitLabDetector), + }, }, new(HgGetter), new(SmbClientGetter),
get_hg.go+21 −8 modified@@ -8,14 +8,20 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" ) // HgGetter is a Getter implementation that will download a module from // a Mercurial repository. -type HgGetter struct{} +type HgGetter struct { + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration +} func (g *HgGetter) Mode(ctx context.Context, _ *url.URL) (Mode, error) { return ModeDir, nil @@ -49,13 +55,20 @@ func (g *HgGetter) Get(ctx context.Context, req *Request) error { if err != nil && !os.IsNotExist(err) { return err } + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if err != nil { - if err := g.clone(req.Dst, newURL); err != nil { + if err := g.clone(ctx, req.Dst, newURL); err != nil { return err } } - if err := g.pull(req.Dst, newURL); err != nil { + if err := g.pull(ctx, req.Dst, newURL); err != nil { return err } @@ -102,21 +115,21 @@ func (g *HgGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...)
get_hg_test.go+104 −0 modified@@ -2,10 +2,13 @@ package getter import ( "context" + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" + "time" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -118,3 +121,104 @@ func TestHgGetter_GetFile(t *testing.T) { } testing_helper.AssertContents(t, dst, "Hello\n") } +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + + tc := []struct { + name string + req Request + errChk func(testing.TB, error) + }{ + { + // If arguments are allowed in the destination, this request to Get will fail + name: "arguments allowed in destination", + req: Request{ + Dst: "--config=alias.clone=!touch ./TEST", + u: testModuleURL("basic-hg"), + }, + errChk: func(t testing.TB, err error) { + if err != nil { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + name: "arguments passed into rev parameter", + req: Request{ + u: testModuleURL("basic-hg?rev=--config=alias.update=!false"), + }, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "hg: parse error") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpreted the argument passed in the URL + name: "arguments passed in the repository URL", + req: Request{ + u: &url.URL{Path: "--config=alias.clone=false"}}, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + } + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + g := new(HgGetter) + + if tt.req.Dst == "" { + dst := testing_helper.TempDir(t) + tt.req.Dst = dst + } + + defer os.RemoveAll(tt.req.Dst) + err := g.Get(ctx, &tt.req) + tt.errChk(t, err) + }) + } +} + +func TestHgGetter_GetWithTimeout(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + g := &HgGetter{ + Timeout: 1 * time.Millisecond, + } + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testModuleURL("basic-hg/foo.txt"), + } + + if err := g.Get(ctx, req); err == nil { + t.Fatalf("err: %s", err.Error()) + } +}
get_http.go+305 −49 modified@@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "time" safetemp "github.com/hashicorp/go-safetemp" ) @@ -26,7 +27,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -49,6 +52,35 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { @@ -58,7 +90,112 @@ func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { return ModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(ctx context.Context, req *Request) error { + // Optionally disable any X-Terraform-Get redirects. This is recommended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } // Copy the URL so we can modify it var newU url.URL = *req.u req.u = &newU @@ -70,17 +207,33 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + g.Client = httpClient + } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := req.u.Query() q.Add("terraform-get", "1") req.u.RawQuery = q.Encode() + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - httpReq, err := http.NewRequestWithContext(ctx, "GET", req.u.String(), nil) + httpReq, err := http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } @@ -92,40 +245,53 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Get client with configured Getters from the context + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configured, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + var getterClient *Client + if v := ClientFromContext(ctx); v != nil { + getterClient = v + } else { + getterClient = &Client{ + Getters: []Getter{g}, + } + } + // Extract the source URL var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } } + if source == "" { return fmt.Errorf("no source URL was returned") } - // If there is a subdir component, then we download the root separately - // into a temporary directory, then copy over the proper subdir. - source, subDir := SourceDirSubdir(source) - req = &Request{ - GetMode: ModeDir, - Src: source, - Dst: req.Dst, - } - if subDir == "" { - _, err = DefaultClient.Get(ctx, req) - return err - } - // We have a subdir, time to jump some hoops - return g.getSubdir(ctx, req, source, subDir) + return g.getXTerraformSource(ctx, req, source, getterClient) } // GetFile fetches the file from src and stores it at dst. @@ -135,6 +301,11 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { // falsely identified as being replaced, or corrupted with extra bytes // appended. func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(req.u); err != nil { @@ -157,38 +328,67 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { } var currentFileSize int64 + var httpReq *http.Request + + if g.DoNotCheckHeadFirst == false { + headCtx := ctx - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - httpReq, err := http.NewRequestWithContext(ctx, "HEAD", req.u.String(), nil) + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + httpReq, err = http.NewRequestWithContext(headCtx, "HEAD", req.u.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + httpReq.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(httpReq) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } + } + } + } + } + } + } + + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + httpReq, err = http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } if g.Header != nil { httpReq.Header = g.Header.Clone() } - headResp, err := g.Client.Do(httpReq) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) - if currentFileSize >= headResp.ContentLength { - // file already present - return nil - } - } - } - } - } + if currentFileSize > 0 { + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) } - httpReq.Method = "GET" resp, err := g.Client.Do(httpReq) if err != nil { @@ -204,23 +404,70 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(readCtx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if req.ProgressListener != nil { // track download fn := filepath.Base(req.u.EscapedPath()) body = req.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } return err } +// getXTerraformSource downloads the source into the destination +// using a protocol switching capable client. +func (g *HttpGetter) getXTerraformSource(ctx context.Context, req *Request, source string, client *Client) error { + + // If there is a subdir component, then we download the root separately + // into a temporary directory, then copy over the proper subdir. + source, subDir := SourceDirSubdir(source) + req = &Request{ + GetMode: ModeDir, + Src: source, + Dst: req.Dst, + DisableSymlinks: req.DisableSymlinks, + } + + if subDir == "" { + // We have a X-Terraform-Get source lets check for supported Getters + var allowed bool + for _, getter := range client.Getters { + shouldDownload, err := Detect(req, getter) + if err != nil { + return fmt.Errorf("failed to detect the proper Getter to handle %s: %w", source, err) + } + if !shouldDownload { + // the request should not be processed by that getter + continue + } + allowed = true + } + + if !allowed { + protocol := strings.Split(source, ":")[0] + return fmt.Errorf("download not supported for scheme %q", protocol) + } + + _, err := client.Get(ctx, req) + return err + } + + // We have a subdir, time to jump some hoops + return g.getSubdir(ctx, req, source, subDir, client) + +} + // getSubdir downloads the source into the destination, but with // the proper subdir. -func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string) error { +func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string, client *Client) error { // Create a temporary directory to store the full source. This has to be // a non-existent directory. td, tdcloser, err := safetemp.Dir("", "getter") @@ -229,8 +476,13 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir } defer tdcloser.Close() - // Download that into the given directory - if _, err := Get(ctx, td, source); err != nil { + tdReq := &Request{ + Src: source, + Dst: td, + GetMode: ModeDir, + DisableSymlinks: req.DisableSymlinks, + } + if _, err := client.Get(ctx, tdReq); err != nil { return err } @@ -256,18 +508,22 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir return err } - return copyDir(ctx, req.Dst, sourcePath, false, req.umask()) + return copyDir(ctx, req.Dst, sourcePath, false, req.DisableSymlinks, req.umask()) } // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF {
get_http_test.go+601 −28 modified@@ -9,13 +9,15 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" "strconv" "strings" "testing" + cleanhttp "github.com/hashicorp/go-cleanhttp" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -38,12 +40,26 @@ func TestHttpGetter_header(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -103,12 +119,26 @@ func TestHttpGetter_meta(t *testing.T) { u.Path = "/meta" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -134,12 +164,26 @@ func TestHttpGetter_metaSubdir(t *testing.T) { u.Path = "/meta-subdir" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -165,12 +209,26 @@ func TestHttpGetter_metaSubdirGlob(t *testing.T) { u.Path = "/meta-subdir-glob" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -361,12 +419,26 @@ func TestHttpGetter_auth(t *testing.T) { u.User = url.UserPassword("foo", "bar") req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -397,12 +469,26 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer tempEnv(t, "NETRC", path)() req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -442,14 +528,29 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter__RespectsContextCanceled(t *testing.T) { @@ -491,6 +592,432 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := testing_helper.TempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Getters: []Getter{g}, + } + + req := Request{ + Dst: dst, + Src: u.String(), + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + tc := []struct { + name string + configuredGetters []Getter + errExpected bool + }{ + {name: "configured getter for git protocol switch", configuredGetters: []Getter{new(GitGetter)}, errExpected: false}, + {name: "configured getter for multiple protocol switch", configuredGetters: []Getter{new(GitGetter), new(HgGetter), new(FileGetter)}, errExpected: false}, + {name: "configured getter for file protocol switch", configuredGetters: []Getter{new(FileGetter)}, errExpected: true}, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + + dst := testing_helper.TempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Getters: []Getter{g}, + } + client.Getters = append(client.Getters, tt.configuredGetters...) + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeDir, + } + + _, err := client.Get(ctx, &req) + // For configured getters that support git, the git repository doesn't exist so error will not be nil. + // If we get a nil error when we expect one other than the git error git exited with -1 we should fail. + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + // We only care about the error messages that indicate that we can download the git header URL + if tt.errExpected && err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + } + }) + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.MaxBytes = 10 + g.DoNotCheckHeadFirst = true + + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeFile, + } + + _, err := client.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerSubDir(t) + defer ln.Close() + + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + g := new(HttpGetter) + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("url: %q", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeAny, + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func TestHttpGetter_XTerraformWithClientFromContext(t *testing.T) { + tc := []struct { + name string + client *Client + errExpected bool + }{ + { + name: "default getters", + client: &Client{ + Getters: Getters, + }, + errExpected: false, + }, + { + name: "client configured with needed getters", + client: &Client{ + Getters: []Getter{ + new(HttpGetter), + new(FileGetter), + }, + }, + errExpected: false, + }, + { + name: "nil client", + errExpected: true, + }, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ln := testHttpServer(t) + defer ln.Close() + ctx := context.Background() + + g := new(HttpGetter) + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/header" + + req := &Request{ + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, + } + + // Using a client stored in the ctx with a file getter should work + ctx = NewContextWithClient(ctx, tt.client) + + err := g.Get(ctx, req) + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + + if err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + return + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -515,6 +1042,29 @@ func testHttpServer(t *testing.T) net.Listener { return ln } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) { if expected, ok := r.URL.Query()["expected"]; ok { if r.Header.Get(expected[0]) != "" { @@ -598,6 +1148,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` <html> <head>
get_test.go+22 −1 modified@@ -348,6 +348,27 @@ func TestGetFile_archive(t *testing.T) { // Verify the main file exists testing_helper.AssertContents(t, dst, "Hello\n") } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + ctx := context.Background() + op, err := GetAny(ctx, dst, u) + + if op != nil { + t.Fatalf("unexpected op: %v", op) + } + + if err == nil { + t.Fatalf("expected error") + } + + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} func TestGetFile_archiveChecksum(t *testing.T) { ctx := context.Background() @@ -763,7 +784,7 @@ func TestGetFile_inplace_badChecksum(t *testing.T) { } } -func TestgetForcedGetter(t *testing.T) { +func TestGetForcedGetter(t *testing.T) { type args struct { src string }
helper/testing/utils.go+1 −1 modified@@ -34,7 +34,7 @@ func AssertContents(t *testing.T, path string, contents string) { } if !reflect.DeepEqual(data, []byte(contents)) { - t.Fatalf("bad. expected:\n\n%s\n\nGot:\n\n%s", contents, string(data)) + t.Fatalf("bad. expected:\n\n%q\n\nGot:\n\n%q", contents, string(data)) } }
README.md+98 −16 modified@@ -21,8 +21,8 @@ URLs. For example: "github.com/hashicorp/go-getter" would turn into a Git URL. Or "./foo" would turn into a file URL. These are extensible. This library is used by [Terraform](https://terraform.io) for -downloading modules and [Nomad](https://nomadproject.io) for downloading -binaries. +downloading modules, [Packer](https://packer.io) for downloading binaries, and +[Nomad](https://nomadproject.io) for downloading binaries. ## Installation and Usage @@ -47,6 +47,16 @@ $ go-getter github.com/foo/bar ./foo The command is useful for verifying URL structures. +## Security +Fetching resources from user-supplied URLs is an inherently dangerous operation and may +leave your application vulnerable to [server side request forgery](https://owasp.org/www-community/attacks/Server_Side_Request_Forgery), +[path traversal](https://owasp.org/www-community/attacks/Path_Traversal), [denial of service](https://owasp.org/www-community/attacks/Denial_of_Service) +or other security flaws. + +go-getter contains mitigations for some of these security issues, but should still be used with +caution in security-critical contexts. See the available [security options](#Security-Options) that +can be configured to mitigate some of these risks. + ## URL Format go-getter uses a single string URL as input to download from a variety of @@ -83,7 +93,7 @@ is built-in by default: file URLs. * GitHub URLs, such as "github.com/mitchellh/vagrant" are automatically changed to Git protocol over HTTP. - * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically + * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically changed to Git protocol over HTTP. * BitBucket URLs, such as "bitbucket.org/mitchellh/vagrant" are automatically changed to a Git or mercurial protocol using the BitBucket API. @@ -178,7 +188,7 @@ checksum string. Examples: ``` ./foo.txt?checksum=file:./foo.txt.sha256sum ``` - + When checksumming from a file - ex: with `checksum=file:url` - go-getter will get the file linked in the URL after `file:` using the same configuration. For example, in `file:http://releases.ubuntu.com/cosmic/MD5SUMS` go-getter will @@ -279,7 +289,7 @@ None from a private key file on disk, you would run `base64 -w0 <file>`. **Note**: Git 2.3+ is required to use this feature. - + * `depth` - The Git clone depth. The provided number specifies the last `n` revisions to clone from the repository. @@ -374,35 +384,107 @@ files from a smb shared folder whenever the url is prefixed with `smb://`. ⚠️ The [`smbclient`](https://www.samba.org/samba/docs/current/man-html/smbclient.1.html) command is available only for Linux. This is the ONLY option for a Linux user and therefore the client must be installed. - + The `smbclient` cli is not available for Windows and MacOS. The go-getter will try to get files using the file system, when this happens the getter uses the FileGetter implementation. -When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will +When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will try to access the following folders when looking for local mounts. - MacOS: /Volumes/<shared_path> - Windows: \\\\\<host>\\\<shared_path> -The following examples work for all the OSes: +The following examples work for all the OSes: - smb://host/shared/dir (downloads directory content) -- smb://host/shared/dir/file (downloads file) +- smb://host/shared/dir/file (downloads file) -The following examples work for Linux: +The following examples work for Linux: - smb://username:password@host/shared/dir (downloads directory content) - smb://username@host/shared/dir - smb://username:password@host/shared/dir/file (downloads file) - smb://username@host/shared/dir/file ⚠️ The above examples also work on the other OSes but the authentication is not used to access the file system. - - + + #### SMB Testing The test for `get_smb.go` requires a smb server running which can be started inside a docker container by -running `make start-smb`. Once the container is up the shared folder can be accessed via `smb://<ip|name>/public/<dir|file>` or -`smb://user:password@<ip|name>/private/<dir|file>` by another container or machine in the same network. +running `make start-smb`. Once the container is up the shared folder can be accessed via `smb://<ip|name>/public/<dir|file>` or +`smb://user:password@<ip|name>/private/<dir|file>` by another container or machine in the same network. -To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some +To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some mock files and directories will be added to the shared folder and a go-getter container will start together with the samba server. -Once the environment for testing is prepared, run `make smbtests` to run the tests. \ No newline at end of file +Once the environment for testing is prepared, run `make smbtests` to run the tests. + +### Security Options + +**Disable Symlinks** + +In your getter client config, we recommend using the `DisableSymlinks` option, +which prevents writing through or copying from symlinks (which may point outside the directory). + +```go +client := getter.Client{ + // This will prevent copying or writing files through symlinks + DisableSymlinks: true, +} +``` + +**Disable or Limit `X-Terraform-Get`** + +Go-Getter supports arbitrary redirects via the `X-Terraform-Get` header. This functionality +exists to support [Terraform use cases](https://www.terraform.io/language/modules/sources#http-urls), +but is likely not needed in most applications. + +For code that uses the `HttpGetter`, add the following configuration options: + +```go +var httpGetter = &getter.HttpGetter{ + // Most clients should disable X-Terraform-Get + // See the note below + XTerraformGetDisabled: true, + // Your software probably doesn’t rely on X-Terraform-Get, but + // if it does, you should set the above field to false, plus + // set XTerraformGet Limit to prevent endless redirects + // XTerraformGetLimit: 10, +} +``` + +**Enforce Timeouts** + +The `HttpGetter` supports timeouts and other resource-constraining configuration options. The `GitGetter` and `HgGetter` +only support timeouts. + +Configuration for the `HttpGetter`: + +```go +var httpGetter = &getter.HttpGetter{ + // Disable pre-fetch HEAD requests + DoNotCheckHeadFirst: true, + + // As an alternative to the above setting, you can + // set a reasonable timeout for HEAD requests + // HeadFirstTimeout: 10 * time.Second, + // Read timeout for HTTP operations + ReadTimeout: 30 * time.Second, + // Set the maximum number of bytes + // that can be read by the getter + MaxBytes: 500000000, // 500 MB +} +``` + +For code that uses the `GitGetter` or `HgGetter`, set the `Timeout` option: +```go +var gitGetter = &getter.GitGetter{ + // Set a reasonable timeout for git operations + Timeout: 5 * time.Minute, +} +``` + +```go +var hgGetter = &getter.HgGetter{ + // Set a reasonable timeout for hg operations + Timeout: 5 * time.Minute, +} +```
request.go+4 −0 modified@@ -58,6 +58,10 @@ type Request struct { // By default a no op progress listener is used. ProgressListener ProgressTracker + // Disable symlinks is used to prevent copying or writing files through symlinks. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool + u *url.URL subDir, realDst string }
s3/get_s3.go+30 −4 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,9 +21,21 @@ import ( // Getter is a Getter implementation that will download a module from // a S3 bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -40,7 +53,7 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -64,6 +77,12 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(req.URL()) if err != nil { @@ -105,7 +124,7 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { s3Req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(s3Req) + resp, err := client.ListObjectsWithContext(ctx, s3Req) if err != nil { return err } @@ -139,6 +158,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(req.URL()) if err != nil { return err @@ -161,7 +187,7 @@ func (g *Getter) getObject(ctx context.Context, client *s3.S3, req *getter.Reque s3req.VersionId = aws.String(version) } - resp, err := client.GetObject(s3req) + resp, err := client.GetObjectWithContext(ctx, s3req) if err != nil { return err }
source.go+3 −1 modified@@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err }
a2ebce998f8dMultiple fixes for go-getter (#359)
16 files changed · +1169 −97
client.go+23 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -13,6 +14,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -76,6 +80,9 @@ type Client struct { // This is identical to tls.Config.InsecureSkipVerify. Insecure bool + // Disable symlinks + DisableSymlinks bool + Options []ClientOption } @@ -123,6 +130,17 @@ func (c *Client) Get() error { dst := c.Dst src, subDir := SourceDirSubdir(src) if subDir != "" { + // Check if the subdirectory is attempting to traverse updwards, outside of + // the cloned repository path. + subDir := filepath.Clean(subDir) + if containsDotDot(subDir) { + return fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + // Prevent absolute paths, remove a leading path separator from the subdirectory + if subDir[0] == os.PathSeparator { + subDir = subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return err @@ -230,6 +248,10 @@ func (c *Client) Get() error { filename = v } + if containsDotDot(filename) { + return fmt.Errorf("filename query parameter contain path traversal") + } + dst = filepath.Join(dst, filename) } } @@ -318,7 +340,7 @@ func (c *Client) Get() error { return err } - return copyDir(c.Ctx, realDst, subDir, false, c.umask()) + return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask()) } return nil
client_option.go+61 −7 modified@@ -1,46 +1,100 @@ package getter -import "context" +import ( + "context" + "os" +) -// A ClientOption allows to configure a client +// ClientOption is used to configure a client. type ClientOption func(*Client) error -// Configure configures a client with options. +// Configure applies all of the given client options, along with any default +// behavior including context, decompressors, detectors, and getters used by +// the client. func (c *Client) Configure(opts ...ClientOption) error { + // If the context has not been configured use the background context. if c.Ctx == nil { c.Ctx = context.Background() } + + // Store the options used to configure this client. c.Options = opts + + // Apply all of the client options. for _, opt := range opts { err := opt(c) if err != nil { return err } } - // Default decompressor values + + // If the client was not configured with any Decompressors, Detectors, + // or Getters, use the default values for each. if c.Decompressors == nil { c.Decompressors = Decompressors } - // Default detector values if c.Detectors == nil { c.Detectors = Detectors } - // Default getter values if c.Getters == nil { c.Getters = Getters } + // Set the client for each getter, so the top-level client can know + // the getter-specific client functions or progress tracking. for _, getter := range c.Getters { getter.SetClient(c) } + return nil } // WithContext allows to pass a context to operation // in order to be able to cancel a download in progress. -func WithContext(ctx context.Context) func(*Client) error { +func WithContext(ctx context.Context) ClientOption { return func(c *Client) error { c.Ctx = ctx return nil } } + +// WithDecompressors specifies which Decompressor are available. +func WithDecompressors(decompressors map[string]Decompressor) ClientOption { + return func(c *Client) error { + c.Decompressors = decompressors + return nil + } +} + +// WithDecompressors specifies which compressors are available. +func WithDetectors(detectors []Detector) ClientOption { + return func(c *Client) error { + c.Detectors = detectors + return nil + } +} + +// WithGetters specifies which getters are available. +func WithGetters(getters map[string]Getter) ClientOption { + return func(c *Client) error { + c.Getters = getters + return nil + } +} + +// WithMode specifies which client mode the getters should operate in. +func WithMode(mode ClientMode) ClientOption { + return func(c *Client) error { + c.Mode = mode + return nil + } +} + +// WithUmask specifies how to mask file permissions when storing local +// files or decompressing an archive. +func WithUmask(mode os.FileMode) ClientOption { + return func(c *Client) error { + c.Umask = mode + return nil + } +}
copy_dir.go+21 −3 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -16,8 +17,11 @@ func mode(mode, umask os.FileMode) os.FileMode { // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { - src, err := filepath.EvalSymlinks(src) +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { + // We can safely evaluate the symlinks here, even if disabled, because they + // will be checked before actual use in walkFn and copyFile + var err error + src, err = filepath.EvalSymlinks(src) if err != nil { return err } @@ -26,6 +30,20 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask if err != nil { return err } + + if disableSymlinks { + fileInfo, err := os.Lstat(path) + if err != nil { + return fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + // if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // return ErrSymlinkCopy + // } + } + if path == src { return nil } @@ -59,7 +77,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err }
get_file_copy.go+12 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,17 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err
get_file_unix.go+7 −1 modified@@ -87,7 +87,13 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { return os.Symlink(path, dst) } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, fi.Mode(), g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, fi.Mode(), g.client.umask()) return err }
get_file_windows.go+7 −1 modified@@ -111,8 +111,14 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { } } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, 0666, g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, 0666, g.client.umask()) return err }
get_gcs.go+26 −2 modified@@ -3,13 +3,15 @@ package getter import ( "context" "fmt" - "golang.org/x/oauth2" - "google.golang.org/api/option" "net/url" "os" "path/filepath" "strconv" "strings" + "time" + + "golang.org/x/oauth2" + "google.golang.org/api/option" "cloud.google.com/go/storage" "google.golang.org/api/iterator" @@ -19,11 +21,21 @@ import ( // a GCS bucket. type GCSGetter struct { getter + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -61,6 +73,12 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { func (g *GCSGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -120,6 +138,12 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error { func (g *GCSGetter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, fragment, err := g.parseURL(u) if err != nil {
get_git.go+28 −16 modified@@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { getter + + // Timeout sets a deadline which all git CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -35,6 +40,13 @@ func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *GitGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("git"); err != nil { return fmt.Errorf("git must be available and on the PATH") } @@ -76,7 +88,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -123,7 +135,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } } @@ -161,8 +173,8 @@ func (g *GitGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -182,7 +194,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR originalRef := ref // we handle an unspecified ref differently than explicitly selecting the default branch below if ref == "" { - ref = findRemoteDefaultBranch(u) + ref = findRemoteDefaultBranch(ctx, u) } if depth > 0 { args = append(args, "--depth", strconv.Itoa(depth)) @@ -211,7 +223,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR // If we didn't add --depth and --branch above then we will now be // on the remote repository's default branch, rather than the selected // ref, so we'll need to fix that before we return. - return g.checkout(dst, originalRef) + return g.checkout(ctx, dst, originalRef) } return nil } @@ -226,18 +238,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -260,9 +272,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -275,9 +287,9 @@ func findDefaultBranch(dst string) string { // findRemoteDefaultBranch checks the remote repo's HEAD symref to return the remote repo's // default branch. "master" is returned if no HEAD symref exists. -func findRemoteDefaultBranch(u *url.URL) string { +func findRemoteDefaultBranch(ctx context.Context, u *url.URL) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "ls-remote", "--symref", u.String(), "HEAD") + cmd := exec.CommandContext(ctx, "git", "ls-remote", "--symref", u.String(), "HEAD") cmd.Stdout = &stdoutbuf err := cmd.Run() matches := lsRemoteSymRefRegexp.FindStringSubmatch(stdoutbuf.String()) @@ -326,13 +338,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err }
get_git_test.go+119 −2 modified@@ -2,7 +2,10 @@ package getter import ( "bytes" + "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -436,12 +439,12 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + if err := checkGitVersion(context.Background(), "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(context.Background(), "1.9"); err != nil { t.Fatal(err) } } @@ -693,6 +696,120 @@ func TestGitGetter_setupGitEnvWithExisting_sshKey(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + target, err := ioutil.TempFile("", "link-target") + if err != nil { + t.Fatal(err) + } + defer os.Remove(target.Name()) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink(target.Name(), path); err != nil { + t.Fatal(err) + } + + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + Mode: ClientModeDir, + DisableSymlinks: true, + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } + } + +} + +func TestGitGetter_subdirectory(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + + Mode: ClientModeDir, + + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T
get_hg.go+19 −7 modified@@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -17,6 +18,10 @@ import ( // a Mercurial repository. type HgGetter struct { getter + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { @@ -25,6 +30,13 @@ func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *HgGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("hg"); err != nil { return fmt.Errorf("hg must be available and on the PATH") } @@ -53,12 +65,12 @@ func (g *HgGetter) Get(dst string, u *url.URL) error { return err } if err != nil { - if err := g.clone(dst, newURL); err != nil { + if err := g.clone(ctx, dst, newURL); err != nil { return err } } - if err := g.pull(dst, newURL); err != nil { + if err := g.pull(ctx, dst, newURL); err != nil { return err } @@ -101,21 +113,21 @@ func (g *HgGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...)
get_hg_test.go+44 −0 modified@@ -1,9 +1,11 @@ package getter import ( + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" ) @@ -97,3 +99,45 @@ func TestHgGetter_GetFile(t *testing.T) { } assertContents(t, dst, "Hello\n") } + +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + + g := new(HgGetter) + + // If arguments are allowed in the destination, this Get call will fail + dst := "--config=alias.clone=!false" + defer os.RemoveAll(dst) + err := g.Get(dst, testModuleURL("basic-hg")) + if err != nil { + t.Fatalf("Expected no err, got: %s", err) + } + + dst = tempDir(t) + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + err = g.Get(dst, testModuleURL("basic-hg?rev=--config=alias.update=!false")) + if err != nil { + if !strings.Contains(err.Error(), "hg: parse error") { + t.Fatalf("Expected no err, got: %s", err) + } + } + + dst = tempDir(t) + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpretted the argument passed in the URL + err = g.Get(dst, &url.URL{Path: "--config=alias.clone=false"}) + if err != nil { + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Fatalf("Expected no err, got: %s", err) + } + } + +}
get_http.go+282 −38 modified@@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/hashicorp/go-cleanhttp" safetemp "github.com/hashicorp/go-safetemp" @@ -28,7 +29,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -52,6 +55,36 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { @@ -61,8 +94,115 @@ func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + // Optionally disable any X-Terraform-Get redirects. This is reccomended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + // Copy the URL so we can modify it var newU url.URL = *u u = &newU @@ -74,22 +214,40 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient - if g.client != nil && g.client.Insecure { - insecureTransport := cleanhttp.DefaultTransport() - insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - g.Client.Transport = insecureTransport + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + client := httpClient + if g.client != nil && g.client.Insecure { + insecureTransport := cleanhttp.DefaultTransport() + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + client.Transport = insecureTransport + } + g.Client = client } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := u.Query() q.Add("terraform-get", "1") u.RawQuery = q.Encode() + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(readCtx, "GET", u.String(), nil) if err != nil { return err } @@ -102,18 +260,28 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } - // Extract the source URL + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Extract the source URL, var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } @@ -127,9 +295,43 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { source, subDir := SourceDirSubdir(source) if subDir == "" { var opts []ClientOption + + // Check if the protocol was switched to one which was not configured. + // + // Otherwise, all default getters are allowed. + if g.client != nil && g.client.Getters != nil { + protocol := strings.Split(source, ":")[0] + _, allowed := g.client.Getters[protocol] + if !allowed { + return fmt.Errorf("no getter available for X-Terraform-Get source protocol: %q", protocol) + } + } + + // Add any getter client options. if g.client != nil { opts = g.client.Options } + + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configued, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + if g.client == nil { + opts = append(opts, WithGetters(map[string]Getter{ + "http": g, + "https": g, + })) + } + + // Ensure we pass along the context we constructed in this function. + // + // This is especially important to enforce a limit on X-Terraform-Get redirects + // which could be setup, if configured, at the top of this function. + opts = append(opts, WithContext(ctx)) + + // Note: this allows the protocol to be switched to another configured getters. return Get(dst, source, opts...) } @@ -145,6 +347,12 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // appended. func (g *HttpGetter) GetFile(dst string, src *url.URL) error { ctx := g.Context() + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(src); err != nil { @@ -171,39 +379,61 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } } - var currentFileSize int64 + var ( + currentFileSize int64 + req *http.Request + ) - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - req, err := http.NewRequestWithContext(ctx, "HEAD", src.String(), nil) - if err != nil { - return err - } - if g.Header != nil { - req.Header = g.Header.Clone() - } - headResp, err := g.Client.Do(req) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - if currentFileSize >= headResp.ContentLength { - // file already present - return nil + if !g.DoNotCheckHeadFirst { + headCtx := ctx + + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + req, err = http.NewRequestWithContext(headCtx, "HEAD", src.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + req.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(req) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } } } } } } } - req, err = http.NewRequestWithContext(ctx, "GET", src.String(), nil) + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + req, err = http.NewRequestWithContext(readCtx, "GET", src.String(), nil) if err != nil { return err } @@ -228,14 +458,18 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if g.client != nil && g.client.ProgressListener != nil { // track download fn := filepath.Base(src.EscapedPath()) body = g.client.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } @@ -284,18 +518,28 @@ func (g *HttpGetter) getSubdir(ctx context.Context, dst, source, subDir string) return err } - return copyDir(ctx, dst, sourcePath, false, g.client.umask()) + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + + return copyDir(ctx, dst, sourcePath, false, disableSymlinks, g.client.umask()) } // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF {
get_http_test.go+473 −14 modified@@ -9,12 +9,15 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" "strconv" "strings" "testing" + + "github.com/hashicorp/go-cleanhttp" ) func TestHttpGetter_impl(t *testing.T) { @@ -34,8 +37,27 @@ func TestHttpGetter_header(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -44,6 +66,7 @@ func TestHttpGetter_header(t *testing.T) { if _, err := os.Stat(mainPath); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter_requestHeader(t *testing.T) { @@ -87,8 +110,27 @@ func TestHttpGetter_meta(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -330,14 +372,27 @@ func TestHttpGetter_auth(t *testing.T) { u.Path = "/meta-auth" u.User = url.UserPassword("foo", "bar") - // Get it! - if err := g.Get(dst, &u); err != nil { - t.Fatalf("err: %s", err) + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) } - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -360,8 +415,27 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer closer() defer tempEnv(t, "NETRC", path)() - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -399,8 +473,27 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -441,6 +534,326 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Ctx: ctx, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + client.Src = u.String() + client.Dst = dst + + err = client.Get() + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Ctx: ctx, + Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + if !strings.Contains(err.Error(), "no getter available for X-Terraform-Get source protocol") { + t.Fatalf("expected no getter available for X-Terraform-Get source protocol, got: %v", err) + } + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := tempDir(t) + + httpGetter := new(HttpGetter) + httpGetter.MaxBytes = 10 + httpGetter.DoNotCheckHeadFirst = true + + client := &Client{ + Ctx: ctx, + Mode: ClientModeFile, + // Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ln := testHttpServerSubDir(t) + defer ln.Close() + + httpGetter := new(HttpGetter) + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + t.Logf("url: %q", u.String()) + + client := &Client{ + Src: u.String(), + Dst: dst, + Mode: ClientModeAny, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + err = client.Get() + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -504,6 +917,29 @@ func testHttpHandlerMetaAuth(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic").String()))) } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerMetaSubdir(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic//subdir").String()))) } @@ -548,6 +984,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` <html> <head>
get_s3.go+29 −3 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,10 +21,22 @@ import ( // a S3 bucket. type S3Getter struct { getter + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { // Parse URL + ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { return 0, err @@ -40,7 +53,7 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -65,6 +78,12 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { func (g *S3Getter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -106,7 +125,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return err } @@ -141,6 +160,13 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { func (g *S3Getter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { return err @@ -163,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke req.VersionId = aws.String(version) } - resp, err := client.GetObject(req) + resp, err := client.GetObjectWithContext(ctx, req) if err != nil { return err }
get_test.go+15 −0 modified@@ -492,6 +492,21 @@ func TestGetFile_filename(t *testing.T) { } } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := tempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + err := GetAny(dst, u) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} + func TestGetFile_checksumSkip(t *testing.T) { dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst))
source.go+3 −1 modified@@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err }
Vulnerability mechanics
Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
11- github.com/advisories/GHSA-cjr4-fv6c-f3mvghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2022-30322ghsaADVISORY
- discuss.hashicorp.comghsax_refsource_MISCWEB
- discuss.hashicorp.com/t/hcsec-2022-13-multiple-vulnerabilities-in-go-getter-libraryghsaWEB
- discuss.hashicorp.com/t/hcsec-2022-13-multiple-vulnerabilities-in-go-getter-library/mitrex_refsource_MISC
- discuss.hashicorp.com/t/hcsec-2022-13-multiple-vulnerabilities-in-go-getter-library/39930ghsax_refsource_MISCWEB
- github.com/hashicorp/go-getter/commit/38e97387488f5439616be60874979433a12edb48ghsaWEB
- github.com/hashicorp/go-getter/commit/a2ebce998f8d4105bd4b78d6c99a12803ad97a45ghsaWEB
- github.com/hashicorp/go-getter/pull/359ghsaWEB
- github.com/hashicorp/go-getter/pull/361ghsaWEB
- pkg.go.dev/vuln/GO-2022-0586ghsaWEB
News mentions
0No linked articles in our index yet.