Glue authenticators together

This commit is contained in:
Lou Knauer 2022-07-07 13:40:38 +02:00
parent 23f6015494
commit db86d2cf7e
5 changed files with 285 additions and 81 deletions

View File

@ -1,13 +1,12 @@
package authv2
import (
"database/sql"
"context"
"encoding/json"
"net/http"
"time"
"github.com/ClusterCockpit/cc-backend/pkg/log"
sq "github.com/Masterminds/squirrel"
"github.com/gorilla/sessions"
"github.com/jmoiron/sqlx"
)
@ -43,10 +42,19 @@ func (u *User) HasRole(role string) bool {
return false
}
func GetUser(ctx context.Context) *User {
x := ctx.Value(ContextUserKey)
if x == nil {
return nil
}
return x.(*User)
}
type Authenticator interface {
Init(auth *Authentication, config json.RawMessage) error
CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool
Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error)
Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error)
Auth(rw http.ResponseWriter, r *http.Request) (*User, error)
}
@ -55,12 +63,17 @@ type ContextKey string
const ContextUserKey ContextKey = "user"
type Authentication struct {
db *sqlx.DB
sessionStore *sessions.CookieStore
db *sqlx.DB
sessionStore *sessions.CookieStore
SessionMaxAge time.Duration
authenticators []Authenticator
LdapAuth *LdapAutnenticator
JwtAuth *JWTAuthenticator
LocalAuth *LocalAuthenticator
}
func Init(db *sqlx.DB) (*Authentication, error) {
func Init(db *sqlx.DB, configs map[string]json.RawMessage) (*Authentication, error) {
auth := &Authentication{}
auth.db = db
_, err := db.Exec(`
@ -75,49 +88,27 @@ func Init(db *sqlx.DB) (*Authentication, error) {
return nil, err
}
return auth, nil
}
func (auth *Authentication) GetUser(username string) (*User, error) {
user := &User{Username: username}
var hashedPassword, name, rawRoles, email sql.NullString
if err := sq.Select("password", "ldap", "name", "roles", "email").From("user").
Where("user.username = ?", username).RunWith(auth.db).
QueryRow().Scan(&hashedPassword, &user.AuthSource, &name, &rawRoles, &email); err != nil {
auth.LocalAuth = &LocalAuthenticator{}
if err := auth.LocalAuth.Init(auth, nil); err != nil {
return nil, err
}
auth.authenticators = append(auth.authenticators, auth.LocalAuth)
user.Password = hashedPassword.String
user.Name = name.String
user.Email = email.String
if rawRoles.Valid {
if err := json.Unmarshal([]byte(rawRoles.String), &user.Roles); err != nil {
auth.JwtAuth = &JWTAuthenticator{}
if err := auth.JwtAuth.Init(auth, nil); err != nil {
return nil, err
}
auth.authenticators = append(auth.authenticators, auth.JwtAuth)
if config, ok := configs["ldap"]; ok {
auth.LdapAuth = &LdapAutnenticator{}
if err := auth.LdapAuth.Init(auth, config); err != nil {
return nil, err
}
auth.authenticators = append(auth.authenticators, auth.LdapAuth)
}
return user, nil
}
func (auth *Authentication) AddUser(user *User) error {
rolesJson, _ := json.Marshal(user.Roles)
cols := []string{"username", "password", "roles"}
vals := []interface{}{user.Username, user.Password, string(rolesJson)}
if user.Name != "" {
cols = append(cols, "name")
vals = append(vals, user.Name)
}
if user.Email != "" {
cols = append(cols, "email")
vals = append(vals, user.Email)
}
if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(auth.db).Exec(); err != nil {
return err
}
log.Infof("new user %#v created (roles: %s)", user.Username, rolesJson)
return nil
return auth, nil
}
func (auth *Authentication) AuthViaSession(rw http.ResponseWriter, r *http.Request) (*User, error) {
@ -133,7 +124,106 @@ func (auth *Authentication) AuthViaSession(rw http.ResponseWriter, r *http.Reque
username, _ := session.Values["username"].(string)
roles, _ := session.Values["roles"].([]string)
return &User{
Username: username,
Roles: roles,
Username: username,
Roles: roles,
AuthSource: -1,
}, nil
}
// Handle a POST request that should log the user in, starting a new session.
func (auth *Authentication) Login(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, loginErr error)) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var err error
username := r.FormValue("username")
user := (*User)(nil)
if username != "" {
if user, _ = auth.GetUser(username); err != nil {
log.Warnf("login of unkown user %#v", username)
}
}
for _, authenticator := range auth.authenticators {
if !authenticator.CanLogin(user, rw, r) {
continue
}
user, err = authenticator.Login(user, rw, r)
if err != nil {
log.Warnf("login failed: %s", err.Error())
onfailure(rw, r, err)
return
}
session, err := auth.sessionStore.New(r, "session")
if err != nil {
log.Errorf("session creation failed: %s", err.Error())
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
if auth.SessionMaxAge != 0 {
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
}
session.Values["username"] = user.Username
session.Values["roles"] = user.Roles
if err := auth.sessionStore.Save(r, rw, session); err != nil {
log.Errorf("session save failed: %s", err.Error())
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
log.Infof("login successfull: user: %#v (roles: %v)", user.Username, user.Roles)
ctx := context.WithValue(r.Context(), ContextUserKey, user)
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
}
log.Warn("login failed: no authenticator applied")
onfailure(rw, r, err)
})
}
// Authenticate the user and put a User object in the
// context of the request. If authentication fails,
// do not continue but send client to the login screen.
func (auth *Authentication) Auth(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, authErr error)) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
for _, authenticator := range auth.authenticators {
user, err := authenticator.Auth(rw, r)
if err != nil {
log.Warnf("authentication failed: %s", err.Error())
http.Error(rw, err.Error(), http.StatusUnauthorized)
return
}
if user == nil {
continue
}
ctx := context.WithValue(r.Context(), ContextUserKey, user)
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
}
log.Warnf("authentication failed: %s", "no authenticator applied")
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
})
}
// Clears the session cookie
func (auth *Authentication) Logout(onsuccess http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
session, err := auth.sessionStore.Get(r, "session")
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
if !session.IsNew {
session.Options.MaxAge = -1
if err := auth.sessionStore.Save(r, rw, session); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
onsuccess.ServeHTTP(rw, r)
})
}

View File

@ -2,7 +2,6 @@ package authv2
import (
"crypto/ed25519"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
@ -48,10 +47,10 @@ func (ja *JWTAuthenticator) Init(auth *Authentication, rawConfig json.RawMessage
}
func (ja *JWTAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
return user.AuthSource == AuthViaToken || r.Header.Get("Authorization") != ""
return (user != nil && user.AuthSource == AuthViaToken) || r.Header.Get("Authorization") != ""
}
func (ja *JWTAuthenticator) Login(_ *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
func (ja *JWTAuthenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
rawtoken := r.Header.Get("X-Auth-Token")
if rawtoken == "" {
rawtoken = r.Header.Get("Authorization")
@ -84,14 +83,9 @@ func (ja *JWTAuthenticator) Login(_ *User, password string, rw http.ResponseWrit
}
}
user, err := ja.auth.GetUser(sub)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
if err != nil && err == sql.ErrNoRows {
if user == nil {
user = &User{
Username: user.Username,
Username: sub,
Roles: roles,
AuthSource: AuthViaToken,
}
@ -114,13 +108,7 @@ func (ja *JWTAuthenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User
// Because a user can also log in via a token, the
// session cookie must be checked here as well:
if rawtoken == "" {
user, err := ja.auth.AuthViaSession(rw, r)
if err != nil {
return nil, err
}
user.AuthSource = AuthViaToken
return user, nil
return ja.auth.AuthViaSession(rw, r)
}
token, err := jwt.Parse(rawtoken, func(t *jwt.Token) (interface{}, error) {

View File

@ -67,10 +67,10 @@ func (la *LdapAutnenticator) Init(auth *Authentication, rawConfig json.RawMessag
}
func (la *LdapAutnenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
return user.AuthSource == AuthViaLDAP
return user != nil && user.AuthSource == AuthViaLDAP
}
func (la *LdapAutnenticator) Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
func (la *LdapAutnenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
l, err := la.getLdapConnection(false)
if err != nil {
return nil, err
@ -78,7 +78,7 @@ func (la *LdapAutnenticator) Login(user *User, password string, rw http.Response
defer l.Close()
userDn := strings.Replace(la.config.UserBind, "{username}", user.Username, -1)
if err := l.Bind(userDn, password); err != nil {
if err := l.Bind(userDn, r.FormValue("password")); err != nil {
return nil, err
}
@ -86,13 +86,7 @@ func (la *LdapAutnenticator) Login(user *User, password string, rw http.Response
}
func (la *LdapAutnenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
user, err := la.auth.AuthViaSession(rw, r)
if err != nil {
return nil, err
}
user.AuthSource = AuthViaLDAP
return user, nil
return la.auth.AuthViaSession(rw, r)
}
func (la *LdapAutnenticator) Sync() error {

View File

@ -20,11 +20,11 @@ func (la *LocalAuthenticator) Init(auth *Authentication, rawConfig json.RawMessa
}
func (la *LocalAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
return user.AuthSource == AuthViaLocalPassword
return user != nil && user.AuthSource == AuthViaLocalPassword
}
func (la *LocalAuthenticator) Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
if e := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); e != nil {
func (la *LocalAuthenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
if e := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(r.FormValue("password"))); e != nil {
return nil, fmt.Errorf("user '%s' provided the wrong password (%w)", user.Username, e)
}
@ -32,11 +32,5 @@ func (la *LocalAuthenticator) Login(user *User, password string, rw http.Respons
}
func (la *LocalAuthenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
user, err := la.auth.AuthViaSession(rw, r)
if err != nil {
return nil, err
}
user.AuthSource = AuthViaLocalPassword
return user, nil
return la.auth.AuthViaSession(rw, r)
}

138
internal/auth-v2/users.go Normal file
View File

@ -0,0 +1,138 @@
package authv2
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
"github.com/ClusterCockpit/cc-backend/pkg/log"
sq "github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
)
func (auth *Authentication) GetUser(username string) (*User, error) {
user := &User{Username: username}
var hashedPassword, name, rawRoles, email sql.NullString
if err := sq.Select("password", "ldap", "name", "roles", "email").From("user").
Where("user.username = ?", username).RunWith(auth.db).
QueryRow().Scan(&hashedPassword, &user.AuthSource, &name, &rawRoles, &email); err != nil {
return nil, err
}
user.Password = hashedPassword.String
user.Name = name.String
user.Email = email.String
if rawRoles.Valid {
if err := json.Unmarshal([]byte(rawRoles.String), &user.Roles); err != nil {
return nil, err
}
}
return user, nil
}
func (auth *Authentication) AddUser(user *User) error {
rolesJson, _ := json.Marshal(user.Roles)
cols := []string{"username", "password", "roles"}
vals := []interface{}{user.Username, user.Password, string(rolesJson)}
if user.Name != "" {
cols = append(cols, "name")
vals = append(vals, user.Name)
}
if user.Email != "" {
cols = append(cols, "email")
vals = append(vals, user.Email)
}
if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(auth.db).Exec(); err != nil {
return err
}
log.Infof("new user %#v created (roles: %s, auth-source: %d)", user.Username, rolesJson, user.AuthSource)
return nil
}
func (auth *Authentication) DelUser(username string) error {
_, err := auth.db.Exec(`DELETE FROM user WHERE user.username = ?`, username)
return err
}
func (auth *Authentication) ListUsers(specialsOnly bool) ([]*User, error) {
q := sq.Select("username", "name", "email", "roles").From("user")
if specialsOnly {
q = q.Where("(roles != '[\"user\"]' AND roles != '[]')")
}
rows, err := q.RunWith(auth.db).Query()
if err != nil {
return nil, err
}
users := make([]*User, 0)
defer rows.Close()
for rows.Next() {
rawroles := ""
user := &User{}
var name, email sql.NullString
if err := rows.Scan(&user.Username, &name, &email, &rawroles); err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(rawroles), &user.Roles); err != nil {
return nil, err
}
user.Name = name.String
user.Email = email.String
users = append(users, user)
}
return users, nil
}
func (auth *Authentication) AddRole(ctx context.Context, username string, role string) error {
user, err := auth.GetUser(username)
if err != nil {
return err
}
if role != RoleAdmin && role != RoleApi && role != RoleUser {
return fmt.Errorf("invalid user role: %#v", role)
}
for _, r := range user.Roles {
if r == role {
return fmt.Errorf("user %#v already has role %#v", username, role)
}
}
roles, _ := json.Marshal(append(user.Roles, role))
if _, err := sq.Update("user").Set("roles", roles).Where("user.username = ?", username).RunWith(auth.db).Exec(); err != nil {
return err
}
return nil
}
func FetchUser(ctx context.Context, db *sqlx.DB, username string) (*model.User, error) {
me := GetUser(ctx)
if me != nil && !me.HasRole(RoleAdmin) && me.Username != username {
return nil, errors.New("forbidden")
}
user := &model.User{Username: username}
var name, email sql.NullString
if err := sq.Select("name", "email").From("user").Where("user.username = ?", username).
RunWith(db).QueryRow().Scan(&name, &email); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
user.Name = name.String
user.Email = email.String
return user, nil
}