diff options
| author | Leonardo Bishop <me@leonardobishop.com> | 2025-02-06 15:22:34 +0000 |
|---|---|---|
| committer | Leonardo Bishop <me@leonardobishop.com> | 2025-02-06 15:22:34 +0000 |
| commit | 2475f5a8b92ef0dd28e7af5f36d01b25243ed778 (patch) | |
| tree | 12f8931d241db4159f8d30f7bf2b648709a94166 /pkg | |
Initial commit
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/config/service.go | 99 | ||||
| -rw-r--r-- | pkg/store/service.go | 81 | ||||
| -rw-r--r-- | pkg/wireguard/service.go | 196 |
3 files changed, 376 insertions, 0 deletions
diff --git a/pkg/config/service.go b/pkg/config/service.go new file mode 100644 index 0000000..3c9a27e --- /dev/null +++ b/pkg/config/service.go @@ -0,0 +1,99 @@ +package config + +import ( + "fmt" + "net" + "os" + "regexp" + + validate "github.com/go-playground/validator/v10" + "gopkg.in/yaml.v3" +) + +type Config struct { + Hostname string `yaml:"host" validate:"required"` + TLS struct { + Enabled bool `yaml:"enabled"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` + } `yaml:"tls"` + WireGuard struct { + Network string `yaml:"network" validate:"cidr,required"` + Port string `yaml:"port" validate:"required"` + InterfaceName string `yaml:"interfaceName" validate:"required"` + } `yaml:"wireGuard"` + ExpireAfter int `yaml:"expireAfter"` +} + +type Service interface { + InitialiseConfig(paths ...string) error + Config() *Config +} + +type service struct { + config *Config + validator *validate.Validate +} + +const InterfaceRegex = "^[a-zA-Z0-9_=+.-]{1,15}$" + +func NewService() Service { + return &service{ + validator: validate.New(validate.WithRequiredStructEnabled()), + } +} + +func (s *service) InitialiseConfig(paths ...string) error { + for _, p := range paths { + if _, err := os.Stat(p); err != nil { + continue + } + c := &Config{} + err := readConfig(p, c) + if err != nil { + return err + } + s.config = c + break + } + return nil +} + +func (s *service) Config() *Config { + return s.config +} + +func readConfig(configPath string, dst *Config) error { + config, err := os.ReadFile(configPath) + if err != nil { + return err + } + + if err := yaml.Unmarshal(config, dst); err != nil { + return err + } + return nil +} + +func (s *service) validateConfig(c *Config) error { + if err := s.validator.Struct(c); err != nil { + return err + } + + match, _ := regexp.MatchString(InterfaceRegex, c.WireGuard.InterfaceName) + if !match { + return fmt.Errorf("invalid interface name: %s", c.WireGuard.InterfaceName) + } + + ifaces, err := net.Interfaces() + if err != nil { + return fmt.Errorf("could not list network interfaces: %w", err) + } + for _, i := range ifaces { + if i.Name == c.WireGuard.InterfaceName { + return fmt.Errorf("an interface already exists with the name '%s'", i.Name) + } + } + + return nil +} diff --git a/pkg/store/service.go b/pkg/store/service.go new file mode 100644 index 0000000..3a042fb --- /dev/null +++ b/pkg/store/service.go @@ -0,0 +1,81 @@ +package store + +import ( + "strings" + "time" + + "github.com/LMBishop/gunnel/pkg/wireguard" + "github.com/tjarratt/babble" +) + +type ForwardingRule struct { + Slug string + Peer *wireguard.Peer + Port string + LastUsed time.Time +} + +type Service interface { + GetRuleBySlug(slug string) *ForwardingRule + NewForwardingRule(slug string, peer *wireguard.Peer, port string) *ForwardingRule + RemoveForwardingRule(slug string) + GetUnusedSlug() string + GetUnusedRulesSince(since time.Time) []*ForwardingRule +} + +type service struct { + forwardingRules map[string]*ForwardingRule +} + +func NewService() Service { + return &service{ + forwardingRules: make(map[string]*ForwardingRule), + } +} + +func (s *service) GetRuleBySlug(slug string) *ForwardingRule { + return s.forwardingRules[slug] +} + +func (s *service) NewForwardingRule(slug string, peer *wireguard.Peer, port string) *ForwardingRule { + if s.forwardingRules[slug] != nil { + return nil + } + + rule := &ForwardingRule{ + Slug: slug, + Peer: peer, + Port: port, + } + s.forwardingRules[slug] = rule + return rule +} + +func (s *service) GetUnusedSlug() string { + b := babble.NewBabbler() + b.Count = 3 + b.Separator = "-" + + for i := 0; i < 10; i++ { + slug := strings.Replace(strings.ToLower(b.Babble()), "'", "", -1) + if s.forwardingRules[slug] == nil { + return slug + } + } + + return "" +} + +func (s *service) GetUnusedRulesSince(since time.Time) []*ForwardingRule { + var rules []*ForwardingRule + for _, rule := range s.forwardingRules { + if rule.LastUsed.Before(since) { + rules = append(rules, rule) + } + } + return rules +} + +func (s *service) RemoveForwardingRule(slug string) { + delete(s.forwardingRules, slug) +} diff --git a/pkg/wireguard/service.go b/pkg/wireguard/service.go new file mode 100644 index 0000000..a0b67b7 --- /dev/null +++ b/pkg/wireguard/service.go @@ -0,0 +1,196 @@ +package wireguard + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "os/exec" + "strings" + + "golang.org/x/sys/unix" +) + +type Peer struct { + IPAddr net.IP + PublicKey string + PrivateKey string +} + +type Service interface { + // GenerateKey() (string, error) + Up(iface string, network string, listenPort string) (string, error) + Down() error + NewPeer() (*Peer, error) + RemovePeer(peer *Peer) error + PublicKey() string +} + +type service struct { + ipNet *net.IPNet + startIP uint32 + endIP uint32 + nextIP uint32 + + iface string + privateKey string + publicKey string +} + +func NewService() Service { + return &service{} +} + +func (s *service) Up(iface string, network string, listenPort string) (string, error) { + _, ipNet, err := net.ParseCIDR(network) + if err != nil { + return "", fmt.Errorf("cannot parse CIDR: %w", err) + } + + s.ipNet = ipNet + mask := binary.BigEndian.Uint32(ipNet.Mask) + s.startIP = binary.BigEndian.Uint32(ipNet.IP) + s.endIP = (s.startIP & mask) | (mask ^ 0xffffffff) + s.nextIP = s.startIP + + private, err := s.generateKey() + if err != nil { + return "", fmt.Errorf("cannot generate private key: %w", err) + } + public, err := s.getPublicKey(private) + if err != nil { + return "", fmt.Errorf("cannot get public key: %w", err) + } + + fd, err := memfile("wg", []byte(private)) + + addInterface := fmt.Sprintf("ip link add dev %s type wireguard", iface) + addAddress := fmt.Sprintf("ip addr add %s dev %s", network, iface) + setPrivateKey := fmt.Sprintf("wg set %s private-key /dev/fd/%d listen-port %s", iface, fd, listenPort) + ifaceUp := fmt.Sprintf("ip link set %s up", iface) + + cmd := exec.Command("bash", "-c", fmt.Sprintf("%s; %s; %s; %s", addInterface, addAddress, setPrivateKey, ifaceUp)) + _, err = cmd.Output() + if err != nil { + return "", fmt.Errorf("cannot bring WireGuard interface up: %w", err) + } + + s.iface = iface + s.privateKey = private + s.publicKey = public + + return public, nil +} + +func (s *service) Down() error { + cmd := exec.Command("bash", "-c", fmt.Sprintf("ip link delete dev %s", s.iface)) + _, err := cmd.Output() + if err != nil { + return fmt.Errorf("cannot bring WireGuard interface down: %w", err) + } + return nil +} + +func (s *service) NewPeer() (*Peer, error) { + private, err := s.generateKey() + if err != nil { + return nil, fmt.Errorf("cannot generate private key: %w", err) + } + public, err := s.getPublicKey(private) + if err != nil { + return nil, fmt.Errorf("cannot get public key: %w", err) + } + + ipAddress, err := s.getNextIP() + if err != nil { + return nil, fmt.Errorf("could not assign new IP: %w", err) + } + + cmd := exec.Command("bash", "-c", fmt.Sprintf("wg set %s peer %s allowed-ips %s/32", s.iface, public, ipAddress.String())) + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("cannot add peer: %s: %w", string(output), err) + } + + return &Peer{ + IPAddr: ipAddress, + PrivateKey: private, + PublicKey: public, + }, nil +} + +func (s *service) RemovePeer(peer *Peer) error { + cmd := exec.Command("bash", "-c", fmt.Sprintf("wg set %s peer %s remove", s.iface, peer.PublicKey)) + _, err := cmd.Output() + if err != nil { + return fmt.Errorf("cannot remove peer: %w", err) + } + return nil +} + +func (s *service) PublicKey() string { + return s.publicKey +} + +func (s *service) getNextIP() (net.IP, error) { + for { + if s.nextIP == s.endIP { + return net.IP{}, fmt.Errorf("no more IP addresses remaining") + } + + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, s.nextIP) + + if ip[3] != 0 && ip[3] != 255 { + s.nextIP++ + return ip, nil + } + + s.nextIP++ + } +} + +func (s *service) generateKey() (string, error) { + cmd := exec.Command("wg", "genkey") + stdout, err := cmd.Output() + if err != nil { + return "", err + } + return strings.Replace(string(stdout[:]), "\n", "", -1), nil +} + +func (s *service) getPublicKey(private string) (string, error) { + cmd := exec.Command("wg", "pubkey") + cmd.Stdin = bytes.NewBufferString(private) + stdout, err := cmd.Output() + if err != nil { + return "", err + } + return strings.Replace(string(stdout[:]), "\n", "", -1), nil +} + +func memfile(name string, b []byte) (int, error) { + fd, err := unix.MemfdCreate(name, 0) + if err != nil { + return 0, fmt.Errorf("MemfdCreate: %v", err) + } + + err = unix.Ftruncate(fd, int64(len(b))) + if err != nil { + return 0, fmt.Errorf("Ftruncate: %v", err) + } + + data, err := unix.Mmap(fd, 0, len(b), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + return 0, fmt.Errorf("Mmap: %v", err) + } + + copy(data, b) + + err = unix.Munmap(data) + if err != nil { + return 0, fmt.Errorf("Munmap: %v", err) + } + + return fd, nil +} |
