From 96659ed81ec1b158904eff9137156d466651e97c Mon Sep 17 00:00:00 2001 From: Godruoyi Date: Thu, 15 Apr 2021 15:42:13 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20Fixed=20bug=20for=20concurrent?= =?UTF-8?q?=20security=20in=20elapsed=20time?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- snowflake.go | 10 +++++----- snowflake_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/snowflake.go b/snowflake.go index 7ff668d..cab2d38 100644 --- a/snowflake.go +++ b/snowflake.go @@ -61,12 +61,12 @@ func NextID() (uint64, error) { } } - df := int(elapsedTime(startTime)) + df := int(elapsedTime(c, startTime)) if df < 0 || df > MaxTimestamp { return 0, errors.New("The maximum life cycle of the snowflake algorithm is 2^41-1(millis), please check starttime") } - id := uint64(df< MaxTimestamp { panic("The maximum life cycle of the snowflake algorithm is 69 years") } @@ -163,8 +163,8 @@ func callSequenceResolver() SequenceResolver { return resolver } -func elapsedTime(s time.Time) int64 { - return currentMillis() - s.UTC().UnixNano()/1e6 +func elapsedTime(nowms int64, s time.Time) int64 { + return nowms - s.UTC().UnixNano()/1e6 } // currentMillis get current millisecond. diff --git a/snowflake_test.go b/snowflake_test.go index 0870c18..e3693f1 100644 --- a/snowflake_test.go +++ b/snowflake_test.go @@ -2,6 +2,7 @@ package snowflake_test import ( "errors" + "sync" "testing" "time" @@ -14,6 +15,54 @@ func TestID(t *testing.T) { if id <= 0 { t.Error("The snowflake should't < 0.") } + + id2 := 1644633515267981312 + df := 392111185853 + if id2 != ((df << (snowflake.MachineIDLength + snowflake.SequenceLength)) | 0 | 0) { + t.Error("Create snowflake should be equal 1644633515267981312") + } + + mp := make(map[uint64]bool) + for i := 0; i < 100000; i++ { + id, e := snowflake.NextID() + if e != nil { + t.Error(e) + continue + } + if _, ok := mp[id]; ok { + t.Error("ID should't repeat", id) + break + } + mp[id] = true + } +} + +func TestID_bitch(t *testing.T) { + le := 100000 + ch := make(chan uint64, le) + var wg sync.WaitGroup + for i := 0; i < le; i++ { + wg.Add(1) + go func() { + defer wg.Done() + id := snowflake.ID() + ch <- id + }() + } + wg.Wait() + close(ch) + + mp := make(map[uint64]bool) + for id := range ch { + if _, ok := mp[id]; ok { + t.Error("It should not be repeated") + break + } + mp[id] = true + } + if len(mp) != le { + t.Error("map length should be equal", le) + } } func TestSetStartTime(t *testing.T) {