mirror of
https://tangled.org/evan.jarrett.net/at-container-registry
synced 2026-04-20 08:30:29 +00:00
777 lines
21 KiB
Go
777 lines
21 KiB
Go
//go:build billing
|
|
|
|
package billing
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"atcr.io/pkg/appview/holdclient"
|
|
"github.com/bluesky-social/indigo/atproto/atcrypto"
|
|
|
|
"github.com/stripe/stripe-go/v84"
|
|
portalsession "github.com/stripe/stripe-go/v84/billingportal/session"
|
|
"github.com/stripe/stripe-go/v84/checkout/session"
|
|
"github.com/stripe/stripe-go/v84/customer"
|
|
"github.com/stripe/stripe-go/v84/price"
|
|
"github.com/stripe/stripe-go/v84/subscription"
|
|
"github.com/stripe/stripe-go/v84/webhook"
|
|
)
|
|
|
|
// Manager handles Stripe billing and pushes tier updates to managed holds.
|
|
type Manager struct {
|
|
cfg *Config
|
|
privateKey *atcrypto.PrivateKeyP256
|
|
appviewDID string
|
|
managedHolds []string
|
|
baseURL string
|
|
stripeKey string
|
|
webhookSecret string
|
|
|
|
// Captain checker: bypasses billing for hold owners
|
|
captainChecker CaptainChecker
|
|
|
|
// Customer cache: DID → Stripe customer
|
|
customerCache map[string]*cachedCustomer
|
|
customerCacheMu sync.RWMutex
|
|
|
|
// Price cache: Stripe price ID → unit amount in cents
|
|
priceCache map[string]*cachedPrice
|
|
priceCacheMu sync.RWMutex
|
|
|
|
// Hold tier cache: holdDID → tier list
|
|
holdTierCache map[string]*cachedHoldTiers
|
|
holdTierCacheMu sync.RWMutex
|
|
}
|
|
|
|
type cachedHoldTiers struct {
|
|
tiers []holdclient.HoldTierInfo
|
|
expiresAt time.Time
|
|
}
|
|
|
|
type cachedCustomer struct {
|
|
customer *stripe.Customer
|
|
expiresAt time.Time
|
|
}
|
|
|
|
type cachedPrice struct {
|
|
unitAmount int64
|
|
expiresAt time.Time
|
|
}
|
|
|
|
const customerCacheTTL = 10 * time.Minute
|
|
const priceCacheTTL = 1 * time.Hour
|
|
|
|
// New creates a new billing manager with Stripe integration.
|
|
// Env vars STRIPE_SECRET_KEY and STRIPE_WEBHOOK_SECRET take precedence over config values.
|
|
func New(cfg *Config, privateKey *atcrypto.PrivateKeyP256, appviewDID string, managedHolds []string, baseURL string) *Manager {
|
|
stripeKey := os.Getenv("STRIPE_SECRET_KEY")
|
|
if stripeKey == "" {
|
|
stripeKey = cfg.StripeSecretKey
|
|
}
|
|
if stripeKey != "" {
|
|
stripe.Key = stripeKey
|
|
}
|
|
|
|
webhookSecret := os.Getenv("STRIPE_WEBHOOK_SECRET")
|
|
if webhookSecret == "" {
|
|
webhookSecret = cfg.WebhookSecret
|
|
}
|
|
|
|
return &Manager{
|
|
cfg: cfg,
|
|
privateKey: privateKey,
|
|
appviewDID: appviewDID,
|
|
managedHolds: managedHolds,
|
|
baseURL: baseURL,
|
|
stripeKey: stripeKey,
|
|
webhookSecret: webhookSecret,
|
|
customerCache: make(map[string]*cachedCustomer),
|
|
priceCache: make(map[string]*cachedPrice),
|
|
holdTierCache: make(map[string]*cachedHoldTiers),
|
|
}
|
|
}
|
|
|
|
// SetCaptainChecker sets a callback that checks if a user is a hold captain.
|
|
// Captains bypass all billing feature gates.
|
|
func (m *Manager) SetCaptainChecker(fn CaptainChecker) {
|
|
m.captainChecker = fn
|
|
}
|
|
|
|
func (m *Manager) isCaptain(userDID string) bool {
|
|
return m.captainChecker != nil && userDID != "" && m.captainChecker(userDID)
|
|
}
|
|
|
|
// Enabled returns true if billing is properly configured.
|
|
func (m *Manager) Enabled() bool {
|
|
return m.cfg != nil && m.stripeKey != "" && len(m.cfg.Tiers) > 0
|
|
}
|
|
|
|
// GetWebhookLimits returns webhook limits for a user based on their subscription tier.
|
|
// Returns (maxWebhooks, allTriggers). Defaults to the lowest tier's limits.
|
|
// Hold captains get unlimited webhooks with all triggers.
|
|
func (m *Manager) GetWebhookLimits(userDID string) (int, bool) {
|
|
if m.isCaptain(userDID) {
|
|
return -1, true // unlimited
|
|
}
|
|
if !m.Enabled() {
|
|
return 1, false
|
|
}
|
|
|
|
info, err := m.GetSubscriptionInfo(userDID)
|
|
if err != nil || info == nil {
|
|
return m.cfg.Tiers[0].MaxWebhooks, m.cfg.Tiers[0].WebhookAllTriggers
|
|
}
|
|
|
|
rank := info.TierRank
|
|
if rank >= 0 && rank < len(m.cfg.Tiers) {
|
|
return m.cfg.Tiers[rank].MaxWebhooks, m.cfg.Tiers[rank].WebhookAllTriggers
|
|
}
|
|
|
|
return m.cfg.Tiers[0].MaxWebhooks, m.cfg.Tiers[0].WebhookAllTriggers
|
|
}
|
|
|
|
// HasAIAdvisor returns whether a user has access to the AI Image Advisor based on their subscription tier.
|
|
// Hold captains always have access.
|
|
func (m *Manager) HasAIAdvisor(userDID string) bool {
|
|
if m.isCaptain(userDID) {
|
|
return true
|
|
}
|
|
if !m.Enabled() {
|
|
return false
|
|
}
|
|
|
|
info, err := m.GetSubscriptionInfo(userDID)
|
|
if err != nil || info == nil {
|
|
return m.cfg.Tiers[0].AIAdvisor
|
|
}
|
|
|
|
rank := info.TierRank
|
|
if rank >= 0 && rank < len(m.cfg.Tiers) {
|
|
return m.cfg.Tiers[rank].AIAdvisor
|
|
}
|
|
|
|
return m.cfg.Tiers[0].AIAdvisor
|
|
}
|
|
|
|
// GetSupporterBadge returns the supporter badge tier name for a user based on their subscription.
|
|
// Returns the tier name if the user's current tier has supporter badges enabled, empty string otherwise.
|
|
// Hold captains get a "Captain" badge.
|
|
func (m *Manager) GetSupporterBadge(userDID string) string {
|
|
if m.isCaptain(userDID) {
|
|
return "Captain"
|
|
}
|
|
if !m.Enabled() {
|
|
return ""
|
|
}
|
|
|
|
info, err := m.GetSubscriptionInfo(userDID)
|
|
if err != nil || info == nil {
|
|
return ""
|
|
}
|
|
|
|
for _, tier := range info.Tiers {
|
|
if tier.ID == info.CurrentTier && tier.SupporterBadge {
|
|
return info.CurrentTier
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// GetSubscriptionInfo returns subscription and tier information for a user.
|
|
// Hold captains see a special "Captain" tier with all features unlocked.
|
|
func (m *Manager) GetSubscriptionInfo(userDID string) (*SubscriptionInfo, error) {
|
|
if m.isCaptain(userDID) {
|
|
return &SubscriptionInfo{
|
|
UserDID: userDID,
|
|
CurrentTier: "Captain",
|
|
TierRank: -1, // above all configured tiers
|
|
Tiers: []TierInfo{{
|
|
ID: "Captain",
|
|
Name: "Captain",
|
|
Description: "Hold operator",
|
|
Features: []string{"Unlimited storage", "Unlimited webhooks", "All webhook triggers", "Scan on push"},
|
|
Rank: -1,
|
|
MaxWebhooks: -1,
|
|
WebhookAllTriggers: true,
|
|
SupporterBadge: true,
|
|
IsCurrent: true,
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
if !m.Enabled() {
|
|
return nil, ErrBillingDisabled
|
|
}
|
|
|
|
info := &SubscriptionInfo{
|
|
UserDID: userDID,
|
|
PaymentsEnabled: true,
|
|
CurrentTier: m.cfg.Tiers[0].Name, // default to lowest
|
|
TierRank: 0,
|
|
}
|
|
|
|
// Build tier list with live Stripe prices
|
|
info.Tiers = make([]TierInfo, len(m.cfg.Tiers))
|
|
for i, tier := range m.cfg.Tiers {
|
|
// Dynamic features: hold-derived first, then webhook limits, then static config
|
|
features := m.aggregateHoldFeatures(i)
|
|
features = append(features, webhookFeatures(tier.MaxWebhooks, tier.WebhookAllTriggers)...)
|
|
features = append(features, aiAdvisorFeatures(tier.AIAdvisor)...)
|
|
if tier.SupporterBadge {
|
|
features = append(features, "Supporter badge")
|
|
}
|
|
features = append(features, tier.Features...)
|
|
info.Tiers[i] = TierInfo{
|
|
ID: tier.Name,
|
|
Name: tier.Name,
|
|
Description: tier.Description,
|
|
Features: features,
|
|
Rank: i,
|
|
MaxWebhooks: tier.MaxWebhooks,
|
|
WebhookAllTriggers: tier.WebhookAllTriggers,
|
|
SupporterBadge: tier.SupporterBadge,
|
|
}
|
|
if tier.StripePriceMonthly != "" {
|
|
if amount, err := m.fetchPrice(tier.StripePriceMonthly); err == nil {
|
|
info.Tiers[i].PriceCentsMonthly = int(amount)
|
|
}
|
|
}
|
|
if tier.StripePriceYearly != "" {
|
|
if amount, err := m.fetchPrice(tier.StripePriceYearly); err == nil {
|
|
info.Tiers[i].PriceCentsYearly = int(amount)
|
|
}
|
|
}
|
|
}
|
|
|
|
if userDID == "" {
|
|
return info, nil
|
|
}
|
|
|
|
// Find Stripe customer for this user
|
|
cust, err := m.findCustomerByDID(userDID)
|
|
if err != nil {
|
|
slog.Debug("No Stripe customer found", "userDID", userDID, "error", err)
|
|
return info, nil
|
|
}
|
|
info.CustomerID = cust.ID
|
|
|
|
// Find active subscription
|
|
params := &stripe.SubscriptionListParams{}
|
|
params.Filters.AddFilter("customer", "", cust.ID)
|
|
params.Filters.AddFilter("status", "", "active")
|
|
iter := subscription.List(params)
|
|
|
|
for iter.Next() {
|
|
sub := iter.Subscription()
|
|
info.SubscriptionID = sub.ID
|
|
|
|
if sub.Items != nil && len(sub.Items.Data) > 0 {
|
|
priceID := sub.Items.Data[0].Price.ID
|
|
tierName, tierRank := m.cfg.GetTierByPriceID(priceID)
|
|
if tierName != "" {
|
|
info.CurrentTier = tierName
|
|
info.TierRank = tierRank
|
|
}
|
|
|
|
if sub.Items.Data[0].Price.Recurring != nil {
|
|
switch sub.Items.Data[0].Price.Recurring.Interval {
|
|
case stripe.PriceRecurringIntervalMonth:
|
|
info.BillingInterval = "monthly"
|
|
case stripe.PriceRecurringIntervalYear:
|
|
info.BillingInterval = "yearly"
|
|
}
|
|
}
|
|
}
|
|
break
|
|
}
|
|
|
|
// Mark current tier
|
|
for i := range info.Tiers {
|
|
info.Tiers[i].IsCurrent = info.Tiers[i].ID == info.CurrentTier
|
|
}
|
|
|
|
return info, nil
|
|
}
|
|
|
|
// CreateCheckoutSession creates a Stripe checkout session for a subscription.
|
|
func (m *Manager) CreateCheckoutSession(r *http.Request, userDID, userHandle string, req *CheckoutSessionRequest) (*CheckoutSessionResponse, error) {
|
|
if !m.Enabled() {
|
|
return nil, ErrBillingDisabled
|
|
}
|
|
|
|
// Find the tier config
|
|
rank := m.cfg.TierRank(req.Tier)
|
|
if rank < 0 {
|
|
return nil, fmt.Errorf("unknown tier: %s", req.Tier)
|
|
}
|
|
tierCfg := m.cfg.Tiers[rank]
|
|
|
|
// Determine price ID: prefer monthly so Stripe upsell can offer yearly toggle,
|
|
// fall back to yearly if no monthly price exists.
|
|
var priceID string
|
|
if req.Interval == "yearly" && tierCfg.StripePriceYearly != "" {
|
|
priceID = tierCfg.StripePriceYearly
|
|
} else if tierCfg.StripePriceMonthly != "" {
|
|
priceID = tierCfg.StripePriceMonthly
|
|
} else if tierCfg.StripePriceYearly != "" {
|
|
priceID = tierCfg.StripePriceYearly
|
|
}
|
|
if priceID == "" {
|
|
return nil, fmt.Errorf("tier %s has no Stripe price configured", req.Tier)
|
|
}
|
|
|
|
// Get or create Stripe customer
|
|
cust, err := m.getOrCreateCustomer(userDID, userHandle)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get/create customer: %w", err)
|
|
}
|
|
|
|
// Build success/cancel URLs
|
|
successURL := strings.ReplaceAll(m.cfg.SuccessURL, "{base_url}", m.baseURL)
|
|
cancelURL := strings.ReplaceAll(m.cfg.CancelURL, "{base_url}", m.baseURL)
|
|
|
|
params := &stripe.CheckoutSessionParams{
|
|
Customer: stripe.String(cust.ID),
|
|
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
|
{
|
|
Price: stripe.String(priceID),
|
|
Quantity: stripe.Int64(1),
|
|
},
|
|
},
|
|
SuccessURL: stripe.String(successURL),
|
|
CancelURL: stripe.String(cancelURL),
|
|
}
|
|
|
|
s, err := session.New(params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create checkout session: %w", err)
|
|
}
|
|
|
|
return &CheckoutSessionResponse{
|
|
CheckoutURL: s.URL,
|
|
SessionID: s.ID,
|
|
}, nil
|
|
}
|
|
|
|
// GetBillingPortalURL creates a Stripe billing portal session.
|
|
func (m *Manager) GetBillingPortalURL(userDID, returnURL string) (*BillingPortalResponse, error) {
|
|
if !m.Enabled() {
|
|
return nil, ErrBillingDisabled
|
|
}
|
|
|
|
cust, err := m.findCustomerByDID(userDID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no billing account found")
|
|
}
|
|
|
|
params := &stripe.BillingPortalSessionParams{
|
|
Customer: stripe.String(cust.ID),
|
|
ReturnURL: stripe.String(returnURL),
|
|
}
|
|
|
|
s, err := portalsession.New(params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create portal session: %w", err)
|
|
}
|
|
|
|
return &BillingPortalResponse{PortalURL: s.URL}, nil
|
|
}
|
|
|
|
// HandleWebhook processes a Stripe webhook event.
|
|
// On subscription changes, it pushes tier updates to all managed holds.
|
|
func (m *Manager) HandleWebhook(r *http.Request) error {
|
|
if !m.Enabled() {
|
|
return ErrBillingDisabled
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read webhook body: %w", err)
|
|
}
|
|
|
|
event, err := webhook.ConstructEvent(body, r.Header.Get("Stripe-Signature"), m.webhookSecret)
|
|
if err != nil {
|
|
return fmt.Errorf("webhook signature verification failed: %w", err)
|
|
}
|
|
|
|
switch event.Type {
|
|
case "checkout.session.completed":
|
|
m.handleCheckoutCompleted(event)
|
|
case "customer.subscription.created",
|
|
"customer.subscription.updated",
|
|
"customer.subscription.deleted",
|
|
"customer.subscription.paused",
|
|
"customer.subscription.resumed":
|
|
m.handleSubscriptionChange(event)
|
|
default:
|
|
slog.Debug("Ignoring Stripe event", "type", event.Type)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleCheckoutCompleted processes a checkout.session.completed event.
|
|
func (m *Manager) handleCheckoutCompleted(event stripe.Event) {
|
|
var cs stripe.CheckoutSession
|
|
if err := json.Unmarshal(event.Data.Raw, &cs); err != nil {
|
|
slog.Error("Failed to parse checkout session", "error", err)
|
|
return
|
|
}
|
|
|
|
slog.Info("Checkout completed", "customerID", cs.Customer.ID, "subscriptionID", cs.Subscription.ID)
|
|
|
|
// The subscription.created event will handle the tier update
|
|
}
|
|
|
|
// handleSubscriptionChange processes subscription lifecycle events.
|
|
func (m *Manager) handleSubscriptionChange(event stripe.Event) {
|
|
var sub stripe.Subscription
|
|
if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
|
|
slog.Error("Failed to parse subscription", "error", err)
|
|
return
|
|
}
|
|
|
|
// Get user DID from customer metadata
|
|
userDID := m.getCustomerDID(sub.Customer.ID)
|
|
if userDID == "" {
|
|
slog.Warn("No user DID found for Stripe customer", "customerID", sub.Customer.ID)
|
|
return
|
|
}
|
|
|
|
// Determine new tier from subscription
|
|
var tierName string
|
|
var tierRank int
|
|
|
|
switch sub.Status {
|
|
case stripe.SubscriptionStatusActive:
|
|
if sub.Items != nil && len(sub.Items.Data) > 0 {
|
|
priceID := sub.Items.Data[0].Price.ID
|
|
tierName, tierRank = m.cfg.GetTierByPriceID(priceID)
|
|
}
|
|
case stripe.SubscriptionStatusCanceled, stripe.SubscriptionStatusPaused:
|
|
// Revert to free tier (rank 0)
|
|
tierName = m.cfg.Tiers[0].Name
|
|
tierRank = 0
|
|
default:
|
|
slog.Debug("Ignoring subscription status", "status", sub.Status)
|
|
return
|
|
}
|
|
|
|
if tierName == "" {
|
|
slog.Warn("Could not resolve tier from subscription", "priceID", sub.Items.Data[0].Price.ID)
|
|
return
|
|
}
|
|
|
|
slog.Info("Pushing tier update to managed holds",
|
|
"userDID", userDID,
|
|
"tierName", tierName,
|
|
"tierRank", tierRank,
|
|
"event", event.Type,
|
|
)
|
|
|
|
// Push tier update to all managed holds
|
|
go holdclient.UpdateCrewTierOnAllHolds(
|
|
context.Background(),
|
|
m.managedHolds,
|
|
userDID,
|
|
tierRank,
|
|
m.privateKey,
|
|
m.appviewDID,
|
|
)
|
|
|
|
// Invalidate customer cache
|
|
m.customerCacheMu.Lock()
|
|
delete(m.customerCache, userDID)
|
|
m.customerCacheMu.Unlock()
|
|
}
|
|
|
|
// getOrCreateCustomer finds or creates a Stripe customer for a DID.
|
|
func (m *Manager) getOrCreateCustomer(userDID, userHandle string) (*stripe.Customer, error) {
|
|
// Check cache
|
|
m.customerCacheMu.RLock()
|
|
if cached, ok := m.customerCache[userDID]; ok && time.Now().Before(cached.expiresAt) {
|
|
m.customerCacheMu.RUnlock()
|
|
return cached.customer, nil
|
|
}
|
|
m.customerCacheMu.RUnlock()
|
|
|
|
// Search Stripe
|
|
cust, err := m.findCustomerByDID(userDID)
|
|
if err == nil {
|
|
m.cacheCustomer(userDID, cust)
|
|
return cust, nil
|
|
}
|
|
|
|
// Create new customer
|
|
params := &stripe.CustomerParams{
|
|
Params: stripe.Params{
|
|
Metadata: map[string]string{
|
|
"user_did": userDID,
|
|
},
|
|
},
|
|
}
|
|
if userHandle != "" {
|
|
params.Name = stripe.String(userHandle)
|
|
}
|
|
|
|
cust, err = customer.New(params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Stripe customer: %w", err)
|
|
}
|
|
|
|
m.cacheCustomer(userDID, cust)
|
|
return cust, nil
|
|
}
|
|
|
|
// findCustomerByDID searches Stripe for a customer with matching DID metadata.
|
|
func (m *Manager) findCustomerByDID(userDID string) (*stripe.Customer, error) {
|
|
params := &stripe.CustomerSearchParams{
|
|
SearchParams: stripe.SearchParams{
|
|
Query: fmt.Sprintf("metadata['user_did']:'%s'", userDID),
|
|
},
|
|
}
|
|
|
|
iter := customer.Search(params)
|
|
for iter.Next() {
|
|
return iter.Customer(), nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("customer not found for DID %s", userDID)
|
|
}
|
|
|
|
// getCustomerDID retrieves the user DID from a Stripe customer's metadata.
|
|
func (m *Manager) getCustomerDID(customerID string) string {
|
|
cust, err := customer.Get(customerID, nil)
|
|
if err != nil {
|
|
slog.Error("Failed to get customer", "customerID", customerID, "error", err)
|
|
return ""
|
|
}
|
|
return cust.Metadata["user_did"]
|
|
}
|
|
|
|
// cacheCustomer stores a customer in the in-memory cache.
|
|
func (m *Manager) cacheCustomer(userDID string, cust *stripe.Customer) {
|
|
m.customerCacheMu.Lock()
|
|
m.customerCache[userDID] = &cachedCustomer{
|
|
customer: cust,
|
|
expiresAt: time.Now().Add(customerCacheTTL),
|
|
}
|
|
m.customerCacheMu.Unlock()
|
|
}
|
|
|
|
const holdTierCacheTTL = 30 * time.Minute
|
|
|
|
// RefreshHoldTiers queries all managed holds for their tier definitions and caches the results.
|
|
// It runs once immediately (with retries for holds that aren't ready yet) and then
|
|
// periodically in the background.
|
|
// Safe to call from a goroutine.
|
|
func (m *Manager) RefreshHoldTiers() {
|
|
if !m.Enabled() || len(m.managedHolds) == 0 {
|
|
return
|
|
}
|
|
|
|
// On startup, retry a few times with backoff in case holds aren't ready yet.
|
|
// This is common in docker-compose where appview starts before the hold.
|
|
const maxRetries = 5
|
|
const initialDelay = 3 * time.Second
|
|
|
|
for attempt := range maxRetries {
|
|
m.refreshHoldTiersOnce()
|
|
|
|
// Check if all managed holds are cached
|
|
m.holdTierCacheMu.RLock()
|
|
allCached := len(m.holdTierCache) == len(m.managedHolds)
|
|
m.holdTierCacheMu.RUnlock()
|
|
|
|
if allCached {
|
|
break
|
|
}
|
|
|
|
if attempt < maxRetries-1 {
|
|
delay := initialDelay * time.Duration(1<<attempt) // 3s, 6s, 12s, 24s
|
|
slog.Info("Some managed holds not yet reachable, retrying",
|
|
"attempt", attempt+1, "maxRetries", maxRetries, "retryIn", delay)
|
|
time.Sleep(delay)
|
|
}
|
|
}
|
|
|
|
ticker := time.NewTicker(holdTierCacheTTL)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
m.refreshHoldTiersOnce()
|
|
}
|
|
}
|
|
|
|
func (m *Manager) refreshHoldTiersOnce() {
|
|
for _, holdDID := range m.managedHolds {
|
|
resp, err := holdclient.ListTiers(context.Background(), holdDID)
|
|
if err != nil {
|
|
slog.Warn("Failed to fetch tiers from hold", "holdDID", holdDID, "error", err)
|
|
continue
|
|
}
|
|
|
|
m.holdTierCacheMu.Lock()
|
|
m.holdTierCache[holdDID] = &cachedHoldTiers{
|
|
tiers: resp.Tiers,
|
|
expiresAt: time.Now().Add(holdTierCacheTTL),
|
|
}
|
|
m.holdTierCacheMu.Unlock()
|
|
|
|
slog.Debug("Cached tier data from hold", "holdDID", holdDID, "tierCount", len(resp.Tiers))
|
|
}
|
|
}
|
|
|
|
// aggregateHoldFeatures generates dynamic feature strings for a tier rank
|
|
// by aggregating data from all cached managed holds.
|
|
// Returns nil if no hold data is available.
|
|
func (m *Manager) aggregateHoldFeatures(rank int) []string {
|
|
m.holdTierCacheMu.RLock()
|
|
defer m.holdTierCacheMu.RUnlock()
|
|
|
|
if len(m.holdTierCache) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
minQuota int64 = -1
|
|
maxQuota int64
|
|
scanCount int
|
|
totalHolds int
|
|
)
|
|
|
|
for _, cached := range m.holdTierCache {
|
|
if time.Now().After(cached.expiresAt) {
|
|
continue
|
|
}
|
|
if rank >= len(cached.tiers) {
|
|
continue
|
|
}
|
|
totalHolds++
|
|
tier := cached.tiers[rank]
|
|
|
|
if minQuota < 0 || tier.QuotaBytes < minQuota {
|
|
minQuota = tier.QuotaBytes
|
|
}
|
|
if tier.QuotaBytes > maxQuota {
|
|
maxQuota = tier.QuotaBytes
|
|
}
|
|
if tier.ScanOnPush {
|
|
scanCount++
|
|
}
|
|
}
|
|
|
|
if totalHolds == 0 {
|
|
return nil
|
|
}
|
|
|
|
var features []string
|
|
|
|
// Storage feature
|
|
if minQuota == maxQuota {
|
|
features = append(features, formatBytes(minQuota)+" storage")
|
|
} else {
|
|
features = append(features, formatBytes(minQuota)+"-"+formatBytes(maxQuota)+" storage")
|
|
}
|
|
|
|
// Scan on push feature
|
|
if scanCount == totalHolds {
|
|
features = append(features, "Scan on push")
|
|
} else if scanCount*2 >= totalHolds {
|
|
features = append(features, "Scan on push (most regions)")
|
|
} else if scanCount > 0 {
|
|
features = append(features, "Scan on push (some regions)")
|
|
}
|
|
|
|
return features
|
|
}
|
|
|
|
// webhookFeatures generates feature bullet strings for webhook limits.
|
|
func webhookFeatures(maxWebhooks int, allTriggers bool) []string {
|
|
var features []string
|
|
switch {
|
|
case maxWebhooks < 0:
|
|
features = append(features, "Unlimited webhooks")
|
|
case maxWebhooks == 1:
|
|
features = append(features, "1 webhook")
|
|
case maxWebhooks > 1:
|
|
features = append(features, fmt.Sprintf("%d webhooks", maxWebhooks))
|
|
}
|
|
if allTriggers {
|
|
features = append(features, "All webhook triggers")
|
|
}
|
|
return features
|
|
}
|
|
|
|
// aiAdvisorFeatures generates feature bullet strings for AI advisor access.
|
|
func aiAdvisorFeatures(enabled bool) []string {
|
|
if enabled {
|
|
return []string{"AI Image Advisor"}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// formatBytes formats bytes as a human-readable string (e.g. "5.0 GB").
|
|
func formatBytes(b int64) string {
|
|
const unit = 1024
|
|
if b < unit {
|
|
return fmt.Sprintf("%d B", b)
|
|
}
|
|
div, exp := int64(unit), 0
|
|
for n := b / unit; n >= unit; n /= unit {
|
|
div *= unit
|
|
exp++
|
|
}
|
|
units := []string{"KB", "MB", "GB", "TB", "PB"}
|
|
return fmt.Sprintf("%.1f %s", float64(b)/float64(div), units[exp])
|
|
}
|
|
|
|
// GetFirstTierWithAllTriggers returns the name of the lowest-rank tier that has
|
|
// webhook_all_triggers enabled. Returns empty string if none found.
|
|
func (m *Manager) GetFirstTierWithAllTriggers() string {
|
|
if !m.Enabled() {
|
|
return ""
|
|
}
|
|
for _, tier := range m.cfg.Tiers {
|
|
if tier.WebhookAllTriggers {
|
|
return tier.Name
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// fetchPrice returns the unit amount in cents for a Stripe price ID, using a cache.
|
|
func (m *Manager) fetchPrice(priceID string) (int64, error) {
|
|
m.priceCacheMu.RLock()
|
|
if cached, ok := m.priceCache[priceID]; ok && time.Now().Before(cached.expiresAt) {
|
|
m.priceCacheMu.RUnlock()
|
|
return cached.unitAmount, nil
|
|
}
|
|
m.priceCacheMu.RUnlock()
|
|
|
|
p, err := price.Get(priceID, nil)
|
|
if err != nil {
|
|
slog.Warn("Failed to fetch Stripe price", "priceID", priceID, "error", err)
|
|
return 0, err
|
|
}
|
|
|
|
m.priceCacheMu.Lock()
|
|
m.priceCache[priceID] = &cachedPrice{
|
|
unitAmount: p.UnitAmount,
|
|
expiresAt: time.Now().Add(priceCacheTTL),
|
|
}
|
|
m.priceCacheMu.Unlock()
|
|
|
|
return p.UnitAmount, nil
|
|
}
|