Files
2026-05-11 19:53:13 -05:00

366 lines
11 KiB
Go

package testpds
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"testing"
"time"
"atcr.io/pkg/atproto"
)
// Server is the fake PDS. Each test calls New(t) to spin one up, then
// AddIdentity() to register users. The Directory() must be passed to
// atproto.SetDirectory() so AppView and Hold resolve DIDs through it.
type Server struct {
t *testing.T
httptest *httptest.Server
dir *Directory
records *recordStore
blobs *blobStore
mu sync.RWMutex
byDID map[string]*Identity
byHandle map[string]*Identity
byBearer map[string]*Identity // accessJwt → identity
didHostEsc string // percent-encoded host:port for synthesized DIDs
}
// New starts a fake PDS bound to a random port via httptest.NewServer. The
// server is torn down automatically via t.Cleanup. Callers should immediately
// install s.Directory() with atproto.SetDirectory() so DID resolution short-
// circuits through the in-memory store.
func New(t *testing.T) *Server {
t.Helper()
s := &Server{
t: t,
dir: newDirectory(),
records: newRecordStore(),
blobs: newBlobStore(),
byDID: make(map[string]*Identity),
byHandle: make(map[string]*Identity),
byBearer: make(map[string]*Identity),
}
mux := http.NewServeMux()
mux.HandleFunc("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
mux.HandleFunc("/xrpc/com.atproto.server.getServiceAuth", s.handleGetServiceAuth)
mux.HandleFunc("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord)
mux.HandleFunc("/xrpc/com.atproto.repo.getRecord", s.handleGetRecord)
mux.HandleFunc("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
mux.HandleFunc("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord)
mux.HandleFunc("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
mux.HandleFunc("/xrpc/com.atproto.repo.uploadBlob", s.handleUploadBlob)
mux.HandleFunc("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
s.httptest = httptest.NewServer(mux)
t.Cleanup(s.httptest.Close)
s.didHostEsc = didWebForHost(strings.TrimPrefix(s.httptest.URL, "http://"))
return s
}
// URL returns the fake PDS's base URL (http://127.0.0.1:NNNN).
func (s *Server) URL() string { return s.httptest.URL }
// Directory returns the identity.Directory backing this fake PDS. Install
// via atproto.SetDirectory() before any DID resolution happens.
func (s *Server) Directory() *Directory { return s.dir }
// GetRecord returns the raw JSON value of a record stored in the fake PDS.
// Useful for tests that need to assert a downstream component wrote a record
// without going through the HTTP XRPC surface. Returns (nil, false) on miss.
func (s *Server) GetRecord(did, collection, rkey string) (json.RawMessage, bool) {
return s.records.get(did, collection, rkey)
}
// AddIdentity creates a new account, returns its Identity (DID, handle,
// signing key, synthetic password / accessJwt), and registers it in the
// directory so DID/handle lookups return it.
func (s *Server) AddIdentity(handle string) (*Identity, error) {
ident, err := newIdentity(s.URL(), s.didHostEsc, handle)
if err != nil {
return nil, err
}
indigoIdent, err := ident.toIndigoIdentity()
if err != nil {
return nil, err
}
s.dir.Register(indigoIdent)
s.mu.Lock()
s.byDID[ident.DID.String()] = ident
s.byHandle[ident.Handle.String()] = ident
s.byBearer[ident.AccessToken] = ident
s.mu.Unlock()
return ident, nil
}
// --- handlers --------------------------------------------------------------
func (s *Server) handleCreateSession(w http.ResponseWriter, r *http.Request) {
var body struct {
Identifier string `json:"identifier"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
return
}
s.mu.RLock()
ident := s.lookupIdentifier(body.Identifier)
s.mu.RUnlock()
if ident == nil || ident.Password != body.Password {
writeXRPCError(w, http.StatusUnauthorized, "AuthFactorTokenRequired", "invalid identifier/password")
return
}
writeJSON(w, http.StatusOK, map[string]any{
"did": ident.DID.String(),
"handle": ident.Handle.String(),
"accessJwt": ident.AccessToken,
"refreshJwt": ident.AccessToken + "-refresh",
"active": true,
})
}
func (s *Server) handleGetServiceAuth(w http.ResponseWriter, r *http.Request) {
ident := s.authenticate(r)
if ident == nil {
writeXRPCError(w, http.StatusUnauthorized, "AuthenticationRequired", "missing or invalid bearer token")
return
}
q := r.URL.Query()
aud := q.Get("aud")
lxm := q.Get("lxm")
expStr := q.Get("exp")
if aud == "" || lxm == "" || expStr == "" {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest",
"aud, lxm, and exp are required; got "+q.Encode())
return
}
expUnix, err := strconv.ParseInt(expStr, 10, 64)
if err != nil {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", "exp must be a Unix timestamp")
return
}
// Clamp to <=1h grant, as real PDSes do.
exp := time.Unix(expUnix, 0)
maxExp := time.Now().Add(1 * time.Hour)
if exp.After(maxExp) {
exp = maxExp
}
tok, err := ident.signServiceAuthJWT(aud, lxm, exp)
if err != nil {
writeXRPCError(w, http.StatusInternalServerError, "InternalServerError", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]string{"token": tok})
}
func (s *Server) handlePutRecord(w http.ResponseWriter, r *http.Request) {
if !s.authenticated(r) {
writeXRPCError(w, http.StatusUnauthorized, "AuthenticationRequired", "")
return
}
var body struct {
Repo string `json:"repo"`
Collection string `json:"collection"`
RKey string `json:"rkey"`
Record json.RawMessage `json:"record"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
return
}
if body.Repo == "" || body.Collection == "" || body.RKey == "" {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", "repo, collection, rkey required")
return
}
s.records.put(body.Repo, body.Collection, body.RKey, body.Record)
writeJSON(w, http.StatusOK, map[string]string{
"uri": fmt.Sprintf("at://%s/%s/%s", body.Repo, body.Collection, body.RKey),
"cid": fakeCID(body.Record),
})
}
func (s *Server) handleGetRecord(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
repo := q.Get("repo")
collection := q.Get("collection")
rkey := q.Get("rkey")
if repo == "" || collection == "" || rkey == "" {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", "repo, collection, rkey required")
return
}
// resolve handle → DID if caller passed a handle as repo
repoDID := s.didForRepo(repo)
val, ok := s.records.get(repoDID, collection, rkey)
if !ok {
writeXRPCError(w, http.StatusNotFound, "RecordNotFound", fmt.Sprintf("%s/%s/%s", repoDID, collection, rkey))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"uri": fmt.Sprintf("at://%s/%s/%s", repoDID, collection, rkey),
"cid": fakeCID(val),
"value": val,
})
}
func (s *Server) handleListRecords(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
repo := q.Get("repo")
collection := q.Get("collection")
if repo == "" || collection == "" {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", "repo, collection required")
return
}
repoDID := s.didForRepo(repo)
entries := s.records.list(repoDID, collection)
records := make([]map[string]any, 0, len(entries))
for _, e := range entries {
records = append(records, map[string]any{
"uri": fmt.Sprintf("at://%s/%s/%s", repoDID, collection, e.RKey),
"cid": fakeCID(e.Value),
"value": e.Value,
})
}
writeJSON(w, http.StatusOK, map[string]any{"records": records})
}
func (s *Server) handleDeleteRecord(w http.ResponseWriter, r *http.Request) {
if !s.authenticated(r) {
writeXRPCError(w, http.StatusUnauthorized, "AuthenticationRequired", "")
return
}
var body struct {
Repo string `json:"repo"`
Collection string `json:"collection"`
RKey string `json:"rkey"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
return
}
s.records.delete(body.Repo, body.Collection, body.RKey)
writeJSON(w, http.StatusOK, map[string]string{})
}
func (s *Server) handleUploadBlob(w http.ResponseWriter, r *http.Request) {
if !s.authenticated(r) {
writeXRPCError(w, http.StatusUnauthorized, "AuthenticationRequired", "")
return
}
body, err := io.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", err.Error())
return
}
mimeType := r.Header.Get("Content-Type")
cid := fakeCID(body)
s.blobs.put(cid, body)
writeJSON(w, http.StatusOK, map[string]any{
"blob": map[string]any{
"$type": "blob",
"ref": map[string]string{"$link": cid},
"mimeType": mimeType,
"size": len(body),
},
})
}
func (s *Server) handleSyncGetBlob(w http.ResponseWriter, r *http.Request) {
cid := r.URL.Query().Get("cid")
if cid == "" {
writeXRPCError(w, http.StatusBadRequest, "InvalidRequest", "cid required")
return
}
data, ok := s.blobs.get(cid)
if !ok {
writeXRPCError(w, http.StatusNotFound, "BlobNotFound", cid)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
func (s *Server) handleResolveHandle(w http.ResponseWriter, r *http.Request) {
handle := r.URL.Query().Get("handle")
s.mu.RLock()
ident, ok := s.byHandle[handle]
s.mu.RUnlock()
if !ok {
writeXRPCError(w, http.StatusNotFound, "HandleNotFound", handle)
return
}
writeJSON(w, http.StatusOK, map[string]string{"did": ident.DID.String()})
}
// --- helpers ---------------------------------------------------------------
// lookupIdentifier resolves either a handle, a did:web, or a did:plc string
// to one of the registered identities. Caller must hold s.mu (read lock).
func (s *Server) lookupIdentifier(id string) *Identity {
if ident, ok := s.byDID[id]; ok {
return ident
}
if ident, ok := s.byHandle[id]; ok {
return ident
}
return nil
}
// didForRepo accepts either a DID or a handle and returns the DID. If the
// repo isn't registered, the input is returned unchanged so the records store
// returns a miss as expected.
func (s *Server) didForRepo(repo string) string {
if strings.HasPrefix(repo, "did:") {
return repo
}
s.mu.RLock()
defer s.mu.RUnlock()
if ident, ok := s.byHandle[repo]; ok {
return ident.DID.String()
}
return repo
}
// authenticate resolves the Bearer token to an identity. Returns nil if the
// header is missing or the token isn't recognized.
func (s *Server) authenticate(r *http.Request) *Identity {
h := r.Header.Get("Authorization")
const bearer = "Bearer "
if !strings.HasPrefix(h, bearer) {
return nil
}
tok := strings.TrimPrefix(h, bearer)
s.mu.RLock()
defer s.mu.RUnlock()
return s.byBearer[tok]
}
func (s *Server) authenticated(r *http.Request) bool { return s.authenticate(r) != nil }
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func writeXRPCError(w http.ResponseWriter, status int, name, msg string) {
writeJSON(w, status, map[string]string{"error": name, "message": msg})
}
// Compile-time assertion that we implement atproto.Directory's expectations.
// atproto.SetDirectory takes an identity.Directory, which Directory satisfies.
var _ = atproto.SetDirectory