Files
at-container-registry/pkg/auth/hold_local.go

102 lines
3.0 KiB
Go

package auth
import (
"context"
"fmt"
"atcr.io/pkg/atproto"
"atcr.io/pkg/hold/pds"
)
// LocalHoldAuthorizer queries the hold's own embedded PDS directly
// Used by hold service to authorize access to its own storage
type LocalHoldAuthorizer struct {
pds *pds.HoldPDS
}
// NewLocalHoldAuthorizer creates a new local authorizer for hold service
func NewLocalHoldAuthorizer(holdPDS *pds.HoldPDS) HoldAuthorizer {
return &LocalHoldAuthorizer{
pds: holdPDS,
}
}
// NewLocalHoldAuthorizerFromInterface creates a new local authorizer from an any
// This is used to avoid import cycles - caller must pass a *pds.HoldPDS
func NewLocalHoldAuthorizerFromInterface(holdPDS any) HoldAuthorizer {
// Type assert to *pds.HoldPDS
if pdsTyped, ok := holdPDS.(*pds.HoldPDS); ok {
return &LocalHoldAuthorizer{
pds: pdsTyped,
}
}
// Return nil if type assertion fails - caller should check
return nil
}
// GetCaptainRecord retrieves the captain record from the hold's PDS
func (a *LocalHoldAuthorizer) GetCaptainRecord(ctx context.Context, holdDID string) (*atproto.CaptainRecord, error) {
// Verify that the requested holdDID matches this hold
if holdDID != a.pds.DID() {
return nil, fmt.Errorf("holdDID mismatch: requested %s, this hold is %s", holdDID, a.pds.DID())
}
// Query the PDS for captain record
_, pdsCaptain, err := a.pds.GetCaptainRecord(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get captain record: %w", err)
}
// The PDS returns *atproto.CaptainRecord directly now (after we update pds to use atproto types)
return pdsCaptain, nil
}
// IsCrewMember checks if userDID is a crew member
func (a *LocalHoldAuthorizer) IsCrewMember(ctx context.Context, holdDID, userDID string) (bool, error) {
// Verify that the requested holdDID matches this hold
if holdDID != a.pds.DID() {
return false, fmt.Errorf("holdDID mismatch: requested %s, this hold is %s", holdDID, a.pds.DID())
}
// Query the PDS for crew list
crewList, err := a.pds.ListCrewMembers(ctx)
if err != nil {
return false, fmt.Errorf("failed to list crew members: %w", err)
}
// Check if userDID is in the crew list
for _, member := range crewList {
if member.Record.Member == userDID {
// TODO: Check expiration if set
return true, nil
}
}
return false, nil
}
// CheckReadAccess implements read authorization using shared logic
func (a *LocalHoldAuthorizer) CheckReadAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
captain, err := a.GetCaptainRecord(ctx, holdDID)
if err != nil {
return false, err
}
return CheckReadAccessWithCaptain(captain, userDID), nil
}
// CheckWriteAccess implements write authorization using shared logic
func (a *LocalHoldAuthorizer) CheckWriteAccess(ctx context.Context, holdDID, userDID string) (bool, error) {
captain, err := a.GetCaptainRecord(ctx, holdDID)
if err != nil {
return false, err
}
isCrew, err := a.IsCrewMember(ctx, holdDID, userDID)
if err != nil {
return false, err
}
return CheckWriteAccessWithCaptain(captain, userDID, isCrew), nil
}