Files
at-container-registry/pkg/labeler/subscribe.go

338 lines
9.3 KiB
Go

package labeler
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strconv"
"strings"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/events"
"github.com/gorilla/websocket"
cbg "github.com/whyrusleeping/cbor-gen"
)
const (
subscriberBuffer = 64
backfillPageLimit = 200
)
var upgrader = websocket.Upgrader{
// CheckOrigin is permissive: the firehose is a public stream by design and ATProto
// consumers are not browsers, so the same-origin policy doesn't apply to them anyway.
CheckOrigin: func(r *http.Request) bool { return true },
}
// frameLabels builds the binary frame for a labels event: CBOR-encoded
// {op:1, t:"#labels"} header concatenated with CBOR-encoded {seq, labels:[...]} body.
func frameLabels(seq int64, labels []*comatproto.LabelDefs_Label) ([]byte, error) {
var buf bytes.Buffer
w := cbg.NewCborWriter(&buf)
header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#labels"}
if err := header.MarshalCBOR(w); err != nil {
return nil, fmt.Errorf("marshal header: %w", err)
}
body := comatproto.LabelSubscribeLabels_Labels{Seq: seq, Labels: labels}
if err := body.MarshalCBOR(w); err != nil {
return nil, fmt.Errorf("marshal body: %w", err)
}
return buf.Bytes(), nil
}
// frameInfo builds the binary frame for an info event: header {op:1, t:"#info"} plus body.
func frameInfo(name, message string) ([]byte, error) {
var buf bytes.Buffer
w := cbg.NewCborWriter(&buf)
header := events.EventHeader{Op: events.EvtKindMessage, MsgType: "#info"}
if err := header.MarshalCBOR(w); err != nil {
return nil, err
}
body := comatproto.LabelSubscribeLabels_Info{Name: name}
if message != "" {
body.Message = &message
}
if err := body.MarshalCBOR(w); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// frameError builds an error frame: header {op:-1} plus {error, message}.
func frameError(name, message string) ([]byte, error) {
var buf bytes.Buffer
w := cbg.NewCborWriter(&buf)
header := events.EventHeader{Op: events.EvtKindErrorFrame}
if err := header.MarshalCBOR(w); err != nil {
return nil, err
}
body := events.ErrorFrame{Error: name, Message: message}
if err := body.MarshalCBOR(w); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// labelToLexicon converts a stored row into the indigo lexicon type used in the wire format.
func labelToLexicon(l *Label) *comatproto.LabelDefs_Label {
tmp := l.ToLabeling()
lex := tmp.ToLexicon()
return &lex
}
// handleSubscribeLabels implements com.atproto.label.subscribeLabels.
//
// Wire format: each WebSocket binary message is two concatenated CBOR objects (header
// + body) matching the firehose convention. Backfill pages historical labels since the
// cursor, then the connection joins the broadcast hub for live deliveries.
func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) {
cursorStr := r.URL.Query().Get("cursor")
var cursor int64
if cursorStr != "" {
v, err := strconv.ParseInt(cursorStr, 10, 64)
if err != nil {
http.Error(w, "invalid cursor", http.StatusBadRequest)
return
}
cursor = v
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("WebSocket upgrade failed", "error", err)
return
}
defer conn.Close()
slog.Info("subscribeLabels client connected", "cursor", cursor)
latest, err := LatestSeq(s.db)
if err != nil {
slog.Error("Failed to read latest seq", "error", err)
return
}
if cursor > latest {
if frame, ferr := frameError("FutureCursor", "cursor is in the future"); ferr == nil {
_ = conn.WriteMessage(websocket.BinaryMessage, frame)
}
return
}
// Subscribe to the broadcast hub BEFORE backfilling so we don't lose events
// that arrive while we're streaming the historical tail.
sub, cancel := s.hub.subscribe(subscriberBuffer)
defer cancel()
if cursor > 0 {
if frame, ferr := frameInfo("OutdatedCursor", "starting backfill from cursor"); ferr == nil {
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
return
}
}
}
// Backfill historical labels in pages until we catch up.
for {
labels, err := GetLabelsSince(s.db, cursor, backfillPageLimit)
if err != nil {
slog.Error("Failed to read labels for backfill", "error", err)
return
}
if len(labels) == 0 {
break
}
for i := range labels {
frame, ferr := frameLabels(labels[i].ID, []*comatproto.LabelDefs_Label{labelToLexicon(&labels[i])})
if ferr != nil {
slog.Error("Failed to encode label frame", "error", ferr)
return
}
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
return
}
cursor = labels[i].ID
}
if len(labels) < backfillPageLimit {
break
}
}
// Live delivery: a goroutine monitors the read side so we notice client disconnects;
// the main loop pulls from the hub and writes frames until either side closes.
done := make(chan struct{})
go func() {
defer close(done)
for {
if _, _, rerr := conn.ReadMessage(); rerr != nil {
return
}
}
}()
for {
select {
case <-done:
return
case lbl, ok := <-sub.ch:
if !ok {
return
}
if lbl.ID <= cursor {
continue // already delivered during backfill
}
frame, ferr := frameLabels(lbl.ID, []*comatproto.LabelDefs_Label{labelToLexicon(lbl)})
if ferr != nil {
slog.Error("Failed to encode label frame", "error", ferr)
return
}
if err := conn.WriteMessage(websocket.BinaryMessage, frame); err != nil {
return
}
cursor = lbl.ID
}
}
}
// queryLabelsResponse mirrors the lexicon JSON shape for queryLabels.
type queryLabelsResponse struct {
Cursor string `json:"cursor,omitempty"`
Labels []*comatproto.LabelDefs_Label `json:"labels"`
}
// handleQueryLabels implements com.atproto.label.queryLabels.
//
// Filters (uriPatterns, sources) are applied in SQL so the LIMIT cap operates on the
// filtered result, not the raw scan. URI patterns support a single trailing `*` glob
// (LIKE), with `%` and `_` escaped to remain literal.
func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
patterns := q["uriPatterns"]
sources := q["sources"]
var cursor int64
if cs := q.Get("cursor"); cs != "" {
v, err := strconv.ParseInt(cs, 10, 64)
if err != nil {
http.Error(w, "invalid cursor", http.StatusBadRequest)
return
}
cursor = v
}
limit := 50
if ls := q.Get("limit"); ls != "" {
if l, err := strconv.Atoi(ls); err == nil && l > 0 && l <= 250 {
limit = l
}
}
rows, err := queryLabelsSQL(s.db, patterns, sources, cursor, limit)
if err != nil {
if errors.Is(err, errInvalidPattern) {
http.Error(w, "invalid uriPattern: wildcard '*' must be at the end", http.StatusBadRequest)
return
}
slog.Error("queryLabels failed", "error", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
out := &queryLabelsResponse{Labels: make([]*comatproto.LabelDefs_Label, 0, len(rows))}
for i := range rows {
out.Labels = append(out.Labels, labelToLexicon(&rows[i]))
}
if len(rows) > 0 {
out.Cursor = strconv.FormatInt(rows[len(rows)-1].ID, 10)
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(out)
}
var errInvalidPattern = errors.New("invalid uriPattern")
// queryLabelsSQL builds the WHERE clause from filter args and runs the query. All
// filtering happens in SQL so LIMIT operates on already-filtered rows.
func queryLabelsSQL(db *sql.DB, patterns, sources []string, cursor int64, limit int) ([]Label, error) {
var (
where []string
args []any
)
where = append(where, "id > ?")
args = append(args, cursor)
if len(patterns) > 0 {
var ors []string
var matchAll bool
for _, p := range patterns {
if p == "" {
continue
}
if p == "*" {
matchAll = true
break
}
like, err := patternToLike(p)
if err != nil {
return nil, err
}
if strings.ContainsAny(like, `%_\`) {
ors = append(ors, "uri LIKE ? ESCAPE '\\'")
} else {
ors = append(ors, "uri = ?")
}
args = append(args, like)
}
if !matchAll && len(ors) > 0 {
where = append(where, "("+strings.Join(ors, " OR ")+")")
}
}
if len(sources) > 0 {
placeholders := strings.Repeat("?,", len(sources))
placeholders = placeholders[:len(placeholders)-1]
where = append(where, "src IN ("+placeholders+")")
for _, s := range sources {
args = append(args, s)
}
}
args = append(args, limit)
q := `SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, ver, sig, subject_did, subject_repo
FROM labels WHERE ` + strings.Join(where, " AND ") +
` ORDER BY id ASC LIMIT ?`
rows, err := db.Query(q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanLabels(rows)
}
// patternToLike converts a uriPattern into a SQLite LIKE expression. The only
// wildcard supported is a trailing `*`, which becomes `%`. Literal `%`, `_`, and `\`
// in the input are escaped via the LIKE ESCAPE clause used at query time.
func patternToLike(p string) (string, error) {
if idx := strings.Index(p, "*"); idx >= 0 && idx != len(p)-1 {
return "", errInvalidPattern
}
literal := p
suffix := ""
if strings.HasSuffix(p, "*") {
literal = p[:len(p)-1]
suffix = "%"
}
literal = strings.ReplaceAll(literal, `\`, `\\`)
literal = strings.ReplaceAll(literal, `%`, `\%`)
literal = strings.ReplaceAll(literal, `_`, `\_`)
return literal + suffix, nil
}