diff options
| author | Leonardo Bishop <me@leonardobishop.com> | 2025-08-15 19:20:48 +0100 |
|---|---|---|
| committer | Leonardo Bishop <me@leonardobishop.com> | 2025-08-15 19:20:48 +0100 |
| commit | 8f7dec8ba6b2f9bde01afd0a110596ebbd43e0ed (patch) | |
| tree | 7b4f203d92f4b99b1e98fac314415e293984196b /api/handlers/auth.go | |
| parent | 4697556cac819c47d068819b9fc9c3b4ea84e279 (diff) | |
Implement OIDC
Diffstat (limited to 'api/handlers/auth.go')
| -rw-r--r-- | api/handlers/auth.go | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/api/handlers/auth.go b/api/handlers/auth.go new file mode 100644 index 0000000..c19fc3a --- /dev/null +++ b/api/handlers/auth.go @@ -0,0 +1,148 @@ +package handlers + +import ( + "errors" + "log/slog" + "net/http" + + "github.com/LMBishop/confplanner/api/dto" + "github.com/LMBishop/confplanner/pkg/auth" + "github.com/LMBishop/confplanner/pkg/database/sqlc" + "github.com/LMBishop/confplanner/pkg/session" +) + +func Login(authService auth.Service, store session.Service) http.HandlerFunc { + return dto.WrapResponseFunc(func(w http.ResponseWriter, r *http.Request) error { + provider := authService.GetAuthProvider(r.PathValue("provider")) + + var user *sqlc.User + var err error + switch p := provider.(type) { + case *auth.BasicAuthProvider: + user, err = doBasicAuth(r, p) + case *auth.OIDCAuthProvider: + user, err = doOIDCAuthJourney(r, p) + default: + return &dto.ErrorResponse{ + Code: http.StatusBadRequest, + Message: "Unknown auth provider", + } + } + + if err != nil { + return err + } + + // TODO X-Forwarded-For + session, err := store.Create(user.ID, user.Username, r.RemoteAddr, r.UserAgent()) + if err != nil { + return err + } + + cookie := &http.Cookie{ + Name: "confplanner_session", + Value: session.Token, + Path: "/api", + } + http.SetCookie(w, cookie) + + return &dto.OkResponse{ + Code: http.StatusOK, + Data: &dto.LoginResponse{ + ID: user.ID, + Username: user.Username, + }, + } + }) +} + +func GetLoginOptions(authService auth.Service) http.HandlerFunc { + return dto.WrapResponseFunc(func(w http.ResponseWriter, r *http.Request) error { + var loginOptions []dto.LoginOption + + for _, identifier := range authService.GetAuthProviders() { + provider := authService.GetAuthProvider(identifier) + loginOptions = append(loginOptions, dto.LoginOption{ + Name: provider.Name(), + Identifier: identifier, + Type: provider.Type(), + }) + } + return &dto.OkResponse{ + Code: http.StatusOK, + Data: &dto.LoginOptionsResponse{ + Options: loginOptions, + }, + } + }) +} + +func doBasicAuth(r *http.Request, p *auth.BasicAuthProvider) (*sqlc.User, error) { + var request dto.LoginBasicRequest + if err := dto.ReadDto(r, &request); err != nil { + return nil, err + } + + user, err := p.Authenticate(request.Username, request.Password) + if err != nil { + return nil, err + } + + if user == nil { + return nil, &dto.ErrorResponse{ + Code: http.StatusBadRequest, + Message: "Username and password combination not found", + } + } + + return user, nil +} + +func doOIDCAuthJourney(r *http.Request, p *auth.OIDCAuthProvider) (*sqlc.User, error) { + var request dto.LoginOAuthCallbackRequest + if err := dto.ReadDto(r, &request); err != nil { + url, err := p.StartJourney(r.RemoteAddr, r.UserAgent()) + if err != nil { + return nil, &dto.ErrorResponse{ + Code: http.StatusInternalServerError, + Message: "Could not start OAuth journey", + } + } + + return nil, &dto.OkResponse{ + Code: http.StatusTemporaryRedirect, + Data: &dto.LoginOAuthOutboundResponse{ + URL: url, + }, + } + } + + user, err := p.CompleteJourney(r.Context(), request.Code, request.State, r.RemoteAddr, r.UserAgent()) + if err != nil { + if errors.Is(err, auth.ErrNotAuthorised) { + return nil, &dto.ErrorResponse{ + Code: http.StatusForbidden, + Message: "You are not authorised to use this service", + } + } else if errors.Is(err, auth.ErrInvalidState) { + return nil, &dto.ErrorResponse{ + Code: http.StatusBadRequest, + Message: "Invalid state", + } + } else if errors.Is(err, auth.ErrStateVerificationFailed) { + return nil, &dto.ErrorResponse{ + Code: http.StatusBadRequest, + Message: "State verification failed", + } + } else if errors.Is(err, auth.ErrUserSyncFailed) { + return nil, &dto.ErrorResponse{ + Code: http.StatusInternalServerError, + Message: "User sync failed", + } + } + slog.Error("error completing oidc journey", "error", err, "ip", r.RemoteAddr) + return nil, err + } + + return user, nil +} |
