summaryrefslogtreecommitdiffstats
path: root/pkg/auth
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/auth')
-rw-r--r--pkg/auth/oidc.go162
1 files changed, 162 insertions, 0 deletions
diff --git a/pkg/auth/oidc.go b/pkg/auth/oidc.go
new file mode 100644
index 0000000..674332e
--- /dev/null
+++ b/pkg/auth/oidc.go
@@ -0,0 +1,162 @@
+// Adapted from
+// https://git.leonardobishop.net/confplanner/plain/pkg/auth/oauth.go
+
+package auth
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "sync"
+ "time"
+
+ "git.leonardobishop.net/instancer/pkg/session"
+ "github.com/coreos/go-oidc/v3/oidc"
+ "golang.org/x/oauth2"
+)
+
+type OIDCAuthProvider struct {
+ Name string
+
+ oauthConfig *oauth2.Config
+ oidcProvider *oidc.Provider
+ oidcVerifier *oidc.IDTokenVerifier
+
+ states map[string]*oidcState
+ lock sync.RWMutex
+}
+
+type oidcState struct {
+ expiry time.Time
+ ip string
+ userAgent string
+}
+
+var (
+ ErrStateVerificationFailed = errors.New("state verification failed")
+ ErrInvalidState = errors.New("invalid state")
+ ErrMissingIDToken = errors.New("missing ID token")
+ ErrInvalidIDToken = errors.New("invalid ID token")
+ ErrNotAuthorised = errors.New("not authorised")
+ ErrUserSyncFailed = errors.New("user sync failed")
+
+ ErrInvalidToken = errors.New("invalid token")
+)
+
+func NewOIDCAuthProvider(name, clientID, clientSecret, endpoint, callbackURL string) (OIDCAuthProvider, error) {
+ provider, err := oidc.NewProvider(context.Background(), endpoint)
+ if err != nil {
+ return OIDCAuthProvider{}, err
+ }
+
+ return OIDCAuthProvider{
+ Name: name,
+ oauthConfig: &oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Endpoint: provider.Endpoint(),
+ RedirectURL: callbackURL,
+ Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+ },
+ oidcProvider: provider,
+ oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}),
+ states: make(map[string]*oidcState),
+ }, nil
+}
+
+func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) {
+ b := make([]byte, 50)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+
+ state := base64.URLEncoding.EncodeToString(b)
+
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ p.states[state] = &oidcState{
+ expiry: time.Now().Add(time.Minute * 5),
+ ip: ip,
+ userAgent: userAgent,
+ }
+
+ return p.oauthConfig.AuthCodeURL(state), nil
+}
+
+func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string, session *session.UserSession) error {
+ var s *oidcState
+
+ p.lock.Lock()
+ s = p.states[state]
+ delete(p.states, state)
+ p.lock.Unlock()
+
+ if s == nil {
+ return ErrInvalidState
+ }
+
+ //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent {
+ // return nil, ErrStateVerificationFailed
+ //}
+ if time.Now().After(s.expiry) || s.userAgent != userAgent {
+ return ErrStateVerificationFailed
+ }
+
+ oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode)
+ if err != nil {
+ return err
+ }
+
+ rawIDToken, ok := oauth2Token.Extra("id_token").(string)
+ if !ok {
+ return ErrMissingIDToken
+ }
+
+ idToken, err := p.oidcVerifier.Verify(ctx, rawIDToken)
+ if err != nil {
+ return ErrInvalidIDToken
+ }
+
+ var claims struct {
+ Subject string `json:"sub"`
+ Email string `json:"email"`
+ Name string `json:"name"`
+ }
+ if err := idToken.Claims(&claims); err != nil {
+ return ErrInvalidIDToken
+ }
+
+ session.OAuthTokenSource = p.oauthConfig.TokenSource(context.Background(), oauth2Token)
+ session.Subject = claims.Subject
+ session.Email = claims.Email
+ session.Name = claims.Name
+
+ return nil
+}
+
+func (p *OIDCAuthProvider) UpdateUserInfo(ctx context.Context, session *session.UserSession) error {
+ userInfo, err := p.oidcProvider.UserInfo(ctx, session.OAuthTokenSource)
+ if err != nil {
+ return ErrInvalidToken
+ }
+
+ var claims struct {
+ Name string `json:"name"`
+ TeamID string `json:"team_id"`
+ TeamName string `json:"team_name"`
+ }
+ err = userInfo.Claims(&claims)
+ if err != nil {
+ return err
+ }
+
+ session.Subject = userInfo.Subject
+ session.Email = userInfo.Email
+ session.Name = claims.Name
+ session.TeamID = claims.TeamID
+ session.TeamName = claims.TeamName
+
+ return nil
+}