aboutsummaryrefslogtreecommitdiffstats
path: root/pkg
diff options
context:
space:
mode:
authorLeonardo Bishop <me@leonardobishop.com>2025-02-06 15:22:34 +0000
committerLeonardo Bishop <me@leonardobishop.com>2025-02-06 15:22:34 +0000
commit2475f5a8b92ef0dd28e7af5f36d01b25243ed778 (patch)
tree12f8931d241db4159f8d30f7bf2b648709a94166 /pkg
Initial commit
Diffstat (limited to 'pkg')
-rw-r--r--pkg/config/service.go99
-rw-r--r--pkg/store/service.go81
-rw-r--r--pkg/wireguard/service.go196
3 files changed, 376 insertions, 0 deletions
diff --git a/pkg/config/service.go b/pkg/config/service.go
new file mode 100644
index 0000000..3c9a27e
--- /dev/null
+++ b/pkg/config/service.go
@@ -0,0 +1,99 @@
+package config
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "regexp"
+
+ validate "github.com/go-playground/validator/v10"
+ "gopkg.in/yaml.v3"
+)
+
+type Config struct {
+ Hostname string `yaml:"host" validate:"required"`
+ TLS struct {
+ Enabled bool `yaml:"enabled"`
+ Cert string `yaml:"cert"`
+ Key string `yaml:"key"`
+ } `yaml:"tls"`
+ WireGuard struct {
+ Network string `yaml:"network" validate:"cidr,required"`
+ Port string `yaml:"port" validate:"required"`
+ InterfaceName string `yaml:"interfaceName" validate:"required"`
+ } `yaml:"wireGuard"`
+ ExpireAfter int `yaml:"expireAfter"`
+}
+
+type Service interface {
+ InitialiseConfig(paths ...string) error
+ Config() *Config
+}
+
+type service struct {
+ config *Config
+ validator *validate.Validate
+}
+
+const InterfaceRegex = "^[a-zA-Z0-9_=+.-]{1,15}$"
+
+func NewService() Service {
+ return &service{
+ validator: validate.New(validate.WithRequiredStructEnabled()),
+ }
+}
+
+func (s *service) InitialiseConfig(paths ...string) error {
+ for _, p := range paths {
+ if _, err := os.Stat(p); err != nil {
+ continue
+ }
+ c := &Config{}
+ err := readConfig(p, c)
+ if err != nil {
+ return err
+ }
+ s.config = c
+ break
+ }
+ return nil
+}
+
+func (s *service) Config() *Config {
+ return s.config
+}
+
+func readConfig(configPath string, dst *Config) error {
+ config, err := os.ReadFile(configPath)
+ if err != nil {
+ return err
+ }
+
+ if err := yaml.Unmarshal(config, dst); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *service) validateConfig(c *Config) error {
+ if err := s.validator.Struct(c); err != nil {
+ return err
+ }
+
+ match, _ := regexp.MatchString(InterfaceRegex, c.WireGuard.InterfaceName)
+ if !match {
+ return fmt.Errorf("invalid interface name: %s", c.WireGuard.InterfaceName)
+ }
+
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return fmt.Errorf("could not list network interfaces: %w", err)
+ }
+ for _, i := range ifaces {
+ if i.Name == c.WireGuard.InterfaceName {
+ return fmt.Errorf("an interface already exists with the name '%s'", i.Name)
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/store/service.go b/pkg/store/service.go
new file mode 100644
index 0000000..3a042fb
--- /dev/null
+++ b/pkg/store/service.go
@@ -0,0 +1,81 @@
+package store
+
+import (
+ "strings"
+ "time"
+
+ "github.com/LMBishop/gunnel/pkg/wireguard"
+ "github.com/tjarratt/babble"
+)
+
+type ForwardingRule struct {
+ Slug string
+ Peer *wireguard.Peer
+ Port string
+ LastUsed time.Time
+}
+
+type Service interface {
+ GetRuleBySlug(slug string) *ForwardingRule
+ NewForwardingRule(slug string, peer *wireguard.Peer, port string) *ForwardingRule
+ RemoveForwardingRule(slug string)
+ GetUnusedSlug() string
+ GetUnusedRulesSince(since time.Time) []*ForwardingRule
+}
+
+type service struct {
+ forwardingRules map[string]*ForwardingRule
+}
+
+func NewService() Service {
+ return &service{
+ forwardingRules: make(map[string]*ForwardingRule),
+ }
+}
+
+func (s *service) GetRuleBySlug(slug string) *ForwardingRule {
+ return s.forwardingRules[slug]
+}
+
+func (s *service) NewForwardingRule(slug string, peer *wireguard.Peer, port string) *ForwardingRule {
+ if s.forwardingRules[slug] != nil {
+ return nil
+ }
+
+ rule := &ForwardingRule{
+ Slug: slug,
+ Peer: peer,
+ Port: port,
+ }
+ s.forwardingRules[slug] = rule
+ return rule
+}
+
+func (s *service) GetUnusedSlug() string {
+ b := babble.NewBabbler()
+ b.Count = 3
+ b.Separator = "-"
+
+ for i := 0; i < 10; i++ {
+ slug := strings.Replace(strings.ToLower(b.Babble()), "'", "", -1)
+ if s.forwardingRules[slug] == nil {
+ return slug
+ }
+ }
+
+ return ""
+}
+
+func (s *service) GetUnusedRulesSince(since time.Time) []*ForwardingRule {
+ var rules []*ForwardingRule
+ for _, rule := range s.forwardingRules {
+ if rule.LastUsed.Before(since) {
+ rules = append(rules, rule)
+ }
+ }
+ return rules
+}
+
+func (s *service) RemoveForwardingRule(slug string) {
+ delete(s.forwardingRules, slug)
+}
diff --git a/pkg/wireguard/service.go b/pkg/wireguard/service.go
new file mode 100644
index 0000000..a0b67b7
--- /dev/null
+++ b/pkg/wireguard/service.go
@@ -0,0 +1,196 @@
+package wireguard
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "os/exec"
+ "strings"
+
+ "golang.org/x/sys/unix"
+)
+
+type Peer struct {
+ IPAddr net.IP
+ PublicKey string
+ PrivateKey string
+}
+
+type Service interface {
+ // GenerateKey() (string, error)
+ Up(iface string, network string, listenPort string) (string, error)
+ Down() error
+ NewPeer() (*Peer, error)
+ RemovePeer(peer *Peer) error
+ PublicKey() string
+}
+
+type service struct {
+ ipNet *net.IPNet
+ startIP uint32
+ endIP uint32
+ nextIP uint32
+
+ iface string
+ privateKey string
+ publicKey string
+}
+
+func NewService() Service {
+ return &service{}
+}
+
+func (s *service) Up(iface string, network string, listenPort string) (string, error) {
+ _, ipNet, err := net.ParseCIDR(network)
+ if err != nil {
+ return "", fmt.Errorf("cannot parse CIDR: %w", err)
+ }
+
+ s.ipNet = ipNet
+ mask := binary.BigEndian.Uint32(ipNet.Mask)
+ s.startIP = binary.BigEndian.Uint32(ipNet.IP)
+ s.endIP = (s.startIP & mask) | (mask ^ 0xffffffff)
+ s.nextIP = s.startIP
+
+ private, err := s.generateKey()
+ if err != nil {
+ return "", fmt.Errorf("cannot generate private key: %w", err)
+ }
+ public, err := s.getPublicKey(private)
+ if err != nil {
+ return "", fmt.Errorf("cannot get public key: %w", err)
+ }
+
+ fd, err := memfile("wg", []byte(private))
+
+ addInterface := fmt.Sprintf("ip link add dev %s type wireguard", iface)
+ addAddress := fmt.Sprintf("ip addr add %s dev %s", network, iface)
+ setPrivateKey := fmt.Sprintf("wg set %s private-key /dev/fd/%d listen-port %s", iface, fd, listenPort)
+ ifaceUp := fmt.Sprintf("ip link set %s up", iface)
+
+ cmd := exec.Command("bash", "-c", fmt.Sprintf("%s; %s; %s; %s", addInterface, addAddress, setPrivateKey, ifaceUp))
+ _, err = cmd.Output()
+ if err != nil {
+ return "", fmt.Errorf("cannot bring WireGuard interface up: %w", err)
+ }
+
+ s.iface = iface
+ s.privateKey = private
+ s.publicKey = public
+
+ return public, nil
+}
+
+func (s *service) Down() error {
+ cmd := exec.Command("bash", "-c", fmt.Sprintf("ip link delete dev %s", s.iface))
+ _, err := cmd.Output()
+ if err != nil {
+ return fmt.Errorf("cannot bring WireGuard interface down: %w", err)
+ }
+ return nil
+}
+
+func (s *service) NewPeer() (*Peer, error) {
+ private, err := s.generateKey()
+ if err != nil {
+ return nil, fmt.Errorf("cannot generate private key: %w", err)
+ }
+ public, err := s.getPublicKey(private)
+ if err != nil {
+ return nil, fmt.Errorf("cannot get public key: %w", err)
+ }
+
+ ipAddress, err := s.getNextIP()
+ if err != nil {
+ return nil, fmt.Errorf("could not assign new IP: %w", err)
+ }
+
+ cmd := exec.Command("bash", "-c", fmt.Sprintf("wg set %s peer %s allowed-ips %s/32", s.iface, public, ipAddress.String()))
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("cannot add peer: %s: %w", string(output), err)
+ }
+
+ return &Peer{
+ IPAddr: ipAddress,
+ PrivateKey: private,
+ PublicKey: public,
+ }, nil
+}
+
+func (s *service) RemovePeer(peer *Peer) error {
+ cmd := exec.Command("bash", "-c", fmt.Sprintf("wg set %s peer %s remove", s.iface, peer.PublicKey))
+ _, err := cmd.Output()
+ if err != nil {
+ return fmt.Errorf("cannot remove peer: %w", err)
+ }
+ return nil
+}
+
+func (s *service) PublicKey() string {
+ return s.publicKey
+}
+
+func (s *service) getNextIP() (net.IP, error) {
+ for {
+ if s.nextIP == s.endIP {
+ return net.IP{}, fmt.Errorf("no more IP addresses remaining")
+ }
+
+ ip := make(net.IP, 4)
+ binary.BigEndian.PutUint32(ip, s.nextIP)
+
+ if ip[3] != 0 && ip[3] != 255 {
+ s.nextIP++
+ return ip, nil
+ }
+
+ s.nextIP++
+ }
+}
+
+func (s *service) generateKey() (string, error) {
+ cmd := exec.Command("wg", "genkey")
+ stdout, err := cmd.Output()
+ if err != nil {
+ return "", err
+ }
+ return strings.Replace(string(stdout[:]), "\n", "", -1), nil
+}
+
+func (s *service) getPublicKey(private string) (string, error) {
+ cmd := exec.Command("wg", "pubkey")
+ cmd.Stdin = bytes.NewBufferString(private)
+ stdout, err := cmd.Output()
+ if err != nil {
+ return "", err
+ }
+ return strings.Replace(string(stdout[:]), "\n", "", -1), nil
+}
+
+func memfile(name string, b []byte) (int, error) {
+ fd, err := unix.MemfdCreate(name, 0)
+ if err != nil {
+ return 0, fmt.Errorf("MemfdCreate: %v", err)
+ }
+
+ err = unix.Ftruncate(fd, int64(len(b)))
+ if err != nil {
+ return 0, fmt.Errorf("Ftruncate: %v", err)
+ }
+
+ data, err := unix.Mmap(fd, 0, len(b), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED)
+ if err != nil {
+ return 0, fmt.Errorf("Mmap: %v", err)
+ }
+
+ copy(data, b)
+
+ err = unix.Munmap(data)
+ if err != nil {
+ return 0, fmt.Errorf("Munmap: %v", err)
+ }
+
+ return fd, nil
+}