Files
at-container-registry/pkg/atproto/client_test.go
2026-01-14 23:18:35 -06:00

1046 lines
30 KiB
Go

package atproto
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// TestNewClient verifies client initialization with Basic Auth
func TestNewClient(t *testing.T) {
client := NewClient("https://pds.example.com", "did:plc:test123", "token123")
if client.pdsEndpoint != "https://pds.example.com" {
t.Errorf("pdsEndpoint = %v, want https://pds.example.com", client.pdsEndpoint)
}
if client.did != "did:plc:test123" {
t.Errorf("did = %v, want did:plc:test123", client.did)
}
// Verify clientProvider is BasicAuthClientProvider
if client.clientProvider == nil {
t.Error("clientProvider should not be nil")
}
if _, ok := client.clientProvider.(*BasicAuthClientProvider); !ok {
t.Errorf("clientProvider should be *BasicAuthClientProvider, got %T", client.clientProvider)
}
}
// TestPutRecord tests storing a record in ATProto
func TestPutRecord(t *testing.T) {
tests := []struct {
name string
collection string
rkey string
record any
serverResponse string
serverStatus int
wantErr bool
checkFunc func(*testing.T, *Record)
}{
{
name: "successful put",
collection: ManifestCollection,
rkey: "abc123",
record: map[string]string{
"$type": ManifestCollection,
"test": "value",
},
serverResponse: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest"}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, r *Record) {
if r.URI != "at://did:plc:test123/io.atcr.manifest/abc123" {
t.Errorf("URI = %v, want at://did:plc:test123/io.atcr.manifest/abc123", r.URI)
}
if r.CID != "bafytest" {
t.Errorf("CID = %v, want bafytest", r.CID)
}
},
},
{
name: "server error",
collection: ManifestCollection,
rkey: "abc123",
record: map[string]string{"test": "value"},
serverResponse: `{"error":"InvalidRequest"}`,
serverStatus: http.StatusBadRequest,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method
if r.Method != "POST" {
t.Errorf("Method = %v, want POST", r.Method)
}
// Verify path
expectedPath := RepoPutRecord
if r.URL.Path != expectedPath {
t.Errorf("Path = %v, want %v", r.URL.Path, expectedPath)
}
// Verify Authorization header
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
t.Errorf("Authorization header missing or malformed: %v", auth)
}
// Verify request body
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("Failed to decode request body: %v", err)
}
if body["repo"] != "did:plc:test123" {
t.Errorf("repo = %v, want did:plc:test123", body["repo"])
}
if body["collection"] != tt.collection {
t.Errorf("collection = %v, want %v", body["collection"], tt.collection)
}
if body["rkey"] != tt.rkey {
t.Errorf("rkey = %v, want %v", body["rkey"], tt.rkey)
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
// Create client pointing to test server
client := NewClient(server.URL, "did:plc:test123", "test-token")
// Call PutRecord
result, err := client.PutRecord(context.Background(), tt.collection, tt.rkey, tt.record)
// Check error
if (err != nil) != tt.wantErr {
t.Errorf("PutRecord() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Run check function if provided
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
// TestGetRecord tests retrieving a record from ATProto
func TestGetRecord(t *testing.T) {
tests := []struct {
name string
collection string
rkey string
serverResponse string
serverStatus int
wantErr bool
wantNotFound bool
checkFunc func(*testing.T, *Record)
}{
{
name: "successful get",
collection: ManifestCollection,
rkey: "abc123",
serverResponse: `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest","value":{"$type":"io.atcr.manifest","repository":"myapp"}}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, r *Record) {
if r.URI != "at://did:plc:test123/io.atcr.manifest/abc123" {
t.Errorf("URI = %v, want at://did:plc:test123/io.atcr.manifest/abc123", r.URI)
}
var value map[string]any
if err := json.Unmarshal(r.Value, &value); err != nil {
t.Errorf("Failed to unmarshal value: %v", err)
}
if value["$type"] != ManifestCollection {
t.Errorf("value.$type = %v, want %v", value["$type"], ManifestCollection)
}
},
},
{
name: "record not found - 404",
collection: ManifestCollection,
rkey: "notfound",
serverResponse: ``,
serverStatus: http.StatusNotFound,
wantErr: true,
wantNotFound: true,
},
{
name: "record not found - error message",
collection: ManifestCollection,
rkey: "notfound",
serverResponse: `{"error":"RecordNotFound","message":"Record not found"}`,
serverStatus: http.StatusBadRequest,
wantErr: true,
wantNotFound: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method
if r.Method != "GET" {
t.Errorf("Method = %v, want GET", r.Method)
}
// Verify path
expectedPath := RepoGetRecord
if r.URL.Path != expectedPath {
t.Errorf("Path = %v, want %v", r.URL.Path, expectedPath)
}
// Verify query parameters
query := r.URL.Query()
if query.Get("repo") != "did:plc:test123" {
t.Errorf("repo = %v, want did:plc:test123", query.Get("repo"))
}
if query.Get("collection") != tt.collection {
t.Errorf("collection = %v, want %v", query.Get("collection"), tt.collection)
}
if query.Get("rkey") != tt.rkey {
t.Errorf("rkey = %v, want %v", query.Get("rkey"), tt.rkey)
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
// Create client pointing to test server
client := NewClient(server.URL, "did:plc:test123", "test-token")
// Call GetRecord
result, err := client.GetRecord(context.Background(), tt.collection, tt.rkey)
// Check error
if (err != nil) != tt.wantErr {
t.Errorf("GetRecord() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Check for ErrRecordNotFound
if tt.wantNotFound && err != ErrRecordNotFound {
t.Errorf("Expected ErrRecordNotFound, got %v", err)
}
// Run check function if provided
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
// TestDeleteRecord tests deleting a record from ATProto
func TestDeleteRecord(t *testing.T) {
tests := []struct {
name string
collection string
rkey string
serverResponse string
serverStatus int
wantErr bool
}{
{
name: "successful delete",
collection: ManifestCollection,
rkey: "abc123",
serverResponse: `{}`,
serverStatus: http.StatusOK,
wantErr: false,
},
{
name: "server error",
collection: ManifestCollection,
rkey: "abc123",
serverResponse: `{"error":"InvalidRequest"}`,
serverStatus: http.StatusBadRequest,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method
if r.Method != "POST" {
t.Errorf("Method = %v, want POST", r.Method)
}
// Verify path
expectedPath := RepoDeleteRecord
if r.URL.Path != expectedPath {
t.Errorf("Path = %v, want %v", r.URL.Path, expectedPath)
}
// Verify request body
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("Failed to decode request body: %v", err)
}
if body["repo"] != "did:plc:test123" {
t.Errorf("repo = %v, want did:plc:test123", body["repo"])
}
if body["collection"] != tt.collection {
t.Errorf("collection = %v, want %v", body["collection"], tt.collection)
}
if body["rkey"] != tt.rkey {
t.Errorf("rkey = %v, want %v", body["rkey"], tt.rkey)
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
// Create client pointing to test server
client := NewClient(server.URL, "did:plc:test123", "test-token")
// Call DeleteRecord
err := client.DeleteRecord(context.Background(), tt.collection, tt.rkey)
// Check error
if (err != nil) != tt.wantErr {
t.Errorf("DeleteRecord() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
// TestListRecords tests listing records in a collection
func TestListRecords(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters
query := r.URL.Query()
if query.Get("repo") != "did:plc:test123" {
t.Errorf("repo = %v, want did:plc:test123", query.Get("repo"))
}
if query.Get("collection") != ManifestCollection {
t.Errorf("collection = %v, want %v", query.Get("collection"), ManifestCollection)
}
if query.Get("limit") != "10" {
t.Errorf("limit = %v, want 10", query.Get("limit"))
}
// Send response
response := `{
"records": [
{"uri":"at://did:plc:test123/io.atcr.manifest/abc1","cid":"bafytest1","value":{"$type":"io.atcr.manifest"}},
{"uri":"at://did:plc:test123/io.atcr.manifest/abc2","cid":"bafytest2","value":{"$type":"io.atcr.manifest"}}
]
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
records, err := client.ListRecords(context.Background(), ManifestCollection, 10)
if err != nil {
t.Fatalf("ListRecords() error = %v", err)
}
if len(records) != 2 {
t.Errorf("len(records) = %v, want 2", len(records))
}
if records[0].URI != "at://did:plc:test123/io.atcr.manifest/abc1" {
t.Errorf("records[0].URI = %v", records[0].URI)
}
}
// TestUploadBlob tests uploading a blob to PDS
func TestUploadBlob(t *testing.T) {
blobData := []byte("test blob content")
mimeType := "application/octet-stream"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request
if r.Method != "POST" {
t.Errorf("Method = %v, want POST", r.Method)
}
if r.URL.Path != RepoUploadBlob {
t.Errorf("Path = %v, want %s", r.URL.Path, RepoUploadBlob)
}
if r.Header.Get("Content-Type") != mimeType {
t.Errorf("Content-Type = %v, want %v", r.Header.Get("Content-Type"), mimeType)
}
// Send response
response := `{
"blob": {
"$type": "blob",
"ref": {"$link": "bafytest123"},
"mimeType": "application/octet-stream",
"size": 17
}
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
blobRef, err := client.UploadBlob(context.Background(), blobData, mimeType)
if err != nil {
t.Fatalf("UploadBlob() error = %v", err)
}
if blobRef.Type != "blob" {
t.Errorf("Type = %v, want blob", blobRef.Type)
}
if blobRef.Ref.Link != "bafytest123" {
t.Errorf("Ref.Link = %v, want bafytest123", blobRef.Ref.Link)
}
if blobRef.Size != 17 {
t.Errorf("Size = %v, want 17", blobRef.Size)
}
}
// TestGetBlob tests downloading a blob from PDS
func TestGetBlob(t *testing.T) {
tests := []struct {
name string
cid string
serverResponse string
contentType string
wantData []byte
wantErr bool
}{
{
name: "raw blob response",
cid: "bafytest123",
serverResponse: "test blob content",
contentType: "application/octet-stream",
wantData: []byte("test blob content"),
wantErr: false,
},
{
name: "blob not found",
cid: "notfound",
serverResponse: "",
contentType: "text/plain",
wantData: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters
query := r.URL.Query()
if query.Get("did") != "did:plc:test123" {
t.Errorf("did = %v, want did:plc:test123", query.Get("did"))
}
if query.Get("cid") != tt.cid {
t.Errorf("cid = %v, want %v", query.Get("cid"), tt.cid)
}
// Send response
if tt.wantErr {
w.WriteHeader(http.StatusNotFound)
} else {
w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(http.StatusOK)
w.Write([]byte(tt.serverResponse))
}
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
data, err := client.GetBlob(context.Background(), tt.cid)
if (err != nil) != tt.wantErr {
t.Errorf("GetBlob() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && string(data) != string(tt.wantData) {
t.Errorf("GetBlob() data = %v, want %v", string(data), string(tt.wantData))
}
})
}
}
// TestBlobCDNURL tests CDN URL construction
func TestBlobCDNURL(t *testing.T) {
tests := []struct {
name string
didOrHandle string
cid string
want string
}{
{
name: "with DID",
didOrHandle: "did:plc:alice123",
cid: "bafytest123",
want: "https://imgs.blue/did:plc:alice123/bafytest123",
},
{
name: "with handle",
didOrHandle: "alice.bsky.social",
cid: "bafytest456",
want: "https://imgs.blue/alice.bsky.social/bafytest456",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BlobCDNURL(tt.didOrHandle, tt.cid)
if got != tt.want {
t.Errorf("BlobCDNURL() = %v, want %v", got, tt.want)
}
})
}
}
// TestFetchDIDDocument tests fetching and parsing DID documents
func TestFetchDIDDocument(t *testing.T) {
tests := []struct {
name string
serverResponse string
serverStatus int
wantErr bool
checkFunc func(*testing.T, *DIDDocument)
}{
{
name: "valid DID document",
serverResponse: `{
"@context": ["https://www.w3.org/ns/did/v1"],
"id": "did:web:example.com",
"service": [
{
"id": "#atproto_pds",
"type": "AtprotoPersonalDataServer",
"serviceEndpoint": "https://pds.example.com"
}
]
}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, doc *DIDDocument) {
if doc.ID != "did:web:example.com" {
t.Errorf("ID = %v, want did:web:example.com", doc.ID)
}
if len(doc.Service) != 1 {
t.Fatalf("len(Service) = %v, want 1", len(doc.Service))
}
if doc.Service[0].Type != "AtprotoPersonalDataServer" {
t.Errorf("Service[0].Type = %v", doc.Service[0].Type)
}
if doc.Service[0].ServiceEndpoint != "https://pds.example.com" {
t.Errorf("Service[0].ServiceEndpoint = %v", doc.Service[0].ServiceEndpoint)
}
},
},
{
name: "404 not found",
serverResponse: "",
serverStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "invalid JSON",
serverResponse: "not json",
serverStatus: http.StatusOK,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
client := NewClient("https://pds.example.com", "did:plc:test123", "")
doc, err := client.FetchDIDDocument(context.Background(), server.URL)
if (err != nil) != tt.wantErr {
t.Errorf("FetchDIDDocument() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, doc)
}
})
}
}
// TestClientWithEmptyToken tests that client doesn't set auth header with empty token
func TestClientWithEmptyToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth != "" {
t.Errorf("Authorization header should not be set with empty token, got: %v", auth)
}
response := `{"uri":"at://did:plc:test123/io.atcr.manifest/abc123","cid":"bafytest","value":{}}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
// Create client with empty token
client := NewClient(server.URL, "did:plc:test123", "")
// Make request - should not include Authorization header
_, err := client.GetRecord(context.Background(), ManifestCollection, "abc123")
if err != nil {
t.Fatalf("GetRecord() error = %v", err)
}
}
// TestListRecordsForRepo tests listing records for a specific repository
func TestListRecordsForRepo(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
if query.Get("repo") != "did:plc:alice123" {
t.Errorf("repo = %v, want did:plc:alice123", query.Get("repo"))
}
if query.Get("collection") != ManifestCollection {
t.Errorf("collection = %v, want %v", query.Get("collection"), ManifestCollection)
}
if query.Get("limit") != "50" {
t.Errorf("limit = %v, want 50", query.Get("limit"))
}
if query.Get("cursor") != "cursor123" {
t.Errorf("cursor = %v, want cursor123", query.Get("cursor"))
}
response := `{
"records": [
{"uri":"at://did:plc:alice123/io.atcr.manifest/abc1","cid":"bafytest1","value":{}}
],
"cursor": "nextcursor456"
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
records, cursor, err := client.ListRecordsForRepo(context.Background(), "did:plc:alice123", ManifestCollection, 50, "cursor123")
if err != nil {
t.Fatalf("ListRecordsForRepo() error = %v", err)
}
if len(records) != 1 {
t.Errorf("len(records) = %v, want 1", len(records))
}
if cursor != "nextcursor456" {
t.Errorf("cursor = %v, want nextcursor456", cursor)
}
}
// TestContextCancellation tests that client respects context cancellation
func TestContextCancellation(t *testing.T) {
// Create a server that sleeps for a while
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{}`))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
// Create a context that gets canceled immediately
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
// Request should fail with context canceled error
_, err := client.GetRecord(ctx, ManifestCollection, "abc123")
if err == nil {
t.Error("Expected error due to context cancellation, got nil")
}
}
// TestListReposByCollection tests listing repositories by collection
func TestListReposByCollection(t *testing.T) {
tests := []struct {
name string
collection string
limit int
cursor string
serverResponse string
serverStatus int
wantErr bool
checkFunc func(*testing.T, *ListReposByCollectionResult)
}{
{
name: "successful list with results",
collection: ManifestCollection,
limit: 100,
cursor: "",
serverResponse: `{
"repos": [
{"did": "did:plc:alice123"},
{"did": "did:plc:bob456"}
],
"cursor": "nextcursor789"
}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, result *ListReposByCollectionResult) {
if len(result.Repos) != 2 {
t.Errorf("len(Repos) = %v, want 2", len(result.Repos))
}
if result.Repos[0].DID != "did:plc:alice123" {
t.Errorf("Repos[0].DID = %v, want did:plc:alice123", result.Repos[0].DID)
}
if result.Cursor != "nextcursor789" {
t.Errorf("Cursor = %v, want nextcursor789", result.Cursor)
}
},
},
{
name: "empty results",
collection: ManifestCollection,
limit: 50,
cursor: "cursor123",
serverResponse: `{"repos": []}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, result *ListReposByCollectionResult) {
if len(result.Repos) != 0 {
t.Errorf("len(Repos) = %v, want 0", len(result.Repos))
}
},
},
{
name: "server error",
collection: ManifestCollection,
limit: 100,
cursor: "",
serverResponse: `{"error":"InternalError"}`,
serverStatus: http.StatusInternalServerError,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters
query := r.URL.Query()
if query.Get("collection") != tt.collection {
t.Errorf("collection = %v, want %v", query.Get("collection"), tt.collection)
}
if tt.limit > 0 && query.Get("limit") != strings.TrimSpace(string(rune(tt.limit))) {
// Check if limit param exists when specified
if !strings.Contains(r.URL.RawQuery, "limit=") {
t.Error("limit parameter missing")
}
}
if tt.cursor != "" && query.Get("cursor") != tt.cursor {
t.Errorf("cursor = %v, want %v", query.Get("cursor"), tt.cursor)
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
result, err := client.ListReposByCollection(context.Background(), tt.collection, tt.limit, tt.cursor)
if (err != nil) != tt.wantErr {
t.Errorf("ListReposByCollection() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, result)
}
})
}
}
// TestGetActorProfile tests fetching actor profiles
func TestGetActorProfile(t *testing.T) {
tests := []struct {
name string
actor string
serverResponse string
serverStatus int
wantErr bool
checkFunc func(*testing.T, *ActorProfile)
}{
{
name: "successful profile fetch by handle",
actor: "alice.bsky.social",
serverResponse: `{
"did": "did:plc:alice123",
"handle": "alice.bsky.social",
"displayName": "Alice Smith",
"description": "Test user",
"avatar": "https://cdn.example.com/avatar.jpg"
}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, profile *ActorProfile) {
if profile.DID != "did:plc:alice123" {
t.Errorf("DID = %v, want did:plc:alice123", profile.DID)
}
if profile.Handle != "alice.bsky.social" {
t.Errorf("Handle = %v, want alice.bsky.social", profile.Handle)
}
if profile.DisplayName != "Alice Smith" {
t.Errorf("DisplayName = %v, want Alice Smith", profile.DisplayName)
}
},
},
{
name: "successful profile fetch by DID",
actor: "did:plc:bob456",
serverResponse: `{
"did": "did:plc:bob456",
"handle": "bob.example.com"
}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, profile *ActorProfile) {
if profile.DID != "did:plc:bob456" {
t.Errorf("DID = %v, want did:plc:bob456", profile.DID)
}
},
},
{
name: "profile not found",
actor: "nonexistent.example.com",
serverResponse: "",
serverStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "server error",
actor: "error.example.com",
serverResponse: `{"error":"InternalError"}`,
serverStatus: http.StatusInternalServerError,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameter
query := r.URL.Query()
if query.Get("actor") != tt.actor {
t.Errorf("actor = %v, want %v", query.Get("actor"), tt.actor)
}
// Verify path
if !strings.Contains(r.URL.Path, "app.bsky.actor.getProfile") {
t.Errorf("Path = %v, should contain app.bsky.actor.getProfile", r.URL.Path)
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
profile, err := client.GetActorProfile(context.Background(), tt.actor)
if (err != nil) != tt.wantErr {
t.Errorf("GetActorProfile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, profile)
}
})
}
}
// TestGetProfileRecord tests fetching profile records from PDS
func TestGetProfileRecord(t *testing.T) {
tests := []struct {
name string
did string
serverResponse string
serverStatus int
wantErr bool
checkFunc func(*testing.T, *ProfileRecord)
}{
{
name: "successful profile record fetch",
did: "did:plc:alice123",
serverResponse: `{
"uri": "at://did:plc:alice123/app.bsky.actor.profile/self",
"cid": "bafytest",
"value": {
"displayName": "Alice Smith",
"description": "Test description",
"avatar": {
"$type": "blob",
"ref": {"$link": "bafyavatar"},
"mimeType": "image/jpeg",
"size": 12345
}
}
}`,
serverStatus: http.StatusOK,
wantErr: false,
checkFunc: func(t *testing.T, profile *ProfileRecord) {
if profile.DisplayName != "Alice Smith" {
t.Errorf("DisplayName = %v, want Alice Smith", profile.DisplayName)
}
if profile.Description != "Test description" {
t.Errorf("Description = %v, want Test description", profile.Description)
}
if profile.Avatar == nil {
t.Fatal("Avatar should not be nil")
}
if profile.Avatar.Ref.Link != "bafyavatar" {
t.Errorf("Avatar.Ref.Link = %v, want bafyavatar", profile.Avatar.Ref.Link)
}
},
},
{
name: "profile record not found",
did: "did:plc:nonexistent",
serverResponse: "",
serverStatus: http.StatusNotFound,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters
query := r.URL.Query()
if query.Get("repo") != tt.did {
t.Errorf("repo = %v, want %v", query.Get("repo"), tt.did)
}
if query.Get("collection") != "app.bsky.actor.profile" {
t.Errorf("collection = %v, want app.bsky.actor.profile", query.Get("collection"))
}
if query.Get("rkey") != "self" {
t.Errorf("rkey = %v, want self", query.Get("rkey"))
}
// Send response
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
profile, err := client.GetProfileRecord(context.Background(), tt.did)
if (err != nil) != tt.wantErr {
t.Errorf("GetProfileRecord() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkFunc != nil {
tt.checkFunc(t, profile)
}
})
}
}
// TestClientDID tests the DID() getter method
func TestClientDID(t *testing.T) {
expectedDID := "did:plc:test123"
client := NewClient("https://pds.example.com", expectedDID, "token")
if client.DID() != expectedDID {
t.Errorf("DID() = %v, want %v", client.DID(), expectedDID)
}
}
// TestClientPDSEndpoint tests the PDSEndpoint() getter method
func TestClientPDSEndpoint(t *testing.T) {
expectedEndpoint := "https://pds.example.com"
client := NewClient(expectedEndpoint, "did:plc:test123", "token")
if client.PDSEndpoint() != expectedEndpoint {
t.Errorf("PDSEndpoint() = %v, want %v", client.PDSEndpoint(), expectedEndpoint)
}
}
// TestListRecordsError tests error handling in ListRecords
func TestListRecordsError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"InternalError"}`))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
_, err := client.ListRecords(context.Background(), ManifestCollection, 10)
if err == nil {
t.Error("Expected error from ListRecords, got nil")
}
}
// TestUploadBlobError tests error handling in UploadBlob
func TestUploadBlobError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"InvalidBlob"}`))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
_, err := client.UploadBlob(context.Background(), []byte("test"), "application/octet-stream")
if err == nil {
t.Error("Expected error from UploadBlob, got nil")
}
}
// TestGetBlobServerError tests error handling in GetBlob for non-404 errors
func TestGetBlobServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"InternalError"}`))
}))
defer server.Close()
client := NewClient(server.URL, "did:plc:test123", "test-token")
_, err := client.GetBlob(context.Background(), "bafytest")
if err == nil {
t.Error("Expected error from GetBlob, got nil")
}
}