diff --git a/cmd/age/age.go b/cmd/age/age.go index 03be199..fc2109d 100644 --- a/cmd/age/age.go +++ b/cmd/age/age.go @@ -12,7 +12,6 @@ import ( "io" "log" "os" - "path/filepath" "time" "github.com/FiloSottile/age/internal/age" @@ -109,19 +108,10 @@ func encrypt(in io.Reader, out io.Writer) { func decrypt(in io.Reader, out io.Writer) { var identities []age.Identity - // TODO: use the default location if no arguments are provided. + // TODO: use the default location if no arguments are provided: + // os.UserConfigDir()/age/keys.txt, ~/.ssh/id_rsa, ~/.ssh/id_ed25519 for _, name := range flag.Args() { - var ( - ids []age.Identity - err error - ) - - // TODO: smarter detection logic than looking for .ssh/* in the path. - if filepath.Base(filepath.Dir(name)) == ".ssh" { - ids, err = parseSSHIdentity(name) - } else { - ids, err = parseIdentitiesFile(name) - } + ids, err := parseIdentitiesFile(name) if err != nil { log.Fatalf("Error: %v", err) } diff --git a/cmd/age/parse.go b/cmd/age/parse.go index d1051fb..62557cd 100644 --- a/cmd/age/parse.go +++ b/cmd/age/parse.go @@ -8,6 +8,7 @@ package main import ( "bufio" + "bytes" "fmt" "io" "io/ioutil" @@ -28,6 +29,8 @@ func parseRecipient(arg string) (age.Recipient, error) { return nil, fmt.Errorf("unknown recipient type: %q", arg) } +const privateKeySizeLimit = 1 << 24 // 16 MiB + func parseIdentitiesFile(name string) ([]age.Identity, error) { f, err := os.Open(name) if err != nil { @@ -35,22 +38,37 @@ func parseIdentitiesFile(name string) ([]age.Identity, error) { } defer f.Close() + buf := &bytes.Buffer{} + limitF := io.LimitReader(f, privateKeySizeLimit) + if _, err := io.Copy(buf, limitF); err != nil { + return nil, fmt.Errorf("failed to read %q: %v", name, err) + } + var ids []age.Identity - scanner := bufio.NewScanner(f) + var ageParsingError error + scanner := bufio.NewScanner(bytes.NewReader(buf.Bytes())) for scanner.Scan() { line := scanner.Text() if strings.HasPrefix(line, "#") || line == "" { continue } - i, err := age.ParseX25519Identity(line) - if err != nil { - return nil, fmt.Errorf("malformed secret keys file %q: %v", name, err) + if strings.HasPrefix(line, "-----BEGIN") { + return parseSSHIdentity(name, bytes.NewReader(buf.Bytes())) + } + if ageParsingError == nil { + i, err := age.ParseX25519Identity(line) + if err != nil { + ageParsingError = fmt.Errorf("malformed secret keys file %q: %v", name, err) + } + ids = append(ids, i) } - ids = append(ids, i) } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("failed to read %q: %v", name, err) } + if ageParsingError != nil { + return nil, ageParsingError + } if len(ids) == 0 { return nil, fmt.Errorf("no secret keys found in %q", name) @@ -58,17 +76,8 @@ func parseIdentitiesFile(name string) ([]age.Identity, error) { return ids, nil } -func parseSSHIdentity(name string) ([]age.Identity, error) { - f, err := os.Open(name) - if err != nil { - return nil, fmt.Errorf("failed to open file: %v", err) - } - defer f.Close() - - // Don't allow unbounded reads. - // TODO: support for multiple keys in the same stream, such as user.keys - // on GitHub. - pemBytes, err := ioutil.ReadAll(io.LimitReader(f, 1<<20)) +func parseSSHIdentity(name string, f io.Reader) ([]age.Identity, error) { + pemBytes, err := ioutil.ReadAll(f) if err != nil { return nil, fmt.Errorf("failed to read %q: %v", name, err) }