diff options
| author | Leonardo Bishop <me@leonardobishop.com> | 2025-08-15 19:20:48 +0100 |
|---|---|---|
| committer | Leonardo Bishop <me@leonardobishop.com> | 2025-08-15 19:20:48 +0100 |
| commit | 8f7dec8ba6b2f9bde01afd0a110596ebbd43e0ed (patch) | |
| tree | 7b4f203d92f4b99b1e98fac314415e293984196b /pkg/auth | |
| parent | 4697556cac819c47d068819b9fc9c3b4ea84e279 (diff) | |
Implement OIDC
Diffstat (limited to 'pkg/auth')
| -rw-r--r-- | pkg/auth/basic.go | 56 | ||||
| -rw-r--r-- | pkg/auth/oauth.go | 191 | ||||
| -rw-r--r-- | pkg/auth/service.go | 49 |
3 files changed, 296 insertions, 0 deletions
diff --git a/pkg/auth/basic.go b/pkg/auth/basic.go new file mode 100644 index 0000000..dafd93f --- /dev/null +++ b/pkg/auth/basic.go @@ -0,0 +1,56 @@ +package auth + +import ( + "errors" + + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/user" + "golang.org/x/crypto/bcrypt" +) + +type BasicAuthProvider struct { + userService user.Service +} + +func NewBasicAuthProvider(userService user.Service) AuthProvider { + return &BasicAuthProvider{ + userService: userService, + } +} + +func (p *BasicAuthProvider) Authenticate(username string, password string) (*sqlc.User, error) { + random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + + u, err := p.userService.GetUserByName(username) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + bcrypt.CompareHashAndPassword(random, []byte(password)) + return nil, nil + } + return nil, err + } + if !u.Password.Valid { + bcrypt.CompareHashAndPassword(random, []byte(password)) + return nil, nil + } + + if err = bcrypt.CompareHashAndPassword([]byte(u.Password.String), []byte(password)); err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return nil, nil + } + return nil, err + } + + return u, nil +} + +func (p *BasicAuthProvider) Name() string { + return "Basic" +} + +func (p *BasicAuthProvider) Type() string { + return "basic" +} diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..9c45e7a --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,191 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/user" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/tidwall/gjson" + "golang.org/x/oauth2" +) + +type OIDCAuthProvider struct { + name string + userService user.Service + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + loginFilter string + loginFilterAllowedValues []string + userSyncFilter string + states map[string]*oidcState + lock sync.RWMutex +} + +type oidcState struct { + expiry time.Time + ip string + userAgent string +} + +var ( + ErrStateVerificationFailed = errors.New("state verification failed") + ErrInvalidState = errors.New("invalid state") + ErrMissingIDToken = errors.New("missing ID token") + ErrNotAuthorised = errors.New("not authorised") + ErrUserSyncFailed = errors.New("user sync failed") +) + +func NewOIDCAuthProvider(userService user.Service, name, clientID, clientSecret, endpoint, callbackURL, loginFilter, userSyncFilter string, loginFilterAllowedValues []string) (AuthProvider, error) { + provider, err := oidc.NewProvider(context.Background(), endpoint) + if err != nil { + return nil, err + } + + return &OIDCAuthProvider{ + name: name, + userService: userService, + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: callbackURL, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}), + loginFilter: loginFilter, + loginFilterAllowedValues: loginFilterAllowedValues, + userSyncFilter: userSyncFilter, + states: make(map[string]*oidcState), + }, nil +} + +func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) { + b := make([]byte, 50) + if _, err := rand.Read(b); err != nil { + return "", err + } + + state := base64.URLEncoding.EncodeToString(b) + + p.lock.Lock() + defer p.lock.Unlock() + + p.states[state] = &oidcState{ + expiry: time.Now().Add(time.Minute * 5), + ip: ip, + userAgent: userAgent, + } + + return p.oauthConfig.AuthCodeURL(state), nil +} + +func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string) (*sqlc.User, error) { + var s *oidcState + + p.lock.Lock() + s = p.states[state] + delete(p.states, state) + p.lock.Unlock() + + if s == nil { + return nil, ErrInvalidState + } + + //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent { + // return nil, ErrStateVerificationFailed + //} + if time.Now().After(s.expiry) || s.userAgent != userAgent { + return nil, ErrStateVerificationFailed + } + + oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode) + if err != nil { + return nil, err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return nil, ErrMissingIDToken + } + + _, err = p.oidcVerifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, err + } + + claims, err := getRawClaims(rawIDToken) + if err != nil { + return nil, err + } + + if p.loginFilter != "" { + rolesClaim := gjson.Get(claims, p.loginFilter) + if !rolesClaim.Exists() { + return nil, fmt.Errorf("cannot verify authorisation as '%s' is missing from claims", p.loginFilter) + } + roles := rolesClaim.Array() + var authorisation bool + out: + for _, allowedRole := range p.loginFilterAllowedValues { + for _, role := range roles { + if role.Str == allowedRole { + authorisation = true + break out + } + } + } + if !authorisation { + return nil, ErrNotAuthorised + } + } + + usernameClaim := gjson.Get(claims, p.userSyncFilter) + if !usernameClaim.Exists() { + return nil, fmt.Errorf("cannot sync user as '%s' is missing from claims", p.userSyncFilter) + } + username := usernameClaim.Str + + u, err := p.userService.GetUserByName(username) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + u, err = p.userService.CreateUser(username, "") + if err != nil { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } else { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } + + return u, nil +} + +func (p *OIDCAuthProvider) Name() string { + return p.name +} + +func (p *OIDCAuthProvider) Type() string { + return "oidc" +} + +func getRawClaims(p string) (string, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return "", fmt.Errorf("malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("malformed jwt payload: %w", err) + } + return string(payload[:]), nil +} diff --git a/pkg/auth/service.go b/pkg/auth/service.go new file mode 100644 index 0000000..be1d6e7 --- /dev/null +++ b/pkg/auth/service.go @@ -0,0 +1,49 @@ +package auth + +import ( + "fmt" + "sync" +) + +type Service interface { + GetAuthProvider(string) AuthProvider + GetAuthProviders() []string + RegisterAuthProvider(string, AuthProvider) error +} + +type AuthProvider interface { + Name() string + Type() string +} + +type service struct { + authProviders map[string]AuthProvider + order []string + lock sync.Mutex +} + +func NewService() Service { + return &service{ + authProviders: make(map[string]AuthProvider), + } +} + +func (s *service) GetAuthProvider(name string) AuthProvider { + return s.authProviders[name] +} + +func (s *service) GetAuthProviders() []string { + return s.order +} + +func (s *service) RegisterAuthProvider(name string, provider AuthProvider) error { + s.lock.Lock() + defer s.lock.Unlock() + + if _, ok := s.authProviders[name]; ok { + return fmt.Errorf("duplicate auth provider: %s", name) + } + s.order = append(s.order, name) + s.authProviders[name] = provider + return nil +} |
