Files
2025-12-29 17:02:07 -06:00

159 lines
4.6 KiB
Go

// Package middleware provides HTTP middleware for AppView, including
// authentication (session-based for web UI, token-based for registry),
// identity resolution (handle/DID to PDS endpoint), and hold discovery
// for routing blobs to storage endpoints.
package middleware
import (
"context"
"database/sql"
"net/http"
"net/url"
"atcr.io/pkg/appview/db"
"atcr.io/pkg/auth"
"atcr.io/pkg/auth/oauth"
)
type contextKey string
const userKey contextKey = "user"
// WebAuthDeps contains dependencies for web auth middleware
type WebAuthDeps struct {
SessionStore *db.SessionStore
Database *sql.DB
Refresher *oauth.Refresher
DefaultHoldDID string
}
// RequireAuth is middleware that requires authentication
func RequireAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
return RequireAuthWithDeps(WebAuthDeps{
SessionStore: store,
Database: database,
})
}
// RequireAuthWithDeps is middleware that requires authentication and creates UserContext
func RequireAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID, ok := getSessionID(r)
if !ok {
// Build return URL with query parameters preserved
returnTo := r.URL.Path
if r.URL.RawQuery != "" {
returnTo = r.URL.Path + "?" + r.URL.RawQuery
}
http.Redirect(w, r, "/auth/oauth/login?return_to="+url.QueryEscape(returnTo), http.StatusFound)
return
}
sess, ok := deps.SessionStore.Get(sessionID)
if !ok {
// Build return URL with query parameters preserved
returnTo := r.URL.Path
if r.URL.RawQuery != "" {
returnTo = r.URL.Path + "?" + r.URL.RawQuery
}
http.Redirect(w, r, "/auth/oauth/login?return_to="+url.QueryEscape(returnTo), http.StatusFound)
return
}
// Look up full user from database to get avatar
user, err := db.GetUserByDID(deps.Database, sess.DID)
if err != nil || user == nil {
// Fallback to session data if DB lookup fails
user = &db.User{
DID: sess.DID,
Handle: sess.Handle,
PDSEndpoint: sess.PDSEndpoint,
}
}
ctx := r.Context()
ctx = context.WithValue(ctx, userKey, user)
// Create UserContext for authenticated users (enables EnsureUserSetup)
if deps.Refresher != nil {
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
Refresher: deps.Refresher,
DefaultHoldDID: deps.DefaultHoldDID,
})
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
userCtx.EnsureUserSetup()
ctx = auth.WithUserContext(ctx, userCtx)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// OptionalAuth is middleware that optionally includes user if authenticated
func OptionalAuth(store *db.SessionStore, database *sql.DB) func(http.Handler) http.Handler {
return OptionalAuthWithDeps(WebAuthDeps{
SessionStore: store,
Database: database,
})
}
// OptionalAuthWithDeps is middleware that optionally includes user and UserContext if authenticated
func OptionalAuthWithDeps(deps WebAuthDeps) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionID, ok := getSessionID(r)
if ok {
if sess, ok := deps.SessionStore.Get(sessionID); ok {
// Look up full user from database to get avatar
user, err := db.GetUserByDID(deps.Database, sess.DID)
if err != nil || user == nil {
// Fallback to session data if DB lookup fails
user = &db.User{
DID: sess.DID,
Handle: sess.Handle,
PDSEndpoint: sess.PDSEndpoint,
}
}
ctx := r.Context()
ctx = context.WithValue(ctx, userKey, user)
// Create UserContext for authenticated users (enables EnsureUserSetup)
if deps.Refresher != nil {
userCtx := auth.NewUserContext(sess.DID, auth.AuthMethodOAuth, r.Method, &auth.Dependencies{
Refresher: deps.Refresher,
DefaultHoldDID: deps.DefaultHoldDID,
})
userCtx.SetPDS(sess.Handle, sess.PDSEndpoint)
userCtx.EnsureUserSetup()
ctx = auth.WithUserContext(ctx, userCtx)
}
r = r.WithContext(ctx)
}
}
next.ServeHTTP(w, r)
})
}
}
// getSessionID gets session ID from cookie
func getSessionID(r *http.Request) (string, bool) {
cookie, err := r.Cookie("atcr_session")
if err != nil {
return "", false
}
return cookie.Value, true
}
// GetUser retrieves the user from the request context
func GetUser(r *http.Request) *db.User {
user, ok := r.Context().Value(userKey).(*db.User)
if !ok {
return nil
}
return user
}