Compare commits
1 Commits
main
...
label-serv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bf3e15ca2 |
82
cmd/labeler/main.go
Normal file
82
cmd/labeler/main.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"atcr.io/pkg/labeler"
|
||||
)
|
||||
|
||||
var configFile string
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "atcr-labeler",
|
||||
Short: "ATCR Labeler Service - ATProto content moderation",
|
||||
}
|
||||
|
||||
var serveCmd = &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Start the labeler service",
|
||||
Long: `Start the ATCR labeler service with admin UI and subscribeLabels endpoint.
|
||||
|
||||
Configuration is loaded from the appview config YAML (labeler section).
|
||||
Use --config to specify the config file path.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg, err := labeler.LoadConfig(configFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
server, err := labeler.NewServer(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize labeler: %w", err)
|
||||
}
|
||||
|
||||
return server.Serve()
|
||||
},
|
||||
}
|
||||
|
||||
var configCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Configuration management commands",
|
||||
}
|
||||
|
||||
var configInitCmd = &cobra.Command{
|
||||
Use: "init [path]",
|
||||
Short: "Generate an example configuration file",
|
||||
Long: `Generate an example YAML configuration file with all available options.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
yamlBytes, err := labeler.ExampleYAML()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate example config: %w", err)
|
||||
}
|
||||
if len(args) == 1 {
|
||||
if err := os.WriteFile(args[0], yamlBytes, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Wrote example config to %s\n", args[0])
|
||||
return nil
|
||||
}
|
||||
fmt.Print(string(yamlBytes))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
serveCmd.Flags().StringVarP(&configFile, "config", "c", "", "path to YAML configuration file")
|
||||
|
||||
configCmd.AddCommand(configInitCmd)
|
||||
|
||||
rootCmd.AddCommand(serveCmd)
|
||||
rootCmd.AddCommand(configCmd)
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,12 @@ var holdConfigTmpl string
|
||||
//go:embed configs/scanner.yaml.tmpl
|
||||
var scannerConfigTmpl string
|
||||
|
||||
//go:embed systemd/labeler.service.tmpl
|
||||
var labelerServiceTmpl string
|
||||
|
||||
//go:embed configs/labeler.yaml.tmpl
|
||||
var labelerConfigTmpl string
|
||||
|
||||
//go:embed configs/cloudinit.sh.tmpl
|
||||
var cloudInitTmpl string
|
||||
|
||||
@@ -111,9 +117,33 @@ func renderScannerServiceUnit(p scannerServiceUnitParams) (string, error) {
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// labelerServiceUnitParams holds values for rendering the labeler systemd unit.
|
||||
type labelerServiceUnitParams struct {
|
||||
DisplayName string // e.g. "Seamark"
|
||||
User string // e.g. "seamark"
|
||||
BinaryPath string // e.g. "/opt/seamark/bin/seamark-labeler"
|
||||
ConfigPath string // e.g. "/etc/seamark/labeler.yaml"
|
||||
DataDir string // e.g. "/var/lib/seamark"
|
||||
ServiceName string // e.g. "seamark-labeler"
|
||||
AppviewServiceName string // e.g. "seamark-appview" (After= dependency)
|
||||
}
|
||||
|
||||
func renderLabelerServiceUnit(p labelerServiceUnitParams) (string, error) {
|
||||
t, err := template.New("labeler-service").Parse(labelerServiceTmpl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse labeler service template: %w", err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, p); err != nil {
|
||||
return "", fmt.Errorf("render labeler service template: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
// generateAppviewCloudInit generates the cloud-init user-data script for the appview server.
|
||||
// Sets up the OS, directories, config, and systemd unit. Binaries are deployed separately via SCP.
|
||||
func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, error) {
|
||||
// When withLabeler is true, a second phase is appended that creates labeler data
|
||||
// directories and installs a labeler systemd service. Binaries are deployed separately via SCP.
|
||||
func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues, withLabeler bool) (string, error) {
|
||||
naming := cfg.Naming()
|
||||
|
||||
configYAML, err := renderConfig(appviewConfigTmpl, vals)
|
||||
@@ -133,7 +163,7 @@ func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, err
|
||||
return "", fmt.Errorf("appview service unit: %w", err)
|
||||
}
|
||||
|
||||
return generateCloudInit(cloudInitParams{
|
||||
script, err := generateCloudInit(cloudInitParams{
|
||||
BinaryName: naming.Appview(),
|
||||
ServiceUnit: serviceUnit,
|
||||
ConfigYAML: configYAML,
|
||||
@@ -146,6 +176,69 @@ func generateAppviewCloudInit(cfg *InfraConfig, vals *ConfigValues) (string, err
|
||||
LogFile: naming.LogFile(),
|
||||
DisplayName: naming.DisplayName(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !withLabeler {
|
||||
return script, nil
|
||||
}
|
||||
|
||||
// Render labeler config YAML
|
||||
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("labeler config: %w", err)
|
||||
}
|
||||
|
||||
// Append labeler setup phase
|
||||
labelerUnit, err := renderLabelerServiceUnit(labelerServiceUnitParams{
|
||||
DisplayName: naming.DisplayName(),
|
||||
User: naming.SystemUser(),
|
||||
BinaryPath: naming.InstallDir() + "/bin/" + naming.Labeler(),
|
||||
ConfigPath: naming.LabelerConfigPath(),
|
||||
DataDir: naming.BasePath(),
|
||||
ServiceName: naming.Labeler(),
|
||||
AppviewServiceName: naming.Appview(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("labeler service unit: %w", err)
|
||||
}
|
||||
|
||||
// Escape single quotes for heredoc embedding
|
||||
labelerUnit = strings.ReplaceAll(labelerUnit, "'", "'\\''")
|
||||
labelerConfigYAML = strings.ReplaceAll(labelerConfigYAML, "'", "'\\''")
|
||||
|
||||
labelerPhase := fmt.Sprintf(`
|
||||
# === Labeler Setup ===
|
||||
|
||||
# Labeler data dirs
|
||||
mkdir -p %s
|
||||
chown -R %s:%s %s
|
||||
|
||||
# Labeler config
|
||||
cat > %s << 'CFGEOF'
|
||||
%s
|
||||
CFGEOF
|
||||
|
||||
# Labeler systemd service
|
||||
cat > /etc/systemd/system/%s.service << 'SVCEOF'
|
||||
%s
|
||||
SVCEOF
|
||||
systemctl daemon-reload
|
||||
systemctl enable %s
|
||||
|
||||
echo "=== Labeler setup complete ==="
|
||||
`,
|
||||
naming.LabelerDataDir(),
|
||||
naming.SystemUser(), naming.SystemUser(), naming.LabelerDataDir(),
|
||||
naming.LabelerConfigPath(),
|
||||
labelerConfigYAML,
|
||||
naming.Labeler(),
|
||||
labelerUnit,
|
||||
naming.Labeler(),
|
||||
)
|
||||
|
||||
return script + labelerPhase, nil
|
||||
}
|
||||
|
||||
// generateHoldCloudInit generates the cloud-init user-data script for the hold server.
|
||||
|
||||
@@ -46,3 +46,5 @@ credential_helper:
|
||||
legal:
|
||||
company_name: Seamark
|
||||
jurisdiction: State of Texas, United States
|
||||
labeler:
|
||||
did: ""
|
||||
|
||||
19
deploy/upcloud/configs/labeler.yaml.tmpl
Normal file
19
deploy/upcloud/configs/labeler.yaml.tmpl
Normal file
@@ -0,0 +1,19 @@
|
||||
version: "0.1"
|
||||
log_level: info
|
||||
log_shipper:
|
||||
backend: ""
|
||||
url: ""
|
||||
batch_size: 100
|
||||
flush_interval: 5s
|
||||
username: ""
|
||||
password: ""
|
||||
labeler:
|
||||
enabled: true
|
||||
addr: :5002
|
||||
owner_did: ""
|
||||
db_path: "{{.BasePath}}/labeler/labeler.db"
|
||||
server:
|
||||
base_url: "https://seamark.dev"
|
||||
client_name: Seamark
|
||||
client_short_name: Seamark
|
||||
test_mode: false
|
||||
BIN
deploy/upcloud/deploy
Executable file
BIN
deploy/upcloud/deploy
Executable file
Binary file not shown.
@@ -57,5 +57,14 @@ func (n Naming) ScannerConfigPath() string { return n.ConfigDir() + "/scanner.ya
|
||||
// ScannerDataDir returns the scanner data directory (e.g. "/var/lib/seamark/scanner").
|
||||
func (n Naming) ScannerDataDir() string { return n.BasePath() + "/scanner" }
|
||||
|
||||
// Labeler returns the labeler binary/service name (e.g. "seamark-labeler").
|
||||
func (n Naming) Labeler() string { return n.ClientName + "-labeler" }
|
||||
|
||||
// LabelerConfigPath returns the labeler config file path.
|
||||
func (n Naming) LabelerConfigPath() string { return n.ConfigDir() + "/labeler.yaml" }
|
||||
|
||||
// LabelerDataDir returns the labeler data directory (e.g. "/var/lib/seamark/labeler").
|
||||
func (n Naming) LabelerDataDir() string { return n.BasePath() + "/labeler" }
|
||||
|
||||
// S3Name returns the name used for S3 storage, user, and bucket.
|
||||
func (n Naming) S3Name() string { return n.ClientName }
|
||||
|
||||
@@ -29,7 +29,8 @@ var provisionCmd = &cobra.Command{
|
||||
sshKey, _ := cmd.Flags().GetString("ssh-key")
|
||||
s3Secret, _ := cmd.Flags().GetString("s3-secret")
|
||||
withScanner, _ := cmd.Flags().GetBool("with-scanner")
|
||||
return cmdProvision(token, zone, plan, sshKey, s3Secret, withScanner)
|
||||
withLabeler, _ := cmd.Flags().GetBool("with-labeler")
|
||||
return cmdProvision(token, zone, plan, sshKey, s3Secret, withScanner, withLabeler)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -39,11 +40,12 @@ func init() {
|
||||
provisionCmd.Flags().String("ssh-key", "", "Path to SSH public key file (required)")
|
||||
provisionCmd.Flags().String("s3-secret", "", "S3 secret access key (for existing object storage)")
|
||||
provisionCmd.Flags().Bool("with-scanner", false, "Deploy vulnerability scanner alongside hold")
|
||||
provisionCmd.Flags().Bool("with-labeler", false, "Deploy content moderation labeler alongside appview")
|
||||
_ = provisionCmd.MarkFlagRequired("ssh-key")
|
||||
rootCmd.AddCommand(provisionCmd)
|
||||
}
|
||||
|
||||
func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bool) error {
|
||||
func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner, withLabeler bool) error {
|
||||
cfg, err := loadConfig(zone, plan, sshKeyPath, s3Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -98,6 +100,12 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
_ = saveState(state)
|
||||
}
|
||||
|
||||
// Labeler setup
|
||||
if withLabeler {
|
||||
state.LabelerEnabled = true
|
||||
_ = saveState(state)
|
||||
}
|
||||
|
||||
fmt.Printf("Provisioning %s infrastructure in zone %s...\n", naming.DisplayName(), cfg.Zone)
|
||||
if needsServers {
|
||||
fmt.Printf("Server plan: %s\n", cfg.Plan)
|
||||
@@ -198,7 +206,7 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
appviewCreated := false
|
||||
if state.Appview.UUID != "" {
|
||||
fmt.Printf("Appview: %s (exists)\n", state.Appview.UUID)
|
||||
appviewScript, err := generateAppviewCloudInit(cfg, vals)
|
||||
appviewScript, err := generateAppviewCloudInit(cfg, vals, state.LabelerEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -212,9 +220,18 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
if err := syncConfigKeys("appview", state.Appview.PublicIP, naming.AppviewConfigPath(), appviewConfigYAML); err != nil {
|
||||
return fmt.Errorf("appview config sync: %w", err)
|
||||
}
|
||||
if state.LabelerEnabled {
|
||||
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
|
||||
if err != nil {
|
||||
return fmt.Errorf("render labeler config: %w", err)
|
||||
}
|
||||
if err := syncConfigKeys("labeler", state.Appview.PublicIP, naming.LabelerConfigPath(), labelerConfigYAML); err != nil {
|
||||
return fmt.Errorf("labeler config sync: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Creating appview server...")
|
||||
appviewUserData, err := generateAppviewCloudInit(cfg, vals)
|
||||
appviewUserData, err := generateAppviewCloudInit(cfg, vals, state.LabelerEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -338,6 +355,12 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
if err := buildLocal(rootDir, outputPath, "./cmd/appview"); err != nil {
|
||||
return fmt.Errorf("build appview: %w", err)
|
||||
}
|
||||
if state.LabelerEnabled {
|
||||
outputPath := filepath.Join(rootDir, "bin", "atcr-labeler")
|
||||
if err := buildLocal(rootDir, outputPath, "./cmd/labeler"); err != nil {
|
||||
return fmt.Errorf("build labeler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if holdCreated {
|
||||
outputPath := filepath.Join(rootDir, "bin", "atcr-hold")
|
||||
@@ -371,6 +394,13 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
if err := scpFile(localPath, state.Appview.PublicIP, remotePath); err != nil {
|
||||
return fmt.Errorf("upload appview: %w", err)
|
||||
}
|
||||
if state.LabelerEnabled {
|
||||
labelerLocal := filepath.Join(rootDir, "bin", "atcr-labeler")
|
||||
labelerRemote := naming.InstallDir() + "/bin/" + naming.Labeler()
|
||||
if err := scpFile(labelerLocal, state.Appview.PublicIP, labelerRemote); err != nil {
|
||||
return fmt.Errorf("upload labeler: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if holdCreated {
|
||||
localPath := filepath.Join(rootDir, "bin", "atcr-hold")
|
||||
@@ -411,11 +441,14 @@ func cmdProvision(token, zone, plan, sshKeyPath, s3Secret string, withScanner bo
|
||||
} else {
|
||||
fmt.Println(" 1. Start services:")
|
||||
}
|
||||
services := []string{naming.Appview(), naming.Hold()}
|
||||
if state.ScannerEnabled {
|
||||
fmt.Printf(" systemctl start %s / %s / %s\n", naming.Appview(), naming.Hold(), naming.Scanner())
|
||||
} else {
|
||||
fmt.Printf(" systemctl start %s / %s\n", naming.Appview(), naming.Hold())
|
||||
services = append(services, naming.Scanner())
|
||||
}
|
||||
if state.LabelerEnabled {
|
||||
services = append(services, naming.Labeler())
|
||||
}
|
||||
fmt.Printf(" systemctl start %s\n", strings.Join(services, " / "))
|
||||
fmt.Println(" 2. Configure DNS records above")
|
||||
|
||||
return nil
|
||||
|
||||
@@ -20,6 +20,7 @@ type InfraState struct {
|
||||
ObjectStorage ObjectStorageState `json:"object_storage"`
|
||||
ScannerEnabled bool `json:"scanner_enabled,omitempty"`
|
||||
ScannerSecret string `json:"scanner_secret,omitempty"`
|
||||
LabelerEnabled bool `json:"labeler_enabled,omitempty"`
|
||||
}
|
||||
|
||||
// Naming returns a Naming helper, defaulting to "seamark" if ClientName is empty.
|
||||
|
||||
25
deploy/upcloud/systemd/labeler.service.tmpl
Normal file
25
deploy/upcloud/systemd/labeler.service.tmpl
Normal file
@@ -0,0 +1,25 @@
|
||||
[Unit]
|
||||
Description={{.DisplayName}} Labeler (Content Moderation)
|
||||
After=network-online.target {{.AppviewServiceName}}.service
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User={{.User}}
|
||||
Group={{.User}}
|
||||
ExecStart={{.BinaryPath}} serve --config {{.ConfigPath}}
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
|
||||
ReadWritePaths={{.DataDir}}
|
||||
ProtectSystem=strict
|
||||
ProtectHome=yes
|
||||
NoNewPrivileges=yes
|
||||
PrivateTmp=yes
|
||||
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier={{.ServiceName}}
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
@@ -24,7 +24,8 @@ var updateCmd = &cobra.Command{
|
||||
target = args[0]
|
||||
}
|
||||
withScanner, _ := cmd.Flags().GetBool("with-scanner")
|
||||
return cmdUpdate(target, withScanner)
|
||||
withLabeler, _ := cmd.Flags().GetBool("with-labeler")
|
||||
return cmdUpdate(target, withScanner, withLabeler)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -40,11 +41,12 @@ var sshCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
updateCmd.Flags().Bool("with-scanner", false, "Enable and deploy vulnerability scanner alongside hold")
|
||||
updateCmd.Flags().Bool("with-labeler", false, "Enable and deploy content moderation labeler alongside appview")
|
||||
rootCmd.AddCommand(updateCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
}
|
||||
|
||||
func cmdUpdate(target string, withScanner bool) error {
|
||||
func cmdUpdate(target string, withScanner, withLabeler bool) error {
|
||||
state, err := loadState()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -67,6 +69,12 @@ func cmdUpdate(target string, withScanner bool) error {
|
||||
_ = saveState(state)
|
||||
}
|
||||
|
||||
// Enable labeler retroactively via --with-labeler on update
|
||||
if withLabeler && !state.LabelerEnabled {
|
||||
state.LabelerEnabled = true
|
||||
_ = saveState(state)
|
||||
}
|
||||
|
||||
vals := configValsFromState(state)
|
||||
|
||||
targets := map[string]struct {
|
||||
@@ -144,6 +152,21 @@ func cmdUpdate(target string, withScanner bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Build labeler locally if needed
|
||||
needLabeler := false
|
||||
for _, name := range toUpdate {
|
||||
if name == "appview" && state.LabelerEnabled {
|
||||
needLabeler = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if needLabeler {
|
||||
outputPath := filepath.Join(rootDir, "bin", "atcr-labeler")
|
||||
if err := buildLocal(rootDir, outputPath, "./cmd/labeler"); err != nil {
|
||||
return fmt.Errorf("build labeler: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Deploy each target
|
||||
for _, name := range toUpdate {
|
||||
t := targets[name]
|
||||
@@ -244,13 +267,65 @@ curl -sf http://localhost:9090/healthz > /dev/null && echo "SCANNER_HEALTH_OK" |
|
||||
`
|
||||
}
|
||||
|
||||
// Labeler additions for appview server
|
||||
labelerRestart := ""
|
||||
if name == "appview" && state.LabelerEnabled {
|
||||
// Sync labeler config keys
|
||||
labelerConfigYAML, err := renderConfig(labelerConfigTmpl, vals)
|
||||
if err != nil {
|
||||
return fmt.Errorf("render labeler config: %w", err)
|
||||
}
|
||||
if err := syncConfigKeys("labeler", t.ip, naming.LabelerConfigPath(), labelerConfigYAML); err != nil {
|
||||
return fmt.Errorf("labeler config sync: %w", err)
|
||||
}
|
||||
|
||||
// Sync labeler service unit
|
||||
labelerUnit, err := renderLabelerServiceUnit(labelerServiceUnitParams{
|
||||
DisplayName: naming.DisplayName(),
|
||||
User: naming.SystemUser(),
|
||||
BinaryPath: naming.InstallDir() + "/bin/" + naming.Labeler(),
|
||||
ConfigPath: naming.LabelerConfigPath(),
|
||||
DataDir: naming.BasePath(),
|
||||
ServiceName: naming.Labeler(),
|
||||
AppviewServiceName: naming.Appview(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("render labeler service unit: %w", err)
|
||||
}
|
||||
labelerUnitChanged, err := syncServiceUnit("labeler", t.ip, naming.Labeler(), labelerUnit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("labeler service unit sync: %w", err)
|
||||
}
|
||||
if labelerUnitChanged {
|
||||
daemonReload = "systemctl daemon-reload"
|
||||
}
|
||||
|
||||
// Upload labeler binary
|
||||
labelerLocal := filepath.Join(rootDir, "bin", "atcr-labeler")
|
||||
labelerRemote := naming.InstallDir() + "/bin/" + naming.Labeler()
|
||||
if err := scpFile(labelerLocal, t.ip, labelerRemote); err != nil {
|
||||
return fmt.Errorf("upload labeler: %w", err)
|
||||
}
|
||||
|
||||
// Ensure labeler data dirs exist
|
||||
labelerSetup := fmt.Sprintf(`mkdir -p %s
|
||||
chown -R %s:%s %s`,
|
||||
naming.LabelerDataDir(),
|
||||
naming.SystemUser(), naming.SystemUser(), naming.LabelerDataDir())
|
||||
if _, err := runSSH(t.ip, labelerSetup, false); err != nil {
|
||||
return fmt.Errorf("labeler dir setup: %w", err)
|
||||
}
|
||||
|
||||
labelerRestart = fmt.Sprintf("\nsystemctl restart %s", naming.Labeler())
|
||||
}
|
||||
|
||||
// Restart services and health check
|
||||
restartScript := fmt.Sprintf(`set -euo pipefail
|
||||
%s
|
||||
systemctl restart %s%s
|
||||
systemctl restart %s%s%s
|
||||
sleep 2
|
||||
curl -sf %s > /dev/null && echo "HEALTH_OK" || echo "HEALTH_FAIL"
|
||||
%s`, daemonReload, t.serviceName, scannerRestart, t.healthURL, scannerHealthCheck)
|
||||
%s`, daemonReload, t.serviceName, scannerRestart, labelerRestart, t.healthURL, scannerHealthCheck)
|
||||
|
||||
output, err := runSSH(t.ip, restartScript, true)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,7 @@ type Config struct {
|
||||
Auth AuthConfig `yaml:"auth" comment:"JWT authentication settings."`
|
||||
CredentialHelper CredentialHelperConfig `yaml:"credential_helper" comment:"Credential helper download settings."`
|
||||
Legal LegalConfig `yaml:"legal" comment:"Legal page customization for self-hosted instances."`
|
||||
Labeler LabelerRefConfig `yaml:"labeler" comment:"ATProto labeler for content moderation (DMCA takedowns)."`
|
||||
Billing billing.Config `yaml:"billing" comment:"Stripe billing integration (requires -tags billing build)."`
|
||||
Distribution *configuration.Configuration `yaml:"-"` // Wrapped distribution config for compatibility
|
||||
}
|
||||
@@ -140,6 +141,12 @@ type LegalConfig struct {
|
||||
Jurisdiction string `yaml:"jurisdiction" comment:"Governing law jurisdiction for legal terms."`
|
||||
}
|
||||
|
||||
// LabelerRefConfig defines the connection to an ATProto labeler service.
|
||||
type LabelerRefConfig struct {
|
||||
// DID or URL of the labeler service for content moderation.
|
||||
DID string `yaml:"did" comment:"DID or URL of the ATProto labeler (e.g., did:web:labeler.atcr.io). Empty disables label filtering."`
|
||||
}
|
||||
|
||||
// setDefaults registers all default values on the given Viper instance.
|
||||
func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("version", "0.1")
|
||||
@@ -193,6 +200,9 @@ func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("legal.company_name", "")
|
||||
v.SetDefault("legal.jurisdiction", "")
|
||||
|
||||
// Labeler defaults
|
||||
v.SetDefault("labeler.did", "")
|
||||
|
||||
// Log formatter (used by distribution config, not in Config struct)
|
||||
v.SetDefault("log_formatter", "text")
|
||||
}
|
||||
|
||||
79
pkg/appview/db/labels.go
Normal file
79
pkg/appview/db/labels.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LabelChecker wraps a database connection to check takedown labels.
|
||||
// Implements middleware.LabelChecker interface.
|
||||
type LabelChecker struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewLabelChecker creates a new LabelChecker.
|
||||
func NewLabelChecker(database *sql.DB) *LabelChecker {
|
||||
return &LabelChecker{db: database}
|
||||
}
|
||||
|
||||
// IsTakenDown checks if a (DID, repository) pair has an active takedown label.
|
||||
func (lc *LabelChecker) IsTakenDown(did, repository string) (bool, error) {
|
||||
return IsTakenDown(lc.db, did, repository)
|
||||
}
|
||||
|
||||
// Label represents an ATProto label mirrored from a labeler service.
|
||||
type Label struct {
|
||||
ID int64
|
||||
Src string
|
||||
URI string
|
||||
Val string
|
||||
Neg bool
|
||||
Cts time.Time
|
||||
SubjectDID string
|
||||
SubjectRepo string
|
||||
Seq int64
|
||||
}
|
||||
|
||||
// IsTakenDown checks if a (DID, repository) pair has an active !takedown label.
|
||||
// Also matches user-level labels (subject_repo = ”).
|
||||
func IsTakenDown(db DBTX, did, repository string) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRow(
|
||||
`SELECT EXISTS(
|
||||
SELECT 1 FROM labels l1
|
||||
WHERE l1.subject_did = ?
|
||||
AND (l1.subject_repo = ? OR l1.subject_repo = '')
|
||||
AND l1.val = '!takedown' AND l1.neg = 0
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM labels l2
|
||||
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
|
||||
AND l2.neg = 1 AND l2.id > l1.id
|
||||
)
|
||||
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)
|
||||
)`,
|
||||
did, repository,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// UpsertLabel inserts or updates a label from a labeler subscription.
|
||||
func UpsertLabel(db DBTX, l *Label) error {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo, seq)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts, seq = excluded.seq`,
|
||||
l.Src, l.URI, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339),
|
||||
l.SubjectDID, l.SubjectRepo, l.Seq,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLabelCursor returns the latest sequence number for a given labeler source.
|
||||
func GetLabelCursor(db DBTX, src string) (int64, error) {
|
||||
var cursor int64
|
||||
err := db.QueryRow(
|
||||
`SELECT COALESCE(MAX(seq), 0) FROM labels WHERE src = ?`,
|
||||
src,
|
||||
).Scan(&cursor)
|
||||
return cursor, err
|
||||
}
|
||||
16
pkg/appview/db/migrations/0017_create_labels.yaml
Normal file
16
pkg/appview/db/migrations/0017_create_labels.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
description: Create labels table for ATProto content moderation (takedowns)
|
||||
query: |
|
||||
CREATE TABLE IF NOT EXISTS labels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
src TEXT NOT NULL,
|
||||
uri TEXT NOT NULL,
|
||||
val TEXT NOT NULL,
|
||||
neg BOOLEAN NOT NULL DEFAULT 0,
|
||||
cts TIMESTAMP NOT NULL,
|
||||
subject_did TEXT NOT NULL,
|
||||
subject_repo TEXT NOT NULL DEFAULT '',
|
||||
seq INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(src, uri, val, neg)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_val ON labels(val);
|
||||
@@ -74,13 +74,18 @@ func SearchRepositories(db DBTX, query string, limit, offset int, currentUserDID
|
||||
SELECT DISTINCT lm.did, lm.repository, lm.latest_id
|
||||
FROM latest_manifests lm
|
||||
JOIN users u ON lm.did = u.did
|
||||
WHERE u.handle LIKE ? ESCAPE '\'
|
||||
WHERE (u.handle LIKE ? ESCAPE '\'
|
||||
OR u.did = ?
|
||||
OR lm.repository LIKE ? ESCAPE '\'
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM repository_annotations ra
|
||||
WHERE ra.did = lm.did AND ra.repository = lm.repository
|
||||
AND ra.value LIKE ? ESCAPE '\'
|
||||
))
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM labels
|
||||
WHERE (subject_did = lm.did AND (subject_repo = lm.repository OR subject_repo = ''))
|
||||
AND val = '!takedown' AND neg = 0
|
||||
)
|
||||
),
|
||||
repo_stats AS (
|
||||
@@ -1953,6 +1958,11 @@ func GetRepoCards(db DBTX, limit int, currentUserDID string, sortOrder RepoCardS
|
||||
JOIN users u ON m.did = u.did
|
||||
LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
|
||||
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM labels
|
||||
WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
|
||||
AND val = '!takedown' AND neg = 0
|
||||
)
|
||||
ORDER BY ` + orderBy + `
|
||||
LIMIT ?
|
||||
`
|
||||
@@ -2026,6 +2036,11 @@ func GetUserRepoCards(db DBTX, userDID string, currentUserDID string) ([]RepoCar
|
||||
JOIN users u ON m.did = u.did
|
||||
LEFT JOIN repository_stats rs ON m.did = rs.did AND m.repository = rs.repository
|
||||
LEFT JOIN repo_pages rp ON m.did = rp.did AND m.repository = rp.repository
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM labels
|
||||
WHERE (subject_did = m.did AND (subject_repo = m.repository OR subject_repo = ''))
|
||||
AND val = '!takedown' AND neg = 0
|
||||
)
|
||||
ORDER BY MAX(rs.last_push, m.created_at) DESC
|
||||
`
|
||||
|
||||
|
||||
@@ -271,3 +271,18 @@ CREATE TABLE IF NOT EXISTS scans (
|
||||
PRIMARY KEY(hold_did, manifest_digest)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_scans_user ON scans(user_did);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS labels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
src TEXT NOT NULL,
|
||||
uri TEXT NOT NULL,
|
||||
val TEXT NOT NULL,
|
||||
neg BOOLEAN NOT NULL DEFAULT 0,
|
||||
cts TIMESTAMP NOT NULL,
|
||||
subject_did TEXT NOT NULL,
|
||||
subject_repo TEXT NOT NULL DEFAULT '',
|
||||
seq INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(src, uri, val, neg)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_val ON labels(val);
|
||||
|
||||
@@ -45,6 +45,12 @@ func (h *DigestDetailHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for takedown labels
|
||||
if taken, _ := db.IsTakenDown(h.ReadOnlyDB, did, repository); taken {
|
||||
RenderNotFound(w, r, &h.BaseUIHandler)
|
||||
return
|
||||
}
|
||||
|
||||
owner, err := db.GetUserByDID(h.ReadOnlyDB, did)
|
||||
if err != nil || owner == nil {
|
||||
RenderNotFound(w, r, &h.BaseUIHandler)
|
||||
|
||||
@@ -34,6 +34,12 @@ func (h *RepositoryPageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
// Check for takedown labels
|
||||
if taken, _ := db.IsTakenDown(h.ReadOnlyDB, did, repository); taken {
|
||||
RenderNotFound(w, r, &h.BaseUIHandler)
|
||||
return
|
||||
}
|
||||
|
||||
// Look up user by DID
|
||||
owner, err := db.GetUserByDID(h.ReadOnlyDB, did)
|
||||
if err != nil {
|
||||
|
||||
@@ -229,6 +229,20 @@ func (p *Processor) ProcessRecord(ctx context.Context, did, collection, rkey str
|
||||
}
|
||||
}
|
||||
|
||||
// Skip ingestion for taken-down content
|
||||
if !isDelete && data != nil {
|
||||
if repo := extractRepoFromRecord(collection, data); repo != "" {
|
||||
if taken, _ := db.IsTakenDown(p.db, did, repo); taken {
|
||||
slog.Debug("Skipping taken-down content",
|
||||
"component", "processor",
|
||||
"did", did,
|
||||
"collection", collection,
|
||||
"repository", repo)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User-activity collections create/update user entries
|
||||
// Skip for deletes - user should already exist, and we don't need to resolve identity
|
||||
if !isDelete {
|
||||
@@ -971,3 +985,23 @@ func (p *Processor) ProcessAccount(ctx context.Context, did string, active bool,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractRepoFromRecord extracts the repository field from a record's JSON data.
|
||||
// Returns empty string for collections that don't have a repository field
|
||||
// (e.g., sailor profile, captain, crew).
|
||||
func extractRepoFromRecord(collection string, data []byte) string {
|
||||
switch collection {
|
||||
case atproto.ManifestCollection,
|
||||
atproto.TagCollection,
|
||||
atproto.RepoPageCollection,
|
||||
atproto.StatsCollection,
|
||||
atproto.ScanCollection:
|
||||
var rec struct {
|
||||
Repository string `json:"repository"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &rec); err == nil {
|
||||
return rec.Repository
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
239
pkg/appview/labeler/subscriber.go
Normal file
239
pkg/appview/labeler/subscriber.go
Normal file
@@ -0,0 +1,239 @@
|
||||
// Package labeler provides a subscription client for consuming labels
|
||||
// from an ATProto labeler service.
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/appview/db"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// LabelsMessage is the wire format for subscribeLabels events.
|
||||
type LabelsMessage struct {
|
||||
Seq int64 `json:"seq"`
|
||||
Labels []LabelEvent `json:"labels"`
|
||||
}
|
||||
|
||||
// LabelEvent is a single label from the labeler.
|
||||
type LabelEvent struct {
|
||||
Src string `json:"src"`
|
||||
URI string `json:"uri"`
|
||||
CID string `json:"cid,omitempty"`
|
||||
Val string `json:"val"`
|
||||
Neg bool `json:"neg"`
|
||||
Cts string `json:"cts"`
|
||||
Exp string `json:"exp,omitempty"`
|
||||
}
|
||||
|
||||
// Subscriber connects to a labeler's subscribeLabels endpoint
|
||||
// and mirrors labels into the appview database.
|
||||
type Subscriber struct {
|
||||
labelerURL string
|
||||
database *sql.DB
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSubscriber creates a new labeler subscriber.
|
||||
func NewSubscriber(labelerURL string, database *sql.DB) *Subscriber {
|
||||
return &Subscriber{
|
||||
labelerURL: labelerURL,
|
||||
database: database,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the subscription loop in a goroutine.
|
||||
func (s *Subscriber) Start() {
|
||||
go s.run()
|
||||
}
|
||||
|
||||
// Stop signals the subscriber to shut down.
|
||||
func (s *Subscriber) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
func (s *Subscriber) run() {
|
||||
backoff := time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := s.connect(); err != nil {
|
||||
slog.Warn("Labeler subscription error, reconnecting",
|
||||
"error", err,
|
||||
"backoff", backoff,
|
||||
)
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
if backoff < 30*time.Second {
|
||||
backoff *= 2
|
||||
}
|
||||
} else {
|
||||
backoff = time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscriber) connect() error {
|
||||
// Get cursor from DB
|
||||
// Use the labeler URL as src identifier
|
||||
labelerDID := extractDIDFromURL(s.labelerURL)
|
||||
cursor, err := db.GetLabelCursor(s.database, labelerDID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get cursor: %w", err)
|
||||
}
|
||||
|
||||
// Build WebSocket URL
|
||||
wsURL := toWebSocketURL(s.labelerURL) + "/xrpc/com.atproto.label.subscribeLabels"
|
||||
if cursor > 0 {
|
||||
wsURL += fmt.Sprintf("?cursor=%d", cursor)
|
||||
}
|
||||
|
||||
slog.Info("Connecting to labeler", "url", wsURL, "cursor", cursor)
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket dial failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
slog.Info("Connected to labeler", "url", s.labelerURL)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
var msg LabelsMessage
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return fmt.Errorf("read error: %w", err)
|
||||
}
|
||||
|
||||
for _, le := range msg.Labels {
|
||||
cts, _ := time.Parse(time.RFC3339, le.Cts)
|
||||
did, repo := extractSubjectFromURI(le.URI)
|
||||
|
||||
label := &db.Label{
|
||||
Src: le.Src,
|
||||
URI: le.URI,
|
||||
Val: le.Val,
|
||||
Neg: le.Neg,
|
||||
Cts: cts,
|
||||
SubjectDID: did,
|
||||
SubjectRepo: repo,
|
||||
Seq: msg.Seq,
|
||||
}
|
||||
|
||||
if err := db.UpsertLabel(s.database, label); err != nil {
|
||||
slog.Warn("Failed to upsert label", "uri", le.URI, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Mirrored label",
|
||||
"uri", le.URI,
|
||||
"val", le.Val,
|
||||
"neg", le.Neg,
|
||||
"subject_did", did,
|
||||
"subject_repo", repo,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractSubjectFromURI extracts the DID and repository from an AT URI.
|
||||
// Examples:
|
||||
//
|
||||
// at://did:plc:xyz → (did:plc:xyz, "")
|
||||
// at://did:plc:xyz/io.atcr.manifest/abc → (did:plc:xyz, "") - repo extracted from record
|
||||
// at://did:plc:xyz/io.atcr.repo/myimage → (did:plc:xyz, "myimage")
|
||||
func extractSubjectFromURI(uri string) (did, repo string) {
|
||||
trimmed := strings.TrimPrefix(uri, "at://")
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
if len(parts) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
did = parts[0]
|
||||
|
||||
// For repo-level summary labels: at://did/io.atcr.repo/reponame
|
||||
if len(parts) >= 3 && parts[1] == "io.atcr.repo" {
|
||||
repo = parts[2]
|
||||
}
|
||||
return did, repo
|
||||
}
|
||||
|
||||
// extractDIDFromURL derives a did:web from a labeler URL.
|
||||
func extractDIDFromURL(labelerURL string) string {
|
||||
u, err := url.Parse(labelerURL)
|
||||
if err != nil {
|
||||
return labelerURL
|
||||
}
|
||||
host := u.Hostname()
|
||||
if port := u.Port(); port != "" {
|
||||
host += "%3A" + port
|
||||
}
|
||||
return "did:web:" + host
|
||||
}
|
||||
|
||||
// toWebSocketURL converts an HTTP URL to a WebSocket URL.
|
||||
func toWebSocketURL(httpURL string) string {
|
||||
u, err := url.Parse(httpURL)
|
||||
if err != nil {
|
||||
return httpURL
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
default:
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ParseLabelerURL parses a labeler DID or URL into an HTTP URL.
|
||||
func ParseLabelerURL(labelerDIDOrURL string) string {
|
||||
if strings.HasPrefix(labelerDIDOrURL, "http://") || strings.HasPrefix(labelerDIDOrURL, "https://") {
|
||||
return labelerDIDOrURL
|
||||
}
|
||||
if strings.HasPrefix(labelerDIDOrURL, "did:web:") {
|
||||
host := strings.TrimPrefix(labelerDIDOrURL, "did:web:")
|
||||
host = strings.ReplaceAll(host, "%3A", ":")
|
||||
return "https://" + host
|
||||
}
|
||||
return labelerDIDOrURL
|
||||
}
|
||||
|
||||
// SubscriberFromConfig creates a Subscriber from a labeler DID/URL config value.
|
||||
// Returns nil if labelerDIDOrURL is empty.
|
||||
func SubscriberFromConfig(labelerDIDOrURL string, database *sql.DB) *Subscriber {
|
||||
if labelerDIDOrURL == "" {
|
||||
return nil
|
||||
}
|
||||
labelerURL := ParseLabelerURL(labelerDIDOrURL)
|
||||
return NewSubscriber(labelerURL, database)
|
||||
}
|
||||
|
||||
// DecodeLabelsFromJSON decodes a JSON-encoded labels message.
|
||||
func DecodeLabelsFromJSON(data []byte) (*LabelsMessage, error) {
|
||||
var msg LabelsMessage
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
@@ -166,6 +166,11 @@ func (vc *validationCache) getOrFetch(ctx context.Context, cacheKey string, fetc
|
||||
return serviceToken, err
|
||||
}
|
||||
|
||||
// LabelChecker checks whether content has been taken down via ATProto labels.
|
||||
type LabelChecker interface {
|
||||
IsTakenDown(did, repository string) (bool, error)
|
||||
}
|
||||
|
||||
// Global variables for initialization only
|
||||
// These are set by main.go during startup and copied into NamespaceResolver instances.
|
||||
// After initialization, request handling uses the NamespaceResolver's instance fields.
|
||||
@@ -175,6 +180,7 @@ var (
|
||||
globalAuthorizer auth.HoldAuthorizer
|
||||
globalWebhookDispatcher storage.PushWebhookDispatcher
|
||||
globalManifestRefChecker storage.ManifestReferenceChecker
|
||||
globalLabelChecker LabelChecker
|
||||
)
|
||||
|
||||
// SetGlobalRefresher sets the OAuth refresher instance during initialization
|
||||
@@ -194,6 +200,11 @@ func SetGlobalManifestRefChecker(checker storage.ManifestReferenceChecker) {
|
||||
globalManifestRefChecker = checker
|
||||
}
|
||||
|
||||
// SetGlobalLabelChecker sets the label checker instance during initialization
|
||||
func SetGlobalLabelChecker(checker LabelChecker) {
|
||||
globalLabelChecker = checker
|
||||
}
|
||||
|
||||
// SetGlobalAuthorizer sets the authorizer instance during initialization
|
||||
// Must be called before the registry starts serving requests
|
||||
func SetGlobalAuthorizer(authorizer auth.HoldAuthorizer) {
|
||||
@@ -304,6 +315,16 @@ func (nr *NamespaceResolver) Repository(ctx context.Context, name reference.Name
|
||||
|
||||
slog.Debug("Resolved identity", "component", "registry/middleware", "did", did, "pds", pdsEndpoint, "handle", handle)
|
||||
|
||||
// Check for takedown labels before proceeding
|
||||
if globalLabelChecker != nil {
|
||||
if taken, _ := globalLabelChecker.IsTakenDown(did, imageName); taken {
|
||||
return nil, errcode.Error{
|
||||
Code: errcode.ErrorCodeDenied,
|
||||
Message: "this repository has been removed due to a policy violation",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query for hold DID - either user's hold or default hold service
|
||||
// Also returns the sailor profile so we can read preferences (e.g. AutoRemoveUntagged)
|
||||
holdDID, sailorProfile := nr.findHoldDIDAndProfile(ctx, did, pdsEndpoint)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"atcr.io/pkg/appview/db"
|
||||
"atcr.io/pkg/appview/holdhealth"
|
||||
"atcr.io/pkg/appview/jetstream"
|
||||
appviewlabeler "atcr.io/pkg/appview/labeler"
|
||||
"atcr.io/pkg/appview/middleware"
|
||||
"atcr.io/pkg/appview/readme"
|
||||
"atcr.io/pkg/appview/routes"
|
||||
@@ -236,6 +237,9 @@ func NewAppViewServer(cfg *Config, branding *BrandingOverrides) (*AppViewServer,
|
||||
middleware.SetGlobalDatabase(holdDIDDB)
|
||||
middleware.SetGlobalManifestRefChecker(holdDIDDB)
|
||||
|
||||
// Set label checker for takedown filtering
|
||||
middleware.SetGlobalLabelChecker(db.NewLabelChecker(s.Database))
|
||||
|
||||
// Create RemoteHoldAuthorizer for hold authorization with caching
|
||||
s.HoldAuthorizer = auth.NewRemoteHoldAuthorizer(s.Database, testMode)
|
||||
middleware.SetGlobalAuthorizer(s.HoldAuthorizer)
|
||||
@@ -287,6 +291,15 @@ func NewAppViewServer(cfg *Config, branding *BrandingOverrides) (*AppViewServer,
|
||||
// Initialize Jetstream workers
|
||||
s.initializeJetstream()
|
||||
|
||||
// Initialize labeler subscriber
|
||||
if cfg.Labeler.DID != "" {
|
||||
sub := appviewlabeler.SubscriberFromConfig(cfg.Labeler.DID, s.Database)
|
||||
if sub != nil {
|
||||
sub.Start()
|
||||
slog.Info("Labeler subscriber started", "labeler", cfg.Labeler.DID)
|
||||
}
|
||||
}
|
||||
|
||||
// Create main chi router
|
||||
mainRouter := chi.NewRouter()
|
||||
|
||||
|
||||
106
pkg/labeler/auth.go
Normal file
106
pkg/labeler/auth.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Session represents an authenticated admin session.
|
||||
type Session struct {
|
||||
DID string
|
||||
Handle string
|
||||
}
|
||||
|
||||
// Auth manages admin authentication.
|
||||
type Auth struct {
|
||||
ownerDID string
|
||||
sessions map[string]*Session
|
||||
sessionsMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth manager.
|
||||
func NewAuth(ownerDID string) *Auth {
|
||||
return &Auth{
|
||||
ownerDID: ownerDID,
|
||||
sessions: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth) createSession(did, handle string) (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
a.sessionsMu.Lock()
|
||||
a.sessions[token] = &Session{DID: did, Handle: handle}
|
||||
a.sessionsMu.Unlock()
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (a *Auth) getSession(token string) *Session {
|
||||
a.sessionsMu.RLock()
|
||||
defer a.sessionsMu.RUnlock()
|
||||
return a.sessions[token]
|
||||
}
|
||||
|
||||
func (a *Auth) deleteSession(token string) {
|
||||
a.sessionsMu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.sessionsMu.Unlock()
|
||||
}
|
||||
|
||||
const sessionCookieName = "labeler_session"
|
||||
|
||||
func setSessionCookie(w http.ResponseWriter, r *http.Request, token string) {
|
||||
secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: 86400, // 24 hours
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func clearSessionCookie(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func getSessionCookie(r *http.Request) (string, bool) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return cookie.Value, true
|
||||
}
|
||||
|
||||
// RequireOwner is middleware that checks the session belongs to the owner DID.
|
||||
func (a *Auth) RequireOwner(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, ok := getSessionCookie(r)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/auth/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
session := a.getSession(token)
|
||||
if session == nil || session.DID != a.ownerDID {
|
||||
http.Redirect(w, r, "/auth/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
146
pkg/labeler/config.go
Normal file
146
pkg/labeler/config.go
Normal file
@@ -0,0 +1,146 @@
|
||||
// Package labeler implements the ATCR labeler service, an ATProto-compatible
|
||||
// content moderation service for issuing takedown labels on container registry content.
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"atcr.io/pkg/config"
|
||||
)
|
||||
|
||||
// Config represents the labeler service configuration.
|
||||
// It reuses the appview config YAML structure, reading from the "labeler" section.
|
||||
type Config struct {
|
||||
Version string `yaml:"version" comment:"Configuration format version."`
|
||||
LogLevel string `yaml:"log_level" comment:"Log level: debug, info, warn, error."`
|
||||
Labeler LabelerConfig `yaml:"labeler" comment:"Labeler service settings."`
|
||||
Server AppviewServerConfig `yaml:"server" comment:"AppView server settings (shared config)."`
|
||||
LogShipper config.LogShipperConfig `yaml:"log_shipper" comment:"Remote log shipping settings."`
|
||||
}
|
||||
|
||||
// LabelerConfig defines labeler-specific settings.
|
||||
type LabelerConfig struct {
|
||||
// Enable the labeler service.
|
||||
Enabled bool `yaml:"enabled" comment:"Enable the labeler service."`
|
||||
|
||||
// Listen address for the labeler HTTP server.
|
||||
Addr string `yaml:"addr" comment:"Listen address for labeler (e.g., :5002)."`
|
||||
|
||||
// DID of the labeler admin. Only this DID can log into the admin panel.
|
||||
OwnerDID string `yaml:"owner_did" comment:"DID of the labeler admin. Only this DID can log into the admin panel."`
|
||||
|
||||
// Path to labeler SQLite database.
|
||||
DBPath string `yaml:"db_path" comment:"Path to labeler SQLite database."`
|
||||
}
|
||||
|
||||
// AppviewServerConfig is a subset of the appview ServerConfig that the labeler needs.
|
||||
type AppviewServerConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
ClientName string `yaml:"client_name"`
|
||||
ClientShortName string `yaml:"client_short_name"`
|
||||
TestMode bool `yaml:"test_mode"`
|
||||
}
|
||||
|
||||
// PublicURL returns the labeler's public URL derived from the appview base URL.
|
||||
// If appview is https://atcr.io, labeler is https://labeler.atcr.io.
|
||||
func (c *Config) PublicURL() string {
|
||||
u, err := url.Parse(c.Server.BaseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
u.Host = "labeler." + u.Host
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// DID returns the labeler's did:web identity derived from its public URL.
|
||||
func (c *Config) DID() string {
|
||||
u, err := url.Parse(c.PublicURL())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
host := u.Hostname()
|
||||
if port := u.Port(); port != "" {
|
||||
host += "%3A" + port
|
||||
}
|
||||
return "did:web:" + host
|
||||
}
|
||||
|
||||
func setDefaults(v *viper.Viper) {
|
||||
v.SetDefault("version", "0.1")
|
||||
v.SetDefault("log_level", "info")
|
||||
|
||||
// Labeler defaults
|
||||
v.SetDefault("labeler.enabled", false)
|
||||
v.SetDefault("labeler.addr", ":5002")
|
||||
v.SetDefault("labeler.owner_did", "")
|
||||
v.SetDefault("labeler.db_path", "/var/lib/atcr-labeler/labeler.db")
|
||||
|
||||
// Server defaults (read from shared appview config)
|
||||
v.SetDefault("server.base_url", "")
|
||||
v.SetDefault("server.client_name", "AT Container Registry")
|
||||
v.SetDefault("server.client_short_name", "ATCR")
|
||||
v.SetDefault("server.test_mode", false)
|
||||
}
|
||||
|
||||
// LoadConfig loads the labeler configuration from the appview config YAML.
|
||||
func LoadConfig(yamlPath string) (*Config, error) {
|
||||
v := config.NewViper("LABELER", yamlPath)
|
||||
setDefaults(v)
|
||||
|
||||
cfg := &Config{}
|
||||
if err := v.Unmarshal(cfg, config.UnmarshalOption()); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Also try ATCR_ prefix for shared server config
|
||||
atcrV := config.NewViper("ATCR", yamlPath)
|
||||
if baseURL := atcrV.GetString("server.base_url"); baseURL != "" && cfg.Server.BaseURL == "" {
|
||||
cfg.Server.BaseURL = baseURL
|
||||
}
|
||||
if clientName := atcrV.GetString("server.client_name"); clientName != "" && cfg.Server.ClientName == "" {
|
||||
cfg.Server.ClientName = clientName
|
||||
}
|
||||
if clientShortName := atcrV.GetString("server.client_short_name"); clientShortName != "" && cfg.Server.ClientShortName == "" {
|
||||
cfg.Server.ClientShortName = clientShortName
|
||||
}
|
||||
if atcrV.GetBool("server.test_mode") {
|
||||
cfg.Server.TestMode = true
|
||||
}
|
||||
|
||||
// Validation
|
||||
if cfg.Server.BaseURL == "" {
|
||||
return nil, fmt.Errorf("server.base_url is required")
|
||||
}
|
||||
if cfg.Labeler.OwnerDID == "" {
|
||||
return nil, fmt.Errorf("labeler.owner_did is required")
|
||||
}
|
||||
if !strings.HasPrefix(cfg.Labeler.OwnerDID, "did:") {
|
||||
return nil, fmt.Errorf("labeler.owner_did must be a DID (got %q)", cfg.Labeler.OwnerDID)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ExampleYAML generates an example labeler configuration file.
|
||||
func ExampleYAML() ([]byte, error) {
|
||||
cfg := &Config{
|
||||
Version: "0.1",
|
||||
LogLevel: "info",
|
||||
Server: AppviewServerConfig{
|
||||
BaseURL: "https://atcr.io",
|
||||
ClientName: "AT Container Registry",
|
||||
ClientShortName: "ATCR",
|
||||
},
|
||||
Labeler: LabelerConfig{
|
||||
Enabled: true,
|
||||
Addr: ":5002",
|
||||
OwnerDID: "did:plc:your-did-here",
|
||||
DBPath: "/var/lib/atcr-labeler/labeler.db",
|
||||
},
|
||||
}
|
||||
return config.MarshalCommentedYAML("ATCR Labeler Configuration", cfg)
|
||||
}
|
||||
51
pkg/labeler/config_test.go
Normal file
51
pkg/labeler/config_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package labeler
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestConfig_PublicURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
want string
|
||||
}{
|
||||
{"standard", "https://atcr.io", "https://labeler.atcr.io"},
|
||||
{"with port", "https://atcr.io:8080", "https://labeler.atcr.io:8080"},
|
||||
{"localhost", "http://localhost:5000", "http://labeler.localhost:5000"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Server: AppviewServerConfig{BaseURL: tt.baseURL},
|
||||
}
|
||||
got := cfg.PublicURL()
|
||||
if got != tt.want {
|
||||
t.Errorf("PublicURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_DID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
want string
|
||||
}{
|
||||
{"standard", "https://atcr.io", "did:web:labeler.atcr.io"},
|
||||
{"with port", "https://atcr.io:8080", "did:web:labeler.atcr.io%3A8080"},
|
||||
{"localhost", "http://localhost:5000", "did:web:labeler.localhost%3A5000"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Server: AppviewServerConfig{BaseURL: tt.baseURL},
|
||||
}
|
||||
got := cfg.DID()
|
||||
if got != tt.want {
|
||||
t.Errorf("DID() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
301
pkg/labeler/db.go
Normal file
301
pkg/labeler/db.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "github.com/tursodatabase/go-libsql"
|
||||
)
|
||||
|
||||
const schema = `
|
||||
CREATE TABLE IF NOT EXISTS labels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
src TEXT NOT NULL,
|
||||
uri TEXT NOT NULL,
|
||||
cid TEXT,
|
||||
val TEXT NOT NULL,
|
||||
neg BOOLEAN NOT NULL DEFAULT 0,
|
||||
cts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
exp TIMESTAMP,
|
||||
subject_did TEXT NOT NULL,
|
||||
subject_repo TEXT NOT NULL DEFAULT '',
|
||||
UNIQUE(src, uri, val, neg)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_subject ON labels(subject_did, subject_repo);
|
||||
CREATE INDEX IF NOT EXISTS idx_labels_cts ON labels(cts DESC);
|
||||
`
|
||||
|
||||
// Label represents an ATProto label (com.atproto.label.defs#label).
|
||||
type Label struct {
|
||||
ID int64
|
||||
Src string
|
||||
URI string
|
||||
CID string
|
||||
Val string
|
||||
Neg bool
|
||||
Cts time.Time
|
||||
Exp *time.Time
|
||||
SubjectDID string
|
||||
SubjectRepo string
|
||||
}
|
||||
|
||||
// OpenDB opens or creates the labeler database.
|
||||
func OpenDB(dbPath string) (*sql.DB, error) {
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create db directory: %w", err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("libsql", "file:"+dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Apply schema
|
||||
for _, stmt := range splitStatements(schema) {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// splitStatements splits SQL by semicolons (go-libsql doesn't support multi-statement exec).
|
||||
func splitStatements(sql string) []string {
|
||||
var stmts []string
|
||||
for _, s := range splitOnSemicolon(sql) {
|
||||
s = trimSpace(s)
|
||||
if s != "" {
|
||||
stmts = append(stmts, s)
|
||||
}
|
||||
}
|
||||
return stmts
|
||||
}
|
||||
|
||||
func splitOnSemicolon(s string) []string {
|
||||
var parts []string
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ';' {
|
||||
parts = append(parts, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
parts = append(parts, s[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
// Simple trim that handles newlines and spaces
|
||||
i := 0
|
||||
for i < len(s) && (s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == '\r') {
|
||||
i++
|
||||
}
|
||||
j := len(s)
|
||||
for j > i && (s[j-1] == ' ' || s[j-1] == '\t' || s[j-1] == '\n' || s[j-1] == '\r') {
|
||||
j--
|
||||
}
|
||||
return s[i:j]
|
||||
}
|
||||
|
||||
// CreateLabel inserts a new label into the database.
|
||||
func CreateLabel(db *sql.DB, l *Label) (int64, error) {
|
||||
result, err := db.Exec(
|
||||
`INSERT INTO labels (src, uri, cid, val, neg, cts, exp, subject_did, subject_repo)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(src, uri, val, neg) DO UPDATE SET cts = excluded.cts`,
|
||||
l.Src, l.URI, l.CID, l.Val, l.Neg, l.Cts.UTC().Format(time.RFC3339), l.Exp,
|
||||
l.SubjectDID, l.SubjectRepo,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create label: %w", err)
|
||||
}
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
// NegateLabel creates a negation label to reverse a previous label.
|
||||
func NegateLabel(db *sql.DB, src, uri, val string, subjectDID, subjectRepo string) error {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
|
||||
VALUES (?, ?, ?, 1, ?, ?, ?)`,
|
||||
src, uri, val, time.Now().UTC().Format(time.RFC3339), subjectDID, subjectRepo,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetLabelsSince returns labels with id > cursor, ordered by id ascending.
|
||||
func GetLabelsSince(db *sql.DB, cursor int64, limit int) ([]Label, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
|
||||
FROM labels WHERE id > ? ORDER BY id ASC LIMIT ?`,
|
||||
cursor, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanLabels(rows)
|
||||
}
|
||||
|
||||
// ListActiveTakedowns returns active (non-negated) takedown labels.
|
||||
func ListActiveTakedowns(db *sql.DB, limit, offset int) ([]Label, int, error) {
|
||||
var total int
|
||||
err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM labels l1
|
||||
WHERE l1.val = '!takedown' AND l1.neg = 0
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM labels l2
|
||||
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
|
||||
AND l2.neg = 1 AND l2.id > l1.id
|
||||
)
|
||||
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)`,
|
||||
).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
`SELECT l1.id, l1.src, l1.uri, COALESCE(l1.cid, ''), l1.val, l1.neg, l1.cts, l1.exp, l1.subject_did, l1.subject_repo
|
||||
FROM labels l1
|
||||
WHERE l1.val = '!takedown' AND l1.neg = 0
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM labels l2
|
||||
WHERE l2.src = l1.src AND l2.uri = l1.uri AND l2.val = l1.val
|
||||
AND l2.neg = 1 AND l2.id > l1.id
|
||||
)
|
||||
AND (l1.exp IS NULL OR l1.exp > CURRENT_TIMESTAMP)
|
||||
ORDER BY l1.cts DESC LIMIT ? OFFSET ?`,
|
||||
limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
labels, err := scanLabels(rows)
|
||||
return labels, total, err
|
||||
}
|
||||
|
||||
// GetLabelsForRepo returns all active labels for a specific DID + repository.
|
||||
func GetLabelsForRepo(db *sql.DB, did, repo string) ([]Label, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, src, uri, COALESCE(cid, ''), val, neg, cts, exp, subject_did, subject_repo
|
||||
FROM labels
|
||||
WHERE subject_did = ? AND subject_repo = ?
|
||||
ORDER BY cts DESC`,
|
||||
did, repo,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanLabels(rows)
|
||||
}
|
||||
|
||||
// NegateRepoLabels creates negation labels for all active takedown labels on a (DID, repo) pair.
|
||||
func NegateRepoLabels(db *sql.DB, src, did, repo string) error {
|
||||
rows, err := db.Query(
|
||||
`SELECT uri FROM labels
|
||||
WHERE subject_did = ? AND subject_repo = ? AND val = '!takedown' AND neg = 0`,
|
||||
did, repo,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uris []string
|
||||
for rows.Next() {
|
||||
var uri string
|
||||
if err := rows.Scan(&uri); err != nil {
|
||||
rows.Close()
|
||||
return err
|
||||
}
|
||||
uris = append(uris, uri)
|
||||
}
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
for _, uri := range uris {
|
||||
if _, err := db.Exec(
|
||||
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
|
||||
VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
|
||||
src, uri, now, did, repo,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NegateUserLabels creates negation labels for all active takedown labels on a DID (user-level).
|
||||
func NegateUserLabels(db *sql.DB, src, did string) error {
|
||||
rows, err := db.Query(
|
||||
`SELECT uri, subject_repo FROM labels
|
||||
WHERE subject_did = ? AND val = '!takedown' AND neg = 0`,
|
||||
did,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
type uriRepo struct {
|
||||
uri string
|
||||
repo string
|
||||
}
|
||||
var entries []uriRepo
|
||||
for rows.Next() {
|
||||
var e uriRepo
|
||||
if err := rows.Scan(&e.uri, &e.repo); err != nil {
|
||||
rows.Close()
|
||||
return err
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
for _, e := range entries {
|
||||
if _, err := db.Exec(
|
||||
`INSERT INTO labels (src, uri, val, neg, cts, subject_did, subject_repo)
|
||||
VALUES (?, ?, '!takedown', 1, ?, ?, ?)`,
|
||||
src, e.uri, now, did, e.repo,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanLabels(rows *sql.Rows) ([]Label, error) {
|
||||
var labels []Label
|
||||
for rows.Next() {
|
||||
var l Label
|
||||
var cts string
|
||||
var exp *string
|
||||
if err := rows.Scan(&l.ID, &l.Src, &l.URI, &l.CID, &l.Val, &l.Neg, &cts, &exp, &l.SubjectDID, &l.SubjectRepo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, cts); err == nil {
|
||||
l.Cts = t
|
||||
}
|
||||
if exp != nil {
|
||||
if t, err := time.Parse(time.RFC3339, *exp); err == nil {
|
||||
l.Exp = &t
|
||||
}
|
||||
}
|
||||
labels = append(labels, l)
|
||||
}
|
||||
return labels, rows.Err()
|
||||
}
|
||||
412
pkg/labeler/db_test.go
Normal file
412
pkg/labeler/db_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOpenDB(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "subdir", "test.db")
|
||||
|
||||
db, err := OpenDB(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenDB failed: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Verify directory was created
|
||||
if _, err := os.Stat(filepath.Dir(dbPath)); os.IsNotExist(err) {
|
||||
t.Error("expected directory to be created")
|
||||
}
|
||||
|
||||
// Verify tables exist
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM labels").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query labels table: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0 labels, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLabel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
label := &Label{
|
||||
Src: "did:web:labeler.atcr.io",
|
||||
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
|
||||
Val: "!takedown",
|
||||
Cts: now,
|
||||
SubjectDID: "did:plc:abc",
|
||||
SubjectRepo: "myimage",
|
||||
}
|
||||
|
||||
id, err := CreateLabel(db, label)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateLabel failed: %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Errorf("expected positive id, got %d", id)
|
||||
}
|
||||
|
||||
// Verify it was stored
|
||||
labels, err := GetLabelsSince(db, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 1 {
|
||||
t.Fatalf("expected 1 label, got %d", len(labels))
|
||||
}
|
||||
if labels[0].Src != "did:web:labeler.atcr.io" {
|
||||
t.Errorf("expected src did:web:labeler.atcr.io, got %s", labels[0].Src)
|
||||
}
|
||||
if labels[0].Val != "!takedown" {
|
||||
t.Errorf("expected val !takedown, got %s", labels[0].Val)
|
||||
}
|
||||
if labels[0].SubjectDID != "did:plc:abc" {
|
||||
t.Errorf("expected subject_did did:plc:abc, got %s", labels[0].SubjectDID)
|
||||
}
|
||||
if labels[0].SubjectRepo != "myimage" {
|
||||
t.Errorf("expected subject_repo myimage, got %s", labels[0].SubjectRepo)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLabel_Upsert(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
now := time.Now().UTC()
|
||||
label := &Label{
|
||||
Src: "did:web:labeler.atcr.io",
|
||||
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
|
||||
Val: "!takedown",
|
||||
Cts: now,
|
||||
SubjectDID: "did:plc:abc",
|
||||
SubjectRepo: "myimage",
|
||||
}
|
||||
|
||||
// First insert
|
||||
_, err = CreateLabel(db, label)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Same (src, uri, val) - should upsert, not error
|
||||
label.Cts = now.Add(time.Hour)
|
||||
_, err = CreateLabel(db, label)
|
||||
if err != nil {
|
||||
t.Fatalf("upsert should not fail: %v", err)
|
||||
}
|
||||
|
||||
// Should still be 1 label
|
||||
labels, err := GetLabelsSince(db, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 1 {
|
||||
t.Errorf("expected 1 label after upsert, got %d", len(labels))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegateLabel(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Create a label
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
|
||||
Val: "!takedown", Cts: now,
|
||||
SubjectDID: "did:plc:abc", SubjectRepo: "myimage",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Negate it
|
||||
err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.manifest/sha256-123", "!takedown", "did:plc:abc", "myimage")
|
||||
if err != nil {
|
||||
t.Fatalf("NegateLabel failed: %v", err)
|
||||
}
|
||||
|
||||
// Should have 2 labels now (original + negation)
|
||||
labels, err := GetLabelsSince(db, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 2 {
|
||||
t.Fatalf("expected 2 labels, got %d", len(labels))
|
||||
}
|
||||
|
||||
// The negation label should have neg=true
|
||||
negLabel := labels[1]
|
||||
if !negLabel.Neg {
|
||||
t.Error("expected negation label to have neg=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListActiveTakedowns(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Create 3 labels
|
||||
for i, repo := range []string{"repo1", "repo2", "repo3"} {
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.repo/" + repo,
|
||||
Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
|
||||
SubjectDID: "did:plc:abc", SubjectRepo: repo,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// All 3 should be active
|
||||
labels, total, err := ListActiveTakedowns(db, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if total != 3 {
|
||||
t.Errorf("expected 3 active takedowns, got %d", total)
|
||||
}
|
||||
if len(labels) != 3 {
|
||||
t.Errorf("expected 3 labels returned, got %d", len(labels))
|
||||
}
|
||||
|
||||
// Negate one
|
||||
err = NegateLabel(db, src, "at://did:plc:abc/io.atcr.repo/repo2", "!takedown", "did:plc:abc", "repo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should be 2 active
|
||||
_, total, err = ListActiveTakedowns(db, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("expected 2 active takedowns after negation, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegateRepoLabels(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
did := "did:plc:abc"
|
||||
|
||||
// Create multiple labels for same repo
|
||||
uris := []string{
|
||||
"at://did:plc:abc/io.atcr.manifest/sha256-111",
|
||||
"at://did:plc:abc/io.atcr.manifest/sha256-222",
|
||||
"at://did:plc:abc/io.atcr.tag/myimage-latest",
|
||||
}
|
||||
for _, uri := range uris {
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: uri, Val: "!takedown", Cts: now,
|
||||
SubjectDID: did, SubjectRepo: "myimage",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Negate all labels for the repo
|
||||
err = NegateRepoLabels(db, src, did, "myimage")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should have 0 active takedowns
|
||||
_, total, err := ListActiveTakedowns(db, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if total != 0 {
|
||||
t.Errorf("expected 0 active takedowns after repo negation, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegateUserLabels(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
did := "did:plc:abc"
|
||||
|
||||
// Create labels for different repos + a user-level label
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc", Val: "!takedown", Cts: now,
|
||||
SubjectDID: did, SubjectRepo: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1", Val: "!takedown", Cts: now,
|
||||
SubjectDID: did, SubjectRepo: "repo1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Negate all labels for the user
|
||||
err = NegateUserLabels(db, src, did)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should have 0 active
|
||||
_, total, err := ListActiveTakedowns(db, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if total != 0 {
|
||||
t.Errorf("expected 0 active takedowns after user negation, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLabelsSince(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Create 5 labels
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.manifest/" + string(rune('a'+i)),
|
||||
Val: "!takedown", Cts: now.Add(time.Duration(i) * time.Minute),
|
||||
SubjectDID: "did:plc:abc", SubjectRepo: "repo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all since 0
|
||||
labels, err := GetLabelsSince(db, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 5 {
|
||||
t.Errorf("expected 5 labels, got %d", len(labels))
|
||||
}
|
||||
|
||||
// Get since cursor (skip first 3)
|
||||
if len(labels) >= 3 {
|
||||
cursor := labels[2].ID
|
||||
after, err := GetLabelsSince(db, cursor, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(after) != 2 {
|
||||
t.Errorf("expected 2 labels after cursor %d, got %d", cursor, len(after))
|
||||
}
|
||||
}
|
||||
|
||||
// Get with limit
|
||||
limited, err := GetLabelsSince(db, 0, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(limited) != 2 {
|
||||
t.Errorf("expected 2 labels with limit, got %d", len(limited))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLabelsForRepo(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := OpenDB(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
src := "did:web:labeler.atcr.io"
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Labels for different repos
|
||||
_, _ = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo1",
|
||||
Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo1",
|
||||
})
|
||||
_, _ = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:abc/io.atcr.repo/repo2",
|
||||
Val: "!takedown", Cts: now, SubjectDID: "did:plc:abc", SubjectRepo: "repo2",
|
||||
})
|
||||
_, _ = CreateLabel(db, &Label{
|
||||
Src: src, URI: "at://did:plc:def/io.atcr.repo/repo1",
|
||||
Val: "!takedown", Cts: now, SubjectDID: "did:plc:def", SubjectRepo: "repo1",
|
||||
})
|
||||
|
||||
// Get labels for specific did+repo
|
||||
labels, err := GetLabelsForRepo(db, "did:plc:abc", "repo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 1 {
|
||||
t.Errorf("expected 1 label for did:plc:abc/repo1, got %d", len(labels))
|
||||
}
|
||||
|
||||
// Different user same repo
|
||||
labels, err = GetLabelsForRepo(db, "did:plc:def", "repo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 1 {
|
||||
t.Errorf("expected 1 label for did:plc:def/repo1, got %d", len(labels))
|
||||
}
|
||||
|
||||
// No labels
|
||||
labels, err = GetLabelsForRepo(db, "did:plc:xyz", "repo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(labels) != 0 {
|
||||
t.Errorf("expected 0 labels for unknown did, got %d", len(labels))
|
||||
}
|
||||
}
|
||||
118
pkg/labeler/handlers.go
Normal file
118
pkg/labeler/handlers.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// Auth handlers
|
||||
|
||||
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// If already logged in, redirect to dashboard
|
||||
if token, ok := getSessionCookie(r); ok {
|
||||
if session := s.auth.getSession(token); session != nil && session.DID == s.config.Labeler.OwnerDID {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
errorMsg := r.URL.Query().Get("error")
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>%s Labeler - Login</title>
|
||||
<style>body{font-family:system-ui;max-width:400px;margin:100px auto;padding:0 20px}
|
||||
.error{color:red;margin-bottom:1em}
|
||||
input{width:100%%;padding:8px;margin:8px 0;box-sizing:border-box}
|
||||
button{padding:10px 20px;cursor:pointer}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>%s Labeler</h1>
|
||||
<p>Sign in with your AT Protocol identity.</p>
|
||||
%s
|
||||
<form action="/auth/oauth/authorize" method="GET">
|
||||
<input name="handle" placeholder="your.handle.com" required>
|
||||
<button type="submit">Sign In</button>
|
||||
</form>
|
||||
</body></html>`,
|
||||
s.config.Server.ClientShortName,
|
||||
s.config.Server.ClientShortName,
|
||||
func() string {
|
||||
if errorMsg != "" {
|
||||
return fmt.Sprintf(`<div class="error">%s</div>`, template.HTMLEscapeString(errorMsg))
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
handle := strings.TrimSpace(r.URL.Query().Get("handle"))
|
||||
if handle == "" {
|
||||
http.Redirect(w, r, "/auth/login?error=Handle+is+required", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
handle = strings.TrimPrefix(handle, "@")
|
||||
|
||||
did, _, _, err := atproto.ResolveIdentity(r.Context(), handle)
|
||||
if err != nil {
|
||||
slog.Warn("Failed to resolve handle for labeler login", "handle", handle, "error", err)
|
||||
http.Redirect(w, r, "/auth/login?error=Could+not+resolve+handle", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
authURL, err := s.clientApp.StartAuthFlow(r.Context(), did)
|
||||
if err != nil {
|
||||
slog.Error("Failed to start OAuth flow", "error", err)
|
||||
http.Redirect(w, r, "/auth/login?error=OAuth+initialization+failed", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
sessionData, err := s.clientApp.ProcessCallback(r.Context(), r.URL.Query())
|
||||
if err != nil {
|
||||
slog.Error("OAuth callback failed", "error", err)
|
||||
http.Redirect(w, r, "/auth/login?error=OAuth+authentication+failed", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
did := sessionData.AccountDID.String()
|
||||
|
||||
_, handle, _, err := atproto.ResolveIdentity(r.Context(), did)
|
||||
if err != nil {
|
||||
handle = did
|
||||
}
|
||||
|
||||
// Only allow the owner
|
||||
if did != s.config.Labeler.OwnerDID {
|
||||
slog.Warn("Non-owner attempted labeler access", "did", did, "handle", handle, "owner", s.config.Labeler.OwnerDID)
|
||||
http.Redirect(w, r, "/auth/login?error=Access+denied:+Only+the+labeler+owner+can+access+the+admin+panel", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := s.auth.createSession(did, handle)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
setSessionCookie(w, r, token)
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if token, ok := getSessionCookie(r); ok {
|
||||
s.auth.deleteSession(token)
|
||||
}
|
||||
clearSessionCookie(w)
|
||||
http.Redirect(w, r, "/auth/login", http.StatusFound)
|
||||
}
|
||||
57
pkg/labeler/identity.go
Normal file
57
pkg/labeler/identity.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DIDDocument represents a did:web DID document.
|
||||
type DIDDocument struct {
|
||||
Context []string `json:"@context"`
|
||||
ID string `json:"id"`
|
||||
Service []DIDService `json:"service,omitempty"`
|
||||
}
|
||||
|
||||
// DIDService represents a service entry in a DID document.
|
||||
type DIDService struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
ServiceEndpoint string `json:"serviceEndpoint"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDIDDocument(w http.ResponseWriter, r *http.Request) {
|
||||
doc := DIDDocument{
|
||||
Context: []string{"https://www.w3.org/ns/did/v1"},
|
||||
ID: s.config.DID(),
|
||||
Service: []DIDService{
|
||||
{
|
||||
ID: "#atproto_labeler",
|
||||
Type: "AtprotoLabeler",
|
||||
ServiceEndpoint: s.config.PublicURL(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(doc)
|
||||
}
|
||||
|
||||
func (s *Server) handleClientMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
publicURL := s.config.PublicURL()
|
||||
metadata := map[string]any{
|
||||
"client_id": publicURL + "/oauth-client-metadata.json",
|
||||
"client_name": fmt.Sprintf("%s Labeler", s.config.Server.ClientShortName),
|
||||
"client_uri": publicURL,
|
||||
"redirect_uris": []string{publicURL + "/auth/oauth/callback"},
|
||||
"scope": "atproto",
|
||||
"grant_types": []string{"authorization_code"},
|
||||
"response_types": []string{"code"},
|
||||
"token_endpoint_auth_method": "none",
|
||||
"application_type": "web",
|
||||
"dpop_bound_access_tokens": true,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(metadata)
|
||||
}
|
||||
156
pkg/labeler/server.go
Normal file
156
pkg/labeler/server.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
indigooauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Server is the labeler HTTP server.
|
||||
type Server struct {
|
||||
config *Config
|
||||
db *sql.DB
|
||||
router chi.Router
|
||||
clientApp *indigooauth.ClientApp
|
||||
auth *Auth
|
||||
}
|
||||
|
||||
// NewServer creates a new labeler server.
|
||||
func NewServer(cfg *Config) (*Server, error) {
|
||||
db, err := OpenDB(cfg.Labeler.DBPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
publicURL := cfg.PublicURL()
|
||||
|
||||
// Set up OAuth client for admin login
|
||||
oauthStore := indigooauth.NewMemStore()
|
||||
scopes := []string{"atproto"}
|
||||
|
||||
var oauthConfig indigooauth.ClientConfig
|
||||
var redirectURI string
|
||||
|
||||
u, err := url.Parse(publicURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid public URL: %w", err)
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if isLocalhost(host) {
|
||||
port := u.Port()
|
||||
if port == "" {
|
||||
port = "5002"
|
||||
}
|
||||
oauthBaseURL := "http://127.0.0.1:" + port
|
||||
redirectURI = oauthBaseURL + "/auth/oauth/callback"
|
||||
oauthConfig = indigooauth.NewLocalhostConfig(redirectURI, scopes)
|
||||
} else {
|
||||
clientID := publicURL + "/oauth-client-metadata.json"
|
||||
redirectURI = publicURL + "/auth/oauth/callback"
|
||||
oauthConfig = indigooauth.NewPublicConfig(clientID, redirectURI, scopes)
|
||||
}
|
||||
|
||||
clientApp := indigooauth.NewClientApp(&oauthConfig, oauthStore)
|
||||
clientApp.Dir = atproto.GetDirectory()
|
||||
|
||||
auth := NewAuth(cfg.Labeler.OwnerDID)
|
||||
|
||||
s := &Server{
|
||||
config: cfg,
|
||||
db: db,
|
||||
clientApp: clientApp,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// DID document
|
||||
r.Get("/.well-known/did.json", s.handleDIDDocument)
|
||||
|
||||
// OAuth client metadata
|
||||
r.Get("/oauth-client-metadata.json", s.handleClientMetadata)
|
||||
|
||||
// Auth routes (public)
|
||||
r.Get("/auth/login", s.handleLogin)
|
||||
r.Get("/auth/oauth/authorize", s.handleAuthorize)
|
||||
r.Get("/auth/oauth/callback", s.handleCallback)
|
||||
r.Get("/auth/logout", s.handleLogout)
|
||||
|
||||
// XRPC endpoints (public)
|
||||
r.Get("/xrpc/com.atproto.label.subscribeLabels", s.handleSubscribeLabels)
|
||||
r.Get("/xrpc/com.atproto.label.queryLabels", s.handleQueryLabels)
|
||||
|
||||
// Protected routes (require owner)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.auth.RequireOwner)
|
||||
|
||||
r.Get("/", s.handleDashboard)
|
||||
r.Get("/takedown", s.handleTakedownForm)
|
||||
r.Post("/takedown", s.handleTakedownSubmit)
|
||||
r.Post("/reverse", s.handleReverse)
|
||||
})
|
||||
|
||||
s.router = r
|
||||
}
|
||||
|
||||
// Serve starts the HTTP server with graceful shutdown.
|
||||
func (s *Server) Serve() error {
|
||||
slog.Info("Starting labeler service",
|
||||
"addr", s.config.Labeler.Addr,
|
||||
"public_url", s.config.PublicURL(),
|
||||
"did", s.config.DID(),
|
||||
"owner", s.config.Labeler.OwnerDID,
|
||||
)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: s.config.Labeler.Addr,
|
||||
Handler: s.router,
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
slog.Info("Shutting down labeler service")
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5000000000) // 5s
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("shutdown error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.db.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isLocalhost(host string) bool {
|
||||
return host == "localhost" || host == "127.0.0.1" || strings.HasPrefix(host, "192.168.")
|
||||
}
|
||||
187
pkg/labeler/subscribe.go
Normal file
187
pkg/labeler/subscribe.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// LabelsMessage is the ATProto subscribeLabels wire format.
|
||||
type LabelsMessage struct {
|
||||
Seq int64 `json:"seq"`
|
||||
Labels []LabelOutput `json:"labels"`
|
||||
}
|
||||
|
||||
// LabelOutput is the ATProto label format for subscribeLabels/queryLabels output.
|
||||
type LabelOutput struct {
|
||||
Src string `json:"src"`
|
||||
URI string `json:"uri"`
|
||||
CID string `json:"cid,omitempty"`
|
||||
Val string `json:"val"`
|
||||
Neg bool `json:"neg"`
|
||||
Cts string `json:"cts"`
|
||||
Exp string `json:"exp,omitempty"`
|
||||
}
|
||||
|
||||
func labelToOutput(l Label) LabelOutput {
|
||||
out := LabelOutput{
|
||||
Src: l.Src,
|
||||
URI: l.URI,
|
||||
CID: l.CID,
|
||||
Val: l.Val,
|
||||
Neg: l.Neg,
|
||||
Cts: l.Cts.UTC().Format(time.RFC3339),
|
||||
}
|
||||
if l.Exp != nil {
|
||||
out.Exp = l.Exp.UTC().Format(time.RFC3339)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// handleSubscribeLabels implements com.atproto.label.subscribeLabels (WebSocket).
|
||||
func (s *Server) handleSubscribeLabels(w http.ResponseWriter, r *http.Request) {
|
||||
cursorStr := r.URL.Query().Get("cursor")
|
||||
var cursor int64
|
||||
if cursorStr != "" {
|
||||
var err error
|
||||
cursor, err = strconv.ParseInt(cursorStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid cursor", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("WebSocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
slog.Info("subscribeLabels client connected", "cursor", cursor)
|
||||
|
||||
// Send historical labels since cursor
|
||||
labels, err := GetLabelsSince(s.db, cursor, 1000)
|
||||
if err != nil {
|
||||
slog.Error("Failed to get labels", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, l := range labels {
|
||||
msg := LabelsMessage{
|
||||
Seq: l.ID,
|
||||
Labels: []LabelOutput{labelToOutput(l)},
|
||||
}
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
return
|
||||
}
|
||||
cursor = l.ID
|
||||
}
|
||||
|
||||
// Poll for new labels
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Read pump (detect client disconnect)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
labels, err := GetLabelsSince(s.db, cursor, 100)
|
||||
if err != nil {
|
||||
slog.Error("Failed to poll labels", "error", err)
|
||||
continue
|
||||
}
|
||||
for _, l := range labels {
|
||||
msg := LabelsMessage{
|
||||
Seq: l.ID,
|
||||
Labels: []LabelOutput{labelToOutput(l)},
|
||||
}
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
return
|
||||
}
|
||||
cursor = l.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleQueryLabels implements com.atproto.label.queryLabels (HTTP GET).
|
||||
func (s *Server) handleQueryLabels(w http.ResponseWriter, r *http.Request) {
|
||||
uriPatterns := r.URL.Query()["uriPatterns"]
|
||||
cursorStr := r.URL.Query().Get("cursor")
|
||||
limitStr := r.URL.Query().Get("limit")
|
||||
|
||||
var cursor int64
|
||||
if cursorStr != "" {
|
||||
cursor, _ = strconv.ParseInt(cursorStr, 10, 64)
|
||||
}
|
||||
limit := 50
|
||||
if limitStr != "" {
|
||||
if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 250 {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
labels, err := GetLabelsSince(s.db, cursor, limit)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to query labels", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Filter by URI patterns if provided
|
||||
var filtered []LabelOutput
|
||||
for _, l := range labels {
|
||||
if len(uriPatterns) == 0 || matchesAnyPattern(l.URI, uriPatterns) {
|
||||
filtered = append(filtered, labelToOutput(l))
|
||||
}
|
||||
}
|
||||
|
||||
var nextCursor string
|
||||
if len(labels) > 0 {
|
||||
nextCursor = strconv.FormatInt(labels[len(labels)-1].ID, 10)
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
Cursor string `json:"cursor,omitempty"`
|
||||
Labels []LabelOutput `json:"labels"`
|
||||
}{
|
||||
Cursor: nextCursor,
|
||||
Labels: filtered,
|
||||
}
|
||||
if resp.Labels == nil {
|
||||
resp.Labels = []LabelOutput{}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func matchesAnyPattern(uri string, patterns []string) bool {
|
||||
for _, p := range patterns {
|
||||
// Simple prefix matching (ATProto spec allows glob-like patterns)
|
||||
if p == uri || (len(p) > 0 && p[len(p)-1] == '*' && len(uri) >= len(p)-1 && uri[:len(p)-1] == p[:len(p)-1]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
86
pkg/labeler/subscribe_test.go
Normal file
86
pkg/labeler/subscribe_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLabelToOutput(t *testing.T) {
|
||||
now := time.Date(2026, 3, 22, 10, 0, 0, 0, time.UTC)
|
||||
exp := time.Date(2026, 4, 22, 10, 0, 0, 0, time.UTC)
|
||||
|
||||
label := Label{
|
||||
ID: 1,
|
||||
Src: "did:web:labeler.atcr.io",
|
||||
URI: "at://did:plc:abc/io.atcr.manifest/sha256-123",
|
||||
CID: "bafyabc",
|
||||
Val: "!takedown",
|
||||
Neg: false,
|
||||
Cts: now,
|
||||
Exp: &exp,
|
||||
SubjectDID: "did:plc:abc",
|
||||
SubjectRepo: "myimage",
|
||||
}
|
||||
|
||||
out := labelToOutput(label)
|
||||
if out.Src != "did:web:labeler.atcr.io" {
|
||||
t.Errorf("Src = %q, want did:web:labeler.atcr.io", out.Src)
|
||||
}
|
||||
if out.URI != "at://did:plc:abc/io.atcr.manifest/sha256-123" {
|
||||
t.Errorf("URI = %q", out.URI)
|
||||
}
|
||||
if out.CID != "bafyabc" {
|
||||
t.Errorf("CID = %q, want bafyabc", out.CID)
|
||||
}
|
||||
if out.Val != "!takedown" {
|
||||
t.Errorf("Val = %q", out.Val)
|
||||
}
|
||||
if out.Neg {
|
||||
t.Error("expected Neg=false")
|
||||
}
|
||||
if out.Cts != "2026-03-22T10:00:00Z" {
|
||||
t.Errorf("Cts = %q", out.Cts)
|
||||
}
|
||||
if out.Exp != "2026-04-22T10:00:00Z" {
|
||||
t.Errorf("Exp = %q", out.Exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLabelToOutput_NoExpiration(t *testing.T) {
|
||||
label := Label{
|
||||
Src: "did:web:labeler.atcr.io",
|
||||
URI: "at://did:plc:abc",
|
||||
Val: "!takedown",
|
||||
Cts: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
out := labelToOutput(label)
|
||||
if out.Exp != "" {
|
||||
t.Errorf("expected empty Exp, got %q", out.Exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesAnyPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
patterns []string
|
||||
want bool
|
||||
}{
|
||||
{"exact match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/io.atcr.manifest/sha256-123"}, true},
|
||||
{"no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/io.atcr.manifest/sha256-123"}, false},
|
||||
{"wildcard match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:abc/*"}, true},
|
||||
{"wildcard no match", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*"}, false},
|
||||
{"empty patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{}, false},
|
||||
{"multiple patterns", "at://did:plc:abc/io.atcr.manifest/sha256-123", []string{"at://did:plc:def/*", "at://did:plc:abc/*"}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := matchesAnyPattern(tt.uri, tt.patterns)
|
||||
if got != tt.want {
|
||||
t.Errorf("matchesAnyPattern(%q, %v) = %v, want %v", tt.uri, tt.patterns, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
418
pkg/labeler/takedown.go
Normal file
418
pkg/labeler/takedown.go
Normal file
@@ -0,0 +1,418 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"atcr.io/pkg/atproto"
|
||||
)
|
||||
|
||||
// TakedownInput represents parsed takedown input.
|
||||
type TakedownInput struct {
|
||||
DID string
|
||||
Handle string
|
||||
Repository string // empty = user-level takedown
|
||||
}
|
||||
|
||||
// ParseTakedownInput parses various input formats into a TakedownInput.
|
||||
// Supported formats:
|
||||
// - atcr.io/r/handle/repo
|
||||
// - handle/repo
|
||||
// - at://did:plc:xyz/io.atcr.repo.page/repo
|
||||
// - at://did:plc:xyz (user-level)
|
||||
// - handle (user-level)
|
||||
// - did:plc:xyz (user-level)
|
||||
func ParseTakedownInput(ctx context.Context, input string) (*TakedownInput, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// AT URI format
|
||||
if strings.HasPrefix(input, "at://") {
|
||||
return parseATURI(ctx, input)
|
||||
}
|
||||
|
||||
// Strip URL prefix if present
|
||||
input = strings.TrimPrefix(input, "https://")
|
||||
input = strings.TrimPrefix(input, "http://")
|
||||
|
||||
// Remove atcr.io/r/ or similar prefix
|
||||
for _, prefix := range []string{"atcr.io/r/", "localhost/r/"} {
|
||||
if strings.HasPrefix(input, prefix) {
|
||||
input = strings.TrimPrefix(input, prefix)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Also handle custom domains: anything ending in /r/
|
||||
if idx := strings.Index(input, "/r/"); idx >= 0 {
|
||||
input = input[idx+3:]
|
||||
}
|
||||
|
||||
// Now input should be "handle/repo" or "handle" or "did:xxx"
|
||||
parts := strings.SplitN(input, "/", 2)
|
||||
identifier := parts[0]
|
||||
var repo string
|
||||
if len(parts) > 1 {
|
||||
repo = parts[1]
|
||||
repo = strings.TrimSuffix(repo, "/")
|
||||
}
|
||||
|
||||
did, handle, err := resolveIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TakedownInput{
|
||||
DID: did,
|
||||
Handle: handle,
|
||||
Repository: repo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseATURI(ctx context.Context, uri string) (*TakedownInput, error) {
|
||||
// at://did:plc:xyz/collection/rkey
|
||||
trimmed := strings.TrimPrefix(uri, "at://")
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
|
||||
did := parts[0]
|
||||
if !strings.HasPrefix(did, "did:") {
|
||||
// It's a handle
|
||||
resolvedDID, handle, err := resolveIdentifier(ctx, did)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
did = resolvedDID
|
||||
if len(parts) >= 3 {
|
||||
return &TakedownInput{DID: did, Handle: handle, Repository: parts[2]}, nil
|
||||
}
|
||||
return &TakedownInput{DID: did, Handle: handle}, nil
|
||||
}
|
||||
|
||||
// Resolve handle from DID
|
||||
_, handle, _, _ := atproto.ResolveIdentity(ctx, did)
|
||||
|
||||
if len(parts) < 3 {
|
||||
// User-level takedown
|
||||
return &TakedownInput{DID: did, Handle: handle}, nil
|
||||
}
|
||||
|
||||
// Extract repository from rkey (third part)
|
||||
repo := parts[2]
|
||||
return &TakedownInput{DID: did, Handle: handle, Repository: repo}, nil
|
||||
}
|
||||
|
||||
func resolveIdentifier(ctx context.Context, identifier string) (did, handle string, err error) {
|
||||
did, handle, _, err = atproto.ResolveIdentity(ctx, identifier)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to resolve %q: %w", identifier, err)
|
||||
}
|
||||
return did, handle, nil
|
||||
}
|
||||
|
||||
// TakedownResult contains the results of a takedown operation.
|
||||
type TakedownResult struct {
|
||||
DID string
|
||||
Handle string
|
||||
Repository string
|
||||
Labels []Label
|
||||
UserLevel bool
|
||||
}
|
||||
|
||||
// ExecuteTakedown creates takedown labels for a repo or user.
|
||||
func (s *Server) ExecuteTakedown(ctx context.Context, input *TakedownInput) (*TakedownResult, error) {
|
||||
src := s.config.DID()
|
||||
now := time.Now().UTC()
|
||||
result := &TakedownResult{
|
||||
DID: input.DID,
|
||||
Handle: input.Handle,
|
||||
Repository: input.Repository,
|
||||
UserLevel: input.Repository == "",
|
||||
}
|
||||
|
||||
if input.Repository == "" {
|
||||
// User-level takedown
|
||||
label := &Label{
|
||||
Src: src,
|
||||
URI: "at://" + input.DID,
|
||||
Val: "!takedown",
|
||||
Cts: now,
|
||||
SubjectDID: input.DID,
|
||||
SubjectRepo: "",
|
||||
}
|
||||
if _, err := CreateLabel(s.db, label); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user-level label: %w", err)
|
||||
}
|
||||
result.Labels = append(result.Labels, *label)
|
||||
slog.Info("Created user-level takedown", "did", input.DID, "handle", input.Handle)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Repo-level takedown: discover all records from PDS
|
||||
labels, err := s.discoverAndLabelRecords(ctx, input.DID, input.Repository, src, now)
|
||||
if err != nil {
|
||||
// Even if PDS discovery fails, create a repo-level summary label
|
||||
slog.Warn("PDS discovery failed, creating summary label only", "error", err)
|
||||
}
|
||||
result.Labels = append(result.Labels, labels...)
|
||||
|
||||
// Always create a repo-level summary label for efficient filtering
|
||||
summaryLabel := &Label{
|
||||
Src: src,
|
||||
URI: fmt.Sprintf("at://%s/io.atcr.repo/%s", input.DID, input.Repository),
|
||||
Val: "!takedown",
|
||||
Cts: now,
|
||||
SubjectDID: input.DID,
|
||||
SubjectRepo: input.Repository,
|
||||
}
|
||||
if _, err := CreateLabel(s.db, summaryLabel); err != nil {
|
||||
return nil, fmt.Errorf("failed to create summary label: %w", err)
|
||||
}
|
||||
result.Labels = append(result.Labels, *summaryLabel)
|
||||
|
||||
slog.Info("Created repo-level takedown",
|
||||
"did", input.DID,
|
||||
"handle", input.Handle,
|
||||
"repository", input.Repository,
|
||||
"label_count", len(result.Labels),
|
||||
)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// discoverAndLabelRecords queries the user's PDS for all records in the given repo
|
||||
// and creates takedown labels for each.
|
||||
func (s *Server) discoverAndLabelRecords(ctx context.Context, did, repo, src string, now time.Time) ([]Label, error) {
|
||||
_, _, pdsEndpoint, err := atproto.ResolveIdentity(ctx, did)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve DID: %w", err)
|
||||
}
|
||||
|
||||
client := atproto.NewClient(pdsEndpoint, did, "")
|
||||
var labels []Label
|
||||
|
||||
// Collections to search
|
||||
collections := []string{
|
||||
atproto.ManifestCollection,
|
||||
atproto.TagCollection,
|
||||
atproto.RepoPageCollection,
|
||||
}
|
||||
|
||||
for _, collection := range collections {
|
||||
records, _, err := client.ListRecordsForRepo(ctx, did, collection, 100, "")
|
||||
if err != nil {
|
||||
slog.Warn("Failed to list records", "collection", collection, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rec := range records {
|
||||
// Filter by repository field
|
||||
recRepo := extractRepoField(rec.Value, collection)
|
||||
if recRepo != repo {
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the full AT URI from the record (at://did/collection/rkey)
|
||||
uri := rec.URI
|
||||
label := &Label{
|
||||
Src: src,
|
||||
URI: uri,
|
||||
Val: "!takedown",
|
||||
Cts: now,
|
||||
SubjectDID: did,
|
||||
SubjectRepo: repo,
|
||||
}
|
||||
if _, err := CreateLabel(s.db, label); err != nil {
|
||||
slog.Warn("Failed to create label", "uri", uri, "error", err)
|
||||
continue
|
||||
}
|
||||
labels = append(labels, *label)
|
||||
}
|
||||
}
|
||||
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
// extractRepoField extracts the repository name from a record's JSON value.
|
||||
func extractRepoField(value json.RawMessage, collection string) string {
|
||||
// For repo pages, the rkey IS the repository name, but we also check the value
|
||||
var rec struct {
|
||||
Repository string `json:"repository"`
|
||||
}
|
||||
if err := json.Unmarshal(value, &rec); err == nil && rec.Repository != "" {
|
||||
return rec.Repository
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handlers
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
labels, total, err := ListActiveTakedowns(s.db, 50, 0)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to list takedowns", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>%s Labeler</title>
|
||||
<style>
|
||||
body{font-family:system-ui;max-width:900px;margin:40px auto;padding:0 20px}
|
||||
table{width:100%%;border-collapse:collapse;margin:20px 0}
|
||||
th,td{text-align:left;padding:8px;border-bottom:1px solid #ddd}
|
||||
th{background:#f5f5f5}
|
||||
.badge{background:#dc2626;color:white;padding:2px 8px;border-radius:4px;font-size:0.85em}
|
||||
a{color:#2563eb}
|
||||
nav{display:flex;gap:16px;margin-bottom:24px}
|
||||
.btn{padding:8px 16px;background:#2563eb;color:white;text-decoration:none;border-radius:4px;border:none;cursor:pointer}
|
||||
.btn-danger{background:#dc2626}
|
||||
form{display:inline}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>%s Labeler</h1>
|
||||
<nav>
|
||||
<a href="/" class="btn">Dashboard</a>
|
||||
<a href="/takedown" class="btn">New Takedown</a>
|
||||
<a href="/auth/logout">Logout</a>
|
||||
</nav>
|
||||
<h2>Active Takedowns (%d)</h2>`,
|
||||
s.config.Server.ClientShortName,
|
||||
s.config.Server.ClientShortName,
|
||||
total,
|
||||
)
|
||||
|
||||
if len(labels) == 0 {
|
||||
fmt.Fprint(w, `<p>No active takedowns.</p>`)
|
||||
} else {
|
||||
fmt.Fprint(w, `<table><tr><th>Subject</th><th>Repository</th><th>URI</th><th>Created</th><th>Action</th></tr>`)
|
||||
for _, l := range labels {
|
||||
repoDisplay := l.SubjectRepo
|
||||
if repoDisplay == "" {
|
||||
repoDisplay = "<em>all repos (user-level)</em>"
|
||||
}
|
||||
fmt.Fprintf(w, `<tr>
|
||||
<td>%s</td>
|
||||
<td>%s</td>
|
||||
<td><code>%s</code></td>
|
||||
<td>%s</td>
|
||||
<td><form method="POST" action="/reverse"><input type="hidden" name="did" value="%s"><input type="hidden" name="repo" value="%s"><button type="submit" class="btn btn-danger" onclick="return confirm('Reverse this takedown?')">Reverse</button></form></td>
|
||||
</tr>`,
|
||||
template.HTMLEscapeString(l.SubjectDID),
|
||||
repoDisplay,
|
||||
template.HTMLEscapeString(l.URI),
|
||||
l.Cts.Format("2006-01-02 15:04"),
|
||||
template.HTMLEscapeString(l.SubjectDID),
|
||||
template.HTMLEscapeString(l.SubjectRepo),
|
||||
)
|
||||
}
|
||||
fmt.Fprint(w, `</table>`)
|
||||
}
|
||||
|
||||
fmt.Fprint(w, `</body></html>`)
|
||||
}
|
||||
|
||||
func (s *Server) handleTakedownForm(w http.ResponseWriter, r *http.Request) {
|
||||
msg := r.URL.Query().Get("msg")
|
||||
errorMsg := r.URL.Query().Get("error")
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>%s Labeler - New Takedown</title>
|
||||
<style>
|
||||
body{font-family:system-ui;max-width:600px;margin:40px auto;padding:0 20px}
|
||||
input[type=text]{width:100%%;padding:8px;margin:8px 0;box-sizing:border-box}
|
||||
.btn{padding:10px 20px;background:#dc2626;color:white;border:none;border-radius:4px;cursor:pointer;font-size:1em}
|
||||
.success{color:green;margin-bottom:1em}
|
||||
.error{color:red;margin-bottom:1em}
|
||||
a{color:#2563eb}
|
||||
nav{display:flex;gap:16px;margin-bottom:24px}
|
||||
.nav-btn{padding:8px 16px;background:#2563eb;color:white;text-decoration:none;border-radius:4px}
|
||||
.help{color:#666;font-size:0.9em;margin-top:4px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>New Takedown</h1>
|
||||
<nav>
|
||||
<a href="/" class="nav-btn">Dashboard</a>
|
||||
<a href="/takedown" class="nav-btn">New Takedown</a>
|
||||
</nav>`,
|
||||
s.config.Server.ClientShortName,
|
||||
)
|
||||
|
||||
if msg != "" {
|
||||
fmt.Fprintf(w, `<div class="success">%s</div>`, template.HTMLEscapeString(msg))
|
||||
}
|
||||
if errorMsg != "" {
|
||||
fmt.Fprintf(w, `<div class="error">%s</div>`, template.HTMLEscapeString(errorMsg))
|
||||
}
|
||||
|
||||
fmt.Fprint(w, `
|
||||
<form method="POST" action="/takedown">
|
||||
<label for="target"><strong>Target</strong></label>
|
||||
<input type="text" id="target" name="target" placeholder="atcr.io/r/handle/repo, at://did/collection/rkey, or handle" required>
|
||||
<p class="help">Accepts repo URLs, AT URIs, handles, or DIDs. Omit the repo for a user-level takedown.</p>
|
||||
<br>
|
||||
<button type="submit" class="btn" onclick="return confirm('Issue takedown? This will suppress the content immediately.')">Issue Takedown</button>
|
||||
</form>
|
||||
</body></html>`)
|
||||
}
|
||||
|
||||
func (s *Server) handleTakedownSubmit(w http.ResponseWriter, r *http.Request) {
|
||||
target := strings.TrimSpace(r.FormValue("target"))
|
||||
if target == "" {
|
||||
http.Redirect(w, r, "/takedown?error=Target+is+required", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
input, err := ParseTakedownInput(r.Context(), target)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/takedown?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := s.ExecuteTakedown(r.Context(), input)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/takedown?error="+strings.ReplaceAll(err.Error(), " ", "+"), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Takedown issued: %d labels created for %s", len(result.Labels), result.DID)
|
||||
if result.Repository != "" {
|
||||
msg += "/" + result.Repository
|
||||
}
|
||||
http.Redirect(w, r, "/takedown?msg="+strings.ReplaceAll(msg, " ", "+"), http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *Server) handleReverse(w http.ResponseWriter, r *http.Request) {
|
||||
did := strings.TrimSpace(r.FormValue("did"))
|
||||
repo := strings.TrimSpace(r.FormValue("repo"))
|
||||
|
||||
if did == "" {
|
||||
http.Redirect(w, r, "/?error=DID+is+required", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
src := s.config.DID()
|
||||
var err error
|
||||
if repo == "" {
|
||||
err = NegateUserLabels(s.db, src, did)
|
||||
} else {
|
||||
err = NegateRepoLabels(s.db, src, did, repo)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.Error("Failed to reverse takedown", "did", did, "repo", repo, "error", err)
|
||||
http.Redirect(w, r, "/?error=Failed+to+reverse+takedown", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Reversed takedown", "did", did, "repo", repo)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
139
pkg/labeler/takedown_test.go
Normal file
139
pkg/labeler/takedown_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package labeler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseTakedownInput_RepoURL(t *testing.T) {
|
||||
// These tests only exercise parsing logic, not PDS resolution.
|
||||
// ResolveIdentity calls are tested with mock server below.
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantRepo string
|
||||
}{
|
||||
{"full URL", "https://atcr.io/r/handle/myimage", "myimage"},
|
||||
{"no scheme", "atcr.io/r/handle/myimage", "myimage"},
|
||||
{"handle/repo", "handle/myimage", "myimage"},
|
||||
{"trailing slash", "atcr.io/r/handle/myimage/", "myimage"},
|
||||
{"custom domain", "https://registry.example.com/r/handle/myimage", "myimage"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// These will fail on ResolveIdentity since there's no real PDS,
|
||||
// but we can at least verify the parsing doesn't panic
|
||||
_, err := ParseTakedownInput(context.Background(), tt.input)
|
||||
if err == nil {
|
||||
t.Skip("ResolveIdentity succeeded unexpectedly (network available)")
|
||||
}
|
||||
// The error should be from resolution, not parsing
|
||||
if err != nil {
|
||||
t.Logf("Expected resolution error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTakedownInput_ATURI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantDID string
|
||||
wantRepo string
|
||||
}{
|
||||
{
|
||||
"full AT URI with collection",
|
||||
"at://did:plc:abc123/io.atcr.repo.page/myimage",
|
||||
"did:plc:abc123",
|
||||
"myimage",
|
||||
},
|
||||
{
|
||||
"DID only (user-level)",
|
||||
"at://did:plc:abc123",
|
||||
"did:plc:abc123",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"DID with collection and rkey",
|
||||
"at://did:plc:xyz/io.atcr.manifest/sha256-deadbeef",
|
||||
"did:plc:xyz",
|
||||
"sha256-deadbeef",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input, err := ParseTakedownInput(context.Background(), tt.input)
|
||||
if err != nil {
|
||||
// Resolution may fail for handle-based AT URIs
|
||||
t.Logf("Parse error (may be expected): %v", err)
|
||||
return
|
||||
}
|
||||
if input.DID != tt.wantDID {
|
||||
t.Errorf("DID = %q, want %q", input.DID, tt.wantDID)
|
||||
}
|
||||
if input.Repository != tt.wantRepo {
|
||||
t.Errorf("Repository = %q, want %q", input.Repository, tt.wantRepo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTakedownInput_DID(t *testing.T) {
|
||||
// Direct DID input (user-level takedown)
|
||||
input, err := ParseTakedownInput(context.Background(), "at://did:plc:abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if input.DID != "did:plc:abc123" {
|
||||
t.Errorf("DID = %q, want did:plc:abc123", input.DID)
|
||||
}
|
||||
if input.Repository != "" {
|
||||
t.Errorf("Repository = %q, want empty (user-level)", input.Repository)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRepoField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
collection string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"manifest with repository",
|
||||
`{"$type":"io.atcr.manifest","repository":"myimage","digest":"sha256:abc"}`,
|
||||
"io.atcr.manifest",
|
||||
"myimage",
|
||||
},
|
||||
{
|
||||
"tag with repository",
|
||||
`{"$type":"io.atcr.tag","repository":"myimage","tag":"latest"}`,
|
||||
"io.atcr.tag",
|
||||
"myimage",
|
||||
},
|
||||
{
|
||||
"no repository field",
|
||||
`{"$type":"io.atcr.manifest","digest":"sha256:abc"}`,
|
||||
"io.atcr.manifest",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid JSON",
|
||||
`{invalid}`,
|
||||
"io.atcr.manifest",
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractRepoField([]byte(tt.value), tt.collection)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractRepoField() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user