Merge branch 'master' into fix

This commit is contained in:
ccfuncy
2023-11-30 09:17:48 +08:00
committed by GitHub
4 changed files with 91 additions and 39 deletions

View File

@@ -4,7 +4,7 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
go-version: [1.19.x, 1.20.x] go-version: [1.20.x, 1.21.x]
os: [ubuntu-latest] os: [ubuntu-latest]
runs-on: ${{matrix.os}} runs-on: ${{matrix.os}}
steps: steps:

View File

@@ -33,10 +33,10 @@ go get github.com/sony/sonyflake
Usage Usage
----- -----
The function NewSonyflake creates a new Sonyflake instance. The function New creates a new Sonyflake instance.
```go ```go
func NewSonyflake(st Settings) *Sonyflake func New(st Settings) (*Sonyflake, error)
``` ```
You can configure Sonyflake by the struct Settings: You can configure Sonyflake by the struct Settings:

View File

@@ -53,21 +53,29 @@ type Sonyflake struct {
machineID uint16 machineID uint16
} }
var (
ErrStartTimeAhead = errors.New("start time is ahead of now")
ErrNoPrivateAddress = errors.New("no private ip address")
ErrOverTimeLimit = errors.New("over the time limit")
ErrInvalidMachineID = errors.New("invalid machine id")
)
var defaultInterfaceAddrs = net.InterfaceAddrs var defaultInterfaceAddrs = net.InterfaceAddrs
// NewSonyflake returns a new Sonyflake configured with the given Settings. // New returns a new Sonyflake configured with the given Settings.
// NewSonyflake returns nil in the following cases: // New returns an error in the following cases:
// - Settings.StartTime is ahead of the current time. // - Settings.StartTime is ahead of the current time.
// - Settings.MachineID returns an error. // - Settings.MachineID returns an error.
// - Settings.CheckMachineID returns false. // - Settings.CheckMachineID returns false.
func NewSonyflake(st Settings) *Sonyflake { func New(st Settings) (*Sonyflake, error) {
if st.StartTime.After(time.Now()) {
return nil, ErrStartTimeAhead
}
sf := new(Sonyflake) sf := new(Sonyflake)
sf.mutex = new(sync.Mutex) sf.mutex = new(sync.Mutex)
sf.sequence = uint16(1<<BitLenSequence - 1) sf.sequence = uint16(1<<BitLenSequence - 1)
if st.StartTime.After(time.Now()) {
return nil
}
if st.StartTime.IsZero() { if st.StartTime.IsZero() {
sf.startTime = toSonyflakeTime(time.Date(2014, 9, 1, 0, 0, 0, 0, time.UTC)) sf.startTime = toSonyflakeTime(time.Date(2014, 9, 1, 0, 0, 0, 0, time.UTC))
sf.startTimeMono = time.Now() sf.startTimeMono = time.Now()
@@ -82,10 +90,24 @@ func NewSonyflake(st Settings) *Sonyflake {
} else { } else {
sf.machineID, err = st.MachineID() sf.machineID, err = st.MachineID()
} }
if err != nil || (st.CheckMachineID != nil && !st.CheckMachineID(sf.machineID)) { if err != nil {
return nil return nil, err
} }
if st.CheckMachineID != nil && !st.CheckMachineID(sf.machineID) {
return nil, ErrInvalidMachineID
}
return sf, nil
}
// NewSonyflake returns a new Sonyflake configured with the given Settings.
// NewSonyflake returns nil in the following cases:
// - Settings.StartTime is ahead of the current time.
// - Settings.MachineID returns an error.
// - Settings.CheckMachineID returns false.
func NewSonyflake(st Settings) *Sonyflake {
sf, _ := New(st)
return sf return sf
} }
@@ -158,7 +180,7 @@ func sleepTime(overtime int64) time.Duration {
func (sf *Sonyflake) toID() (uint64, error) { func (sf *Sonyflake) toID() (uint64, error) {
if sf.elapsedTime >= 1<<BitLenTime { if sf.elapsedTime >= 1<<BitLenTime {
return 0, errors.New("over the time limit") return 0, ErrOverTimeLimit
} }
return uint64(sf.elapsedTime)<<(BitLenSequence+BitLenMachineID) | return uint64(sf.elapsedTime)<<(BitLenSequence+BitLenMachineID) |
@@ -183,7 +205,7 @@ func privateIPv4(interfaceAddrs types.InterfaceAddrs) (net.IP, error) {
return ip, nil return ip, nil
} }
} }
return nil, errors.New("no private ip address") return nil, ErrNoPrivateAddress
} }
func isPrivateIPv4(ip net.IP) bool { func isPrivateIPv4(ip net.IP) bool {

View File

@@ -1,7 +1,7 @@
package sonyflake package sonyflake
import ( import (
"bytes" "errors"
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
@@ -40,6 +40,60 @@ func nextID(t *testing.T) uint64 {
return id return id
} }
func TestNew(t *testing.T) {
genError := fmt.Errorf("an error occurred while generating ID")
tests := []struct {
name string
settings Settings
err error
}{
{
name: "failure: time ahead",
settings: Settings{
StartTime: time.Now().Add(time.Minute),
},
err: ErrStartTimeAhead,
},
{
name: "failure: machine ID",
settings: Settings{
MachineID: func() (uint16, error) {
return 0, genError
},
},
err: genError,
},
{
name: "failure: invalid machine ID",
settings: Settings{
CheckMachineID: func(uint16) bool {
return false
},
},
err: ErrInvalidMachineID,
},
{
name: "success",
settings: Settings{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sonyflake, err := New(test.settings)
if !errors.Is(err, test.err) {
t.Fatalf("unexpected value, want %#v, got %#v", test.err, err)
}
if sonyflake == nil && err == nil {
t.Fatal("unexpected value, sonyflake should not be nil")
}
})
}
}
func TestSonyflakeOnce(t *testing.T) { func TestSonyflakeOnce(t *testing.T) {
sleepTime := time.Duration(50 * sonyflakeTimeUnit) sleepTime := time.Duration(50 * sonyflakeTimeUnit)
time.Sleep(sleepTime) time.Sleep(sleepTime)
@@ -150,30 +204,6 @@ func TestSonyflakeInParallel(t *testing.T) {
fmt.Println("number of id:", len(set)) fmt.Println("number of id:", len(set))
} }
func TestNilSonyflake(t *testing.T) {
var startInFuture Settings
startInFuture.StartTime = time.Now().Add(time.Duration(1) * time.Minute)
if NewSonyflake(startInFuture) != nil {
t.Errorf("sonyflake starting in the future")
}
var noMachineID Settings
noMachineID.MachineID = func() (uint16, error) {
return 0, fmt.Errorf("no machine id")
}
if NewSonyflake(noMachineID) != nil {
t.Errorf("sonyflake with no machine id")
}
var invalidMachineID Settings
invalidMachineID.CheckMachineID = func(uint16) bool {
return false
}
if NewSonyflake(invalidMachineID) != nil {
t.Errorf("sonyflake with invalid machine id")
}
}
func pseudoSleep(period time.Duration) { func pseudoSleep(period time.Duration) {
sf.startTime -= int64(period) / sonyflakeTimeUnit sf.startTime -= int64(period) / sonyflakeTimeUnit
} }
@@ -228,7 +258,7 @@ func TestPrivateIPv4(t *testing.T) {
return return
} }
if bytes.Equal(actual, tc.expected) { if net.IP.Equal(actual, tc.expected) {
return return
} else { } else {
t.Errorf("error: expected: %s, but got: %s", tc.expected, actual) t.Errorf("error: expected: %s, but got: %s", tc.expected, actual)