diff --git a/internal/age/age_test.go b/internal/age/age_test.go index 497159c..0697a05 100644 --- a/internal/age/age_test.go +++ b/internal/age/age_test.go @@ -76,7 +76,7 @@ func TestEncryptDecryptScrypt(t *testing.T) { if err != nil { t.Fatal(err) } - r.SetWorkFactor(1 << 15) + r.SetWorkFactor(15) buf := &bytes.Buffer{} w, err := age.Encrypt(buf, r) if err != nil { diff --git a/internal/age/recipients_test.go b/internal/age/recipients_test.go index 97f887e..e37d2ae 100644 --- a/internal/age/recipients_test.go +++ b/internal/age/recipients_test.go @@ -64,7 +64,7 @@ func TestScryptRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - r.SetWorkFactor(1 << 15) + r.SetWorkFactor(15) i, err := age.NewScryptIdentity(password) if err != nil { t.Fatal(err) diff --git a/internal/age/scrypt.go b/internal/age/scrypt.go index 277aa4f..0bacdef 100644 --- a/internal/age/scrypt.go +++ b/internal/age/scrypt.go @@ -31,15 +31,20 @@ func NewScryptRecipient(password string) (*ScryptRecipient, error) { return nil, errors.New("empty scrypt password") } r := &ScryptRecipient{ - password: []byte(password), - workFactor: 1 << 18, // 1s on a modern machine + password: []byte(password), + // TODO: automatically scale this to 1s (with a min) in the CLI. + workFactor: 18, // 1s on a modern machine } return r, nil } -func (r *ScryptRecipient) SetWorkFactor(N int) { - // TODO: automatically scale this to 1s (with a min) in the CLI. - r.workFactor = N +// SetWorkFactor sets the scrypt work factor to 2^logN. +// It must be called before Wrap. +func (r *ScryptRecipient) SetWorkFactor(logN int) { + if logN > 30 || logN < 1 { + panic("age: SetWorkFactor called with illegal value") + } + r.workFactor = logN } func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) { @@ -48,13 +53,13 @@ func (r *ScryptRecipient) Wrap(fileKey []byte) (*format.Recipient, error) { return nil, err } - N := r.workFactor + logN := r.workFactor l := &format.Recipient{ Type: "scrypt", - Args: []string{format.EncodeToString(salt), strconv.Itoa(N)}, + Args: []string{format.EncodeToString(salt), strconv.Itoa(logN)}, } - k, err := scrypt.Key(r.password, salt, N, 8, 1, chacha20poly1305.KeySize) + k, err := scrypt.Key(r.password, salt, 1< 30 || logN < 1 { + panic("age: SetMaxWorkFactor called with illegal value") + } + i.maxWorkFactor = logN } func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) { @@ -106,19 +116,22 @@ func (i *ScryptIdentity) Unwrap(block *format.Recipient) ([]byte, error) { if len(salt) != 16 { return nil, errors.New("invalid scrypt recipient block") } - N, err := strconv.Atoi(block.Args[1]) + logN, err := strconv.Atoi(block.Args[1]) if err != nil { return nil, fmt.Errorf("failed to parse scrypt work factor: %v", err) } - if N > i.maxWorkFactor { - return nil, fmt.Errorf("scrypt work factor too large: %v", N) + if logN > i.maxWorkFactor { + return nil, fmt.Errorf("scrypt work factor too large: %v", logN) + } + if logN <= 0 { + return nil, fmt.Errorf("invalid scrypt work factor: %v", logN) } wrappedKey, err := format.DecodeString(string(block.Body)) if err != nil { return nil, fmt.Errorf("failed to parse scrypt recipient: %v", err) } - k, err := scrypt.Key(i.password, salt, N, 8, 1, chacha20poly1305.KeySize) + k, err := scrypt.Key(i.password, salt, 1<