🚀 Fixed bug for concurrent security in elapsed time

This commit is contained in:
Godruoyi
2021-04-15 15:42:13 +08:00
parent aeeee0fbb9
commit 96659ed81e
2 changed files with 54 additions and 5 deletions

View File

@@ -61,12 +61,12 @@ func NextID() (uint64, error) {
}
}
df := int(elapsedTime(startTime))
df := int(elapsedTime(c, startTime))
if df < 0 || df > MaxTimestamp {
return 0, errors.New("The maximum life cycle of the snowflake algorithm is 2^41-1(millis), please check starttime")
}
id := uint64(df<<timestampMoveLength | machineID<<machineIDMoveLength | int(seq))
id := uint64((df << timestampMoveLength) | (machineID << machineIDMoveLength) | int(seq))
return id, nil
}
@@ -89,7 +89,7 @@ func SetStartTime(s time.Time) {
}
// Because s must after now, so the `df` not < 0.
df := elapsedTime(s)
df := elapsedTime(currentMillis(), s)
if df > MaxTimestamp {
panic("The maximum life cycle of the snowflake algorithm is 69 years")
}
@@ -163,8 +163,8 @@ func callSequenceResolver() SequenceResolver {
return resolver
}
func elapsedTime(s time.Time) int64 {
return currentMillis() - s.UTC().UnixNano()/1e6
func elapsedTime(nowms int64, s time.Time) int64 {
return nowms - s.UTC().UnixNano()/1e6
}
// currentMillis get current millisecond.

View File

@@ -2,6 +2,7 @@ package snowflake_test
import (
"errors"
"sync"
"testing"
"time"
@@ -14,6 +15,54 @@ func TestID(t *testing.T) {
if id <= 0 {
t.Error("The snowflake should't < 0.")
}
id2 := 1644633515267981312
df := 392111185853
if id2 != ((df << (snowflake.MachineIDLength + snowflake.SequenceLength)) | 0 | 0) {
t.Error("Create snowflake should be equal 1644633515267981312")
}
mp := make(map[uint64]bool)
for i := 0; i < 100000; i++ {
id, e := snowflake.NextID()
if e != nil {
t.Error(e)
continue
}
if _, ok := mp[id]; ok {
t.Error("ID should't repeat", id)
break
}
mp[id] = true
}
}
func TestID_bitch(t *testing.T) {
le := 100000
ch := make(chan uint64, le)
var wg sync.WaitGroup
for i := 0; i < le; i++ {
wg.Add(1)
go func() {
defer wg.Done()
id := snowflake.ID()
ch <- id
}()
}
wg.Wait()
close(ch)
mp := make(map[uint64]bool)
for id := range ch {
if _, ok := mp[id]; ok {
t.Error("It should not be repeated")
break
}
mp[id] = true
}
if len(mp) != le {
t.Error("map length should be equal", le)
}
}
func TestSetStartTime(t *testing.T) {