aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorLeonardo Bishop <me@leonardobishop.com>2025-08-15 19:20:48 +0100
committerLeonardo Bishop <me@leonardobishop.com>2025-08-15 19:20:48 +0100
commit8f7dec8ba6b2f9bde01afd0a110596ebbd43e0ed (patch)
tree7b4f203d92f4b99b1e98fac314415e293984196b /pkg
parent4697556cac819c47d068819b9fc9c3b4ea84e279 (diff)
Implement OIDC
Diffstat (limited to 'pkg')
-rw-r--r--pkg/auth/basic.go56
-rw-r--r--pkg/auth/oauth.go191
-rw-r--r--pkg/auth/service.go49
-rw-r--r--pkg/database/migrations/0002_nullable_passwords.sql4
-rw-r--r--pkg/database/sqlc/calendars.sql.go2
-rw-r--r--pkg/database/sqlc/db.go2
-rw-r--r--pkg/database/sqlc/favourites.sql.go2
-rw-r--r--pkg/database/sqlc/models.go8
-rw-r--r--pkg/database/sqlc/users.sql.go8
-rw-r--r--pkg/session/memory.go4
-rw-r--r--pkg/user/service.go49
11 files changed, 332 insertions, 43 deletions
diff --git a/pkg/auth/basic.go b/pkg/auth/basic.go
new file mode 100644
index 0000000..dafd93f
--- /dev/null
+++ b/pkg/auth/basic.go
@@ -0,0 +1,56 @@
+package auth
+
+import (
+ "errors"
+
+ "github.com/LMBishop/confplanner/pkg/database/sqlc"
+ "github.com/LMBishop/confplanner/pkg/user"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type BasicAuthProvider struct {
+ userService user.Service
+}
+
+func NewBasicAuthProvider(userService user.Service) AuthProvider {
+ return &BasicAuthProvider{
+ userService: userService,
+ }
+}
+
+func (p *BasicAuthProvider) Authenticate(username string, password string) (*sqlc.User, error) {
+ random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, err
+ }
+
+ u, err := p.userService.GetUserByName(username)
+ if err != nil {
+ if errors.Is(err, user.ErrUserNotFound) {
+ bcrypt.CompareHashAndPassword(random, []byte(password))
+ return nil, nil
+ }
+ return nil, err
+ }
+ if !u.Password.Valid {
+ bcrypt.CompareHashAndPassword(random, []byte(password))
+ return nil, nil
+ }
+
+ if err = bcrypt.CompareHashAndPassword([]byte(u.Password.String), []byte(password)); err != nil {
+ if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ return u, nil
+}
+
+func (p *BasicAuthProvider) Name() string {
+ return "Basic"
+}
+
+func (p *BasicAuthProvider) Type() string {
+ return "basic"
+}
diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go
new file mode 100644
index 0000000..9c45e7a
--- /dev/null
+++ b/pkg/auth/oauth.go
@@ -0,0 +1,191 @@
+package auth
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/LMBishop/confplanner/pkg/database/sqlc"
+ "github.com/LMBishop/confplanner/pkg/user"
+ "github.com/coreos/go-oidc/v3/oidc"
+ "github.com/tidwall/gjson"
+ "golang.org/x/oauth2"
+)
+
+type OIDCAuthProvider struct {
+ name string
+ userService user.Service
+ oauthConfig *oauth2.Config
+ oidcProvider *oidc.Provider
+ oidcVerifier *oidc.IDTokenVerifier
+ loginFilter string
+ loginFilterAllowedValues []string
+ userSyncFilter string
+ states map[string]*oidcState
+ lock sync.RWMutex
+}
+
+type oidcState struct {
+ expiry time.Time
+ ip string
+ userAgent string
+}
+
+var (
+ ErrStateVerificationFailed = errors.New("state verification failed")
+ ErrInvalidState = errors.New("invalid state")
+ ErrMissingIDToken = errors.New("missing ID token")
+ ErrNotAuthorised = errors.New("not authorised")
+ ErrUserSyncFailed = errors.New("user sync failed")
+)
+
+func NewOIDCAuthProvider(userService user.Service, name, clientID, clientSecret, endpoint, callbackURL, loginFilter, userSyncFilter string, loginFilterAllowedValues []string) (AuthProvider, error) {
+ provider, err := oidc.NewProvider(context.Background(), endpoint)
+ if err != nil {
+ return nil, err
+ }
+
+ return &OIDCAuthProvider{
+ name: name,
+ userService: userService,
+ oauthConfig: &oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Endpoint: provider.Endpoint(),
+ RedirectURL: callbackURL,
+ Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+ },
+ oidcProvider: provider,
+ oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}),
+ loginFilter: loginFilter,
+ loginFilterAllowedValues: loginFilterAllowedValues,
+ userSyncFilter: userSyncFilter,
+ states: make(map[string]*oidcState),
+ }, nil
+}
+
+func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) {
+ b := make([]byte, 50)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+
+ state := base64.URLEncoding.EncodeToString(b)
+
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ p.states[state] = &oidcState{
+ expiry: time.Now().Add(time.Minute * 5),
+ ip: ip,
+ userAgent: userAgent,
+ }
+
+ return p.oauthConfig.AuthCodeURL(state), nil
+}
+
+func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string) (*sqlc.User, error) {
+ var s *oidcState
+
+ p.lock.Lock()
+ s = p.states[state]
+ delete(p.states, state)
+ p.lock.Unlock()
+
+ if s == nil {
+ return nil, ErrInvalidState
+ }
+
+ //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent {
+ // return nil, ErrStateVerificationFailed
+ //}
+ if time.Now().After(s.expiry) || s.userAgent != userAgent {
+ return nil, ErrStateVerificationFailed
+ }
+
+ oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode)
+ if err != nil {
+ return nil, err
+ }
+
+ rawIDToken, ok := oauth2Token.Extra("id_token").(string)
+ if !ok {
+ return nil, ErrMissingIDToken
+ }
+
+ _, err = p.oidcVerifier.Verify(ctx, rawIDToken)
+ if err != nil {
+ return nil, err
+ }
+
+ claims, err := getRawClaims(rawIDToken)
+ if err != nil {
+ return nil, err
+ }
+
+ if p.loginFilter != "" {
+ rolesClaim := gjson.Get(claims, p.loginFilter)
+ if !rolesClaim.Exists() {
+ return nil, fmt.Errorf("cannot verify authorisation as '%s' is missing from claims", p.loginFilter)
+ }
+ roles := rolesClaim.Array()
+ var authorisation bool
+ out:
+ for _, allowedRole := range p.loginFilterAllowedValues {
+ for _, role := range roles {
+ if role.Str == allowedRole {
+ authorisation = true
+ break out
+ }
+ }
+ }
+ if !authorisation {
+ return nil, ErrNotAuthorised
+ }
+ }
+
+ usernameClaim := gjson.Get(claims, p.userSyncFilter)
+ if !usernameClaim.Exists() {
+ return nil, fmt.Errorf("cannot sync user as '%s' is missing from claims", p.userSyncFilter)
+ }
+ username := usernameClaim.Str
+
+ u, err := p.userService.GetUserByName(username)
+ if err != nil {
+ if errors.Is(err, user.ErrUserNotFound) {
+ u, err = p.userService.CreateUser(username, "")
+ if err != nil {
+ return nil, errors.Join(ErrUserSyncFailed, err)
+ }
+ } else {
+ return nil, errors.Join(ErrUserSyncFailed, err)
+ }
+ }
+
+ return u, nil
+}
+
+func (p *OIDCAuthProvider) Name() string {
+ return p.name
+}
+
+func (p *OIDCAuthProvider) Type() string {
+ return "oidc"
+}
+
+func getRawClaims(p string) (string, error) {
+ parts := strings.Split(p, ".")
+ if len(parts) < 2 {
+ return "", fmt.Errorf("malformed jwt, expected 3 parts got %d", len(parts))
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return "", fmt.Errorf("malformed jwt payload: %w", err)
+ }
+ return string(payload[:]), nil
+}
diff --git a/pkg/auth/service.go b/pkg/auth/service.go
new file mode 100644
index 0000000..be1d6e7
--- /dev/null
+++ b/pkg/auth/service.go
@@ -0,0 +1,49 @@
+package auth
+
+import (
+ "fmt"
+ "sync"
+)
+
+type Service interface {
+ GetAuthProvider(string) AuthProvider
+ GetAuthProviders() []string
+ RegisterAuthProvider(string, AuthProvider) error
+}
+
+type AuthProvider interface {
+ Name() string
+ Type() string
+}
+
+type service struct {
+ authProviders map[string]AuthProvider
+ order []string
+ lock sync.Mutex
+}
+
+func NewService() Service {
+ return &service{
+ authProviders: make(map[string]AuthProvider),
+ }
+}
+
+func (s *service) GetAuthProvider(name string) AuthProvider {
+ return s.authProviders[name]
+}
+
+func (s *service) GetAuthProviders() []string {
+ return s.order
+}
+
+func (s *service) RegisterAuthProvider(name string, provider AuthProvider) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.authProviders[name]; ok {
+ return fmt.Errorf("duplicate auth provider: %s", name)
+ }
+ s.order = append(s.order, name)
+ s.authProviders[name] = provider
+ return nil
+}
diff --git a/pkg/database/migrations/0002_nullable_passwords.sql b/pkg/database/migrations/0002_nullable_passwords.sql
new file mode 100644
index 0000000..2f31366
--- /dev/null
+++ b/pkg/database/migrations/0002_nullable_passwords.sql
@@ -0,0 +1,4 @@
+-- +goose Up
+ALTER TABLE users DROP CONSTRAINT valid_hash;
+ALTER TABLE users ALTER COLUMN password DROP NOT NULL;
+ALTER TABLE users ADD CONSTRAINT valid_hash CHECK (length(password) = 60 OR password IS NULL);
diff --git a/pkg/database/sqlc/calendars.sql.go b/pkg/database/sqlc/calendars.sql.go
index 47ae37f..ad55a51 100644
--- a/pkg/database/sqlc/calendars.sql.go
+++ b/pkg/database/sqlc/calendars.sql.go
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
-// sqlc v1.27.0
+// sqlc v1.29.0
// source: calendars.sql
package sqlc
diff --git a/pkg/database/sqlc/db.go b/pkg/database/sqlc/db.go
index b931bc5..2725108 100644
--- a/pkg/database/sqlc/db.go
+++ b/pkg/database/sqlc/db.go
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
-// sqlc v1.27.0
+// sqlc v1.29.0
package sqlc
diff --git a/pkg/database/sqlc/favourites.sql.go b/pkg/database/sqlc/favourites.sql.go
index 359ae9d..b13261f 100644
--- a/pkg/database/sqlc/favourites.sql.go
+++ b/pkg/database/sqlc/favourites.sql.go
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
-// sqlc v1.27.0
+// sqlc v1.29.0
// source: favourites.sql
package sqlc
diff --git a/pkg/database/sqlc/models.go b/pkg/database/sqlc/models.go
index e38851a..57fd082 100644
--- a/pkg/database/sqlc/models.go
+++ b/pkg/database/sqlc/models.go
@@ -1,6 +1,6 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
-// sqlc v1.27.0
+// sqlc v1.29.0
package sqlc
@@ -23,7 +23,7 @@ type Favourite struct {
}
type User struct {
- ID int32 `json:"id"`
- Username string `json:"username"`
- Password string `json:"password"`
+ ID int32 `json:"id"`
+ Username string `json:"username"`
+ Password pgtype.Text `json:"password"`
}
diff --git a/pkg/database/sqlc/users.sql.go b/pkg/database/sqlc/users.sql.go
index dfd2c2f..cf0aeb9 100644
--- a/pkg/database/sqlc/users.sql.go
+++ b/pkg/database/sqlc/users.sql.go
@@ -1,12 +1,14 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
-// sqlc v1.27.0
+// sqlc v1.29.0
// source: users.sql
package sqlc
import (
"context"
+
+ "github.com/jackc/pgx/v5/pgtype"
)
const createUser = `-- name: CreateUser :one
@@ -19,8 +21,8 @@ RETURNING id, username, password
`
type CreateUserParams struct {
- Username string `json:"username"`
- Password string `json:"password"`
+ Username string `json:"username"`
+ Password pgtype.Text `json:"password"`
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
diff --git a/pkg/session/memory.go b/pkg/session/memory.go
index 96e416b..f02b792 100644
--- a/pkg/session/memory.go
+++ b/pkg/session/memory.go
@@ -2,7 +2,7 @@ package session
import (
"crypto/rand"
- "encoding/hex"
+ "encoding/base64"
"fmt"
"sync"
"time"
@@ -91,5 +91,5 @@ func generateSessionToken() string {
if _, err := rand.Read(b); err != nil {
return ""
}
- return hex.EncodeToString(b)
+ return base64.StdEncoding.EncodeToString(b)
}
diff --git a/pkg/user/service.go b/pkg/user/service.go
index 7784811..21cfa9e 100644
--- a/pkg/user/service.go
+++ b/pkg/user/service.go
@@ -9,6 +9,7 @@ import (
"github.com/LMBishop/confplanner/pkg/database/sqlc"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
+ "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
@@ -17,7 +18,6 @@ type Service interface {
CreateUser(username string, password string) (*sqlc.User, error)
GetUserByName(username string) (*sqlc.User, error)
GetUserByID(id int32) (*sqlc.User, error)
- Authenticate(username string, password string) (*sqlc.User, error)
}
var (
@@ -43,18 +43,30 @@ func (s *service) CreateUser(username string, password string) (*sqlc.User, erro
return nil, ErrNotAcceptingRegistrations
}
+ var passwordHash pgtype.Text
queries := sqlc.New(s.pool)
- var passwordBytes = []byte(password)
+ if password != "" {
+ var passwordBytes = []byte(password)
- hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
- if err != nil {
- return nil, fmt.Errorf("could not hash password: %w", err)
+ hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("could not hash password: %w", err)
+ }
+
+ passwordHash = pgtype.Text{
+ String: string(hash),
+ Valid: true,
+ }
+ } else {
+ passwordHash = pgtype.Text{
+ Valid: false,
+ }
}
user, err := queries.CreateUser(context.Background(), sqlc.CreateUserParams{
Username: strings.ToLower(username),
- Password: string(hash),
+ Password: passwordHash,
})
if err != nil {
var pgErr *pgconn.PgError
@@ -94,28 +106,3 @@ func (s *service) GetUserByID(id int32) (*sqlc.User, error) {
return &user, nil
}
-
-func (s *service) Authenticate(username string, password string) (*sqlc.User, error) {
- random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost)
- if err != nil {
- return nil, err
- }
-
- user, err := s.GetUserByName(username)
- if err != nil {
- if errors.Is(err, ErrUserNotFound) {
- bcrypt.CompareHashAndPassword(random, []byte(password))
- return nil, nil
- }
- return nil, err
- }
-
- if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
- if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
- return nil, nil
- }
- return nil, err
- }
-
- return user, nil
-}