aboutsummaryrefslogtreecommitdiffstats
path: root/pkg/config/service.go
blob: 3c9a27ec44cac4a94dbeaae72a4427c7b0ad023a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package config

import (
	"fmt"
	"net"
	"os"
	"regexp"

	validate "github.com/go-playground/validator/v10"
	"gopkg.in/yaml.v3"
)

type Config struct {
	Hostname string `yaml:"host" validate:"required"`
	TLS      struct {
		Enabled bool   `yaml:"enabled"`
		Cert    string `yaml:"cert"`
		Key     string `yaml:"key"`
	} `yaml:"tls"`
	WireGuard struct {
		Network       string `yaml:"network" validate:"cidr,required"`
		Port          string `yaml:"port" validate:"required"`
		InterfaceName string `yaml:"interfaceName" validate:"required"`
	} `yaml:"wireGuard"`
	ExpireAfter int `yaml:"expireAfter"`
}

type Service interface {
	InitialiseConfig(paths ...string) error
	Config() *Config
}

type service struct {
	config    *Config
	validator *validate.Validate
}

const InterfaceRegex = "^[a-zA-Z0-9_=+.-]{1,15}$"

func NewService() Service {
	return &service{
		validator: validate.New(validate.WithRequiredStructEnabled()),
	}
}

func (s *service) InitialiseConfig(paths ...string) error {
	for _, p := range paths {
		if _, err := os.Stat(p); err != nil {
			continue
		}
		c := &Config{}
		err := readConfig(p, c)
		if err != nil {
			return err
		}
		s.config = c
		break
	}
	return nil
}

func (s *service) Config() *Config {
	return s.config
}

func readConfig(configPath string, dst *Config) error {
	config, err := os.ReadFile(configPath)
	if err != nil {
		return err
	}

	if err := yaml.Unmarshal(config, dst); err != nil {
		return err
	}
	return nil
}

func (s *service) validateConfig(c *Config) error {
	if err := s.validator.Struct(c); err != nil {
		return err
	}

	match, _ := regexp.MatchString(InterfaceRegex, c.WireGuard.InterfaceName)
	if !match {
		return fmt.Errorf("invalid interface name: %s", c.WireGuard.InterfaceName)
	}

	ifaces, err := net.Interfaces()
	if err != nil {
		return fmt.Errorf("could not list network interfaces: %w", err)
	}
	for _, i := range ifaces {
		if i.Name == c.WireGuard.InterfaceName {
			return fmt.Errorf("an interface already exists with the name '%s'", i.Name)
		}
	}

	return nil
}