diff --git a/cmd/stbak/cmd/archive.go b/cmd/stbak/cmd/archive.go index ef4c326..84bd97e 100644 --- a/cmd/stbak/cmd/archive.go +++ b/cmd/stbak/cmd/archive.go @@ -59,6 +59,8 @@ var ( errKeyNotAccessible = errors.New("key not found or accessible") errMissingTarHeader = errors.New("tar header is missing") + + errRecipientUnparsable = errors.New("recipient could not be parsed") ) var archiveCmd = &cobra.Command{ @@ -103,6 +105,11 @@ var archiveCmd = &cobra.Command{ return err } + recipient, err := parseRecipient(viper.GetString(encryptionFlag), pubkey) + if err != nil { + return err + } + hdrs, err := archive( viper.GetString(tapeFlag), viper.GetInt(recordSizeFlag), @@ -111,7 +118,7 @@ var archiveCmd = &cobra.Command{ viper.GetString(compressionFlag), viper.GetString(compressionLevelFlag), viper.GetString(encryptionFlag), - pubkey, + recipient, ) if err != nil { return err @@ -149,7 +156,7 @@ func archive( compressionFormat string, compressionLevel string, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) ([]*tar.Header, error) { dirty := false tw, isRegular, cleanup, err := openTapeWriter(tape) @@ -241,7 +248,7 @@ func archive( Writer: io.Discard, } - encryptor, err := encrypt(fileSizeCounter, encryptionFormat, pubkey) + encryptor, err := encrypt(fileSizeCounter, encryptionFormat, recipient) if err != nil { return err } @@ -316,7 +323,7 @@ func archive( hdrToAppend := *hdr headers = append(headers, &hdrToAppend) - if err := encryptHeader(hdr, encryptionFormat, pubkey); err != nil { + if err := encryptHeader(hdr, encryptionFormat, recipient); err != nil { return err } @@ -329,7 +336,7 @@ func archive( } // Compress and write the file - encryptor, err := encrypt(tw, encryptionFormat, pubkey) + encryptor, err := encrypt(tw, encryptionFormat, recipient) if err != nil { return err } @@ -422,7 +429,7 @@ func checkCompressionLevel(compressionLevel string) error { func encryptHeader( hdr *tar.Header, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) error { if encryptionFormat == encryptionFormatNoneKey { return nil @@ -439,7 +446,7 @@ func encryptHeader( return err } - newHdr.PAXRecords[pax.STFSEmbeddedHeader], err = encryptString(string(wrappedHeader), encryptionFormat, pubkey) + newHdr.PAXRecords[pax.STFSEmbeddedHeader], err = encryptString(string(wrappedHeader), encryptionFormat, recipient) if err != nil { return err } @@ -483,23 +490,39 @@ func addSuffix(name string, compressionFormat string, encryptionFormat string) ( return name, nil } +func parseRecipient( + encryptionFormat string, + pubkey []byte, +) (interface{}, error) { + switch encryptionFormat { + case encryptionFormatAgeKey: + return age.ParseX25519Recipient(string(pubkey)) + case encryptionFormatPGPKey: + return openpgp.ReadKeyRing(bytes.NewBuffer(pubkey)) + case encryptionFormatNoneKey: + return pubkey, nil + default: + return nil, errUnsupportedEncryptionFormat + } +} + func encrypt( dst io.Writer, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) (io.WriteCloser, error) { switch encryptionFormat { case encryptionFormatAgeKey: - recipient, err := age.ParseX25519Recipient(string(pubkey)) - if err != nil { - return nil, err + recipient, ok := recipient.(*age.X25519Recipient) + if !ok { + return nil, errRecipientUnparsable } return age.Encrypt(dst, recipient) case encryptionFormatPGPKey: - recipient, err := openpgp.ReadKeyRing(bytes.NewBuffer(pubkey)) - if err != nil { - return nil, err + recipient, ok := recipient.(openpgp.EntityList) + if !ok { + return nil, errRecipientUnparsable } return openpgp.Encrypt(dst, recipient, nil, nil, nil) @@ -513,13 +536,13 @@ func encrypt( func encryptString( src string, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) (string, error) { switch encryptionFormat { case encryptionFormatAgeKey: - recipient, err := age.ParseX25519Recipient(string(pubkey)) - if err != nil { - return "", err + recipient, ok := recipient.(*age.X25519Recipient) + if !ok { + return "", errRecipientUnparsable } out := &bytes.Buffer{} @@ -538,9 +561,9 @@ func encryptString( return base64.StdEncoding.EncodeToString(out.Bytes()), nil case encryptionFormatPGPKey: - recipient, err := openpgp.ReadKeyRing(bytes.NewBuffer(pubkey)) - if err != nil { - return "", err + recipient, ok := recipient.(openpgp.EntityList) + if !ok { + return "", errRecipientUnparsable } out := &bytes.Buffer{} diff --git a/cmd/stbak/cmd/delete.go b/cmd/stbak/cmd/delete.go index 1997305..831563b 100644 --- a/cmd/stbak/cmd/delete.go +++ b/cmd/stbak/cmd/delete.go @@ -47,12 +47,17 @@ var deleteCmd = &cobra.Command{ return err } + recipient, err := parseRecipient(viper.GetString(encryptionFlag), pubkey) + if err != nil { + return err + } + return delete( viper.GetString(tapeFlag), viper.GetString(metadataFlag), viper.GetString(nameFlag), viper.GetString(encryptionFlag), - pubkey, + recipient, ) }, } @@ -62,7 +67,7 @@ func delete( metadata string, name string, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) error { dirty := false tw, _, cleanup, err := openTapeWriter(tape) @@ -113,7 +118,7 @@ func delete( hdr.PAXRecords[pax.STFSRecordVersion] = pax.STFSRecordVersion1 hdr.PAXRecords[pax.STFSRecordAction] = pax.STFSRecordActionDelete - if err := encryptHeader(hdr, encryptionFormat, pubkey); err != nil { + if err := encryptHeader(hdr, encryptionFormat, recipient); err != nil { return err } diff --git a/cmd/stbak/cmd/move.go b/cmd/stbak/cmd/move.go index caa4bdd..f22282d 100644 --- a/cmd/stbak/cmd/move.go +++ b/cmd/stbak/cmd/move.go @@ -40,13 +40,18 @@ var moveCmd = &cobra.Command{ return err } + recipient, err := parseRecipient(viper.GetString(encryptionFlag), pubkey) + if err != nil { + return err + } + return move( viper.GetString(tapeFlag), viper.GetString(metadataFlag), viper.GetString(srcFlag), viper.GetString(dstFlag), viper.GetString(encryptionFlag), - pubkey, + recipient, ) }, } @@ -57,7 +62,7 @@ func move( src string, dst string, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) error { dirty := false tw, _, cleanup, err := openTapeWriter(tape) @@ -110,7 +115,7 @@ func move( hdr.PAXRecords[pax.STFSRecordAction] = pax.STFSRecordActionUpdate hdr.PAXRecords[pax.STFSRecordReplacesName] = dbhdr.Name - if err := encryptHeader(hdr, encryptionFormat, pubkey); err != nil { + if err := encryptHeader(hdr, encryptionFormat, recipient); err != nil { return err } diff --git a/cmd/stbak/cmd/update.go b/cmd/stbak/cmd/update.go index 27ac10a..cba8554 100644 --- a/cmd/stbak/cmd/update.go +++ b/cmd/stbak/cmd/update.go @@ -60,6 +60,11 @@ var updateCmd = &cobra.Command{ return err } + recipient, err := parseRecipient(viper.GetString(encryptionFlag), pubkey) + if err != nil { + return err + } + hdrs, err := update( viper.GetString(tapeFlag), viper.GetInt(recordSizeFlag), @@ -68,7 +73,7 @@ var updateCmd = &cobra.Command{ viper.GetString(compressionFlag), viper.GetString(compressionLevelFlag), viper.GetString(encryptionFlag), - pubkey, + recipient, ) if err != nil { return err @@ -106,7 +111,7 @@ func update( compressionFormat string, compressionLevel string, encryptionFormat string, - pubkey []byte, + recipient interface{}, ) ([]*tar.Header, error) { dirty := false tw, isRegular, cleanup, err := openTapeWriter(tape) @@ -152,7 +157,7 @@ func update( Writer: io.Discard, } - encryptor, err := encrypt(fileSizeCounter, encryptionFormat, pubkey) + encryptor, err := encrypt(fileSizeCounter, encryptionFormat, recipient) if err != nil { return err } @@ -230,7 +235,7 @@ func update( hdrToAppend := *hdr headers = append(headers, &hdrToAppend) - if err := encryptHeader(hdr, encryptionFormat, pubkey); err != nil { + if err := encryptHeader(hdr, encryptionFormat, recipient); err != nil { return err } @@ -243,7 +248,7 @@ func update( } // Compress and write the file - encryptor, err := encrypt(tw, encryptionFormat, pubkey) + encryptor, err := encrypt(tw, encryptionFormat, recipient) if err != nil { return err } @@ -300,7 +305,7 @@ func update( hdrToAppend := *hdr headers = append(headers, &hdrToAppend) - if err := encryptHeader(hdr, encryptionFormat, pubkey); err != nil { + if err := encryptHeader(hdr, encryptionFormat, recipient); err != nil { return err }