Reformat with gofumpt

This commit is contained in:
2026-01-13 09:50:57 +01:00
parent a9366d14c6
commit 2ebab1e2e2
8 changed files with 64 additions and 65 deletions

View File

@@ -40,7 +40,7 @@ type Authenticator interface {
// authenticator should attempt the login. This method should not perform // authenticator should attempt the login. This method should not perform
// expensive operations or actual authentication. // expensive operations or actual authentication.
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool) CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
// Login performs the actually authentication for the user. // Login performs the actually authentication for the user.
// It returns the authenticated user or an error if authentication fails. // It returns the authenticated user or an error if authentication fails.
// The user parameter may be nil if the user doesn't exist in the database yet. // The user parameter may be nil if the user doesn't exist in the database yet.
@@ -65,13 +65,13 @@ var ipUserLimiters sync.Map
func getIPUserLimiter(ip, username string) *rate.Limiter { func getIPUserLimiter(ip, username string) *rate.Limiter {
key := ip + ":" + username key := ip + ":" + username
now := time.Now() now := time.Now()
if entry, ok := ipUserLimiters.Load(key); ok { if entry, ok := ipUserLimiters.Load(key); ok {
rle := entry.(*rateLimiterEntry) rle := entry.(*rateLimiterEntry)
rle.lastUsed = now rle.lastUsed = now
return rle.limiter return rle.limiter
} }
// More aggressive rate limiting: 5 attempts per 15 minutes // More aggressive rate limiting: 5 attempts per 15 minutes
newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5) newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5)
ipUserLimiters.Store(key, &rateLimiterEntry{ ipUserLimiters.Store(key, &rateLimiterEntry{
@@ -176,7 +176,7 @@ func (auth *Authentication) AuthViaSession(
func Init(authCfg *json.RawMessage) { func Init(authCfg *json.RawMessage) {
initOnce.Do(func() { initOnce.Do(func() {
authInstance = &Authentication{} authInstance = &Authentication{}
// Start background cleanup of rate limiters // Start background cleanup of rate limiters
startRateLimiterCleanup() startRateLimiterCleanup()
@@ -272,7 +272,7 @@ func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool)
cclog.Errorf("Error while loading user '%s': %v", user.Username, err) cclog.Errorf("Error while loading user '%s': %v", user.Username, err)
return return
} }
if err == sql.ErrNoRows && syncUserOnLogin { // Add new user if err == sql.ErrNoRows && syncUserOnLogin { // Add new user
if err := r.AddUser(user); err != nil { if err := r.AddUser(user); err != nil {
cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err) cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err)

View File

@@ -15,25 +15,25 @@ import (
func TestGetIPUserLimiter(t *testing.T) { func TestGetIPUserLimiter(t *testing.T) {
ip := "192.168.1.1" ip := "192.168.1.1"
username := "testuser" username := "testuser"
// Get limiter for the first time // Get limiter for the first time
limiter1 := getIPUserLimiter(ip, username) limiter1 := getIPUserLimiter(ip, username)
if limiter1 == nil { if limiter1 == nil {
t.Fatal("Expected limiter to be created") t.Fatal("Expected limiter to be created")
} }
// Get the same limiter again // Get the same limiter again
limiter2 := getIPUserLimiter(ip, username) limiter2 := getIPUserLimiter(ip, username)
if limiter1 != limiter2 { if limiter1 != limiter2 {
t.Error("Expected to get the same limiter instance") t.Error("Expected to get the same limiter instance")
} }
// Get a different limiter for different user // Get a different limiter for different user
limiter3 := getIPUserLimiter(ip, "otheruser") limiter3 := getIPUserLimiter(ip, "otheruser")
if limiter1 == limiter3 { if limiter1 == limiter3 {
t.Error("Expected different limiter for different user") t.Error("Expected different limiter for different user")
} }
// Get a different limiter for different IP // Get a different limiter for different IP
limiter4 := getIPUserLimiter("192.168.1.2", username) limiter4 := getIPUserLimiter("192.168.1.2", username)
if limiter1 == limiter4 { if limiter1 == limiter4 {
@@ -45,16 +45,16 @@ func TestGetIPUserLimiter(t *testing.T) {
func TestRateLimiterBehavior(t *testing.T) { func TestRateLimiterBehavior(t *testing.T) {
ip := "10.0.0.1" ip := "10.0.0.1"
username := "ratelimituser" username := "ratelimituser"
limiter := getIPUserLimiter(ip, username) limiter := getIPUserLimiter(ip, username)
// Should allow first 5 attempts // Should allow first 5 attempts
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
if !limiter.Allow() { if !limiter.Allow() {
t.Errorf("Request %d should be allowed within rate limit", i+1) t.Errorf("Request %d should be allowed within rate limit", i+1)
} }
} }
// 6th attempt should be blocked // 6th attempt should be blocked
if limiter.Allow() { if limiter.Allow() {
t.Error("Request 6 should be blocked by rate limiter") t.Error("Request 6 should be blocked by rate limiter")
@@ -65,19 +65,19 @@ func TestRateLimiterBehavior(t *testing.T) {
func TestCleanupOldRateLimiters(t *testing.T) { func TestCleanupOldRateLimiters(t *testing.T) {
// Clear all existing limiters first to avoid interference from other tests // Clear all existing limiters first to avoid interference from other tests
cleanupOldRateLimiters(time.Now().Add(24 * time.Hour)) cleanupOldRateLimiters(time.Now().Add(24 * time.Hour))
// Create some new rate limiters // Create some new rate limiters
limiter1 := getIPUserLimiter("1.1.1.1", "user1") limiter1 := getIPUserLimiter("1.1.1.1", "user1")
limiter2 := getIPUserLimiter("2.2.2.2", "user2") limiter2 := getIPUserLimiter("2.2.2.2", "user2")
if limiter1 == nil || limiter2 == nil { if limiter1 == nil || limiter2 == nil {
t.Fatal("Failed to create test limiters") t.Fatal("Failed to create test limiters")
} }
// Cleanup limiters older than 1 second from now (should keep both) // Cleanup limiters older than 1 second from now (should keep both)
time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference
cleanupOldRateLimiters(time.Now().Add(-1 * time.Second)) cleanupOldRateLimiters(time.Now().Add(-1 * time.Second))
// Verify they still exist (should get same instance) // Verify they still exist (should get same instance)
if getIPUserLimiter("1.1.1.1", "user1") != limiter1 { if getIPUserLimiter("1.1.1.1", "user1") != limiter1 {
t.Error("Limiter 1 was incorrectly cleaned up") t.Error("Limiter 1 was incorrectly cleaned up")
@@ -85,10 +85,10 @@ func TestCleanupOldRateLimiters(t *testing.T) {
if getIPUserLimiter("2.2.2.2", "user2") != limiter2 { if getIPUserLimiter("2.2.2.2", "user2") != limiter2 {
t.Error("Limiter 2 was incorrectly cleaned up") t.Error("Limiter 2 was incorrectly cleaned up")
} }
// Cleanup limiters older than 1 hour from now (should remove both) // Cleanup limiters older than 1 hour from now (should remove both)
cleanupOldRateLimiters(time.Now().Add(2 * time.Hour)) cleanupOldRateLimiters(time.Now().Add(2 * time.Hour))
// Getting them again should create new instances // Getting them again should create new instances
newLimiter1 := getIPUserLimiter("1.1.1.1", "user1") newLimiter1 := getIPUserLimiter("1.1.1.1", "user1")
if newLimiter1 == limiter1 { if newLimiter1 == limiter1 {
@@ -107,14 +107,14 @@ func TestIPv4Extraction(t *testing.T) {
{"IPv4 without port", "192.168.1.1", "192.168.1.1"}, {"IPv4 without port", "192.168.1.1", "192.168.1.1"},
{"Localhost with port", "127.0.0.1:3000", "127.0.0.1"}, {"Localhost with port", "127.0.0.1:3000", "127.0.0.1"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := tt.input result := tt.input
if host, _, err := net.SplitHostPort(result); err == nil { if host, _, err := net.SplitHostPort(result); err == nil {
result = host result = host
} }
if result != tt.expected { if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result) t.Errorf("Expected %s, got %s", tt.expected, result)
} }
@@ -122,7 +122,7 @@ func TestIPv4Extraction(t *testing.T) {
} }
} }
// TestIPv6Extraction tests extracting IPv6 addresses // TestIPv6Extraction tests extracting IPv6 addresses
func TestIPv6Extraction(t *testing.T) { func TestIPv6Extraction(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -134,14 +134,14 @@ func TestIPv6Extraction(t *testing.T) {
{"IPv6 without port", "2001:db8::1", "2001:db8::1"}, {"IPv6 without port", "2001:db8::1", "2001:db8::1"},
{"IPv6 localhost", "::1", "::1"}, {"IPv6 localhost", "::1", "::1"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := tt.input result := tt.input
if host, _, err := net.SplitHostPort(result); err == nil { if host, _, err := net.SplitHostPort(result); err == nil {
result = host result = host
} }
if result != tt.expected { if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result) t.Errorf("Expected %s, got %s", tt.expected, result)
} }
@@ -160,14 +160,14 @@ func TestIPExtractionEdgeCases(t *testing.T) {
{"Empty string", "", ""}, {"Empty string", "", ""},
{"Just port", ":8080", ""}, {"Just port", ":8080", ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := tt.input result := tt.input
if host, _, err := net.SplitHostPort(result); err == nil { if host, _, err := net.SplitHostPort(result); err == nil {
result = host result = host
} }
if result != tt.expected { if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result) t.Errorf("Expected %s, got %s", tt.expected, result)
} }

View File

@@ -101,20 +101,20 @@ func (ja *JWTAuthenticator) AuthViaJWT(
// Token is valid, extract payload // Token is valid, extract payload
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
// Use shared helper to get user from JWT claims // Use shared helper to get user from JWT claims
var user *schema.User var user *schema.User
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1) user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If not validating user, we only get roles from JWT (no projects for this auth method) // If not validating user, we only get roles from JWT (no projects for this auth method)
if !Keys.JwtConfig.ValidateUser { if !Keys.JwtConfig.ValidateUser {
user.Roles = extractRolesFromClaims(claims, false) user.Roles = extractRolesFromClaims(claims, false)
user.Projects = nil // Standard JWT auth doesn't include projects user.Projects = nil // Standard JWT auth doesn't include projects
} }
return user, nil return user, nil
} }

View File

@@ -146,13 +146,13 @@ func (ja *JWTCookieSessionAuthenticator) Login(
} }
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
// Use shared helper to get user from JWT claims // Use shared helper to get user from JWT claims
user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken) user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Sync or update user if configured // Sync or update user if configured
if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) { if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) {
handleTokenUser(user) handleTokenUser(user)

View File

@@ -28,7 +28,7 @@ func extractStringFromClaims(claims jwt.MapClaims, key string) string {
// If validateRoles is true, only valid roles are returned // If validateRoles is true, only valid roles are returned
func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string { func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
var roles []string var roles []string
if rawroles, ok := claims["roles"].([]any); ok { if rawroles, ok := claims["roles"].([]any); ok {
for _, rr := range rawroles { for _, rr := range rawroles {
if r, ok := rr.(string); ok { if r, ok := rr.(string); ok {
@@ -42,14 +42,14 @@ func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
} }
} }
} }
return roles return roles
} }
// extractProjectsFromClaims extracts projects from JWT claims // extractProjectsFromClaims extracts projects from JWT claims
func extractProjectsFromClaims(claims jwt.MapClaims) []string { func extractProjectsFromClaims(claims jwt.MapClaims) []string {
projects := make([]string, 0) projects := make([]string, 0)
if rawprojs, ok := claims["projects"].([]any); ok { if rawprojs, ok := claims["projects"].([]any); ok {
for _, pp := range rawprojs { for _, pp := range rawprojs {
if p, ok := pp.(string); ok { if p, ok := pp.(string); ok {
@@ -61,7 +61,7 @@ func extractProjectsFromClaims(claims jwt.MapClaims) []string {
projects = append(projects, projSlice...) projects = append(projects, projSlice...)
} }
} }
return projects return projects
} }
@@ -72,14 +72,14 @@ func extractNameFromClaims(claims jwt.MapClaims) string {
if name, ok := claims["name"].(string); ok { if name, ok := claims["name"].(string); ok {
return name return name
} }
// Try nested structure: {name: {values: [...]}} // Try nested structure: {name: {values: [...]}}
if wrap, ok := claims["name"].(map[string]any); ok { if wrap, ok := claims["name"].(map[string]any); ok {
if vals, ok := wrap["values"].([]any); ok { if vals, ok := wrap["values"].([]any); ok {
if len(vals) == 0 { if len(vals) == 0 {
return "" return ""
} }
name := fmt.Sprintf("%v", vals[0]) name := fmt.Sprintf("%v", vals[0])
for i := 1; i < len(vals); i++ { for i := 1; i < len(vals); i++ {
name += fmt.Sprintf(" %v", vals[i]) name += fmt.Sprintf(" %v", vals[i])
@@ -87,7 +87,7 @@ func extractNameFromClaims(claims jwt.MapClaims) string {
return name return name
} }
} }
return "" return ""
} }
@@ -100,7 +100,7 @@ func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.Aut
if sub == "" { if sub == "" {
return nil, errors.New("missing 'sub' claim in JWT") return nil, errors.New("missing 'sub' claim in JWT")
} }
if validateUser { if validateUser {
// Validate user against database // Validate user against database
ur := repository.GetUserRepository() ur := repository.GetUserRepository()
@@ -109,22 +109,22 @@ func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.Aut
cclog.Errorf("Error while loading user '%v': %v", sub, err) cclog.Errorf("Error while loading user '%v': %v", sub, err)
return nil, fmt.Errorf("database error: %w", err) return nil, fmt.Errorf("database error: %w", err)
} }
// Deny any logins for unknown usernames // Deny any logins for unknown usernames
if user == nil || err == sql.ErrNoRows { if user == nil || err == sql.ErrNoRows {
cclog.Warn("Could not find user from JWT in internal database.") cclog.Warn("Could not find user from JWT in internal database.")
return nil, errors.New("unknown user") return nil, errors.New("unknown user")
} }
// Return database user (with database roles) // Return database user (with database roles)
return user, nil return user, nil
} }
// Create user from JWT claims // Create user from JWT claims
name := extractNameFromClaims(claims) name := extractNameFromClaims(claims)
roles := extractRolesFromClaims(claims, true) // Validate roles roles := extractRolesFromClaims(claims, true) // Validate roles
projects := extractProjectsFromClaims(claims) projects := extractProjectsFromClaims(claims)
return &schema.User{ return &schema.User{
Username: sub, Username: sub,
Name: name, Name: name,

View File

@@ -19,7 +19,7 @@ func TestExtractStringFromClaims(t *testing.T) {
"email": "test@example.com", "email": "test@example.com",
"age": 25, // not a string "age": 25, // not a string
} }
tests := []struct { tests := []struct {
name string name string
key string key string
@@ -30,7 +30,7 @@ func TestExtractStringFromClaims(t *testing.T) {
{"Non-existent key", "missing", ""}, {"Non-existent key", "missing", ""},
{"Non-string value", "age", ""}, {"Non-string value", "age", ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := extractStringFromClaims(claims, tt.key) result := extractStringFromClaims(claims, tt.key)
@@ -88,16 +88,16 @@ func TestExtractRolesFromClaims(t *testing.T) {
expected: []string{}, expected: []string{},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := extractRolesFromClaims(tt.claims, tt.validateRoles) result := extractRolesFromClaims(tt.claims, tt.validateRoles)
if len(result) != len(tt.expected) { if len(result) != len(tt.expected) {
t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result)) t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result))
return return
} }
for i, role := range result { for i, role := range result {
if i >= len(tt.expected) || role != tt.expected[i] { if i >= len(tt.expected) || role != tt.expected[i] {
t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role) t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role)
@@ -141,16 +141,16 @@ func TestExtractProjectsFromClaims(t *testing.T) {
expected: []string{"project1", "project2"}, // Should skip non-strings expected: []string{"project1", "project2"}, // Should skip non-strings
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := extractProjectsFromClaims(tt.claims) result := extractProjectsFromClaims(tt.claims)
if len(result) != len(tt.expected) { if len(result) != len(tt.expected) {
t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result)) t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result))
return return
} }
for i, project := range result { for i, project := range result {
if i >= len(tt.expected) || project != tt.expected[i] { if i >= len(tt.expected) || project != tt.expected[i] {
t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project) t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project)
@@ -216,7 +216,7 @@ func TestExtractNameFromClaims(t *testing.T) {
expected: "123 Smith", // Should convert to string expected: "123 Smith", // Should convert to string
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := extractNameFromClaims(tt.claims) result := extractNameFromClaims(tt.claims)
@@ -235,29 +235,28 @@ func TestGetUserFromJWT_NoValidation(t *testing.T) {
"roles": []any{"user", "admin"}, "roles": []any{"user", "admin"},
"projects": []any{"project1", "project2"}, "projects": []any{"project1", "project2"},
} }
user, err := getUserFromJWT(claims, false, schema.AuthToken, -1) user, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
if user.Username != "testuser" { if user.Username != "testuser" {
t.Errorf("Expected username 'testuser', got '%s'", user.Username) t.Errorf("Expected username 'testuser', got '%s'", user.Username)
} }
if user.Name != "Test User" { if user.Name != "Test User" {
t.Errorf("Expected name 'Test User', got '%s'", user.Name) t.Errorf("Expected name 'Test User', got '%s'", user.Name)
} }
if len(user.Roles) != 2 { if len(user.Roles) != 2 {
t.Errorf("Expected 2 roles, got %d", len(user.Roles)) t.Errorf("Expected 2 roles, got %d", len(user.Roles))
} }
if len(user.Projects) != 2 { if len(user.Projects) != 2 {
t.Errorf("Expected 2 projects, got %d", len(user.Projects)) t.Errorf("Expected 2 projects, got %d", len(user.Projects))
} }
if user.AuthType != schema.AuthToken { if user.AuthType != schema.AuthToken {
t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType) t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType)
} }
@@ -268,13 +267,13 @@ func TestGetUserFromJWT_MissingSub(t *testing.T) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"name": "Test User", "name": "Test User",
} }
_, err := getUserFromJWT(claims, false, schema.AuthToken, -1) _, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
if err == nil { if err == nil {
t.Error("Expected error for missing sub claim") t.Error("Expected error for missing sub claim")
} }
if err.Error() != "missing 'sub' claim in JWT" { if err.Error() != "missing 'sub' claim in JWT" {
t.Errorf("Expected specific error message, got: %v", err) t.Errorf("Expected specific error message, got: %v", err)
} }

View File

@@ -75,13 +75,13 @@ func (ja *JWTSessionAuthenticator) Login(
} }
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
// Use shared helper to get user from JWT claims // Use shared helper to get user from JWT claims
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken) user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Sync or update user if configured // Sync or update user if configured
if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) { if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) {
handleTokenUser(user) handleTokenUser(user)

View File

@@ -59,7 +59,7 @@ func NewOIDC(a *Authentication) *OIDC {
// Use context with timeout for provider initialization // Use context with timeout for provider initialization
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider) provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider)
if err != nil { if err != nil {
cclog.Fatal(err) cclog.Fatal(err)
@@ -119,7 +119,7 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
// Exchange authorization code for token with timeout // Exchange authorization code for token with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier)) token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
if err != nil { if err != nil {
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)