Files
at-container-registry/pkg/appview/storage/profile_test.go
2026-02-15 22:28:36 -06:00

734 lines
21 KiB
Go

package storage
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"atcr.io/pkg/atproto"
)
// TestEnsureProfile_Create tests creating a new profile when one doesn't exist
func TestEnsureProfile_Create(t *testing.T) {
tests := []struct {
name string
defaultHoldDID string
wantNormalized string // Expected defaultHold value after normalization
}{
{
name: "with DID",
defaultHoldDID: "did:web:hold01.atcr.io",
wantNormalized: "did:web:hold01.atcr.io",
},
{
name: "empty default hold",
defaultHoldDID: "",
wantNormalized: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var createdProfile *atproto.SailorProfileRecord
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// First request: GetRecord (should 404)
if r.Method == "GET" {
w.WriteHeader(http.StatusNotFound)
return
}
// Second request: PutRecord (create profile)
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
// Verify profile data
recordData := body["record"].(map[string]any)
if recordData["$type"] != atproto.SailorProfileCollection {
t.Errorf("$type = %v, want %v", recordData["$type"], atproto.SailorProfileCollection)
}
// Check defaultHold normalization
defaultHold := recordData["defaultHold"]
// Handle empty string (may be nil in JSON)
defaultHoldStr := ""
if defaultHold != nil {
defaultHoldStr = defaultHold.(string)
}
if defaultHoldStr != tt.wantNormalized {
t.Errorf("defaultHold = %v, want %v", defaultHoldStr, tt.wantNormalized)
}
// Store for later verification
profileBytes, _ := json.Marshal(recordData)
json.Unmarshal(profileBytes, &createdProfile)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := EnsureProfile(context.Background(), client, tt.defaultHoldDID)
if err != nil {
t.Fatalf("EnsureProfile() error = %v", err)
}
// Verify created profile
if createdProfile == nil {
t.Fatal("Profile was not created")
}
if createdProfile.Type != atproto.SailorProfileCollection {
t.Errorf("Type = %v, want %v", createdProfile.Type, atproto.SailorProfileCollection)
}
if createdProfile.DefaultHold != tt.wantNormalized {
t.Errorf("DefaultHold = %v, want %v", createdProfile.DefaultHold, tt.wantNormalized)
}
})
}
// URL normalization test uses a local test server for /.well-known/atproto-did
t.Run("with URL - should normalize to DID", func(t *testing.T) {
var createdProfile *atproto.SailorProfileRecord
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle hold DID resolution
if r.URL.Path == "/.well-known/atproto-did" {
w.Write([]byte("did:web:hold01.atcr.io"))
return
}
// GetRecord: profile doesn't exist
if r.Method == "GET" {
w.WriteHeader(http.StatusNotFound)
return
}
// PutRecord: create profile
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
recordData := body["record"].(map[string]any)
defaultHold := recordData["defaultHold"]
defaultHoldStr := ""
if defaultHold != nil {
defaultHoldStr = defaultHold.(string)
}
if defaultHoldStr != "did:web:hold01.atcr.io" {
t.Errorf("defaultHold = %v, want did:web:hold01.atcr.io", defaultHoldStr)
}
profileBytes, _ := json.Marshal(recordData)
json.Unmarshal(profileBytes, &createdProfile)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := EnsureProfile(context.Background(), client, server.URL)
if err != nil {
t.Fatalf("EnsureProfile() error = %v", err)
}
if createdProfile == nil {
t.Fatal("Profile was not created")
}
if createdProfile.DefaultHold != "did:web:hold01.atcr.io" {
t.Errorf("DefaultHold = %v, want did:web:hold01.atcr.io", createdProfile.DefaultHold)
}
})
}
// TestEnsureProfile_Exists tests that EnsureProfile doesn't recreate existing profiles
func TestEnsureProfile_Exists(t *testing.T) {
putRecordCalled := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// GetRecord: profile exists
if r.Method == "GET" {
response := `{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"cid": "bafytest",
"value": {
"$type": "io.atcr.sailor.profile",
"defaultHold": "did:web:hold01.atcr.io",
"createdAt": "2025-01-01T00:00:00Z",
"updatedAt": "2025-01-01T00:00:00Z"
}
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
return
}
// PutRecord: should not be called
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
putRecordCalled = true
t.Error("PutRecord should not be called when profile exists")
}
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := EnsureProfile(context.Background(), client, "did:web:hold01.atcr.io")
if err != nil {
t.Fatalf("EnsureProfile() error = %v", err)
}
if putRecordCalled {
t.Error("PutRecord was called when profile already exists")
}
}
// TestGetProfile tests retrieving a user's profile
func TestGetProfile(t *testing.T) {
tests := []struct {
name string
serverResponse string
serverStatus int
wantProfile *atproto.SailorProfileRecord
wantNil bool
wantErr bool
expectMigration bool // Whether URL-to-DID migration should happen
originalHoldURL string
expectedHoldDID string
}{
{
name: "profile with DID (no migration needed)",
serverResponse: `{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"value": {
"$type": "io.atcr.sailor.profile",
"defaultHold": "did:web:hold01.atcr.io",
"createdAt": "2025-01-01T00:00:00Z",
"updatedAt": "2025-01-01T00:00:00Z"
}
}`,
serverStatus: http.StatusOK,
wantNil: false,
wantErr: false,
expectMigration: false,
expectedHoldDID: "did:web:hold01.atcr.io",
},
{
name: "profile doesn't exist - return nil",
serverResponse: "",
serverStatus: http.StatusNotFound,
wantNil: true,
wantErr: false,
expectMigration: false,
},
{
name: "server error",
serverResponse: `{"error":"InternalServerError"}`,
serverStatus: http.StatusInternalServerError,
wantNil: false,
wantErr: true,
expectMigration: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear migration locks before each test
migrationLocks = sync.Map{}
var mu sync.Mutex
putRecordCalled := false
var migrationRequest map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// GetRecord
if r.Method == "GET" {
w.WriteHeader(tt.serverStatus)
w.Write([]byte(tt.serverResponse))
return
}
// PutRecord (migration)
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
mu.Lock()
putRecordCalled = true
json.NewDecoder(r.Body).Decode(&migrationRequest)
mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
profile, err := GetProfile(context.Background(), client)
if (err != nil) != tt.wantErr {
t.Errorf("GetProfile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantNil {
if profile != nil {
t.Errorf("GetProfile() = %v, want nil", profile)
}
return
}
if !tt.wantErr {
if profile == nil {
t.Fatal("GetProfile() returned nil, want profile")
}
// Check that defaultHold is migrated to DID in returned profile
if profile.DefaultHold != tt.expectedHoldDID {
t.Errorf("DefaultHold = %v, want %v", profile.DefaultHold, tt.expectedHoldDID)
}
if tt.expectMigration {
// Give goroutine time to execute
time.Sleep(50 * time.Millisecond)
mu.Lock()
called := putRecordCalled
request := migrationRequest
mu.Unlock()
if !called {
t.Error("Expected migration PutRecord to be called")
}
if request != nil {
recordData := request["record"].(map[string]any)
migratedHold := recordData["defaultHold"]
if migratedHold != tt.expectedHoldDID {
t.Errorf("Migrated defaultHold = %v, want %v", migratedHold, tt.expectedHoldDID)
}
}
}
}
})
}
// URL migration test uses a local test server for /.well-known/atproto-did
t.Run("profile with URL (migration needed)", func(t *testing.T) {
migrationLocks = sync.Map{}
var mu sync.Mutex
putRecordCalled := false
var migrationRequest map[string]any
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle hold DID resolution
if r.URL.Path == "/.well-known/atproto-did" {
w.Write([]byte("did:web:hold01.atcr.io"))
return
}
// GetRecord - return profile with URL pointing to this server
if r.Method == "GET" {
response := fmt.Sprintf(`{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"value": {
"$type": "io.atcr.sailor.profile",
"defaultHold": %q,
"createdAt": "2025-01-01T00:00:00Z",
"updatedAt": "2025-01-01T00:00:00Z"
}
}`, server.URL)
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
return
}
// PutRecord (migration)
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
mu.Lock()
putRecordCalled = true
json.NewDecoder(r.Body).Decode(&migrationRequest)
mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
profile, err := GetProfile(context.Background(), client)
if err != nil {
t.Fatalf("GetProfile() error = %v", err)
}
if profile == nil {
t.Fatal("GetProfile() returned nil, want profile")
}
if profile.DefaultHold != "did:web:hold01.atcr.io" {
t.Errorf("DefaultHold = %v, want did:web:hold01.atcr.io", profile.DefaultHold)
}
// Give migration goroutine time to execute
time.Sleep(50 * time.Millisecond)
mu.Lock()
called := putRecordCalled
request := migrationRequest
mu.Unlock()
if !called {
t.Error("Expected migration PutRecord to be called")
}
if request != nil {
recordData := request["record"].(map[string]any)
migratedHold := recordData["defaultHold"]
if migratedHold != "did:web:hold01.atcr.io" {
t.Errorf("Migrated defaultHold = %v, want did:web:hold01.atcr.io", migratedHold)
}
}
})
}
// TestGetProfile_MigrationLocking tests that concurrent migrations don't happen
func TestGetProfile_MigrationLocking(t *testing.T) {
// Clear migration locks
migrationLocks = sync.Map{}
putRecordCount := 0
var mu sync.Mutex
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle hold DID resolution
if r.URL.Path == "/.well-known/atproto-did" {
w.Write([]byte("did:web:hold01.atcr.io"))
return
}
// GetRecord - return profile with URL pointing to this server
if r.Method == "GET" {
response := fmt.Sprintf(`{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"value": {
"$type": "io.atcr.sailor.profile",
"defaultHold": %q,
"createdAt": "2025-01-01T00:00:00Z",
"updatedAt": "2025-01-01T00:00:00Z"
}
}`, server.URL)
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
return
}
// PutRecord - count migrations
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
mu.Lock()
putRecordCount++
mu.Unlock()
// Add small delay to ensure concurrent requests
time.Sleep(10 * time.Millisecond)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
// Make 5 concurrent GetProfile calls
var wg sync.WaitGroup
for range 5 {
wg.Go(func() {
_, err := GetProfile(context.Background(), client)
if err != nil {
t.Errorf("GetProfile() error = %v", err)
}
})
}
wg.Wait()
// Give migrations time to complete
time.Sleep(200 * time.Millisecond)
// Only one migration should have been persisted due to locking
mu.Lock()
count := putRecordCount
mu.Unlock()
if count != 1 {
t.Errorf("PutRecord called %d times, want 1 (locking should prevent concurrent migrations)", count)
}
}
// TestUpdateProfile tests updating a user's profile
func TestUpdateProfile(t *testing.T) {
tests := []struct {
name string
profile *atproto.SailorProfileRecord
wantNormalized string // Expected defaultHold after normalization
wantErr bool
}{
{
name: "update with DID",
profile: &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "did:web:hold02.atcr.io",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
wantNormalized: "did:web:hold02.atcr.io",
wantErr: false,
},
{
name: "clear default hold",
profile: &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
wantNormalized: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var sentProfile map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
sentProfile = body
// Verify rkey is "self"
if body["rkey"] != ProfileRKey {
t.Errorf("rkey = %v, want %v", body["rkey"], ProfileRKey)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := UpdateProfile(context.Background(), client, tt.profile)
if (err != nil) != tt.wantErr {
t.Errorf("UpdateProfile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
// Verify normalization happened
recordData := sentProfile["record"].(map[string]any)
defaultHold := recordData["defaultHold"]
// Handle empty string (may be nil in JSON)
defaultHoldStr := ""
if defaultHold != nil {
defaultHoldStr = defaultHold.(string)
}
if defaultHoldStr != tt.wantNormalized {
t.Errorf("defaultHold = %v, want %v", defaultHoldStr, tt.wantNormalized)
}
// Verify normalization also updated the profile object
if tt.profile.DefaultHold != tt.wantNormalized {
t.Errorf("profile.DefaultHold = %v, want %v (should be updated in-place)", tt.profile.DefaultHold, tt.wantNormalized)
}
}
})
}
// URL normalization test uses a local test server for /.well-known/atproto-did
t.Run("update with URL - should normalize", func(t *testing.T) {
var sentProfile map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle hold DID resolution
if r.URL.Path == "/.well-known/atproto-did" {
w.Write([]byte("did:web:hold02.atcr.io"))
return
}
if r.Method == "POST" && strings.Contains(r.URL.Path, "putRecord") {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
sentProfile = body
if body["rkey"] != ProfileRKey {
t.Errorf("rkey = %v, want %v", body["rkey"], ProfileRKey)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.sailor.profile/self","cid":"bafytest"}`))
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
profile := &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: server.URL, // URL pointing to test server with /.well-known/atproto-did
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := UpdateProfile(context.Background(), client, profile)
if err != nil {
t.Errorf("UpdateProfile() error = %v", err)
return
}
recordData := sentProfile["record"].(map[string]any)
defaultHold := recordData["defaultHold"].(string)
if defaultHold != "did:web:hold02.atcr.io" {
t.Errorf("defaultHold = %v, want did:web:hold02.atcr.io", defaultHold)
}
if profile.DefaultHold != "did:web:hold02.atcr.io" {
t.Errorf("profile.DefaultHold = %v, want did:web:hold02.atcr.io (should be updated in-place)", profile.DefaultHold)
}
})
}
// TestProfileRKey tests that profile record key is always "self"
func TestProfileRKey(t *testing.T) {
if ProfileRKey != "self" {
t.Errorf("ProfileRKey = %v, want self", ProfileRKey)
}
}
// TestEnsureProfile_Error tests error handling during profile creation
func TestEnsureProfile_Error(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// GetRecord: profile doesn't exist
if r.Method == "GET" {
w.WriteHeader(http.StatusNotFound)
return
}
// PutRecord: fail with server error
if r.Method == "POST" {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"InternalServerError"}`))
return
}
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
err := EnsureProfile(context.Background(), client, "did:web:hold01.atcr.io")
if err == nil {
t.Error("EnsureProfile() should return error when PutRecord fails")
}
}
// TestGetProfile_InvalidJSON tests handling of invalid profile JSON
func TestGetProfile_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"value": "not-valid-json-object"
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
_, err := GetProfile(context.Background(), client)
if err == nil {
t.Error("GetProfile() should return error for invalid JSON")
}
}
// TestGetProfile_EmptyDefaultHold tests profile with empty defaultHold
func TestGetProfile_EmptyDefaultHold(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"uri": "at://did:plc:test123/io.atcr.sailor.profile/self",
"value": {
"$type": "io.atcr.sailor.profile",
"defaultHold": "",
"createdAt": "2025-01-01T00:00:00Z",
"updatedAt": "2025-01-01T00:00:00Z"
}
}`
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
profile, err := GetProfile(context.Background(), client)
if err != nil {
t.Fatalf("GetProfile() error = %v", err)
}
if profile.DefaultHold != "" {
t.Errorf("DefaultHold = %v, want empty string", profile.DefaultHold)
}
}
// TestUpdateProfile_ServerError tests error handling in UpdateProfile
func TestUpdateProfile_ServerError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"InternalServerError"}`))
}))
defer server.Close()
client := atproto.NewClient(server.URL, "did:plc:test123", "test-token")
profile := &atproto.SailorProfileRecord{
Type: atproto.SailorProfileCollection,
DefaultHold: "did:web:hold01.atcr.io",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err := UpdateProfile(context.Background(), client, profile)
if err == nil {
t.Error("UpdateProfile() should return error when server fails")
}
}