aboutsummaryrefslogtreecommitdiffstats
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go176
1 files changed, 176 insertions, 0 deletions
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..c31f56d
--- /dev/null
+++ b/main.go
@@ -0,0 +1,176 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "os/signal"
+ "os/user"
+ "time"
+
+ "github.com/LMBishop/gunnel/api/handlers"
+ "github.com/LMBishop/gunnel/pkg/config"
+ "github.com/LMBishop/gunnel/pkg/store"
+ "github.com/LMBishop/gunnel/pkg/wireguard"
+ "github.com/go-co-op/gocron/v2"
+ "github.com/gorilla/mux"
+)
+
+func main() {
+ u, err := user.Current()
+ if err != nil {
+ slog.Warn("cannot verify user is root", "error", err)
+ } else if u.Uid != "0" {
+ slog.Error("this program must be run as root to manage WireGuard")
+ os.Exit(1)
+ }
+
+ _, err = os.Stat("/usr/share/dict/words")
+ if err != nil {
+ slog.Error("could not find dictionary file at /usr/share/dict/words (you need to install a wordlist first)", "error", err)
+ os.Exit(1)
+ }
+
+ if err := run(); err != nil {
+ slog.Error("Unhandled error", "error", err)
+ os.Exit(1)
+ }
+}
+
+func run() error {
+ configService := config.NewService()
+ err := configService.InitialiseConfig("/etc/gunnel/config.yaml", "config.yaml")
+ if err != nil {
+ return fmt.Errorf("failed to load config: %w", err)
+ }
+
+ wireguardService := wireguard.NewService()
+ storeService := store.NewService()
+
+ c := configService.Config()
+
+ public, err := wireguardService.Up(c.WireGuard.InterfaceName, c.WireGuard.Network, c.WireGuard.Port)
+ if err != nil {
+ return fmt.Errorf("could not bring WireGuard interface up: %w", err)
+ }
+ slog.Info("interface up", "interface", c.WireGuard.InterfaceName, "publickey", public)
+
+ r := mux.NewRouter()
+ r.Host(c.Hostname).PathPrefix("/{port:[0-9]+}").HandlerFunc(handlers.NewPeer(storeService, wireguardService, configService))
+ r.Host(c.Hostname).Path("/").HandlerFunc(handlers.Index(configService))
+ r.Host(fmt.Sprintf("{subdomain}.%s", c.Hostname)).HandlerFunc(handlers.ReverseProxy(storeService))
+
+ srv := make([]*http.Server, 1)
+ if c.TLS.Enabled {
+ srv[0] = startHttpsServer(r, c.TLS.Cert, c.TLS.Key)
+ srv = append(srv, startHttpRedirect())
+ } else {
+ srv[0] = startHttpServer(r)
+ }
+
+ slog.Info("server started", "hostname", c.Hostname, "tls", c.TLS.Enabled)
+
+ s, err := gocron.NewScheduler()
+ if err != nil {
+ return fmt.Errorf("could not create scheduler: %w", err)
+ }
+
+ // todo fix (and move to service)
+ _, err = s.NewJob(gocron.CronJob("0 * * * *", false), gocron.NewTask(func() {
+ unusedRules := storeService.GetUnusedRulesSince(time.Now().Add(-time.Duration(c.ExpireAfter)))
+
+ if len(unusedRules) == 0 {
+ return
+ }
+
+ slog.Info("removing unused tunnels", "count", len(unusedRules))
+
+ for _, rule := range unusedRules {
+ wireguardService.RemovePeer(rule.Peer)
+ storeService.RemoveForwardingRule(rule.Slug)
+ }
+ }),
+ )
+
+ channel := make(chan os.Signal, 1)
+ signal.Notify(channel, os.Interrupt)
+
+ <-channel
+
+ err = s.Shutdown()
+ if err != nil {
+ slog.Error("scheduler shutdown", "error", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
+ defer cancel()
+ for _, s := range srv {
+ if err := s.Shutdown(ctx); err != nil {
+ slog.Error("server shutdown", "error", err)
+ }
+ }
+
+ err = wireguardService.Down()
+ if err != nil {
+ return fmt.Errorf("could not bring WireGuard interface down %w", err)
+ }
+
+ slog.Info("interface down", "interface", c.WireGuard.InterfaceName)
+
+ return nil
+}
+
+func startHttpServer(router *mux.Router) *http.Server {
+ srv := &http.Server{
+ Handler: router,
+ Addr: ":80",
+ WriteTimeout: 15 * time.Second,
+ ReadTimeout: 15 * time.Second,
+ }
+
+ go func() {
+ if err := srv.ListenAndServe(); err != nil {
+ slog.Error("http server", "error", err)
+ }
+ }()
+
+ return srv
+}
+
+func startHttpsServer(router *mux.Router, cert string, key string) *http.Server {
+ srv := &http.Server{
+ Handler: router,
+ Addr: ":443",
+ WriteTimeout: 15 * time.Second,
+ ReadTimeout: 15 * time.Second,
+ }
+
+ go func() {
+ if err := srv.ListenAndServeTLS(cert, key); err != nil {
+ slog.Error("https server", "error", err)
+ }
+ }()
+
+ return srv
+}
+
+func startHttpRedirect() *http.Server {
+ srv := &http.Server{
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, "https://"+r.Host+r.URL.String(), http.StatusMovedPermanently)
+ }),
+ Addr: ":80",
+ WriteTimeout: 15 * time.Second,
+ ReadTimeout: 15 * time.Second,
+ }
+
+ go func() {
+ if err := srv.ListenAndServe(); err != nil {
+ slog.Error("http redirect server", "error", err)
+ }
+ }()
+
+ return srv
+}