mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2026-01-15 17:21:46 +01:00
Reformat with gofumpt
This commit is contained in:
@@ -40,7 +40,7 @@ type Authenticator interface {
|
||||
// authenticator should attempt the login. This method should not perform
|
||||
// expensive operations or actual authentication.
|
||||
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
|
||||
|
||||
|
||||
// Login performs the actually authentication for the user.
|
||||
// It returns the authenticated user or an error if authentication fails.
|
||||
// The user parameter may be nil if the user doesn't exist in the database yet.
|
||||
@@ -65,13 +65,13 @@ var ipUserLimiters sync.Map
|
||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||
key := ip + ":" + username
|
||||
now := time.Now()
|
||||
|
||||
|
||||
if entry, ok := ipUserLimiters.Load(key); ok {
|
||||
rle := entry.(*rateLimiterEntry)
|
||||
rle.lastUsed = now
|
||||
return rle.limiter
|
||||
}
|
||||
|
||||
|
||||
// More aggressive rate limiting: 5 attempts per 15 minutes
|
||||
newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5)
|
||||
ipUserLimiters.Store(key, &rateLimiterEntry{
|
||||
@@ -176,7 +176,7 @@ func (auth *Authentication) AuthViaSession(
|
||||
func Init(authCfg *json.RawMessage) {
|
||||
initOnce.Do(func() {
|
||||
authInstance = &Authentication{}
|
||||
|
||||
|
||||
// Start background cleanup of rate limiters
|
||||
startRateLimiterCleanup()
|
||||
|
||||
@@ -272,7 +272,7 @@ func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool)
|
||||
cclog.Errorf("Error while loading user '%s': %v", user.Username, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if err == sql.ErrNoRows && syncUserOnLogin { // Add new user
|
||||
if err := r.AddUser(user); err != nil {
|
||||
cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err)
|
||||
|
||||
@@ -15,25 +15,25 @@ import (
|
||||
func TestGetIPUserLimiter(t *testing.T) {
|
||||
ip := "192.168.1.1"
|
||||
username := "testuser"
|
||||
|
||||
|
||||
// Get limiter for the first time
|
||||
limiter1 := getIPUserLimiter(ip, username)
|
||||
if limiter1 == nil {
|
||||
t.Fatal("Expected limiter to be created")
|
||||
}
|
||||
|
||||
|
||||
// Get the same limiter again
|
||||
limiter2 := getIPUserLimiter(ip, username)
|
||||
if limiter1 != limiter2 {
|
||||
t.Error("Expected to get the same limiter instance")
|
||||
}
|
||||
|
||||
|
||||
// Get a different limiter for different user
|
||||
limiter3 := getIPUserLimiter(ip, "otheruser")
|
||||
if limiter1 == limiter3 {
|
||||
t.Error("Expected different limiter for different user")
|
||||
}
|
||||
|
||||
|
||||
// Get a different limiter for different IP
|
||||
limiter4 := getIPUserLimiter("192.168.1.2", username)
|
||||
if limiter1 == limiter4 {
|
||||
@@ -45,16 +45,16 @@ func TestGetIPUserLimiter(t *testing.T) {
|
||||
func TestRateLimiterBehavior(t *testing.T) {
|
||||
ip := "10.0.0.1"
|
||||
username := "ratelimituser"
|
||||
|
||||
|
||||
limiter := getIPUserLimiter(ip, username)
|
||||
|
||||
|
||||
// Should allow first 5 attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
if !limiter.Allow() {
|
||||
t.Errorf("Request %d should be allowed within rate limit", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 6th attempt should be blocked
|
||||
if limiter.Allow() {
|
||||
t.Error("Request 6 should be blocked by rate limiter")
|
||||
@@ -65,19 +65,19 @@ func TestRateLimiterBehavior(t *testing.T) {
|
||||
func TestCleanupOldRateLimiters(t *testing.T) {
|
||||
// Clear all existing limiters first to avoid interference from other tests
|
||||
cleanupOldRateLimiters(time.Now().Add(24 * time.Hour))
|
||||
|
||||
|
||||
// Create some new rate limiters
|
||||
limiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||
limiter2 := getIPUserLimiter("2.2.2.2", "user2")
|
||||
|
||||
|
||||
if limiter1 == nil || limiter2 == nil {
|
||||
t.Fatal("Failed to create test limiters")
|
||||
}
|
||||
|
||||
|
||||
// Cleanup limiters older than 1 second from now (should keep both)
|
||||
time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference
|
||||
cleanupOldRateLimiters(time.Now().Add(-1 * time.Second))
|
||||
|
||||
|
||||
// Verify they still exist (should get same instance)
|
||||
if getIPUserLimiter("1.1.1.1", "user1") != limiter1 {
|
||||
t.Error("Limiter 1 was incorrectly cleaned up")
|
||||
@@ -85,10 +85,10 @@ func TestCleanupOldRateLimiters(t *testing.T) {
|
||||
if getIPUserLimiter("2.2.2.2", "user2") != limiter2 {
|
||||
t.Error("Limiter 2 was incorrectly cleaned up")
|
||||
}
|
||||
|
||||
|
||||
// Cleanup limiters older than 1 hour from now (should remove both)
|
||||
cleanupOldRateLimiters(time.Now().Add(2 * time.Hour))
|
||||
|
||||
|
||||
// Getting them again should create new instances
|
||||
newLimiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||
if newLimiter1 == limiter1 {
|
||||
@@ -107,14 +107,14 @@ func TestIPv4Extraction(t *testing.T) {
|
||||
{"IPv4 without port", "192.168.1.1", "192.168.1.1"},
|
||||
{"Localhost with port", "127.0.0.1:3000", "127.0.0.1"},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func TestIPv4Extraction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIPv6Extraction tests extracting IPv6 addresses
|
||||
// TestIPv6Extraction tests extracting IPv6 addresses
|
||||
func TestIPv6Extraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -134,14 +134,14 @@ func TestIPv6Extraction(t *testing.T) {
|
||||
{"IPv6 without port", "2001:db8::1", "2001:db8::1"},
|
||||
{"IPv6 localhost", "::1", "::1"},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
@@ -160,14 +160,14 @@ func TestIPExtractionEdgeCases(t *testing.T) {
|
||||
{"Empty string", "", ""},
|
||||
{"Just port", ":8080", ""},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.input
|
||||
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||
result = host
|
||||
}
|
||||
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
|
||||
@@ -101,20 +101,20 @@ func (ja *JWTAuthenticator) AuthViaJWT(
|
||||
|
||||
// Token is valid, extract payload
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
var user *schema.User
|
||||
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// If not validating user, we only get roles from JWT (no projects for this auth method)
|
||||
if !Keys.JwtConfig.ValidateUser {
|
||||
user.Roles = extractRolesFromClaims(claims, false)
|
||||
user.Projects = nil // Standard JWT auth doesn't include projects
|
||||
}
|
||||
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -146,13 +146,13 @@ func (ja *JWTCookieSessionAuthenticator) Login(
|
||||
}
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// Sync or update user if configured
|
||||
if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) {
|
||||
handleTokenUser(user)
|
||||
|
||||
@@ -28,7 +28,7 @@ func extractStringFromClaims(claims jwt.MapClaims, key string) string {
|
||||
// If validateRoles is true, only valid roles are returned
|
||||
func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
|
||||
var roles []string
|
||||
|
||||
|
||||
if rawroles, ok := claims["roles"].([]any); ok {
|
||||
for _, rr := range rawroles {
|
||||
if r, ok := rr.(string); ok {
|
||||
@@ -42,14 +42,14 @@ func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return roles
|
||||
}
|
||||
|
||||
// extractProjectsFromClaims extracts projects from JWT claims
|
||||
func extractProjectsFromClaims(claims jwt.MapClaims) []string {
|
||||
projects := make([]string, 0)
|
||||
|
||||
|
||||
if rawprojs, ok := claims["projects"].([]any); ok {
|
||||
for _, pp := range rawprojs {
|
||||
if p, ok := pp.(string); ok {
|
||||
@@ -61,7 +61,7 @@ func extractProjectsFromClaims(claims jwt.MapClaims) []string {
|
||||
projects = append(projects, projSlice...)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return projects
|
||||
}
|
||||
|
||||
@@ -72,14 +72,14 @@ func extractNameFromClaims(claims jwt.MapClaims) string {
|
||||
if name, ok := claims["name"].(string); ok {
|
||||
return name
|
||||
}
|
||||
|
||||
|
||||
// Try nested structure: {name: {values: [...]}}
|
||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
||||
if vals, ok := wrap["values"].([]any); ok {
|
||||
if len(vals) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
name := fmt.Sprintf("%v", vals[0])
|
||||
for i := 1; i < len(vals); i++ {
|
||||
name += fmt.Sprintf(" %v", vals[i])
|
||||
@@ -87,7 +87,7 @@ func extractNameFromClaims(claims jwt.MapClaims) string {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.Aut
|
||||
if sub == "" {
|
||||
return nil, errors.New("missing 'sub' claim in JWT")
|
||||
}
|
||||
|
||||
|
||||
if validateUser {
|
||||
// Validate user against database
|
||||
ur := repository.GetUserRepository()
|
||||
@@ -109,22 +109,22 @@ func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.Aut
|
||||
cclog.Errorf("Error while loading user '%v': %v", sub, err)
|
||||
return nil, fmt.Errorf("database error: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// Deny any logins for unknown usernames
|
||||
if user == nil || err == sql.ErrNoRows {
|
||||
cclog.Warn("Could not find user from JWT in internal database.")
|
||||
return nil, errors.New("unknown user")
|
||||
}
|
||||
|
||||
|
||||
// Return database user (with database roles)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
||||
// Create user from JWT claims
|
||||
name := extractNameFromClaims(claims)
|
||||
roles := extractRolesFromClaims(claims, true) // Validate roles
|
||||
projects := extractProjectsFromClaims(claims)
|
||||
|
||||
|
||||
return &schema.User{
|
||||
Username: sub,
|
||||
Name: name,
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestExtractStringFromClaims(t *testing.T) {
|
||||
"email": "test@example.com",
|
||||
"age": 25, // not a string
|
||||
}
|
||||
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
@@ -30,7 +30,7 @@ func TestExtractStringFromClaims(t *testing.T) {
|
||||
{"Non-existent key", "missing", ""},
|
||||
{"Non-string value", "age", ""},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractStringFromClaims(claims, tt.key)
|
||||
@@ -88,16 +88,16 @@ func TestExtractRolesFromClaims(t *testing.T) {
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractRolesFromClaims(tt.claims, tt.validateRoles)
|
||||
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
for i, role := range result {
|
||||
if i >= len(tt.expected) || role != tt.expected[i] {
|
||||
t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role)
|
||||
@@ -141,16 +141,16 @@ func TestExtractProjectsFromClaims(t *testing.T) {
|
||||
expected: []string{"project1", "project2"}, // Should skip non-strings
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractProjectsFromClaims(tt.claims)
|
||||
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
for i, project := range result {
|
||||
if i >= len(tt.expected) || project != tt.expected[i] {
|
||||
t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project)
|
||||
@@ -216,7 +216,7 @@ func TestExtractNameFromClaims(t *testing.T) {
|
||||
expected: "123 Smith", // Should convert to string
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractNameFromClaims(tt.claims)
|
||||
@@ -235,29 +235,28 @@ func TestGetUserFromJWT_NoValidation(t *testing.T) {
|
||||
"roles": []any{"user", "admin"},
|
||||
"projects": []any{"project1", "project2"},
|
||||
}
|
||||
|
||||
|
||||
user, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if user.Username != "testuser" {
|
||||
t.Errorf("Expected username 'testuser', got '%s'", user.Username)
|
||||
}
|
||||
|
||||
|
||||
if user.Name != "Test User" {
|
||||
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
|
||||
}
|
||||
|
||||
|
||||
if len(user.Roles) != 2 {
|
||||
t.Errorf("Expected 2 roles, got %d", len(user.Roles))
|
||||
}
|
||||
|
||||
|
||||
if len(user.Projects) != 2 {
|
||||
t.Errorf("Expected 2 projects, got %d", len(user.Projects))
|
||||
}
|
||||
|
||||
|
||||
if user.AuthType != schema.AuthToken {
|
||||
t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType)
|
||||
}
|
||||
@@ -268,13 +267,13 @@ func TestGetUserFromJWT_MissingSub(t *testing.T) {
|
||||
claims := jwt.MapClaims{
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
|
||||
_, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing sub claim")
|
||||
}
|
||||
|
||||
|
||||
if err.Error() != "missing 'sub' claim in JWT" {
|
||||
t.Errorf("Expected specific error message, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -75,13 +75,13 @@ func (ja *JWTSessionAuthenticator) Login(
|
||||
}
|
||||
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
|
||||
|
||||
// Use shared helper to get user from JWT claims
|
||||
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// Sync or update user if configured
|
||||
if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) {
|
||||
handleTokenUser(user)
|
||||
|
||||
@@ -59,7 +59,7 @@ func NewOIDC(a *Authentication) *OIDC {
|
||||
// Use context with timeout for provider initialization
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider)
|
||||
if err != nil {
|
||||
cclog.Fatal(err)
|
||||
@@ -119,7 +119,7 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
||||
// Exchange authorization code for token with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
|
||||
Reference in New Issue
Block a user