Files
at-container-registry/pkg/billing/billing.go

777 lines
21 KiB
Go

//go:build billing
package billing
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"sync"
"time"
"atcr.io/pkg/appview/holdclient"
"github.com/bluesky-social/indigo/atproto/atcrypto"
"github.com/stripe/stripe-go/v84"
portalsession "github.com/stripe/stripe-go/v84/billingportal/session"
"github.com/stripe/stripe-go/v84/checkout/session"
"github.com/stripe/stripe-go/v84/customer"
"github.com/stripe/stripe-go/v84/price"
"github.com/stripe/stripe-go/v84/subscription"
"github.com/stripe/stripe-go/v84/webhook"
)
// Manager handles Stripe billing and pushes tier updates to managed holds.
type Manager struct {
cfg *Config
privateKey *atcrypto.PrivateKeyP256
appviewDID string
managedHolds []string
baseURL string
stripeKey string
webhookSecret string
// Captain checker: bypasses billing for hold owners
captainChecker CaptainChecker
// Customer cache: DID → Stripe customer
customerCache map[string]*cachedCustomer
customerCacheMu sync.RWMutex
// Price cache: Stripe price ID → unit amount in cents
priceCache map[string]*cachedPrice
priceCacheMu sync.RWMutex
// Hold tier cache: holdDID → tier list
holdTierCache map[string]*cachedHoldTiers
holdTierCacheMu sync.RWMutex
}
type cachedHoldTiers struct {
tiers []holdclient.HoldTierInfo
expiresAt time.Time
}
type cachedCustomer struct {
customer *stripe.Customer
expiresAt time.Time
}
type cachedPrice struct {
unitAmount int64
expiresAt time.Time
}
const customerCacheTTL = 10 * time.Minute
const priceCacheTTL = 1 * time.Hour
// New creates a new billing manager with Stripe integration.
// Env vars STRIPE_SECRET_KEY and STRIPE_WEBHOOK_SECRET take precedence over config values.
func New(cfg *Config, privateKey *atcrypto.PrivateKeyP256, appviewDID string, managedHolds []string, baseURL string) *Manager {
stripeKey := os.Getenv("STRIPE_SECRET_KEY")
if stripeKey == "" {
stripeKey = cfg.StripeSecretKey
}
if stripeKey != "" {
stripe.Key = stripeKey
}
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
if webhookSecret == "" {
webhookSecret = cfg.WebhookSecret
}
return &Manager{
cfg: cfg,
privateKey: privateKey,
appviewDID: appviewDID,
managedHolds: managedHolds,
baseURL: baseURL,
stripeKey: stripeKey,
webhookSecret: webhookSecret,
customerCache: make(map[string]*cachedCustomer),
priceCache: make(map[string]*cachedPrice),
holdTierCache: make(map[string]*cachedHoldTiers),
}
}
// SetCaptainChecker sets a callback that checks if a user is a hold captain.
// Captains bypass all billing feature gates.
func (m *Manager) SetCaptainChecker(fn CaptainChecker) {
m.captainChecker = fn
}
func (m *Manager) isCaptain(userDID string) bool {
return m.captainChecker != nil && userDID != "" && m.captainChecker(userDID)
}
// Enabled returns true if billing is properly configured.
func (m *Manager) Enabled() bool {
return m.cfg != nil && m.stripeKey != "" && len(m.cfg.Tiers) > 0
}
// GetWebhookLimits returns webhook limits for a user based on their subscription tier.
// Returns (maxWebhooks, allTriggers). Defaults to the lowest tier's limits.
// Hold captains get unlimited webhooks with all triggers.
func (m *Manager) GetWebhookLimits(userDID string) (int, bool) {
if m.isCaptain(userDID) {
return -1, true // unlimited
}
if !m.Enabled() {
return 1, false
}
info, err := m.GetSubscriptionInfo(userDID)
if err != nil || info == nil {
return m.cfg.Tiers[0].MaxWebhooks, m.cfg.Tiers[0].WebhookAllTriggers
}
rank := info.TierRank
if rank >= 0 && rank < len(m.cfg.Tiers) {
return m.cfg.Tiers[rank].MaxWebhooks, m.cfg.Tiers[rank].WebhookAllTriggers
}
return m.cfg.Tiers[0].MaxWebhooks, m.cfg.Tiers[0].WebhookAllTriggers
}
// HasAIAdvisor returns whether a user has access to the AI Image Advisor based on their subscription tier.
// Hold captains always have access.
func (m *Manager) HasAIAdvisor(userDID string) bool {
if m.isCaptain(userDID) {
return true
}
if !m.Enabled() {
return false
}
info, err := m.GetSubscriptionInfo(userDID)
if err != nil || info == nil {
return m.cfg.Tiers[0].AIAdvisor
}
rank := info.TierRank
if rank >= 0 && rank < len(m.cfg.Tiers) {
return m.cfg.Tiers[rank].AIAdvisor
}
return m.cfg.Tiers[0].AIAdvisor
}
// GetSupporterBadge returns the supporter badge tier name for a user based on their subscription.
// Returns the tier name if the user's current tier has supporter badges enabled, empty string otherwise.
// Hold captains get a "Captain" badge.
func (m *Manager) GetSupporterBadge(userDID string) string {
if m.isCaptain(userDID) {
return "Captain"
}
if !m.Enabled() {
return ""
}
info, err := m.GetSubscriptionInfo(userDID)
if err != nil || info == nil {
return ""
}
for _, tier := range info.Tiers {
if tier.ID == info.CurrentTier && tier.SupporterBadge {
return info.CurrentTier
}
}
return ""
}
// GetSubscriptionInfo returns subscription and tier information for a user.
// Hold captains see a special "Captain" tier with all features unlocked.
func (m *Manager) GetSubscriptionInfo(userDID string) (*SubscriptionInfo, error) {
if m.isCaptain(userDID) {
return &SubscriptionInfo{
UserDID: userDID,
CurrentTier: "Captain",
TierRank: -1, // above all configured tiers
Tiers: []TierInfo{{
ID: "Captain",
Name: "Captain",
Description: "Hold operator",
Features: []string{"Unlimited storage", "Unlimited webhooks", "All webhook triggers", "Scan on push"},
Rank: -1,
MaxWebhooks: -1,
WebhookAllTriggers: true,
SupporterBadge: true,
IsCurrent: true,
}},
}, nil
}
if !m.Enabled() {
return nil, ErrBillingDisabled
}
info := &SubscriptionInfo{
UserDID: userDID,
PaymentsEnabled: true,
CurrentTier: m.cfg.Tiers[0].Name, // default to lowest
TierRank: 0,
}
// Build tier list with live Stripe prices
info.Tiers = make([]TierInfo, len(m.cfg.Tiers))
for i, tier := range m.cfg.Tiers {
// Dynamic features: hold-derived first, then webhook limits, then static config
features := m.aggregateHoldFeatures(i)
features = append(features, webhookFeatures(tier.MaxWebhooks, tier.WebhookAllTriggers)...)
features = append(features, aiAdvisorFeatures(tier.AIAdvisor)...)
if tier.SupporterBadge {
features = append(features, "Supporter badge")
}
features = append(features, tier.Features...)
info.Tiers[i] = TierInfo{
ID: tier.Name,
Name: tier.Name,
Description: tier.Description,
Features: features,
Rank: i,
MaxWebhooks: tier.MaxWebhooks,
WebhookAllTriggers: tier.WebhookAllTriggers,
SupporterBadge: tier.SupporterBadge,
}
if tier.StripePriceMonthly != "" {
if amount, err := m.fetchPrice(tier.StripePriceMonthly); err == nil {
info.Tiers[i].PriceCentsMonthly = int(amount)
}
}
if tier.StripePriceYearly != "" {
if amount, err := m.fetchPrice(tier.StripePriceYearly); err == nil {
info.Tiers[i].PriceCentsYearly = int(amount)
}
}
}
if userDID == "" {
return info, nil
}
// Find Stripe customer for this user
cust, err := m.findCustomerByDID(userDID)
if err != nil {
slog.Debug("No Stripe customer found", "userDID", userDID, "error", err)
return info, nil
}
info.CustomerID = cust.ID
// Find active subscription
params := &stripe.SubscriptionListParams{}
params.Filters.AddFilter("customer", "", cust.ID)
params.Filters.AddFilter("status", "", "active")
iter := subscription.List(params)
for iter.Next() {
sub := iter.Subscription()
info.SubscriptionID = sub.ID
if sub.Items != nil && len(sub.Items.Data) > 0 {
priceID := sub.Items.Data[0].Price.ID
tierName, tierRank := m.cfg.GetTierByPriceID(priceID)
if tierName != "" {
info.CurrentTier = tierName
info.TierRank = tierRank
}
if sub.Items.Data[0].Price.Recurring != nil {
switch sub.Items.Data[0].Price.Recurring.Interval {
case stripe.PriceRecurringIntervalMonth:
info.BillingInterval = "monthly"
case stripe.PriceRecurringIntervalYear:
info.BillingInterval = "yearly"
}
}
}
break
}
// Mark current tier
for i := range info.Tiers {
info.Tiers[i].IsCurrent = info.Tiers[i].ID == info.CurrentTier
}
return info, nil
}
// CreateCheckoutSession creates a Stripe checkout session for a subscription.
func (m *Manager) CreateCheckoutSession(r *http.Request, userDID, userHandle string, req *CheckoutSessionRequest) (*CheckoutSessionResponse, error) {
if !m.Enabled() {
return nil, ErrBillingDisabled
}
// Find the tier config
rank := m.cfg.TierRank(req.Tier)
if rank < 0 {
return nil, fmt.Errorf("unknown tier: %s", req.Tier)
}
tierCfg := m.cfg.Tiers[rank]
// Determine price ID: prefer monthly so Stripe upsell can offer yearly toggle,
// fall back to yearly if no monthly price exists.
var priceID string
if req.Interval == "yearly" && tierCfg.StripePriceYearly != "" {
priceID = tierCfg.StripePriceYearly
} else if tierCfg.StripePriceMonthly != "" {
priceID = tierCfg.StripePriceMonthly
} else if tierCfg.StripePriceYearly != "" {
priceID = tierCfg.StripePriceYearly
}
if priceID == "" {
return nil, fmt.Errorf("tier %s has no Stripe price configured", req.Tier)
}
// Get or create Stripe customer
cust, err := m.getOrCreateCustomer(userDID, userHandle)
if err != nil {
return nil, fmt.Errorf("failed to get/create customer: %w", err)
}
// Build success/cancel URLs
successURL := strings.ReplaceAll(m.cfg.SuccessURL, "{base_url}", m.baseURL)
cancelURL := strings.ReplaceAll(m.cfg.CancelURL, "{base_url}", m.baseURL)
params := &stripe.CheckoutSessionParams{
Customer: stripe.String(cust.ID),
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(priceID),
Quantity: stripe.Int64(1),
},
},
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
}
s, err := session.New(params)
if err != nil {
return nil, fmt.Errorf("failed to create checkout session: %w", err)
}
return &CheckoutSessionResponse{
CheckoutURL: s.URL,
SessionID: s.ID,
}, nil
}
// GetBillingPortalURL creates a Stripe billing portal session.
func (m *Manager) GetBillingPortalURL(userDID, returnURL string) (*BillingPortalResponse, error) {
if !m.Enabled() {
return nil, ErrBillingDisabled
}
cust, err := m.findCustomerByDID(userDID)
if err != nil {
return nil, fmt.Errorf("no billing account found")
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(cust.ID),
ReturnURL: stripe.String(returnURL),
}
s, err := portalsession.New(params)
if err != nil {
return nil, fmt.Errorf("failed to create portal session: %w", err)
}
return &BillingPortalResponse{PortalURL: s.URL}, nil
}
// HandleWebhook processes a Stripe webhook event.
// On subscription changes, it pushes tier updates to all managed holds.
func (m *Manager) HandleWebhook(r *http.Request) error {
if !m.Enabled() {
return ErrBillingDisabled
}
body, err := io.ReadAll(r.Body)
if err != nil {
return fmt.Errorf("failed to read webhook body: %w", err)
}
event, err := webhook.ConstructEvent(body, r.Header.Get("Stripe-Signature"), m.webhookSecret)
if err != nil {
return fmt.Errorf("webhook signature verification failed: %w", err)
}
switch event.Type {
case "checkout.session.completed":
m.handleCheckoutCompleted(event)
case "customer.subscription.created",
"customer.subscription.updated",
"customer.subscription.deleted",
"customer.subscription.paused",
"customer.subscription.resumed":
m.handleSubscriptionChange(event)
default:
slog.Debug("Ignoring Stripe event", "type", event.Type)
}
return nil
}
// handleCheckoutCompleted processes a checkout.session.completed event.
func (m *Manager) handleCheckoutCompleted(event stripe.Event) {
var cs stripe.CheckoutSession
if err := json.Unmarshal(event.Data.Raw, &cs); err != nil {
slog.Error("Failed to parse checkout session", "error", err)
return
}
slog.Info("Checkout completed", "customerID", cs.Customer.ID, "subscriptionID", cs.Subscription.ID)
// The subscription.created event will handle the tier update
}
// handleSubscriptionChange processes subscription lifecycle events.
func (m *Manager) handleSubscriptionChange(event stripe.Event) {
var sub stripe.Subscription
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
slog.Error("Failed to parse subscription", "error", err)
return
}
// Get user DID from customer metadata
userDID := m.getCustomerDID(sub.Customer.ID)
if userDID == "" {
slog.Warn("No user DID found for Stripe customer", "customerID", sub.Customer.ID)
return
}
// Determine new tier from subscription
var tierName string
var tierRank int
switch sub.Status {
case stripe.SubscriptionStatusActive:
if sub.Items != nil && len(sub.Items.Data) > 0 {
priceID := sub.Items.Data[0].Price.ID
tierName, tierRank = m.cfg.GetTierByPriceID(priceID)
}
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusPaused:
// Revert to free tier (rank 0)
tierName = m.cfg.Tiers[0].Name
tierRank = 0
default:
slog.Debug("Ignoring subscription status", "status", sub.Status)
return
}
if tierName == "" {
slog.Warn("Could not resolve tier from subscription", "priceID", sub.Items.Data[0].Price.ID)
return
}
slog.Info("Pushing tier update to managed holds",
"userDID", userDID,
"tierName", tierName,
"tierRank", tierRank,
"event", event.Type,
)
// Push tier update to all managed holds
go holdclient.UpdateCrewTierOnAllHolds(
context.Background(),
m.managedHolds,
userDID,
tierRank,
m.privateKey,
m.appviewDID,
)
// Invalidate customer cache
m.customerCacheMu.Lock()
delete(m.customerCache, userDID)
m.customerCacheMu.Unlock()
}
// getOrCreateCustomer finds or creates a Stripe customer for a DID.
func (m *Manager) getOrCreateCustomer(userDID, userHandle string) (*stripe.Customer, error) {
// Check cache
m.customerCacheMu.RLock()
if cached, ok := m.customerCache[userDID]; ok && time.Now().Before(cached.expiresAt) {
m.customerCacheMu.RUnlock()
return cached.customer, nil
}
m.customerCacheMu.RUnlock()
// Search Stripe
cust, err := m.findCustomerByDID(userDID)
if err == nil {
m.cacheCustomer(userDID, cust)
return cust, nil
}
// Create new customer
params := &stripe.CustomerParams{
Params: stripe.Params{
Metadata: map[string]string{
"user_did": userDID,
},
},
}
if userHandle != "" {
params.Name = stripe.String(userHandle)
}
cust, err = customer.New(params)
if err != nil {
return nil, fmt.Errorf("failed to create Stripe customer: %w", err)
}
m.cacheCustomer(userDID, cust)
return cust, nil
}
// findCustomerByDID searches Stripe for a customer with matching DID metadata.
func (m *Manager) findCustomerByDID(userDID string) (*stripe.Customer, error) {
params := &stripe.CustomerSearchParams{
SearchParams: stripe.SearchParams{
Query: fmt.Sprintf("metadata['user_did']:'%s'", userDID),
},
}
iter := customer.Search(params)
for iter.Next() {
return iter.Customer(), nil
}
return nil, fmt.Errorf("customer not found for DID %s", userDID)
}
// getCustomerDID retrieves the user DID from a Stripe customer's metadata.
func (m *Manager) getCustomerDID(customerID string) string {
cust, err := customer.Get(customerID, nil)
if err != nil {
slog.Error("Failed to get customer", "customerID", customerID, "error", err)
return ""
}
return cust.Metadata["user_did"]
}
// cacheCustomer stores a customer in the in-memory cache.
func (m *Manager) cacheCustomer(userDID string, cust *stripe.Customer) {
m.customerCacheMu.Lock()
m.customerCache[userDID] = &cachedCustomer{
customer: cust,
expiresAt: time.Now().Add(customerCacheTTL),
}
m.customerCacheMu.Unlock()
}
const holdTierCacheTTL = 30 * time.Minute
// RefreshHoldTiers queries all managed holds for their tier definitions and caches the results.
// It runs once immediately (with retries for holds that aren't ready yet) and then
// periodically in the background.
// Safe to call from a goroutine.
func (m *Manager) RefreshHoldTiers() {
if !m.Enabled() || len(m.managedHolds) == 0 {
return
}
// On startup, retry a few times with backoff in case holds aren't ready yet.
// This is common in docker-compose where appview starts before the hold.
const maxRetries = 5
const initialDelay = 3 * time.Second
for attempt := range maxRetries {
m.refreshHoldTiersOnce()
// Check if all managed holds are cached
m.holdTierCacheMu.RLock()
allCached := len(m.holdTierCache) == len(m.managedHolds)
m.holdTierCacheMu.RUnlock()
if allCached {
break
}
if attempt < maxRetries-1 {
delay := initialDelay * time.Duration(1<<attempt) // 3s, 6s, 12s, 24s
slog.Info("Some managed holds not yet reachable, retrying",
"attempt", attempt+1, "maxRetries", maxRetries, "retryIn", delay)
time.Sleep(delay)
}
}
ticker := time.NewTicker(holdTierCacheTTL)
defer ticker.Stop()
for range ticker.C {
m.refreshHoldTiersOnce()
}
}
func (m *Manager) refreshHoldTiersOnce() {
for _, holdDID := range m.managedHolds {
resp, err := holdclient.ListTiers(context.Background(), holdDID)
if err != nil {
slog.Warn("Failed to fetch tiers from hold", "holdDID", holdDID, "error", err)
continue
}
m.holdTierCacheMu.Lock()
m.holdTierCache[holdDID] = &cachedHoldTiers{
tiers: resp.Tiers,
expiresAt: time.Now().Add(holdTierCacheTTL),
}
m.holdTierCacheMu.Unlock()
slog.Debug("Cached tier data from hold", "holdDID", holdDID, "tierCount", len(resp.Tiers))
}
}
// aggregateHoldFeatures generates dynamic feature strings for a tier rank
// by aggregating data from all cached managed holds.
// Returns nil if no hold data is available.
func (m *Manager) aggregateHoldFeatures(rank int) []string {
m.holdTierCacheMu.RLock()
defer m.holdTierCacheMu.RUnlock()
if len(m.holdTierCache) == 0 {
return nil
}
var (
minQuota int64 = -1
maxQuota int64
scanCount int
totalHolds int
)
for _, cached := range m.holdTierCache {
if time.Now().After(cached.expiresAt) {
continue
}
if rank >= len(cached.tiers) {
continue
}
totalHolds++
tier := cached.tiers[rank]
if minQuota < 0 || tier.QuotaBytes < minQuota {
minQuota = tier.QuotaBytes
}
if tier.QuotaBytes > maxQuota {
maxQuota = tier.QuotaBytes
}
if tier.ScanOnPush {
scanCount++
}
}
if totalHolds == 0 {
return nil
}
var features []string
// Storage feature
if minQuota == maxQuota {
features = append(features, formatBytes(minQuota)+" storage")
} else {
features = append(features, formatBytes(minQuota)+"-"+formatBytes(maxQuota)+" storage")
}
// Scan on push feature
if scanCount == totalHolds {
features = append(features, "Scan on push")
} else if scanCount*2 >= totalHolds {
features = append(features, "Scan on push (most regions)")
} else if scanCount > 0 {
features = append(features, "Scan on push (some regions)")
}
return features
}
// webhookFeatures generates feature bullet strings for webhook limits.
func webhookFeatures(maxWebhooks int, allTriggers bool) []string {
var features []string
switch {
case maxWebhooks < 0:
features = append(features, "Unlimited webhooks")
case maxWebhooks == 1:
features = append(features, "1 webhook")
case maxWebhooks > 1:
features = append(features, fmt.Sprintf("%d webhooks", maxWebhooks))
}
if allTriggers {
features = append(features, "All webhook triggers")
}
return features
}
// aiAdvisorFeatures generates feature bullet strings for AI advisor access.
func aiAdvisorFeatures(enabled bool) []string {
if enabled {
return []string{"AI Image Advisor"}
}
return nil
}
// formatBytes formats bytes as a human-readable string (e.g. "5.0 GB").
func formatBytes(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
units := []string{"KB", "MB", "GB", "TB", "PB"}
return fmt.Sprintf("%.1f %s", float64(b)/float64(div), units[exp])
}
// GetFirstTierWithAllTriggers returns the name of the lowest-rank tier that has
// webhook_all_triggers enabled. Returns empty string if none found.
func (m *Manager) GetFirstTierWithAllTriggers() string {
if !m.Enabled() {
return ""
}
for _, tier := range m.cfg.Tiers {
if tier.WebhookAllTriggers {
return tier.Name
}
}
return ""
}
// fetchPrice returns the unit amount in cents for a Stripe price ID, using a cache.
func (m *Manager) fetchPrice(priceID string) (int64, error) {
m.priceCacheMu.RLock()
if cached, ok := m.priceCache[priceID]; ok && time.Now().Before(cached.expiresAt) {
m.priceCacheMu.RUnlock()
return cached.unitAmount, nil
}
m.priceCacheMu.RUnlock()
p, err := price.Get(priceID, nil)
if err != nil {
slog.Warn("Failed to fetch Stripe price", "priceID", priceID, "error", err)
return 0, err
}
m.priceCacheMu.Lock()
m.priceCache[priceID] = &cachedPrice{
unitAmount: p.UnitAmount,
expiresAt: time.Now().Add(priceCacheTTL),
}
m.priceCacheMu.Unlock()
return p.UnitAmount, nil
}