From 2343cac6764e32c756da0bea2b913a4ba188fc2d Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sun, 18 May 2025 06:56:36 +0000 Subject: [PATCH] Check compose args (#74) * v2: take only 'bitsMachine' least significant bits from generated machine IDs so it doesn't corrupt the other ID parts (#72) Co-authored-by: ok32 Co-authored-by: Yoshiyuki Mineo * Refactor Sonyflake Compose method to return errors for invalid parameters and update related tests. Consolidate error handling for start time, sequence, and machine ID validations. Remove unused variable in tests for clarity. --------- Co-authored-by: ok32 Co-authored-by: ok32 --- v2/sonyflake.go | 22 +++++++++++++--- v2/sonyflake_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/v2/sonyflake.go b/v2/sonyflake.go index 7e00701..00bf609 100644 --- a/v2/sonyflake.go +++ b/v2/sonyflake.go @@ -74,8 +74,9 @@ var ( ErrInvalidBitsSequence = errors.New("invalid bit length for sequence number") ErrInvalidBitsMachineID = errors.New("invalid bit length for machine id") ErrInvalidTimeUnit = errors.New("invalid time unit") + ErrInvalidSequence = errors.New("invalid sequence number") ErrInvalidMachineID = errors.New("invalid machine id") - ErrStartTimeAhead = errors.New("start time is ahead of now") + ErrStartTimeAhead = errors.New("start time is ahead") ErrOverTimeLimit = errors.New("over the time limit") ErrNoPrivateAddress = errors.New("no private ip address") ) @@ -260,11 +261,26 @@ func (sf *Sonyflake) ToTime(id int64) time.Time { // The time parameter should be the time when the ID was generated. // The sequence parameter should be between 0 and 2^BitsSequence-1 (inclusive). // The machineID parameter should be between 0 and 2^BitsMachineID-1 (inclusive). -func (sf *Sonyflake) Compose(t time.Time, sequence, machineID int) int64 { +func (sf *Sonyflake) Compose(t time.Time, sequence, machineID int) (int64, error) { elapsedTime := sf.toInternalTime(t.UTC()) - sf.startTime + if elapsedTime < 0 { + return 0, ErrStartTimeAhead + } + if elapsedTime >= 1<= 1<= 1<