Files
at-container-registry/cmd/record-query/main.go

579 lines
14 KiB
Go

// record-query queries the ATProto relay to find all users with records in a given
// collection, fetches the records from each user's PDS, and optionally filters them.
//
// Usage:
//
// go run ./cmd/record-query --collection io.atcr.sailor.profile --filter "defaultHold!=prefix:did:web"
// go run ./cmd/record-query --collection io.atcr.manifest
// go run ./cmd/record-query --collection io.atcr.sailor.profile --limit 5
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"sort"
"strings"
"time"
)
// ListReposByCollectionResponse is the response from com.atproto.sync.listReposByCollection
type ListReposByCollectionResponse struct {
Repos []RepoRef `json:"repos"`
Cursor string `json:"cursor,omitempty"`
}
// RepoRef is a single repo reference
type RepoRef struct {
DID string `json:"did"`
}
// ListRecordsResponse is the response from com.atproto.repo.listRecords
type ListRecordsResponse struct {
Records []Record `json:"records"`
Cursor string `json:"cursor,omitempty"`
}
// Record is a single ATProto record
type Record struct {
URI string `json:"uri"`
CID string `json:"cid"`
Value json.RawMessage `json:"value"`
}
// MatchResult is a record that passed the filter
type MatchResult struct {
DID string
Handle string
URI string
Fields map[string]any
}
// Filter defines a simple field filter
type Filter struct {
Field string
Operator string // "=", "!="
Mode string // "exact", "prefix", "empty"
Value string
}
var client = &http.Client{Timeout: 30 * time.Second}
func main() {
relay := flag.String("relay", "https://relay1.us-east.bsky.network", "Relay endpoint")
collection := flag.String("collection", "io.atcr.sailor.profile", "ATProto collection to query")
filterStr := flag.String("filter", "", "Filter expression: field=value, field!=value, field=prefix:xxx, field!=prefix:xxx, field=empty, field!=empty")
resolve := flag.Bool("resolve", true, "Resolve DIDs to handles")
limit := flag.Int("limit", 0, "Max repos to process (0 = unlimited)")
flag.Parse()
// Parse filter
var filter *Filter
if *filterStr != "" {
var err error
filter, err = parseFilter(*filterStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid filter: %v\n", err)
os.Exit(1)
}
fmt.Printf("Filter: %s %s %s:%s\n", filter.Field, filter.Operator, filter.Mode, filter.Value)
}
fmt.Printf("Relay: %s\n", *relay)
fmt.Printf("Collection: %s\n", *collection)
if *limit > 0 {
fmt.Printf("Limit: %d repos\n", *limit)
}
fmt.Println()
// Step 1: Enumerate all DIDs with records in this collection
fmt.Println("Enumerating repos from relay...")
dids, err := listAllRepos(*relay, *collection, *limit)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to list repos: %v\n", err)
os.Exit(1)
}
fmt.Printf("Found %d repos with %s records\n\n", len(dids), *collection)
// Step 2: For each DID, fetch records and apply filter
fmt.Println("Fetching records from each user's PDS...")
var results []MatchResult
errorsByCategory := make(map[string][]string) // category -> list of DIDs
for i, did := range dids {
totalErrors := 0
for _, v := range errorsByCategory {
totalErrors += len(v)
}
if (i+1)%10 == 0 || i == len(dids)-1 {
fmt.Printf(" Progress: %d/%d repos (matches: %d, errors: %d)\r", i+1, len(dids), len(results), totalErrors)
}
matches, err := fetchAndFilter(did, *collection, filter)
if err != nil {
cat := categorizeError(err)
errorsByCategory[cat] = append(errorsByCategory[cat], did)
continue
}
results = append(results, matches...)
}
totalErrors := 0
for _, v := range errorsByCategory {
totalErrors += len(v)
}
fmt.Printf(" Progress: %d/%d repos (matches: %d, errors: %d)\n", len(dids), len(dids), len(results), totalErrors)
if len(errorsByCategory) > 0 {
fmt.Println(" Error breakdown:")
var cats []string
for k := range errorsByCategory {
cats = append(cats, k)
}
sort.Strings(cats)
for _, cat := range cats {
dids := errorsByCategory[cat]
fmt.Printf(" %s (%d):\n", cat, len(dids))
for _, did := range dids {
fmt.Printf(" - %s\n", did)
}
}
}
fmt.Println()
// Step 3: Resolve DIDs to handles
if *resolve && len(results) > 0 {
fmt.Println("Resolving DIDs to handles...")
handleCache := make(map[string]string)
for i := range results {
did := results[i].DID
if h, ok := handleCache[did]; ok {
results[i].Handle = h
continue
}
handle, err := resolveDIDToHandle(did)
if err != nil {
handle = did
}
handleCache[did] = handle
results[i].Handle = handle
}
fmt.Println()
}
// Step 4: Print results
if len(results) == 0 {
fmt.Println("No matching records found.")
return
}
// Sort by handle/DID for consistent output
sort.Slice(results, func(i, j int) bool {
return results[i].Handle < results[j].Handle
})
fmt.Println("========================================")
fmt.Printf("RESULTS (%d matches)\n", len(results))
fmt.Println("========================================")
for i, r := range results {
identity := r.Handle
if identity == "" {
identity = r.DID
}
fmt.Printf("\n%3d. %s\n", i+1, identity)
if r.Handle != "" && r.Handle != r.DID {
fmt.Printf(" DID: %s\n", r.DID)
}
fmt.Printf(" URI: %s\n", r.URI)
// Print interesting fields (skip $type, createdAt, updatedAt)
for k, v := range r.Fields {
if k == "$type" || k == "createdAt" || k == "updatedAt" {
continue
}
fmt.Printf(" %s: %v\n", k, v)
}
}
// CSV output
fmt.Println("\n========================================")
fmt.Println("CSV FORMAT")
fmt.Println("========================================")
// Collect all field names for CSV header
fieldSet := make(map[string]bool)
for _, r := range results {
for k := range r.Fields {
if k == "$type" || k == "createdAt" || k == "updatedAt" {
continue
}
fieldSet[k] = true
}
}
var fieldNames []string
for k := range fieldSet {
fieldNames = append(fieldNames, k)
}
sort.Strings(fieldNames)
// Header
fmt.Printf("handle,did,uri")
for _, f := range fieldNames {
fmt.Printf(",%s", f)
}
fmt.Println()
// Rows
for _, r := range results {
identity := r.Handle
if identity == "" {
identity = r.DID
}
fmt.Printf("%s,%s,%s", identity, r.DID, r.URI)
for _, f := range fieldNames {
val := ""
if v, ok := r.Fields[f]; ok {
val = fmt.Sprintf("%v", v)
}
// Escape commas in values
if strings.Contains(val, ",") {
val = "\"" + val + "\""
}
fmt.Printf(",%s", val)
}
fmt.Println()
}
}
// parseFilter parses a filter string like "field!=prefix:did:web"
func parseFilter(s string) (*Filter, error) {
f := &Filter{}
// Check for != first (before =)
if idx := strings.Index(s, "!="); idx > 0 {
f.Field = s[:idx]
f.Operator = "!="
s = s[idx+2:]
} else if idx := strings.Index(s, "="); idx > 0 {
f.Field = s[:idx]
f.Operator = "="
s = s[idx+1:]
} else {
return nil, fmt.Errorf("expected field=value or field!=value, got %q", s)
}
// Check for mode prefix
if s == "empty" {
f.Mode = "empty"
f.Value = ""
} else if strings.HasPrefix(s, "prefix:") {
f.Mode = "prefix"
f.Value = strings.TrimPrefix(s, "prefix:")
} else {
f.Mode = "exact"
f.Value = s
}
return f, nil
}
// matchFilter checks if a record's fields match the filter
func matchFilter(fields map[string]any, filter *Filter) bool {
if filter == nil {
return true
}
val := ""
if v, ok := fields[filter.Field]; ok {
val = fmt.Sprintf("%v", v)
}
switch filter.Mode {
case "empty":
isEmpty := val == "" || val == "<nil>"
if filter.Operator == "=" {
return isEmpty
}
return !isEmpty
case "prefix":
hasPrefix := strings.HasPrefix(val, filter.Value)
if filter.Operator == "=" {
return hasPrefix
}
return !hasPrefix && val != "" && val != "<nil>"
case "exact":
if filter.Operator == "=" {
return val == filter.Value
}
return val != filter.Value
}
return true
}
// categorizeError classifies an error into a human-readable category
func categorizeError(err error) string {
s := err.Error()
// HTTP status codes
for _, code := range []string{"400", "401", "403", "404", "410", "429", "500", "502", "503"} {
if strings.Contains(s, "status "+code) {
switch code {
case "400":
if strings.Contains(s, "RepoDeactivated") || strings.Contains(s, "deactivated") {
return "deactivated (400)"
}
if strings.Contains(s, "RepoTakendown") || strings.Contains(s, "takendown") {
return "takendown (400)"
}
if strings.Contains(s, "RepoNotFound") || strings.Contains(s, "Could not find repo") {
return "repo not found (400)"
}
return "bad request (400)"
case "401":
return "unauthorized (401)"
case "404":
return "not found (404)"
case "410":
return "gone/deleted (410)"
case "429":
return "rate limited (429)"
case "502":
return "bad gateway (502)"
case "503":
return "unavailable (503)"
default:
return fmt.Sprintf("HTTP %s", code)
}
}
}
// Connection errors
if strings.Contains(s, "connection refused") {
return "connection refused"
}
if strings.Contains(s, "no such host") {
return "DNS failure"
}
if strings.Contains(s, "timeout") || strings.Contains(s, "deadline exceeded") {
return "timeout"
}
if strings.Contains(s, "TLS") || strings.Contains(s, "certificate") {
return "TLS error"
}
if strings.Contains(s, "EOF") {
return "connection reset"
}
// PLC/DID errors
if strings.Contains(s, "no PDS found") {
return "no PDS in DID doc"
}
if strings.Contains(s, "unsupported DID method") {
return "unsupported DID method"
}
return "other: " + s
}
// listAllRepos paginates through the relay to get all DIDs with records in a collection
func listAllRepos(relayURL, collection string, limit int) ([]string, error) {
var dids []string
cursor := ""
for {
u := fmt.Sprintf("%s/xrpc/com.atproto.sync.listReposByCollection", relayURL)
params := url.Values{}
params.Set("collection", collection)
params.Set("limit", "1000")
if cursor != "" {
params.Set("cursor", cursor)
}
resp, err := client.Get(u + "?" + params.Encode())
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("status %d: %s", resp.StatusCode, string(body))
}
var result ListReposByCollectionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("decode failed: %w", err)
}
resp.Body.Close()
for _, repo := range result.Repos {
dids = append(dids, repo.DID)
}
fmt.Printf(" Fetched %d repos so far...\r", len(dids))
if limit > 0 && len(dids) >= limit {
dids = dids[:limit]
break
}
if result.Cursor == "" {
break
}
cursor = result.Cursor
}
fmt.Println()
return dids, nil
}
// fetchAndFilter fetches records for a DID and returns those matching the filter
func fetchAndFilter(did, collection string, filter *Filter) ([]MatchResult, error) {
// Resolve DID to PDS
pdsEndpoint, err := resolveDIDToPDS(did)
if err != nil {
return nil, fmt.Errorf("resolve PDS: %w", err)
}
var results []MatchResult
cursor := ""
for {
u := fmt.Sprintf("%s/xrpc/com.atproto.repo.listRecords", pdsEndpoint)
params := url.Values{}
params.Set("repo", did)
params.Set("collection", collection)
params.Set("limit", "100")
if cursor != "" {
params.Set("cursor", cursor)
}
resp, err := client.Get(u + "?" + params.Encode())
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("status %d", resp.StatusCode)
}
var listResp ListRecordsResponse
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("decode failed: %w", err)
}
resp.Body.Close()
for _, rec := range listResp.Records {
var fields map[string]any
if err := json.Unmarshal(rec.Value, &fields); err != nil {
continue
}
if matchFilter(fields, filter) {
results = append(results, MatchResult{
DID: did,
URI: rec.URI,
Fields: fields,
})
}
}
if listResp.Cursor == "" || len(listResp.Records) < 100 {
break
}
cursor = listResp.Cursor
}
return results, nil
}
// resolveDIDToHandle resolves a DID to a handle using the PLC directory or did:web
func resolveDIDToHandle(did string) (string, error) {
if strings.HasPrefix(did, "did:web:") {
return strings.TrimPrefix(did, "did:web:"), nil
}
if strings.HasPrefix(did, "did:plc:") {
resp, err := client.Get("https://plc.directory/" + did)
if err != nil {
return "", fmt.Errorf("PLC query failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("PLC returned status %d", resp.StatusCode)
}
var plcDoc struct {
AlsoKnownAs []string `json:"alsoKnownAs"`
}
if err := json.NewDecoder(resp.Body).Decode(&plcDoc); err != nil {
return "", fmt.Errorf("failed to parse PLC response: %w", err)
}
for _, aka := range plcDoc.AlsoKnownAs {
if strings.HasPrefix(aka, "at://") {
return strings.TrimPrefix(aka, "at://"), nil
}
}
return did, nil
}
return did, nil
}
// resolveDIDToPDS resolves a DID to its PDS endpoint
func resolveDIDToPDS(did string) (string, error) {
if strings.HasPrefix(did, "did:web:") {
domain := strings.TrimPrefix(did, "did:web:")
domain = strings.ReplaceAll(domain, "%3A", ":")
scheme := "https"
if strings.Contains(domain, ":") {
scheme = "http"
}
return scheme + "://" + domain, nil
}
if strings.HasPrefix(did, "did:plc:") {
resp, err := client.Get("https://plc.directory/" + did)
if err != nil {
return "", fmt.Errorf("PLC query failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("PLC returned status %d", resp.StatusCode)
}
var plcDoc struct {
Service []struct {
ID string `json:"id"`
Type string `json:"type"`
ServiceEndpoint string `json:"serviceEndpoint"`
} `json:"service"`
}
if err := json.NewDecoder(resp.Body).Decode(&plcDoc); err != nil {
return "", fmt.Errorf("failed to parse PLC response: %w", err)
}
for _, svc := range plcDoc.Service {
if svc.Type == "AtprotoPersonalDataServer" {
return svc.ServiceEndpoint, nil
}
}
return "", fmt.Errorf("no PDS found in DID document")
}
return "", fmt.Errorf("unsupported DID method: %s", did)
}