diff --git a/auth/auth.go b/auth/auth.go index a45e57d..4e31afc 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -14,7 +14,6 @@ import ( "strings" "github.com/ClusterCockpit/cc-backend/log" - "github.com/ClusterCockpit/cc-backend/templates" sq "github.com/Masterminds/squirrel" "github.com/golang-jwt/jwt/v4" "github.com/gorilla/sessions" @@ -22,8 +21,9 @@ import ( "golang.org/x/crypto/bcrypt" ) -// TODO: Properly do this "roles" stuff. -// Add a roles array and `user.HasRole(...)` functions. +// Only Username and Roles will always be filled in when returned by `GetUser`. +// If Name and Email is needed as well, use auth.FetchUser(), which does a database +// query for all fields. type User struct { Username string Password string @@ -52,12 +52,18 @@ type ContextKey string const ContextUserKey ContextKey = "user" -var JwtPublicKey ed25519.PublicKey -var JwtPrivateKey ed25519.PrivateKey +type Authentication struct { + db *sqlx.DB + sessionStore *sessions.CookieStore + jwtPublicKey ed25519.PublicKey + jwtPrivateKey ed25519.PrivateKey -var sessionStore *sessions.CookieStore + ldapConfig *LdapConfig + ldapSyncUserPassword string +} -func Init(db *sqlx.DB, ldapConfig *LdapConfig) error { +func (auth *Authentication) Init(db *sqlx.DB, ldapConfig *LdapConfig) error { + auth.db = db _, err := db.Exec(` CREATE TABLE IF NOT EXISTS user ( username varchar(255) PRIMARY KEY, @@ -77,13 +83,13 @@ func Init(db *sqlx.DB, ldapConfig *LdapConfig) error { if _, err := rand.Read(bytes); err != nil { return err } - sessionStore = sessions.NewCookieStore(bytes) + auth.sessionStore = sessions.NewCookieStore(bytes) } else { bytes, err := base64.StdEncoding.DecodeString(sessKey) if err != nil { return err } - sessionStore = sessions.NewCookieStore(bytes) + auth.sessionStore = sessions.NewCookieStore(bytes) } pubKey, privKey := os.Getenv("JWT_PUBLIC_KEY"), os.Getenv("JWT_PRIVATE_KEY") @@ -94,16 +100,17 @@ func Init(db *sqlx.DB, ldapConfig *LdapConfig) error { if err != nil { return err } - JwtPublicKey = ed25519.PublicKey(bytes) + auth.jwtPublicKey = ed25519.PublicKey(bytes) bytes, err = base64.StdEncoding.DecodeString(privKey) if err != nil { return err } - JwtPrivateKey = ed25519.PrivateKey(bytes) + auth.jwtPrivateKey = ed25519.PrivateKey(bytes) } if ldapConfig != nil { - if err := initLdap(ldapConfig); err != nil { + auth.ldapConfig = ldapConfig + if err := auth.initLdap(); err != nil { return err } } @@ -111,8 +118,8 @@ func Init(db *sqlx.DB, ldapConfig *LdapConfig) error { return nil } -// arg must be formated like this: ":[admin]:" -func AddUserToDB(db *sqlx.DB, arg string) error { +// arg must be formated like this: ":[admin|api|]:" +func (auth *Authentication) AddUser(arg string) error { parts := strings.SplitN(arg, ":", 3) if len(parts) != 3 || len(parts[0]) == 0 { return errors.New("invalid argument format") @@ -139,7 +146,7 @@ func AddUserToDB(db *sqlx.DB, arg string) error { } rolesJson, _ := json.Marshal(roles) - _, err := sq.Insert("user").Columns("username", "password", "roles").Values(parts[0], password, string(rolesJson)).RunWith(db).Exec() + _, err := sq.Insert("user").Columns("username", "password", "roles").Values(parts[0], password, string(rolesJson)).RunWith(auth.db).Exec() if err != nil { return err } @@ -147,16 +154,16 @@ func AddUserToDB(db *sqlx.DB, arg string) error { return nil } -func DelUserFromDB(db *sqlx.DB, username string) error { - _, err := db.Exec(`DELETE FROM user WHERE user.username = ?`, username) +func (auth *Authentication) DelUser(username string) error { + _, err := auth.db.Exec(`DELETE FROM user WHERE user.username = ?`, username) return err } -func FetchUserFromDB(db *sqlx.DB, username string) (*User, error) { +func (auth *Authentication) FetchUser(username string) (*User, error) { user := &User{Username: username} var hashedPassword, name, rawRoles, email sql.NullString if err := sq.Select("password", "ldap", "name", "roles", "email").From("user"). - Where("user.username = ?", username).RunWith(db). + Where("user.username = ?", username).RunWith(auth.db). QueryRow().Scan(&hashedPassword, &user.ViaLdap, &name, &rawRoles, &email); err != nil { return nil, fmt.Errorf("user '%s' not found (%s)", username, err.Error()) } @@ -165,20 +172,21 @@ func FetchUserFromDB(db *sqlx.DB, username string) (*User, error) { user.Name = name.String user.Email = email.String if rawRoles.Valid { - json.Unmarshal([]byte(rawRoles.String), &user.Roles) + if err := json.Unmarshal([]byte(rawRoles.String), &user.Roles); err != nil { + return nil, err + } } return user, nil } -// Handle a POST request that should log the user in, -// starting a new session. -func Login(db *sqlx.DB) http.Handler { +// Handle a POST request that should log the user in, starting a new session. +func (auth *Authentication) Login(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, loginErr error)) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { username, password := r.FormValue("username"), r.FormValue("password") - user, err := FetchUserFromDB(db, username) - if err == nil && user.ViaLdap && ldapAuthEnabled { - err = loginViaLdap(user, password) + user, err := auth.FetchUser(username) + if err == nil && user.ViaLdap && auth.ldapConfig != nil { + err = auth.loginViaLdap(user, password) } else if err == nil && !user.ViaLdap && user.Password != "" { if e := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); e != nil { err = fmt.Errorf("user '%s' provided the wrong password (%s)", username, e.Error()) @@ -189,15 +197,11 @@ func Login(db *sqlx.DB) http.Handler { if err != nil { log.Warnf("login of user %#v failed: %s", username, err.Error()) - rw.WriteHeader(http.StatusUnauthorized) - templates.Render(rw, r, "login.tmpl", &templates.Page{ - Title: "Login failed", - Error: "Username or password incorrect", - }) + onfailure(rw, r, err) return } - session, err := sessionStore.New(r, "session") + session, err := auth.sessionStore.New(r, "session") if err != nil { log.Errorf("session creation failed: %s", err.Error()) http.Error(rw, err.Error(), http.StatusInternalServerError) @@ -207,21 +211,22 @@ func Login(db *sqlx.DB) http.Handler { session.Options.MaxAge = 30 * 24 * 60 * 60 session.Values["username"] = user.Username session.Values["roles"] = user.Roles - if err := sessionStore.Save(r, rw, session); err != nil { + if err := auth.sessionStore.Save(r, rw, session); err != nil { log.Errorf("session save failed: %s", err.Error()) http.Error(rw, err.Error(), http.StatusInternalServerError) return } log.Infof("login successfull: user: %#v (roles: %v)", user.Username, user.Roles) - http.Redirect(rw, r, "/", http.StatusTemporaryRedirect) + ctx := context.WithValue(r.Context(), ContextUserKey, user) + onsuccess.ServeHTTP(rw, r.WithContext(ctx)) }) } var ErrTokenInvalid error = errors.New("invalid token") -func authViaToken(r *http.Request) (*User, error) { - if JwtPublicKey == nil { +func (auth *Authentication) authViaToken(r *http.Request) (*User, error) { + if auth.jwtPublicKey == nil { return nil, nil } @@ -239,7 +244,7 @@ func authViaToken(r *http.Request) (*User, error) { if t.Method != jwt.SigningMethodEdDSA { return nil, errors.New("only Ed25519/EdDSA supported") } - return JwtPublicKey, nil + return auth.jwtPublicKey, nil }) if err != nil { return nil, ErrTokenInvalid @@ -263,9 +268,9 @@ func authViaToken(r *http.Request) (*User, error) { // Authenticate the user and put a User object in the // context of the request. If authentication fails, // do not continue but send client to the login screen. -func Auth(next http.Handler) http.Handler { +func (auth *Authentication) Auth(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, authErr error)) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - user, err := authViaToken(r) + user, err := auth.authViaToken(r) if err == ErrTokenInvalid { log.Warn("authentication failed: invalid token") http.Error(rw, err.Error(), http.StatusUnauthorized) @@ -274,11 +279,11 @@ func Auth(next http.Handler) http.Handler { if user != nil { // Successfull authentication using a token ctx := context.WithValue(r.Context(), ContextUserKey, user) - next.ServeHTTP(rw, r.WithContext(ctx)) + onsuccess.ServeHTTP(rw, r.WithContext(ctx)) return } - session, err := sessionStore.Get(r, "session") + session, err := auth.sessionStore.Get(r, "session") if err != nil { // sessionStore.Get will return a new session if no current one is attached to this request. http.Error(rw, err.Error(), http.StatusInternalServerError) @@ -287,12 +292,7 @@ func Auth(next http.Handler) http.Handler { if session.IsNew { log.Warn("authentication failed: no session or jwt found") - - rw.WriteHeader(http.StatusUnauthorized) - templates.Render(rw, r, "login.tmpl", &templates.Page{ - Title: "Authentication failed", - Error: "No valid session or JWT provided", - }) + onfailure(rw, r, errors.New("no valid session or JWT provided")) return } @@ -302,13 +302,13 @@ func Auth(next http.Handler) http.Handler { Username: username, Roles: roles, }) - next.ServeHTTP(rw, r.WithContext(ctx)) + onsuccess.ServeHTTP(rw, r.WithContext(ctx)) }) } // Generate a new JWT that can be used for authentication -func ProvideJWT(user *User) (string, error) { - if JwtPrivateKey == nil { +func (auth *Authentication) ProvideJWT(user *User) (string, error) { + if auth.jwtPrivateKey == nil { return "", errors.New("environment variable 'JWT_PRIVATE_KEY' not set") } @@ -317,7 +317,7 @@ func ProvideJWT(user *User) (string, error) { "roles": user.Roles, }) - return tok.SignedString(JwtPrivateKey) + return tok.SignedString(auth.jwtPrivateKey) } func GetUser(ctx context.Context) *User { @@ -330,23 +330,22 @@ func GetUser(ctx context.Context) *User { } // Clears the session cookie -func Logout(rw http.ResponseWriter, r *http.Request) { - session, err := sessionStore.Get(r, "session") - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - - if !session.IsNew { - session.Options.MaxAge = -1 - if err := sessionStore.Save(r, rw, session); err != nil { +func (auth *Authentication) Logout(onsuccess http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + session, err := auth.sessionStore.Get(r, "session") + if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } - } - templates.Render(rw, r, "login.tmpl", &templates.Page{ - Title: "Logout successful", - Info: "Logout successful", + if !session.IsNew { + session.Options.MaxAge = -1 + if err := auth.sessionStore.Save(r, rw, session); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + } + + onsuccess.ServeHTTP(rw, r) }) } diff --git a/auth/ldap.go b/auth/ldap.go index 9228e07..a321c1e 100644 --- a/auth/ldap.go +++ b/auth/ldap.go @@ -9,7 +9,6 @@ import ( "github.com/ClusterCockpit/cc-backend/log" "github.com/go-ldap/ldap/v3" - "github.com/jmoiron/sqlx" ) type LdapConfig struct { @@ -21,30 +20,23 @@ type LdapConfig struct { TLS bool `json:"tls"` } -var ldapAuthEnabled bool = false -var ldapConfig *LdapConfig -var ldapAdminPassword string - -func initLdap(config *LdapConfig) error { - ldapAdminPassword = os.Getenv("LDAP_ADMIN_PASSWORD") - if ldapAdminPassword == "" { +func (auth *Authentication) initLdap() error { + auth.ldapSyncUserPassword = os.Getenv("LDAP_ADMIN_PASSWORD") + if auth.ldapSyncUserPassword == "" { log.Warn("environment variable 'LDAP_ADMIN_PASSWORD' not set (ldap sync or authentication will not work)") } - - ldapConfig = config - ldapAuthEnabled = true return nil } // TODO: Add a connection pool or something like // that so that connections can be reused/cached. -func getLdapConnection(admin bool) (*ldap.Conn, error) { - conn, err := ldap.DialURL(ldapConfig.Url) +func (auth *Authentication) getLdapConnection(admin bool) (*ldap.Conn, error) { + conn, err := ldap.DialURL(auth.ldapConfig.Url) if err != nil { return nil, err } - if ldapConfig.TLS { + if auth.ldapConfig.TLS { if err := conn.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { conn.Close() return nil, err @@ -52,7 +44,7 @@ func getLdapConnection(admin bool) (*ldap.Conn, error) { } if admin { - if err := conn.Bind(ldapConfig.SearchDN, ldapAdminPassword); err != nil { + if err := conn.Bind(auth.ldapConfig.SearchDN, auth.ldapSyncUserPassword); err != nil { conn.Close() return nil, err } @@ -61,18 +53,14 @@ func getLdapConnection(admin bool) (*ldap.Conn, error) { return conn, nil } -func releaseConnection(conn *ldap.Conn) { - conn.Close() -} - -func loginViaLdap(user *User, password string) error { - l, err := getLdapConnection(false) +func (auth *Authentication) loginViaLdap(user *User, password string) error { + l, err := auth.getLdapConnection(false) if err != nil { return err } - defer releaseConnection(l) + defer l.Close() - userDn := strings.Replace(ldapConfig.UserBind, "{username}", user.Username, -1) + userDn := strings.Replace(auth.ldapConfig.UserBind, "{username}", user.Username, -1) if err := l.Bind(userDn, password); err != nil { return err } @@ -83,8 +71,8 @@ func loginViaLdap(user *User, password string) error { // Delete users where user.ldap is 1 and that do not show up in the ldap search results. // Add users to the users table that are new in the ldap search results. -func SyncWithLDAP(db *sqlx.DB) error { - if !ldapAuthEnabled { +func (auth *Authentication) SyncWithLDAP(deleteOldUsers bool) error { + if auth.ldapConfig == nil { return errors.New("ldap not enabled") } @@ -93,7 +81,7 @@ func SyncWithLDAP(db *sqlx.DB) error { const IN_BOTH int = 3 users := map[string]int{} - rows, err := db.Query(`SELECT username FROM user WHERE user.ldap = 1`) + rows, err := auth.db.Query(`SELECT username FROM user WHERE user.ldap = 1`) if err != nil { return err } @@ -107,15 +95,15 @@ func SyncWithLDAP(db *sqlx.DB) error { users[username] = IN_DB } - l, err := getLdapConnection(true) + l, err := auth.getLdapConnection(true) if err != nil { return err } - defer releaseConnection(l) + defer l.Close() ldapResults, err := l.Search(ldap.NewSearchRequest( - ldapConfig.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - ldapConfig.UserFilter, []string{"dn", "uid", "gecos"}, nil)) + auth.ldapConfig.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + auth.ldapConfig.UserFilter, []string{"dn", "uid", "gecos"}, nil)) if err != nil { return err } @@ -137,15 +125,15 @@ func SyncWithLDAP(db *sqlx.DB) error { } for username, where := range users { - if where == IN_DB { + if where == IN_DB && deleteOldUsers { log.Infof("ldap-sync: remove %#v (does not show up in LDAP anymore)", username) - if _, err := db.Exec(`DELETE FROM user WHERE user.username = ?`, username); err != nil { + if _, err := auth.db.Exec(`DELETE FROM user WHERE user.username = ?`, username); err != nil { return err } } else if where == IN_LDAP { name := newnames[username] log.Infof("ldap-sync: add %#v (name: %#v, roles: [], ldap: true)", username, name) - if _, err := db.Exec(`INSERT INTO user (username, ldap, name, roles) VALUES (?, ?, ?, ?)`, + if _, err := auth.db.Exec(`INSERT INTO user (username, ldap, name, roles) VALUES (?, ?, ?, ?)`, username, 1, name, "[]"); err != nil { return err } diff --git a/server.go b/server.go index b062313..ed9cf39 100644 --- a/server.go +++ b/server.go @@ -247,28 +247,31 @@ func main() { // Initialize sub-modules... + authentication := &auth.Authentication{} if !programConfig.DisableAuthentication { - if err := auth.Init(db, programConfig.LdapConfig); err != nil { + if err := authentication.Init(db, programConfig.LdapConfig); err != nil { log.Fatal(err) } if flagNewUser != "" { - if err := auth.AddUserToDB(db, flagNewUser); err != nil { + if err := authentication.AddUser(flagNewUser); err != nil { log.Fatal(err) } } if flagDelUser != "" { - if err := auth.DelUserFromDB(db, flagDelUser); err != nil { + if err := authentication.DelUser(flagDelUser); err != nil { log.Fatal(err) } } if flagSyncLDAP { - auth.SyncWithLDAP(db) + if err := authentication.SyncWithLDAP(true); err != nil { + log.Fatal(err) + } } if flagGenJWT != "" { - user, err := auth.FetchUserFromDB(db, flagGenJWT) + user, err := authentication.FetchUser(flagGenJWT) if err != nil { log.Fatal(err) } @@ -277,7 +280,7 @@ func main() { log.Warn("that user does not have the API role") } - jwt, err := auth.ProvideJWT(user) + jwt, err := authentication.ProvideJWT(user) if err != nil { log.Fatal(err) } @@ -350,10 +353,8 @@ func main() { }) r.Handle("/playground", graphQLPlayground) - r.Handle("/login", auth.Login(db)).Methods(http.MethodPost) r.HandleFunc("/login", handleGetLogin).Methods(http.MethodGet) - r.HandleFunc("/logout", auth.Logout).Methods(http.MethodPost) r.HandleFunc("/imprint", func(rw http.ResponseWriter, r *http.Request) { templates.Render(rw, r, "imprint.tmpl", &templates.Page{ Title: "Imprint", @@ -367,7 +368,35 @@ func main() { secured := r.PathPrefix("/").Subrouter() if !programConfig.DisableAuthentication { - secured.Use(auth.Auth) + r.Handle("/login", authentication.Login( + // On success: + http.RedirectHandler("/", http.StatusTemporaryRedirect), + + // On failure: + func(rw http.ResponseWriter, r *http.Request, loginErr error) { + rw.WriteHeader(http.StatusUnauthorized) + templates.Render(rw, r, "login.tmpl", &templates.Page{ + Title: "Login failed - ClusterCockpit", + Error: err.Error(), + }) + })).Methods(http.MethodPost) + + r.Handle("/logout", authentication.Logout(http.RedirectHandler("/login", http.StatusTemporaryRedirect))).Methods(http.MethodPost) + + secured.Use(func(next http.Handler) http.Handler { + return authentication.Auth( + // On success; + next, + + // On failure: + func(rw http.ResponseWriter, r *http.Request, authErr error) { + rw.WriteHeader(http.StatusUnauthorized) + templates.Render(rw, r, "login.tmpl", &templates.Page{ + Title: "Authentication failed - ClusterCockpit", + Error: err.Error(), + }) + }) + }) } secured.Handle("/query", graphQLEndpoint)