aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/user
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/user')
-rw-r--r--pkg/user/service.go121
1 files changed, 121 insertions, 0 deletions
diff --git a/pkg/user/service.go b/pkg/user/service.go
new file mode 100644
index 0000000..7784811
--- /dev/null
+++ b/pkg/user/service.go
@@ -0,0 +1,121 @@
+package user
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/LMBishop/confplanner/pkg/database/sqlc"
+ "github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgconn"
+ "github.com/jackc/pgx/v5/pgxpool"
+ "golang.org/x/crypto/bcrypt"
+)
+
+type Service interface {
+ CreateUser(username string, password string) (*sqlc.User, error)
+ GetUserByName(username string) (*sqlc.User, error)
+ GetUserByID(id int32) (*sqlc.User, error)
+ Authenticate(username string, password string) (*sqlc.User, error)
+}
+
+var (
+ ErrUserExists = errors.New("user already exists")
+ ErrUserNotFound = errors.New("user not found")
+ ErrNotAcceptingRegistrations = errors.New("not currently accepting registrations")
+)
+
+type service struct {
+ pool *pgxpool.Pool
+ acceptingRegistrations bool
+}
+
+func NewService(pool *pgxpool.Pool, acceptingRegistrations bool) Service {
+ return &service{
+ pool: pool,
+ acceptingRegistrations: acceptingRegistrations,
+ }
+}
+
+func (s *service) CreateUser(username string, password string) (*sqlc.User, error) {
+ if !s.acceptingRegistrations {
+ return nil, ErrNotAcceptingRegistrations
+ }
+
+ queries := sqlc.New(s.pool)
+
+ var passwordBytes = []byte(password)
+
+ hash, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
+ if err != nil {
+ return nil, fmt.Errorf("could not hash password: %w", err)
+ }
+
+ user, err := queries.CreateUser(context.Background(), sqlc.CreateUserParams{
+ Username: strings.ToLower(username),
+ Password: string(hash),
+ })
+ if err != nil {
+ var pgErr *pgconn.PgError
+ if errors.As(err, &pgErr) && pgErr.Code == "23505" {
+ return nil, ErrUserExists
+ }
+ return nil, fmt.Errorf("could not create user: %w", err)
+ }
+
+ return &user, nil
+}
+
+func (s *service) GetUserByName(username string) (*sqlc.User, error) {
+ queries := sqlc.New(s.pool)
+
+ user, err := queries.GetUserByName(context.Background(), username)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("could not fetch user: %w", err)
+ }
+
+ return &user, nil
+}
+
+func (s *service) GetUserByID(id int32) (*sqlc.User, error) {
+ queries := sqlc.New(s.pool)
+
+ user, err := queries.GetUserByID(context.Background(), id)
+ if err != nil {
+ if errors.Is(err, pgx.ErrNoRows) {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("could not fetch user: %w", err)
+ }
+
+ return &user, nil
+}
+
+func (s *service) Authenticate(username string, password string) (*sqlc.User, error) {
+ random, err := bcrypt.GenerateFromPassword([]byte("00000000"), bcrypt.DefaultCost)
+ if err != nil {
+ return nil, err
+ }
+
+ user, err := s.GetUserByName(username)
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ bcrypt.CompareHashAndPassword(random, []byte(password))
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
+ if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ return user, nil
+}