mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-05-24 17:01:31 +00:00
338 lines
9.3 KiB
Go
338 lines
9.3 KiB
Go
package labeler
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
comatproto "github.com/bluesky-social/indigo/api/atproto"
|
|
"github.com/bluesky-social/indigo/events"
|
|
"github.com/gorilla/websocket"
|
|
cbg "github.com/whyrusleeping/cbor-gen"
|
|
)
|
|
|
|
const (
|
|
subscriberBuffer = 64
|
|
backfillPageLimit = 200
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
// CheckOrigin is permissive: the firehose is a public stream by design and ATProto
|
|
// consumers are not browsers, so the same-origin policy doesn't apply to them anyway.
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
|
|
// frameLabels builds the binary frame for a labels event: CBOR-encoded
|
|
// {op:1, t:"#labels"} header concatenated with CBOR-encoded {seq, labels:[...]} body.
|
|
func frameLabels(seq int64, labels []*comatproto.LabelDefs_Label) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w := cbg.NewCborWriter(&buf)
|
|
|
|
header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#labels"}
|
|
if err := header.MarshalCBOR(w); err != nil {
|
|
return nil, fmt.Errorf("marshal header: %w", err)
|
|
}
|
|
body := comatproto.LabelSubscribeLabels_Labels{Seq: seq, Labels: labels}
|
|
if err := body.MarshalCBOR(w); err != nil {
|
|
return nil, fmt.Errorf("marshal body: %w", err)
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// frameInfo builds the binary frame for an info event: header {op:1, t:"#info"} plus body.
|
|
func frameInfo(name, message string) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w := cbg.NewCborWriter(&buf)
|
|
|
|
header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#info"}
|
|
if err := header.MarshalCBOR(w); err != nil {
|
|
return nil, err
|
|
}
|
|
body := comatproto.LabelSubscribeLabels_Info{Name: name}
|
|
if message != "" {
|
|
body.Message = &message
|
|
}
|
|
if err := body.MarshalCBOR(w); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// frameError builds an error frame: header {op:-1} plus {error, message}.
|
|
func frameError(name, message string) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w := cbg.NewCborWriter(&buf)
|
|
|
|
header := events.EventHeader{Op: events.EvtKindErrorFrame}
|
|
if err := header.MarshalCBOR(w); err != nil {
|
|
return nil, err
|
|
}
|
|
body := events.ErrorFrame{Error: name, Message: message}
|
|
if err := body.MarshalCBOR(w); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// labelToLexicon converts a stored row into the indigo lexicon type used in the wire format.
|
|
func labelToLexicon(l *Label) *comatproto.LabelDefs_Label {
|
|
tmp := l.ToLabeling()
|
|
lex := tmp.ToLexicon()
|
|
return &lex
|
|
}
|
|
|
|
// handleSubscribeLabels implements com.atproto.label.subscribeLabels.
|
|
//
|
|
// Wire format: each WebSocket binary message is two concatenated CBOR objects (header
|
|
// + body) matching the firehose convention. Backfill pages historical labels since the
|
|
// cursor, then the connection joins the broadcast hub for live deliveries.
|
|
func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) {
|
|
cursorStr := r.URL.Query().Get("cursor")
|
|
var cursor int64
|
|
if cursorStr != "" {
|
|
v, err := strconv.ParseInt(cursorStr, 10, 64)
|
|
if err != nil {
|
|
http.Error(w, "invalid cursor", http.StatusBadRequest)
|
|
return
|
|
}
|
|
cursor = v
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
slog.Error("WebSocket upgrade failed", "error", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
slog.Info("subscribeLabels client connected", "cursor", cursor)
|
|
|
|
latest, err := LatestSeq(s.db)
|
|
if err != nil {
|
|
slog.Error("Failed to read latest seq", "error", err)
|
|
return
|
|
}
|
|
if cursor > latest {
|
|
if frame, ferr := frameError("FutureCursor", "cursor is in the future"); ferr == nil {
|
|
_ = conn.WriteMessage(websocket.BinaryMessage, frame)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Subscribe to the broadcast hub BEFORE backfilling so we don't lose events
|
|
// that arrive while we're streaming the historical tail.
|
|
sub, cancel := s.hub.subscribe(subscriberBuffer)
|
|
defer cancel()
|
|
|
|
if cursor > 0 {
|
|
if frame, ferr := frameInfo("OutdatedCursor", "starting backfill from cursor"); ferr == nil {
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Backfill historical labels in pages until we catch up.
|
|
for {
|
|
labels, err := GetLabelsSince(s.db, cursor, backfillPageLimit)
|
|
if err != nil {
|
|
slog.Error("Failed to read labels for backfill", "error", err)
|
|
return
|
|
}
|
|
if len(labels) == 0 {
|
|
break
|
|
}
|
|
for i := range labels {
|
|
frame, ferr := frameLabels(labels[i].ID, []*comatproto.LabelDefs_Label{labelToLexicon(&labels[i])})
|
|
if ferr != nil {
|
|
slog.Error("Failed to encode label frame", "error", ferr)
|
|
return
|
|
}
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
|
|
return
|
|
}
|
|
cursor = labels[i].ID
|
|
}
|
|
if len(labels) < backfillPageLimit {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Live delivery: a goroutine monitors the read side so we notice client disconnects;
|
|
// the main loop pulls from the hub and writes frames until either side closes.
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
for {
|
|
if _, _, rerr := conn.ReadMessage(); rerr != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-done:
|
|
return
|
|
case lbl, ok := <-sub.ch:
|
|
if !ok {
|
|
return
|
|
}
|
|
if lbl.ID <= cursor {
|
|
continue // already delivered during backfill
|
|
}
|
|
frame, ferr := frameLabels(lbl.ID, []*comatproto.LabelDefs_Label{labelToLexicon(lbl)})
|
|
if ferr != nil {
|
|
slog.Error("Failed to encode label frame", "error", ferr)
|
|
return
|
|
}
|
|
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
|
|
return
|
|
}
|
|
cursor = lbl.ID
|
|
}
|
|
}
|
|
}
|
|
|
|
// queryLabelsResponse mirrors the lexicon JSON shape for queryLabels.
|
|
type queryLabelsResponse struct {
|
|
Cursor string `json:"cursor,omitempty"`
|
|
Labels []*comatproto.LabelDefs_Label `json:"labels"`
|
|
}
|
|
|
|
// handleQueryLabels implements com.atproto.label.queryLabels.
|
|
//
|
|
// Filters (uriPatterns, sources) are applied in SQL so the LIMIT cap operates on the
|
|
// filtered result, not the raw scan. URI patterns support a single trailing `*` glob
|
|
// (LIKE), with `%` and `_` escaped to remain literal.
|
|
func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
patterns := q["uriPatterns"]
|
|
sources := q["sources"]
|
|
|
|
var cursor int64
|
|
if cs := q.Get("cursor"); cs != "" {
|
|
v, err := strconv.ParseInt(cs, 10, 64)
|
|
if err != nil {
|
|
http.Error(w, "invalid cursor", http.StatusBadRequest)
|
|
return
|
|
}
|
|
cursor = v
|
|
}
|
|
|
|
limit := 50
|
|
if ls := q.Get("limit"); ls != "" {
|
|
if l, err := strconv.Atoi(ls); err == nil && l > 0 && l <= 250 {
|
|
limit = l
|
|
}
|
|
}
|
|
|
|
rows, err := queryLabelsSQL(s.db, patterns, sources, cursor, limit)
|
|
if err != nil {
|
|
if errors.Is(err, errInvalidPattern) {
|
|
http.Error(w, "invalid uriPattern: wildcard '*' must be at the end", http.StatusBadRequest)
|
|
return
|
|
}
|
|
slog.Error("queryLabels failed", "error", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
out := &queryLabelsResponse{Labels: make([]*comatproto.LabelDefs_Label, 0, len(rows))}
|
|
for i := range rows {
|
|
out.Labels = append(out.Labels, labelToLexicon(&rows[i]))
|
|
}
|
|
if len(rows) > 0 {
|
|
out.Cursor = strconv.FormatInt(rows[len(rows)-1].ID, 10)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(out)
|
|
}
|
|
|
|
var errInvalidPattern = errors.New("invalid uriPattern")
|
|
|
|
// queryLabelsSQL builds the WHERE clause from filter args and runs the query. All
|
|
// filtering happens in SQL so LIMIT operates on already-filtered rows.
|
|
func queryLabelsSQL(db *sql.DB, patterns, sources []string, cursor int64, limit int) ([]Label, error) {
|
|
var (
|
|
where []string
|
|
args []any
|
|
)
|
|
where = append(where, "id > ?")
|
|
args = append(args, cursor)
|
|
|
|
if len(patterns) > 0 {
|
|
var ors []string
|
|
var matchAll bool
|
|
for _, p := range patterns {
|
|
if p == "" {
|
|
continue
|
|
}
|
|
if p == "*" {
|
|
matchAll = true
|
|
break
|
|
}
|
|
like, err := patternToLike(p)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.ContainsAny(like, `%_\`) {
|
|
ors = append(ors, "uri LIKE ? ESCAPE '\\'")
|
|
} else {
|
|
ors = append(ors, "uri = ?")
|
|
}
|
|
args = append(args, like)
|
|
}
|
|
if !matchAll && len(ors) > 0 {
|
|
where = append(where, "("+strings.Join(ors, " OR ")+")")
|
|
}
|
|
}
|
|
|
|
if len(sources) > 0 {
|
|
placeholders := strings.Repeat("?,", len(sources))
|
|
placeholders = placeholders[:len(placeholders)-1]
|
|
where = append(where, "src IN ("+placeholders+")")
|
|
for _, s := range sources {
|
|
args = append(args, s)
|
|
}
|
|
}
|
|
|
|
args = append(args, limit)
|
|
q := `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo
|
|
FROM labels WHERE ` + strings.Join(where, " AND ") +
|
|
` ORDER BY id ASC LIMIT ?`
|
|
|
|
rows, err := db.Query(q, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanLabels(rows)
|
|
}
|
|
|
|
// patternToLike converts a uriPattern into a SQLite LIKE expression. The only
|
|
// wildcard supported is a trailing `*`, which becomes `%`. Literal `%`, `_`, and `\`
|
|
// in the input are escaped via the LIKE ESCAPE clause used at query time.
|
|
func patternToLike(p string) (string, error) {
|
|
if idx := strings.Index(p, "*"); idx >= 0 && idx != len(p)-1 {
|
|
return "", errInvalidPattern
|
|
}
|
|
literal := p
|
|
suffix := ""
|
|
if strings.HasSuffix(p, "*") {
|
|
literal = p[:len(p)-1]
|
|
suffix = "%"
|
|
}
|
|
literal = strings.ReplaceAll(literal, `\`, `\\`)
|
|
literal = strings.ReplaceAll(literal, `%`, `\%`)
|
|
literal = strings.ReplaceAll(literal, `_`, `\_`)
|
|
return literal + suffix, nil
|
|
}
|