diff --git a/internal/auth/auth.go b/internal/auth/auth.go index c6cb90e..5d94735 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -18,7 +18,6 @@ import ( "net" "net/http" "os" - "strings" "sync" "time" @@ -32,8 +31,19 @@ import ( "github.com/gorilla/sessions" ) +// Authenticator is the interface for all authentication methods. +// Each authenticator determines if it can handle a login request (CanLogin) +// and performs the actual authentication (Login). type Authenticator interface { + // CanLogin determines if this authenticator can handle the login request. + // It returns the user object if available and a boolean indicating if this + // authenticator should attempt the login. This method should not perform + // expensive operations or actual authentication. CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool) + + // Login performs the actually authentication for the user. + // It returns the authenticated user or an error if authentication fails. + // The user parameter may be nil if the user doesn't exist in the database yet. Login(user *schema.User, rw http.ResponseWriter, r *http.Request) (*schema.User, error) } @@ -42,27 +52,70 @@ var ( authInstance *Authentication ) -var ipUserLimiters sync.Map - -func getIPUserLimiter(ip, username string) *rate.Limiter { - key := ip + ":" + username - limiter, ok := ipUserLimiters.Load(key) - if !ok { - newLimiter := rate.NewLimiter(rate.Every(time.Hour/10), 10) - ipUserLimiters.Store(key, newLimiter) - return newLimiter - } - return limiter.(*rate.Limiter) +// rateLimiterEntry tracks a rate limiter and its last use time for cleanup +type rateLimiterEntry struct { + limiter *rate.Limiter + lastUsed time.Time } +var ipUserLimiters sync.Map + +// getIPUserLimiter returns a rate limiter for the given IP and username combination. +// Rate limiters are created on demand and track 5 attempts per 15 minutes. +func getIPUserLimiter(ip, username string) *rate.Limiter { + key := ip + ":" + username + now := time.Now() + + if entry, ok := ipUserLimiters.Load(key); ok { + rle := entry.(*rateLimiterEntry) + rle.lastUsed = now + return rle.limiter + } + + // More aggressive rate limiting: 5 attempts per 15 minutes + newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5) + ipUserLimiters.Store(key, &rateLimiterEntry{ + limiter: newLimiter, + lastUsed: now, + }) + return newLimiter +} + +// cleanupOldRateLimiters removes rate limiters that haven't been used recently +func cleanupOldRateLimiters(olderThan time.Time) { + ipUserLimiters.Range(func(key, value any) bool { + entry := value.(*rateLimiterEntry) + if entry.lastUsed.Before(olderThan) { + ipUserLimiters.Delete(key) + cclog.Debugf("Cleaned up rate limiter for %v", key) + } + return true + }) +} + +// startRateLimiterCleanup starts a background goroutine to clean up old rate limiters +func startRateLimiterCleanup() { + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for range ticker.C { + // Clean up limiters not used in the last 24 hours + cleanupOldRateLimiters(time.Now().Add(-24 * time.Hour)) + } + }() +} + +// AuthConfig contains configuration for all authentication methods type AuthConfig struct { LdapConfig *LdapConfig `json:"ldap"` JwtConfig *JWTAuthConfig `json:"jwts"` OpenIDConfig *OpenIDConfig `json:"oidc"` } +// Keys holds the global authentication configuration var Keys AuthConfig +// Authentication manages all authentication methods and session handling type Authentication struct { sessionStore *sessions.CookieStore LdapAuth *LdapAuthenticator @@ -86,10 +139,31 @@ func (auth *Authentication) AuthViaSession( return nil, nil } - // TODO: Check if session keys exist - username, _ := session.Values["username"].(string) - projects, _ := session.Values["projects"].([]string) - roles, _ := session.Values["roles"].([]string) + // Validate session data with proper type checking + username, ok := session.Values["username"].(string) + if !ok || username == "" { + cclog.Warn("Invalid session: missing or invalid username") + // Invalidate the corrupted session + session.Options.MaxAge = -1 + _ = auth.sessionStore.Save(r, rw, session) + return nil, errors.New("invalid session data") + } + + projects, ok := session.Values["projects"].([]string) + if !ok { + cclog.Warn("Invalid session: projects not found or invalid type, using empty list") + projects = []string{} + } + + roles, ok := session.Values["roles"].([]string) + if !ok || len(roles) == 0 { + cclog.Warn("Invalid session: missing or invalid roles") + // Invalidate the corrupted session + session.Options.MaxAge = -1 + _ = auth.sessionStore.Save(r, rw, session) + return nil, errors.New("invalid session data") + } + return &schema.User{ Username: username, Projects: projects, @@ -102,6 +176,9 @@ func (auth *Authentication) AuthViaSession( func Init(authCfg *json.RawMessage) { initOnce.Do(func() { authInstance = &Authentication{} + + // Start background cleanup of rate limiters + startRateLimiterCleanup() sessKey := os.Getenv("SESSION_KEY") if sessKey == "" { @@ -185,38 +262,36 @@ func GetAuthInstance() *Authentication { return authInstance } -func handleTokenUser(tokenUser *schema.User) { +// handleUserSync syncs or updates a user in the database based on configuration. +// This is used for both JWT and OIDC authentication when syncUserOnLogin or updateUserOnLogin is enabled. +func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool) { r := repository.GetUserRepository() - dbUser, err := r.GetUser(tokenUser.Username) + dbUser, err := r.GetUser(user.Username) if err != nil && err != sql.ErrNoRows { - cclog.Errorf("Error while loading user '%s': %v", tokenUser.Username, err) - } else if err == sql.ErrNoRows && Keys.JwtConfig.SyncUserOnLogin { // Adds New User - if err := r.AddUser(tokenUser); err != nil { - cclog.Errorf("Error while adding user '%s' to DB: %v", tokenUser.Username, err) + cclog.Errorf("Error while loading user '%s': %v", user.Username, err) + return + } + + if err == sql.ErrNoRows && syncUserOnLogin { // Add new user + if err := r.AddUser(user); err != nil { + cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err) } - } else if err == nil && Keys.JwtConfig.UpdateUserOnLogin { // Update Existing User - if err := r.UpdateUser(dbUser, tokenUser); err != nil { - cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err) + } else if err == nil && updateUserOnLogin { // Update existing user + if err := r.UpdateUser(dbUser, user); err != nil { + cclog.Errorf("Error while updating user '%s' in DB: %v", dbUser.Username, err) } } } -func handleOIDCUser(OIDCUser *schema.User) { - r := repository.GetUserRepository() - dbUser, err := r.GetUser(OIDCUser.Username) +// handleTokenUser syncs JWT token user with database +func handleTokenUser(tokenUser *schema.User) { + handleUserSync(tokenUser, Keys.JwtConfig.SyncUserOnLogin, Keys.JwtConfig.UpdateUserOnLogin) +} - if err != nil && err != sql.ErrNoRows { - cclog.Errorf("Error while loading user '%s': %v", OIDCUser.Username, err) - } else if err == sql.ErrNoRows && Keys.OpenIDConfig.SyncUserOnLogin { // Adds New User - if err := r.AddUser(OIDCUser); err != nil { - cclog.Errorf("Error while adding user '%s' to DB: %v", OIDCUser.Username, err) - } - } else if err == nil && Keys.OpenIDConfig.UpdateUserOnLogin { // Update Existing User - if err := r.UpdateUser(dbUser, OIDCUser); err != nil { - cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err) - } - } +// handleOIDCUser syncs OIDC user with database +func handleOIDCUser(OIDCUser *schema.User) { + handleUserSync(OIDCUser, Keys.OpenIDConfig.SyncUserOnLogin, Keys.OpenIDConfig.UpdateUserOnLogin) } func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error { @@ -231,6 +306,7 @@ func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, session.Options.MaxAge = int(auth.SessionMaxAge.Seconds()) } if config.Keys.HTTPSCertFile == "" && config.Keys.HTTPSKeyFile == "" { + cclog.Warn("HTTPS not configured - session cookies will not have Secure flag set (insecure for production)") session.Options.Secure = false } session.Options.SameSite = http.SameSiteStrictMode @@ -532,10 +608,13 @@ func securedCheck(user *schema.User, r *http.Request) error { IPAddress = r.RemoteAddr } - // FIXME: IPV6 not handled - if strings.Contains(IPAddress, ":") { - IPAddress = strings.Split(IPAddress, ":")[0] + // Handle both IPv4 and IPv6 addresses properly + // For IPv6, this will strip the port and brackets + // For IPv4, this will strip the port + if host, _, err := net.SplitHostPort(IPAddress); err == nil { + IPAddress = host } + // If SplitHostPort fails, IPAddress is already just a host (no port) // If nothing declared in config: deny all request to this api endpoint if len(config.Keys.APIAllowedIPs) == 0 { diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..15f153e --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,176 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "net" + "testing" + "time" +) + +// TestGetIPUserLimiter tests the rate limiter creation and retrieval +func TestGetIPUserLimiter(t *testing.T) { + ip := "192.168.1.1" + username := "testuser" + + // Get limiter for the first time + limiter1 := getIPUserLimiter(ip, username) + if limiter1 == nil { + t.Fatal("Expected limiter to be created") + } + + // Get the same limiter again + limiter2 := getIPUserLimiter(ip, username) + if limiter1 != limiter2 { + t.Error("Expected to get the same limiter instance") + } + + // Get a different limiter for different user + limiter3 := getIPUserLimiter(ip, "otheruser") + if limiter1 == limiter3 { + t.Error("Expected different limiter for different user") + } + + // Get a different limiter for different IP + limiter4 := getIPUserLimiter("192.168.1.2", username) + if limiter1 == limiter4 { + t.Error("Expected different limiter for different IP") + } +} + +// TestRateLimiterBehavior tests that rate limiting works correctly +func TestRateLimiterBehavior(t *testing.T) { + ip := "10.0.0.1" + username := "ratelimituser" + + limiter := getIPUserLimiter(ip, username) + + // Should allow first 5 attempts + for i := 0; i < 5; i++ { + if !limiter.Allow() { + t.Errorf("Request %d should be allowed within rate limit", i+1) + } + } + + // 6th attempt should be blocked + if limiter.Allow() { + t.Error("Request 6 should be blocked by rate limiter") + } +} + +// TestCleanupOldRateLimiters tests the cleanup function +func TestCleanupOldRateLimiters(t *testing.T) { + // Clear all existing limiters first to avoid interference from other tests + cleanupOldRateLimiters(time.Now().Add(24 * time.Hour)) + + // Create some new rate limiters + limiter1 := getIPUserLimiter("1.1.1.1", "user1") + limiter2 := getIPUserLimiter("2.2.2.2", "user2") + + if limiter1 == nil || limiter2 == nil { + t.Fatal("Failed to create test limiters") + } + + // Cleanup limiters older than 1 second from now (should keep both) + time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference + cleanupOldRateLimiters(time.Now().Add(-1 * time.Second)) + + // Verify they still exist (should get same instance) + if getIPUserLimiter("1.1.1.1", "user1") != limiter1 { + t.Error("Limiter 1 was incorrectly cleaned up") + } + if getIPUserLimiter("2.2.2.2", "user2") != limiter2 { + t.Error("Limiter 2 was incorrectly cleaned up") + } + + // Cleanup limiters older than 1 hour from now (should remove both) + cleanupOldRateLimiters(time.Now().Add(2 * time.Hour)) + + // Getting them again should create new instances + newLimiter1 := getIPUserLimiter("1.1.1.1", "user1") + if newLimiter1 == limiter1 { + t.Error("Old limiter should have been cleaned up") + } +} + +// TestIPv4Extraction tests extracting IPv4 addresses +func TestIPv4Extraction(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"IPv4 with port", "192.168.1.1:8080", "192.168.1.1"}, + {"IPv4 without port", "192.168.1.1", "192.168.1.1"}, + {"Localhost with port", "127.0.0.1:3000", "127.0.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.input + if host, _, err := net.SplitHostPort(result); err == nil { + result = host + } + + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} + +// TestIPv6Extraction tests extracting IPv6 addresses +func TestIPv6Extraction(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"}, + {"IPv6 localhost with port", "[::1]:3000", "::1"}, + {"IPv6 without port", "2001:db8::1", "2001:db8::1"}, + {"IPv6 localhost", "::1", "::1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.input + if host, _, err := net.SplitHostPort(result); err == nil { + result = host + } + + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} + +// TestIPExtractionEdgeCases tests edge cases for IP extraction +func TestIPExtractionEdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"Hostname without port", "example.com", "example.com"}, + {"Empty string", "", ""}, + {"Just port", ":8080", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.input + if host, _, err := net.SplitHostPort(result); err == nil { + result = host + } + + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 7d7a047..4f1f3f5 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/ClusterCockpit/cc-backend/internal/repository" cclog "github.com/ClusterCockpit/cc-lib/ccLogger" "github.com/ClusterCockpit/cc-lib/schema" "github.com/golang-jwt/jwt/v5" @@ -102,38 +101,21 @@ func (ja *JWTAuthenticator) AuthViaJWT( // Token is valid, extract payload claims := token.Claims.(jwt.MapClaims) - sub, _ := claims["sub"].(string) - - var roles []string - - // Validate user + roles from JWT against database? - if Keys.JwtConfig.ValidateUser { - ur := repository.GetUserRepository() - user, err := ur.GetUser(sub) - // Deny any logins for unknown usernames - if err != nil { - cclog.Warn("Could not find user from JWT in internal database.") - return nil, errors.New("unknown user") - } - // Take user roles from database instead of trusting the JWT - roles = user.Roles - } else { - // Extract roles from JWT (if present) - if rawroles, ok := claims["roles"].([]any); ok { - for _, rr := range rawroles { - if r, ok := rr.(string); ok { - roles = append(roles, r) - } - } - } + + // Use shared helper to get user from JWT claims + var user *schema.User + user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1) + if err != nil { + return nil, err } - - return &schema.User{ - Username: sub, - Roles: roles, - AuthType: schema.AuthToken, - AuthSource: -1, - }, nil + + // If not validating user, we only get roles from JWT (no projects for this auth method) + if !Keys.JwtConfig.ValidateUser { + user.Roles = extractRolesFromClaims(claims, false) + user.Projects = nil // Standard JWT auth doesn't include projects + } + + return user, nil } // ProvideJWT generates a new JWT that can be used for authentication diff --git a/internal/auth/jwtCookieSession.go b/internal/auth/jwtCookieSession.go index 300f875..44c64a0 100644 --- a/internal/auth/jwtCookieSession.go +++ b/internal/auth/jwtCookieSession.go @@ -7,14 +7,11 @@ package auth import ( "crypto/ed25519" - "database/sql" "encoding/base64" "errors" - "fmt" "net/http" "os" - "github.com/ClusterCockpit/cc-backend/internal/repository" cclog "github.com/ClusterCockpit/cc-lib/ccLogger" "github.com/ClusterCockpit/cc-lib/schema" "github.com/golang-jwt/jwt/v5" @@ -149,57 +146,16 @@ func (ja *JWTCookieSessionAuthenticator) Login( } claims := token.Claims.(jwt.MapClaims) - sub, _ := claims["sub"].(string) - - var roles []string - projects := make([]string, 0) - - if jc.ValidateUser { - var err error - user, err = repository.GetUserRepository().GetUser(sub) - if err != nil && err != sql.ErrNoRows { - cclog.Errorf("Error while loading user '%v'", sub) - } - - // Deny any logins for unknown usernames - if user == nil { - cclog.Warn("Could not find user from JWT in internal database.") - return nil, errors.New("unknown user") - } - } else { - var name string - if wrap, ok := claims["name"].(map[string]any); ok { - if vals, ok := wrap["values"].([]any); ok { - if len(vals) != 0 { - name = fmt.Sprintf("%v", vals[0]) - - for i := 1; i < len(vals); i++ { - name += fmt.Sprintf(" %v", vals[i]) - } - } - } - } - - // Extract roles from JWT (if present) - if rawroles, ok := claims["roles"].([]any); ok { - for _, rr := range rawroles { - if r, ok := rr.(string); ok { - roles = append(roles, r) - } - } - } - user = &schema.User{ - Username: sub, - Name: name, - Roles: roles, - Projects: projects, - AuthType: schema.AuthSession, - AuthSource: schema.AuthViaToken, - } - - if jc.SyncUserOnLogin || jc.UpdateUserOnLogin { - handleTokenUser(user) - } + + // Use shared helper to get user from JWT claims + user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken) + if err != nil { + return nil, err + } + + // Sync or update user if configured + if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) { + handleTokenUser(user) } // (Ask browser to) Delete JWT cookie diff --git a/internal/auth/jwtHelpers.go b/internal/auth/jwtHelpers.go new file mode 100644 index 0000000..792722a --- /dev/null +++ b/internal/auth/jwtHelpers.go @@ -0,0 +1,136 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "database/sql" + "errors" + "fmt" + + "github.com/ClusterCockpit/cc-backend/internal/repository" + cclog "github.com/ClusterCockpit/cc-lib/ccLogger" + "github.com/ClusterCockpit/cc-lib/schema" + "github.com/golang-jwt/jwt/v5" +) + +// extractStringFromClaims extracts a string value from JWT claims +func extractStringFromClaims(claims jwt.MapClaims, key string) string { + if val, ok := claims[key].(string); ok { + return val + } + return "" +} + +// extractRolesFromClaims extracts roles from JWT claims +// If validateRoles is true, only valid roles are returned +func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string { + var roles []string + + if rawroles, ok := claims["roles"].([]any); ok { + for _, rr := range rawroles { + if r, ok := rr.(string); ok { + if validateRoles { + if schema.IsValidRole(r) { + roles = append(roles, r) + } + } else { + roles = append(roles, r) + } + } + } + } + + return roles +} + +// extractProjectsFromClaims extracts projects from JWT claims +func extractProjectsFromClaims(claims jwt.MapClaims) []string { + projects := make([]string, 0) + + if rawprojs, ok := claims["projects"].([]any); ok { + for _, pp := range rawprojs { + if p, ok := pp.(string); ok { + projects = append(projects, p) + } + } + } else if rawprojs, ok := claims["projects"]; ok { + if projSlice, ok := rawprojs.([]string); ok { + projects = append(projects, projSlice...) + } + } + + return projects +} + +// extractNameFromClaims extracts name from JWT claims +// Handles both simple string and complex nested structure +func extractNameFromClaims(claims jwt.MapClaims) string { + // Try simple string first + if name, ok := claims["name"].(string); ok { + return name + } + + // Try nested structure: {name: {values: [...]}} + if wrap, ok := claims["name"].(map[string]any); ok { + if vals, ok := wrap["values"].([]any); ok { + if len(vals) == 0 { + return "" + } + + name := fmt.Sprintf("%v", vals[0]) + for i := 1; i < len(vals); i++ { + name += fmt.Sprintf(" %v", vals[i]) + } + return name + } + } + + return "" +} + +// getUserFromJWT creates or retrieves a user based on JWT claims +// If validateUser is true, the user must exist in the database +// Otherwise, a new user object is created from claims +// authSource should be a schema.AuthSource constant (like schema.AuthViaToken) +func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.AuthType, authSource schema.AuthSource) (*schema.User, error) { + sub := extractStringFromClaims(claims, "sub") + if sub == "" { + return nil, errors.New("missing 'sub' claim in JWT") + } + + if validateUser { + // Validate user against database + ur := repository.GetUserRepository() + user, err := ur.GetUser(sub) + if err != nil && err != sql.ErrNoRows { + cclog.Errorf("Error while loading user '%v': %v", sub, err) + return nil, fmt.Errorf("database error: %w", err) + } + + // Deny any logins for unknown usernames + if user == nil || err == sql.ErrNoRows { + cclog.Warn("Could not find user from JWT in internal database.") + return nil, errors.New("unknown user") + } + + // Return database user (with database roles) + return user, nil + } + + // Create user from JWT claims + name := extractNameFromClaims(claims) + roles := extractRolesFromClaims(claims, true) // Validate roles + projects := extractProjectsFromClaims(claims) + + return &schema.User{ + Username: sub, + Name: name, + Roles: roles, + Projects: projects, + AuthType: authType, + AuthSource: authSource, + }, nil +} diff --git a/internal/auth/jwtHelpers_test.go b/internal/auth/jwtHelpers_test.go new file mode 100644 index 0000000..5cee1df --- /dev/null +++ b/internal/auth/jwtHelpers_test.go @@ -0,0 +1,281 @@ +// Copyright (C) NHR@FAU, University Erlangen-Nuremberg. +// All rights reserved. This file is part of cc-backend. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package auth + +import ( + "testing" + + "github.com/ClusterCockpit/cc-lib/schema" + "github.com/golang-jwt/jwt/v5" +) + +// TestExtractStringFromClaims tests extracting string values from JWT claims +func TestExtractStringFromClaims(t *testing.T) { + claims := jwt.MapClaims{ + "sub": "testuser", + "email": "test@example.com", + "age": 25, // not a string + } + + tests := []struct { + name string + key string + expected string + }{ + {"Existing string", "sub", "testuser"}, + {"Another string", "email", "test@example.com"}, + {"Non-existent key", "missing", ""}, + {"Non-string value", "age", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractStringFromClaims(claims, tt.key) + if result != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + }) + } +} + +// TestExtractRolesFromClaims tests role extraction and validation +func TestExtractRolesFromClaims(t *testing.T) { + tests := []struct { + name string + claims jwt.MapClaims + validateRoles bool + expected []string + }{ + { + name: "Valid roles without validation", + claims: jwt.MapClaims{ + "roles": []any{"admin", "user", "invalid_role"}, + }, + validateRoles: false, + expected: []string{"admin", "user", "invalid_role"}, + }, + { + name: "Valid roles with validation", + claims: jwt.MapClaims{ + "roles": []any{"admin", "user", "api"}, + }, + validateRoles: true, + expected: []string{"admin", "user", "api"}, + }, + { + name: "Invalid roles with validation", + claims: jwt.MapClaims{ + "roles": []any{"invalid_role", "fake_role"}, + }, + validateRoles: true, + expected: []string{}, // Should filter out invalid roles + }, + { + name: "No roles claim", + claims: jwt.MapClaims{}, + validateRoles: false, + expected: []string{}, + }, + { + name: "Non-array roles", + claims: jwt.MapClaims{ + "roles": "admin", + }, + validateRoles: false, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractRolesFromClaims(tt.claims, tt.validateRoles) + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result)) + return + } + + for i, role := range result { + if i >= len(tt.expected) || role != tt.expected[i] { + t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role) + } + } + }) + } +} + +// TestExtractProjectsFromClaims tests project extraction from claims +func TestExtractProjectsFromClaims(t *testing.T) { + tests := []struct { + name string + claims jwt.MapClaims + expected []string + }{ + { + name: "Projects as array of interfaces", + claims: jwt.MapClaims{ + "projects": []any{"project1", "project2", "project3"}, + }, + expected: []string{"project1", "project2", "project3"}, + }, + { + name: "Projects as string array", + claims: jwt.MapClaims{ + "projects": []string{"projectA", "projectB"}, + }, + expected: []string{"projectA", "projectB"}, + }, + { + name: "No projects claim", + claims: jwt.MapClaims{}, + expected: []string{}, + }, + { + name: "Mixed types in projects array", + claims: jwt.MapClaims{ + "projects": []any{"project1", 123, "project2"}, + }, + expected: []string{"project1", "project2"}, // Should skip non-strings + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractProjectsFromClaims(tt.claims) + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result)) + return + } + + for i, project := range result { + if i >= len(tt.expected) || project != tt.expected[i] { + t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project) + } + } + }) + } +} + +// TestExtractNameFromClaims tests name extraction from various formats +func TestExtractNameFromClaims(t *testing.T) { + tests := []struct { + name string + claims jwt.MapClaims + expected string + }{ + { + name: "Simple string name", + claims: jwt.MapClaims{ + "name": "John Doe", + }, + expected: "John Doe", + }, + { + name: "Nested name structure", + claims: jwt.MapClaims{ + "name": map[string]any{ + "values": []any{"John", "Doe"}, + }, + }, + expected: "John Doe", + }, + { + name: "Nested name with single value", + claims: jwt.MapClaims{ + "name": map[string]any{ + "values": []any{"Alice"}, + }, + }, + expected: "Alice", + }, + { + name: "No name claim", + claims: jwt.MapClaims{}, + expected: "", + }, + { + name: "Empty nested values", + claims: jwt.MapClaims{ + "name": map[string]any{ + "values": []any{}, + }, + }, + expected: "", + }, + { + name: "Nested with non-string values", + claims: jwt.MapClaims{ + "name": map[string]any{ + "values": []any{123, "Smith"}, + }, + }, + expected: "123 Smith", // Should convert to string + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractNameFromClaims(tt.claims) + if result != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, result) + } + }) + } +} + +// TestGetUserFromJWT_NoValidation tests getUserFromJWT without database validation +func TestGetUserFromJWT_NoValidation(t *testing.T) { + claims := jwt.MapClaims{ + "sub": "testuser", + "name": "Test User", + "roles": []any{"user", "admin"}, + "projects": []any{"project1", "project2"}, + } + + user, err := getUserFromJWT(claims, false, schema.AuthToken, -1) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if user.Username != "testuser" { + t.Errorf("Expected username 'testuser', got '%s'", user.Username) + } + + if user.Name != "Test User" { + t.Errorf("Expected name 'Test User', got '%s'", user.Name) + } + + if len(user.Roles) != 2 { + t.Errorf("Expected 2 roles, got %d", len(user.Roles)) + } + + if len(user.Projects) != 2 { + t.Errorf("Expected 2 projects, got %d", len(user.Projects)) + } + + if user.AuthType != schema.AuthToken { + t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType) + } +} + +// TestGetUserFromJWT_MissingSub tests error when sub claim is missing +func TestGetUserFromJWT_MissingSub(t *testing.T) { + claims := jwt.MapClaims{ + "name": "Test User", + } + + _, err := getUserFromJWT(claims, false, schema.AuthToken, -1) + + if err == nil { + t.Error("Expected error for missing sub claim") + } + + if err.Error() != "missing 'sub' claim in JWT" { + t.Errorf("Expected specific error message, got: %v", err) + } +} diff --git a/internal/auth/jwtSession.go b/internal/auth/jwtSession.go index 714b986..15e5834 100644 --- a/internal/auth/jwtSession.go +++ b/internal/auth/jwtSession.go @@ -6,7 +6,6 @@ package auth import ( - "database/sql" "encoding/base64" "errors" "fmt" @@ -14,7 +13,6 @@ import ( "os" "strings" - "github.com/ClusterCockpit/cc-backend/internal/repository" cclog "github.com/ClusterCockpit/cc-lib/ccLogger" "github.com/ClusterCockpit/cc-lib/schema" "github.com/golang-jwt/jwt/v5" @@ -77,70 +75,16 @@ func (ja *JWTSessionAuthenticator) Login( } claims := token.Claims.(jwt.MapClaims) - sub, _ := claims["sub"].(string) - - var roles []string - projects := make([]string, 0) - - if Keys.JwtConfig.ValidateUser { - var err error - user, err = repository.GetUserRepository().GetUser(sub) - if err != nil && err != sql.ErrNoRows { - cclog.Errorf("Error while loading user '%v'", sub) - } - - // Deny any logins for unknown usernames - if user == nil { - cclog.Warn("Could not find user from JWT in internal database.") - return nil, errors.New("unknown user") - } - } else { - var name string - if wrap, ok := claims["name"].(map[string]any); ok { - if vals, ok := wrap["values"].([]any); ok { - if len(vals) != 0 { - name = fmt.Sprintf("%v", vals[0]) - - for i := 1; i < len(vals); i++ { - name += fmt.Sprintf(" %v", vals[i]) - } - } - } - } - - // Extract roles from JWT (if present) - if rawroles, ok := claims["roles"].([]any); ok { - for _, rr := range rawroles { - if r, ok := rr.(string); ok { - if schema.IsValidRole(r) { - roles = append(roles, r) - } - } - } - } - - if rawprojs, ok := claims["projects"].([]any); ok { - for _, pp := range rawprojs { - if p, ok := pp.(string); ok { - projects = append(projects, p) - } - } - } else if rawprojs, ok := claims["projects"]; ok { - projects = append(projects, rawprojs.([]string)...) - } - - user = &schema.User{ - Username: sub, - Name: name, - Roles: roles, - Projects: projects, - AuthType: schema.AuthSession, - AuthSource: schema.AuthViaToken, - } - - if Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin { - handleTokenUser(user) - } + + // Use shared helper to get user from JWT claims + user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken) + if err != nil { + return nil, err + } + + // Sync or update user if configured + if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) { + handleTokenUser(user) } return user, nil diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index b3def70..9e36130 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -54,8 +54,13 @@ func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value strin http.SetCookie(w, c) } +// NewOIDC creates a new OIDC authenticator with the configured provider func NewOIDC(a *Authentication) *OIDC { - provider, err := oidc.NewProvider(context.Background(), Keys.OpenIDConfig.Provider) + // Use context with timeout for provider initialization + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider) if err != nil { cclog.Fatal(err) } @@ -111,13 +116,18 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "Code not found", http.StatusBadRequest) return } - token, err := oa.client.Exchange(context.Background(), code, oauth2.VerifierOption(codeVerifier)) + // Exchange authorization code for token with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier)) if err != nil { http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) return } - userInfo, err := oa.provider.UserInfo(context.Background(), oauth2.StaticTokenSource(token)) + // Get user info from OIDC provider with same timeout + userInfo, err := oa.provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) if err != nil { http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError) return @@ -180,8 +190,8 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) { oa.authentication.SaveSession(rw, r, user) cclog.Infof("login successfull: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects) - ctx := context.WithValue(r.Context(), repository.ContextUserKey, user) - http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(ctx)) + userCtx := context.WithValue(r.Context(), repository.ContextUserKey, user) + http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(userCtx)) } func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) {