diff --git a/src/backend.go b/src/backend.go index f4ee04b..afe8405 100644 --- a/src/backend.go +++ b/src/backend.go @@ -151,8 +151,11 @@ type Backend interface { // Retrieve a single record from the audit log. QueryAuditLog(ctx context.Context, id AuditID) (record *AuditRecord, err error) - // Retrieve records from the audit log by time range. + // Retrieve record IDs from the audit log by time range. SearchAuditLog(ctx context.Context, opts SearchAuditLogOptions) iter.Seq2[AuditID, error] + + // Retrieve audit record contents for given IDs. + GetAuditLogRecords(ctx context.Context, ids iter.Seq2[AuditID, error]) iter.Seq2[*AuditRecord, error] } func CreateBackend(ctx context.Context, config *StorageConfig) (backend Backend, err error) { diff --git a/src/backend_fs.go b/src/backend_fs.go index be93049..612d587 100644 --- a/src/backend_fs.go +++ b/src/backend_fs.go @@ -524,3 +524,19 @@ func (fs *FSBackend) SearchAuditLog( }) } } + +func (fs *FSBackend) GetAuditLogRecords( + ctx context.Context, ids iter.Seq2[AuditID, error], +) iter.Seq2[*AuditRecord, error] { + return func(yield func(*AuditRecord, error) bool) { + for id, err := range ids { + var record *AuditRecord + if err == nil { + record, err = fs.QueryAuditLog(ctx, id) + } + if !yield(record, err) { + break + } + } + } +} diff --git a/src/backend_s3.go b/src/backend_s3.go index 5f73700..0d1341f 100644 --- a/src/backend_s3.go +++ b/src/backend_s3.go @@ -848,3 +848,40 @@ func (s3 *S3Backend) SearchAuditLog( } } } + +var getAuditLogRecordsSemaphore = make(chan struct{}, 64) + +func (s3 *S3Backend) GetAuditLogRecords( + ctx context.Context, ids iter.Seq2[AuditID, error], +) iter.Seq2[*AuditRecord, error] { + return func(yield func(*AuditRecord, error) bool) { + resultsChan := make(chan tuple[*AuditRecord, error]) + enumeratorCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func(ctx context.Context) { + wg := sync.WaitGroup{} + for id, err := range ids { + if err != nil { + resultsChan <- tuple[*AuditRecord, error]{nil, err} + } else { + getAuditLogRecordsSemaphore <- struct{}{} // acquire + wg.Go(func() { + defer func() { <-getAuditLogRecordsSemaphore }() // release + record, err := s3.QueryAuditLog(ctx, id) + resultsChan <- tuple[*AuditRecord, error]{record, err} + }) + } + } + wg.Wait() + close(resultsChan) + }(enumeratorCtx) + + for result := range resultsChan { + record, err := result.Splat() + if !yield(record, err) { + break + } + } + } +} diff --git a/src/main.go b/src/main.go index ca94635..085d179 100644 --- a/src/main.go +++ b/src/main.go @@ -462,30 +462,20 @@ func Main(versionInfo string) { } case *auditLog: - ch := make(chan *AuditRecord) - ids := []AuditID{} - for id, err := range backend.SearchAuditLog(ctx, SearchAuditLogOptions{}) { + records := []*AuditRecord{} + ids := backend.SearchAuditLog(ctx, SearchAuditLogOptions{}) + for record, err := range backend.GetAuditLogRecords(ctx, ids) { if err != nil { logc.Fatalln(ctx, err) } - go func() { - if record, err := backend.QueryAuditLog(ctx, id); err != nil { - logc.Fatalln(ctx, err) - } else { - ch <- record - } - }() - ids = append(ids, id) + records = append(records, record) } - records := map[AuditID]*AuditRecord{} - for len(records) < len(ids) { - record := <-ch - records[record.GetAuditID()] = record - } + slices.SortFunc(records, func(a, b *AuditRecord) int { + return cmp.Compare(a.GetAuditID(), b.GetAuditID()) + }) - for _, id := range ids { - record := records[id] + for _, record := range records { fmt.Fprintf(color.Output, "%s %s %s %s %s\n", record.GetAuditID().String(), color.HiWhiteString(record.GetTimestamp().AsTime().UTC().Format(time.RFC3339)), diff --git a/src/observe.go b/src/observe.go index da744c9..a21edca 100644 --- a/src/observe.go +++ b/src/observe.go @@ -376,3 +376,17 @@ func (backend *observedBackend) SearchAuditLog( span.Finish() } } + +func (backend *observedBackend) GetAuditLogRecords( + ctx context.Context, ids iter.Seq2[AuditID, error], +) iter.Seq2[*AuditRecord, error] { + return func(yield func(*AuditRecord, error) bool) { + span, ctx := ObserveFunction(ctx, "GetAuditLogRecords") + for item, err := range backend.inner.GetAuditLogRecords(ctx, ids) { + if !yield(item, err) { + break + } + } + span.Finish() + } +}