1 Commits

Author SHA1 Message Date
Evan Jarrett
8bf3e15ca2 first pass at implementing a label service 2026-03-22 21:44:57 -05:00
33 changed files with 2985 additions and 15 deletions

82
cmd/labeler/main.go Normal file
View File

@@ -0,0 +1,82 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"atcr.io/pkg/labeler"
)
var configFile string
var rootCmd = &cobra.Command{
Use: "atcr-labeler",
Short: "ATCR Labeler Service - ATProto content moderation",
}
var serveCmd = &cobra.Command{
Use: "serve",
Short: "Start the labeler service",
Long: `Start the ATCR labeler service with admin UI and subscribeLabels endpoint.
Configuration is loaded from the appview config YAML (labeler section).
Use --config to specify the config file path.`,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := labeler.LoadConfig(configFile)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
server, err := labeler.NewServer(cfg)
if err != nil {
return fmt.Errorf("failed to initialize labeler: %w", err)
}
return server.Serve()
},
}
var configCmd = &cobra.Command{
Use: "config",
Short: "Configuration management commands",
}
var configInitCmd = &cobra.Command{
Use: "init [path]",
Short: "Generate an example configuration file",
Long: `Generate an example YAML configuration file with all available options.`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
yamlBytes, err := labeler.ExampleYAML()
if err != nil {
return fmt.Errorf("failed to generate example config: %w", err)
}
if len(args) == 1 {
if err := os.WriteFile(args[0], yamlBytes, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
fmt.Fprintf(os.Stderr, "Wrote example config to %s\n", args[0])
return nil
}
fmt.Print(string(yamlBytes))
return nil
},
}
func init() {
serveCmd.Flags().StringVarP(&configFile, "config", "c", "", "path to YAML configuration file")
configCmd.AddCommand(configInitCmd)
rootCmd.AddCommand(serveCmd)
rootCmd.AddCommand(configCmd)
}
func main() {
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}

View File

@@ -28,6 +28,12 @@ var holdConfigTmpl string
//go:embed configs/scanner.yaml.tmpl
var scannerConfigTmpl string
//go:embed systemd/labeler.service.tmpl
var labelerServiceTmpl string
//go:embed configs/labeler.yaml.tmpl
var labelerConfigTmpl string
//go:embed configs/cloudinit.sh.tmpl
var cloudInitTmpl string
@@ -111,9 +117,33 @@ func renderScannerServiceUnit(p scannerServiceUnitParams) (string, error) {
return buf.String(), nil
}
// labelerServiceUnitParams holds values for rendering the labeler systemd unit.
type labelerServiceUnitParams struct {
DisplayName string // e.g. "Seamark"
User string // e.g. "seamark"
BinaryPath string // e.g. "/opt/seamark/bin/seamark-labeler"
ConfigPath string // e.g. "/etc/seamark/labeler.yaml"
DataDir string // e.g. "/var/lib/seamark"
ServiceName string // e.g. "seamark-labeler"
AppviewServiceName string // e.g. "seamark-appview" (After= dependency)
}
func renderLabelerServiceUnit(p labelerServiceUnitParams) (string, error) {
t, err := template.New("labeler-service").Parse(labelerServiceTmpl)
if err != nil {
return "", fmt.Errorf("parse labeler service template: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, p); err != nil {
return "", fmt.Errorf("render labeler service template: %w", err)
}
return buf.String(), nil
}
// generateAppviewCloudInit generates the cloud-init user-data script for the appview server.
// Sets up the OS, directories, config, and systemd unit. Binaries are deployed separately via SCP.
func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, error) {
// When withLabeler is true, a second phase is appended that creates labeler data
// directories and installs a labeler systemd service. Binaries are deployed separately via SCP.
func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues, withLabeler bool) (string, error) {
naming := cfg.Naming()
configYAML, err := renderConfig(appviewConfigTmpl, vals)
@@ -133,7 +163,7 @@ func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, err
return "", fmt.Errorf("appview service unit: %w", err)
}
return generateCloudInit(cloudInitParams{
script, err := generateCloudInit(cloudInitParams{
BinaryName: naming.Appview(),
ServiceUnit: serviceUnit,
ConfigYAML: configYAML,
@@ -146,6 +176,69 @@ func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, err
LogFile: naming.LogFile(),
DisplayName: naming.DisplayName(),
})
if err != nil {
return "", err
}
if !withLabeler {
return script, nil
}
// Render labeler config YAML
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
if err != nil {
return "", fmt.Errorf("labeler config: %w", err)
}
// Append labeler setup phase
labelerUnit, err := renderLabelerServiceUnit(labelerServiceUnitParams{
DisplayName: naming.DisplayName(),
User: naming.SystemUser(),
BinaryPath: naming.InstallDir() + "/bin/" + naming.Labeler(),
ConfigPath: naming.LabelerConfigPath(),
DataDir: naming.BasePath(),
ServiceName: naming.Labeler(),
AppviewServiceName: naming.Appview(),
})
if err != nil {
return "", fmt.Errorf("labeler service unit: %w", err)
}
// Escape single quotes for heredoc embedding
labelerUnit = strings.ReplaceAll(labelerUnit, "'", "'\\''")
labelerConfigYAML = strings.ReplaceAll(labelerConfigYAML, "'", "'\\''")
labelerPhase := fmt.Sprintf(`
# === Labeler Setup ===
# Labeler data dirs
mkdir -p %s
chown -R %s:%s %s
# Labeler config
cat > %s << 'CFGEOF'
%s
CFGEOF
# Labeler systemd service
cat > /etc/systemd/system/%s.service << 'SVCEOF'
%s
SVCEOF
systemctl daemon-reload
systemctl enable %s
echo "=== Labeler setup complete ==="
`,
naming.LabelerDataDir(),
naming.SystemUser(), naming.SystemUser(), naming.LabelerDataDir(),
naming.LabelerConfigPath(),
labelerConfigYAML,
naming.Labeler(),
labelerUnit,
naming.Labeler(),
)
return script + labelerPhase, nil
}
// generateHoldCloudInit generates the cloud-init user-data script for the hold server.

View File

@@ -46,3 +46,5 @@ credential_helper:
legal:
company_name: Seamark
jurisdiction: State of Texas, United States
labeler:
did: ""

View File

@@ -0,0 +1,19 @@
version: "0.1"
log_level: info
log_shipper:
backend: ""
url: ""
batch_size: 100
flush_interval: 5s
username: ""
password: ""
labeler:
enabled: true
addr: :5002
owner_did: ""
db_path: "{{.BasePath}}/labeler/labeler.db"
server:
base_url: "https://seamark.dev"
client_name: Seamark
client_short_name: Seamark
test_mode: false

BIN
deploy/upcloud/deploy Executable file

Binary file not shown.

View File

@@ -57,5 +57,14 @@ func (n Naming) ScannerConfigPath() string { return n.ConfigDir() + "/scanner.ya
// ScannerDataDir returns the scanner data directory (e.g. "/var/lib/seamark/scanner").
func (n Naming) ScannerDataDir() string { return n.BasePath() + "/scanner" }
// Labeler returns the labeler binary/service name (e.g. "seamark-labeler").
func (n Naming) Labeler() string { return n.ClientName + "-labeler" }
// LabelerConfigPath returns the labeler config file path.
func (n Naming) LabelerConfigPath() string { return n.ConfigDir() + "/labeler.yaml" }
// LabelerDataDir returns the labeler data directory (e.g. "/var/lib/seamark/labeler").
func (n Naming) LabelerDataDir() string { return n.BasePath() + "/labeler" }
// S3Name returns the name used for S3 storage, user, and bucket.
func (n Naming) S3Name() string { return n.ClientName }

View File

@@ -29,7 +29,8 @@ var provisionCmd = &cobra.Command{
sshKey, _ := cmd.Flags().GetString("ssh-key")
s3Secret, _ := cmd.Flags().GetString("s3-secret")
withScanner, _ := cmd.Flags().GetBool("with-scanner")
return cmdProvision(token, zone, plan, sshKey, s3Secret, withScanner)
withLabeler, _ := cmd.Flags().GetBool("with-labeler")
return cmdProvision(token, zone, plan, sshKey, s3Secret, withScanner, withLabeler)
},
}
@@ -39,11 +40,12 @@ func init() {
provisionCmd.Flags().String("ssh-key", "", "Path to SSH public key file (required)")
provisionCmd.Flags().String("s3-secret", "", "S3 secret access key (for existing object storage)")
provisionCmd.Flags().Bool("with-scanner", false, "Deploy vulnerability scanner alongside hold")
provisionCmd.Flags().Bool("with-labeler", false, "Deploy content moderation labeler alongside appview")
_ = provisionCmd.MarkFlagRequired("ssh-key")
rootCmd.AddCommand(provisionCmd)
}
func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bool) error {
func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner, withLabeler bool) error {
cfg, err := loadConfig(zone, plan, sshKeyPath, s3Secret)
if err != nil {
return err
@@ -98,6 +100,12 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
_ = saveState(state)
}
// Labeler setup
if withLabeler {
state.LabelerEnabled = true
_ = saveState(state)
}
fmt.Printf("Provisioning %s infrastructure in zone %s...\n", naming.DisplayName(), cfg.Zone)
if needsServers {
fmt.Printf("Server plan: %s\n", cfg.Plan)
@@ -198,7 +206,7 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
appviewCreated := false
if state.Appview.UUID != "" {
fmt.Printf("Appview: %s (exists)\n", state.Appview.UUID)
appviewScript, err := generateAppviewCloudInit(cfg, vals)
appviewScript, err := generateAppviewCloudInit(cfg, vals, state.LabelerEnabled)
if err != nil {
return err
}
@@ -212,9 +220,18 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
if err := syncConfigKeys("appview", state.Appview.PublicIP, naming.AppviewConfigPath(), appviewConfigYAML); err != nil {
return fmt.Errorf("appview config sync: %w", err)
}
if state.LabelerEnabled {
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
if err != nil {
return fmt.Errorf("render labeler config: %w", err)
}
if err := syncConfigKeys("labeler", state.Appview.PublicIP, naming.LabelerConfigPath(), labelerConfigYAML); err != nil {
return fmt.Errorf("labeler config sync: %w", err)
}
}
} else {
fmt.Println("Creating appview server...")
appviewUserData, err := generateAppviewCloudInit(cfg, vals)
appviewUserData, err := generateAppviewCloudInit(cfg, vals, state.LabelerEnabled)
if err != nil {
return err
}
@@ -338,6 +355,12 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
if err := buildLocal(rootDir, outputPath, "./cmd/appview"); err != nil {
return fmt.Errorf("build appview: %w", err)
}
if state.LabelerEnabled {
outputPath := filepath.Join(rootDir, "bin", "atcr-labeler")
if err := buildLocal(rootDir, outputPath, "./cmd/labeler"); err != nil {
return fmt.Errorf("build labeler: %w", err)
}
}
}
if holdCreated {
outputPath := filepath.Join(rootDir, "bin", "atcr-hold")
@@ -371,6 +394,13 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
if err := scpFile(localPath, state.Appview.PublicIP, remotePath); err != nil {
return fmt.Errorf("upload appview: %w", err)
}
if state.LabelerEnabled {
labelerLocal := filepath.Join(rootDir, "bin", "atcr-labeler")
labelerRemote := naming.InstallDir() + "/bin/" + naming.Labeler()
if err := scpFile(labelerLocal, state.Appview.PublicIP, labelerRemote); err != nil {
return fmt.Errorf("upload labeler: %w", err)
}
}
}
if holdCreated {
localPath := filepath.Join(rootDir, "bin", "atcr-hold")
@@ -411,11 +441,14 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
} else {
fmt.Println(" 1. Start services:")
}
services := []string{naming.Appview(), naming.Hold()}
if state.ScannerEnabled {
fmt.Printf(" systemctl start %s / %s / %s\n", naming.Appview(), naming.Hold(), naming.Scanner())
} else {
fmt.Printf(" systemctl start %s / %s\n", naming.Appview(), naming.Hold())
services = append(services, naming.Scanner())
}
if state.LabelerEnabled {
services = append(services, naming.Labeler())
}
fmt.Printf(" systemctl start %s\n", strings.Join(services, " / "))
fmt.Println(" 2. Configure DNS records above")
return nil

View File

@@ -20,6 +20,7 @@ type InfraState struct {
ObjectStorage ObjectStorageState `json:"object_storage"`
ScannerEnabled bool `json:"scanner_enabled,omitempty"`
ScannerSecret string `json:"scanner_secret,omitempty"`
LabelerEnabled bool `json:"labeler_enabled,omitempty"`
}
// Naming returns a Naming helper, defaulting to "seamark" if ClientName is empty.

View File

@@ -0,0 +1,25 @@
[Unit]
Description={{.DisplayName}} Labeler (Content Moderation)
After=network-online.target {{.AppviewServiceName}}.service
Wants=network-online.target
[Service]
Type=simple
User={{.User}}
Group={{.User}}
ExecStart={{.BinaryPath}} serve --config {{.ConfigPath}}
Restart=on-failure
RestartSec=10
ReadWritePaths={{.DataDir}}
ProtectSystem=strict
ProtectHome=yes
NoNewPrivileges=yes
PrivateTmp=yes
StandardOutput=journal
StandardError=journal
SyslogIdentifier={{.ServiceName}}
[Install]
WantedBy=multi-user.target

View File

@@ -24,7 +24,8 @@ var updateCmd = &cobra.Command{
target = args[0]
}
withScanner, _ := cmd.Flags().GetBool("with-scanner")
return cmdUpdate(target, withScanner)
withLabeler, _ := cmd.Flags().GetBool("with-labeler")
return cmdUpdate(target, withScanner, withLabeler)
},
}
@@ -40,11 +41,12 @@ var sshCmd = &cobra.Command{
func init() {
updateCmd.Flags().Bool("with-scanner", false, "Enable and deploy vulnerability scanner alongside hold")
updateCmd.Flags().Bool("with-labeler", false, "Enable and deploy content moderation labeler alongside appview")
rootCmd.AddCommand(updateCmd)
rootCmd.AddCommand(sshCmd)
}
func cmdUpdate(target string, withScanner bool) error {
func cmdUpdate(target string, withScanner, withLabeler bool) error {
state, err := loadState()
if err != nil {
return err
@@ -67,6 +69,12 @@ func cmdUpdate(target string, withScanner bool) error {
_ = saveState(state)
}
// Enable labeler retroactively via --with-labeler on update
if withLabeler && !state.LabelerEnabled {
state.LabelerEnabled = true
_ = saveState(state)
}
vals := configValsFromState(state)
targets := map[string]struct {
@@ -144,6 +152,21 @@ func cmdUpdate(target string, withScanner bool) error {
}
}
// Build labeler locally if needed
needLabeler := false
for _, name := range toUpdate {
if name == "appview" && state.LabelerEnabled {
needLabeler = true
break
}
}
if needLabeler {
outputPath := filepath.Join(rootDir, "bin", "atcr-labeler")
if err := buildLocal(rootDir, outputPath, "./cmd/labeler"); err != nil {
return fmt.Errorf("build labeler: %w", err)
}
}
// Deploy each target
for _, name := range toUpdate {
t := targets[name]
@@ -244,13 +267,65 @@ curl -sf http://localhost:9090/healthz > /dev/null && echo "SCANNER_HEALTH_OK" |
`
}
// Labeler additions for appview server
labelerRestart := ""
if name == "appview" && state.LabelerEnabled {
// Sync labeler config keys
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
if err != nil {
return fmt.Errorf("render labeler config: %w", err)
}
if err := syncConfigKeys("labeler", t.ip, naming.LabelerConfigPath(), labelerConfigYAML); err != nil {
return fmt.Errorf("labeler config sync: %w", err)
}
// Sync labeler service unit
labelerUnit, err := renderLabelerServiceUnit(labelerServiceUnitParams{
DisplayName: naming.DisplayName(),
User: naming.SystemUser(),
BinaryPath: naming.InstallDir() + "/bin/" + naming.Labeler(),
ConfigPath: naming.LabelerConfigPath(),
DataDir: naming.BasePath(),
ServiceName: naming.Labeler(),
AppviewServiceName: naming.Appview(),
})
if err != nil {
return fmt.Errorf("render labeler service unit: %w", err)
}
labelerUnitChanged, err := syncServiceUnit("labeler", t.ip, naming.Labeler(), labelerUnit)
if err != nil {
return fmt.Errorf("labeler service unit sync: %w", err)
}
if labelerUnitChanged {
daemonReload = "systemctl daemon-reload"
}
// Upload labeler binary
labelerLocal := filepath.Join(rootDir, "bin", "atcr-labeler")
labelerRemote := naming.InstallDir() + "/bin/" + naming.Labeler()
if err := scpFile(labelerLocal, t.ip, labelerRemote); err != nil {
return fmt.Errorf("upload labeler: %w", err)
}
// Ensure labeler data dirs exist
labelerSetup := fmt.Sprintf(`mkdir -p %s
chown -R %s:%s %s`,
naming.LabelerDataDir(),
naming.SystemUser(), naming.SystemUser(), naming.LabelerDataDir())
if _, err := runSSH(t.ip, labelerSetup, false); err != nil {
return fmt.Errorf("labeler dir setup: %w", err)
}
labelerRestart = fmt.Sprintf("\nsystemctl restart %s", naming.Labeler())
}
// Restart services and health check
restartScript := fmt.Sprintf(`set -euo pipefail
%s
systemctl restart %s%s
systemctl restart %s%s%s
sleep 2
curl -sf %s > /dev/null && echo "HEALTH_OK" || echo "HEALTH_FAIL"
%s`, daemonReload, t.serviceName, scannerRestart, t.healthURL, scannerHealthCheck)
%s`, daemonReload, t.serviceName, scannerRestart, labelerRestart, t.healthURL, scannerHealthCheck)
output, err := runSSH(t.ip, restartScript, true)
if err != nil {

View File

@@ -32,6 +32,7 @@ type Config struct {
Auth AuthConfig `yaml:"auth" comment:"JWT authentication settings."`
CredentialHelper CredentialHelperConfig `yaml:"credential_helper" comment:"Credential helper download settings."`
Legal LegalConfig `yaml:"legal" comment:"Legal page customization for self-hosted instances."`
Labeler LabelerRefConfig `yaml:"labeler" comment:"ATProto labeler for content moderation (DMCA takedowns)."`
Billing billing.Config `yaml:"billing" comment:"Stripe billing integration (requires -tags billing build)."`
Distribution *configuration.Configuration `yaml:"-"` // Wrapped distribution config for compatibility
}
@@ -140,6 +141,12 @@ type LegalConfig struct {
Jurisdiction string `yaml:"jurisdiction" comment:"Governing law jurisdiction for legal terms."`
}
// LabelerRefConfig defines the connection to an ATProto labeler service.
type LabelerRefConfig struct {
// DID or URL of the labeler service for content moderation.
DID string `yaml:"did" comment:"DID or URL of the ATProto labeler (e.g., did:web:labeler.atcr.io). Empty disables label filtering."`
}
// setDefaults registers all default values on the given Viper instance.
func setDefaults(v *viper.Viper) {
v.SetDefault("version", "0.1")
@@ -193,6 +200,9 @@ func setDefaults(v *viper.Viper) {
v.SetDefault("legal.company_name", "")
v.SetDefault("legal.jurisdiction", "")
// Labeler defaults
v.SetDefault("labeler.did", "")
// Log formatter (used by distribution config, not in Config struct)
v.SetDefault("log_formatter", "text")
}

79
pkg/appview/db/labels.go Normal file
View File

@@ -0,0 +1,79 @@
package db
import (
"database/sql"
"time"
)
// LabelChecker wraps a database connection to check takedown labels.
// Implements middleware.LabelChecker interface.
type LabelChecker struct {
db *sql.DB
}
// NewLabelChecker creates a new LabelChecker.
func NewLabelChecker(database *sql.DB) *LabelChecker {
return &LabelChecker{db: database}
}
// IsTakenDown checks if a (DID, repository) pair has an active takedown label.
func (lc *LabelChecker) IsTakenDown(did, repository string) (bool, error) {
return IsTakenDown(lc.db, did, repository)
}
// Label represents an ATProto label mirrored from a labeler service.
type Label struct {
ID int64
Src string
URI string
Val string
Neg bool
Cts time.Time
SubjectDID string
SubjectRepo string
Seq int64
}
// IsTakenDown checks if a (DID, repository) pair has an active !takedown label.
// Also matches user-level labels (subject_repo = ”).
func IsTakenDown(db DBTX, did, repository string) (bool, error) {
var exists bool
err := db.QueryRow(
`SELECT EXISTS(
SELECT 1 FROM labels l1
WHERE l1.subject_did = ?
AND (l1.subject_repo = ? OR l1.subject_repo = '')
AND l1.val = '!takedown' AND l1.neg = 0
AND NOT EXISTS (
SELECT 1 FROM labels l2
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
AND l2.neg = 1 AND l2.id > l1.id
)
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)
)`,
did, repository,
).Scan(&exists)
return exists, err
}
// UpsertLabel inserts or updates a label from a labeler subscription.
func UpsertLabel(db DBTX, l *Label) error {
_, err := db.Exec(
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo, seq)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts, seq = excluded.seq`,
l.Src, l.URI, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339),
l.SubjectDID, l.SubjectRepo, l.Seq,
)
return err
}
// GetLabelCursor returns the latest sequence number for a given labeler source.
func GetLabelCursor(db DBTX, src string) (int64, error) {
var cursor int64
err := db.QueryRow(
`SELECT COALESCE(MAX(seq), 0) FROM labels WHERE src = ?`,
src,
).Scan(&cursor)
return cursor, err
}

View File

@@ -0,0 +1,16 @@
description: Create labels table for ATProto content moderation (takedowns)
query: |
CREATE TABLE IF NOT EXISTS labels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
src TEXT NOT NULL,
uri TEXT NOT NULL,
val TEXT NOT NULL,
neg BOOLEAN NOT NULL DEFAULT 0,
cts TIMESTAMP NOT NULL,
subject_did TEXT NOT NULL,
subject_repo TEXT NOT NULL DEFAULT '',
seq INTEGER NOT NULL DEFAULT 0,
UNIQUE(src, uri, val, neg)
);
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
CREATE INDEX IF NOT EXISTS idx_labels_val ON labels(val);

View File

@@ -74,13 +74,18 @@ func SearchRepositories(db DBTX, query string, limit, offset int, currentUserDID
SELECT DISTINCT lm.did, lm.repository, lm.latest_id
FROM latest_manifests lm
JOIN users u ON lm.did = u.did
WHERE u.handle LIKE ? ESCAPE '\'
WHERE (u.handle LIKE ? ESCAPE '\'
OR u.did = ?
OR lm.repository LIKE ? ESCAPE '\'
OR EXISTS (
SELECT 1 FROM repository_annotations ra
WHERE ra.did = lm.did AND ra.repository = lm.repository
AND ra.value LIKE ? ESCAPE '\'
))
AND NOT EXISTS (
SELECT 1 FROM labels
WHERE (subject_did = lm.did AND (subject_repo = lm.repository OR subject_repo = ''))
AND val = '!takedown' AND neg = 0
)
),
repo_stats AS (
@@ -1953,6 +1958,11 @@ func GetRepoCards(db DBTX, limit int, currentUserDID string, sortOrder RepoCardS
JOIN users u ON m.did = u.did
LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
WHERE NOT EXISTS (
SELECT 1 FROM labels
WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
AND val = '!takedown' AND neg = 0
)
ORDER BY ` + orderBy + `
LIMIT ?
`
@@ -2026,6 +2036,11 @@ func GetUserRepoCards(db DBTX, userDID string, currentUserDID string) ([]RepoCar
JOIN users u ON m.did = u.did
LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
WHERE NOT EXISTS (
SELECT 1 FROM labels
WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
AND val = '!takedown' AND neg = 0
)
ORDER BY MAX(rs.last_push, m.created_at) DESC
`

View File

@@ -271,3 +271,18 @@ CREATE TABLE IF NOT EXISTS scans (
PRIMARY KEY(hold_did, manifest_digest)
);
CREATE INDEX IF NOT EXISTS idx_scans_user ON scans(user_did);
CREATE TABLE IF NOT EXISTS labels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
src TEXT NOT NULL,
uri TEXT NOT NULL,
val TEXT NOT NULL,
neg BOOLEAN NOT NULL DEFAULT 0,
cts TIMESTAMP NOT NULL,
subject_did TEXT NOT NULL,
subject_repo TEXT NOT NULL DEFAULT '',
seq INTEGER NOT NULL DEFAULT 0,
UNIQUE(src, uri, val, neg)
);
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
CREATE INDEX IF NOT EXISTS idx_labels_val ON labels(val);

View File

@@ -45,6 +45,12 @@ func (h *DigestDetailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
// Check for takedown labels
if taken, _ := db.IsTakenDown(h.ReadOnlyDB, did, repository); taken {
RenderNotFound(w, r, &h.BaseUIHandler)
return
}
owner, err := db.GetUserByDID(h.ReadOnlyDB, did)
if err != nil || owner == nil {
RenderNotFound(w, r, &h.BaseUIHandler)

View File

@@ -34,6 +34,12 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
// Check for takedown labels
if taken, _ := db.IsTakenDown(h.ReadOnlyDB, did, repository); taken {
RenderNotFound(w, r, &h.BaseUIHandler)
return
}
// Look up user by DID
owner, err := db.GetUserByDID(h.ReadOnlyDB, did)
if err != nil {

View File

@@ -229,6 +229,20 @@ func (p *Processor) ProcessRecord(ctx context.Context, did, collection, rkey str
}
}
// Skip ingestion for taken-down content
if !isDelete && data != nil {
if repo := extractRepoFromRecord(collection, data); repo != "" {
if taken, _ := db.IsTakenDown(p.db, did, repo); taken {
slog.Debug("Skipping taken-down content",
"component", "processor",
"did", did,
"collection", collection,
"repository", repo)
return nil
}
}
}
// User-activity collections create/update user entries
// Skip for deletes - user should already exist, and we don't need to resolve identity
if !isDelete {
@@ -971,3 +985,23 @@ func (p *Processor) ProcessAccount(ctx context.Context, did string, active bool,
return nil
}
// extractRepoFromRecord extracts the repository field from a record's JSON data.
// Returns empty string for collections that don't have a repository field
// (e.g., sailor profile, captain, crew).
func extractRepoFromRecord(collection string, data []byte) string {
switch collection {
case atproto.ManifestCollection,
atproto.TagCollection,
atproto.RepoPageCollection,
atproto.StatsCollection,
atproto.ScanCollection:
var rec struct {
Repository string `json:"repository"`
}
if err := json.Unmarshal(data, &rec); err == nil {
return rec.Repository
}
}
return ""
}

View File

@@ -0,0 +1,239 @@
// Package labeler provides a subscription client for consuming labels
// from an ATProto labeler service.
package labeler
import (
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
"atcr.io/pkg/appview/db"
"github.com/gorilla/websocket"
)
// LabelsMessage is the wire format for subscribeLabels events.
type LabelsMessage struct {
Seq int64 `json:"seq"`
Labels []LabelEvent `json:"labels"`
}
// LabelEvent is a single label from the labeler.
type LabelEvent struct {
Src string `json:"src"`
URI string `json:"uri"`
CID string `json:"cid,omitempty"`
Val string `json:"val"`
Neg bool `json:"neg"`
Cts string `json:"cts"`
Exp string `json:"exp,omitempty"`
}
// Subscriber connects to a labeler's subscribeLabels endpoint
// and mirrors labels into the appview database.
type Subscriber struct {
labelerURL string
database *sql.DB
stopCh chan struct{}
}
// NewSubscriber creates a new labeler subscriber.
func NewSubscriber(labelerURL string, database *sql.DB) *Subscriber {
return &Subscriber{
labelerURL: labelerURL,
database: database,
stopCh: make(chan struct{}),
}
}
// Start begins the subscription loop in a goroutine.
func (s *Subscriber) Start() {
go s.run()
}
// Stop signals the subscriber to shut down.
func (s *Subscriber) Stop() {
close(s.stopCh)
}
func (s *Subscriber) run() {
backoff := time.Second
for {
select {
case <-s.stopCh:
return
default:
}
if err := s.connect(); err != nil {
slog.Warn("Labeler subscription error, reconnecting",
"error", err,
"backoff", backoff,
)
select {
case <-s.stopCh:
return
case <-time.After(backoff):
}
if backoff < 30*time.Second {
backoff *= 2
}
} else {
backoff = time.Second
}
}
}
func (s *Subscriber) connect() error {
// Get cursor from DB
// Use the labeler URL as src identifier
labelerDID := extractDIDFromURL(s.labelerURL)
cursor, err := db.GetLabelCursor(s.database, labelerDID)
if err != nil {
return fmt.Errorf("failed to get cursor: %w", err)
}
// Build WebSocket URL
wsURL := toWebSocketURL(s.labelerURL) + "/xrpc/com.atproto.label.subscribeLabels"
if cursor > 0 {
wsURL += fmt.Sprintf("?cursor=%d", cursor)
}
slog.Info("Connecting to labeler", "url", wsURL, "cursor", cursor)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return fmt.Errorf("websocket dial failed: %w", err)
}
defer conn.Close()
slog.Info("Connected to labeler", "url", s.labelerURL)
for {
select {
case <-s.stopCh:
return nil
default:
}
var msg LabelsMessage
if err := conn.ReadJSON(&msg); err != nil {
return fmt.Errorf("read error: %w", err)
}
for _, le := range msg.Labels {
cts, _ := time.Parse(time.RFC3339, le.Cts)
did, repo := extractSubjectFromURI(le.URI)
label := &db.Label{
Src: le.Src,
URI: le.URI,
Val: le.Val,
Neg: le.Neg,
Cts: cts,
SubjectDID: did,
SubjectRepo: repo,
Seq: msg.Seq,
}
if err := db.UpsertLabel(s.database, label); err != nil {
slog.Warn("Failed to upsert label", "uri", le.URI, "error", err)
continue
}
slog.Info("Mirrored label",
"uri", le.URI,
"val", le.Val,
"neg", le.Neg,
"subject_did", did,
"subject_repo", repo,
)
}
}
}
// extractSubjectFromURI extracts the DID and repository from an AT URI.
// Examples:
//
// at://did:plc:xyz → (did:plc:xyz, "")
// at://did:plc:xyz/io.atcr.manifest/abc → (did:plc:xyz, "") - repo extracted from record
// at://did:plc:xyz/io.atcr.repo/myimage → (did:plc:xyz, "myimage")
func extractSubjectFromURI(uri string) (did, repo string) {
trimmed := strings.TrimPrefix(uri, "at://")
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) == 0 {
return "", ""
}
did = parts[0]
// For repo-level summary labels: at://did/io.atcr.repo/reponame
if len(parts) >= 3 && parts[1] == "io.atcr.repo" {
repo = parts[2]
}
return did, repo
}
// extractDIDFromURL derives a did:web from a labeler URL.
func extractDIDFromURL(labelerURL string) string {
u, err := url.Parse(labelerURL)
if err != nil {
return labelerURL
}
host := u.Hostname()
if port := u.Port(); port != "" {
host += "%3A" + port
}
return "did:web:" + host
}
// toWebSocketURL converts an HTTP URL to a WebSocket URL.
func toWebSocketURL(httpURL string) string {
u, err := url.Parse(httpURL)
if err != nil {
return httpURL
}
switch u.Scheme {
case "https":
u.Scheme = "wss"
default:
u.Scheme = "ws"
}
return u.String()
}
// ParseLabelerURL parses a labeler DID or URL into an HTTP URL.
func ParseLabelerURL(labelerDIDOrURL string) string {
if strings.HasPrefix(labelerDIDOrURL, "http://") || strings.HasPrefix(labelerDIDOrURL, "https://") {
return labelerDIDOrURL
}
if strings.HasPrefix(labelerDIDOrURL, "did:web:") {
host := strings.TrimPrefix(labelerDIDOrURL, "did:web:")
host = strings.ReplaceAll(host, "%3A", ":")
return "https://" + host
}
return labelerDIDOrURL
}
// SubscriberFromConfig creates a Subscriber from a labeler DID/URL config value.
// Returns nil if labelerDIDOrURL is empty.
func SubscriberFromConfig(labelerDIDOrURL string, database *sql.DB) *Subscriber {
if labelerDIDOrURL == "" {
return nil
}
labelerURL := ParseLabelerURL(labelerDIDOrURL)
return NewSubscriber(labelerURL, database)
}
// DecodeLabelsFromJSON decodes a JSON-encoded labels message.
func DecodeLabelsFromJSON(data []byte) (*LabelsMessage, error) {
var msg LabelsMessage
if err := json.Unmarshal(data, &msg); err != nil {
return nil, err
}
return &msg, nil
}

View File

@@ -166,6 +166,11 @@ func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetc
return serviceToken, err
}
// LabelChecker checks whether content has been taken down via ATProto labels.
type LabelChecker interface {
IsTakenDown(did, repository string) (bool, error)
}
// Global variables for initialization only
// These are set by main.go during startup and copied into NamespaceResolver instances.
// After initialization, request handling uses the NamespaceResolver's instance fields.
@@ -175,6 +180,7 @@ var (
globalAuthorizer auth.HoldAuthorizer
globalWebhookDispatcher storage.PushWebhookDispatcher
globalManifestRefChecker storage.ManifestReferenceChecker
globalLabelChecker LabelChecker
)
// SetGlobalRefresher sets the OAuth refresher instance during initialization
@@ -194,6 +200,11 @@ func SetGlobalManifestRefChecker(checker storage.ManifestReferenceChecker) {
globalManifestRefChecker = checker
}
// SetGlobalLabelChecker sets the label checker instance during initialization
func SetGlobalLabelChecker(checker LabelChecker) {
globalLabelChecker = checker
}
// SetGlobalAuthorizer sets the authorizer instance during initialization
// Must be called before the registry starts serving requests
func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
@@ -304,6 +315,16 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
slog.Debug("Resolved identity", "component", "registry/middleware", "did", did, "pds", pdsEndpoint, "handle", handle)
// Check for takedown labels before proceeding
if globalLabelChecker != nil {
if taken, _ := globalLabelChecker.IsTakenDown(did, imageName); taken {
return nil, errcode.Error{
Code: errcode.ErrorCodeDenied,
Message: "this repository has been removed due to a policy violation",
}
}
}
// Query for hold DID - either user's hold or default hold service
// Also returns the sailor profile so we can read preferences (e.g. AutoRemoveUntagged)
holdDID, sailorProfile := nr.findHoldDIDAndProfile(ctx, did, pdsEndpoint)

View File

@@ -24,6 +24,7 @@ import (
"atcr.io/pkg/appview/db"
"atcr.io/pkg/appview/holdhealth"
"atcr.io/pkg/appview/jetstream"
appviewlabeler "atcr.io/pkg/appview/labeler"
"atcr.io/pkg/appview/middleware"
"atcr.io/pkg/appview/readme"
"atcr.io/pkg/appview/routes"
@@ -236,6 +237,9 @@ func NewAppViewServer(cfg *Config, branding *BrandingOverrides) (*AppViewServer,
middleware.SetGlobalDatabase(holdDIDDB)
middleware.SetGlobalManifestRefChecker(holdDIDDB)
// Set label checker for takedown filtering
middleware.SetGlobalLabelChecker(db.NewLabelChecker(s.Database))
// Create RemoteHoldAuthorizer for hold authorization with caching
s.HoldAuthorizer = auth.NewRemoteHoldAuthorizer(s.Database, testMode)
middleware.SetGlobalAuthorizer(s.HoldAuthorizer)
@@ -287,6 +291,15 @@ func NewAppViewServer(cfg *Config, branding *BrandingOverrides) (*AppViewServer,
// Initialize Jetstream workers
s.initializeJetstream()
// Initialize labeler subscriber
if cfg.Labeler.DID != "" {
sub := appviewlabeler.SubscriberFromConfig(cfg.Labeler.DID, s.Database)
if sub != nil {
sub.Start()
slog.Info("Labeler subscriber started", "labeler", cfg.Labeler.DID)
}
}
// Create main chi router
mainRouter := chi.NewRouter()

106
pkg/labeler/auth.go Normal file
View File

@@ -0,0 +1,106 @@
package labeler
import (
"crypto/rand"
"encoding/base64"
"net/http"
"sync"
)
// Session represents an authenticated admin session.
type Session struct {
DID string
Handle string
}
// Auth manages admin authentication.
type Auth struct {
ownerDID string
sessions map[string]*Session
sessionsMu sync.RWMutex
}
// NewAuth creates a new Auth manager.
func NewAuth(ownerDID string) *Auth {
return &Auth{
ownerDID: ownerDID,
sessions: make(map[string]*Session),
}
}
func (a *Auth) createSession(did, handle string) (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
token := base64.URLEncoding.EncodeToString(b)
a.sessionsMu.Lock()
a.sessions[token] = &Session{DID: did, Handle: handle}
a.sessionsMu.Unlock()
return token, nil
}
func (a *Auth) getSession(token string) *Session {
a.sessionsMu.RLock()
defer a.sessionsMu.RUnlock()
return a.sessions[token]
}
func (a *Auth) deleteSession(token string) {
a.sessionsMu.Lock()
delete(a.sessions, token)
a.sessionsMu.Unlock()
}
const sessionCookieName = "labeler_session"
func setSessionCookie(w http.ResponseWriter, r *http.Request, token string) {
secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: token,
Path: "/",
MaxAge: 86400, // 24 hours
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}
func getSessionCookie(r *http.Request) (string, bool) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
return "", false
}
return cookie.Value, true
}
// RequireOwner is middleware that checks the session belongs to the owner DID.
func (a *Auth) RequireOwner(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, ok := getSessionCookie(r)
if !ok {
http.Redirect(w, r, "/auth/login", http.StatusFound)
return
}
session := a.getSession(token)
if session == nil || session.DID != a.ownerDID {
http.Redirect(w, r, "/auth/login", http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}

146
pkg/labeler/config.go Normal file
View File

@@ -0,0 +1,146 @@
// Package labeler implements the ATCR labeler service, an ATProto-compatible
// content moderation service for issuing takedown labels on container registry content.
package labeler
import (
"fmt"
"net/url"
"strings"
"github.com/spf13/viper"
"atcr.io/pkg/config"
)
// Config represents the labeler service configuration.
// It reuses the appview config YAML structure, reading from the "labeler" section.
type Config struct {
Version string `yaml:"version" comment:"Configuration format version."`
LogLevel string `yaml:"log_level" comment:"Log level: debug, info, warn, error."`
Labeler LabelerConfig `yaml:"labeler" comment:"Labeler service settings."`
Server AppviewServerConfig `yaml:"server" comment:"AppView server settings (shared config)."`
LogShipper config.LogShipperConfig `yaml:"log_shipper" comment:"Remote log shipping settings."`
}
// LabelerConfig defines labeler-specific settings.
type LabelerConfig struct {
// Enable the labeler service.
Enabled bool `yaml:"enabled" comment:"Enable the labeler service."`
// Listen address for the labeler HTTP server.
Addr string `yaml:"addr" comment:"Listen address for labeler (e.g., :5002)."`
// DID of the labeler admin. Only this DID can log into the admin panel.
OwnerDID string `yaml:"owner_did" comment:"DID of the labeler admin. Only this DID can log into the admin panel."`
// Path to labeler SQLite database.
DBPath string `yaml:"db_path" comment:"Path to labeler SQLite database."`
}
// AppviewServerConfig is a subset of the appview ServerConfig that the labeler needs.
type AppviewServerConfig struct {
BaseURL string `yaml:"base_url"`
ClientName string `yaml:"client_name"`
ClientShortName string `yaml:"client_short_name"`
TestMode bool `yaml:"test_mode"`
}
// PublicURL returns the labeler's public URL derived from the appview base URL.
// If appview is https://atcr.io, labeler is https://labeler.atcr.io.
func (c *Config) PublicURL() string {
u, err := url.Parse(c.Server.BaseURL)
if err != nil {
return ""
}
u.Host = "labeler." + u.Host
return u.String()
}
// DID returns the labeler's did:web identity derived from its public URL.
func (c *Config) DID() string {
u, err := url.Parse(c.PublicURL())
if err != nil {
return ""
}
host := u.Hostname()
if port := u.Port(); port != "" {
host += "%3A" + port
}
return "did:web:" + host
}
func setDefaults(v *viper.Viper) {
v.SetDefault("version", "0.1")
v.SetDefault("log_level", "info")
// Labeler defaults
v.SetDefault("labeler.enabled", false)
v.SetDefault("labeler.addr", ":5002")
v.SetDefault("labeler.owner_did", "")
v.SetDefault("labeler.db_path", "/var/lib/atcr-labeler/labeler.db")
// Server defaults (read from shared appview config)
v.SetDefault("server.base_url", "")
v.SetDefault("server.client_name", "AT Container Registry")
v.SetDefault("server.client_short_name", "ATCR")
v.SetDefault("server.test_mode", false)
}
// LoadConfig loads the labeler configuration from the appview config YAML.
func LoadConfig(yamlPath string) (*Config, error) {
v := config.NewViper("LABELER", yamlPath)
setDefaults(v)
cfg := &Config{}
if err := v.Unmarshal(cfg, config.UnmarshalOption()); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Also try ATCR_ prefix for shared server config
atcrV := config.NewViper("ATCR", yamlPath)
if baseURL := atcrV.GetString("server.base_url"); baseURL != "" && cfg.Server.BaseURL == "" {
cfg.Server.BaseURL = baseURL
}
if clientName := atcrV.GetString("server.client_name"); clientName != "" && cfg.Server.ClientName == "" {
cfg.Server.ClientName = clientName
}
if clientShortName := atcrV.GetString("server.client_short_name"); clientShortName != "" && cfg.Server.ClientShortName == "" {
cfg.Server.ClientShortName = clientShortName
}
if atcrV.GetBool("server.test_mode") {
cfg.Server.TestMode = true
}
// Validation
if cfg.Server.BaseURL == "" {
return nil, fmt.Errorf("server.base_url is required")
}
if cfg.Labeler.OwnerDID == "" {
return nil, fmt.Errorf("labeler.owner_did is required")
}
if !strings.HasPrefix(cfg.Labeler.OwnerDID, "did:") {
return nil, fmt.Errorf("labeler.owner_did must be a DID (got %q)", cfg.Labeler.OwnerDID)
}
return cfg, nil
}
// ExampleYAML generates an example labeler configuration file.
func ExampleYAML() ([]byte, error) {
cfg := &Config{
Version: "0.1",
LogLevel: "info",
Server: AppviewServerConfig{
BaseURL: "https://atcr.io",
ClientName: "AT Container Registry",
ClientShortName: "ATCR",
},
Labeler: LabelerConfig{
Enabled: true,
Addr: ":5002",
OwnerDID: "did:plc:your-did-here",
DBPath: "/var/lib/atcr-labeler/labeler.db",
},
}
return config.MarshalCommentedYAML("ATCR Labeler Configuration", cfg)
}

View File

@@ -0,0 +1,51 @@
package labeler
import "testing"
func TestConfig_PublicURL(t *testing.T) {
tests := []struct {
name string
baseURL string
want string
}{
{"standard", "https://atcr.io", "https://labeler.atcr.io"},
{"with port", "https://atcr.io:8080", "https://labeler.atcr.io:8080"},
{"localhost", "http://localhost:5000", "http://labeler.localhost:5000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Server: AppviewServerConfig{BaseURL: tt.baseURL},
}
got := cfg.PublicURL()
if got != tt.want {
t.Errorf("PublicURL() = %q, want %q", got, tt.want)
}
})
}
}
func TestConfig_DID(t *testing.T) {
tests := []struct {
name string
baseURL string
want string
}{
{"standard", "https://atcr.io", "did:web:labeler.atcr.io"},
{"with port", "https://atcr.io:8080", "did:web:labeler.atcr.io%3A8080"},
{"localhost", "http://localhost:5000", "did:web:labeler.localhost%3A5000"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &Config{
Server: AppviewServerConfig{BaseURL: tt.baseURL},
}
got := cfg.DID()
if got != tt.want {
t.Errorf("DID() = %q, want %q", got, tt.want)
}
})
}
}

301
pkg/labeler/db.go Normal file
View File

@@ -0,0 +1,301 @@
package labeler
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"time"
_ "github.com/tursodatabase/go-libsql"
)
const schema = `
CREATE TABLE IF NOT EXISTS labels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
src TEXT NOT NULL,
uri TEXT NOT NULL,
cid TEXT,
val TEXT NOT NULL,
neg BOOLEAN NOT NULL DEFAULT 0,
cts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
exp TIMESTAMP,
subject_did TEXT NOT NULL,
subject_repo TEXT NOT NULL DEFAULT '',
UNIQUE(src, uri, val, neg)
);
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
CREATE INDEX IF NOT EXISTS idx_labels_cts ON labels(cts DESC);
`
// Label represents an ATProto label (com.atproto.label.defs#label).
type Label struct {
ID int64
Src string
URI string
CID string
Val string
Neg bool
Cts time.Time
Exp *time.Time
SubjectDID string
SubjectRepo string
}
// OpenDB opens or creates the labeler database.
func OpenDB(dbPath string) (*sql.DB, error) {
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("failed to create db directory: %w", err)
}
db, err := sql.Open("libsql", "file:"+dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Apply schema
for _, stmt := range splitStatements(schema) {
if _, err := db.Exec(stmt); err != nil {
return nil, fmt.Errorf("failed to apply schema: %w", err)
}
}
return db, nil
}
// splitStatements splits SQL by semicolons (go-libsql doesn't support multi-statement exec).
func splitStatements(sql string) []string {
var stmts []string
for _, s := range splitOnSemicolon(sql) {
s = trimSpace(s)
if s != "" {
stmts = append(stmts, s)
}
}
return stmts
}
func splitOnSemicolon(s string) []string {
var parts []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == ';' {
parts = append(parts, s[start:i])
start = i + 1
}
}
if start < len(s) {
parts = append(parts, s[start:])
}
return parts
}
func trimSpace(s string) string {
// Simple trim that handles newlines and spaces
i := 0
for i < len(s) && (s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == '\r') {
i++
}
j := len(s)
for j > i && (s[j-1] == ' ' || s[j-1] == '\t' || s[j-1] == '\n' || s[j-1] == '\r') {
j--
}
return s[i:j]
}
// CreateLabel inserts a new label into the database.
func CreateLabel(db *sql.DB, l *Label) (int64, error) {
result, err := db.Exec(
`INSERT INTO labels (src, uri, cid, val, neg, cts, exp, subject_did, subject_repo)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts`,
l.Src, l.URI, l.CID, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339), l.Exp,
l.SubjectDID, l.SubjectRepo,
)
if err != nil {
return 0, fmt.Errorf("failed to create label: %w", err)
}
return result.LastInsertId()
}
// NegateLabel creates a negation label to reverse a previous label.
func NegateLabel(db *sql.DB, src, uri, val string, subjectDID, subjectRepo string) error {
_, err := db.Exec(
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
VALUES (?, ?, ?, 1, ?, ?, ?)`,
src, uri, val, time.Now().UTC().Format(time.RFC3339), subjectDID, subjectRepo,
)
return err
}
// GetLabelsSince returns labels with id > cursor, ordered by id ascending.
func GetLabelsSince(db *sql.DB, cursor int64, limit int) ([]Label, error) {
rows, err := db.Query(
`SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
FROM labels WHERE id > ? ORDER BY id ASC LIMIT ?`,
cursor, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanLabels(rows)
}
// ListActiveTakedowns returns active (non-negated) takedown labels.
func ListActiveTakedowns(db *sql.DB, limit, offset int) ([]Label, int, error) {
var total int
err := db.QueryRow(
`SELECT COUNT(*) FROM labels l1
WHERE l1.val = '!takedown' AND l1.neg = 0
AND NOT EXISTS (
SELECT 1 FROM labels l2
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
AND l2.neg = 1 AND l2.id > l1.id
)
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)`,
).Scan(&total)
if err != nil {
return nil, 0, err
}
rows, err := db.Query(
`SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.subject_did, l1.subject_repo
FROM labels l1
WHERE l1.val = '!takedown' AND l1.neg = 0
AND NOT EXISTS (
SELECT 1 FROM labels l2
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
AND l2.neg = 1 AND l2.id > l1.id
)
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)
ORDER BY l1.cts DESC LIMIT ? OFFSET ?`,
limit, offset,
)
if err != nil {
return nil, 0, err
}
defer rows.Close()
labels, err := scanLabels(rows)
return labels, total, err
}
// GetLabelsForRepo returns all active labels for a specific DID + repository.
func GetLabelsForRepo(db *sql.DB, did, repo string) ([]Label, error) {
rows, err := db.Query(
`SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
FROM labels
WHERE subject_did = ? AND subject_repo = ?
ORDER BY cts DESC`,
did, repo,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanLabels(rows)
}
// NegateRepoLabels creates negation labels for all active takedown labels on a (DID, repo) pair.
func NegateRepoLabels(db *sql.DB, src, did, repo string) error {
rows, err := db.Query(
`SELECT uri FROM labels
WHERE subject_did = ? AND subject_repo = ? AND val = '!takedown' AND neg = 0`,
did, repo,
)
if err != nil {
return err
}
var uris []string
for rows.Next() {
var uri string
if err := rows.Scan(&uri); err != nil {
rows.Close()
return err
}
uris = append(uris, uri)
}
rows.Close()
if err := rows.Err(); err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
for _, uri := range uris {
if _, err := db.Exec(
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
src, uri, now, did, repo,
); err != nil {
return err
}
}
return nil
}
// NegateUserLabels creates negation labels for all active takedown labels on a DID (user-level).
func NegateUserLabels(db *sql.DB, src, did string) error {
rows, err := db.Query(
`SELECT uri, subject_repo FROM labels
WHERE subject_did = ? AND val = '!takedown' AND neg = 0`,
did,
)
if err != nil {
return err
}
type uriRepo struct {
uri string
repo string
}
var entries []uriRepo
for rows.Next() {
var e uriRepo
if err := rows.Scan(&e.uri, &e.repo); err != nil {
rows.Close()
return err
}
entries = append(entries, e)
}
rows.Close()
if err := rows.Err(); err != nil {
return err
}
now := time.Now().UTC().Format(time.RFC3339)
for _, e := range entries {
if _, err := db.Exec(
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
src, e.uri, now, did, e.repo,
); err != nil {
return err
}
}
return nil
}
func scanLabels(rows *sql.Rows) ([]Label, error) {
var labels []Label
for rows.Next() {
var l Label
var cts string
var exp *string
if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.SubjectDID, &l.SubjectRepo); err != nil {
return nil, err
}
if t, err := time.Parse(time.RFC3339, cts); err == nil {
l.Cts = t
}
if exp != nil {
if t, err := time.Parse(time.RFC3339, *exp); err == nil {
l.Exp = &t
}
}
labels = append(labels, l)
}
return labels, rows.Err()
}

412
pkg/labeler/db_test.go Normal file
View File

@@ -0,0 +1,412 @@
package labeler
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestOpenDB(t *testing.T) {
dir := t.TempDir()
dbPath := filepath.Join(dir, "subdir", "test.db")
db, err := OpenDB(dbPath)
if err != nil {
t.Fatalf("OpenDB failed: %v", err)
}
defer db.Close()
// Verify directory was created
if _, err := os.Stat(filepath.Dir(dbPath)); os.IsNotExist(err) {
t.Error("expected directory to be created")
}
// Verify tables exist
var count int
err = db.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count)
if err != nil {
t.Fatalf("failed to query labels table: %v", err)
}
if count != 0 {
t.Errorf("expected 0 labels, got %d", count)
}
}
func TestCreateLabel(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
now := time.Now().UTC().Truncate(time.Second)
label := &Label{
Src: "did:web:labeler.atcr.io",
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
Val: "!takedown",
Cts: now,
SubjectDID: "did:plc:abc",
SubjectRepo: "myimage",
}
id, err := CreateLabel(db, label)
if err != nil {
t.Fatalf("CreateLabel failed: %v", err)
}
if id <= 0 {
t.Errorf("expected positive id, got %d", id)
}
// Verify it was stored
labels, err := GetLabelsSince(db, 0, 10)
if err != nil {
t.Fatal(err)
}
if len(labels) != 1 {
t.Fatalf("expected 1 label, got %d", len(labels))
}
if labels[0].Src != "did:web:labeler.atcr.io" {
t.Errorf("expected src did:web:labeler.atcr.io, got %s", labels[0].Src)
}
if labels[0].Val != "!takedown" {
t.Errorf("expected val !takedown, got %s", labels[0].Val)
}
if labels[0].SubjectDID != "did:plc:abc" {
t.Errorf("expected subject_did did:plc:abc, got %s", labels[0].SubjectDID)
}
if labels[0].SubjectRepo != "myimage" {
t.Errorf("expected subject_repo myimage, got %s", labels[0].SubjectRepo)
}
}
func TestCreateLabel_Upsert(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
now := time.Now().UTC()
label := &Label{
Src: "did:web:labeler.atcr.io",
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
Val: "!takedown",
Cts: now,
SubjectDID: "did:plc:abc",
SubjectRepo: "myimage",
}
// First insert
_, err = CreateLabel(db, label)
if err != nil {
t.Fatal(err)
}
// Same (src, uri, val) - should upsert, not error
label.Cts = now.Add(time.Hour)
_, err = CreateLabel(db, label)
if err != nil {
t.Fatalf("upsert should not fail: %v", err)
}
// Should still be 1 label
labels, err := GetLabelsSince(db, 0, 10)
if err != nil {
t.Fatal(err)
}
if len(labels) != 1 {
t.Errorf("expected 1 label after upsert, got %d", len(labels))
}
}
func TestNegateLabel(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
// Create a label
_, err = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
Val: "!takedown", Cts: now,
SubjectDID: "did:plc:abc", SubjectRepo: "myimage",
})
if err != nil {
t.Fatal(err)
}
// Negate it
err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.manifest/sha256-123", "!takedown", "did:plc:abc", "myimage")
if err != nil {
t.Fatalf("NegateLabel failed: %v", err)
}
// Should have 2 labels now (original + negation)
labels, err := GetLabelsSince(db, 0, 10)
if err != nil {
t.Fatal(err)
}
if len(labels) != 2 {
t.Fatalf("expected 2 labels, got %d", len(labels))
}
// The negation label should have neg=true
negLabel := labels[1]
if !negLabel.Neg {
t.Error("expected negation label to have neg=true")
}
}
func TestListActiveTakedowns(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
// Create 3 labels
for i, repo := range []string{"repo1", "repo2", "repo3"} {
_, err = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.repo/" + repo,
Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
SubjectDID: "did:plc:abc", SubjectRepo: repo,
})
if err != nil {
t.Fatal(err)
}
}
// All 3 should be active
labels, total, err := ListActiveTakedowns(db, 10, 0)
if err != nil {
t.Fatal(err)
}
if total != 3 {
t.Errorf("expected 3 active takedowns, got %d", total)
}
if len(labels) != 3 {
t.Errorf("expected 3 labels returned, got %d", len(labels))
}
// Negate one
err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.repo/repo2", "!takedown", "did:plc:abc", "repo2")
if err != nil {
t.Fatal(err)
}
// Should be 2 active
_, total, err = ListActiveTakedowns(db, 10, 0)
if err != nil {
t.Fatal(err)
}
if total != 2 {
t.Errorf("expected 2 active takedowns after negation, got %d", total)
}
}
func TestNegateRepoLabels(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
did := "did:plc:abc"
// Create multiple labels for same repo
uris := []string{
"at://did:plc:abc/io.atcr.manifest/sha256-111",
"at://did:plc:abc/io.atcr.manifest/sha256-222",
"at://did:plc:abc/io.atcr.tag/myimage-latest",
}
for _, uri := range uris {
_, err = CreateLabel(db, &Label{
Src: src, URI: uri, Val: "!takedown", Cts: now,
SubjectDID: did, SubjectRepo: "myimage",
})
if err != nil {
t.Fatal(err)
}
}
// Negate all labels for the repo
err = NegateRepoLabels(db, src, did, "myimage")
if err != nil {
t.Fatal(err)
}
// Should have 0 active takedowns
_, total, err := ListActiveTakedowns(db, 10, 0)
if err != nil {
t.Fatal(err)
}
if total != 0 {
t.Errorf("expected 0 active takedowns after repo negation, got %d", total)
}
}
func TestNegateUserLabels(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
did := "did:plc:abc"
// Create labels for different repos + a user-level label
_, err = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc", Val: "!takedown", Cts: now,
SubjectDID: did, SubjectRepo: "",
})
if err != nil {
t.Fatal(err)
}
_, err = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", Val: "!takedown", Cts: now,
SubjectDID: did, SubjectRepo: "repo1",
})
if err != nil {
t.Fatal(err)
}
// Negate all labels for the user
err = NegateUserLabels(db, src, did)
if err != nil {
t.Fatal(err)
}
// Should have 0 active
_, total, err := ListActiveTakedowns(db, 10, 0)
if err != nil {
t.Fatal(err)
}
if total != 0 {
t.Errorf("expected 0 active takedowns after user negation, got %d", total)
}
}
func TestGetLabelsSince(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
// Create 5 labels
for i := 0; i < 5; i++ {
_, err = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.manifest/" + string(rune('a'+i)),
Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
SubjectDID: "did:plc:abc", SubjectRepo: "repo",
})
if err != nil {
t.Fatal(err)
}
}
// Get all since 0
labels, err := GetLabelsSince(db, 0, 10)
if err != nil {
t.Fatal(err)
}
if len(labels) != 5 {
t.Errorf("expected 5 labels, got %d", len(labels))
}
// Get since cursor (skip first 3)
if len(labels) >= 3 {
cursor := labels[2].ID
after, err := GetLabelsSince(db, cursor, 10)
if err != nil {
t.Fatal(err)
}
if len(after) != 2 {
t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after))
}
}
// Get with limit
limited, err := GetLabelsSince(db, 0, 2)
if err != nil {
t.Fatal(err)
}
if len(limited) != 2 {
t.Errorf("expected 2 labels with limit, got %d", len(limited))
}
}
func TestGetLabelsForRepo(t *testing.T) {
dir := t.TempDir()
db, err := OpenDB(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatal(err)
}
defer db.Close()
src := "did:web:labeler.atcr.io"
now := time.Now().UTC()
// Labels for different repos
_, _ = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1",
Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1",
})
_, _ = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo2",
Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo2",
})
_, _ = CreateLabel(db, &Label{
Src: src, URI: "at://did:plc:def/io.atcr.repo/repo1",
Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1",
})
// Get labels for specific did+repo
labels, err := GetLabelsForRepo(db, "did:plc:abc", "repo1")
if err != nil {
t.Fatal(err)
}
if len(labels) != 1 {
t.Errorf("expected 1 label for did:plc:abc/repo1, got %d", len(labels))
}
// Different user same repo
labels, err = GetLabelsForRepo(db, "did:plc:def", "repo1")
if err != nil {
t.Fatal(err)
}
if len(labels) != 1 {
t.Errorf("expected 1 label for did:plc:def/repo1, got %d", len(labels))
}
// No labels
labels, err = GetLabelsForRepo(db, "did:plc:xyz", "repo1")
if err != nil {
t.Fatal(err)
}
if len(labels) != 0 {
t.Errorf("expected 0 labels for unknown did, got %d", len(labels))
}
}

118
pkg/labeler/handlers.go Normal file
View File

@@ -0,0 +1,118 @@
package labeler
import (
"fmt"
"html/template"
"log/slog"
"net/http"
"strings"
"atcr.io/pkg/atproto"
)
// Auth handlers
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
// If already logged in, redirect to dashboard
if token, ok := getSessionCookie(r); ok {
if session := s.auth.getSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID {
http.Redirect(w, r, "/", http.StatusFound)
return
}
}
errorMsg := r.URL.Query().Get("error")
w.Header().Set("Content-Type", "text/html")
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>%s Labeler - Login</title>
<style>body{font-family:system-ui;max-width:400px;margin:100px auto;padding:0 20px}
.error{color:red;margin-bottom:1em}
input{width:100%%;padding:8px;margin:8px 0;box-sizing:border-box}
button{padding:10px 20px;cursor:pointer}</style>
</head>
<body>
<h1>%s Labeler</h1>
<p>Sign in with your AT Protocol identity.</p>
%s
<form action="/auth/oauth/authorize" method="GET">
<input name="handle" placeholder="your.handle.com" required>
<button type="submit">Sign In</button>
</form>
</body></html>`,
s.config.Server.ClientShortName,
s.config.Server.ClientShortName,
func() string {
if errorMsg != "" {
return fmt.Sprintf(`<div class="error">%s</div>`, template.HTMLEscapeString(errorMsg))
}
return ""
}(),
)
}
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
handle := strings.TrimSpace(r.URL.Query().Get("handle"))
if handle == "" {
http.Redirect(w, r, "/auth/login?error=Handle+is+required", http.StatusFound)
return
}
handle = strings.TrimPrefix(handle, "@")
did, _, _, err := atproto.ResolveIdentity(r.Context(), handle)
if err != nil {
slog.Warn("Failed to resolve handle for labeler login", "handle", handle, "error", err)
http.Redirect(w, r, "/auth/login?error=Could+not+resolve+handle", http.StatusFound)
return
}
authURL, err := s.clientApp.StartAuthFlow(r.Context(), did)
if err != nil {
slog.Error("Failed to start OAuth flow", "error", err)
http.Redirect(w, r, "/auth/login?error=OAuth+initialization+failed", http.StatusFound)
return
}
http.Redirect(w, r, authURL, http.StatusFound)
}
func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) {
sessionData, err := s.clientApp.ProcessCallback(r.Context(), r.URL.Query())
if err != nil {
slog.Error("OAuth callback failed", "error", err)
http.Redirect(w, r, "/auth/login?error=OAuth+authentication+failed", http.StatusFound)
return
}
did := sessionData.AccountDID.String()
_, handle, _, err := atproto.ResolveIdentity(r.Context(), did)
if err != nil {
handle = did
}
// Only allow the owner
if did != s.config.Labeler.OwnerDID {
slog.Warn("Non-owner attempted labeler access", "did", did, "handle", handle, "owner", s.config.Labeler.OwnerDID)
http.Redirect(w, r, "/auth/login?error=Access+denied:+Only+the+labeler+owner+can+access+the+admin+panel", http.StatusFound)
return
}
token, err := s.auth.createSession(did, handle)
if err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
setSessionCookie(w, r, token)
http.Redirect(w, r, "/", http.StatusFound)
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if token, ok := getSessionCookie(r); ok {
s.auth.deleteSession(token)
}
clearSessionCookie(w)
http.Redirect(w, r, "/auth/login", http.StatusFound)
}

57
pkg/labeler/identity.go Normal file
View File

@@ -0,0 +1,57 @@
package labeler
import (
"encoding/json"
"fmt"
"net/http"
)
// DIDDocument represents a did:web DID document.
type DIDDocument struct {
Context []string `json:"@context"`
ID string `json:"id"`
Service []DIDService `json:"service,omitempty"`
}
// DIDService represents a service entry in a DID document.
type DIDService struct {
ID string `json:"id"`
Type string `json:"type"`
ServiceEndpoint string `json:"serviceEndpoint"`
}
func (s *Server) handleDIDDocument(w http.ResponseWriter, r *http.Request) {
doc := DIDDocument{
Context: []string{"https://www.w3.org/ns/did/v1"},
ID: s.config.DID(),
Service: []DIDService{
{
ID: "#atproto_labeler",
Type: "AtprotoLabeler",
ServiceEndpoint: s.config.PublicURL(),
},
},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(doc)
}
func (s *Server) handleClientMetadata(w http.ResponseWriter, r *http.Request) {
publicURL := s.config.PublicURL()
metadata := map[string]any{
"client_id": publicURL + "/oauth-client-metadata.json",
"client_name": fmt.Sprintf("%s Labeler", s.config.Server.ClientShortName),
"client_uri": publicURL,
"redirect_uris": []string{publicURL + "/auth/oauth/callback"},
"scope": "atproto",
"grant_types": []string{"authorization_code"},
"response_types": []string{"code"},
"token_endpoint_auth_method": "none",
"application_type": "web",
"dpop_bound_access_tokens": true,
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(metadata)
}

156
pkg/labeler/server.go Normal file
View File

@@ -0,0 +1,156 @@
package labeler
import (
"context"
"database/sql"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"atcr.io/pkg/atproto"
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
"github.com/go-chi/chi/v5"
)
// Server is the labeler HTTP server.
type Server struct {
config *Config
db *sql.DB
router chi.Router
clientApp *indigooauth.ClientApp
auth *Auth
}
// NewServer creates a new labeler server.
func NewServer(cfg *Config) (*Server, error) {
db, err := OpenDB(cfg.Labeler.DBPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
publicURL := cfg.PublicURL()
// Set up OAuth client for admin login
oauthStore := indigooauth.NewMemStore()
scopes := []string{"atproto"}
var oauthConfig indigooauth.ClientConfig
var redirectURI string
u, err := url.Parse(publicURL)
if err != nil {
return nil, fmt.Errorf("invalid public URL: %w", err)
}
host := u.Hostname()
if isLocalhost(host) {
port := u.Port()
if port == "" {
port = "5002"
}
oauthBaseURL := "http://127.0.0.1:" + port
redirectURI = oauthBaseURL + "/auth/oauth/callback"
oauthConfig = indigooauth.NewLocalhostConfig(redirectURI, scopes)
} else {
clientID := publicURL + "/oauth-client-metadata.json"
redirectURI = publicURL + "/auth/oauth/callback"
oauthConfig = indigooauth.NewPublicConfig(clientID, redirectURI, scopes)
}
clientApp := indigooauth.NewClientApp(&oauthConfig, oauthStore)
clientApp.Dir = atproto.GetDirectory()
auth := NewAuth(cfg.Labeler.OwnerDID)
s := &Server{
config: cfg,
db: db,
clientApp: clientApp,
auth: auth,
}
s.setupRoutes()
return s, nil
}
func (s *Server) setupRoutes() {
r := chi.NewRouter()
// DID document
r.Get("/.well-known/did.json", s.handleDIDDocument)
// OAuth client metadata
r.Get("/oauth-client-metadata.json", s.handleClientMetadata)
// Auth routes (public)
r.Get("/auth/login", s.handleLogin)
r.Get("/auth/oauth/authorize", s.handleAuthorize)
r.Get("/auth/oauth/callback", s.handleCallback)
r.Get("/auth/logout", s.handleLogout)
// XRPC endpoints (public)
r.Get("/xrpc/com.atproto.label.subscribeLabels", s.handleSubscribeLabels)
r.Get("/xrpc/com.atproto.label.queryLabels", s.handleQueryLabels)
// Protected routes (require owner)
r.Group(func(r chi.Router) {
r.Use(s.auth.RequireOwner)
r.Get("/", s.handleDashboard)
r.Get("/takedown", s.handleTakedownForm)
r.Post("/takedown", s.handleTakedownSubmit)
r.Post("/reverse", s.handleReverse)
})
s.router = r
}
// Serve starts the HTTP server with graceful shutdown.
func (s *Server) Serve() error {
slog.Info("Starting labeler service",
"addr", s.config.Labeler.Addr,
"public_url", s.config.PublicURL(),
"did", s.config.DID(),
"owner", s.config.Labeler.OwnerDID,
)
srv := &http.Server{
Addr: s.config.Labeler.Addr,
Handler: s.router,
}
// Graceful shutdown
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case err := <-errCh:
if err != http.ErrServerClosed {
return err
}
case <-ctx.Done():
slog.Info("Shutting down labeler service")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5000000000) // 5s
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown error: %w", err)
}
}
s.db.Close()
return nil
}
func isLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || strings.HasPrefix(host, "192.168.")
}

187
pkg/labeler/subscribe.go Normal file
View File

@@ -0,0 +1,187 @@
package labeler
import (
"encoding/json"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// LabelsMessage is the ATProto subscribeLabels wire format.
type LabelsMessage struct {
Seq int64 `json:"seq"`
Labels []LabelOutput `json:"labels"`
}
// LabelOutput is the ATProto label format for subscribeLabels/queryLabels output.
type LabelOutput struct {
Src string `json:"src"`
URI string `json:"uri"`
CID string `json:"cid,omitempty"`
Val string `json:"val"`
Neg bool `json:"neg"`
Cts string `json:"cts"`
Exp string `json:"exp,omitempty"`
}
func labelToOutput(l Label) LabelOutput {
out := LabelOutput{
Src: l.Src,
URI: l.URI,
CID: l.CID,
Val: l.Val,
Neg: l.Neg,
Cts: l.Cts.UTC().Format(time.RFC3339),
}
if l.Exp != nil {
out.Exp = l.Exp.UTC().Format(time.RFC3339)
}
return out
}
// handleSubscribeLabels implements com.atproto.label.subscribeLabels (WebSocket).
func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) {
cursorStr := r.URL.Query().Get("cursor")
var cursor int64
if cursorStr != "" {
var err error
cursor, err = strconv.ParseInt(cursorStr, 10, 64)
if err != nil {
http.Error(w, "invalid cursor", http.StatusBadRequest)
return
}
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("WebSocket upgrade failed", "error", err)
return
}
defer conn.Close()
slog.Info("subscribeLabels client connected", "cursor", cursor)
// Send historical labels since cursor
labels, err := GetLabelsSince(s.db, cursor, 1000)
if err != nil {
slog.Error("Failed to get labels", "error", err)
return
}
for _, l := range labels {
msg := LabelsMessage{
Seq: l.ID,
Labels: []LabelOutput{labelToOutput(l)},
}
if err := conn.WriteJSON(msg); err != nil {
return
}
cursor = l.ID
}
// Poll for new labels
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
// Read pump (detect client disconnect)
done := make(chan struct{})
go func() {
defer close(done)
for {
if _, _, err := conn.ReadMessage(); err != nil {
return
}
}
}()
for {
select {
case <-done:
return
case <-ticker.C:
labels, err := GetLabelsSince(s.db, cursor, 100)
if err != nil {
slog.Error("Failed to poll labels", "error", err)
continue
}
for _, l := range labels {
msg := LabelsMessage{
Seq: l.ID,
Labels: []LabelOutput{labelToOutput(l)},
}
if err := conn.WriteJSON(msg); err != nil {
return
}
cursor = l.ID
}
}
}
}
// handleQueryLabels implements com.atproto.label.queryLabels (HTTP GET).
func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) {
uriPatterns := r.URL.Query()["uriPatterns"]
cursorStr := r.URL.Query().Get("cursor")
limitStr := r.URL.Query().Get("limit")
var cursor int64
if cursorStr != "" {
cursor, _ = strconv.ParseInt(cursorStr, 10, 64)
}
limit := 50
if limitStr != "" {
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 250 {
limit = l
}
}
labels, err := GetLabelsSince(s.db, cursor, limit)
if err != nil {
http.Error(w, "failed to query labels", http.StatusInternalServerError)
return
}
// Filter by URI patterns if provided
var filtered []LabelOutput
for _, l := range labels {
if len(uriPatterns) == 0 || matchesAnyPattern(l.URI, uriPatterns) {
filtered = append(filtered, labelToOutput(l))
}
}
var nextCursor string
if len(labels) > 0 {
nextCursor = strconv.FormatInt(labels[len(labels)-1].ID, 10)
}
resp := struct {
Cursor string `json:"cursor,omitempty"`
Labels []LabelOutput `json:"labels"`
}{
Cursor: nextCursor,
Labels: filtered,
}
if resp.Labels == nil {
resp.Labels = []LabelOutput{}
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
func matchesAnyPattern(uri string, patterns []string) bool {
for _, p := range patterns {
// Simple prefix matching (ATProto spec allows glob-like patterns)
if p == uri || (len(p) > 0 && p[len(p)-1] == '*' && len(uri) >= len(p)-1 && uri[:len(p)-1] == p[:len(p)-1]) {
return true
}
}
return false
}

View File

@@ -0,0 +1,86 @@
package labeler
import (
"testing"
"time"
)
func TestLabelToOutput(t *testing.T) {
now := time.Date(2026, 3, 22, 10, 0, 0, 0, time.UTC)
exp := time.Date(2026, 4, 22, 10, 0, 0, 0, time.UTC)
label := Label{
ID: 1,
Src: "did:web:labeler.atcr.io",
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
CID: "bafyabc",
Val: "!takedown",
Neg: false,
Cts: now,
Exp: &exp,
SubjectDID: "did:plc:abc",
SubjectRepo: "myimage",
}
out := labelToOutput(label)
if out.Src != "did:web:labeler.atcr.io" {
t.Errorf("Src = %q, want did:web:labeler.atcr.io", out.Src)
}
if out.URI != "at://did:plc:abc/io.atcr.manifest/sha256-123" {
t.Errorf("URI = %q", out.URI)
}
if out.CID != "bafyabc" {
t.Errorf("CID = %q, want bafyabc", out.CID)
}
if out.Val != "!takedown" {
t.Errorf("Val = %q", out.Val)
}
if out.Neg {
t.Error("expected Neg=false")
}
if out.Cts != "2026-03-22T10:00:00Z" {
t.Errorf("Cts = %q", out.Cts)
}
if out.Exp != "2026-04-22T10:00:00Z" {
t.Errorf("Exp = %q", out.Exp)
}
}
func TestLabelToOutput_NoExpiration(t *testing.T) {
label := Label{
Src: "did:web:labeler.atcr.io",
URI: "at://did:plc:abc",
Val: "!takedown",
Cts: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
}
out := labelToOutput(label)
if out.Exp != "" {
t.Errorf("expected empty Exp, got %q", out.Exp)
}
}
func TestMatchesAnyPattern(t *testing.T) {
tests := []struct {
name string
uri string
patterns []string
want bool
}{
{"exact match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/io.atcr.manifest/sha256-123"}, true},
{"no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/io.atcr.manifest/sha256-123"}, false},
{"wildcard match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/*"}, true},
{"wildcard no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*"}, false},
{"empty patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{}, false},
{"multiple patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*", "at://did:plc:abc/*"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchesAnyPattern(tt.uri, tt.patterns)
if got != tt.want {
t.Errorf("matchesAnyPattern(%q, %v) = %v, want %v", tt.uri, tt.patterns, got, tt.want)
}
})
}
}

418
pkg/labeler/takedown.go Normal file
View File

@@ -0,0 +1,418 @@
package labeler
import (
"context"
"encoding/json"
"fmt"
"html/template"
"log/slog"
"net/http"
"strings"
"time"
"atcr.io/pkg/atproto"
)
// TakedownInput represents parsed takedown input.
type TakedownInput struct {
DID string
Handle string
Repository string // empty = user-level takedown
}
// ParseTakedownInput parses various input formats into a TakedownInput.
// Supported formats:
// - atcr.io/r/handle/repo
// - handle/repo
// - at://did:plc:xyz/io.atcr.repo.page/repo
// - at://did:plc:xyz (user-level)
// - handle (user-level)
// - did:plc:xyz (user-level)
func ParseTakedownInput(ctx context.Context, input string) (*TakedownInput, error) {
input = strings.TrimSpace(input)
// AT URI format
if strings.HasPrefix(input, "at://") {
return parseATURI(ctx, input)
}
// Strip URL prefix if present
input = strings.TrimPrefix(input, "https://")
input = strings.TrimPrefix(input, "http://")
// Remove atcr.io/r/ or similar prefix
for _, prefix := range []string{"atcr.io/r/", "localhost/r/"} {
if strings.HasPrefix(input, prefix) {
input = strings.TrimPrefix(input, prefix)
break
}
}
// Also handle custom domains: anything ending in /r/
if idx := strings.Index(input, "/r/"); idx >= 0 {
input = input[idx+3:]
}
// Now input should be "handle/repo" or "handle" or "did:xxx"
parts := strings.SplitN(input, "/", 2)
identifier := parts[0]
var repo string
if len(parts) > 1 {
repo = parts[1]
repo = strings.TrimSuffix(repo, "/")
}
did, handle, err := resolveIdentifier(ctx, identifier)
if err != nil {
return nil, err
}
return &TakedownInput{
DID: did,
Handle: handle,
Repository: repo,
}, nil
}
func parseATURI(ctx context.Context, uri string) (*TakedownInput, error) {
// at://did:plc:xyz/collection/rkey
trimmed := strings.TrimPrefix(uri, "at://")
parts := strings.SplitN(trimmed, "/", 3)
did := parts[0]
if !strings.HasPrefix(did, "did:") {
// It's a handle
resolvedDID, handle, err := resolveIdentifier(ctx, did)
if err != nil {
return nil, err
}
did = resolvedDID
if len(parts) >= 3 {
return &TakedownInput{DID: did, Handle: handle, Repository: parts[2]}, nil
}
return &TakedownInput{DID: did, Handle: handle}, nil
}
// Resolve handle from DID
_, handle, _, _ := atproto.ResolveIdentity(ctx, did)
if len(parts) < 3 {
// User-level takedown
return &TakedownInput{DID: did, Handle: handle}, nil
}
// Extract repository from rkey (third part)
repo := parts[2]
return &TakedownInput{DID: did, Handle: handle, Repository: repo}, nil
}
func resolveIdentifier(ctx context.Context, identifier string) (did, handle string, err error) {
did, handle, _, err = atproto.ResolveIdentity(ctx, identifier)
if err != nil {
return "", "", fmt.Errorf("failed to resolve %q: %w", identifier, err)
}
return did, handle, nil
}
// TakedownResult contains the results of a takedown operation.
type TakedownResult struct {
DID string
Handle string
Repository string
Labels []Label
UserLevel bool
}
// ExecuteTakedown creates takedown labels for a repo or user.
func (s *Server) ExecuteTakedown(ctx context.Context, input *TakedownInput) (*TakedownResult, error) {
src := s.config.DID()
now := time.Now().UTC()
result := &TakedownResult{
DID: input.DID,
Handle: input.Handle,
Repository: input.Repository,
UserLevel: input.Repository == "",
}
if input.Repository == "" {
// User-level takedown
label := &Label{
Src: src,
URI: "at://" + input.DID,
Val: "!takedown",
Cts: now,
SubjectDID: input.DID,
SubjectRepo: "",
}
if _, err := CreateLabel(s.db, label); err != nil {
return nil, fmt.Errorf("failed to create user-level label: %w", err)
}
result.Labels = append(result.Labels, *label)
slog.Info("Created user-level takedown", "did", input.DID, "handle", input.Handle)
return result, nil
}
// Repo-level takedown: discover all records from PDS
labels, err := s.discoverAndLabelRecords(ctx, input.DID, input.Repository, src, now)
if err != nil {
// Even if PDS discovery fails, create a repo-level summary label
slog.Warn("PDS discovery failed, creating summary label only", "error", err)
}
result.Labels = append(result.Labels, labels...)
// Always create a repo-level summary label for efficient filtering
summaryLabel := &Label{
Src: src,
URI: fmt.Sprintf("at://%s/io.atcr.repo/%s", input.DID, input.Repository),
Val: "!takedown",
Cts: now,
SubjectDID: input.DID,
SubjectRepo: input.Repository,
}
if _, err := CreateLabel(s.db, summaryLabel); err != nil {
return nil, fmt.Errorf("failed to create summary label: %w", err)
}
result.Labels = append(result.Labels, *summaryLabel)
slog.Info("Created repo-level takedown",
"did", input.DID,
"handle", input.Handle,
"repository", input.Repository,
"label_count", len(result.Labels),
)
return result, nil
}
// discoverAndLabelRecords queries the user's PDS for all records in the given repo
// and creates takedown labels for each.
func (s *Server) discoverAndLabelRecords(ctx context.Context, did, repo, src string, now time.Time) ([]Label, error) {
_, _, pdsEndpoint, err := atproto.ResolveIdentity(ctx, did)
if err != nil {
return nil, fmt.Errorf("failed to resolve DID: %w", err)
}
client := atproto.NewClient(pdsEndpoint, did, "")
var labels []Label
// Collections to search
collections := []string{
atproto.ManifestCollection,
atproto.TagCollection,
atproto.RepoPageCollection,
}
for _, collection := range collections {
records, _, err := client.ListRecordsForRepo(ctx, did, collection, 100, "")
if err != nil {
slog.Warn("Failed to list records", "collection", collection, "error", err)
continue
}
for _, rec := range records {
// Filter by repository field
recRepo := extractRepoField(rec.Value, collection)
if recRepo != repo {
continue
}
// Use the full AT URI from the record (at://did/collection/rkey)
uri := rec.URI
label := &Label{
Src: src,
URI: uri,
Val: "!takedown",
Cts: now,
SubjectDID: did,
SubjectRepo: repo,
}
if _, err := CreateLabel(s.db, label); err != nil {
slog.Warn("Failed to create label", "uri", uri, "error", err)
continue
}
labels = append(labels, *label)
}
}
return labels, nil
}
// extractRepoField extracts the repository name from a record's JSON value.
func extractRepoField(value json.RawMessage, collection string) string {
// For repo pages, the rkey IS the repository name, but we also check the value
var rec struct {
Repository string `json:"repository"`
}
if err := json.Unmarshal(value, &rec); err == nil && rec.Repository != "" {
return rec.Repository
}
return ""
}
// Handlers
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
labels, total, err := ListActiveTakedowns(s.db, 50, 0)
if err != nil {
http.Error(w, "Failed to list takedowns", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>%s Labeler</title>
<style>
body{font-family:system-ui;max-width:900px;margin:40px auto;padding:0 20px}
table{width:100%%;border-collapse:collapse;margin:20px 0}
th,td{text-align:left;padding:8px;border-bottom:1px solid #ddd}
th{background:#f5f5f5}
.badge{background:#dc2626;color:white;padding:2px 8px;border-radius:4px;font-size:0.85em}
a{color:#2563eb}
nav{display:flex;gap:16px;margin-bottom:24px}
.btn{padding:8px 16px;background:#2563eb;color:white;text-decoration:none;border-radius:4px;border:none;cursor:pointer}
.btn-danger{background:#dc2626}
form{display:inline}
</style>
</head>
<body>
<h1>%s Labeler</h1>
<nav>
<a href="/" class="btn">Dashboard</a>
<a href="/takedown" class="btn">New Takedown</a>
<a href="/auth/logout">Logout</a>
</nav>
<h2>Active Takedowns (%d)</h2>`,
s.config.Server.ClientShortName,
s.config.Server.ClientShortName,
total,
)
if len(labels) == 0 {
fmt.Fprint(w, `<p>No active takedowns.</p>`)
} else {
fmt.Fprint(w, `<table><tr><th>Subject</th><th>Repository</th><th>URI</th><th>Created</th><th>Action</th></tr>`)
for _, l := range labels {
repoDisplay := l.SubjectRepo
if repoDisplay == "" {
repoDisplay = "<em>all repos (user-level)</em>"
}
fmt.Fprintf(w, `<tr>
<td>%s</td>
<td>%s</td>
<td><code>%s</code></td>
<td>%s</td>
<td><form method="POST" action="/reverse"><input type="hidden" name="did" value="%s"><input type="hidden" name="repo" value="%s"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form></td>
</tr>`,
template.HTMLEscapeString(l.SubjectDID),
repoDisplay,
template.HTMLEscapeString(l.URI),
l.Cts.Format("2006-01-02 15:04"),
template.HTMLEscapeString(l.SubjectDID),
template.HTMLEscapeString(l.SubjectRepo),
)
}
fmt.Fprint(w, `</table>`)
}
fmt.Fprint(w, `</body></html>`)
}
func (s *Server) handleTakedownForm(w http.ResponseWriter, r *http.Request) {
msg := r.URL.Query().Get("msg")
errorMsg := r.URL.Query().Get("error")
w.Header().Set("Content-Type", "text/html")
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>%s Labeler - New Takedown</title>
<style>
body{font-family:system-ui;max-width:600px;margin:40px auto;padding:0 20px}
input[type=text]{width:100%%;padding:8px;margin:8px 0;box-sizing:border-box}
.btn{padding:10px 20px;background:#dc2626;color:white;border:none;border-radius:4px;cursor:pointer;font-size:1em}
.success{color:green;margin-bottom:1em}
.error{color:red;margin-bottom:1em}
a{color:#2563eb}
nav{display:flex;gap:16px;margin-bottom:24px}
.nav-btn{padding:8px 16px;background:#2563eb;color:white;text-decoration:none;border-radius:4px}
.help{color:#666;font-size:0.9em;margin-top:4px}
</style>
</head>
<body>
<h1>New Takedown</h1>
<nav>
<a href="/" class="nav-btn">Dashboard</a>
<a href="/takedown" class="nav-btn">New Takedown</a>
</nav>`,
s.config.Server.ClientShortName,
)
if msg != "" {
fmt.Fprintf(w, `<div class="success">%s</div>`, template.HTMLEscapeString(msg))
}
if errorMsg != "" {
fmt.Fprintf(w, `<div class="error">%s</div>`, template.HTMLEscapeString(errorMsg))
}
fmt.Fprint(w, `
<form method="POST" action="/takedown">
<label for="target"><strong>Target</strong></label>
<input type="text" id="target" name="target" placeholder="atcr.io/r/handle/repo, at://did/collection/rkey, or handle" required>
<p class="help">Accepts repo URLs, AT URIs, handles, or DIDs. Omit the repo for a user-level takedown.</p>
<br>
<button type="submit" class="btn" onclick="return confirm('Issue takedown? This will suppress the content immediately.')">Issue Takedown</button>
</form>
</body></html>`)
}
func (s *Server) handleTakedownSubmit(w http.ResponseWriter, r *http.Request) {
target := strings.TrimSpace(r.FormValue("target"))
if target == "" {
http.Redirect(w, r, "/takedown?error=Target+is+required", http.StatusFound)
return
}
input, err := ParseTakedownInput(r.Context(), target)
if err != nil {
http.Redirect(w, r, "/takedown?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound)
return
}
result, err := s.ExecuteTakedown(r.Context(), input)
if err != nil {
http.Redirect(w, r, "/takedown?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound)
return
}
msg := fmt.Sprintf("Takedown issued: %d labels created for %s", len(result.Labels), result.DID)
if result.Repository != "" {
msg += "/" + result.Repository
}
http.Redirect(w, r, "/takedown?msg="+strings.ReplaceAll(msg, " ", "+"), http.StatusFound)
}
func (s *Server) handleReverse(w http.ResponseWriter, r *http.Request) {
did := strings.TrimSpace(r.FormValue("did"))
repo := strings.TrimSpace(r.FormValue("repo"))
if did == "" {
http.Redirect(w, r, "/?error=DID+is+required", http.StatusFound)
return
}
src := s.config.DID()
var err error
if repo == "" {
err = NegateUserLabels(s.db, src, did)
} else {
err = NegateRepoLabels(s.db, src, did, repo)
}
if err != nil {
slog.Error("Failed to reverse takedown", "did", did, "repo", repo, "error", err)
http.Redirect(w, r, "/?error=Failed+to+reverse+takedown", http.StatusFound)
return
}
slog.Info("Reversed takedown", "did", did, "repo", repo)
http.Redirect(w, r, "/", http.StatusFound)
}

View File

@@ -0,0 +1,139 @@
package labeler
import (
"context"
"testing"
)
func TestParseTakedownInput_RepoURL(t *testing.T) {
// These tests only exercise parsing logic, not PDS resolution.
// ResolveIdentity calls are tested with mock server below.
tests := []struct {
name string
input string
wantRepo string
}{
{"full URL", "https://atcr.io/r/handle/myimage", "myimage"},
{"no scheme", "atcr.io/r/handle/myimage", "myimage"},
{"handle/repo", "handle/myimage", "myimage"},
{"trailing slash", "atcr.io/r/handle/myimage/", "myimage"},
{"custom domain", "https://registry.example.com/r/handle/myimage", "myimage"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// These will fail on ResolveIdentity since there's no real PDS,
// but we can at least verify the parsing doesn't panic
_, err := ParseTakedownInput(context.Background(), tt.input)
if err == nil {
t.Skip("ResolveIdentity succeeded unexpectedly (network available)")
}
// The error should be from resolution, not parsing
if err != nil {
t.Logf("Expected resolution error: %v", err)
}
})
}
}
func TestParseTakedownInput_ATURI(t *testing.T) {
tests := []struct {
name string
input string
wantDID string
wantRepo string
}{
{
"full AT URI with collection",
"at://did:plc:abc123/io.atcr.repo.page/myimage",
"did:plc:abc123",
"myimage",
},
{
"DID only (user-level)",
"at://did:plc:abc123",
"did:plc:abc123",
"",
},
{
"DID with collection and rkey",
"at://did:plc:xyz/io.atcr.manifest/sha256-deadbeef",
"did:plc:xyz",
"sha256-deadbeef",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input, err := ParseTakedownInput(context.Background(), tt.input)
if err != nil {
// Resolution may fail for handle-based AT URIs
t.Logf("Parse error (may be expected): %v", err)
return
}
if input.DID != tt.wantDID {
t.Errorf("DID = %q, want %q", input.DID, tt.wantDID)
}
if input.Repository != tt.wantRepo {
t.Errorf("Repository = %q, want %q", input.Repository, tt.wantRepo)
}
})
}
}
func TestParseTakedownInput_DID(t *testing.T) {
// Direct DID input (user-level takedown)
input, err := ParseTakedownInput(context.Background(), "at://did:plc:abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if input.DID != "did:plc:abc123" {
t.Errorf("DID = %q, want did:plc:abc123", input.DID)
}
if input.Repository != "" {
t.Errorf("Repository = %q, want empty (user-level)", input.Repository)
}
}
func TestExtractRepoField(t *testing.T) {
tests := []struct {
name string
value string
collection string
want string
}{
{
"manifest with repository",
`{"$type":"io.atcr.manifest","repository":"myimage","digest":"sha256:abc"}`,
"io.atcr.manifest",
"myimage",
},
{
"tag with repository",
`{"$type":"io.atcr.tag","repository":"myimage","tag":"latest"}`,
"io.atcr.tag",
"myimage",
},
{
"no repository field",
`{"$type":"io.atcr.manifest","digest":"sha256:abc"}`,
"io.atcr.manifest",
"",
},
{
"invalid JSON",
`{invalid}`,
"io.atcr.manifest",
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractRepoField([]byte(tt.value), tt.collection)
if got != tt.want {
t.Errorf("extractRepoField() = %q, want %q", got, tt.want)
}
})
}
}