mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2025-11-26 03:23:07 +01:00
Refactor auth package
Fix security issues Remove redundant code Add documentation Add units tests
This commit is contained in:
@@ -18,7 +18,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,8 +31,19 @@ import (
|
|||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Authenticator is the interface for all authentication methods.
|
||||||
|
// Each authenticator determines if it can handle a login request (CanLogin)
|
||||||
|
// and performs the actual authentication (Login).
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
|
// CanLogin determines if this authenticator can handle the login request.
|
||||||
|
// It returns the user object if available and a boolean indicating if this
|
||||||
|
// authenticator should attempt the login. This method should not perform
|
||||||
|
// expensive operations or actual authentication.
|
||||||
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
|
CanLogin(user *schema.User, username string, rw http.ResponseWriter, r *http.Request) (*schema.User, bool)
|
||||||
|
|
||||||
|
// Login performs the actually authentication for the user.
|
||||||
|
// It returns the authenticated user or an error if authentication fails.
|
||||||
|
// The user parameter may be nil if the user doesn't exist in the database yet.
|
||||||
Login(user *schema.User, rw http.ResponseWriter, r *http.Request) (*schema.User, error)
|
Login(user *schema.User, rw http.ResponseWriter, r *http.Request) (*schema.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,27 +52,70 @@ var (
|
|||||||
authInstance *Authentication
|
authInstance *Authentication
|
||||||
)
|
)
|
||||||
|
|
||||||
var ipUserLimiters sync.Map
|
// rateLimiterEntry tracks a rate limiter and its last use time for cleanup
|
||||||
|
type rateLimiterEntry struct {
|
||||||
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
limiter *rate.Limiter
|
||||||
key := ip + ":" + username
|
lastUsed time.Time
|
||||||
limiter, ok := ipUserLimiters.Load(key)
|
|
||||||
if !ok {
|
|
||||||
newLimiter := rate.NewLimiter(rate.Every(time.Hour/10), 10)
|
|
||||||
ipUserLimiters.Store(key, newLimiter)
|
|
||||||
return newLimiter
|
|
||||||
}
|
|
||||||
return limiter.(*rate.Limiter)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ipUserLimiters sync.Map
|
||||||
|
|
||||||
|
// getIPUserLimiter returns a rate limiter for the given IP and username combination.
|
||||||
|
// Rate limiters are created on demand and track 5 attempts per 15 minutes.
|
||||||
|
func getIPUserLimiter(ip, username string) *rate.Limiter {
|
||||||
|
key := ip + ":" + username
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if entry, ok := ipUserLimiters.Load(key); ok {
|
||||||
|
rle := entry.(*rateLimiterEntry)
|
||||||
|
rle.lastUsed = now
|
||||||
|
return rle.limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// More aggressive rate limiting: 5 attempts per 15 minutes
|
||||||
|
newLimiter := rate.NewLimiter(rate.Every(15*time.Minute/5), 5)
|
||||||
|
ipUserLimiters.Store(key, &rateLimiterEntry{
|
||||||
|
limiter: newLimiter,
|
||||||
|
lastUsed: now,
|
||||||
|
})
|
||||||
|
return newLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupOldRateLimiters removes rate limiters that haven't been used recently
|
||||||
|
func cleanupOldRateLimiters(olderThan time.Time) {
|
||||||
|
ipUserLimiters.Range(func(key, value any) bool {
|
||||||
|
entry := value.(*rateLimiterEntry)
|
||||||
|
if entry.lastUsed.Before(olderThan) {
|
||||||
|
ipUserLimiters.Delete(key)
|
||||||
|
cclog.Debugf("Cleaned up rate limiter for %v", key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// startRateLimiterCleanup starts a background goroutine to clean up old rate limiters
|
||||||
|
func startRateLimiterCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
// Clean up limiters not used in the last 24 hours
|
||||||
|
cleanupOldRateLimiters(time.Now().Add(-24 * time.Hour))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthConfig contains configuration for all authentication methods
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
LdapConfig *LdapConfig `json:"ldap"`
|
LdapConfig *LdapConfig `json:"ldap"`
|
||||||
JwtConfig *JWTAuthConfig `json:"jwts"`
|
JwtConfig *JWTAuthConfig `json:"jwts"`
|
||||||
OpenIDConfig *OpenIDConfig `json:"oidc"`
|
OpenIDConfig *OpenIDConfig `json:"oidc"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keys holds the global authentication configuration
|
||||||
var Keys AuthConfig
|
var Keys AuthConfig
|
||||||
|
|
||||||
|
// Authentication manages all authentication methods and session handling
|
||||||
type Authentication struct {
|
type Authentication struct {
|
||||||
sessionStore *sessions.CookieStore
|
sessionStore *sessions.CookieStore
|
||||||
LdapAuth *LdapAuthenticator
|
LdapAuth *LdapAuthenticator
|
||||||
@@ -86,10 +139,31 @@ func (auth *Authentication) AuthViaSession(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Check if session keys exist
|
// Validate session data with proper type checking
|
||||||
username, _ := session.Values["username"].(string)
|
username, ok := session.Values["username"].(string)
|
||||||
projects, _ := session.Values["projects"].([]string)
|
if !ok || username == "" {
|
||||||
roles, _ := session.Values["roles"].([]string)
|
cclog.Warn("Invalid session: missing or invalid username")
|
||||||
|
// Invalidate the corrupted session
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
_ = auth.sessionStore.Save(r, rw, session)
|
||||||
|
return nil, errors.New("invalid session data")
|
||||||
|
}
|
||||||
|
|
||||||
|
projects, ok := session.Values["projects"].([]string)
|
||||||
|
if !ok {
|
||||||
|
cclog.Warn("Invalid session: projects not found or invalid type, using empty list")
|
||||||
|
projects = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
roles, ok := session.Values["roles"].([]string)
|
||||||
|
if !ok || len(roles) == 0 {
|
||||||
|
cclog.Warn("Invalid session: missing or invalid roles")
|
||||||
|
// Invalidate the corrupted session
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
_ = auth.sessionStore.Save(r, rw, session)
|
||||||
|
return nil, errors.New("invalid session data")
|
||||||
|
}
|
||||||
|
|
||||||
return &schema.User{
|
return &schema.User{
|
||||||
Username: username,
|
Username: username,
|
||||||
Projects: projects,
|
Projects: projects,
|
||||||
@@ -103,6 +177,9 @@ func Init(authCfg *json.RawMessage) {
|
|||||||
initOnce.Do(func() {
|
initOnce.Do(func() {
|
||||||
authInstance = &Authentication{}
|
authInstance = &Authentication{}
|
||||||
|
|
||||||
|
// Start background cleanup of rate limiters
|
||||||
|
startRateLimiterCleanup()
|
||||||
|
|
||||||
sessKey := os.Getenv("SESSION_KEY")
|
sessKey := os.Getenv("SESSION_KEY")
|
||||||
if sessKey == "" {
|
if sessKey == "" {
|
||||||
cclog.Warn("environment variable 'SESSION_KEY' not set (will use non-persistent random key)")
|
cclog.Warn("environment variable 'SESSION_KEY' not set (will use non-persistent random key)")
|
||||||
@@ -185,38 +262,36 @@ func GetAuthInstance() *Authentication {
|
|||||||
return authInstance
|
return authInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTokenUser(tokenUser *schema.User) {
|
// handleUserSync syncs or updates a user in the database based on configuration.
|
||||||
|
// This is used for both JWT and OIDC authentication when syncUserOnLogin or updateUserOnLogin is enabled.
|
||||||
|
func handleUserSync(user *schema.User, syncUserOnLogin, updateUserOnLogin bool) {
|
||||||
r := repository.GetUserRepository()
|
r := repository.GetUserRepository()
|
||||||
dbUser, err := r.GetUser(tokenUser.Username)
|
dbUser, err := r.GetUser(user.Username)
|
||||||
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
cclog.Errorf("Error while loading user '%s': %v", tokenUser.Username, err)
|
cclog.Errorf("Error while loading user '%s': %v", user.Username, err)
|
||||||
} else if err == sql.ErrNoRows && Keys.JwtConfig.SyncUserOnLogin { // Adds New User
|
return
|
||||||
if err := r.AddUser(tokenUser); err != nil {
|
}
|
||||||
cclog.Errorf("Error while adding user '%s' to DB: %v", tokenUser.Username, err)
|
|
||||||
|
if err == sql.ErrNoRows && syncUserOnLogin { // Add new user
|
||||||
|
if err := r.AddUser(user); err != nil {
|
||||||
|
cclog.Errorf("Error while adding user '%s' to DB: %v", user.Username, err)
|
||||||
}
|
}
|
||||||
} else if err == nil && Keys.JwtConfig.UpdateUserOnLogin { // Update Existing User
|
} else if err == nil && updateUserOnLogin { // Update existing user
|
||||||
if err := r.UpdateUser(dbUser, tokenUser); err != nil {
|
if err := r.UpdateUser(dbUser, user); err != nil {
|
||||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
cclog.Errorf("Error while updating user '%s' in DB: %v", dbUser.Username, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleOIDCUser(OIDCUser *schema.User) {
|
// handleTokenUser syncs JWT token user with database
|
||||||
r := repository.GetUserRepository()
|
func handleTokenUser(tokenUser *schema.User) {
|
||||||
dbUser, err := r.GetUser(OIDCUser.Username)
|
handleUserSync(tokenUser, Keys.JwtConfig.SyncUserOnLogin, Keys.JwtConfig.UpdateUserOnLogin)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
// handleOIDCUser syncs OIDC user with database
|
||||||
cclog.Errorf("Error while loading user '%s': %v", OIDCUser.Username, err)
|
func handleOIDCUser(OIDCUser *schema.User) {
|
||||||
} else if err == sql.ErrNoRows && Keys.OpenIDConfig.SyncUserOnLogin { // Adds New User
|
handleUserSync(OIDCUser, Keys.OpenIDConfig.SyncUserOnLogin, Keys.OpenIDConfig.UpdateUserOnLogin)
|
||||||
if err := r.AddUser(OIDCUser); err != nil {
|
|
||||||
cclog.Errorf("Error while adding user '%s' to DB: %v", OIDCUser.Username, err)
|
|
||||||
}
|
|
||||||
} else if err == nil && Keys.OpenIDConfig.UpdateUserOnLogin { // Update Existing User
|
|
||||||
if err := r.UpdateUser(dbUser, OIDCUser); err != nil {
|
|
||||||
cclog.Errorf("Error while updating user '%s' to DB: %v", dbUser.Username, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error {
|
func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request, user *schema.User) error {
|
||||||
@@ -231,6 +306,7 @@ func (auth *Authentication) SaveSession(rw http.ResponseWriter, r *http.Request,
|
|||||||
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
|
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
|
||||||
}
|
}
|
||||||
if config.Keys.HTTPSCertFile == "" && config.Keys.HTTPSKeyFile == "" {
|
if config.Keys.HTTPSCertFile == "" && config.Keys.HTTPSKeyFile == "" {
|
||||||
|
cclog.Warn("HTTPS not configured - session cookies will not have Secure flag set (insecure for production)")
|
||||||
session.Options.Secure = false
|
session.Options.Secure = false
|
||||||
}
|
}
|
||||||
session.Options.SameSite = http.SameSiteStrictMode
|
session.Options.SameSite = http.SameSiteStrictMode
|
||||||
@@ -532,10 +608,13 @@ func securedCheck(user *schema.User, r *http.Request) error {
|
|||||||
IPAddress = r.RemoteAddr
|
IPAddress = r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: IPV6 not handled
|
// Handle both IPv4 and IPv6 addresses properly
|
||||||
if strings.Contains(IPAddress, ":") {
|
// For IPv6, this will strip the port and brackets
|
||||||
IPAddress = strings.Split(IPAddress, ":")[0]
|
// For IPv4, this will strip the port
|
||||||
|
if host, _, err := net.SplitHostPort(IPAddress); err == nil {
|
||||||
|
IPAddress = host
|
||||||
}
|
}
|
||||||
|
// If SplitHostPort fails, IPAddress is already just a host (no port)
|
||||||
|
|
||||||
// If nothing declared in config: deny all request to this api endpoint
|
// If nothing declared in config: deny all request to this api endpoint
|
||||||
if len(config.Keys.APIAllowedIPs) == 0 {
|
if len(config.Keys.APIAllowedIPs) == 0 {
|
||||||
|
|||||||
176
internal/auth/auth_test.go
Normal file
176
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||||
|
// All rights reserved. This file is part of cc-backend.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetIPUserLimiter tests the rate limiter creation and retrieval
|
||||||
|
func TestGetIPUserLimiter(t *testing.T) {
|
||||||
|
ip := "192.168.1.1"
|
||||||
|
username := "testuser"
|
||||||
|
|
||||||
|
// Get limiter for the first time
|
||||||
|
limiter1 := getIPUserLimiter(ip, username)
|
||||||
|
if limiter1 == nil {
|
||||||
|
t.Fatal("Expected limiter to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the same limiter again
|
||||||
|
limiter2 := getIPUserLimiter(ip, username)
|
||||||
|
if limiter1 != limiter2 {
|
||||||
|
t.Error("Expected to get the same limiter instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a different limiter for different user
|
||||||
|
limiter3 := getIPUserLimiter(ip, "otheruser")
|
||||||
|
if limiter1 == limiter3 {
|
||||||
|
t.Error("Expected different limiter for different user")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a different limiter for different IP
|
||||||
|
limiter4 := getIPUserLimiter("192.168.1.2", username)
|
||||||
|
if limiter1 == limiter4 {
|
||||||
|
t.Error("Expected different limiter for different IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimiterBehavior tests that rate limiting works correctly
|
||||||
|
func TestRateLimiterBehavior(t *testing.T) {
|
||||||
|
ip := "10.0.0.1"
|
||||||
|
username := "ratelimituser"
|
||||||
|
|
||||||
|
limiter := getIPUserLimiter(ip, username)
|
||||||
|
|
||||||
|
// Should allow first 5 attempts
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
if !limiter.Allow() {
|
||||||
|
t.Errorf("Request %d should be allowed within rate limit", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6th attempt should be blocked
|
||||||
|
if limiter.Allow() {
|
||||||
|
t.Error("Request 6 should be blocked by rate limiter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCleanupOldRateLimiters tests the cleanup function
|
||||||
|
func TestCleanupOldRateLimiters(t *testing.T) {
|
||||||
|
// Clear all existing limiters first to avoid interference from other tests
|
||||||
|
cleanupOldRateLimiters(time.Now().Add(24 * time.Hour))
|
||||||
|
|
||||||
|
// Create some new rate limiters
|
||||||
|
limiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||||
|
limiter2 := getIPUserLimiter("2.2.2.2", "user2")
|
||||||
|
|
||||||
|
if limiter1 == nil || limiter2 == nil {
|
||||||
|
t.Fatal("Failed to create test limiters")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup limiters older than 1 second from now (should keep both)
|
||||||
|
time.Sleep(10 * time.Millisecond) // Small delay to ensure timestamp difference
|
||||||
|
cleanupOldRateLimiters(time.Now().Add(-1 * time.Second))
|
||||||
|
|
||||||
|
// Verify they still exist (should get same instance)
|
||||||
|
if getIPUserLimiter("1.1.1.1", "user1") != limiter1 {
|
||||||
|
t.Error("Limiter 1 was incorrectly cleaned up")
|
||||||
|
}
|
||||||
|
if getIPUserLimiter("2.2.2.2", "user2") != limiter2 {
|
||||||
|
t.Error("Limiter 2 was incorrectly cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup limiters older than 1 hour from now (should remove both)
|
||||||
|
cleanupOldRateLimiters(time.Now().Add(2 * time.Hour))
|
||||||
|
|
||||||
|
// Getting them again should create new instances
|
||||||
|
newLimiter1 := getIPUserLimiter("1.1.1.1", "user1")
|
||||||
|
if newLimiter1 == limiter1 {
|
||||||
|
t.Error("Old limiter should have been cleaned up")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIPv4Extraction tests extracting IPv4 addresses
|
||||||
|
func TestIPv4Extraction(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"IPv4 with port", "192.168.1.1:8080", "192.168.1.1"},
|
||||||
|
{"IPv4 without port", "192.168.1.1", "192.168.1.1"},
|
||||||
|
{"Localhost with port", "127.0.0.1:3000", "127.0.0.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.input
|
||||||
|
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||||
|
result = host
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIPv6Extraction tests extracting IPv6 addresses
|
||||||
|
func TestIPv6Extraction(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"IPv6 with port", "[2001:db8::1]:8080", "2001:db8::1"},
|
||||||
|
{"IPv6 localhost with port", "[::1]:3000", "::1"},
|
||||||
|
{"IPv6 without port", "2001:db8::1", "2001:db8::1"},
|
||||||
|
{"IPv6 localhost", "::1", "::1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.input
|
||||||
|
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||||
|
result = host
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIPExtractionEdgeCases tests edge cases for IP extraction
|
||||||
|
func TestIPExtractionEdgeCases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"Hostname without port", "example.com", "example.com"},
|
||||||
|
{"Empty string", "", ""},
|
||||||
|
{"Just port", ":8080", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.input
|
||||||
|
if host, _, err := net.SplitHostPort(result); err == nil {
|
||||||
|
result = host
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
|
||||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||||
"github.com/ClusterCockpit/cc-lib/schema"
|
"github.com/ClusterCockpit/cc-lib/schema"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -102,38 +101,21 @@ func (ja *JWTAuthenticator) AuthViaJWT(
|
|||||||
|
|
||||||
// Token is valid, extract payload
|
// Token is valid, extract payload
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
sub, _ := claims["sub"].(string)
|
|
||||||
|
|
||||||
var roles []string
|
// Use shared helper to get user from JWT claims
|
||||||
|
var user *schema.User
|
||||||
// Validate user + roles from JWT against database?
|
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthToken, -1)
|
||||||
if Keys.JwtConfig.ValidateUser {
|
if err != nil {
|
||||||
ur := repository.GetUserRepository()
|
return nil, err
|
||||||
user, err := ur.GetUser(sub)
|
|
||||||
// Deny any logins for unknown usernames
|
|
||||||
if err != nil {
|
|
||||||
cclog.Warn("Could not find user from JWT in internal database.")
|
|
||||||
return nil, errors.New("unknown user")
|
|
||||||
}
|
|
||||||
// Take user roles from database instead of trusting the JWT
|
|
||||||
roles = user.Roles
|
|
||||||
} else {
|
|
||||||
// Extract roles from JWT (if present)
|
|
||||||
if rawroles, ok := claims["roles"].([]any); ok {
|
|
||||||
for _, rr := range rawroles {
|
|
||||||
if r, ok := rr.(string); ok {
|
|
||||||
roles = append(roles, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &schema.User{
|
// If not validating user, we only get roles from JWT (no projects for this auth method)
|
||||||
Username: sub,
|
if !Keys.JwtConfig.ValidateUser {
|
||||||
Roles: roles,
|
user.Roles = extractRolesFromClaims(claims, false)
|
||||||
AuthType: schema.AuthToken,
|
user.Projects = nil // Standard JWT auth doesn't include projects
|
||||||
AuthSource: -1,
|
}
|
||||||
}, nil
|
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideJWT generates a new JWT that can be used for authentication
|
// ProvideJWT generates a new JWT that can be used for authentication
|
||||||
|
|||||||
@@ -7,14 +7,11 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
|
||||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||||
"github.com/ClusterCockpit/cc-lib/schema"
|
"github.com/ClusterCockpit/cc-lib/schema"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -149,57 +146,16 @@ func (ja *JWTCookieSessionAuthenticator) Login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
sub, _ := claims["sub"].(string)
|
|
||||||
|
|
||||||
var roles []string
|
// Use shared helper to get user from JWT claims
|
||||||
projects := make([]string, 0)
|
user, err = getUserFromJWT(claims, jc.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if jc.ValidateUser {
|
// Sync or update user if configured
|
||||||
var err error
|
if !jc.ValidateUser && (jc.SyncUserOnLogin || jc.UpdateUserOnLogin) {
|
||||||
user, err = repository.GetUserRepository().GetUser(sub)
|
handleTokenUser(user)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
cclog.Errorf("Error while loading user '%v'", sub)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deny any logins for unknown usernames
|
|
||||||
if user == nil {
|
|
||||||
cclog.Warn("Could not find user from JWT in internal database.")
|
|
||||||
return nil, errors.New("unknown user")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var name string
|
|
||||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
|
||||||
if vals, ok := wrap["values"].([]any); ok {
|
|
||||||
if len(vals) != 0 {
|
|
||||||
name = fmt.Sprintf("%v", vals[0])
|
|
||||||
|
|
||||||
for i := 1; i < len(vals); i++ {
|
|
||||||
name += fmt.Sprintf(" %v", vals[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract roles from JWT (if present)
|
|
||||||
if rawroles, ok := claims["roles"].([]any); ok {
|
|
||||||
for _, rr := range rawroles {
|
|
||||||
if r, ok := rr.(string); ok {
|
|
||||||
roles = append(roles, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
user = &schema.User{
|
|
||||||
Username: sub,
|
|
||||||
Name: name,
|
|
||||||
Roles: roles,
|
|
||||||
Projects: projects,
|
|
||||||
AuthType: schema.AuthSession,
|
|
||||||
AuthSource: schema.AuthViaToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
if jc.SyncUserOnLogin || jc.UpdateUserOnLogin {
|
|
||||||
handleTokenUser(user)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// (Ask browser to) Delete JWT cookie
|
// (Ask browser to) Delete JWT cookie
|
||||||
|
|||||||
136
internal/auth/jwtHelpers.go
Normal file
136
internal/auth/jwtHelpers.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||||
|
// All rights reserved. This file is part of cc-backend.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
||||||
|
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||||
|
"github.com/ClusterCockpit/cc-lib/schema"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractStringFromClaims extracts a string value from JWT claims
|
||||||
|
func extractStringFromClaims(claims jwt.MapClaims, key string) string {
|
||||||
|
if val, ok := claims[key].(string); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractRolesFromClaims extracts roles from JWT claims
|
||||||
|
// If validateRoles is true, only valid roles are returned
|
||||||
|
func extractRolesFromClaims(claims jwt.MapClaims, validateRoles bool) []string {
|
||||||
|
var roles []string
|
||||||
|
|
||||||
|
if rawroles, ok := claims["roles"].([]any); ok {
|
||||||
|
for _, rr := range rawroles {
|
||||||
|
if r, ok := rr.(string); ok {
|
||||||
|
if validateRoles {
|
||||||
|
if schema.IsValidRole(r) {
|
||||||
|
roles = append(roles, r)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
roles = append(roles, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return roles
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractProjectsFromClaims extracts projects from JWT claims
|
||||||
|
func extractProjectsFromClaims(claims jwt.MapClaims) []string {
|
||||||
|
projects := make([]string, 0)
|
||||||
|
|
||||||
|
if rawprojs, ok := claims["projects"].([]any); ok {
|
||||||
|
for _, pp := range rawprojs {
|
||||||
|
if p, ok := pp.(string); ok {
|
||||||
|
projects = append(projects, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rawprojs, ok := claims["projects"]; ok {
|
||||||
|
if projSlice, ok := rawprojs.([]string); ok {
|
||||||
|
projects = append(projects, projSlice...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return projects
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractNameFromClaims extracts name from JWT claims
|
||||||
|
// Handles both simple string and complex nested structure
|
||||||
|
func extractNameFromClaims(claims jwt.MapClaims) string {
|
||||||
|
// Try simple string first
|
||||||
|
if name, ok := claims["name"].(string); ok {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try nested structure: {name: {values: [...]}}
|
||||||
|
if wrap, ok := claims["name"].(map[string]any); ok {
|
||||||
|
if vals, ok := wrap["values"].([]any); ok {
|
||||||
|
if len(vals) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("%v", vals[0])
|
||||||
|
for i := 1; i < len(vals); i++ {
|
||||||
|
name += fmt.Sprintf(" %v", vals[i])
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserFromJWT creates or retrieves a user based on JWT claims
|
||||||
|
// If validateUser is true, the user must exist in the database
|
||||||
|
// Otherwise, a new user object is created from claims
|
||||||
|
// authSource should be a schema.AuthSource constant (like schema.AuthViaToken)
|
||||||
|
func getUserFromJWT(claims jwt.MapClaims, validateUser bool, authType schema.AuthType, authSource schema.AuthSource) (*schema.User, error) {
|
||||||
|
sub := extractStringFromClaims(claims, "sub")
|
||||||
|
if sub == "" {
|
||||||
|
return nil, errors.New("missing 'sub' claim in JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
if validateUser {
|
||||||
|
// Validate user against database
|
||||||
|
ur := repository.GetUserRepository()
|
||||||
|
user, err := ur.GetUser(sub)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
cclog.Errorf("Error while loading user '%v': %v", sub, err)
|
||||||
|
return nil, fmt.Errorf("database error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deny any logins for unknown usernames
|
||||||
|
if user == nil || err == sql.ErrNoRows {
|
||||||
|
cclog.Warn("Could not find user from JWT in internal database.")
|
||||||
|
return nil, errors.New("unknown user")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return database user (with database roles)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create user from JWT claims
|
||||||
|
name := extractNameFromClaims(claims)
|
||||||
|
roles := extractRolesFromClaims(claims, true) // Validate roles
|
||||||
|
projects := extractProjectsFromClaims(claims)
|
||||||
|
|
||||||
|
return &schema.User{
|
||||||
|
Username: sub,
|
||||||
|
Name: name,
|
||||||
|
Roles: roles,
|
||||||
|
Projects: projects,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthSource: authSource,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
281
internal/auth/jwtHelpers_test.go
Normal file
281
internal/auth/jwtHelpers_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
// Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
|
||||||
|
// All rights reserved. This file is part of cc-backend.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ClusterCockpit/cc-lib/schema"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestExtractStringFromClaims tests extracting string values from JWT claims
|
||||||
|
func TestExtractStringFromClaims(t *testing.T) {
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"sub": "testuser",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"age": 25, // not a string
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"Existing string", "sub", "testuser"},
|
||||||
|
{"Another string", "email", "test@example.com"},
|
||||||
|
{"Non-existent key", "missing", ""},
|
||||||
|
{"Non-string value", "age", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractStringFromClaims(claims, tt.key)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractRolesFromClaims tests role extraction and validation
|
||||||
|
func TestExtractRolesFromClaims(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims jwt.MapClaims
|
||||||
|
validateRoles bool
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid roles without validation",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"roles": []any{"admin", "user", "invalid_role"},
|
||||||
|
},
|
||||||
|
validateRoles: false,
|
||||||
|
expected: []string{"admin", "user", "invalid_role"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid roles with validation",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"roles": []any{"admin", "user", "api"},
|
||||||
|
},
|
||||||
|
validateRoles: true,
|
||||||
|
expected: []string{"admin", "user", "api"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid roles with validation",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"roles": []any{"invalid_role", "fake_role"},
|
||||||
|
},
|
||||||
|
validateRoles: true,
|
||||||
|
expected: []string{}, // Should filter out invalid roles
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No roles claim",
|
||||||
|
claims: jwt.MapClaims{},
|
||||||
|
validateRoles: false,
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-array roles",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"roles": "admin",
|
||||||
|
},
|
||||||
|
validateRoles: false,
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractRolesFromClaims(tt.claims, tt.validateRoles)
|
||||||
|
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("Expected %d roles, got %d", len(tt.expected), len(result))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, role := range result {
|
||||||
|
if i >= len(tt.expected) || role != tt.expected[i] {
|
||||||
|
t.Errorf("Expected role %s at position %d, got %s", tt.expected[i], i, role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractProjectsFromClaims tests project extraction from claims
|
||||||
|
func TestExtractProjectsFromClaims(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims jwt.MapClaims
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Projects as array of interfaces",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"projects": []any{"project1", "project2", "project3"},
|
||||||
|
},
|
||||||
|
expected: []string{"project1", "project2", "project3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Projects as string array",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"projects": []string{"projectA", "projectB"},
|
||||||
|
},
|
||||||
|
expected: []string{"projectA", "projectB"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No projects claim",
|
||||||
|
claims: jwt.MapClaims{},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed types in projects array",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"projects": []any{"project1", 123, "project2"},
|
||||||
|
},
|
||||||
|
expected: []string{"project1", "project2"}, // Should skip non-strings
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractProjectsFromClaims(tt.claims)
|
||||||
|
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("Expected %d projects, got %d", len(tt.expected), len(result))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, project := range result {
|
||||||
|
if i >= len(tt.expected) || project != tt.expected[i] {
|
||||||
|
t.Errorf("Expected project %s at position %d, got %s", tt.expected[i], i, project)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractNameFromClaims tests name extraction from various formats
|
||||||
|
func TestExtractNameFromClaims(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
claims jwt.MapClaims
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Simple string name",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
expected: "John Doe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested name structure",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"name": map[string]any{
|
||||||
|
"values": []any{"John", "Doe"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "John Doe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested name with single value",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"name": map[string]any{
|
||||||
|
"values": []any{"Alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "Alice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No name claim",
|
||||||
|
claims: jwt.MapClaims{},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty nested values",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"name": map[string]any{
|
||||||
|
"values": []any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nested with non-string values",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"name": map[string]any{
|
||||||
|
"values": []any{123, "Smith"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "123 Smith", // Should convert to string
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractNameFromClaims(tt.claims)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetUserFromJWT_NoValidation tests getUserFromJWT without database validation
|
||||||
|
func TestGetUserFromJWT_NoValidation(t *testing.T) {
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"sub": "testuser",
|
||||||
|
"name": "Test User",
|
||||||
|
"roles": []any{"user", "admin"},
|
||||||
|
"projects": []any{"project1", "project2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Username != "testuser" {
|
||||||
|
t.Errorf("Expected username 'testuser', got '%s'", user.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Name != "Test User" {
|
||||||
|
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user.Roles) != 2 {
|
||||||
|
t.Errorf("Expected 2 roles, got %d", len(user.Roles))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user.Projects) != 2 {
|
||||||
|
t.Errorf("Expected 2 projects, got %d", len(user.Projects))
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AuthType != schema.AuthToken {
|
||||||
|
t.Errorf("Expected AuthType %v, got %v", schema.AuthToken, user.AuthType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetUserFromJWT_MissingSub tests error when sub claim is missing
|
||||||
|
func TestGetUserFromJWT_MissingSub(t *testing.T) {
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := getUserFromJWT(claims, false, schema.AuthToken, -1)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for missing sub claim")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Error() != "missing 'sub' claim in JWT" {
|
||||||
|
t.Errorf("Expected specific error message, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -14,7 +13,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/internal/repository"
|
|
||||||
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
cclog "github.com/ClusterCockpit/cc-lib/ccLogger"
|
||||||
"github.com/ClusterCockpit/cc-lib/schema"
|
"github.com/ClusterCockpit/cc-lib/schema"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
@@ -77,70 +75,16 @@ func (ja *JWTSessionAuthenticator) Login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := token.Claims.(jwt.MapClaims)
|
claims := token.Claims.(jwt.MapClaims)
|
||||||
sub, _ := claims["sub"].(string)
|
|
||||||
|
|
||||||
var roles []string
|
// Use shared helper to get user from JWT claims
|
||||||
projects := make([]string, 0)
|
user, err = getUserFromJWT(claims, Keys.JwtConfig.ValidateUser, schema.AuthSession, schema.AuthViaToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if Keys.JwtConfig.ValidateUser {
|
// Sync or update user if configured
|
||||||
var err error
|
if !Keys.JwtConfig.ValidateUser && (Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin) {
|
||||||
user, err = repository.GetUserRepository().GetUser(sub)
|
handleTokenUser(user)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
cclog.Errorf("Error while loading user '%v'", sub)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deny any logins for unknown usernames
|
|
||||||
if user == nil {
|
|
||||||
cclog.Warn("Could not find user from JWT in internal database.")
|
|
||||||
return nil, errors.New("unknown user")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var name string
|
|
||||||
if wrap, ok := claims["name"].(map[string]any); ok {
|
|
||||||
if vals, ok := wrap["values"].([]any); ok {
|
|
||||||
if len(vals) != 0 {
|
|
||||||
name = fmt.Sprintf("%v", vals[0])
|
|
||||||
|
|
||||||
for i := 1; i < len(vals); i++ {
|
|
||||||
name += fmt.Sprintf(" %v", vals[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract roles from JWT (if present)
|
|
||||||
if rawroles, ok := claims["roles"].([]any); ok {
|
|
||||||
for _, rr := range rawroles {
|
|
||||||
if r, ok := rr.(string); ok {
|
|
||||||
if schema.IsValidRole(r) {
|
|
||||||
roles = append(roles, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawprojs, ok := claims["projects"].([]any); ok {
|
|
||||||
for _, pp := range rawprojs {
|
|
||||||
if p, ok := pp.(string); ok {
|
|
||||||
projects = append(projects, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if rawprojs, ok := claims["projects"]; ok {
|
|
||||||
projects = append(projects, rawprojs.([]string)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
user = &schema.User{
|
|
||||||
Username: sub,
|
|
||||||
Name: name,
|
|
||||||
Roles: roles,
|
|
||||||
Projects: projects,
|
|
||||||
AuthType: schema.AuthSession,
|
|
||||||
AuthSource: schema.AuthViaToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
if Keys.JwtConfig.SyncUserOnLogin || Keys.JwtConfig.UpdateUserOnLogin {
|
|
||||||
handleTokenUser(user)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
|
|||||||
@@ -54,8 +54,13 @@ func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value strin
|
|||||||
http.SetCookie(w, c)
|
http.SetCookie(w, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewOIDC creates a new OIDC authenticator with the configured provider
|
||||||
func NewOIDC(a *Authentication) *OIDC {
|
func NewOIDC(a *Authentication) *OIDC {
|
||||||
provider, err := oidc.NewProvider(context.Background(), Keys.OpenIDConfig.Provider)
|
// Use context with timeout for provider initialization
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
provider, err := oidc.NewProvider(ctx, Keys.OpenIDConfig.Provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cclog.Fatal(err)
|
cclog.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -111,13 +116,18 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(rw, "Code not found", http.StatusBadRequest)
|
http.Error(rw, "Code not found", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := oa.client.Exchange(context.Background(), code, oauth2.VerifierOption(codeVerifier))
|
// Exchange authorization code for token with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
token, err := oa.client.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userInfo, err := oa.provider.UserInfo(context.Background(), oauth2.StaticTokenSource(token))
|
// Get user info from OIDC provider with same timeout
|
||||||
|
userInfo, err := oa.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@@ -180,8 +190,8 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
oa.authentication.SaveSession(rw, r, user)
|
oa.authentication.SaveSession(rw, r, user)
|
||||||
cclog.Infof("login successfull: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects)
|
cclog.Infof("login successfull: user: %#v (roles: %v, projects: %v)", user.Username, user.Roles, user.Projects)
|
||||||
ctx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
userCtx := context.WithValue(r.Context(), repository.ContextUserKey, user)
|
||||||
http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(ctx))
|
http.RedirectHandler("/", http.StatusTemporaryRedirect).ServeHTTP(rw, r.WithContext(userCtx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) {
|
func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
Reference in New Issue
Block a user