From cc3c18a6a7d89ca1ced6053607a2b480d9cded21 Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Wed, 23 Dec 2020 13:32:30 +0100 Subject: [PATCH] p2p: add NodeID.Validate(), replaces validateID() --- p2p/key.go | 30 +++++++++++++++++++++++++++++- p2p/netaddress.go | 21 +++------------------ p2p/switch.go | 4 ++-- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/p2p/key.go b/p2p/key.go index d970794d4..3ebae995d 100644 --- a/p2p/key.go +++ b/p2p/key.go @@ -2,6 +2,8 @@ package p2p import ( "encoding/hex" + "errors" + "fmt" "io/ioutil" "github.com/tendermint/tendermint/crypto" @@ -11,12 +13,38 @@ import ( ) // NodeID is a hex-encoded crypto.Address. +// FIXME: We should either ensure this is always lowercased, or add an Equal() +// for comparison that decodes to the binary byte slice first. type NodeID string // NodeIDByteLength is the length of a crypto.Address. Currently only 20. -// TODO: support other length addresses? +// FIXME: support other length addresses? const NodeIDByteLength = crypto.AddressSize +// Bytes converts the node ID to it's binary byte representation. +func (id NodeID) Bytes() ([]byte, error) { + bz, err := hex.DecodeString(string(id)) + if err != nil { + return nil, fmt.Errorf("invalid node ID encoding: %w", err) + } + return bz, nil +} + +// Validate validates the NodeID. +func (id NodeID) Validate() error { + if len(id) == 0 { + return errors.New("no ID") + } + bz, err := id.Bytes() + if err != nil { + return err + } + if len(bz) != NodeIDByteLength { + return fmt.Errorf("invalid ID length - got %d, expected %d", len(bz), NodeIDByteLength) + } + return nil +} + //------------------------------------------------------------------------------ // Persistent peer ID // TODO: encrypt on disk diff --git a/p2p/netaddress.go b/p2p/netaddress.go index 5da34507c..b7c860ec3 100644 --- a/p2p/netaddress.go +++ b/p2p/netaddress.go @@ -5,7 +5,6 @@ package p2p import ( - "encoding/hex" "errors" "flag" "fmt" @@ -52,7 +51,7 @@ func NewNetAddress(id NodeID, addr net.Addr) *NetAddress { } } - if err := validateID(id); err != nil { + if err := id.Validate(); err != nil { panic(fmt.Sprintf("Invalid ID %v: %v (addr: %v)", id, err, addr)) } @@ -75,7 +74,7 @@ func NewNetAddressString(addr string) (*NetAddress, error) { } // get ID - if err := validateID(NodeID(spl[0])); err != nil { + if err := NodeID(spl[0]).Validate(); err != nil { return nil, ErrNetAddressInvalid{addrWithoutProtocol, err} } var id NodeID @@ -262,7 +261,7 @@ func (na *NetAddress) Routable() bool { // For IPv4 these are either a 0 or all bits set address. For IPv6 a zero // address or one that matches the RFC3849 documentation address format. func (na *NetAddress) Valid() error { - if err := validateID(na.ID); err != nil { + if err := na.ID.Validate(); err != nil { return fmt.Errorf("invalid ID: %w", err) } @@ -414,17 +413,3 @@ func removeProtocolIfDefined(addr string) string { return addr } - -func validateID(id NodeID) error { - if len(id) == 0 { - return errors.New("no ID") - } - idBytes, err := hex.DecodeString(string(id)) - if err != nil { - return err - } - if len(idBytes) != NodeIDByteLength { - return fmt.Errorf("invalid hex length - got %d, expected %d", len(idBytes), NodeIDByteLength) - } - return nil -} diff --git a/p2p/switch.go b/p2p/switch.go index 54841f9c0..d4c76cd38 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -592,7 +592,7 @@ func (sw *Switch) AddPersistentPeers(addrs []string) error { func (sw *Switch) AddUnconditionalPeerIDs(ids []string) error { sw.Logger.Info("Adding unconditional peer ids", "ids", ids) for i, id := range ids { - err := validateID(NodeID(id)) + err := NodeID(id).Validate() if err != nil { return fmt.Errorf("wrong ID #%d: %w", i, err) } @@ -604,7 +604,7 @@ func (sw *Switch) AddUnconditionalPeerIDs(ids []string) error { func (sw *Switch) AddPrivatePeerIDs(ids []string) error { validIDs := make([]string, 0, len(ids)) for i, id := range ids { - err := validateID(NodeID(id)) + err := NodeID(id).Validate() if err != nil { return fmt.Errorf("wrong ID #%d: %w", i, err) }