From 8f7dec8ba6b2f9bde01afd0a110596ebbd43e0ed Mon Sep 17 00:00:00 2001 From: Leonardo Bishop Date: Fri, 15 Aug 2025 19:20:48 +0100 Subject: Implement OIDC --- pkg/auth/basic.go | 56 ++++++ pkg/auth/oauth.go | 191 +++++++++++++++++++++ pkg/auth/service.go | 49 ++++++ .../migrations/0002_nullable_passwords.sql | 4 + pkg/database/sqlc/calendars.sql.go | 2 +- pkg/database/sqlc/db.go | 2 +- pkg/database/sqlc/favourites.sql.go | 2 +- pkg/database/sqlc/models.go | 8 +- pkg/database/sqlc/users.sql.go | 8 +- pkg/session/memory.go | 4 +- pkg/user/service.go | 49 ++---- 11 files changed, 332 insertions(+), 43 deletions(-) create mode 100644 pkg/auth/basic.go create mode 100644 pkg/auth/oauth.go create mode 100644 pkg/auth/service.go create mode 100644 pkg/database/migrations/0002_nullable_passwords.sql (limited to 'pkg') diff --git a/pkg/auth/basic.go b/pkg/auth/basic.go new file mode 100644 index 0000000..dafd93f --- /dev/null +++ b/pkg/auth/basic.go @@ -0,0 +1,56 @@ +package auth + +import ( + "errors" + + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/user" + "golang.org/x/crypto/bcrypt" +) + +type BasicAuthProvider struct { + userService user.Service +} + +func NewBasicAuthProvider(userService user.Service) AuthProvider { + return &BasicAuthProvider{ + userService: userService, + } +} + +func (p *BasicAuthProvider) Authenticate(username string, password string) (*sqlc.User, error) { + random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + + u, err := p.userService.GetUserByName(username) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + bcrypt.CompareHashAndPassword(random, []byte(password)) + return nil, nil + } + return nil, err + } + if !u.Password.Valid { + bcrypt.CompareHashAndPassword(random, []byte(password)) + return nil, nil + } + + if err = bcrypt.CompareHashAndPassword([]byte(u.Password.String), []byte(password)); err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return nil, nil + } + return nil, err + } + + return u, nil +} + +func (p *BasicAuthProvider) Name() string { + return "Basic" +} + +func (p *BasicAuthProvider) Type() string { + return "basic" +} diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..9c45e7a --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,191 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/user" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/tidwall/gjson" + "golang.org/x/oauth2" +) + +type OIDCAuthProvider struct { + name string + userService user.Service + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + loginFilter string + loginFilterAllowedValues []string + userSyncFilter string + states map[string]*oidcState + lock sync.RWMutex +} + +type oidcState struct { + expiry time.Time + ip string + userAgent string +} + +var ( + ErrStateVerificationFailed = errors.New("state verification failed") + ErrInvalidState = errors.New("invalid state") + ErrMissingIDToken = errors.New("missing ID token") + ErrNotAuthorised = errors.New("not authorised") + ErrUserSyncFailed = errors.New("user sync failed") +) + +func NewOIDCAuthProvider(userService user.Service, name, clientID, clientSecret, endpoint, callbackURL, loginFilter, userSyncFilter string, loginFilterAllowedValues []string) (AuthProvider, error) { + provider, err := oidc.NewProvider(context.Background(), endpoint) + if err != nil { + return nil, err + } + + return &OIDCAuthProvider{ + name: name, + userService: userService, + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: callbackURL, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}), + loginFilter: loginFilter, + loginFilterAllowedValues: loginFilterAllowedValues, + userSyncFilter: userSyncFilter, + states: make(map[string]*oidcState), + }, nil +} + +func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) { + b := make([]byte, 50) + if _, err := rand.Read(b); err != nil { + return "", err + } + + state := base64.URLEncoding.EncodeToString(b) + + p.lock.Lock() + defer p.lock.Unlock() + + p.states[state] = &oidcState{ + expiry: time.Now().Add(time.Minute * 5), + ip: ip, + userAgent: userAgent, + } + + return p.oauthConfig.AuthCodeURL(state), nil +} + +func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string) (*sqlc.User, error) { + var s *oidcState + + p.lock.Lock() + s = p.states[state] + delete(p.states, state) + p.lock.Unlock() + + if s == nil { + return nil, ErrInvalidState + } + + //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent { + // return nil, ErrStateVerificationFailed + //} + if time.Now().After(s.expiry) || s.userAgent != userAgent { + return nil, ErrStateVerificationFailed + } + + oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode) + if err != nil { + return nil, err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return nil, ErrMissingIDToken + } + + _, err = p.oidcVerifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, err + } + + claims, err := getRawClaims(rawIDToken) + if err != nil { + return nil, err + } + + if p.loginFilter != "" { + rolesClaim := gjson.Get(claims, p.loginFilter) + if !rolesClaim.Exists() { + return nil, fmt.Errorf("cannot verify authorisation as '%s' is missing from claims", p.loginFilter) + } + roles := rolesClaim.Array() + var authorisation bool + out: + for _, allowedRole := range p.loginFilterAllowedValues { + for _, role := range roles { + if role.Str == allowedRole { + authorisation = true + break out + } + } + } + if !authorisation { + return nil, ErrNotAuthorised + } + } + + usernameClaim := gjson.Get(claims, p.userSyncFilter) + if !usernameClaim.Exists() { + return nil, fmt.Errorf("cannot sync user as '%s' is missing from claims", p.userSyncFilter) + } + username := usernameClaim.Str + + u, err := p.userService.GetUserByName(username) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + u, err = p.userService.CreateUser(username, "") + if err != nil { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } else { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } + + return u, nil +} + +func (p *OIDCAuthProvider) Name() string { + return p.name +} + +func (p *OIDCAuthProvider) Type() string { + return "oidc" +} + +func getRawClaims(p string) (string, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return "", fmt.Errorf("malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("malformed jwt payload: %w", err) + } + return string(payload[:]), nil +} diff --git a/pkg/auth/service.go b/pkg/auth/service.go new file mode 100644 index 0000000..be1d6e7 --- /dev/null +++ b/pkg/auth/service.go @@ -0,0 +1,49 @@ +package auth + +import ( + "fmt" + "sync" +) + +type Service interface { + GetAuthProvider(string) AuthProvider + GetAuthProviders() []string + RegisterAuthProvider(string, AuthProvider) error +} + +type AuthProvider interface { + Name() string + Type() string +} + +type service struct { + authProviders map[string]AuthProvider + order []string + lock sync.Mutex +} + +func NewService() Service { + return &service{ + authProviders: make(map[string]AuthProvider), + } +} + +func (s *service) GetAuthProvider(name string) AuthProvider { + return s.authProviders[name] +} + +func (s *service) GetAuthProviders() []string { + return s.order +} + +func (s *service) RegisterAuthProvider(name string, provider AuthProvider) error { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.authProviders[name]; ok { + return fmt.Errorf("duplicate auth provider: %s", name) + } + s.order = append(s.order, name) + s.authProviders[name] = provider + return nil +} diff --git a/pkg/database/migrations/0002_nullable_passwords.sql b/pkg/database/migrations/0002_nullable_passwords.sql new file mode 100644 index 0000000..2f31366 --- /dev/null +++ b/pkg/database/migrations/0002_nullable_passwords.sql @@ -0,0 +1,4 @@ +-- +goose Up +ALTER TABLE users DROP CONSTRAINT valid_hash; +ALTER TABLE users ALTER COLUMN password DROP NOT NULL; +ALTER TABLE users ADD CONSTRAINT valid_hash CHECK (length(password) = 60 OR password IS NULL); diff --git a/pkg/database/sqlc/calendars.sql.go b/pkg/database/sqlc/calendars.sql.go index 47ae37f..ad55a51 100644 --- a/pkg/database/sqlc/calendars.sql.go +++ b/pkg/database/sqlc/calendars.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: calendars.sql package sqlc diff --git a/pkg/database/sqlc/db.go b/pkg/database/sqlc/db.go index b931bc5..2725108 100644 --- a/pkg/database/sqlc/db.go +++ b/pkg/database/sqlc/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package sqlc diff --git a/pkg/database/sqlc/favourites.sql.go b/pkg/database/sqlc/favourites.sql.go index 359ae9d..b13261f 100644 --- a/pkg/database/sqlc/favourites.sql.go +++ b/pkg/database/sqlc/favourites.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: favourites.sql package sqlc diff --git a/pkg/database/sqlc/models.go b/pkg/database/sqlc/models.go index e38851a..57fd082 100644 --- a/pkg/database/sqlc/models.go +++ b/pkg/database/sqlc/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package sqlc @@ -23,7 +23,7 @@ type Favourite struct { } type User struct { - ID int32 `json:"id"` - Username string `json:"username"` - Password string `json:"password"` + ID int32 `json:"id"` + Username string `json:"username"` + Password pgtype.Text `json:"password"` } diff --git a/pkg/database/sqlc/users.sql.go b/pkg/database/sqlc/users.sql.go index dfd2c2f..cf0aeb9 100644 --- a/pkg/database/sqlc/users.sql.go +++ b/pkg/database/sqlc/users.sql.go @@ -1,12 +1,14 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: users.sql package sqlc import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const createUser = `-- name: CreateUser :one @@ -19,8 +21,8 @@ RETURNING id, username, password ` type CreateUserParams struct { - Username string `json:"username"` - Password string `json:"password"` + Username string `json:"username"` + Password pgtype.Text `json:"password"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { diff --git a/pkg/session/memory.go b/pkg/session/memory.go index 96e416b..f02b792 100644 --- a/pkg/session/memory.go +++ b/pkg/session/memory.go @@ -2,7 +2,7 @@ package session import ( "crypto/rand" - "encoding/hex" + "encoding/base64" "fmt" "sync" "time" @@ -91,5 +91,5 @@ func generateSessionToken() string { if _, err := rand.Read(b); err != nil { return "" } - return hex.EncodeToString(b) + return base64.StdEncoding.EncodeToString(b) } diff --git a/pkg/user/service.go b/pkg/user/service.go index 7784811..21cfa9e 100644 --- a/pkg/user/service.go +++ b/pkg/user/service.go @@ -9,6 +9,7 @@ import ( "github.com/LMBishop/confplanner/pkg/database/sqlc" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) @@ -17,7 +18,6 @@ type Service interface { CreateUser(username string, password string) (*sqlc.User, error) GetUserByName(username string) (*sqlc.User, error) GetUserByID(id int32) (*sqlc.User, error) - Authenticate(username string, password string) (*sqlc.User, error) } var ( @@ -43,18 +43,30 @@ func (s *service) CreateUser(username string, password string) (*sqlc.User, erro return nil, ErrNotAcceptingRegistrations } + var passwordHash pgtype.Text queries := sqlc.New(s.pool) - var passwordBytes = []byte(password) + if password != "" { + var passwordBytes = []byte(password) - hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) - if err != nil { - return nil, fmt.Errorf("could not hash password: %w", err) + hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("could not hash password: %w", err) + } + + passwordHash = pgtype.Text{ + String: string(hash), + Valid: true, + } + } else { + passwordHash = pgtype.Text{ + Valid: false, + } } user, err := queries.CreateUser(context.Background(), sqlc.CreateUserParams{ Username: strings.ToLower(username), - Password: string(hash), + Password: passwordHash, }) if err != nil { var pgErr *pgconn.PgError @@ -94,28 +106,3 @@ func (s *service) GetUserByID(id int32) (*sqlc.User, error) { return &user, nil } - -func (s *service) Authenticate(username string, password string) (*sqlc.User, error) { - random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost) - if err != nil { - return nil, err - } - - user, err := s.GetUserByName(username) - if err != nil { - if errors.Is(err, ErrUserNotFound) { - bcrypt.CompareHashAndPassword(random, []byte(password)) - return nil, nil - } - return nil, err - } - - if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { - return nil, nil - } - return nil, err - } - - return user, nil -} -- cgit v1.2.3-70-g09d2