mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 16:40:29 +00:00
185 lines
5.4 KiB
Go
185 lines
5.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"atcr.io/pkg/atproto"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bluesky-social/indigo/atproto/identity"
|
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
|
)
|
|
|
|
// CachedSession represents a cached session
|
|
type CachedSession struct {
|
|
DID string
|
|
Handle string
|
|
PDS string
|
|
AccessToken string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// SessionValidator validates ATProto credentials
|
|
type SessionValidator struct {
|
|
directory identity.Directory
|
|
httpClient *http.Client
|
|
cache map[string]*CachedSession
|
|
cacheMu sync.RWMutex
|
|
}
|
|
|
|
// NewSessionValidator creates a new ATProto session validator
|
|
func NewSessionValidator() *SessionValidator {
|
|
return &SessionValidator{
|
|
directory: identity.DefaultDirectory(),
|
|
httpClient: &http.Client{},
|
|
cache: make(map[string]*CachedSession),
|
|
}
|
|
}
|
|
|
|
// getCacheKey generates a cache key from username and password
|
|
func getCacheKey(username, password string) string {
|
|
h := sha256.New()
|
|
h.Write([]byte(username + ":" + password))
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// getCachedSession retrieves a cached session if valid
|
|
func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) {
|
|
v.cacheMu.RLock()
|
|
defer v.cacheMu.RUnlock()
|
|
|
|
session, ok := v.cache[cacheKey]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
// Check if expired (with 5 minute buffer)
|
|
if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) {
|
|
return nil, false
|
|
}
|
|
|
|
return session, true
|
|
}
|
|
|
|
// setCachedSession stores a session in the cache
|
|
func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) {
|
|
v.cacheMu.Lock()
|
|
defer v.cacheMu.Unlock()
|
|
v.cache[cacheKey] = session
|
|
}
|
|
|
|
// SessionResponse represents the response from createSession
|
|
type SessionResponse struct {
|
|
DID string `json:"did"`
|
|
Handle string `json:"handle"`
|
|
AccessJWT string `json:"accessJwt"`
|
|
RefreshJWT string `json:"refreshJwt"`
|
|
Email string `json:"email,omitempty"`
|
|
AccessToken string `json:"access_token,omitempty"` // Alternative field name
|
|
}
|
|
|
|
// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token
|
|
func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) {
|
|
// Check cache first
|
|
cacheKey := getCacheKey(identifier, password)
|
|
if cached, ok := v.getCachedSession(cacheKey); ok {
|
|
fmt.Printf("DEBUG [atproto/session]: Using cached session for %s (DID=%s)\n", identifier, cached.DID)
|
|
return cached.DID, cached.Handle, cached.AccessToken, nil
|
|
}
|
|
|
|
fmt.Printf("DEBUG [atproto/session]: No cached session for %s, creating new session\n", identifier)
|
|
|
|
// Resolve identifier to PDS endpoint
|
|
atID, err := syntax.ParseAtIdentifier(identifier)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("invalid identifier %q: %w", identifier, err)
|
|
}
|
|
|
|
ident, err := v.directory.Lookup(ctx, *atID)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("failed to resolve identity %q: %w", identifier, err)
|
|
}
|
|
|
|
did = ident.DID.String()
|
|
pds := ident.PDSEndpoint()
|
|
if pds == "" {
|
|
return "", "", "", fmt.Errorf("no PDS endpoint found for %q", identifier)
|
|
}
|
|
|
|
// Create session
|
|
sessionResp, err := v.createSession(ctx, pds, identifier, password)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("authentication failed: %w", err)
|
|
}
|
|
|
|
// Cache the session (ATProto sessions typically last 2 hours)
|
|
v.setCachedSession(cacheKey, &CachedSession{
|
|
DID: sessionResp.DID,
|
|
Handle: sessionResp.Handle,
|
|
PDS: pds,
|
|
AccessToken: sessionResp.AccessJWT,
|
|
ExpiresAt: time.Now().Add(2 * time.Hour),
|
|
})
|
|
fmt.Printf("DEBUG [atproto/session]: Cached session for %s (expires in 2 hours)\n", identifier)
|
|
|
|
return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil
|
|
}
|
|
|
|
// createSession calls com.atproto.server.createSession
|
|
func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) {
|
|
payload := map[string]string{
|
|
"identifier": identifier,
|
|
"password": password,
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession)
|
|
fmt.Printf("DEBUG [atproto/session]: POST %s\n", url)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
fmt.Printf("DEBUG [atproto/session]: HTTP request failed: %v\n", err)
|
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
fmt.Printf("DEBUG [atproto/session]: Got HTTP status %d\n", resp.StatusCode)
|
|
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
fmt.Printf("DEBUG [atproto/session]: Unauthorized response: %s\n", string(bodyBytes))
|
|
return nil, fmt.Errorf("invalid credentials")
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
fmt.Printf("DEBUG [atproto/session]: Error response: %s\n", string(bodyBytes))
|
|
return nil, fmt.Errorf("create session failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var sessionResp SessionResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
return &sessionResp, nil
|
|
}
|