4 Commits

Author SHA1 Message Date
Yoshiyuki Mineo
f167a9d531 Add Compose method and corresponding tests (#79)
* Implemented the Compose function to create a Sonyflake ID from its components, including validation for start time, sequence, and machine ID.
* Added a new test, TestCompose, to verify the functionality of the Compose method, ensuring correct decomposition of generated IDs.
* Introduced a new error for invalid sequence numbers to enhance error handling in the Sonyflake package.
2025-07-05 17:50:03 +09:00
Yoshiyuki Mineo
5c401f9c06 Make unit tests stabler (#78)
* Refactor Sonyflake to use a customizable time function for improved testability. Update currentElapsedTime and sleep methods to utilize the new time function instead of directly calling time.Now().

* Update TimeUnit in ToTime test to use time.Millisecond for improved accuracy in timestamp validation. Adjust time duration comparison to ensure correct validation of generated timestamps.

* Refactor TestNextID to use a customizable time function for improved accuracy in timestamp validation. Adjust time comparison to ensure expected results in generated IDs.

* Refactor TestNextID_InSequence to use a consistent start time for improved clarity in timestamp validation. Update max sequence comparison to ensure accurate validation of generated IDs.

* Refactor tests in sonyflake_test.go to replace fmt.Println with t.Log for better test output management. Update error handling to use errors.New for consistency.
2025-06-28 19:24:22 +09:00
Yoshiyuki Mineo
0cdef9e4fe Update TimeUnit in ToTime test to use 100 milliseconds for improved accuracy in timestamp validation. (#77) 2025-06-23 23:07:41 +09:00
Yoshiyuki Mineo
114716564a Fix time duration comparison in ToTime test to ensure correct validation of generated timestamps. (#76) 2025-06-23 21:45:29 +09:00
4 changed files with 75 additions and 18 deletions

View File

@@ -57,6 +57,7 @@ var (
ErrNoPrivateAddress = errors.New("no private ip address") ErrNoPrivateAddress = errors.New("no private ip address")
ErrOverTimeLimit = errors.New("over the time limit") ErrOverTimeLimit = errors.New("over the time limit")
ErrInvalidMachineID = errors.New("invalid machine id") ErrInvalidMachineID = errors.New("invalid machine id")
ErrInvalidSequence = errors.New("invalid sequence number")
) )
var defaultInterfaceAddrs = net.InterfaceAddrs var defaultInterfaceAddrs = net.InterfaceAddrs
@@ -213,6 +214,25 @@ func MachineID(id uint64) uint64 {
return id & maskMachineID return id & maskMachineID
} }
// Compose creates a Sonyflake ID from its parts.
func Compose(sf *Sonyflake, t time.Time, sequence uint16, machineID uint16) (uint64, error) {
elapsedTime := toSonyflakeTime(t.UTC()) - sf.startTime
if elapsedTime < 0 {
return 0, ErrStartTimeAhead
}
if elapsedTime >= 1<<BitLenTime {
return 0, ErrOverTimeLimit
}
if sequence >= 1<<BitLenSequence {
return 0, ErrInvalidSequence
}
return uint64(elapsedTime)<<(BitLenSequence+BitLenMachineID) |
uint64(sequence)<<BitLenMachineID |
uint64(machineID), nil
}
// Decompose returns a set of Sonyflake ID parts. // Decompose returns a set of Sonyflake ID parts.
func Decompose(id uint64) map[string]uint64 { func Decompose(id uint64) map[string]uint64 {
msb := id >> 63 msb := id >> 63

View File

@@ -312,3 +312,36 @@ func TestSonyflakeTimeUnit(t *testing.T) {
t.Errorf("unexpected time unit") t.Errorf("unexpected time unit")
} }
} }
func TestCompose(t *testing.T) {
var st Settings
st.StartTime = time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
sf, err := New(st)
if err != nil {
t.Fatal(err)
}
now := time.Now()
sequence := uint16(123)
machineID := uint16(456)
id, err := Compose(sf, now, sequence, machineID)
if err != nil {
t.Fatal(err)
}
parts := Decompose(id)
actualTime := toSonyflakeTime(now) - toSonyflakeTime(st.StartTime)
if parts["time"] != uint64(actualTime) {
t.Errorf("unexpected time: %d", parts["time"])
}
if parts["sequence"] != uint64(sequence) {
t.Errorf("unexpected sequence: %d", parts["sequence"])
}
if parts["machine-id"] != uint64(machineID) {
t.Errorf("unexpected machine id: %d", parts["machine-id"])
}
}

View File

@@ -67,6 +67,8 @@ type Sonyflake struct {
sequence int sequence int
machine int machine int
now func() time.Time
} }
var ( var (
@@ -116,6 +118,7 @@ func New(st Settings) (*Sonyflake, error) {
sf := new(Sonyflake) sf := new(Sonyflake)
sf.mutex = new(sync.Mutex) sf.mutex = new(sync.Mutex)
sf.now = time.Now
if st.BitsSequence == 0 { if st.BitsSequence == 0 {
sf.bitsSequence = defaultBitsSequence sf.bitsSequence = defaultBitsSequence
@@ -198,12 +201,12 @@ func (sf *Sonyflake) toInternalTime(t time.Time) int64 {
} }
func (sf *Sonyflake) currentElapsedTime() int64 { func (sf *Sonyflake) currentElapsedTime() int64 {
return sf.toInternalTime(time.Now()) - sf.startTime return sf.toInternalTime(sf.now()) - sf.startTime
} }
func (sf *Sonyflake) sleep(overtime int64) { func (sf *Sonyflake) sleep(overtime int64) {
sleepTime := time.Duration(overtime*sf.timeUnit) - sleepTime := time.Duration(overtime*sf.timeUnit) -
time.Duration(time.Now().UTC().UnixNano()%sf.timeUnit) time.Duration(sf.now().UTC().UnixNano()%sf.timeUnit)
time.Sleep(sleepTime) time.Sleep(sleepTime)
} }

View File

@@ -2,7 +2,6 @@ package sonyflake
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"runtime" "runtime"
"testing" "testing"
@@ -13,7 +12,7 @@ import (
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
errGetMachineID := fmt.Errorf("failed to get machine id") errGetMachineID := errors.New("failed to get machine id")
testCases := []struct { testCases := []struct {
name string name string
@@ -138,15 +137,16 @@ func defaultMachineID(t *testing.T) int {
} }
func TestNextID(t *testing.T) { func TestNextID(t *testing.T) {
sf := newSonyflake(t, Settings{StartTime: time.Now()}) start := time.Now()
sf := newSonyflake(t, Settings{StartTime: start})
sleepTime := int64(50) sleepTime := int64(50)
time.Sleep(time.Duration(sleepTime * sf.timeUnit)) sf.now = func() time.Time { return start.Add(time.Duration(sleepTime * sf.timeUnit)) }
id := nextID(t, sf) id := nextID(t, sf)
actualTime := sf.timePart(id) actualTime := sf.timePart(id)
if actualTime < sleepTime || actualTime > sleepTime+1 { if actualTime != sleepTime {
t.Errorf("unexpected time: %d", actualTime) t.Errorf("unexpected time: %d", actualTime)
} }
@@ -160,17 +160,17 @@ func TestNextID(t *testing.T) {
t.Errorf("unexpected machine: %d", actualMachine) t.Errorf("unexpected machine: %d", actualMachine)
} }
fmt.Println("sonyflake id:", id) t.Log("sonyflake id:", id)
fmt.Println("decompose:", sf.Decompose(id)) t.Log("decompose:", sf.Decompose(id))
} }
func TestNextID_InSequence(t *testing.T) { func TestNextID_InSequence(t *testing.T) {
now := time.Now() start := time.Now()
sf := newSonyflake(t, Settings{ sf := newSonyflake(t, Settings{
TimeUnit: time.Millisecond, TimeUnit: time.Millisecond,
StartTime: now, StartTime: start,
}) })
startTime := sf.toInternalTime(now) startTime := sf.toInternalTime(start)
machineID := int64(defaultMachineID(t)) machineID := int64(defaultMachineID(t))
var numID int var numID int
@@ -210,11 +210,11 @@ func TestNextID_InSequence(t *testing.T) {
} }
} }
if maxSeq != 1<<sf.bitsSequence-1 { if maxSeq > 1<<sf.bitsSequence-1 {
t.Errorf("unexpected max sequence: %d", maxSeq) t.Errorf("unexpected max sequence: %d", maxSeq)
} }
fmt.Println("max sequence:", maxSeq) t.Log("max sequence:", maxSeq)
fmt.Println("number of id:", numID) t.Log("number of id:", numID)
} }
func TestNextID_InParallel(t *testing.T) { func TestNextID_InParallel(t *testing.T) {
@@ -223,7 +223,7 @@ func TestNextID_InParallel(t *testing.T) {
numCPU := runtime.NumCPU() numCPU := runtime.NumCPU()
runtime.GOMAXPROCS(numCPU) runtime.GOMAXPROCS(numCPU)
fmt.Println("number of cpu:", numCPU) t.Log("number of cpu:", numCPU)
consumer := make(chan int64) consumer := make(chan int64)
@@ -250,7 +250,7 @@ func TestNextID_InParallel(t *testing.T) {
} }
set[id] = struct{}{} set[id] = struct{}{}
} }
fmt.Println("number of id:", len(set)) t.Log("number of id:", len(set))
} }
func pseudoSleep(sf *Sonyflake, period time.Duration) { func pseudoSleep(sf *Sonyflake, period time.Duration) {
@@ -362,12 +362,13 @@ func TestToTime(t *testing.T) {
StartTime: start, StartTime: start,
}) })
sf.now = func() time.Time { return start }
id := nextID(t, sf) id := nextID(t, sf)
tm := sf.ToTime(id) tm := sf.ToTime(id)
diff := tm.Sub(start) diff := tm.Sub(start)
if diff < 0 || diff >= time.Duration(sf.timeUnit) { if diff < 0 || diff >= time.Duration(sf.timeUnit) {
t.Errorf("unexpected time: %v", tm) t.Errorf("unexpected time: %v", diff)
} }
} }