aboutsummaryrefslogtreecommitdiffstats
path: root/walrss/internal
diff options
context:
space:
mode:
Diffstat (limited to 'walrss/internal')
-rw-r--r--walrss/internal/core/users.go22
-rw-r--r--walrss/internal/http/auth.go113
-rw-r--r--walrss/internal/http/http.go28
-rw-r--r--walrss/internal/http/views/signin.qtpl.html5
-rw-r--r--walrss/internal/http/views/signin.qtpl.html.go13
-rw-r--r--walrss/internal/state/state.go6
-rw-r--r--walrss/internal/urls/urls.go9
7 files changed, 191 insertions, 5 deletions
diff --git a/walrss/internal/core/users.go b/walrss/internal/core/users.go
index 5c17251..dfd2c86 100644
--- a/walrss/internal/core/users.go
+++ b/walrss/internal/core/users.go
@@ -45,12 +45,34 @@ func RegisterUser(st *state.State, email, password string) (*db.User, error) {
return u, nil
}
+func RegisterUserOIDC(st *state.State, email string) (*db.User, error) {
+ u := &db.User{
+ ID: shortuuid.New(),
+ Email: email,
+ }
+
+ if _, err := st.Data.NewInsert().Model(u).Exec(context.Background()); err != nil {
+ if e, ok := err.(*sqlite3.Error); ok {
+ if e.Code == sqlite3.ErrConstraint {
+ return nil, NewUserError("email address in use")
+ }
+ }
+ return nil, err
+ }
+
+ return u, nil
+}
+
func AreUserCredentialsCorrect(st *state.State, email, password string) (bool, error) {
user, err := GetUserByEmail(st, email)
if err != nil {
return false, err
}
+ if len(user.Password) == 0 {
+ return false, nil
+ }
+
if err := bcrypt.CompareHashAndPassword(user.Password, combineStringAndSalt(password, user.Salt)); err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return false, nil
diff --git a/walrss/internal/http/auth.go b/walrss/internal/http/auth.go
index a9ceca6..af098fa 100644
--- a/walrss/internal/http/auth.go
+++ b/walrss/internal/http/auth.go
@@ -1,12 +1,15 @@
package http
import (
+ "context"
"errors"
"github.com/codemicro/walrss/walrss/internal/core"
"github.com/codemicro/walrss/walrss/internal/http/views"
"github.com/codemicro/walrss/walrss/internal/urls"
"github.com/gofiber/fiber/v2"
"github.com/stevelacy/daz"
+ "math/rand"
+ "sync"
"time"
)
@@ -74,7 +77,8 @@ success:
func (s *Server) authSignIn(ctx *fiber.Ctx) error {
page := &views.SignInPage{
- Problem: ctx.Query("problem"),
+ Problem: ctx.Query("problem"),
+ OIDCEnabled: s.state.Config.OIDC.Enable,
}
if getCurrentUserID(ctx) != "" {
@@ -127,3 +131,110 @@ incorrectUsernameOrPassword:
ctx.Status(fiber.StatusUnauthorized)
return views.SendPage(ctx, &views.SignInPage{Problem: "Incorrect username or password"})
}
+
+var (
+ knownStates = make(map[string]time.Time)
+ stateLock sync.Mutex
+)
+
+func init() {
+ rand.Seed(time.Now().Unix())
+
+ go func() {
+ time.Sleep(time.Minute * 5)
+ stateLock.Lock()
+
+ var toDelete []string
+
+ for k, v := range knownStates {
+ if !v.After(time.Now().UTC()) {
+ toDelete = append(toDelete, k)
+ }
+ }
+
+ for _, k := range toDelete {
+ delete(knownStates, k)
+ }
+
+ stateLock.Unlock()
+ }()
+}
+
+func (s *Server) authOIDCOutbound(ctx *fiber.Ctx) error {
+ if !s.state.Config.OIDC.Enable {
+ return core.NewUserErrorWithStatus(fiber.StatusForbidden, "OIDC is disabled")
+ }
+
+ b := make([]byte, 30)
+ for i := 0; i < len(b); i++ {
+ b[i] = byte(65 + rand.Intn(25))
+ }
+ knownStates[string(b)] = time.Now().UTC().Add(time.Minute * 2)
+
+ return ctx.Redirect(s.oauth2Config.AuthCodeURL(string(b)))
+}
+
+func (s *Server) authOIDCCallback(ctx *fiber.Ctx) error {
+ if !s.state.Config.OIDC.Enable {
+ return core.NewUserErrorWithStatus(fiber.StatusForbidden, "OIDC is disabled")
+ }
+
+ providedState := ctx.Query("state")
+ stateLock.Lock()
+ if exp, ok := knownStates[providedState]; ok && exp.After(time.Now().UTC()) {
+ delete(knownStates, providedState)
+ stateLock.Unlock()
+ } else {
+ stateLock.Unlock()
+ return core.NewUserError("Invalid state")
+ }
+
+ oauth2Token, err := s.oauth2Config.Exchange(context.Background(), ctx.Query("code"))
+ if err != nil {
+ return err
+ }
+
+ rawIDToken, ok := oauth2Token.Extra("id_token").(string)
+ if !ok {
+ return errors.New("missing ID token")
+ }
+
+ idToken, err := s.oidcVerifier.Verify(context.Background(), rawIDToken)
+ if err != nil {
+ return err
+ }
+
+ var claims struct {
+ Email string `json:"email"`
+ }
+ if err := idToken.Claims(&claims); err != nil {
+ return err
+ }
+
+ user, err := core.GetUserByEmail(s.state, claims.Email)
+ if err != nil {
+ if errors.Is(err, core.ErrNotFound) {
+ if s.state.Config.Platform.DisableRegistration {
+ return core.NewUserError("Cannot register user on-demand as registrations are disabled.")
+ }
+ user, err = core.RegisterUserOIDC(s.state, claims.Email)
+ if err != nil {
+ return err
+ }
+ } else {
+ return err
+ }
+ }
+
+ token := core.GenerateSessionToken(user.ID)
+
+ ctx.Cookie(&fiber.Cookie{
+ Name: sessionCookieKey,
+ Value: token,
+ Expires: time.Now().UTC().Add(sessionDuration),
+ Secure: s.state.Config.EnableSecureCookies(),
+ HTTPOnly: true,
+ })
+
+ return ctx.Redirect(urls.Index)
+}
diff --git a/walrss/internal/http/http.go b/walrss/internal/http/http.go
index 36255bc..892b261 100644
--- a/walrss/internal/http/http.go
+++ b/walrss/internal/http/http.go
@@ -1,15 +1,19 @@
package http
import (
+ "context"
"github.com/codemicro/walrss/walrss/internal/core"
"github.com/codemicro/walrss/walrss/internal/http/views"
"github.com/codemicro/walrss/walrss/internal/state"
"github.com/codemicro/walrss/walrss/internal/static"
"github.com/codemicro/walrss/walrss/internal/urls"
+ "github.com/coreos/go-oidc"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
"github.com/stevelacy/daz"
+ "golang.org/x/oauth2"
"net/url"
+ "strings"
"time"
)
@@ -22,6 +26,10 @@ const (
type Server struct {
state *state.State
app *fiber.App
+
+ oidcProvider *oidc.Provider
+ oidcVerifier *oidc.IDTokenVerifier
+ oauth2Config *oauth2.Config
}
func New(st *state.State) (*Server, error) {
@@ -53,6 +61,23 @@ func New(st *state.State) (*Server, error) {
app: app,
}
+ if st.Config.OIDC.Enable {
+ provider, err := oidc.NewProvider(context.Background(), st.Config.OIDC.Issuer)
+ if err != nil {
+ return nil, err
+ }
+
+ s.oidcProvider = provider
+ s.oidcVerifier = provider.Verifier(&oidc.Config{ClientID: st.Config.OIDC.ClientID})
+ s.oauth2Config = &oauth2.Config{
+ ClientID: st.Config.OIDC.ClientID,
+ ClientSecret: st.Config.OIDC.ClientSecret,
+ Endpoint: provider.Endpoint(),
+ RedirectURL: strings.TrimSuffix(st.Config.Server.ExternalURL, "/") + urls.AuthOIDCCallback,
+ Scopes: []string{"email", "profile", "openid"},
+ }
+ }
+
s.registerHandlers()
return s, nil
@@ -78,6 +103,9 @@ func (s *Server) registerHandlers() {
s.app.Get(urls.AuthSignIn, s.authSignIn)
s.app.Post(urls.AuthSignIn, s.authSignIn)
+ s.app.Get(urls.AuthOIDCOutbound, s.authOIDCOutbound)
+ s.app.Get(urls.AuthOIDCCallback, s.authOIDCCallback)
+
s.app.Put(urls.EditEnabledState, s.editEnabledState)
s.app.Put(urls.EditTimings, s.editTimings)
diff --git a/walrss/internal/http/views/signin.qtpl.html b/walrss/internal/http/views/signin.qtpl.html
index 3022f7f..b0c7d4a 100644
--- a/walrss/internal/http/views/signin.qtpl.html
+++ b/walrss/internal/http/views/signin.qtpl.html
@@ -3,6 +3,7 @@
{% code type SignInPage struct {
BasePage
Problem string
+ OIDCEnabled bool
} %}
{% func (p *SignInPage) Title() %}Sign in{% endfunc %}
@@ -29,5 +30,9 @@
</form>
<br>
<a href="{%s= urls.AuthRegister %}">No account? Click here to register</a>
+ <br>
+ {% if p.OIDCEnabled %}
+ <a href="{%s= urls.AuthOIDCOutbound %}">Click here to login with OIDC</a>
+ {% endif %}
</div>
{% endfunc %} \ No newline at end of file
diff --git a/walrss/internal/http/views/signin.qtpl.html.go b/walrss/internal/http/views/signin.qtpl.html.go
index ba433dd..35e9fb7 100644
--- a/walrss/internal/http/views/signin.qtpl.html.go
+++ b/walrss/internal/http/views/signin.qtpl.html.go
@@ -18,7 +18,8 @@ var (
type SignInPage struct {
BasePage
- Problem string
+ Problem string
+ OIDCEnabled bool
}
func (p *SignInPage) StreamTitle(qw422016 *qt422016.Writer) {
@@ -73,6 +74,16 @@ func (p *SignInPage) StreamBody(qw422016 *qt422016.Writer) {
<a href="`)
qw422016.N().S(urls.AuthRegister)
qw422016.N().S(`">No account? Click here to register</a>
+ <br>
+ `)
+ if p.OIDCEnabled {
+ qw422016.N().S(`
+ <a href="`)
+ qw422016.N().S(urls.AuthOIDCOutbound)
+ qw422016.N().S(`">Click here to login with OIDC</a>
+ `)
+ }
+ qw422016.N().S(`
</div>
`)
}
diff --git a/walrss/internal/state/state.go b/walrss/internal/state/state.go
index a65a261..227f49e 100644
--- a/walrss/internal/state/state.go
+++ b/walrss/internal/state/state.go
@@ -37,6 +37,12 @@ type Config struct {
DisableRegistration bool `fig:"disableRegistration"`
DisableSecureCookies bool `fig:"disableSecureCookies"`
}
+ OIDC struct {
+ Enable bool `fig:"enable"`
+ ClientID string `fig:"clientID"`
+ ClientSecret string `fig:"clientSecret"`
+ Issuer string `fig:"issuer"`
+ }
Debug bool `fig:"debug"`
}
diff --git a/walrss/internal/urls/urls.go b/walrss/internal/urls/urls.go
index 5cf166b..5b2fc69 100644
--- a/walrss/internal/urls/urls.go
+++ b/walrss/internal/urls/urls.go
@@ -8,9 +8,12 @@ import (
const (
Index = "/"
- Auth = "/auth"
- AuthSignIn = Auth + "/signin"
- AuthRegister = Auth + "/register"
+ Auth = "/auth"
+ AuthSignIn = Auth + "/signin"
+ AuthRegister = Auth + "/register"
+ AuthOIDC = Auth + "/oidc"
+ AuthOIDCOutbound = AuthOIDC + "/outbound"
+ AuthOIDCCallback = AuthOIDC + "/callback"
Edit = "/edit"
EditEnabledState = Edit + "/enabled"