Headscale: Sync users and update ACLs

Change-Id: Ie3488f6296567f5e2301476912d79de845299708
diff --git a/core/headscale/main.go b/core/headscale/main.go
index de0c2ff..da0ff6c 100644
--- a/core/headscale/main.go
+++ b/core/headscale/main.go
@@ -1,15 +1,22 @@
 package main
 
 import (
+	"bytes"
 	"encoding/json"
+	"errors"
 	"flag"
 	"fmt"
+	"io"
+	"io/ioutil"
 	"log"
 	"net"
 	"net/http"
 	"os"
 	"strings"
 	"text/template"
+	"time"
+
+	"golang.org/x/exp/rand"
 
 	"github.com/gorilla/mux"
 )
@@ -18,6 +25,8 @@
 var config = flag.String("config", "", "Path to headscale config")
 var acls = flag.String("acls", "", "Path to the headscale acls file")
 var ipSubnet = flag.String("ip-subnet", "10.1.0.0/24", "IP subnet of the private network")
+var fetchUsersAddr = flag.String("fetch-users-addr", "", "API endpoint to fetch user data")
+var self = flag.String("self", "", "Self address")
 
 // TODO(gio): make internal network cidr and proxy user configurable
 const defaultACLs = `
@@ -37,27 +46,59 @@
       "dst": ["{{ . }}:*", "private-network-proxy:0"],
     },
     {{- end }}
+    {{- range .users }}
+    { // Everyone has passthough access to private-network-proxy node
+      "action": "accept",
+      "src": ["{{ . }}"],
+      "dst": ["{{ . }}:*"],
+    },
+    {{- end }}
   ],
 }
 `
 
 type server struct {
-	port   int
-	client *client
+	port           int
+	client         *client
+	fetchUsersAddr string
+	self           string
+	aclsPath       string
+	aclsReloadPath string
+	cidrs          []string
 }
 
-func newServer(port int, client *client) *server {
+func newServer(port int, client *client, fetchUsersAddr, self, aclsPath string, cidrs []string) *server {
 	return &server{
 		port,
 		client,
+		fetchUsersAddr,
+		self,
+		aclsPath,
+		fmt.Sprintf("%s-reload", aclsPath), // TODO(gio): take from the flag
+		cidrs,
 	}
 }
 
 func (s *server) start() error {
+	f, err := os.Create(s.aclsReloadPath)
+	if err != nil {
+		return err
+	}
+	f.Close()
 	r := mux.NewRouter()
+	r.HandleFunc("/sync-users", s.handleSyncUsers).Methods(http.MethodGet)
 	r.HandleFunc("/user/{user}/preauthkey", s.createReusablePreAuthKey).Methods(http.MethodPost)
 	r.HandleFunc("/user", s.createUser).Methods(http.MethodPost)
 	r.HandleFunc("/routes/{id}/enable", s.enableRoute).Methods(http.MethodPost)
+	go func() {
+		rand.Seed(uint64(time.Now().UnixNano()))
+		s.syncUsers()
+		for {
+			delay := time.Duration(rand.Intn(60)+60) * time.Second
+			time.Sleep(delay)
+			s.syncUsers()
+		}
+	}()
 	return http.ListenAndServe(fmt.Sprintf(":%d", s.port), r)
 }
 
@@ -91,6 +132,49 @@
 	}
 }
 
+func (s *server) handleSyncUsers(_ http.ResponseWriter, _ *http.Request) {
+	go s.syncUsers()
+}
+
+type user struct {
+	Username string `json:"username"`
+}
+
+func (s *server) syncUsers() {
+	resp, err := http.Get(fmt.Sprintf("%s?selfAddress=%s/sync-users", s.fetchUsersAddr, s.self))
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	users := []user{}
+	if err := json.NewDecoder(resp.Body).Decode(&users); err != nil {
+		fmt.Println(err)
+		return
+	}
+	var usernames []string
+	for _, u := range users {
+		usernames = append(usernames, u.Username)
+		if err := s.client.createUser(u.Username); err != nil && !errors.Is(err, ErrorAlreadyExists) {
+			fmt.Println(err)
+			continue
+		}
+	}
+	currentACLs, err := ioutil.ReadFile(s.aclsPath)
+	if err != nil {
+		fmt.Println(err)
+	}
+	newACLs, err := updateACLs(s.aclsPath, s.cidrs, usernames)
+	if err != nil {
+		fmt.Println(err)
+		panic(err)
+	}
+	if !bytes.Equal(currentACLs, newACLs) {
+		if err := os.Remove(s.aclsReloadPath); err != nil {
+			fmt.Println(err)
+		}
+	}
+}
+
 func (s *server) enableRoute(w http.ResponseWriter, r *http.Request) {
 	id, ok := mux.Vars(r)["id"]
 	if !ok {
@@ -103,22 +187,24 @@
 	}
 }
 
-func updateACLs(cidrs []string, aclsPath string) error {
+func updateACLs(aclsPath string, cidrs []string, users []string) ([]byte, error) {
 	tmpl, err := template.New("acls").Parse(defaultACLs)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	out, err := os.Create(aclsPath)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	defer out.Close()
-	tmpl.Execute(os.Stdout, map[string]any{
+	var ret bytes.Buffer
+	if err := tmpl.Execute(io.MultiWriter(out, &ret), map[string]any{
 		"cidrs": cidrs,
-	})
-	return tmpl.Execute(out, map[string]any{
-		"cidrs": cidrs,
-	})
+		"users": users,
+	}); err != nil {
+		return nil, err
+	}
+	return ret.Bytes(), nil
 }
 
 func main() {
@@ -131,8 +217,7 @@
 		}
 		cidrs = append(cidrs, cidr.String())
 	}
-	updateACLs(cidrs, *acls)
 	c := newClient(*config)
-	s := newServer(*port, c)
+	s := newServer(*port, c, *fetchUsersAddr, *self, *acls, cidrs)
 	log.Fatal(s.start())
 }