mirror of
https://github.com/ClusterCockpit/cc-backend
synced 2024-11-10 08:57:25 +01:00
Glue authenticators together
This commit is contained in:
parent
23f6015494
commit
db86d2cf7e
@ -1,13 +1,12 @@
|
|||||||
package authv2
|
package authv2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
||||||
sq "github.com/Masterminds/squirrel"
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
@ -43,10 +42,19 @@ func (u *User) HasRole(role string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUser(ctx context.Context) *User {
|
||||||
|
x := ctx.Value(ContextUserKey)
|
||||||
|
if x == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return x.(*User)
|
||||||
|
}
|
||||||
|
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Init(auth *Authentication, config json.RawMessage) error
|
Init(auth *Authentication, config json.RawMessage) error
|
||||||
CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool
|
CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool
|
||||||
Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error)
|
Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error)
|
||||||
Auth(rw http.ResponseWriter, r *http.Request) (*User, error)
|
Auth(rw http.ResponseWriter, r *http.Request) (*User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,10 +65,15 @@ const ContextUserKey ContextKey = "user"
|
|||||||
type Authentication struct {
|
type Authentication struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
sessionStore *sessions.CookieStore
|
sessionStore *sessions.CookieStore
|
||||||
|
SessionMaxAge time.Duration
|
||||||
|
|
||||||
authenticators []Authenticator
|
authenticators []Authenticator
|
||||||
|
LdapAuth *LdapAutnenticator
|
||||||
|
JwtAuth *JWTAuthenticator
|
||||||
|
LocalAuth *LocalAuthenticator
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init(db *sqlx.DB) (*Authentication, error) {
|
func Init(db *sqlx.DB, configs map[string]json.RawMessage) (*Authentication, error) {
|
||||||
auth := &Authentication{}
|
auth := &Authentication{}
|
||||||
auth.db = db
|
auth.db = db
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
@ -75,51 +88,29 @@ func Init(db *sqlx.DB) (*Authentication, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.LocalAuth = &LocalAuthenticator{}
|
||||||
|
if err := auth.LocalAuth.Init(auth, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
auth.authenticators = append(auth.authenticators, auth.LocalAuth)
|
||||||
|
|
||||||
|
auth.JwtAuth = &JWTAuthenticator{}
|
||||||
|
if err := auth.JwtAuth.Init(auth, nil); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
auth.authenticators = append(auth.authenticators, auth.JwtAuth)
|
||||||
|
|
||||||
|
if config, ok := configs["ldap"]; ok {
|
||||||
|
auth.LdapAuth = &LdapAutnenticator{}
|
||||||
|
if err := auth.LdapAuth.Init(auth, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
auth.authenticators = append(auth.authenticators, auth.LdapAuth)
|
||||||
|
}
|
||||||
|
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Authentication) GetUser(username string) (*User, error) {
|
|
||||||
user := &User{Username: username}
|
|
||||||
var hashedPassword, name, rawRoles, email sql.NullString
|
|
||||||
if err := sq.Select("password", "ldap", "name", "roles", "email").From("user").
|
|
||||||
Where("user.username = ?", username).RunWith(auth.db).
|
|
||||||
QueryRow().Scan(&hashedPassword, &user.AuthSource, &name, &rawRoles, &email); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.Password = hashedPassword.String
|
|
||||||
user.Name = name.String
|
|
||||||
user.Email = email.String
|
|
||||||
if rawRoles.Valid {
|
|
||||||
if err := json.Unmarshal([]byte(rawRoles.String), &user.Roles); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *Authentication) AddUser(user *User) error {
|
|
||||||
rolesJson, _ := json.Marshal(user.Roles)
|
|
||||||
cols := []string{"username", "password", "roles"}
|
|
||||||
vals := []interface{}{user.Username, user.Password, string(rolesJson)}
|
|
||||||
if user.Name != "" {
|
|
||||||
cols = append(cols, "name")
|
|
||||||
vals = append(vals, user.Name)
|
|
||||||
}
|
|
||||||
if user.Email != "" {
|
|
||||||
cols = append(cols, "email")
|
|
||||||
vals = append(vals, user.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(auth.db).Exec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("new user %#v created (roles: %s)", user.Username, rolesJson)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *Authentication) AuthViaSession(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (auth *Authentication) AuthViaSession(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
session, err := auth.sessionStore.Get(r, "session")
|
session, err := auth.sessionStore.Get(r, "session")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -135,5 +126,104 @@ func (auth *Authentication) AuthViaSession(rw http.ResponseWriter, r *http.Reque
|
|||||||
return &User{
|
return &User{
|
||||||
Username: username,
|
Username: username,
|
||||||
Roles: roles,
|
Roles: roles,
|
||||||
|
AuthSource: -1,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle a POST request that should log the user in, starting a new session.
|
||||||
|
func (auth *Authentication) Login(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, loginErr error)) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
username := r.FormValue("username")
|
||||||
|
user := (*User)(nil)
|
||||||
|
if username != "" {
|
||||||
|
if user, _ = auth.GetUser(username); err != nil {
|
||||||
|
log.Warnf("login of unkown user %#v", username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, authenticator := range auth.authenticators {
|
||||||
|
if !authenticator.CanLogin(user, rw, r) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err = authenticator.Login(user, rw, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("login failed: %s", err.Error())
|
||||||
|
onfailure(rw, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := auth.sessionStore.New(r, "session")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("session creation failed: %s", err.Error())
|
||||||
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.SessionMaxAge != 0 {
|
||||||
|
session.Options.MaxAge = int(auth.SessionMaxAge.Seconds())
|
||||||
|
}
|
||||||
|
session.Values["username"] = user.Username
|
||||||
|
session.Values["roles"] = user.Roles
|
||||||
|
if err := auth.sessionStore.Save(r, rw, session); err != nil {
|
||||||
|
log.Errorf("session save failed: %s", err.Error())
|
||||||
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("login successfull: user: %#v (roles: %v)", user.Username, user.Roles)
|
||||||
|
ctx := context.WithValue(r.Context(), ContextUserKey, user)
|
||||||
|
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn("login failed: no authenticator applied")
|
||||||
|
onfailure(rw, r, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate the user and put a User object in the
|
||||||
|
// context of the request. If authentication fails,
|
||||||
|
// do not continue but send client to the login screen.
|
||||||
|
func (auth *Authentication) Auth(onsuccess http.Handler, onfailure func(rw http.ResponseWriter, r *http.Request, authErr error)) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
for _, authenticator := range auth.authenticators {
|
||||||
|
user, err := authenticator.Auth(rw, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("authentication failed: %s", err.Error())
|
||||||
|
http.Error(rw, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), ContextUserKey, user)
|
||||||
|
onsuccess.ServeHTTP(rw, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("authentication failed: %s", "no authenticator applied")
|
||||||
|
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clears the session cookie
|
||||||
|
func (auth *Authentication) Logout(onsuccess http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
session, err := auth.sessionStore.Get(r, "session")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.IsNew {
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
if err := auth.sessionStore.Save(r, rw, session); err != nil {
|
||||||
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onsuccess.ServeHTTP(rw, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -2,7 +2,6 @@ package authv2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -48,10 +47,10 @@ func (ja *JWTAuthenticator) Init(auth *Authentication, rawConfig json.RawMessage
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JWTAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
func (ja *JWTAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
||||||
return user.AuthSource == AuthViaToken || r.Header.Get("Authorization") != ""
|
return (user != nil && user.AuthSource == AuthViaToken) || r.Header.Get("Authorization") != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ja *JWTAuthenticator) Login(_ *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (ja *JWTAuthenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
rawtoken := r.Header.Get("X-Auth-Token")
|
rawtoken := r.Header.Get("X-Auth-Token")
|
||||||
if rawtoken == "" {
|
if rawtoken == "" {
|
||||||
rawtoken = r.Header.Get("Authorization")
|
rawtoken = r.Header.Get("Authorization")
|
||||||
@ -84,14 +83,9 @@ func (ja *JWTAuthenticator) Login(_ *User, password string, rw http.ResponseWrit
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := ja.auth.GetUser(sub)
|
if user == nil {
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
user = &User{
|
user = &User{
|
||||||
Username: user.Username,
|
Username: sub,
|
||||||
Roles: roles,
|
Roles: roles,
|
||||||
AuthSource: AuthViaToken,
|
AuthSource: AuthViaToken,
|
||||||
}
|
}
|
||||||
@ -114,13 +108,7 @@ func (ja *JWTAuthenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User
|
|||||||
// Because a user can also log in via a token, the
|
// Because a user can also log in via a token, the
|
||||||
// session cookie must be checked here as well:
|
// session cookie must be checked here as well:
|
||||||
if rawtoken == "" {
|
if rawtoken == "" {
|
||||||
user, err := ja.auth.AuthViaSession(rw, r)
|
return ja.auth.AuthViaSession(rw, r)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.AuthSource = AuthViaToken
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.Parse(rawtoken, func(t *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(rawtoken, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
@ -67,10 +67,10 @@ func (la *LdapAutnenticator) Init(auth *Authentication, rawConfig json.RawMessag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (la *LdapAutnenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
func (la *LdapAutnenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
||||||
return user.AuthSource == AuthViaLDAP
|
return user != nil && user.AuthSource == AuthViaLDAP
|
||||||
}
|
}
|
||||||
|
|
||||||
func (la *LdapAutnenticator) Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (la *LdapAutnenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
l, err := la.getLdapConnection(false)
|
l, err := la.getLdapConnection(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -78,7 +78,7 @@ func (la *LdapAutnenticator) Login(user *User, password string, rw http.Response
|
|||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
|
||||||
userDn := strings.Replace(la.config.UserBind, "{username}", user.Username, -1)
|
userDn := strings.Replace(la.config.UserBind, "{username}", user.Username, -1)
|
||||||
if err := l.Bind(userDn, password); err != nil {
|
if err := l.Bind(userDn, r.FormValue("password")); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,13 +86,7 @@ func (la *LdapAutnenticator) Login(user *User, password string, rw http.Response
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (la *LdapAutnenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (la *LdapAutnenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
user, err := la.auth.AuthViaSession(rw, r)
|
return la.auth.AuthViaSession(rw, r)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.AuthSource = AuthViaLDAP
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (la *LdapAutnenticator) Sync() error {
|
func (la *LdapAutnenticator) Sync() error {
|
||||||
|
@ -20,11 +20,11 @@ func (la *LocalAuthenticator) Init(auth *Authentication, rawConfig json.RawMessa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (la *LocalAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
func (la *LocalAuthenticator) CanLogin(user *User, rw http.ResponseWriter, r *http.Request) bool {
|
||||||
return user.AuthSource == AuthViaLocalPassword
|
return user != nil && user.AuthSource == AuthViaLocalPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
func (la *LocalAuthenticator) Login(user *User, password string, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (la *LocalAuthenticator) Login(user *User, rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
if e := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); e != nil {
|
if e := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(r.FormValue("password"))); e != nil {
|
||||||
return nil, fmt.Errorf("user '%s' provided the wrong password (%w)", user.Username, e)
|
return nil, fmt.Errorf("user '%s' provided the wrong password (%w)", user.Username, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,11 +32,5 @@ func (la *LocalAuthenticator) Login(user *User, password string, rw http.Respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (la *LocalAuthenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
func (la *LocalAuthenticator) Auth(rw http.ResponseWriter, r *http.Request) (*User, error) {
|
||||||
user, err := la.auth.AuthViaSession(rw, r)
|
return la.auth.AuthViaSession(rw, r)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user.AuthSource = AuthViaLocalPassword
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
138
internal/auth-v2/users.go
Normal file
138
internal/auth-v2/users.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
package authv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ClusterCockpit/cc-backend/internal/graph/model"
|
||||||
|
"github.com/ClusterCockpit/cc-backend/pkg/log"
|
||||||
|
sq "github.com/Masterminds/squirrel"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (auth *Authentication) GetUser(username string) (*User, error) {
|
||||||
|
user := &User{Username: username}
|
||||||
|
var hashedPassword, name, rawRoles, email sql.NullString
|
||||||
|
if err := sq.Select("password", "ldap", "name", "roles", "email").From("user").
|
||||||
|
Where("user.username = ?", username).RunWith(auth.db).
|
||||||
|
QueryRow().Scan(&hashedPassword, &user.AuthSource, &name, &rawRoles, &email); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Password = hashedPassword.String
|
||||||
|
user.Name = name.String
|
||||||
|
user.Email = email.String
|
||||||
|
if rawRoles.Valid {
|
||||||
|
if err := json.Unmarshal([]byte(rawRoles.String), &user.Roles); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authentication) AddUser(user *User) error {
|
||||||
|
rolesJson, _ := json.Marshal(user.Roles)
|
||||||
|
cols := []string{"username", "password", "roles"}
|
||||||
|
vals := []interface{}{user.Username, user.Password, string(rolesJson)}
|
||||||
|
if user.Name != "" {
|
||||||
|
cols = append(cols, "name")
|
||||||
|
vals = append(vals, user.Name)
|
||||||
|
}
|
||||||
|
if user.Email != "" {
|
||||||
|
cols = append(cols, "email")
|
||||||
|
vals = append(vals, user.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(auth.db).Exec(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("new user %#v created (roles: %s, auth-source: %d)", user.Username, rolesJson, user.AuthSource)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authentication) DelUser(username string) error {
|
||||||
|
_, err := auth.db.Exec(`DELETE FROM user WHERE user.username = ?`, username)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authentication) ListUsers(specialsOnly bool) ([]*User, error) {
|
||||||
|
q := sq.Select("username", "name", "email", "roles").From("user")
|
||||||
|
if specialsOnly {
|
||||||
|
q = q.Where("(roles != '[\"user\"]' AND roles != '[]')")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := q.RunWith(auth.db).Query()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*User, 0)
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
rawroles := ""
|
||||||
|
user := &User{}
|
||||||
|
var name, email sql.NullString
|
||||||
|
if err := rows.Scan(&user.Username, &name, &email, &rawroles); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(rawroles), &user.Roles); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Name = name.String
|
||||||
|
user.Email = email.String
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Authentication) AddRole(ctx context.Context, username string, role string) error {
|
||||||
|
user, err := auth.GetUser(username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if role != RoleAdmin && role != RoleApi && role != RoleUser {
|
||||||
|
return fmt.Errorf("invalid user role: %#v", role)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range user.Roles {
|
||||||
|
if r == role {
|
||||||
|
return fmt.Errorf("user %#v already has role %#v", username, role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
roles, _ := json.Marshal(append(user.Roles, role))
|
||||||
|
if _, err := sq.Update("user").Set("roles", roles).Where("user.username = ?", username).RunWith(auth.db).Exec(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchUser(ctx context.Context, db *sqlx.DB, username string) (*model.User, error) {
|
||||||
|
me := GetUser(ctx)
|
||||||
|
if me != nil && !me.HasRole(RoleAdmin) && me.Username != username {
|
||||||
|
return nil, errors.New("forbidden")
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &model.User{Username: username}
|
||||||
|
var name, email sql.NullString
|
||||||
|
if err := sq.Select("name", "email").From("user").Where("user.username = ?", username).
|
||||||
|
RunWith(db).QueryRow().Scan(&name, &email); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Name = name.String
|
||||||
|
user.Email = email.String
|
||||||
|
return user, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user