diff --git a/internal/auth/ldap.go b/internal/auth/ldap.go index b800ca7..d9888ca 100644 --- a/internal/auth/ldap.go +++ b/internal/auth/ldap.go @@ -21,7 +21,7 @@ import ( type LdapAuthenticator struct { syncPassword string - UserAttr string + UserAttr string } var _ Authenticator = (*LdapAuthenticator)(nil) @@ -74,8 +74,8 @@ func (la *LdapAuthenticator) CanLogin( user *schema.User, username string, rw http.ResponseWriter, - r *http.Request) (*schema.User, bool) { - + r *http.Request, +) (*schema.User, bool) { lc := config.Keys.LdapConfig if user != nil { @@ -138,8 +138,8 @@ func (la *LdapAuthenticator) CanLogin( func (la *LdapAuthenticator) Login( user *schema.User, rw http.ResponseWriter, - r *http.Request) (*schema.User, error) { - + r *http.Request, +) (*schema.User, error) { l, err := la.getLdapConnection(false) if err != nil { log.Warn("Error while getting ldap connection") @@ -238,7 +238,6 @@ func (la *LdapAuthenticator) Sync() error { } func (la *LdapAuthenticator) getLdapConnection(admin bool) (*ldap.Conn, error) { - lc := config.Keys.LdapConfig conn, err := ldap.DialURL(lc.Url) if err != nil { diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 480b212..cfcf5b6 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -6,38 +6,59 @@ package auth import ( "context" + "crypto/rand" + "encoding/base64" + "io" "log" "net/http" + "strings" + "time" + "github.com/ClusterCockpit/cc-backend/internal/config" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" "golang.org/x/oauth2" ) type OIDC struct { - client *oauth2.Config - provider *oidc.Provider - state string - codeVerifier string + client *oauth2.Config + provider *oidc.Provider +} + +func randString(nByte int) (string, error) { + b := make([]byte, nByte) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func setCallbackCookie(w http.ResponseWriter, r *http.Request, name, value string) { + c := &http.Cookie{ + Name: name, + Value: value, + MaxAge: int(time.Hour.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + http.SetCookie(w, c) } func (oa *OIDC) Init(r *mux.Router) error { - oa.client = &oauth2.Config{ - ClientID: "YOUR_CLIENT_ID", - ClientSecret: "YOUR_CLIENT_SECRET", - Endpoint: oauth2.Endpoint{ - AuthURL: "https://provider.com/o/oauth2/auth", - TokenURL: "https://provider.com/o/oauth2/token", - }, - } - provider, err := oidc.NewProvider(context.Background(), "https://provider") if err != nil { log.Fatal(err) } - oa.provider = provider + oa.client = &oauth2.Config{ + ClientID: "YOUR_CLIENT_ID", + ClientSecret: "YOUR_CLIENT_SECRET", + Endpoint: provider.Endpoint(), + RedirectURL: "https://" + config.Keys.Addr + "/oidc-callback", + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + } + r.HandleFunc("/oidc-login", oa.OAuth2Login) r.HandleFunc("/oidc-callback", oa.OAuth2Callback) @@ -45,9 +66,18 @@ func (oa *OIDC) Init(r *mux.Router) error { } func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) { + c, err := r.Cookie("state") + if err != nil { + http.Error(rw, "state not found", http.StatusBadRequest) + return + } + + str := strings.Split(c.Value, " ") + state := str[0] + codeVerifier := str[1] + _ = r.ParseForm() - state := r.Form.Get("state") - if state != oa.state { + if r.Form.Get("state") != state { http.Error(rw, "State invalid", http.StatusBadRequest) return } @@ -56,18 +86,32 @@ func (oa *OIDC) OAuth2Callback(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "Code not found", http.StatusBadRequest) return } - token, err := oa.client.Exchange(context.Background(), code, oauth2.VerifierOption(oa.codeVerifier)) + token, err := oa.client.Exchange(context.Background(), code, oauth2.VerifierOption(codeVerifier)) if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + http.Error(rw, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + return + } + + userInfo, err := oa.provider.UserInfo(context.Background(), oauth2.StaticTokenSource(token)) + if err != nil { + http.Error(rw, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError) return } } func (oa *OIDC) OAuth2Login(rw http.ResponseWriter, r *http.Request) { + state, err := randString(16) + if err != nil { + http.Error(rw, "Internal error", http.StatusInternalServerError) + return + } + // use PKCE to protect against CSRF attacks - oa.codeVerifier = oauth2.GenerateVerifier() + codeVerifier := oauth2.GenerateVerifier() + + setCallbackCookie(rw, r, "state", strings.Join([]string{state, codeVerifier}, " ")) // Redirect user to consent page to ask for permission - url := oa.client.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oa.codeVerifier)) + url := oa.client.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(codeVerifier)) http.Redirect(rw, r, url, http.StatusFound) }