Headscale: Sync users and update ACLs
Change-Id: Ie3488f6296567f5e2301476912d79de845299708
diff --git a/core/headscale/main.go b/core/headscale/main.go
index de0c2ff..da0ff6c 100644
--- a/core/headscale/main.go
+++ b/core/headscale/main.go
@@ -1,15 +1,22 @@
package main
import (
+ "bytes"
"encoding/json"
+ "errors"
"flag"
"fmt"
+ "io"
+ "io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"text/template"
+ "time"
+
+ "golang.org/x/exp/rand"
"github.com/gorilla/mux"
)
@@ -18,6 +25,8 @@
var config = flag.String("config", "", "Path to headscale config")
var acls = flag.String("acls", "", "Path to the headscale acls file")
var ipSubnet = flag.String("ip-subnet", "10.1.0.0/24", "IP subnet of the private network")
+var fetchUsersAddr = flag.String("fetch-users-addr", "", "API endpoint to fetch user data")
+var self = flag.String("self", "", "Self address")
// TODO(gio): make internal network cidr and proxy user configurable
const defaultACLs = `
@@ -37,27 +46,59 @@
"dst": ["{{ . }}:*", "private-network-proxy:0"],
},
{{- end }}
+ {{- range .users }}
+ { // Everyone has passthough access to private-network-proxy node
+ "action": "accept",
+ "src": ["{{ . }}"],
+ "dst": ["{{ . }}:*"],
+ },
+ {{- end }}
],
}
`
type server struct {
- port int
- client *client
+ port int
+ client *client
+ fetchUsersAddr string
+ self string
+ aclsPath string
+ aclsReloadPath string
+ cidrs []string
}
-func newServer(port int, client *client) *server {
+func newServer(port int, client *client, fetchUsersAddr, self, aclsPath string, cidrs []string) *server {
return &server{
port,
client,
+ fetchUsersAddr,
+ self,
+ aclsPath,
+ fmt.Sprintf("%s-reload", aclsPath), // TODO(gio): take from the flag
+ cidrs,
}
}
func (s *server) start() error {
+ f, err := os.Create(s.aclsReloadPath)
+ if err != nil {
+ return err
+ }
+ f.Close()
r := mux.NewRouter()
+ r.HandleFunc("/sync-users", s.handleSyncUsers).Methods(http.MethodGet)
r.HandleFunc("/user/{user}/preauthkey", s.createReusablePreAuthKey).Methods(http.MethodPost)
r.HandleFunc("/user", s.createUser).Methods(http.MethodPost)
r.HandleFunc("/routes/{id}/enable", s.enableRoute).Methods(http.MethodPost)
+ go func() {
+ rand.Seed(uint64(time.Now().UnixNano()))
+ s.syncUsers()
+ for {
+ delay := time.Duration(rand.Intn(60)+60) * time.Second
+ time.Sleep(delay)
+ s.syncUsers()
+ }
+ }()
return http.ListenAndServe(fmt.Sprintf(":%d", s.port), r)
}
@@ -91,6 +132,49 @@
}
}
+func (s *server) handleSyncUsers(_ http.ResponseWriter, _ *http.Request) {
+ go s.syncUsers()
+}
+
+type user struct {
+ Username string `json:"username"`
+}
+
+func (s *server) syncUsers() {
+ resp, err := http.Get(fmt.Sprintf("%s?selfAddress=%s/sync-users", s.fetchUsersAddr, s.self))
+ if err != nil {
+ fmt.Println(err)
+ return
+ }
+ users := []user{}
+ if err := json.NewDecoder(resp.Body).Decode(&users); err != nil {
+ fmt.Println(err)
+ return
+ }
+ var usernames []string
+ for _, u := range users {
+ usernames = append(usernames, u.Username)
+ if err := s.client.createUser(u.Username); err != nil && !errors.Is(err, ErrorAlreadyExists) {
+ fmt.Println(err)
+ continue
+ }
+ }
+ currentACLs, err := ioutil.ReadFile(s.aclsPath)
+ if err != nil {
+ fmt.Println(err)
+ }
+ newACLs, err := updateACLs(s.aclsPath, s.cidrs, usernames)
+ if err != nil {
+ fmt.Println(err)
+ panic(err)
+ }
+ if !bytes.Equal(currentACLs, newACLs) {
+ if err := os.Remove(s.aclsReloadPath); err != nil {
+ fmt.Println(err)
+ }
+ }
+}
+
func (s *server) enableRoute(w http.ResponseWriter, r *http.Request) {
id, ok := mux.Vars(r)["id"]
if !ok {
@@ -103,22 +187,24 @@
}
}
-func updateACLs(cidrs []string, aclsPath string) error {
+func updateACLs(aclsPath string, cidrs []string, users []string) ([]byte, error) {
tmpl, err := template.New("acls").Parse(defaultACLs)
if err != nil {
- return err
+ return nil, err
}
out, err := os.Create(aclsPath)
if err != nil {
- return err
+ return nil, err
}
defer out.Close()
- tmpl.Execute(os.Stdout, map[string]any{
+ var ret bytes.Buffer
+ if err := tmpl.Execute(io.MultiWriter(out, &ret), map[string]any{
"cidrs": cidrs,
- })
- return tmpl.Execute(out, map[string]any{
- "cidrs": cidrs,
- })
+ "users": users,
+ }); err != nil {
+ return nil, err
+ }
+ return ret.Bytes(), nil
}
func main() {
@@ -131,8 +217,7 @@
}
cidrs = append(cidrs, cidr.String())
}
- updateACLs(cidrs, *acls)
c := newClient(*config)
- s := newServer(*port, c)
+ s := newServer(*port, c, *fetchUsersAddr, *self, *acls, cidrs)
log.Fatal(s.start())
}