blob: 3696342022c9463b3cd714537f5c13ea0fc51d99 [file] [log] [blame]
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/giolekva/pcloud/core/installer"
"github.com/giolekva/pcloud/core/installer/soft"
"golang.org/x/crypto/ssh"
)
const (
secretLength = 20
start = 49152
end = 65535
)
var port = flag.Int("port", 8080, "Port to listen on")
var repoAddr = flag.String("repo-addr", "", "Git repository address where Helm releases are stored")
var sshKey = flag.String("ssh-key", "", "Path to SHH key used to connect with Git repository")
var ingressNginxPath = flag.String("ingress-nginx-path", "", "Path to the ingress-nginx Helm release")
var minPreOpenPorts = flag.Int("min-pre-open-ports", 5, "Minimum number of pre-open ports to keep in reserve")
var preOpenPortsBatchSize = flag.Int("pre-open-ports-batch-size", 10, "Number of new ports to open at a time")
type client interface {
ReservePort(remoteProxy bool) (int, string, error)
ReleaseReservedPort(port ...int)
AddPortForwarding(protocol string, port int, secret, dest string) error
RemovePortForwarding(protocol string, port int) error
}
type Reservation struct {
Secret string `json:"secret"`
IsRemoteProxy bool `json:"isRemoteProxy"`
}
type repoClient struct {
l sync.Locker
repo soft.RepoIO
path string
secretGenerator SecretGenerator
proxyCfg *installer.NginxProxyConfigurator
minPreOpenPorts int
preOpenPortsBatchSize int
preOpenPorts []int
proxyPreOpenPorts []int
blocklist map[int]struct{}
reserve map[int]Reservation
availablePorts []int
}
func getProxyBackendConfigPath(repo soft.RepoIO, path string) (string, error) {
cfgPath := filepath.Join(filepath.Dir(path), "proxy-backend-config.yaml")
inp, err := repo.Reader(cfgPath)
if err != nil {
return "", nil
}
defer inp.Close()
return cfgPath, nil
}
func newRepoClient(
repo soft.RepoIO,
path string,
minPreOpenPorts int,
preOpenPortsBatchSize int,
secretGenerator SecretGenerator,
) (client, error) {
proxyCfg, err := getProxyBackendConfigPath(repo, path)
if err != nil {
return nil, err
}
var cnc *installer.NginxProxyConfigurator
if proxyCfg != "" {
cnc = &installer.NginxProxyConfigurator{
Repo: repo,
ConfigPath: proxyCfg,
ServicePath: filepath.Join(filepath.Dir(proxyCfg), "proxy-backend-service.yaml"),
}
}
ret := &repoClient{
l: &sync.Mutex{},
repo: repo,
path: path,
secretGenerator: secretGenerator,
proxyCfg: cnc,
minPreOpenPorts: minPreOpenPorts,
preOpenPortsBatchSize: preOpenPortsBatchSize,
preOpenPorts: []int{},
blocklist: map[int]struct{}{},
reserve: map[int]Reservation{},
availablePorts: []int{},
}
st, err := ret.readState(repo)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return nil, err
}
} else {
ret.preOpenPorts = st.PreOpenPorts
ret.blocklist = st.Blocklist
ret.reserve = st.Reserve
}
for i := start; i < end; i++ {
if _, ok := ret.blocklist[i]; !ok {
ret.availablePorts = append(ret.availablePorts, i)
}
}
if err := ret.preOpenNewPorts(); err != nil {
return nil, err
}
var reservedPorts []int
for k := range ret.reserve {
reservedPorts = append(reservedPorts, k)
}
go func() {
time.Sleep(30 * time.Minute)
ret.ReleaseReservedPort(reservedPorts...)
}()
return ret, nil
}
func (c *repoClient) ReservePort(remoteProxy bool) (int, string, error) {
c.l.Lock()
defer c.l.Unlock()
var port int
if !remoteProxy {
if len(c.preOpenPorts) == 0 {
return -1, "", fmt.Errorf("no pre-open ports are available")
}
port = c.preOpenPorts[0]
c.preOpenPorts = c.preOpenPorts[1:]
} else {
if c.proxyCfg == nil {
return -1, "", fmt.Errorf("does not support TCP/UDP proxy")
}
if len(c.proxyPreOpenPorts) == 0 {
return -1, "", fmt.Errorf("no proxy pre-open ports are available")
}
port = c.proxyPreOpenPorts[0]
c.proxyPreOpenPorts = c.proxyPreOpenPorts[1:]
}
secret, err := c.secretGenerator()
if err != nil {
return -1, "", err
}
c.reserve[port] = Reservation{secret, remoteProxy}
return port, secret, nil
}
func (c *repoClient) ReleaseReservedPort(port ...int) {
if len(port) == 0 {
return
}
c.l.Lock()
defer c.l.Unlock()
if _, err := c.repo.Do(func(fs soft.RepoFS) (string, error) {
for _, p := range port {
r, ok := c.reserve[p]
if !ok {
continue
}
delete(c.reserve, p)
if r.IsRemoteProxy {
c.proxyPreOpenPorts = append(c.proxyPreOpenPorts, p)
} else {
c.preOpenPorts = append(c.preOpenPorts, p)
}
}
if err := c.writeState(fs); err != nil {
return "", err
}
return fmt.Sprintf("Released port reservations: %+v", port), nil
}); err != nil {
panic(err)
}
}
type oldState struct {
PreOpenPorts []int `json:"preOpenPorts"`
ProxyPreOpenPorts []int `json:"proxyPreOpenPorts"`
Blocklist map[int]struct{} `json:"blocklist"`
Reserve map[int]string `json:"reserve"`
}
type state struct {
PreOpenPorts []int `json:"preOpenPorts"`
ProxyPreOpenPorts []int `json:"proxyPreOpenPorts"`
Blocklist map[int]struct{} `json:"blocklist"`
Reserve map[int]Reservation `json:"reserve"`
}
func (c *repoClient) preOpenNewPorts() error {
c.l.Lock()
defer c.l.Unlock()
var ports []int
if len(c.preOpenPorts) < c.minPreOpenPorts {
for count := c.preOpenPortsBatchSize; count > 0; count-- {
if len(c.availablePorts) == 0 {
return fmt.Errorf("could not open new port")
}
r := rand.Intn(len(c.availablePorts))
p := c.availablePorts[r]
c.availablePorts[r] = c.availablePorts[len(c.availablePorts)-1]
c.availablePorts = c.availablePorts[:len(c.availablePorts)-1]
ports = append(ports, p)
c.preOpenPorts = append(c.preOpenPorts, p)
c.blocklist[p] = struct{}{}
}
}
if c.proxyCfg != nil && len(c.proxyPreOpenPorts) < c.minPreOpenPorts {
for count := c.preOpenPortsBatchSize; count > 0; count-- {
if len(c.availablePorts) == 0 {
return fmt.Errorf("could not open new port")
}
r := rand.Intn(len(c.availablePorts))
p := c.availablePorts[r]
c.availablePorts[r] = c.availablePorts[len(c.availablePorts)-1]
c.availablePorts = c.availablePorts[:len(c.availablePorts)-1]
ports = append(ports, p)
c.proxyPreOpenPorts = append(c.proxyPreOpenPorts, p)
c.blocklist[p] = struct{}{}
}
}
if len(ports) == 0 {
return nil
}
_, err := c.repo.Do(func(fs soft.RepoFS) (string, error) {
if err := c.writeState(fs); err != nil {
return "", err
}
rel, err := c.readRelease(fs)
if err != nil {
return "", err
}
svcType := ""
svcEnabled, err := extractBool(rel, "spec.values.controller.service.enabled")
if err != nil {
return "", err
}
if svcEnabled {
svcType, err = extractString(rel, "spec.values.controller.service.type")
if err != nil {
return "", err
}
}
if svcType == "NodePort" {
tcp, err := extractPorts(rel, "spec.values.controller.service.nodePorts.tcp")
if err != nil {
return "", err
}
udp, err := extractPorts(rel, "spec.values.controller.service.nodePorts.udp")
if err != nil {
return "", err
}
for _, p := range ports {
ps := strconv.Itoa(p)
tcp[ps] = p
udp[ps] = p
}
if err := c.writeRelease(fs, rel); err != nil {
return "", err
}
}
fmt.Printf("Pre opened new ports: %+v\n", ports)
return "preopen new ports", nil
})
return err
}
func (c *repoClient) AddPortForwarding(protocol string, port int, secret, dest string) error {
protocol = strings.ToLower(protocol)
defer func() {
if err := c.preOpenNewPorts(); err != nil {
panic(err)
}
}()
c.l.Lock()
defer c.l.Unlock()
r, ok := c.reserve[port]
if !ok || r.Secret != secret {
return fmt.Errorf("wrong secret")
}
delete(c.reserve, port)
if r.IsRemoteProxy {
if c.proxyCfg == nil {
return fmt.Errorf("does not support TCP/UDP proxy")
}
switch strings.ToLower(protocol) {
case "tcp":
if _, err := c.proxyCfg.AddProxy(port, dest, installer.ProtocolTCP); err != nil {
return err
}
case "udp":
if _, err := c.proxyCfg.AddProxy(port, dest, installer.ProtocolUDP); err != nil {
return err
}
default:
return fmt.Errorf("unknown protocol: %s", protocol)
}
}
_, err := c.repo.Do(func(fs soft.RepoFS) (string, error) {
if err := c.writeState(fs); err != nil {
return "", err
}
rel, err := c.readRelease(fs)
if err != nil {
return "", err
}
portStr := strconv.Itoa(port)
var portMap map[string]any
base := "spec.values"
if r.IsRemoteProxy {
base = "spec.values.controller.service.extraPorts"
dest = portStr
}
switch protocol {
case "tcp":
portMap, err = extractPorts(rel, fmt.Sprintf("%s.tcp", base))
if err != nil {
return "", err
}
case "udp":
portMap, err = extractPorts(rel, fmt.Sprintf("%s.udp", base))
if err != nil {
return "", err
}
default:
panic("MUST NOT REACH")
}
portMap[portStr] = dest
if err := c.writeRelease(fs, rel); err != nil {
return "", err
}
return fmt.Sprintf("ingress: port %s map %d %s", protocol, port, dest), nil
})
return err
}
func (c *repoClient) RemovePortForwarding(protocol string, port int) error {
protocol = strings.ToLower(protocol)
c.l.Lock()
defer c.l.Unlock()
_, err := c.repo.Do(func(fs soft.RepoFS) (string, error) {
rel, err := c.readRelease(fs)
if err != nil {
return "", err
}
switch protocol {
case "tcp":
tcp, err := extractPorts(rel, "spec.values.tcp")
if err != nil {
return "", err
}
if err := removePort(tcp, port); err != nil {
return "", err
}
case "udp":
udp, err := extractPorts(rel, "spec.values.udp")
if err != nil {
return "", err
}
if err := removePort(udp, port); err != nil {
return "", err
}
default:
panic("MUST NOT REACH")
}
svcType := ""
svcEnabled, err := extractBool(rel, "spec.values.controller.service.enabled")
if err != nil {
return "", err
}
if svcEnabled {
svcType, err = extractString(rel, "spec.values.controller.service.type")
if err != nil {
return "", err
}
}
if svcType == "NodePort" {
svcTCP, err := extractPorts(rel, "spec.values.controller.service.nodePorts.tcp")
if err != nil {
return "", err
}
svcUDP, err := extractPorts(rel, "spec.values.controller.service.nodePorts.udp")
if err != nil {
return "", err
}
if err := removePort(svcTCP, port); err != nil {
return "", err
}
if err := removePort(svcUDP, port); err != nil {
return "", err
}
}
if err := c.writeRelease(fs, rel); err != nil {
return "", err
}
return fmt.Sprintf("ingress: remove %s port map %d", protocol, port), nil
})
return err
}
func (c *repoClient) readState(fs soft.RepoFS) (state, error) {
r, err := fs.Reader(fmt.Sprintf("%s-state.json", c.path))
if err != nil {
return state{}, err
}
defer r.Close()
buf, err := io.ReadAll(r)
if err != nil {
return state{}, err
}
var ret state
if err := json.NewDecoder(bytes.NewReader(buf)).Decode(&ret); err == nil {
return ret, nil
}
var old oldState
if err := json.NewDecoder(bytes.NewReader(buf)).Decode(&old); err != nil {
return state{}, err
}
ret = state{
PreOpenPorts: old.PreOpenPorts,
ProxyPreOpenPorts: []int{},
Blocklist: old.Blocklist,
Reserve: map[int]Reservation{},
}
for port, secret := range old.Reserve {
ret.Reserve[port] = Reservation{secret, false}
}
return ret, err
}
func (c *repoClient) writeState(fs soft.RepoFS) error {
w, err := fs.Writer(fmt.Sprintf("%s-state.json", c.path))
if err != nil {
return err
}
defer w.Close()
if err := json.NewEncoder(w).Encode(state{c.preOpenPorts, c.proxyPreOpenPorts, c.blocklist, c.reserve}); err != nil {
return err
}
return err
}
func (c *repoClient) readRelease(fs soft.RepoFS) (map[string]any, error) {
ret := map[string]any{}
if err := soft.ReadYaml(fs, c.path, &ret); err != nil {
return nil, err
}
return ret, nil
}
func (c *repoClient) writeRelease(fs soft.RepoFS, rel map[string]any) error {
return soft.WriteYaml(fs, c.path, rel)
}
type server struct {
s *http.Server
r *http.ServeMux
client client
}
func newServer(port int, client client) *server {
r := http.NewServeMux()
s := &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: r,
}
return &server{s, r, client}
}
func (s *server) Start() error {
s.r.HandleFunc("/api/reserve", s.handleReserve)
s.r.HandleFunc("/api/allocate", s.handleAllocate)
s.r.HandleFunc("/api/remove", s.handleRemove)
if err := s.s.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
return nil
}
func (s *server) Close() error {
return s.s.Close()
}
type allocateReq struct {
Protocol string `json:"protocol"`
SourcePort int `json:"sourcePort"`
TargetService string `json:"targetService"`
TargetPort int `json:"targetPort"`
Secret string `json:"secret"`
}
type removeReq struct {
Protocol string `json:"protocol"`
SourcePort int `json:"sourcePort"`
TargetService string `json:"targetService"`
TargetPort int `json:"targetPort"`
}
func extractAllocateReq(r io.Reader) (allocateReq, error) {
var req allocateReq
if err := json.NewDecoder(r).Decode(&req); err != nil {
return allocateReq{}, err
}
req.Protocol = strings.ToLower(req.Protocol)
if req.Protocol != "tcp" && req.Protocol != "udp" {
return allocateReq{}, fmt.Errorf("Unexpected protocol %s", req.Protocol)
}
return req, nil
}
func extractRemoveReq(r io.Reader) (removeReq, error) {
var req removeReq
if err := json.NewDecoder(r).Decode(&req); err != nil {
return removeReq{}, err
}
req.Protocol = strings.ToLower(req.Protocol)
if req.Protocol != "tcp" && req.Protocol != "udp" {
return removeReq{}, fmt.Errorf("Unexpected protocol %s", req.Protocol)
}
return req, nil
}
type reserveResp struct {
Port int `json:"port"`
Secret string `json:"secret"`
}
func extractField(data map[string]any, path string) (any, error) {
var val any = data
for _, i := range strings.Split(path, ".") {
valM, ok := val.(map[string]any)
if !ok {
return nil, fmt.Errorf("expected map, %s", i)
}
val, ok = valM[i]
if !ok {
return nil, fmt.Errorf("%s not found", i)
}
}
return val, nil
}
func extractPorts(data map[string]any, path string) (map[string]any, error) {
ret, err := extractField(data, path)
if err != nil {
return nil, err
}
retM, ok := ret.(map[string]any)
if !ok {
return nil, fmt.Errorf("expected map")
}
return retM, nil
}
func extractString(data map[string]any, path string) (string, error) {
ret, err := extractField(data, path)
if err != nil {
return "", err
}
retS, ok := ret.(string)
if !ok {
return "", fmt.Errorf("expected string")
}
return retS, nil
}
func extractBool(data map[string]any, path string) (bool, error) {
ret, err := extractField(data, path)
if err != nil {
return false, err
}
retS, ok := ret.(bool)
if !ok {
return false, fmt.Errorf("expected boolean")
}
return retS, nil
}
func addPort(pm map[string]any, sourcePort int, targetService string, targetPort int) error {
sourcePortStr := strconv.Itoa(sourcePort)
if _, ok := pm[sourcePortStr]; ok || sourcePort == 80 || sourcePort == 443 || sourcePort == 22 {
return fmt.Errorf("port %d is already taken", sourcePort)
}
pm[sourcePortStr] = fmt.Sprintf("%s:%d", targetService, targetPort)
return nil
}
func removePort(pm map[string]any, port int) error {
sourcePortStr := strconv.Itoa(port)
if _, ok := pm[sourcePortStr]; !ok {
return fmt.Errorf("port %d is not open to remove", port)
}
delete(pm, sourcePortStr)
return nil
}
func (s *server) handleAllocate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "only post method is supported", http.StatusBadRequest)
return
}
req, err := extractAllocateReq(r.Body)
if err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := s.client.AddPortForwarding(
req.Protocol,
req.SourcePort,
req.Secret,
fmt.Sprintf("%s:%d", req.TargetService, req.TargetPort),
); err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
type reserveReq struct {
RemoteProxy bool `json:"remoteProxy"`
}
func (s *server) handleReserve(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "only post method is supported", http.StatusBadRequest)
return
}
var req reserveReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var port int
var secret string
var err error
if port, secret, err = s.client.ReservePort(req.RemoteProxy); err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
go func() {
time.Sleep(30 * time.Minute)
s.client.ReleaseReservedPort(port)
}()
if err := json.NewEncoder(w).Encode(reserveResp{port, secret}); err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
func (s *server) handleRemove(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "only post method is supported", http.StatusBadRequest)
return
}
req, err := extractRemoveReq(r.Body)
if err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := s.client.RemovePortForwarding(req.Protocol, req.SourcePort); err != nil {
fmt.Println(err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// TODO(gio): deduplicate
func createRepoClient(addr string, keyPath string) (soft.RepoIO, error) {
sshKey, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
signer, err := ssh.ParsePrivateKey(sshKey)
if err != nil {
return nil, err
}
repoAddr, err := soft.ParseRepositoryAddress(addr)
if err != nil {
return nil, err
}
repo, err := soft.CloneRepository(repoAddr, signer)
if err != nil {
return nil, err
}
return soft.NewRepoIO(repo, signer)
}
type SecretGenerator func() (string, error)
func generateSecret() (string, error) {
b := make([]byte, secretLength)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("error generating secret: %v", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func main() {
flag.Parse()
repo, err := createRepoClient(*repoAddr, *sshKey)
if err != nil {
log.Fatal(err)
}
c, err := newRepoClient(
repo,
*ingressNginxPath,
*minPreOpenPorts,
*preOpenPortsBatchSize,
generateSecret,
)
if err != nil {
log.Fatal(err)
}
s := newServer(*port, c)
log.Fatal(s.Start())
}