CVE-2022-30321
Description
go-getter up to 1.5.11 and 2.0.2 allowed arbitrary host access via go-getter path traversal, symlink processing, and command injection flaws. Fixed in 1.6.1 and 2.1.0.
AI Insight
LLM-synthesized narrative grounded in this CVE's description and references.
go-getter before v1.6.1/2.1.0 vulnerable to path traversal and symlink abuse, allowing arbitrary file access.
## Vulnerability go-getter versions v1.0.0 through v1.5.11 and v2.0.0 through v2.0.2 contain multiple security vulnerabilities: path traversal via the src subdirectory or filename query parameter, and improper handling of symlinks when the DisableSymlinks option is not set [2][3][4]. A specially crafted URL can cause directory traversal outside the intended destination.
Exploitation
An attacker can exploit these issues by providing a malicious URL to an application that uses go-getter to download resources. For path traversal, the attacker includes ../ sequences in the subdirectory or filename portion of the URL to write files to arbitrary locations [3]. For symlink processing, if the repository contains symlinks, go-getter would follow them unless DisableSymlinks is explicitly enabled, allowing the attacker to read or write files outside the target directory [2].
Impact
Successful exploitation can lead to arbitrary file read or write within the context of the vulnerable application, potentially resulting in information disclosure or code execution depending on the application's privileges.
Mitigation
The vulnerabilities are fixed in go-getter v1.6.1 (for the 1.x series) and v2.1.0 (for the 2.x series) [2][3][4]. Users should upgrade to these versions. As a workaround, applications can enable the DisableSymlinks client option to prevent symlink-following attacks [2].
AI Insight generated on May 21, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/hashicorp/go-getterGo | < 1.6.1 | 1.6.1 |
github.com/hashicorp/go-getterGo | >= 2.0.0, < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/v2Go | < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/s3/v2Go | < 2.1.0 | 2.1.0 |
github.com/hashicorp/go-getter/gcs/v2Go | < 2.1.0 | 2.1.0 |
Affected products
5- go-getter/go-getterdescription
- ghsa-coords4 versionspkg:golang/github.com/hashicorp/go-getterpkg:golang/github.com/hashicorp/go-getter/gcs/v2pkg:golang/github.com/hashicorp/go-getter/s3/v2pkg:golang/github.com/hashicorp/go-getter/v2
< 1.6.1+ 3 more
- (no CPE)range: < 1.6.1
- (no CPE)range: < 2.1.0
- (no CPE)range: < 2.1.0
- (no CPE)range: < 2.1.0
Patches
238e97387488fMultiple fixes for go-getter v2 (#361)
21 files changed · +1470 −164
.circleci/config.yml+9 −9 modified@@ -49,11 +49,11 @@ commands: jobs: linux-tests: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string - environment: + environment: <<: *ENVIRONMENT parallelism: 4 steps: @@ -104,7 +104,7 @@ jobs: path: *TEST_RESULTS_PATH windows-tests: - executor: + executor: name: win/default shell: bash --login -eo pipefail environment: @@ -115,12 +115,12 @@ jobs: type: string gotestsum-version: type: string - steps: + steps: - run: git config --global core.autocrlf false - checkout - attach_workspace: at: . - - run: + - run: name: Setup (remove pre-installed go) command: | rm -rf "c:\Go" @@ -131,16 +131,16 @@ jobs: - win-golang-<< parameters.go-version >>-cache-v1 - win-gomod-cache-{{ checksum "go.mod" }}-v1 - - run: + - run: name: Install go version << parameters.go-version >> - command: | + command: | if [ ! -d "c:\go" ]; then echo "Cache not found, installing new version of go" curl --fail --location https://dl.google.com/go/go<< parameters.go-version >>.windows-amd64.zip --output go.zip unzip go.zip -d "/c" fi - - run: + - run: command: go mod download - save_cache: @@ -176,7 +176,7 @@ jobs: go-smb-test: docker: - - image: circleci/golang:<< parameters.go-version >> + - image: cimg/go:<< parameters.go-version >> parameters: go-version: type: string
client.go+35 −2 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -14,6 +15,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -27,6 +31,10 @@ type Client struct { // Getters is the list of protocols supported by this client. If this // is nil, then the default Getters variable will be used. Getters []Getter + + // Disable symlinks is used to prevent copying or writing files through symlinks for Get requests. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool } // GetResult is the result of a Client.Get @@ -41,15 +49,36 @@ func (c *Client) Get(ctx context.Context, req *Request) (*GetResult, error) { return nil, err } + // Pass along the configured Getter client in the context for usage with the X-Terraform-Get feature. + ctx = NewContextWithClient(ctx, c) + // Store this locally since there are cases we swap this if req.GetMode == ModeInvalid { req.GetMode = ModeAny } + // Client setting takes precedence for all requests + if c.DisableSymlinks { + req.DisableSymlinks = true + } + // If there is a subdir component, then we download the root separately // and then copy over the proper subdir. req.Src, req.subDir = SourceDirSubdir(req.Src) + if req.subDir != "" { + // Check if the subdirectory is attempting to traverse upwards, outside of + // the cloned repository path. + req.subDir = filepath.Clean(req.subDir) + if containsDotDot(req.subDir) { + return nil, fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + + // Prevent absolute paths, remove a leading path separator from the subdirectory + if req.subDir[0] == os.PathSeparator { + req.subDir = req.subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return nil, err @@ -123,7 +152,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * // Determine if we have an archive type archiveV := q.Get("archive") if archiveV != "" { - // Delete the paramter since it is a magic parameter we don't + // Delete the parameter since it is a magic parameter we don't // want to pass on to the Getter q.Del("archive") req.u.RawQuery = q.Encode() @@ -199,6 +228,10 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * filename = v } + if containsDotDot(filename) { + return nil, &getError{true, fmt.Errorf("filename query parameter contain path traversal")} + } + req.Dst = filepath.Join(req.Dst, filename) } } @@ -284,7 +317,7 @@ func (c *Client) get(ctx context.Context, req *Request, g Getter) (*GetResult, * return nil, &getError{true, err} } - err = copyDir(ctx, req.realDst, subDir, false, req.umask()) + err = copyDir(ctx, req.realDst, subDir, false, req.DisableSymlinks, req.umask()) if err != nil { return nil, &getError{false, err} }
client_option.go+21 −0 modified@@ -1,5 +1,26 @@ package getter +import ( + "context" +) + +type clientContextKey int + +const clientContextValue clientContextKey = 0 + +func NewContextWithClient(ctx context.Context, client *Client) context.Context { + return context.WithValue(ctx, clientContextValue, client) +} + +func ClientFromContext(ctx context.Context) *Client { + // ctx.Value returns nil if ctx has no value for the key; + client, ok := ctx.Value(clientContextValue).(*Client) + if !ok { + return nil + } + return client +} + // configure configures a client with options. func (c *Client) configure() error { // Default decompressor values
cmd/go-getter/main.go+6 −1 modified@@ -16,6 +16,7 @@ import ( func main() { modeRaw := flag.String("mode", "any", "get mode (any, file, dir)") progress := flag.Bool("progress", false, "display terminal progress") + noSymlinks := flag.Bool("disable-symlinks", false, "prevent copying or writing files through symlinks") flag.Parse() args := flag.Args() if len(args) < 2 { @@ -54,12 +55,16 @@ func main() { if *progress { req.ProgressListener = defaultProgressBar } - wg := sync.WaitGroup{} wg.Add(1) client := getter.DefaultClient + // Disable symlinks for all client requests + if *noSymlinks { + client.DisableSymlinks = true + } + getters := getter.Getters getters = append(getters, new(gcs.Getter)) getters = append(getters, new(s3.Getter))
copy_dir.go+8 −2 modified@@ -11,7 +11,7 @@ import ( // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { src, err := filepath.EvalSymlinks(src) if err != nil { return err @@ -34,6 +34,12 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } } + if disableSymlinks { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + } + // The "path" has the src prefixed to it. We need to join our // destination with the path without the src on it. dstPath := filepath.Join(dst, path[len(src):]) @@ -54,7 +60,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err }
detect_test.go+6 −5 modified@@ -6,11 +6,12 @@ import ( ) func TestDetect(t *testing.T) { - gitGetter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + gitGetter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } cases := []struct { Input string
gcs/get_gcs.go+27 −1 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "cloud.google.com/go/storage" "github.com/hashicorp/go-getter/v2" @@ -15,10 +16,21 @@ import ( // Getter is a Getter implementation that will download a module from // a GCS bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(u) if err != nil { @@ -54,6 +66,13 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { } func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil { @@ -111,6 +130,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, err := g.parseURL(req.URL()) if err != nil {
get_file_copy.go+14 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,19 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err
get_git.go+23 −12 modified@@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { Detectors []Detector + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -71,10 +76,16 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { req.u.RawQuery = q.Encode() } + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -121,7 +132,7 @@ func (g *GitGetter) Get(ctx context.Context, req *Request) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(req.Dst, ref); err != nil { + if err := g.checkout(ctx, req.Dst, ref); err != nil { return err } } @@ -163,8 +174,8 @@ func (g *GitGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -192,18 +203,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -226,9 +237,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -278,13 +289,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err }
get_git_test.go+125 −17 modified@@ -4,6 +4,8 @@ import ( "bytes" "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -342,12 +344,13 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + ctx := context.Background() + if err := checkGitVersion(ctx, "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(ctx, "1.9"); err != nil { t.Fatal(err) } } @@ -411,11 +414,12 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { GetMode: ModeDir, } - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } client := &Client{ Getters: []Getter{getter}, @@ -623,11 +627,12 @@ func TestGitGetter_GitHubDetector(t *testing.T) { } pwd := "/pwd" - f := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + f := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for i, tc := range cases { req := &Request{ @@ -704,11 +709,12 @@ func TestGitGetter_Detector(t *testing.T) { } pwd := "/pwd" - getter := &GitGetter{[]Detector{ - new(GitDetector), - new(BitBucketDetector), - new(GitHubDetector), - }, + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(BitBucketDetector), + new(GitHubDetector), + }, } for _, tc := range cases { t.Run(tc.Input, func(t *testing.T) { @@ -730,6 +736,108 @@ func TestGitGetter_Detector(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink("/etc/passwd", path); err != nil { + t.Fatal(err) + } + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + DisableSymlinks: true, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } + } +} + +func TestGitGetter_subdirectory_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + req := &Request{ + Src: u.String(), + Dst: dst, + Pwd: ".", + GetMode: ModeDir, + } + + getter := &GitGetter{ + Detectors: []Detector{ + new(GitDetector), + new(GitHubDetector), + }, + } + client := &Client{ + Getters: []Getter{getter}, + } + + ctx := context.Background() + _, err = client.Get(ctx, req) + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T
get.go+7 −6 modified@@ -76,12 +76,13 @@ func init() { // The order of the Getters in the list may affect the result // depending if the Request.Src is detected as valid by multiple getters Getters = []Getter{ - &GitGetter{[]Detector{ - new(GitHubDetector), - new(GitDetector), - new(BitBucketDetector), - new(GitLabDetector), - }, + &GitGetter{ + Detectors: []Detector{ + new(GitHubDetector), + new(GitDetector), + new(BitBucketDetector), + new(GitLabDetector), + }, }, new(HgGetter), new(SmbClientGetter),
get_hg.go+21 −8 modified@@ -8,14 +8,20 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/v2/helper/url" safetemp "github.com/hashicorp/go-safetemp" ) // HgGetter is a Getter implementation that will download a module from // a Mercurial repository. -type HgGetter struct{} +type HgGetter struct { + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Defaults to zero which means no timeout. + Timeout time.Duration +} func (g *HgGetter) Mode(ctx context.Context, _ *url.URL) (Mode, error) { return ModeDir, nil @@ -49,13 +55,20 @@ func (g *HgGetter) Get(ctx context.Context, req *Request) error { if err != nil && !os.IsNotExist(err) { return err } + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if err != nil { - if err := g.clone(req.Dst, newURL); err != nil { + if err := g.clone(ctx, req.Dst, newURL); err != nil { return err } } - if err := g.pull(req.Dst, newURL); err != nil { + if err := g.pull(ctx, req.Dst, newURL); err != nil { return err } @@ -102,21 +115,21 @@ func (g *HgGetter) GetFile(ctx context.Context, req *Request) error { return fg.GetFile(ctx, req) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...)
get_hg_test.go+104 −0 modified@@ -2,10 +2,13 @@ package getter import ( "context" + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" + "time" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -118,3 +121,104 @@ func TestHgGetter_GetFile(t *testing.T) { } testing_helper.AssertContents(t, dst, "Hello\n") } +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + + tc := []struct { + name string + req Request + errChk func(testing.TB, error) + }{ + { + // If arguments are allowed in the destination, this request to Get will fail + name: "arguments allowed in destination", + req: Request{ + Dst: "--config=alias.clone=!touch ./TEST", + u: testModuleURL("basic-hg"), + }, + errChk: func(t testing.TB, err error) { + if err != nil { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + name: "arguments passed into rev parameter", + req: Request{ + u: testModuleURL("basic-hg?rev=--config=alias.update=!false"), + }, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "hg: parse error") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + { + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpreted the argument passed in the URL + name: "arguments passed in the repository URL", + req: Request{ + u: &url.URL{Path: "--config=alias.clone=false"}}, + errChk: func(t testing.TB, err error) { + if err == nil { + return + } + + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Errorf("Expected no err, got: %s", err) + } + }, + }, + } + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + g := new(HgGetter) + + if tt.req.Dst == "" { + dst := testing_helper.TempDir(t) + tt.req.Dst = dst + } + + defer os.RemoveAll(tt.req.Dst) + err := g.Get(ctx, &tt.req) + tt.errChk(t, err) + }) + } +} + +func TestHgGetter_GetWithTimeout(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + ctx := context.Background() + g := &HgGetter{ + Timeout: 1 * time.Millisecond, + } + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testModuleURL("basic-hg/foo.txt"), + } + + if err := g.Get(ctx, req); err == nil { + t.Fatalf("err: %s", err.Error()) + } +}
get_http.go+305 −49 modified@@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "time" safetemp "github.com/hashicorp/go-safetemp" ) @@ -26,7 +27,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -49,6 +52,35 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { @@ -58,7 +90,112 @@ func (g *HttpGetter) Mode(ctx context.Context, u *url.URL) (Mode, error) { return ModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(ctx context.Context, req *Request) error { + // Optionally disable any X-Terraform-Get redirects. This is recommended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } // Copy the URL so we can modify it var newU url.URL = *req.u req.u = &newU @@ -70,17 +207,33 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + g.Client = httpClient + } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := req.u.Query() q.Add("terraform-get", "1") req.u.RawQuery = q.Encode() + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - httpReq, err := http.NewRequestWithContext(ctx, "GET", req.u.String(), nil) + httpReq, err := http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } @@ -92,40 +245,53 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Get client with configured Getters from the context + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configured, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + var getterClient *Client + if v := ClientFromContext(ctx); v != nil { + getterClient = v + } else { + getterClient = &Client{ + Getters: []Getter{g}, + } + } + // Extract the source URL var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } } + if source == "" { return fmt.Errorf("no source URL was returned") } - // If there is a subdir component, then we download the root separately - // into a temporary directory, then copy over the proper subdir. - source, subDir := SourceDirSubdir(source) - req = &Request{ - GetMode: ModeDir, - Src: source, - Dst: req.Dst, - } - if subDir == "" { - _, err = DefaultClient.Get(ctx, req) - return err - } - // We have a subdir, time to jump some hoops - return g.getSubdir(ctx, req, source, subDir) + return g.getXTerraformSource(ctx, req, source, getterClient) } // GetFile fetches the file from src and stores it at dst. @@ -135,6 +301,11 @@ func (g *HttpGetter) Get(ctx context.Context, req *Request) error { // falsely identified as being replaced, or corrupted with extra bytes // appended. func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(req.u); err != nil { @@ -157,38 +328,67 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { } var currentFileSize int64 + var httpReq *http.Request + + if g.DoNotCheckHeadFirst == false { + headCtx := ctx - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - httpReq, err := http.NewRequestWithContext(ctx, "HEAD", req.u.String(), nil) + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + httpReq, err = http.NewRequestWithContext(headCtx, "HEAD", req.u.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + httpReq.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(httpReq) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } + } + } + } + } + } + } + + readCtx := ctx + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + httpReq, err = http.NewRequestWithContext(readCtx, "GET", req.u.String(), nil) if err != nil { return err } if g.Header != nil { httpReq.Header = g.Header.Clone() } - headResp, err := g.Client.Do(httpReq) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) - if currentFileSize >= headResp.ContentLength { - // file already present - return nil - } - } - } - } - } + if currentFileSize > 0 { + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) } - httpReq.Method = "GET" resp, err := g.Client.Do(httpReq) if err != nil { @@ -204,23 +404,70 @@ func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(readCtx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if req.ProgressListener != nil { // track download fn := filepath.Base(req.u.EscapedPath()) body = req.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } return err } +// getXTerraformSource downloads the source into the destination +// using a protocol switching capable client. +func (g *HttpGetter) getXTerraformSource(ctx context.Context, req *Request, source string, client *Client) error { + + // If there is a subdir component, then we download the root separately + // into a temporary directory, then copy over the proper subdir. + source, subDir := SourceDirSubdir(source) + req = &Request{ + GetMode: ModeDir, + Src: source, + Dst: req.Dst, + DisableSymlinks: req.DisableSymlinks, + } + + if subDir == "" { + // We have a X-Terraform-Get source lets check for supported Getters + var allowed bool + for _, getter := range client.Getters { + shouldDownload, err := Detect(req, getter) + if err != nil { + return fmt.Errorf("failed to detect the proper Getter to handle %s: %w", source, err) + } + if !shouldDownload { + // the request should not be processed by that getter + continue + } + allowed = true + } + + if !allowed { + protocol := strings.Split(source, ":")[0] + return fmt.Errorf("download not supported for scheme %q", protocol) + } + + _, err := client.Get(ctx, req) + return err + } + + // We have a subdir, time to jump some hoops + return g.getSubdir(ctx, req, source, subDir, client) + +} + // getSubdir downloads the source into the destination, but with // the proper subdir. -func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string) error { +func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir string, client *Client) error { // Create a temporary directory to store the full source. This has to be // a non-existent directory. td, tdcloser, err := safetemp.Dir("", "getter") @@ -229,8 +476,13 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir } defer tdcloser.Close() - // Download that into the given directory - if _, err := Get(ctx, td, source); err != nil { + tdReq := &Request{ + Src: source, + Dst: td, + GetMode: ModeDir, + DisableSymlinks: req.DisableSymlinks, + } + if _, err := client.Get(ctx, tdReq); err != nil { return err } @@ -256,18 +508,22 @@ func (g *HttpGetter) getSubdir(ctx context.Context, req *Request, source, subDir return err } - return copyDir(ctx, req.Dst, sourcePath, false, req.umask()) + return copyDir(ctx, req.Dst, sourcePath, false, req.DisableSymlinks, req.umask()) } // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF {
get_http_test.go+601 −28 modified@@ -9,13 +9,15 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" "strconv" "strings" "testing" + cleanhttp "github.com/hashicorp/go-cleanhttp" testing_helper "github.com/hashicorp/go-getter/v2/helper/testing" ) @@ -38,12 +40,26 @@ func TestHttpGetter_header(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -103,12 +119,26 @@ func TestHttpGetter_meta(t *testing.T) { u.Path = "/meta" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -134,12 +164,26 @@ func TestHttpGetter_metaSubdir(t *testing.T) { u.Path = "/meta-subdir" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -165,12 +209,26 @@ func TestHttpGetter_metaSubdirGlob(t *testing.T) { u.Path = "/meta-subdir-glob" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "error downloading") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -361,12 +419,26 @@ func TestHttpGetter_auth(t *testing.T) { u.User = url.UserPassword("foo", "bar") req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -397,12 +469,26 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer tempEnv(t, "NETRC", path)() req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -442,14 +528,29 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Path = "/header" req := &Request{ - Dst: dst, - u: &u, + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, } - // Get it! - if err := g.Get(ctx, req); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(ctx, req) + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("unexpected error: %v", err) + } + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: []Getter{ + g, + new(FileGetter), + }, + } + + if _, err = c.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter__RespectsContextCanceled(t *testing.T) { @@ -491,6 +592,432 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.Client = &http.Client{} + + req := Request{ + Dst: dst, + u: &u, + GetMode: ModeDir, + } + + err := g.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := testing_helper.TempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Getters: []Getter{g}, + } + + req := Request{ + Dst: dst, + Src: u.String(), + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + tc := []struct { + name string + configuredGetters []Getter + errExpected bool + }{ + {name: "configured getter for git protocol switch", configuredGetters: []Getter{new(GitGetter)}, errExpected: false}, + {name: "configured getter for multiple protocol switch", configuredGetters: []Getter{new(GitGetter), new(HgGetter), new(FileGetter)}, errExpected: false}, + {name: "configured getter for file protocol switch", configuredGetters: []Getter{new(FileGetter)}, errExpected: true}, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + + dst := testing_helper.TempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Getters: []Getter{g}, + } + client.Getters = append(client.Getters, tt.configuredGetters...) + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeDir, + } + + _, err := client.Get(ctx, &req) + // For configured getters that support git, the git repository doesn't exist so error will not be nil. + // If we get a nil error when we expect one other than the git error git exited with -1 we should fail. + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + // We only care about the error messages that indicate that we can download the git header URL + if tt.errExpected && err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + } + }) + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := testing_helper.TempDir(t) + + g := new(HttpGetter) + g.MaxBytes = 10 + g.DoNotCheckHeadFirst = true + + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("%v", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeFile, + } + + _, err := client.Get(ctx, &req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerSubDir(t) + defer ln.Close() + + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + g := new(HttpGetter) + client := &Client{ + Getters: []Getter{g}, + } + + t.Logf("url: %q", u.String()) + + req := Request{ + Dst: dst, + Src: u.String(), + GetMode: ModeAny, + } + + _, err = client.Get(ctx, &req) + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func TestHttpGetter_XTerraformWithClientFromContext(t *testing.T) { + tc := []struct { + name string + client *Client + errExpected bool + }{ + { + name: "default getters", + client: &Client{ + Getters: Getters, + }, + errExpected: false, + }, + { + name: "client configured with needed getters", + client: &Client{ + Getters: []Getter{ + new(HttpGetter), + new(FileGetter), + }, + }, + errExpected: false, + }, + { + name: "nil client", + errExpected: true, + }, + } + + for _, tt := range tc { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ln := testHttpServer(t) + defer ln.Close() + ctx := context.Background() + + g := new(HttpGetter) + dst := testing_helper.TempDir(t) + defer os.RemoveAll(dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/header" + + req := &Request{ + Dst: dst, + Src: u.String(), + u: &u, + GetMode: ModeDir, + } + + // Using a client stored in the ctx with a file getter should work + ctx = NewContextWithClient(ctx, tt.client) + + err := g.Get(ctx, req) + if tt.errExpected && err == nil { + t.Fatalf("error expected") + } + + if err != nil { + if !strings.Contains(err.Error(), "download not supported for scheme") { + t.Fatalf("expected download not supported for scheme, got: %v", err) + } + return + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -515,6 +1042,29 @@ func testHttpServer(t *testing.T) net.Listener { return ln } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerExpectHeader(w http.ResponseWriter, r *http.Request) { if expected, ok := r.URL.Query()["expected"]; ok { if r.Header.Get(expected[0]) != "" { @@ -598,6 +1148,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` <html> <head>
get_test.go+22 −1 modified@@ -348,6 +348,27 @@ func TestGetFile_archive(t *testing.T) { // Verify the main file exists testing_helper.AssertContents(t, dst, "Hello\n") } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := testing_helper.TempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + ctx := context.Background() + op, err := GetAny(ctx, dst, u) + + if op != nil { + t.Fatalf("unexpected op: %v", op) + } + + if err == nil { + t.Fatalf("expected error") + } + + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} func TestGetFile_archiveChecksum(t *testing.T) { ctx := context.Background() @@ -763,7 +784,7 @@ func TestGetFile_inplace_badChecksum(t *testing.T) { } } -func TestgetForcedGetter(t *testing.T) { +func TestGetForcedGetter(t *testing.T) { type args struct { src string }
helper/testing/utils.go+1 −1 modified@@ -34,7 +34,7 @@ func AssertContents(t *testing.T, path string, contents string) { } if !reflect.DeepEqual(data, []byte(contents)) { - t.Fatalf("bad. expected:\n\n%s\n\nGot:\n\n%s", contents, string(data)) + t.Fatalf("bad. expected:\n\n%q\n\nGot:\n\n%q", contents, string(data)) } }
README.md+98 −16 modified@@ -21,8 +21,8 @@ URLs. For example: "github.com/hashicorp/go-getter" would turn into a Git URL. Or "./foo" would turn into a file URL. These are extensible. This library is used by [Terraform](https://terraform.io) for -downloading modules and [Nomad](https://nomadproject.io) for downloading -binaries. +downloading modules, [Packer](https://packer.io) for downloading binaries, and +[Nomad](https://nomadproject.io) for downloading binaries. ## Installation and Usage @@ -47,6 +47,16 @@ $ go-getter github.com/foo/bar ./foo The command is useful for verifying URL structures. +## Security +Fetching resources from user-supplied URLs is an inherently dangerous operation and may +leave your application vulnerable to [server side request forgery](https://owasp.org/www-community/attacks/Server_Side_Request_Forgery), +[path traversal](https://owasp.org/www-community/attacks/Path_Traversal), [denial of service](https://owasp.org/www-community/attacks/Denial_of_Service) +or other security flaws. + +go-getter contains mitigations for some of these security issues, but should still be used with +caution in security-critical contexts. See the available [security options](#Security-Options) that +can be configured to mitigate some of these risks. + ## URL Format go-getter uses a single string URL as input to download from a variety of @@ -83,7 +93,7 @@ is built-in by default: file URLs. * GitHub URLs, such as "github.com/mitchellh/vagrant" are automatically changed to Git protocol over HTTP. - * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically + * GitLab URLs, such as "gitlab.com/inkscape/inkscape" are automatically changed to Git protocol over HTTP. * BitBucket URLs, such as "bitbucket.org/mitchellh/vagrant" are automatically changed to a Git or mercurial protocol using the BitBucket API. @@ -178,7 +188,7 @@ checksum string. Examples: ``` ./foo.txt?checksum=file:./foo.txt.sha256sum ``` - + When checksumming from a file - ex: with `checksum=file:url` - go-getter will get the file linked in the URL after `file:` using the same configuration. For example, in `file:http://releases.ubuntu.com/cosmic/MD5SUMS` go-getter will @@ -279,7 +289,7 @@ None from a private key file on disk, you would run `base64 -w0 <file>`. **Note**: Git 2.3+ is required to use this feature. - + * `depth` - The Git clone depth. The provided number specifies the last `n` revisions to clone from the repository. @@ -374,35 +384,107 @@ files from a smb shared folder whenever the url is prefixed with `smb://`. ⚠️ The [`smbclient`](https://www.samba.org/samba/docs/current/man-html/smbclient.1.html) command is available only for Linux. This is the ONLY option for a Linux user and therefore the client must be installed. - + The `smbclient` cli is not available for Windows and MacOS. The go-getter will try to get files using the file system, when this happens the getter uses the FileGetter implementation. -When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will +When connecting to a smb server, the OS creates a local mount in a system specific volume folder, and go-getter will try to access the following folders when looking for local mounts. - MacOS: /Volumes/<shared_path> - Windows: \\\\\<host>\\\<shared_path> -The following examples work for all the OSes: +The following examples work for all the OSes: - smb://host/shared/dir (downloads directory content) -- smb://host/shared/dir/file (downloads file) +- smb://host/shared/dir/file (downloads file) -The following examples work for Linux: +The following examples work for Linux: - smb://username:password@host/shared/dir (downloads directory content) - smb://username@host/shared/dir - smb://username:password@host/shared/dir/file (downloads file) - smb://username@host/shared/dir/file ⚠️ The above examples also work on the other OSes but the authentication is not used to access the file system. - - + + #### SMB Testing The test for `get_smb.go` requires a smb server running which can be started inside a docker container by -running `make start-smb`. Once the container is up the shared folder can be accessed via `smb://<ip|name>/public/<dir|file>` or -`smb://user:password@<ip|name>/private/<dir|file>` by another container or machine in the same network. +running `make start-smb`. Once the container is up the shared folder can be accessed via `smb://<ip|name>/public/<dir|file>` or +`smb://user:password@<ip|name>/private/<dir|file>` by another container or machine in the same network. -To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some +To run the tests inside `get_smb_test.go` and `client_test.go`, prepare the environment with `make smbtests-prepare`. On prepare some mock files and directories will be added to the shared folder and a go-getter container will start together with the samba server. -Once the environment for testing is prepared, run `make smbtests` to run the tests. \ No newline at end of file +Once the environment for testing is prepared, run `make smbtests` to run the tests. + +### Security Options + +**Disable Symlinks** + +In your getter client config, we recommend using the `DisableSymlinks` option, +which prevents writing through or copying from symlinks (which may point outside the directory). + +```go +client := getter.Client{ + // This will prevent copying or writing files through symlinks + DisableSymlinks: true, +} +``` + +**Disable or Limit `X-Terraform-Get`** + +Go-Getter supports arbitrary redirects via the `X-Terraform-Get` header. This functionality +exists to support [Terraform use cases](https://www.terraform.io/language/modules/sources#http-urls), +but is likely not needed in most applications. + +For code that uses the `HttpGetter`, add the following configuration options: + +```go +var httpGetter = &getter.HttpGetter{ + // Most clients should disable X-Terraform-Get + // See the note below + XTerraformGetDisabled: true, + // Your software probably doesn’t rely on X-Terraform-Get, but + // if it does, you should set the above field to false, plus + // set XTerraformGet Limit to prevent endless redirects + // XTerraformGetLimit: 10, +} +``` + +**Enforce Timeouts** + +The `HttpGetter` supports timeouts and other resource-constraining configuration options. The `GitGetter` and `HgGetter` +only support timeouts. + +Configuration for the `HttpGetter`: + +```go +var httpGetter = &getter.HttpGetter{ + // Disable pre-fetch HEAD requests + DoNotCheckHeadFirst: true, + + // As an alternative to the above setting, you can + // set a reasonable timeout for HEAD requests + // HeadFirstTimeout: 10 * time.Second, + // Read timeout for HTTP operations + ReadTimeout: 30 * time.Second, + // Set the maximum number of bytes + // that can be read by the getter + MaxBytes: 500000000, // 500 MB +} +``` + +For code that uses the `GitGetter` or `HgGetter`, set the `Timeout` option: +```go +var gitGetter = &getter.GitGetter{ + // Set a reasonable timeout for git operations + Timeout: 5 * time.Minute, +} +``` + +```go +var hgGetter = &getter.HgGetter{ + // Set a reasonable timeout for hg operations + Timeout: 5 * time.Minute, +} +```
request.go+4 −0 modified@@ -58,6 +58,10 @@ type Request struct { // By default a no op progress listener is used. ProgressListener ProgressTracker + // Disable symlinks is used to prevent copying or writing files through symlinks. + // When set to true any copying or writing through symlinks will result in a ErrSymlinkCopy error. + DisableSymlinks bool + u *url.URL subDir, realDst string }
s3/get_s3.go+30 −4 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,9 +21,21 @@ import ( // Getter is a Getter implementation that will download a module from // a S3 bucket. -type Getter struct{} +type Getter struct { + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration +} func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -40,7 +53,7 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -64,6 +77,12 @@ func (g *Getter) Mode(ctx context.Context, u *url.URL) (getter.Mode, error) { func (g *Getter) Get(ctx context.Context, req *getter.Request) error { + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(req.URL()) if err != nil { @@ -105,7 +124,7 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { s3Req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(s3Req) + resp, err := client.ListObjectsWithContext(ctx, s3Req) if err != nil { return err } @@ -139,6 +158,13 @@ func (g *Getter) Get(ctx context.Context, req *getter.Request) error { } func (g *Getter) GetFile(ctx context.Context, req *getter.Request) error { + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(req.URL()) if err != nil { return err @@ -161,7 +187,7 @@ func (g *Getter) getObject(ctx context.Context, client *s3.S3, req *getter.Reque s3req.VersionId = aws.String(version) } - resp, err := client.GetObject(s3req) + resp, err := client.GetObjectWithContext(ctx, s3req) if err != nil { return err }
source.go+3 −1 modified@@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err }
a2ebce998f8dMultiple fixes for go-getter (#359)
16 files changed · +1169 −97
client.go+23 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -13,6 +14,9 @@ import ( safetemp "github.com/hashicorp/go-safetemp" ) +// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled. +var ErrSymlinkCopy = errors.New("copying of symlinks has been disabled") + // Client is a client for downloading things. // // Top-level functions such as Get are shortcuts for interacting with a client. @@ -76,6 +80,9 @@ type Client struct { // This is identical to tls.Config.InsecureSkipVerify. Insecure bool + // Disable symlinks + DisableSymlinks bool + Options []ClientOption } @@ -123,6 +130,17 @@ func (c *Client) Get() error { dst := c.Dst src, subDir := SourceDirSubdir(src) if subDir != "" { + // Check if the subdirectory is attempting to traverse updwards, outside of + // the cloned repository path. + subDir := filepath.Clean(subDir) + if containsDotDot(subDir) { + return fmt.Errorf("subdirectory component contain path traversal out of the repository") + } + // Prevent absolute paths, remove a leading path separator from the subdirectory + if subDir[0] == os.PathSeparator { + subDir = subDir[1:] + } + td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return err @@ -230,6 +248,10 @@ func (c *Client) Get() error { filename = v } + if containsDotDot(filename) { + return fmt.Errorf("filename query parameter contain path traversal") + } + dst = filepath.Join(dst, filename) } } @@ -318,7 +340,7 @@ func (c *Client) Get() error { return err } - return copyDir(c.Ctx, realDst, subDir, false, c.umask()) + return copyDir(c.Ctx, realDst, subDir, false, c.DisableSymlinks, c.umask()) } return nil
client_option.go+61 −7 modified@@ -1,46 +1,100 @@ package getter -import "context" +import ( + "context" + "os" +) -// A ClientOption allows to configure a client +// ClientOption is used to configure a client. type ClientOption func(*Client) error -// Configure configures a client with options. +// Configure applies all of the given client options, along with any default +// behavior including context, decompressors, detectors, and getters used by +// the client. func (c *Client) Configure(opts ...ClientOption) error { + // If the context has not been configured use the background context. if c.Ctx == nil { c.Ctx = context.Background() } + + // Store the options used to configure this client. c.Options = opts + + // Apply all of the client options. for _, opt := range opts { err := opt(c) if err != nil { return err } } - // Default decompressor values + + // If the client was not configured with any Decompressors, Detectors, + // or Getters, use the default values for each. if c.Decompressors == nil { c.Decompressors = Decompressors } - // Default detector values if c.Detectors == nil { c.Detectors = Detectors } - // Default getter values if c.Getters == nil { c.Getters = Getters } + // Set the client for each getter, so the top-level client can know + // the getter-specific client functions or progress tracking. for _, getter := range c.Getters { getter.SetClient(c) } + return nil } // WithContext allows to pass a context to operation // in order to be able to cancel a download in progress. -func WithContext(ctx context.Context) func(*Client) error { +func WithContext(ctx context.Context) ClientOption { return func(c *Client) error { c.Ctx = ctx return nil } } + +// WithDecompressors specifies which Decompressor are available. +func WithDecompressors(decompressors map[string]Decompressor) ClientOption { + return func(c *Client) error { + c.Decompressors = decompressors + return nil + } +} + +// WithDecompressors specifies which compressors are available. +func WithDetectors(detectors []Detector) ClientOption { + return func(c *Client) error { + c.Detectors = detectors + return nil + } +} + +// WithGetters specifies which getters are available. +func WithGetters(getters map[string]Getter) ClientOption { + return func(c *Client) error { + c.Getters = getters + return nil + } +} + +// WithMode specifies which client mode the getters should operate in. +func WithMode(mode ClientMode) ClientOption { + return func(c *Client) error { + c.Mode = mode + return nil + } +} + +// WithUmask specifies how to mask file permissions when storing local +// files or decompressing an archive. +func WithUmask(mode os.FileMode) ClientOption { + return func(c *Client) error { + c.Umask = mode + return nil + } +}
copy_dir.go+21 −3 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -16,8 +17,11 @@ func mode(mode, umask os.FileMode) os.FileMode { // should already exist. // // If ignoreDot is set to true, then dot-prefixed files/folders are ignored. -func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask os.FileMode) error { - src, err := filepath.EvalSymlinks(src) +func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, disableSymlinks bool, umask os.FileMode) error { + // We can safely evaluate the symlinks here, even if disabled, because they + // will be checked before actual use in walkFn and copyFile + var err error + src, err = filepath.EvalSymlinks(src) if err != nil { return err } @@ -26,6 +30,20 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask if err != nil { return err } + + if disableSymlinks { + fileInfo, err := os.Lstat(path) + if err != nil { + return fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return ErrSymlinkCopy + } + // if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // return ErrSymlinkCopy + // } + } + if path == src { return nil } @@ -59,7 +77,7 @@ func copyDir(ctx context.Context, dst string, src string, ignoreDot bool, umask } // If we have a file, copy the contents. - _, err = copyFile(ctx, dstPath, path, info.Mode(), umask) + _, err = copyFile(ctx, dstPath, path, disableSymlinks, info.Mode(), umask) return err }
get_file_copy.go+12 −1 modified@@ -2,6 +2,7 @@ package getter import ( "context" + "fmt" "io" "os" ) @@ -49,7 +50,17 @@ func copyReader(dst string, src io.Reader, fmode, umask os.FileMode) error { } // copyFile copies a file in chunks from src path to dst path, using umask to create the dst file -func copyFile(ctx context.Context, dst, src string, fmode, umask os.FileMode) (int64, error) { +func copyFile(ctx context.Context, dst, src string, disableSymlinks bool, fmode, umask os.FileMode) (int64, error) { + if disableSymlinks { + fileInfo, err := os.Lstat(src) + if err != nil { + return 0, fmt.Errorf("failed to check copy file source for symlinks: %w", err) + } + if fileInfo.Mode()&os.ModeSymlink == os.ModeSymlink { + return 0, ErrSymlinkCopy + } + } + srcF, err := os.Open(src) if err != nil { return 0, err
get_file_unix.go+7 −1 modified@@ -87,7 +87,13 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { return os.Symlink(path, dst) } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, fi.Mode(), g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, fi.Mode(), g.client.umask()) return err }
get_file_windows.go+7 −1 modified@@ -111,8 +111,14 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error { } } + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + // Copy - _, err = copyFile(ctx, dst, path, 0666, g.client.umask()) + _, err = copyFile(ctx, dst, path, disableSymlinks, 0666, g.client.umask()) return err }
get_gcs.go+26 −2 modified@@ -3,13 +3,15 @@ package getter import ( "context" "fmt" - "golang.org/x/oauth2" - "google.golang.org/api/option" "net/url" "os" "path/filepath" "strconv" "strings" + "time" + + "golang.org/x/oauth2" + "google.golang.org/api/option" "cloud.google.com/go/storage" "google.golang.org/api/iterator" @@ -19,11 +21,21 @@ import ( // a GCS bucket. type GCSGetter struct { getter + + // Timeout sets a deadline which all GCS operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -61,6 +73,12 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { func (g *GCSGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, _, err := g.parseURL(u) if err != nil { @@ -120,6 +138,12 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error { func (g *GCSGetter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL bucket, object, fragment, err := g.parseURL(u) if err != nil {
get_git.go+28 −16 modified@@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -24,6 +25,10 @@ import ( // a git repository. type GitGetter struct { getter + + // Timeout sets a deadline which all git CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) @@ -35,6 +40,13 @@ func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *GitGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("git"); err != nil { return fmt.Errorf("git must be available and on the PATH") } @@ -76,7 +88,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { var sshKeyFile string if sshKey != "" { // Check that the git version is sufficiently new. - if err := checkGitVersion("2.3"); err != nil { + if err := checkGitVersion(ctx, "2.3"); err != nil { return fmt.Errorf("Error using ssh key: %v", err) } @@ -123,7 +135,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } } @@ -161,8 +173,8 @@ func (g *GitGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *GitGetter) checkout(dst string, ref string) error { - cmd := exec.Command("git", "checkout", ref) +func (g *GitGetter) checkout(ctx context.Context, dst string, ref string) error { + cmd := exec.CommandContext(ctx, "git", "checkout", ref) cmd.Dir = dst return getRunCommand(cmd) } @@ -182,7 +194,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR originalRef := ref // we handle an unspecified ref differently than explicitly selecting the default branch below if ref == "" { - ref = findRemoteDefaultBranch(u) + ref = findRemoteDefaultBranch(ctx, u) } if depth > 0 { args = append(args, "--depth", strconv.Itoa(depth)) @@ -211,7 +223,7 @@ func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.UR // If we didn't add --depth and --branch above then we will now be // on the remote repository's default branch, rather than the selected // ref, so we'll need to fix that before we return. - return g.checkout(dst, originalRef) + return g.checkout(ctx, dst, originalRef) } return nil } @@ -226,18 +238,18 @@ func (g *GitGetter) update(ctx context.Context, dst, sshKeyFile, ref string, dep // Not a branch, switch to default branch. This will also catch // non-existent branches, in which case we want to switch to default // and then checkout the proper branch later. - ref = findDefaultBranch(dst) + ref = findDefaultBranch(ctx, dst) } // We have to be on a branch to pull - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(ctx, dst, ref); err != nil { return err } if depth > 0 { - cmd = exec.Command("git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--depth", strconv.Itoa(depth), "--ff-only") } else { - cmd = exec.Command("git", "pull", "--ff-only") + cmd = exec.CommandContext(ctx, "git", "pull", "--ff-only") } cmd.Dir = dst @@ -260,9 +272,9 @@ func (g *GitGetter) fetchSubmodules(ctx context.Context, dst, sshKeyFile string, // findDefaultBranch checks the repo's origin remote for its default branch // (generally "master"). "master" is returned if an origin default branch // can't be determined. -func findDefaultBranch(dst string) string { +func findDefaultBranch(ctx context.Context, dst string) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") + cmd := exec.CommandContext(ctx, "git", "branch", "-r", "--points-at", "refs/remotes/origin/HEAD") cmd.Dir = dst cmd.Stdout = &stdoutbuf err := cmd.Run() @@ -275,9 +287,9 @@ func findDefaultBranch(dst string) string { // findRemoteDefaultBranch checks the remote repo's HEAD symref to return the remote repo's // default branch. "master" is returned if no HEAD symref exists. -func findRemoteDefaultBranch(u *url.URL) string { +func findRemoteDefaultBranch(ctx context.Context, u *url.URL) string { var stdoutbuf bytes.Buffer - cmd := exec.Command("git", "ls-remote", "--symref", u.String(), "HEAD") + cmd := exec.CommandContext(ctx, "git", "ls-remote", "--symref", u.String(), "HEAD") cmd.Stdout = &stdoutbuf err := cmd.Run() matches := lsRemoteSymRefRegexp.FindStringSubmatch(stdoutbuf.String()) @@ -326,13 +338,13 @@ func setupGitEnv(cmd *exec.Cmd, sshKeyFile string) { // checkGitVersion is used to check the version of git installed on the system // against a known minimum version. Returns an error if the installed version // is older than the given minimum. -func checkGitVersion(min string) error { +func checkGitVersion(ctx context.Context, min string) error { want, err := version.NewVersion(min) if err != nil { return err } - out, err := exec.Command("git", "version").Output() + out, err := exec.CommandContext(ctx, "git", "version").Output() if err != nil { return err }
get_git_test.go+119 −2 modified@@ -2,7 +2,10 @@ package getter import ( "bytes" + "context" "encoding/base64" + "errors" + "fmt" "io/ioutil" "net/url" "os" @@ -436,12 +439,12 @@ func TestGitGetter_gitVersion(t *testing.T) { os.Setenv("PATH", dir) // Asking for a higher version throws an error - if err := checkGitVersion("2.3"); err == nil { + if err := checkGitVersion(context.Background(), "2.3"); err == nil { t.Fatal("expect git version error") } // Passes when version is satisfied - if err := checkGitVersion("1.9"); err != nil { + if err := checkGitVersion(context.Background(), "1.9"); err != nil { t.Fatal(err) } } @@ -693,6 +696,120 @@ func TestGitGetter_setupGitEnvWithExisting_sshKey(t *testing.T) { } } +func TestGitGetter_subdirectory_symlink(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + target, err := ioutil.TempFile("", "link-target") + if err != nil { + t.Fatal(err) + } + defer os.Remove(target.Name()) + + repo := testGitRepo(t, "repo-with-symlink") + innerDir := filepath.Join(repo.dir, "this-directory-contains-a-symlink") + if err := os.Mkdir(innerDir, 0700); err != nil { + t.Fatal(err) + } + path := filepath.Join(innerDir, "this-is-a-symlink") + if err := os.Symlink(target.Name(), path); err != nil { + t.Fatal(err) + } + + repo.git("add", path) + repo.git("commit", "-m", "Adding "+path) + + u, err := url.Parse(fmt.Sprintf("git::%s//this-directory-contains-a-symlink", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + Mode: ClientModeDir, + DisableSymlinks: true, + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + + if runtime.GOOS == "windows" { + // Windows doesn't handle symlinks as one might expect with git. + // + // https://github.com/git-for-windows/git/wiki/Symbolic-Links + filepath.Walk(dst, func(path string, info os.FileInfo, err error) error { + if strings.Contains(path, "this-is-a-symlink") { + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + // If you see this test fail in the future, you've probably enabled + // symlinks within git on your Windows system. Our CI/CD system does + // not do this, so this is this is the only way we can make this test + // make any sense. + t.Fatalf("windows git should not have cloned a symlink") + } + } + return nil + }) + } else { + // We can rely on POSIX compliant systems running git to do the right thing. + if err == nil { + t.Fatalf("expected client get to fail") + } + if !errors.Is(err, ErrSymlinkCopy) { + t.Fatalf("unexpected error: %v", err) + } + } + +} + +func TestGitGetter_subdirectory(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + g := new(GitGetter) + dst := tempDir(t) + + repo := testGitRepo(t, "empty-repo") + u, err := url.Parse(fmt.Sprintf("git::%s//../../../../../../etc/passwd", repo.url.String())) + if err != nil { + t.Fatal(err) + } + + client := &Client{ + Src: u.String(), + Dst: dst, + Pwd: ".", + + Mode: ClientModeDir, + + Detectors: []Detector{ + new(GitDetector), + }, + Getters: map[string]Getter{ + "git": g, + }, + } + + err = client.Get() + if err == nil { + t.Fatalf("expected client get to fail") + } + if !strings.Contains(err.Error(), "subdirectory component contain path traversal out of the repository") { + t.Fatalf("unexpected error: %v", err) + } +} + // gitRepo is a helper struct which controls a single temp git repo. type gitRepo struct { t *testing.T
get_hg.go+19 −7 modified@@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "runtime" + "time" urlhelper "github.com/hashicorp/go-getter/helper/url" safetemp "github.com/hashicorp/go-safetemp" @@ -17,6 +18,10 @@ import ( // a Mercurial repository. type HgGetter struct { getter + + // Timeout sets a deadline which all hg CLI operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { @@ -25,6 +30,13 @@ func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { func (g *HgGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + if _, err := exec.LookPath("hg"); err != nil { return fmt.Errorf("hg must be available and on the PATH") } @@ -53,12 +65,12 @@ func (g *HgGetter) Get(dst string, u *url.URL) error { return err } if err != nil { - if err := g.clone(dst, newURL); err != nil { + if err := g.clone(ctx, dst, newURL); err != nil { return err } } - if err := g.pull(dst, newURL); err != nil { + if err := g.pull(ctx, dst, newURL); err != nil { return err } @@ -101,21 +113,21 @@ func (g *HgGetter) GetFile(dst string, u *url.URL) error { return fg.GetFile(dst, u) } -func (g *HgGetter) clone(dst string, u *url.URL) error { - cmd := exec.Command("hg", "clone", "-U", u.String(), dst) +func (g *HgGetter) clone(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "clone", "-U", "--", u.String(), dst) return getRunCommand(cmd) } -func (g *HgGetter) pull(dst string, u *url.URL) error { - cmd := exec.Command("hg", "pull") +func (g *HgGetter) pull(ctx context.Context, dst string, u *url.URL) error { + cmd := exec.CommandContext(ctx, "hg", "pull") cmd.Dir = dst return getRunCommand(cmd) } func (g *HgGetter) update(ctx context.Context, dst string, u *url.URL, rev string) error { args := []string{"update"} if rev != "" { - args = append(args, rev) + args = append(args, "--", rev) } cmd := exec.CommandContext(ctx, "hg", args...)
get_hg_test.go+44 −0 modified@@ -1,9 +1,11 @@ package getter import ( + "net/url" "os" "os/exec" "path/filepath" + "strings" "testing" ) @@ -97,3 +99,45 @@ func TestHgGetter_GetFile(t *testing.T) { } assertContents(t, dst, "Hello\n") } + +func TestHgGetter_HgArgumentsNotAllowed(t *testing.T) { + if !testHasHg { + t.Log("hg not found, skipping") + t.Skip() + } + + g := new(HgGetter) + + // If arguments are allowed in the destination, this Get call will fail + dst := "--config=alias.clone=!false" + defer os.RemoveAll(dst) + err := g.Get(dst, testModuleURL("basic-hg")) + if err != nil { + t.Fatalf("Expected no err, got: %s", err) + } + + dst = tempDir(t) + // Test arguments passed into the `rev` parameter + // This clone call will fail regardless, but an exit code of 1 indicates + // that the `false` command executed + // We are expecting an hg parse error + err = g.Get(dst, testModuleURL("basic-hg?rev=--config=alias.update=!false")) + if err != nil { + if !strings.Contains(err.Error(), "hg: parse error") { + t.Fatalf("Expected no err, got: %s", err) + } + } + + dst = tempDir(t) + // Test arguments passed in the repository URL + // This Get call will fail regardless, but it should fail + // because the repository can't be found. + // Other failures indicate that hg interpretted the argument passed in the URL + err = g.Get(dst, &url.URL{Path: "--config=alias.clone=false"}) + if err != nil { + if !strings.Contains(err.Error(), "repository --config=alias.clone=false not found") { + t.Fatalf("Expected no err, got: %s", err) + } + } + +}
get_http.go+282 −38 modified@@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/hashicorp/go-cleanhttp" safetemp "github.com/hashicorp/go-safetemp" @@ -28,7 +29,9 @@ import ( // wish. The response must be a 2xx. // // First, a header is looked for "X-Terraform-Get" which should contain -// a source URL to download. +// a source URL to download. This source must use one of the configured +// protocols and getters for the client, or "http"/"https" if using +// the HttpGetter directly. // // If the header is not present, then a meta tag is searched for named // "terraform-get" and the content should be a source URL. @@ -52,6 +55,36 @@ type HttpGetter struct { // and as such it needs to be initialized before use, via something like // make(http.Header). Header http.Header + + // DoNotCheckHeadFirst configures the client to NOT check if the server + // supports HEAD requests. + DoNotCheckHeadFirst bool + + // HeadFirstTimeout configures the client to enforce a timeout when + // the server supports HEAD requests. + // + // The zero value means no timeout. + HeadFirstTimeout time.Duration + + // ReadTimeout configures the client to enforce a timeout when + // making a request to an HTTP server and reading its response body. + // + // The zero value means no timeout. + ReadTimeout time.Duration + + // MaxBytes limits the number of bytes that will be ready from an HTTP + // response body returned from a server. The zero value means no limit. + MaxBytes int64 + + // XTerraformGetLimit configures how many times the client with follow + // the " X-Terraform-Get" header value. + // + // The zero value means no limit. + XTerraformGetLimit int + + // XTerraformGetDisabled disables the client's usage of the "X-Terraform-Get" + // header value. + XTerraformGetDisabled bool } func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { @@ -61,8 +94,115 @@ func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } +type contextKey int + +const ( + xTerraformGetDisable contextKey = 0 + xTerraformGetLimit contextKey = 1 + xTerraformGetLimitCurrentValue contextKey = 2 + httpClientValue contextKey = 3 + httpMaxBytesValue contextKey = 4 +) + +func xTerraformGetDisabled(ctx context.Context) bool { + value, ok := ctx.Value(xTerraformGetDisable).(bool) + if !ok { + return false + } + return value +} + +func xTerraformGetLimitCurrentValueFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimitCurrentValue).(int) + if !ok { + return 1 + } + return value +} + +func xTerraformGetLimiConfiguredtFromContext(ctx context.Context) int { + value, ok := ctx.Value(xTerraformGetLimit).(int) + if !ok { + return 0 + } + return value +} + +func httpClientFromContext(ctx context.Context) *http.Client { + value, ok := ctx.Value(httpClientValue).(*http.Client) + if !ok { + return nil + } + return value +} + +func httpMaxBytesFromContext(ctx context.Context) int64 { + value, ok := ctx.Value(httpMaxBytesValue).(int64) + if !ok { + return 0 // no limit + } + return value +} + +type limitedWrappedReaderCloser struct { + underlying io.Reader + closeFn func() error +} + +func (l *limitedWrappedReaderCloser) Read(p []byte) (n int, err error) { + return l.underlying.Read(p) +} + +func (l *limitedWrappedReaderCloser) Close() (err error) { + return l.closeFn() +} + +func newLimitedWrappedReaderCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &limitedWrappedReaderCloser{ + underlying: io.LimitReader(r, limit), + closeFn: r.Close, + } +} + func (g *HttpGetter) Get(dst string, u *url.URL) error { ctx := g.Context() + + // Optionally disable any X-Terraform-Get redirects. This is reccomended for usage of + // this client outside of Terraform's. This feature is likely not required if the + // source server can provider normal HTTP redirects. + if g.XTerraformGetDisabled { + ctx = context.WithValue(ctx, xTerraformGetDisable, g.XTerraformGetDisabled) + } + + // Optionally enforce a limit on X-Terraform-Get redirects. We check this for every + // invocation of this function, because the value is not passed down to subsequent + // client Get function invocations. + if g.XTerraformGetLimit > 0 { + ctx = context.WithValue(ctx, xTerraformGetLimit, g.XTerraformGetLimit) + } + + // If there was a limit on X-Terraform-Get redirects, check what the current count value. + // + // If the value is greater than the limit, return an error. Otherwise, increment the value, + // and include it in the the context to be passed along in all the subsequent client + // Get function invocations. + if limit := xTerraformGetLimiConfiguredtFromContext(ctx); limit > 0 { + currentValue := xTerraformGetLimitCurrentValueFromContext(ctx) + + if currentValue > limit { + return fmt.Errorf("too many X-Terraform-Get redirects: %d", currentValue) + } + + currentValue++ + + ctx = context.WithValue(ctx, xTerraformGetLimitCurrentValue, currentValue) + } + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + // Copy the URL so we can modify it var newU url.URL = *u u = &newU @@ -74,22 +214,40 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { } } + // If the HTTP client is nil, check if there is one available in the context, + // otherwise create one using cleanhttp's default transport. if g.Client == nil { - g.Client = httpClient - if g.client != nil && g.client.Insecure { - insecureTransport := cleanhttp.DefaultTransport() - insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - g.Client.Transport = insecureTransport + if client := httpClientFromContext(ctx); client != nil { + g.Client = client + } else { + client := httpClient + if g.client != nil && g.client.Insecure { + insecureTransport := cleanhttp.DefaultTransport() + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + client.Transport = insecureTransport + } + g.Client = client } } + // Pass along the configured HTTP client in the context for usage with the X-Terraform-Get feature. + ctx = context.WithValue(ctx, httpClientValue, g.Client) + // Add terraform-get to the parameter. q := u.Query() q.Add("terraform-get", "1") u.RawQuery = q.Encode() + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + // Get the URL - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(readCtx, "GET", u.String(), nil) if err != nil { return err } @@ -102,18 +260,28 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { if err != nil { return err } - defer resp.Body.Close() + + body := resp.Body + + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("bad response code: %d", resp.StatusCode) } - // Extract the source URL + if disabled := xTerraformGetDisabled(ctx); disabled { + return nil + } + + // Extract the source URL, var source string if v := resp.Header.Get("X-Terraform-Get"); v != "" { source = v } else { - source, err = g.parseMeta(resp.Body) + source, err = g.parseMeta(readCtx, body) if err != nil { return err } @@ -127,9 +295,43 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { source, subDir := SourceDirSubdir(source) if subDir == "" { var opts []ClientOption + + // Check if the protocol was switched to one which was not configured. + // + // Otherwise, all default getters are allowed. + if g.client != nil && g.client.Getters != nil { + protocol := strings.Split(source, ":")[0] + _, allowed := g.client.Getters[protocol] + if !allowed { + return fmt.Errorf("no getter available for X-Terraform-Get source protocol: %q", protocol) + } + } + + // Add any getter client options. if g.client != nil { opts = g.client.Options } + + // If the client is nil, we know we're using the HttpGetter directly. In this case, + // we don't know exactly which protocols are configued, but we can make a good guess. + // + // This prevents all default getters from being allowed when only using the + // HttpGetter directly. To enable protocol switching, a client "wrapper" must + // be used. + if g.client == nil { + opts = append(opts, WithGetters(map[string]Getter{ + "http": g, + "https": g, + })) + } + + // Ensure we pass along the context we constructed in this function. + // + // This is especially important to enforce a limit on X-Terraform-Get redirects + // which could be setup, if configured, at the top of this function. + opts = append(opts, WithContext(ctx)) + + // Note: this allows the protocol to be switched to another configured getters. return Get(dst, source, opts...) } @@ -145,6 +347,12 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // appended. func (g *HttpGetter) GetFile(dst string, src *url.URL) error { ctx := g.Context() + + // Optionally enforce a maxiumum HTTP response body size. + if g.MaxBytes > 0 { + ctx = context.WithValue(ctx, httpMaxBytesValue, g.MaxBytes) + } + if g.Netrc { // Add auth from netrc if we can if err := addAuthFromNetrc(src); err != nil { @@ -171,39 +379,61 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } } - var currentFileSize int64 + var ( + currentFileSize int64 + req *http.Request + ) - // We first make a HEAD request so we can check - // if the server supports range queries. If the server/URL doesn't - // support HEAD requests, we just fall back to GET. - req, err := http.NewRequestWithContext(ctx, "HEAD", src.String(), nil) - if err != nil { - return err - } - if g.Header != nil { - req.Header = g.Header.Clone() - } - headResp, err := g.Client.Do(req) - if err == nil { - headResp.Body.Close() - if headResp.StatusCode == 200 { - // If the HEAD request succeeded, then attempt to set the range - // query if we can. - if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { - if fi, err := f.Stat(); err == nil { - if _, err = f.Seek(0, io.SeekEnd); err == nil { - currentFileSize = fi.Size() - if currentFileSize >= headResp.ContentLength { - // file already present - return nil + if !g.DoNotCheckHeadFirst { + headCtx := ctx + + if g.HeadFirstTimeout > 0 { + var cancel context.CancelFunc + + headCtx, cancel = context.WithTimeout(ctx, g.HeadFirstTimeout) + defer cancel() + } + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + req, err = http.NewRequestWithContext(headCtx, "HEAD", src.String(), nil) + if err != nil { + return err + } + if g.Header != nil { + req.Header = g.Header.Clone() + } + headResp, err := g.Client.Do(req) + if err == nil { + headResp.Body.Close() + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" && headResp.ContentLength >= 0 { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, io.SeekEnd); err == nil { + currentFileSize = fi.Size() + if currentFileSize >= headResp.ContentLength { + // file already present + return nil + } } } } } } } - req, err = http.NewRequestWithContext(ctx, "GET", src.String(), nil) + readCtx := ctx + + if g.ReadTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, g.ReadTimeout) + defer cancel() + } + + req, err = http.NewRequestWithContext(readCtx, "GET", src.String(), nil) if err != nil { return err } @@ -228,14 +458,18 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { body := resp.Body + if maxBytes := httpMaxBytesFromContext(ctx); maxBytes > 0 { + body = newLimitedWrappedReaderCloser(body, maxBytes) + } + if g.client != nil && g.client.ProgressListener != nil { // track download fn := filepath.Base(src.EscapedPath()) body = g.client.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) } defer body.Close() - n, err := Copy(ctx, f, body) + n, err := Copy(readCtx, f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } @@ -284,18 +518,28 @@ func (g *HttpGetter) getSubdir(ctx context.Context, dst, source, subDir string) return err } - return copyDir(ctx, dst, sourcePath, false, g.client.umask()) + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + + return copyDir(ctx, dst, sourcePath, false, disableSymlinks, g.client.umask()) } // parseMeta looks for the first meta tag in the given reader that // will give us the source URL. -func (g *HttpGetter) parseMeta(r io.Reader) (string, error) { +func (g *HttpGetter) parseMeta(ctx context.Context, r io.Reader) (string, error) { d := xml.NewDecoder(r) d.CharsetReader = charsetReader d.Strict = false var err error var t xml.Token for { + if ctx.Err() != nil { + return "", fmt.Errorf("context error while parsing meta tag: %w", ctx.Err()) + } + t, err = d.Token() if err != nil { if err == io.EOF {
get_http_test.go+473 −14 modified@@ -9,12 +9,15 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "net/url" "os" "path/filepath" "strconv" "strings" "testing" + + "github.com/hashicorp/go-cleanhttp" ) func TestHttpGetter_impl(t *testing.T) { @@ -34,8 +37,27 @@ func TestHttpGetter_header(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -44,6 +66,7 @@ func TestHttpGetter_header(t *testing.T) { if _, err := os.Stat(mainPath); err != nil { t.Fatalf("err: %s", err) } + } func TestHttpGetter_requestHeader(t *testing.T) { @@ -87,8 +110,27 @@ func TestHttpGetter_meta(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -330,14 +372,27 @@ func TestHttpGetter_auth(t *testing.T) { u.Path = "/meta-auth" u.User = url.UserPassword("foo", "bar") - // Get it! - if err := g.Get(dst, &u); err != nil { - t.Fatalf("err: %s", err) + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) } - // Verify the main file exists - mainPath := filepath.Join(dst, "main.tf") - if _, err := os.Stat(mainPath); err != nil { + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -360,8 +415,27 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer closer() defer tempEnv(t, "NETRC", path)() - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } @@ -399,8 +473,27 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" - // Get it! - if err := g.Get(dst, &u); err != nil { + // Get it, which should error because it uses the file protocol. + err := g.Get(dst, &u) + + if !strings.Contains(err.Error(), "download not supported for scheme 'file'") { + t.Fatalf("unexpected error: %v", err) + } + + // But, using a wrapper client with a file getter will work. + c := &Client{ + Getters: map[string]Getter{ + "http": g, + "file": new(FileGetter), + }, + Src: u.String(), + Dst: dst, + Mode: ClientModeDir, + } + + err = c.Get() + + if err != nil { t.Fatalf("err: %s", err) } } @@ -441,6 +534,326 @@ func TestHttpGetter__RespectsContextCanceled(t *testing.T) { } } +func TestHttpGetter__XTerraformGetLimit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetLimit = 10 + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if !strings.Contains(err.Error(), "too many X-Terraform-Get redirects") { + t.Fatalf("too many X-Terraform-Get redirects, got: %v", err) + } +} + +func TestHttpGetter__XTerraformGetDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetLoop(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/loop" + dst := tempDir(t) + + g := new(HttpGetter) + g.XTerraformGetDisabled = true + g.client = &Client{ + Ctx: ctx, + } + g.Client = &http.Client{} + + err := g.Get(dst, &u) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetProxyBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetProxyBypass(t) + + proxyLn := testHttpServerProxy(t, ln.Addr().String()) + + t.Logf("starting malicious server on: %v", ln.Addr().String()) + t.Logf("starting proxy on: %v", proxyLn.Addr().String()) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + proxy, err := url.Parse(fmt.Sprintf("http://%s/", proxyLn.Addr().String())) + if err != nil { + t.Fatalf("failed to parse proxy URL: %v", err) + } + + transport := cleanhttp.DefaultTransport() + transport.Proxy = http.ProxyURL(proxy) + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: transport, + } + + client := &Client{ + Ctx: ctx, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + client.Src = u.String() + client.Dst = dst + + err = client.Get() + if err != nil { + t.Logf("client get error: %v", err) + } +} + +func TestHttpGetter__XTerraformGetConfiguredGettersBypass(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithXTerraformGetConfiguredGettersBypass(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/start" + dst := tempDir(t) + + rt := hookableHTTPRoundTripper{ + before: func(req *http.Request) { + t.Logf("making request") + }, + RoundTripper: http.DefaultTransport, + } + + httpGetter := new(HttpGetter) + httpGetter.XTerraformGetLimit = 10 + httpGetter.Client = &http.Client{ + Transport: &rt, + } + + client := &Client{ + Ctx: ctx, + Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + if !strings.Contains(err.Error(), "no getter available for X-Terraform-Get source protocol") { + t.Fatalf("expected no getter available for X-Terraform-Get source protocol, got: %v", err) + } + } +} + +func TestHttpGetter__endless_body(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := testHttpServerWithEndlessBody(t) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/" + dst := tempDir(t) + + httpGetter := new(HttpGetter) + httpGetter.MaxBytes = 10 + httpGetter.DoNotCheckHeadFirst = true + + client := &Client{ + Ctx: ctx, + Mode: ClientModeFile, + // Mode: ClientModeDir, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + t.Logf("%v", u.String()) + + client.Src = u.String() + client.Dst = dst + + err := client.Get() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHttpGetter_subdirLink(t *testing.T) { + ln := testHttpServerSubDir(t) + defer ln.Close() + + httpGetter := new(HttpGetter) + dst, err := ioutil.TempDir("", "tf") + if err != nil { + t.Fatalf("err: %s", err) + } + + t.Logf("dst: %q", dst) + + var u url.URL + u.Scheme = "http" + u.Host = ln.Addr().String() + u.Path = "/regular-subdir//meta-subdir" + + t.Logf("url: %q", u.String()) + + client := &Client{ + Src: u.String(), + Dst: dst, + Mode: ClientModeAny, + Getters: map[string]Getter{ + "http": httpGetter, + }, + } + + err = client.Get() + if err != nil { + t.Fatalf("get err: %v", err) + } +} + +func testHttpServerWithXTerraformGetLoop(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v:%v", ln.Addr().String(), "/loop") + + mux := http.NewServeMux() + mux.HandleFunc("/loop", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving loop") + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetProxyBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("http://%v/bypass", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/bypass", func(w http.ResponseWriter, r *http.Request) { + t.Fail() + t.Logf("bypassed proxy") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerWithXTerraformGetConfiguredGettersBypass(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + header := fmt.Sprintf("git::http://%v/some/repository.git", ln.Addr().String()) + + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Terraform-Get", header) + t.Logf("serving start") + }) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving git HTTP server path: %v", r.URL.Path) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + +func testHttpServerProxy(t *testing.T, upstreamHost string) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + t.Logf("serving proxy: %v: %#+v", r.URL.Path, r.Header) + // create the reverse proxy + proxy := httputil.NewSingleHostReverseProxy(r.URL) + // Note that ServeHttp is non blocking & uses a go routine under the hood + proxy.ServeHTTP(w, r) + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpServer(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -504,6 +917,29 @@ func testHttpHandlerMetaAuth(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic").String()))) } +func testHttpServerWithEndlessBody(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + for { + w.Write([]byte(".\n")) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + func testHttpHandlerMetaSubdir(w http.ResponseWriter, r *http.Request) { w.Write([]byte(fmt.Sprintf(testHttpMetaStr, testModuleURL("basic//subdir").String()))) } @@ -548,6 +984,29 @@ func testHttpHandlerNoRange(w http.ResponseWriter, r *http.Request) { } } +func testHttpServerSubDir(t *testing.T) net.Listener { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + t.Logf("serving: %v: %v: %#+[1]v", r.Method, r.URL.String(), r.Header) + } + }) + + var server http.Server + server.Handler = mux + go server.Serve(ln) + + return ln +} + const testHttpMetaStr = ` <html> <head>
get_s3.go+29 −3 modified@@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -20,10 +21,22 @@ import ( // a S3 bucket. type S3Getter struct { getter + + // Timeout sets a deadline which all S3 operations should + // complete within. Zero value means no timeout. + Timeout time.Duration } func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { // Parse URL + ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { return 0, err @@ -40,7 +53,7 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { Bucket: aws.String(bucket), Prefix: aws.String(path), } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return 0, err } @@ -65,6 +78,12 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { func (g *S3Getter) Get(dst string, u *url.URL) error { ctx := g.Context() + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -106,7 +125,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(req) + resp, err := client.ListObjectsWithContext(ctx, req) if err != nil { return err } @@ -141,6 +160,13 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { func (g *S3Getter) GetFile(dst string, u *url.URL) error { ctx := g.Context() + + if g.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, g.Timeout) + defer cancel() + } + region, bucket, path, version, creds, err := g.parseUrl(u) if err != nil { return err @@ -163,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke req.VersionId = aws.String(version) } - resp, err := client.GetObject(req) + resp, err := client.GetObjectWithContext(ctx, req) if err != nil { return err }
get_test.go+15 −0 modified@@ -492,6 +492,21 @@ func TestGetFile_filename(t *testing.T) { } } +func TestGetFile_filename_path_traversal(t *testing.T) { + dst := tempDir(t) + u := testModule("basic-file/foo.txt") + + u += "?filename=../../../../../../../../../../../../../tmp/bar.txt" + + err := GetAny(dst, u) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "filename query parameter contain path traversal") { + t.Fatalf("unexpected err: %s", err) + } +} + func TestGetFile_checksumSkip(t *testing.T) { dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst))
source.go+3 −1 modified@@ -58,7 +58,9 @@ func SourceDirSubdir(src string) (string, string) { // // The returned path is the full absolute path. func SubdirGlob(dst, subDir string) (string, error) { - matches, err := filepath.Glob(filepath.Join(dst, subDir)) + pattern := filepath.Join(dst, subDir) + + matches, err := filepath.Glob(pattern) if err != nil { return "", err }
Vulnerability mechanics
Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
9- github.com/advisories/GHSA-fcgg-rvwg-jv58ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2022-30321ghsaADVISORY
- discuss.hashicorp.comghsax_refsource_MISCWEB
- discuss.hashicorp.com/t/hcsec-2022-13-multiple-vulnerabilities-in-go-getter-library/39930ghsax_refsource_MISCWEB
- github.com/hashicorp/go-getter/commit/38e97387488f5439616be60874979433a12edb48ghsaWEB
- github.com/hashicorp/go-getter/commit/a2ebce998f8d4105bd4b78d6c99a12803ad97a45ghsaWEB
- github.com/hashicorp/go-getter/pull/359ghsaWEB
- github.com/hashicorp/go-getter/pull/361ghsaWEB
- pkg.go.dev/vuln/GO-2022-0586ghsaWEB
News mentions
0No linked articles in our index yet.