Check compose args (#74)

* v2: take only 'bitsMachine' least significant bits from generated machine IDs so it doesn't corrupt the other ID parts (#72)

Co-authored-by: ok32 <e.starikoff@gmail.com>
Co-authored-by: Yoshiyuki Mineo <Yoshiyuki.Mineo@jp.sony.com>

* Refactor Sonyflake Compose method to return errors for invalid parameters and update related tests. Consolidate error handling for start time, sequence, and machine ID validations. Remove unused variable in tests for clarity.

---------

Co-authored-by: ok32 <artuh.gubanova@gmail.com>
Co-authored-by: ok32 <e.starikoff@gmail.com>
This commit is contained in:
Yoshiyuki Mineo
2025-05-18 06:56:36 +00:00
committed by GitHub
parent 5347433c8c
commit 2343cac676
2 changed files with 77 additions and 5 deletions

View File

@@ -74,8 +74,9 @@ var (
ErrInvalidBitsSequence = errors.New("invalid bit length for sequence number") ErrInvalidBitsSequence = errors.New("invalid bit length for sequence number")
ErrInvalidBitsMachineID = errors.New("invalid bit length for machine id") ErrInvalidBitsMachineID = errors.New("invalid bit length for machine id")
ErrInvalidTimeUnit = errors.New("invalid time unit") ErrInvalidTimeUnit = errors.New("invalid time unit")
ErrInvalidSequence = errors.New("invalid sequence number")
ErrInvalidMachineID = errors.New("invalid machine id") ErrInvalidMachineID = errors.New("invalid machine id")
ErrStartTimeAhead = errors.New("start time is ahead of now") ErrStartTimeAhead = errors.New("start time is ahead")
ErrOverTimeLimit = errors.New("over the time limit") ErrOverTimeLimit = errors.New("over the time limit")
ErrNoPrivateAddress = errors.New("no private ip address") ErrNoPrivateAddress = errors.New("no private ip address")
) )
@@ -260,11 +261,26 @@ func (sf *Sonyflake) ToTime(id int64) time.Time {
// The time parameter should be the time when the ID was generated. // The time parameter should be the time when the ID was generated.
// The sequence parameter should be between 0 and 2^BitsSequence-1 (inclusive). // The sequence parameter should be between 0 and 2^BitsSequence-1 (inclusive).
// The machineID parameter should be between 0 and 2^BitsMachineID-1 (inclusive). // The machineID parameter should be between 0 and 2^BitsMachineID-1 (inclusive).
func (sf *Sonyflake) Compose(t time.Time, sequence, machineID int) int64 { func (sf *Sonyflake) Compose(t time.Time, sequence, machineID int) (int64, error) {
elapsedTime := sf.toInternalTime(t.UTC()) - sf.startTime elapsedTime := sf.toInternalTime(t.UTC()) - sf.startTime
if elapsedTime < 0 {
return 0, ErrStartTimeAhead
}
if elapsedTime >= 1<<sf.bitsTime {
return 0, ErrOverTimeLimit
}
if sequence < 0 || sequence >= 1<<sf.bitsSequence {
return 0, ErrInvalidSequence
}
if machineID < 0 || machineID >= 1<<sf.bitsMachine {
return 0, ErrInvalidMachineID
}
return elapsedTime<<(sf.bitsSequence+sf.bitsMachine) | return elapsedTime<<(sf.bitsSequence+sf.bitsMachine) |
int64(sequence)<<sf.bitsMachine | int64(sequence)<<sf.bitsMachine |
int64(machineID) int64(machineID), nil
} }
// Decompose returns a set of Sonyflake ID parts. // Decompose returns a set of Sonyflake ID parts.

View File

@@ -257,10 +257,11 @@ func pseudoSleep(sf *Sonyflake, period time.Duration) {
sf.startTime -= int64(period) / sf.timeUnit sf.startTime -= int64(period) / sf.timeUnit
} }
const year = time.Duration(365*24) * time.Hour
func TestNextID_ReturnsError(t *testing.T) { func TestNextID_ReturnsError(t *testing.T) {
sf := newSonyflake(t, Settings{StartTime: time.Now()}) sf := newSonyflake(t, Settings{StartTime: time.Now()})
year := time.Duration(365*24) * time.Hour
pseudoSleep(sf, time.Duration(174)*year) pseudoSleep(sf, time.Duration(174)*year)
nextID(t, sf) nextID(t, sf)
@@ -411,7 +412,11 @@ func TestComposeAndDecompose(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
id := sf.Compose(tc.time, tc.sequence, tc.machineID) id, err := sf.Compose(tc.time, tc.sequence, tc.machineID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
parts := sf.Decompose(id) parts := sf.Decompose(id)
// Verify time part // Verify time part
@@ -437,3 +442,54 @@ func TestComposeAndDecompose(t *testing.T) {
}) })
} }
} }
func TestCompose_ReturnsError(t *testing.T) {
start := time.Now()
sf := newSonyflake(t, Settings{StartTime: start})
testCases := []struct {
name string
time time.Time
sequence int
machineID int
err error
}{
{
name: "start time ahead",
time: start.Add(-time.Second),
sequence: 0,
machineID: 0,
err: ErrStartTimeAhead,
},
{
name: "over time limit",
time: start.Add(time.Duration(175) * year),
sequence: 0,
machineID: 0,
err: ErrOverTimeLimit,
},
{
name: "invalid sequence",
time: start,
sequence: 1 << sf.bitsSequence,
machineID: 0,
err: ErrInvalidSequence,
},
{
name: "invalid machine id",
time: start,
sequence: 0,
machineID: 1 << sf.bitsMachine,
err: ErrInvalidMachineID,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := sf.Compose(tc.time, tc.sequence, tc.machineID)
if !errors.Is(err, tc.err) {
t.Errorf("unexpected error: %v", err)
}
})
}
}