// Adapted from // https://git.leonardobishop.net/confplanner/plain/pkg/auth/oauth.go package auth import ( "context" "crypto/rand" "encoding/base64" "errors" "sync" "time" "git.leonardobishop.net/instancer/pkg/session" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) type OIDCAuthProvider struct { Name string oauthConfig *oauth2.Config oidcProvider *oidc.Provider oidcVerifier *oidc.IDTokenVerifier states map[string]*oidcState lock sync.RWMutex } type oidcState struct { expiry time.Time ip string userAgent string } var ( ErrStateVerificationFailed = errors.New("state verification failed") ErrInvalidState = errors.New("invalid state") ErrMissingIDToken = errors.New("missing ID token") ErrInvalidIDToken = errors.New("invalid ID token") ErrNotAuthorised = errors.New("not authorised") ErrUserSyncFailed = errors.New("user sync failed") ErrInvalidToken = errors.New("invalid token") ) func NewOIDCAuthProvider(name, clientID, clientSecret, endpoint, callbackURL string) (OIDCAuthProvider, error) { provider, err := oidc.NewProvider(context.Background(), endpoint) if err != nil { return OIDCAuthProvider{}, err } return OIDCAuthProvider{ Name: name, oauthConfig: &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, Endpoint: provider.Endpoint(), RedirectURL: callbackURL, Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, oidcProvider: provider, oidcVerifier: provider.Verifier(&oidc.Config{ClientID: clientID}), states: make(map[string]*oidcState), }, nil } func (p *OIDCAuthProvider) StartJourney(ip string, userAgent string) (string, error) { b := make([]byte, 50) if _, err := rand.Read(b); err != nil { return "", err } state := base64.URLEncoding.EncodeToString(b) p.lock.Lock() defer p.lock.Unlock() p.states[state] = &oidcState{ expiry: time.Now().Add(time.Minute * 5), ip: ip, userAgent: userAgent, } return p.oauthConfig.AuthCodeURL(state), nil } func (p *OIDCAuthProvider) CompleteJourney(ctx context.Context, authCode string, state string, ip string, userAgent string, session *session.UserSession) error { var s *oidcState p.lock.Lock() s = p.states[state] delete(p.states, state) p.lock.Unlock() if s == nil { return ErrInvalidState } //if time.Now().After(s.expiry) || s.ip != ip || s.userAgent != userAgent { // return nil, ErrStateVerificationFailed //} if time.Now().After(s.expiry) || s.userAgent != userAgent { return ErrStateVerificationFailed } oauth2Token, err := p.oauthConfig.Exchange(ctx, authCode) if err != nil { return err } rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { return ErrMissingIDToken } idToken, err := p.oidcVerifier.Verify(ctx, rawIDToken) if err != nil { return ErrInvalidIDToken } var claims struct { Subject string `json:"sub"` Email string `json:"email"` Name string `json:"name"` } if err := idToken.Claims(&claims); err != nil { return ErrInvalidIDToken } session.OAuthTokenSource = p.oauthConfig.TokenSource(context.Background(), oauth2Token) session.Subject = claims.Subject session.Email = claims.Email session.Name = claims.Name return nil } func (p *OIDCAuthProvider) UpdateUserInfo(ctx context.Context, session *session.UserSession) error { userInfo, err := p.oidcProvider.UserInfo(ctx, session.OAuthTokenSource) if err != nil { return ErrInvalidToken } var claims struct { Name string `json:"name"` TeamID string `json:"team_id"` TeamName string `json:"team_name"` } err = userInfo.Claims(&claims) if err != nil { return err } session.Subject = userInfo.Subject session.Email = userInfo.Email session.Name = claims.Name session.TeamID = claims.TeamID session.TeamName = claims.TeamName return nil }