mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-11-26 03:23:07 +01:00
Refactor auth package
Fix security issues Remove redundant code Add documentation Add units tests
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -32,8 +31,19 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// Authenticator is the interface for all authentication methods.
|
||||
// Each authenticator determines if it can handle a login request (CanLogin)
|
||||
// and performs the actual authentication (Login).
|
||||
type Authenticator interface {
|
||||
// CanLogin determines if this authenticator can handle the login request.
|
||||
// It returns the user object if available and a boolean indicating if this
|
||||
// authenticator should attempt the login. This method should not perform
|
||||
// expensive operations or actual authentication.
|
||||
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
|
||||
|
||||
// Login performs the actually authentication for the user.
|
||||
// It returns the authenticated user or an error if authentication fails.
|
||||
// The user parameter may be nil if the user doesn't exist in the database yet.
|
||||
Login(user *schema.User, rw http.ResponseWriter, r *http.Request) (*schema.User, error)
|
||||
}
|
||||
|
||||
@@ -42,27 +52,70 @@ var (
|
||||
authInstance *Authentication
|
||||
)
|
||||
|
||||
var ipUserLimiters sync.Map
|
||||
|
||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||
key := ip + ":" + username
|
||||
limiter, ok := ipUserLimiters.Load(key)
|
||||
if !ok {
|
||||
newLimiter := rate.NewLimiter(rate.Every(time.Hour/10), 10)
|
||||
ipUserLimiters.Store(key, newLimiter)
|
||||
return newLimiter
|
||||
}
|
||||
return limiter.(*rate.Limiter)
|
||||
// rateLimiterEntry tracks a rate limiter and its last use time for cleanup
|
||||
type rateLimiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
var ipUserLimiters sync.Map
|
||||
|
||||
// getIPUserLimiter returns a rate limiter for the given IP and username combination.
|
||||
// Rate limiters are created on demand and track 5 attempts per 15 minutes.
|
||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||
key := ip + ":" + username
|
||||
now := time.Now()
|
||||
|
||||
if entry, ok := ipUserLimiters.Load(key); ok {
|
||||
rle := entry.(*rateLimiterEntry)
|
||||
rle.lastUsed = now
|
||||
return rle.limiter
|
||||
}
|
||||
|
||||
// More aggressive rate limiting: 5 attempts per 15 minutes
|
||||
newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5)
|
||||
ipUserLimiters.Store(key, &rateLimiterEntry{
|
||||
limiter: newLimiter,
|
||||
lastUsed: now,
|
||||
})
|
||||
return newLimiter
|
||||
}
|
||||
|
||||
// cleanupOldRateLimiters removes rate limiters that haven't been used recently
|
||||
func cleanupOldRateLimiters(olderThan time.Time) {
|
||||
ipUserLimiters.Range(func(key, value any) bool {
|
||||
entry := value.(*rateLimiterEntry)
|
||||
if entry.lastUsed.Before(olderThan) {
|
||||
ipUserLimiters.Delete(key)
|
||||
cclog.Debugf("Cleaned up rate limiter for %v", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// startRateLimiterCleanup starts a background goroutine to clean up old rate limiters
|
||||
func startRateLimiterCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
// Clean up limiters not used in the last 24 hours
|
||||
cleanupOldRateLimiters(time.Now().Add(-24 * time.Hour))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// AuthConfig contains configuration for all authentication methods
|
||||
type AuthConfig struct {
|
||||
LdapConfig *LdapConfig `json:"ldap"`
|
||||
JwtConfig *JWTAuthConfig `json:"jwts"`
|
||||
OpenIDConfig *OpenIDConfig `json:"oidc"`
|
||||
}
|
||||
|
||||
// Keys holds the global authentication configuration
|
||||
var Keys AuthConfig
|
||||
|
||||
// Authentication manages all authentication methods and session handling
|
||||
type Authentication struct {
|
||||
sessionStore *sessions.CookieStore
|
||||
LdapAuth *LdapAuthenticator
|
||||
@@ -86,10 +139,31 @@ func (auth *Authentication) AuthViaSession(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: Check if session keys exist
|
||||
username, _ := session.Values["username"].(string)
|
||||
projects, _ := session.Values["projects"].([]string)
|
||||
roles, _ := session.Values["roles"].([]string)
|
||||
// Validate session data with proper type checking
|
||||
username, ok := session.Values["username"].(string)
|
||||
if !ok || username == "" {
|
||||
cclog.Warn("Invalid session: missing or invalid username")
|
||||
// Invalidate the corrupted session
|
||||
session.Options.MaxAge = -1
|
||||
_ = auth.sessionStore.Save(r, rw, session)
|
||||
return nil, errors.New("invalid session data")
|
||||
}
|
||||
|
||||
projects, ok := session.Values["projects"].([]string)
|
||||
if !ok {
|
||||
cclog.Warn("Invalid session: projects not found or invalid type, using empty list")
|
||||
projects = []string{}
|
||||
}
|
||||
|
||||
roles, ok := session.Values["roles"].([]string)
|
||||
if !ok || len(roles) == 0 {
|
||||
cclog.Warn("Invalid session: missing or invalid roles")
|
||||
// Invalidate the corrupted session
|
||||
session.Options.MaxAge = -1
|
||||
_ = auth.sessionStore.Save(r, rw, session)
|
||||
return nil, errors.New("invalid session data")
|
||||
}
|
||||
|
||||
return &schema.User{
|
||||
Username: username,
|
||||
Projects: projects,
|
||||
@@ -102,6 +176,9 @@ func (auth *Authentication) AuthViaSession(
|
||||
func Init(authCfg *json.RawMessage) {
|
||||
initOnce.Do(func() {
|
||||
authInstance = &Authentication{}
|
||||
|
||||
// Start background cleanup of rate limiters
|
||||
startRateLimiterCleanup()
|
||||
|
||||
sessKey := os.Getenv("SESSION_KEY")
|
||||
if sessKey == "" {
|
||||
@@ -185,38 +262,36 @@ func GetAuthInstance() *Authentication {
|
||||
return authInstance
|
||||
}
|
||||
|
||||
func handleTokenUser(tokenUser *schema.User) {
|
||||
// handleUserSync syncs or updates a user in the database based on configuration.
|
||||
// This is used for both JWT and OIDC authentication when syncUserOnLogin or updateUserOnLogin is enabled.
|
||||
func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool) {
|
||||
r := repository.GetUserRepository()
|
||||
dbUser, err := r.GetUser(tokenUser.Username)
|
||||
dbUser, err := r.GetUser(user.Username)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%s': %v", tokenUser.Username, err)
|
||||
} else if err == sql.ErrNoRows && Keys.JwtConfig.SyncUserOnLogin { // Adds New User
|
||||
if err := r.AddUser(tokenUser); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", tokenUser.Username, err)
|
||||
cclog.Errorf("Error while loading user '%s': %v", user.Username, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err == sql.ErrNoRows && syncUserOnLogin { // Add new user
|
||||
if err := r.AddUser(user); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err)
|
||||
}
|
||||
} else if err == nil && Keys.JwtConfig.UpdateUserOnLogin { // Update Existing User
|
||||
if err := r.UpdateUser(dbUser, tokenUser); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
||||
} else if err == nil && updateUserOnLogin { // Update existing user
|
||||
if err := r.UpdateUser(dbUser, user); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' in DB: %v", dbUser.Username, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleOIDCUser(OIDCUser *schema.User) {
|
||||
r := repository.GetUserRepository()
|
||||
dbUser, err := r.GetUser(OIDCUser.Username)
|
||||
// handleTokenUser syncs JWT token user with database
|
||||
func handleTokenUser(tokenUser *schema.User) {
|
||||
handleUserSync(tokenUser, Keys.JwtConfig.SyncUserOnLogin, Keys.JwtConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
cclog.Errorf("Error while loading user '%s': %v", OIDCUser.Username, err)
|
||||
} else if err == sql.ErrNoRows && Keys.OpenIDConfig.SyncUserOnLogin { // Adds New User
|
||||
if err := r.AddUser(OIDCUser); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", OIDCUser.Username, err)
|
||||
}
|
||||
} else if err == nil && Keys.OpenIDConfig.UpdateUserOnLogin { // Update Existing User
|
||||
if err := r.UpdateUser(dbUser, OIDCUser); err != nil {
|
||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
||||
}
|
||||
}
|
||||
// handleOIDCUser syncs OIDC user with database
|
||||
func handleOIDCUser(OIDCUser *schema.User) {
|
||||
handleUserSync(OIDCUser, Keys.OpenIDConfig.SyncUserOnLogin, Keys.OpenIDConfig.UpdateUserOnLogin)
|
||||
}
|
||||
|
||||
func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error {
|
||||
@@ -231,6 +306,7 @@ func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request,
|
||||
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
|
||||
}
|
||||
if config.Keys.HTTPSCertFile == "" && config.Keys.HTTPSKeyFile == "" {
|
||||
cclog.Warn("HTTPS not configured - session cookies will not have Secure flag set (insecure for production)")
|
||||
session.Options.Secure = false
|
||||
}
|
||||
session.Options.SameSite = http.SameSiteStrictMode
|
||||
@@ -532,10 +608,13 @@ func securedCheck(user *schema.User, r *http.Request) error {
|
||||
IPAddress = r.RemoteAddr
|
||||
}
|
||||
|
||||
// FIXME: IPV6 not handled
|
||||
if strings.Contains(IPAddress, ":") {
|
||||
IPAddress = strings.Split(IPAddress, ":")[0]
|
||||
// Handle both IPv4 and IPv6 addresses properly
|
||||
// For IPv6, this will strip the port and brackets
|
||||
// For IPv4, this will strip the port
|
||||
if host, _, err := net.SplitHostPort(IPAddress); err == nil {
|
||||
IPAddress = host
|
||||
}
|
||||
// If SplitHostPort fails, IPAddress is already just a host (no port)
|
||||
|
||||
// If nothing declared in config: deny all request to this api endpoint
|
||||
if len(config.Keys.APIAllowedIPs) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user