diff options
Diffstat (limited to 'pkg/auth')
| -rw-r--r-- | pkg/auth/oidc.go | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go new file mode 100644 index 0000000..674332e --- /dev/null +++ b/pkg/auth/oidc.go @@ -0,0 +1,162 @@ +// Adapted from +// https://git.leonardobishop.net/confplanner/plain/pkg/auth/oauth.go + +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "sync" + "time" + + "git.leonardobishop.net/instancer/pkg/session" + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type OIDCAuthProvider struct { + Name string + + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + + states map[string]*oidcState + lock sync.RWMutex +} + +type oidcState struct { + expiry time.Time + ip string + userAgent string +} + +var ( + ErrStateVerificationFailed = errors.New("state verification failed") + ErrInvalidState = errors.New("invalid state") + ErrMissingIDToken = errors.New("missing ID token") + ErrInvalidIDToken = errors.New("invalid ID token") + ErrNotAuthorised = errors.New("not authorised") + ErrUserSyncFailed = errors.New("user sync failed") + + ErrInvalidToken = errors.New("invalid token") +) + +func NewOIDCAuthProvider(name, clientID, clientSecret, endpoint, callbackURL string) (OIDCAuthProvider, error) { + provider, err := oidc.NewProvider(context.Background(), endpoint) + if err != nil { + return OIDCAuthProvider{}, err + } + + return OIDCAuthProvider{ + Name: name, + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: provider.Endpoint(), + RedirectURL: callbackURL, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}), + states: make(map[string]*oidcState), + }, nil +} + +func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) { + b := make([]byte, 50) + if _, err := rand.Read(b); err != nil { + return "", err + } + + state := base64.URLEncoding.EncodeToString(b) + + p.lock.Lock() + defer p.lock.Unlock() + + p.states[state] = &oidcState{ + expiry: time.Now().Add(time.Minute * 5), + ip: ip, + userAgent: userAgent, + } + + return p.oauthConfig.AuthCodeURL(state), nil +} + +func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string, session *session.UserSession) error { + var s *oidcState + + p.lock.Lock() + s = p.states[state] + delete(p.states, state) + p.lock.Unlock() + + if s == nil { + return ErrInvalidState + } + + //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent { + // return nil, ErrStateVerificationFailed + //} + if time.Now().After(s.expiry) || s.userAgent != userAgent { + return ErrStateVerificationFailed + } + + oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode) + if err != nil { + return err + } + + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return ErrMissingIDToken + } + + idToken, err := p.oidcVerifier.Verify(ctx, rawIDToken) + if err != nil { + return ErrInvalidIDToken + } + + var claims struct { + Subject string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + } + if err := idToken.Claims(&claims); err != nil { + return ErrInvalidIDToken + } + + session.OAuthTokenSource = p.oauthConfig.TokenSource(context.Background(), oauth2Token) + session.Subject = claims.Subject + session.Email = claims.Email + session.Name = claims.Name + + return nil +} + +func (p *OIDCAuthProvider) UpdateUserInfo(ctx context.Context, session *session.UserSession) error { + userInfo, err := p.oidcProvider.UserInfo(ctx, session.OAuthTokenSource) + if err != nil { + return ErrInvalidToken + } + + var claims struct { + Name string `json:"name"` + TeamID string `json:"team_id"` + TeamName string `json:"team_name"` + } + err = userInfo.Claims(&claims) + if err != nil { + return err + } + + session.Subject = userInfo.Subject + session.Email = userInfo.Email + session.Name = claims.Name + session.TeamID = claims.TeamID + session.TeamName = claims.TeamName + + return nil +} |
