diff options
Diffstat (limited to 'walrss/internal')
| -rw-r--r-- | walrss/internal/core/users.go | 22 | ||||
| -rw-r--r-- | walrss/internal/http/auth.go | 113 | ||||
| -rw-r--r-- | walrss/internal/http/http.go | 28 | ||||
| -rw-r--r-- | walrss/internal/http/views/signin.qtpl.html | 5 | ||||
| -rw-r--r-- | walrss/internal/http/views/signin.qtpl.html.go | 13 | ||||
| -rw-r--r-- | walrss/internal/state/state.go | 6 | ||||
| -rw-r--r-- | walrss/internal/urls/urls.go | 9 |
7 files changed, 191 insertions, 5 deletions
diff --git a/walrss/internal/core/users.go b/walrss/internal/core/users.go index 5c17251..dfd2c86 100644 --- a/walrss/internal/core/users.go +++ b/walrss/internal/core/users.go @@ -45,12 +45,34 @@ func RegisterUser(st *state.State, email, password string) (*db.User, error) { return u, nil } +func RegisterUserOIDC(st *state.State, email string) (*db.User, error) { + u := &db.User{ + ID: shortuuid.New(), + Email: email, + } + + if _, err := st.Data.NewInsert().Model(u).Exec(context.Background()); err != nil { + if e, ok := err.(*sqlite3.Error); ok { + if e.Code == sqlite3.ErrConstraint { + return nil, NewUserError("email address in use") + } + } + return nil, err + } + + return u, nil +} + func AreUserCredentialsCorrect(st *state.State, email, password string) (bool, error) { user, err := GetUserByEmail(st, email) if err != nil { return false, err } + if len(user.Password) == 0 { + return false, nil + } + if err := bcrypt.CompareHashAndPassword(user.Password, combineStringAndSalt(password, user.Salt)); err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { return false, nil diff --git a/walrss/internal/http/auth.go b/walrss/internal/http/auth.go index a9ceca6..af098fa 100644 --- a/walrss/internal/http/auth.go +++ b/walrss/internal/http/auth.go @@ -1,12 +1,15 @@ package http import ( + "context" "errors" "github.com/codemicro/walrss/walrss/internal/core" "github.com/codemicro/walrss/walrss/internal/http/views" "github.com/codemicro/walrss/walrss/internal/urls" "github.com/gofiber/fiber/v2" "github.com/stevelacy/daz" + "math/rand" + "sync" "time" ) @@ -74,7 +77,8 @@ success: func (s *Server) authSignIn(ctx *fiber.Ctx) error { page := &views.SignInPage{ - Problem: ctx.Query("problem"), + Problem: ctx.Query("problem"), + OIDCEnabled: s.state.Config.OIDC.Enable, } if getCurrentUserID(ctx) != "" { @@ -127,3 +131,110 @@ incorrectUsernameOrPassword: ctx.Status(fiber.StatusUnauthorized) return views.SendPage(ctx, &views.SignInPage{Problem: "Incorrect username or password"}) } + +var ( + knownStates = make(map[string]time.Time) + stateLock sync.Mutex +) + +func init() { + rand.Seed(time.Now().Unix()) + + go func() { + time.Sleep(time.Minute * 5) + stateLock.Lock() + + var toDelete []string + + for k, v := range knownStates { + if !v.After(time.Now().UTC()) { + toDelete = append(toDelete, k) + } + } + + for _, k := range toDelete { + delete(knownStates, k) + } + + stateLock.Unlock() + }() +} + +func (s *Server) authOIDCOutbound(ctx *fiber.Ctx) error { + if !s.state.Config.OIDC.Enable { + return core.NewUserErrorWithStatus(fiber.StatusForbidden, "OIDC is disabled") + } + + b := make([]byte, 30) + for i := 0; i < len(b); i++ { + b[i] = byte(65 + rand.Intn(25)) + } + knownStates[string(b)] = time.Now().UTC().Add(time.Minute * 2) + + return ctx.Redirect(s.oauth2Config.AuthCodeURL(string(b))) +} + +func (s *Server) authOIDCCallback(ctx *fiber.Ctx) error { + if !s.state.Config.OIDC.Enable { + return core.NewUserErrorWithStatus(fiber.StatusForbidden, "OIDC is disabled") + } + + providedState := ctx.Query("state") + stateLock.Lock() + if exp, ok := knownStates[providedState]; ok && exp.After(time.Now().UTC()) { + delete(knownStates, providedState) + stateLock.Unlock() + } else { + stateLock.Unlock() + return core.NewUserError("Invalid state") + } + + oauth2Token, err := s.oauth2Config.Exchange(context.Background(), ctx.Query("code")) + if err != nil { + return err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return errors.New("missing ID token") + } + + idToken, err := s.oidcVerifier.Verify(context.Background(), rawIDToken) + if err != nil { + return err + } + + var claims struct { + Email string `json:"email"` + } + if err := idToken.Claims(&claims); err != nil { + return err + } + + user, err := core.GetUserByEmail(s.state, claims.Email) + if err != nil { + if errors.Is(err, core.ErrNotFound) { + if s.state.Config.Platform.DisableRegistration { + return core.NewUserError("Cannot register user on-demand as registrations are disabled.") + } + user, err = core.RegisterUserOIDC(s.state, claims.Email) + if err != nil { + return err + } + } else { + return err + } + } + + token := core.GenerateSessionToken(user.ID) + + ctx.Cookie(&fiber.Cookie{ + Name: sessionCookieKey, + Value: token, + Expires: time.Now().UTC().Add(sessionDuration), + Secure: s.state.Config.EnableSecureCookies(), + HTTPOnly: true, + }) + + return ctx.Redirect(urls.Index) +} diff --git a/walrss/internal/http/http.go b/walrss/internal/http/http.go index 36255bc..892b261 100644 --- a/walrss/internal/http/http.go +++ b/walrss/internal/http/http.go @@ -1,15 +1,19 @@ package http import ( + "context" "github.com/codemicro/walrss/walrss/internal/core" "github.com/codemicro/walrss/walrss/internal/http/views" "github.com/codemicro/walrss/walrss/internal/state" "github.com/codemicro/walrss/walrss/internal/static" "github.com/codemicro/walrss/walrss/internal/urls" + "github.com/coreos/go-oidc" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" "github.com/stevelacy/daz" + "golang.org/x/oauth2" "net/url" + "strings" "time" ) @@ -22,6 +26,10 @@ const ( type Server struct { state *state.State app *fiber.App + + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + oauth2Config *oauth2.Config } func New(st *state.State) (*Server, error) { @@ -53,6 +61,23 @@ func New(st *state.State) (*Server, error) { app: app, } + if st.Config.OIDC.Enable { + provider, err := oidc.NewProvider(context.Background(), st.Config.OIDC.Issuer) + if err != nil { + return nil, err + } + + s.oidcProvider = provider + s.oidcVerifier = provider.Verifier(&oidc.Config{ClientID: st.Config.OIDC.ClientID}) + s.oauth2Config = &oauth2.Config{ + ClientID: st.Config.OIDC.ClientID, + ClientSecret: st.Config.OIDC.ClientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: strings.TrimSuffix(st.Config.Server.ExternalURL, "/") + urls.AuthOIDCCallback, + Scopes: []string{"email", "profile", "openid"}, + } + } + s.registerHandlers() return s, nil @@ -78,6 +103,9 @@ func (s *Server) registerHandlers() { s.app.Get(urls.AuthSignIn, s.authSignIn) s.app.Post(urls.AuthSignIn, s.authSignIn) + s.app.Get(urls.AuthOIDCOutbound, s.authOIDCOutbound) + s.app.Get(urls.AuthOIDCCallback, s.authOIDCCallback) + s.app.Put(urls.EditEnabledState, s.editEnabledState) s.app.Put(urls.EditTimings, s.editTimings) diff --git a/walrss/internal/http/views/signin.qtpl.html b/walrss/internal/http/views/signin.qtpl.html index 3022f7f..b0c7d4a 100644 --- a/walrss/internal/http/views/signin.qtpl.html +++ b/walrss/internal/http/views/signin.qtpl.html @@ -3,6 +3,7 @@ {% code type SignInPage struct { BasePage Problem string + OIDCEnabled bool } %} {% func (p *SignInPage) Title() %}Sign in{% endfunc %} @@ -29,5 +30,9 @@ </form> <br> <a href="{%s= urls.AuthRegister %}">No account? Click here to register</a> + <br> + {% if p.OIDCEnabled %} + <a href="{%s= urls.AuthOIDCOutbound %}">Click here to login with OIDC</a> + {% endif %} </div> {% endfunc %}
\ No newline at end of file diff --git a/walrss/internal/http/views/signin.qtpl.html.go b/walrss/internal/http/views/signin.qtpl.html.go index ba433dd..35e9fb7 100644 --- a/walrss/internal/http/views/signin.qtpl.html.go +++ b/walrss/internal/http/views/signin.qtpl.html.go @@ -18,7 +18,8 @@ var ( type SignInPage struct { BasePage - Problem string + Problem string + OIDCEnabled bool } func (p *SignInPage) StreamTitle(qw422016 *qt422016.Writer) { @@ -73,6 +74,16 @@ func (p *SignInPage) StreamBody(qw422016 *qt422016.Writer) { <a href="`) qw422016.N().S(urls.AuthRegister) qw422016.N().S(`">No account? Click here to register</a> + <br> + `) + if p.OIDCEnabled { + qw422016.N().S(` + <a href="`) + qw422016.N().S(urls.AuthOIDCOutbound) + qw422016.N().S(`">Click here to login with OIDC</a> + `) + } + qw422016.N().S(` </div> `) } diff --git a/walrss/internal/state/state.go b/walrss/internal/state/state.go index a65a261..227f49e 100644 --- a/walrss/internal/state/state.go +++ b/walrss/internal/state/state.go @@ -37,6 +37,12 @@ type Config struct { DisableRegistration bool `fig:"disableRegistration"` DisableSecureCookies bool `fig:"disableSecureCookies"` } + OIDC struct { + Enable bool `fig:"enable"` + ClientID string `fig:"clientID"` + ClientSecret string `fig:"clientSecret"` + Issuer string `fig:"issuer"` + } Debug bool `fig:"debug"` } diff --git a/walrss/internal/urls/urls.go b/walrss/internal/urls/urls.go index 5cf166b..5b2fc69 100644 --- a/walrss/internal/urls/urls.go +++ b/walrss/internal/urls/urls.go @@ -8,9 +8,12 @@ import ( const ( Index = "/" - Auth = "/auth" - AuthSignIn = Auth + "/signin" - AuthRegister = Auth + "/register" + Auth = "/auth" + AuthSignIn = Auth + "/signin" + AuthRegister = Auth + "/register" + AuthOIDC = Auth + "/oidc" + AuthOIDCOutbound = AuthOIDC + "/outbound" + AuthOIDCCallback = AuthOIDC + "/callback" Edit = "/edit" EditEnabledState = Edit + "/enabled" |
