diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index a01af35..eff9153 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -2,6 +2,7 @@ package auth import ( "crypto/ed25519" + "database/sql" "encoding/base64" "errors" "fmt" @@ -106,15 +107,23 @@ func (ja *JWTAuthenticator) Login(user *User, rw http.ResponseWriter, r *http.Re } } } + if rawrole, ok := claims["roles"].(string); ok { + roles = append(roles, rawrole) + } if user == nil { - user = &User{ - Username: sub, - Roles: roles, - AuthSource: AuthViaToken, - } - if err := ja.auth.AddUser(user); err != nil { + user, err = ja.auth.GetUser(sub) + if err != nil && err != sql.ErrNoRows { return nil, err + } else if user == nil { + user = &User{ + Username: sub, + Roles: roles, + AuthSource: AuthViaToken, + } + if err := ja.auth.AddUser(user); err != nil { + return nil, err + } } } diff --git a/internal/auth/users.go b/internal/auth/users.go index 0de710f..611b051 100644 --- a/internal/auth/users.go +++ b/internal/auth/users.go @@ -38,13 +38,8 @@ func (auth *Authentication) GetUser(username string) (*User, error) { func (auth *Authentication) AddUser(user *User) error { rolesJson, _ := json.Marshal(user.Roles) - password, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) - if err != nil { - return err - } - - cols := []string{"username", "password", "roles"} - vals := []interface{}{user.Username, string(password), string(rolesJson)} + cols := []string{"username", "roles"} + vals := []interface{}{user.Username, string(rolesJson)} if user.Name != "" { cols = append(cols, "name") vals = append(vals, user.Name) @@ -53,6 +48,14 @@ func (auth *Authentication) AddUser(user *User) error { cols = append(cols, "email") vals = append(vals, user.Email) } + if user.Password != "" { + password, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) + if err != nil { + return err + } + cols = append(cols, "password") + vals = append(vals, string(password)) + } if _, err := sq.Insert("user").Columns(cols...).Values(vals...).RunWith(auth.db).Exec(); err != nil { return err