diff options
Diffstat (limited to 'pkg/auth/oauth.go')
| -rw-r--r-- | pkg/auth/oauth.go | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..9c45e7a --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,191 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/user" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/tidwall/gjson" + "golang.org/x/oauth2" +) + +type OIDCAuthProvider struct { + name string + userService user.Service + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + loginFilter string + loginFilterAllowedValues []string + userSyncFilter string + states map[string]*oidcState + lock sync.RWMutex +} + +type oidcState struct { + expiry time.Time + ip string + userAgent string +} + +var ( + ErrStateVerificationFailed = errors.New("state verification failed") + ErrInvalidState = errors.New("invalid state") + ErrMissingIDToken = errors.New("missing ID token") + ErrNotAuthorised = errors.New("not authorised") + ErrUserSyncFailed = errors.New("user sync failed") +) + +func NewOIDCAuthProvider(userService user.Service, name, clientID, clientSecret, endpoint, callbackURL, loginFilter, userSyncFilter string, loginFilterAllowedValues []string) (AuthProvider, error) { + provider, err := oidc.NewProvider(context.Background(), endpoint) + if err != nil { + return nil, err + } + + return &OIDCAuthProvider{ + name: name, + userService: userService, + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: callbackURL, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}), + loginFilter: loginFilter, + loginFilterAllowedValues: loginFilterAllowedValues, + userSyncFilter: userSyncFilter, + states: make(map[string]*oidcState), + }, nil +} + +func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) { + b := make([]byte, 50) + if _, err := rand.Read(b); err != nil { + return "", err + } + + state := base64.URLEncoding.EncodeToString(b) + + p.lock.Lock() + defer p.lock.Unlock() + + p.states[state] = &oidcState{ + expiry: time.Now().Add(time.Minute * 5), + ip: ip, + userAgent: userAgent, + } + + return p.oauthConfig.AuthCodeURL(state), nil +} + +func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string) (*sqlc.User, error) { + var s *oidcState + + p.lock.Lock() + s = p.states[state] + delete(p.states, state) + p.lock.Unlock() + + if s == nil { + return nil, ErrInvalidState + } + + //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent { + // return nil, ErrStateVerificationFailed + //} + if time.Now().After(s.expiry) || s.userAgent != userAgent { + return nil, ErrStateVerificationFailed + } + + oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode) + if err != nil { + return nil, err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return nil, ErrMissingIDToken + } + + _, err = p.oidcVerifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, err + } + + claims, err := getRawClaims(rawIDToken) + if err != nil { + return nil, err + } + + if p.loginFilter != "" { + rolesClaim := gjson.Get(claims, p.loginFilter) + if !rolesClaim.Exists() { + return nil, fmt.Errorf("cannot verify authorisation as '%s' is missing from claims", p.loginFilter) + } + roles := rolesClaim.Array() + var authorisation bool + out: + for _, allowedRole := range p.loginFilterAllowedValues { + for _, role := range roles { + if role.Str == allowedRole { + authorisation = true + break out + } + } + } + if !authorisation { + return nil, ErrNotAuthorised + } + } + + usernameClaim := gjson.Get(claims, p.userSyncFilter) + if !usernameClaim.Exists() { + return nil, fmt.Errorf("cannot sync user as '%s' is missing from claims", p.userSyncFilter) + } + username := usernameClaim.Str + + u, err := p.userService.GetUserByName(username) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + u, err = p.userService.CreateUser(username, "") + if err != nil { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } else { + return nil, errors.Join(ErrUserSyncFailed, err) + } + } + + return u, nil +} + +func (p *OIDCAuthProvider) Name() string { + return p.name +} + +func (p *OIDCAuthProvider) Type() string { + return "oidc" +} + +func getRawClaims(p string) (string, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return "", fmt.Errorf("malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("malformed jwt payload: %w", err) + } + return string(payload[:]), nil +} |
