312 lines
8.7 KiB
Go
312 lines
8.7 KiB
Go
package snowflake
|
|
|
|
import (
|
|
"errors"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// These constants are the bit lengths of SnowFlake ID parts.
|
|
const (
|
|
BitLenTime = 39 // bit length of time
|
|
BitLenMachineID = 63 - BitLenTime - BitLenSequence // bit length of machine id (16 bits)
|
|
BitLenSequence = 8 // bit length of sequence number
|
|
)
|
|
|
|
// Settings configures SnowFlake:
|
|
//
|
|
// StartTime is the time since which the SnowFlake time is defined as the elapsed time.
|
|
// If StartTime is 0, the start time of the SnowFlake is set to "2014-09-01 00:00:00 +0000 UTC".
|
|
// If StartTime is ahead of the current time, SnowFlake is not created.
|
|
//
|
|
// MachineID returns the unique ID of the SnowFlake instance.
|
|
// If MachineID returns an error, SnowFlake is not created.
|
|
// If MachineID is nil, default MachineID is used.
|
|
// Default MachineID returns the lower 16 bits of the private IP address.
|
|
//
|
|
// CheckMachineID validates the uniqueness of the machine ID.
|
|
// If CheckMachineID returns false, SnowFlake is not created.
|
|
// If CheckMachineID is nil, no validation is done.
|
|
type Settings struct {
|
|
StartTime time.Time
|
|
MachineID func() (uint16, error)
|
|
CheckMachineID func(uint16) bool
|
|
}
|
|
|
|
// SnowFlake is a distributed unique ID generator.
|
|
type SnowFlake struct {
|
|
mutex *sync.Mutex
|
|
startTime int64
|
|
recentTime int64 // most recent time when this snowflake was used
|
|
sequence uint16
|
|
machineID uint16
|
|
}
|
|
|
|
type IDList struct {
|
|
List []uint64 `json:"id_list"`
|
|
MachineId uint16 `json:"machine_id"`
|
|
}
|
|
|
|
func initSnowFlake(st *Settings) *SnowFlake {
|
|
if st == nil {
|
|
st = &Settings{}
|
|
// Default start-time is the Jan 01, 2014 and the ID generator should work for 174 yeard from then
|
|
st.StartTime = time.Date(2014, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
}
|
|
sf := NewSnowFlake(*st)
|
|
if sf == nil {
|
|
log.Println("snowFlake not created")
|
|
}
|
|
return sf
|
|
|
|
}
|
|
|
|
func GenerateID(settings *Settings) (*IDList, error) {
|
|
snowFlake := initSnowFlake(settings)
|
|
ids, err := snowFlake.NextIDs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idList := &IDList{List: ids, MachineId: snowFlake.machineID}
|
|
return idList, nil
|
|
}
|
|
|
|
// NewSnowFlake returns a new SnowFlake configured with the given Settings.
|
|
// NewSnowFlake returns nil in the following cases:
|
|
// - Settings.StartTime is ahead of the current time.
|
|
// - Settings.MachineID returns an error.
|
|
// - Settings.CheckMachineID returns false.
|
|
func NewSnowFlake(st Settings) *SnowFlake {
|
|
sf := new(SnowFlake)
|
|
sf.mutex = new(sync.Mutex)
|
|
// why is it set to max value ?
|
|
sf.sequence = uint16(1<<BitLenSequence - 1)
|
|
if st.StartTime.After(time.Now()) {
|
|
return nil
|
|
}
|
|
if st.StartTime.IsZero() {
|
|
sf.startTime = toSnowFlakeTime(time.Date(2014, 9, 1, 0, 0, 0, 0, time.UTC))
|
|
} else {
|
|
sf.startTime = toSnowFlakeTime(st.StartTime)
|
|
}
|
|
|
|
var err error
|
|
if st.MachineID == nil {
|
|
sf.machineID, err = lower16BitPrivateIP()
|
|
} else {
|
|
sf.machineID, err = st.MachineID()
|
|
}
|
|
if err != nil || (st.CheckMachineID != nil && !st.CheckMachineID(sf.machineID)) {
|
|
return nil
|
|
}
|
|
|
|
return sf
|
|
}
|
|
|
|
// elapsedTime, machine-id and sequence
|
|
func (sf *SnowFlake) NextIDs() ([]uint64, error) {
|
|
sf.mutex.Lock()
|
|
defer sf.mutex.Unlock()
|
|
//sf.elapsedTime = currentElapsedTime(sf.startTime)
|
|
sf.validateTime()
|
|
sf.sequence = 0
|
|
const maxSequence = uint16(1<<BitLenSequence - 1)
|
|
idList := make([]uint64, 0, maxSequence+1)
|
|
for sf.sequence <= maxSequence {
|
|
id, err := sf.toID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idList = append(idList, id)
|
|
sf.sequence = (sf.sequence + 1)
|
|
}
|
|
return idList, nil
|
|
}
|
|
|
|
// NextID generates a next unique ID.
|
|
// After the SnowFlake time overflows, NextID returns an error.
|
|
// ONLY USED in Testing ??
|
|
func (sf *SnowFlake) NextID() (uint64, error) {
|
|
sf.mutex.Lock()
|
|
defer sf.mutex.Unlock()
|
|
sf.validateTime()
|
|
return sf.toID()
|
|
}
|
|
|
|
// checks if the current time (time elapsed since the start of this snowflake instances start) is less than the most recent
|
|
// snowflake time. If the recentTime is less than the current time -- meaning the id has not been generated in a while --
|
|
// update the recent time to current time and set the sequence to 0. The sequence is set to zero since new ids will be generated in this time.
|
|
// if recentTime time is equal to or greater than current time -- find the number of ids that have already been generated by updating the sequence.
|
|
// if the ids is 0 -- meaning all ids at the current time has been generated -- sleep for the time until the next time slot is available
|
|
func (sf *SnowFlake) validateTime() {
|
|
current := currentElapsedTime(sf.startTime)
|
|
if sf.recentTime < current {
|
|
// this is only executed the first time
|
|
// this will be executed if the elapsedTime is not set correctly to current time
|
|
sf.recentTime = current
|
|
sf.sequence = 0
|
|
} else if sf.recentTime == current {
|
|
const maskSequence = uint16(1<<BitLenSequence - 1)
|
|
sf.sequence = (sf.sequence + 1) & maskSequence
|
|
if sf.sequence == 0 {
|
|
sf.recentTime++
|
|
overtime := sf.recentTime - current
|
|
time.Sleep(sleepTime((overtime)))
|
|
}
|
|
} else {
|
|
log.Fatal("recent time can never be greater than current time")
|
|
}
|
|
}
|
|
func (sf *SnowFlake) NextIDRange() (uint64, uint64, error) {
|
|
sf.mutex.Lock()
|
|
defer sf.mutex.Unlock()
|
|
sf.validateTime()
|
|
lower, err := sf.toID()
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
sf.sequence = uint16(1<<BitLenSequence - 1)
|
|
upper, err := sf.toID()
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
return lower, upper, nil
|
|
|
|
}
|
|
|
|
const snowFlakeTimeUnitScaleFactor = 1e7 // nsec, i.e. 10 msec convert unit of nano-sec to 10 msec.
|
|
|
|
func toSnowFlakeTime(t time.Time) int64 {
|
|
return t.UTC().UnixNano() / snowFlakeTimeUnitScaleFactor
|
|
}
|
|
|
|
func currentElapsedTime(startTime int64) int64 {
|
|
return toSnowFlakeTime(time.Now()) - startTime
|
|
}
|
|
|
|
func sleepTime(overtime int64) time.Duration {
|
|
return time.Duration(overtime)*10*time.Millisecond -
|
|
time.Duration(time.Now().UTC().UnixNano()%snowFlakeTimeUnitScaleFactor)*time.Nanosecond
|
|
}
|
|
|
|
func (sf *SnowFlake) toID() (uint64, error) {
|
|
if sf.recentTime >= 1<<BitLenTime {
|
|
return 0, errors.New("over the time limit")
|
|
}
|
|
// Time-Sequence-MachineID
|
|
//return uint64(sf.elapsedTime)<<(BitLenSequence+BitLenMachineID) |
|
|
// uint64(sf.sequence)<<BitLenMachineID |
|
|
// uint64(sf.machineID), nil
|
|
|
|
// Time-MachineID-Sequence
|
|
return uint64(sf.recentTime)<<(BitLenSequence+BitLenMachineID) |
|
|
uint64(sf.machineID)<<BitLenSequence |
|
|
uint64(sf.sequence), nil
|
|
}
|
|
|
|
func privateIPv4() (net.IP, error) {
|
|
as, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, a := range as {
|
|
ipnet, ok := a.(*net.IPNet)
|
|
if !ok || ipnet.IP.IsLoopback() {
|
|
continue
|
|
}
|
|
|
|
ip := ipnet.IP.To4()
|
|
if isPrivateIPv4(ip) {
|
|
return ip, nil
|
|
}
|
|
}
|
|
return nil, errors.New("no private ip address")
|
|
}
|
|
|
|
func amazonEC2PrivateIPv4() (net.IP, error) {
|
|
// URL to retrieve instance metadata in an AWS EC2 instance:
|
|
// http://docs.aws.amazon.com/en_us/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
|
|
timeout := time.Duration(10 * time.Millisecond)
|
|
client := http.Client{
|
|
Timeout: timeout,
|
|
}
|
|
res, err := client.Get("http://169.254.169.254/latest/meta-data/local-ipv4")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
body, err := ioutil.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ip := net.ParseIP(string(body))
|
|
if ip == nil {
|
|
return nil, errors.New("invalid ip address")
|
|
}
|
|
return ip.To4(), nil
|
|
}
|
|
|
|
func k8sPodIPFromEnvVariable() (net.IP, error) {
|
|
podIpEnvVarKey := "UNIQUE_ID_POD_IP"
|
|
podIpStr := os.Getenv(podIpEnvVarKey)
|
|
if podIpStr == "" {
|
|
return nil, errors.New("Env Variable Not Present")
|
|
}
|
|
ip := net.ParseIP(podIpStr)
|
|
if ip == nil {
|
|
return nil, errors.New("invalid ip address")
|
|
}
|
|
return ip.To4(), nil
|
|
}
|
|
|
|
func isPrivateIPv4(ip net.IP) bool {
|
|
return ip != nil &&
|
|
(ip[0] == 10 || ip[0] == 172 && (ip[1] >= 16 && ip[1] < 32) || ip[0] == 192 && ip[1] == 168)
|
|
}
|
|
|
|
func lower16BitPrivateIP() (uint16, error) {
|
|
ip, err := k8sPodIPFromEnvVariable()
|
|
if err != nil {
|
|
ip, err = amazonEC2PrivateIPv4()
|
|
}
|
|
if err != nil {
|
|
ip, err = privateIPv4()
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return uint16(ip[2])<<8 + uint16(ip[3]), nil
|
|
}
|
|
|
|
// Decompose returns a set of SnowFlake ID parts.
|
|
// Time-MachineID-Sequence
|
|
func decompose(id uint64) map[string]uint64 {
|
|
// const maskSequence = uint64((1<<BitLenSequence - 1) << BitLenMachineID)
|
|
const maskSequence = uint64((1<<BitLenSequence - 1))
|
|
//const maskMachineID = uint64(1<<BitLenMachineID - 1)
|
|
const maskMachineID = uint64((1<<BitLenMachineID - 1) << BitLenSequence)
|
|
msb := id >> 63
|
|
time := id >> (BitLenSequence + BitLenMachineID)
|
|
//sequence := id & maskSequence >> BitLenMachineID
|
|
sequence := id & maskSequence
|
|
|
|
machineID := id & maskMachineID >> BitLenSequence
|
|
|
|
return map[string]uint64{
|
|
"id": id,
|
|
"msb": msb,
|
|
"time": time,
|
|
"sequence": sequence,
|
|
"machine-id": machineID,
|
|
}
|
|
}
|