Engine interface with Wireguard specific implementation.
diff --git a/core/vpn/engine/engine.go b/core/vpn/engine/engine.go
new file mode 100644
index 0000000..68900b0
--- /dev/null
+++ b/core/vpn/engine/engine.go
@@ -0,0 +1,24 @@
+package engine
+
+import (
+ "inet.af/netaddr"
+ "tailscale.com/ipn/ipnstate"
+
+ "github.com/giolekva/pcloud/core/vpn/types"
+)
+
+// Abstracts away communication with host OS needed to setup netfwork interfaces
+// for VPN.
+type Engine interface {
+ // Reconfigures local network interfaces in accordance to the given VPN
+ // layout.
+ Configure(netMap *types.NetworkMap) error
+ // Unique public discovery key of the current device.
+ DiscoKey() types.DiscoKey
+ // Unique public endpoint of the given device.
+ // Communication between devices happen throughs such endpoints
+ // instead of IP addresses.
+ DiscoEndpoint() string
+ // Sends ping to the given IP address and invokes callback with results.
+ Ping(ip netaddr.IP, cb func(*ipnstate.PingResult))
+}
diff --git a/core/vpn/engine/wireguard.go b/core/vpn/engine/wireguard.go
new file mode 100644
index 0000000..3dbf717
--- /dev/null
+++ b/core/vpn/engine/wireguard.go
@@ -0,0 +1,168 @@
+package engine
+
+import (
+ "encoding/hex"
+ "fmt"
+ "log"
+
+ "github.com/giolekva/pcloud/core/vpn/types"
+
+ "github.com/tailscale/wireguard-go/wgcfg"
+ "inet.af/netaddr"
+ "tailscale.com/control/controlclient"
+ "tailscale.com/ipn/ipnstate"
+ "tailscale.com/tailcfg"
+ "tailscale.com/types/wgkey"
+ "tailscale.com/wgengine"
+ "tailscale.com/wgengine/router"
+)
+
+// Wireguard specific implementation of the Engine interface.
+type WireguardEngine struct {
+ wg wgengine.Engine
+ port uint16
+ privKey types.PrivateKey
+}
+
+// Creates Wireguard engine.
+func NewWireguardEngine(tunName string, port uint16, privKey types.PrivateKey) (Engine, error) {
+ e, err := wgengine.NewUserspaceEngine(log.Printf, tunName, port)
+ if err != nil {
+ return nil, err
+ }
+ return &WireguardEngine{
+ wg: e,
+ port: port,
+ privKey: privKey,
+ }, nil
+}
+
+// Used for unit testing.
+func NewFakeWireguardEngine(port uint16, privKey types.PrivateKey) (Engine, error) {
+ e, err := wgengine.NewFakeUserspaceEngine(log.Printf, port, nil)
+ if err != nil {
+ return nil, err
+ }
+ return &WireguardEngine{
+ wg: e,
+ port: port,
+ privKey: privKey,
+ }, nil
+}
+
+func genWireguardConf(privKey types.PrivateKey, port uint16, netMap *types.NetworkMap) *wgcfg.Config {
+ c := &wgcfg.Config{
+ Name: "foo",
+ PrivateKey: wgcfg.PrivateKey(privKey),
+ Addresses: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: netMap.Self.VPNIP,
+ Bits: 32, // TODO(giolekva): adapt for IPv6
+ }},
+ ListenPort: port,
+ Peers: make([]wgcfg.Peer, 0, len(netMap.Peers)),
+ }
+ for _, peer := range netMap.Peers {
+ c.Peers = append(c.Peers, wgcfg.Peer{
+ PublicKey: wgcfg.Key(peer.PublicKey),
+ AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: peer.VPNIP,
+ Bits: 32,
+ }},
+ Endpoints: peer.DiscoEndpoint,
+ PersistentKeepalive: 15, // TODO(giolekva): make it configurable
+ })
+ }
+ return c
+}
+
+func genRouterConf(netMap *types.NetworkMap) *router.Config {
+ c := &router.Config{
+ LocalAddrs: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: netMap.Self.VPNIP,
+ Bits: 32,
+ }},
+ Routes: make([]netaddr.IPPrefix, 0, len(netMap.Peers)),
+ }
+ for _, peer := range netMap.Peers {
+ c.Routes = append(c.Routes, netaddr.IPPrefix{
+ IP: peer.VPNIP,
+ Bits: 32,
+ })
+ }
+ return c
+}
+
+func genTailNetMap(privKey types.PrivateKey, port uint16, netMap *types.NetworkMap) *controlclient.NetworkMap {
+ fmt.Println(netMap.Self.IPPort.String())
+ c := &controlclient.NetworkMap{
+ SelfNode: &tailcfg.Node{
+ ID: 0, // TODO(giolekva): maybe IDs should be stored server side.
+ StableID: "0",
+ Name: "0",
+ Key: tailcfg.NodeKey(netMap.Self.PublicKey),
+ DiscoKey: tailcfg.DiscoKey(netMap.Self.DiscoKey),
+ Addresses: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: netMap.Self.VPNIP,
+ Bits: 32,
+ }},
+ AllowedIPs: make([]netaddr.IPPrefix, 0, len(netMap.Peers)),
+ Endpoints: []string{netMap.Self.IPPort.String()},
+ KeepAlive: true, // TODO(giolekva): make it configurable
+ },
+ NodeKey: tailcfg.NodeKey(netMap.Self.PublicKey),
+ PrivateKey: wgkey.Private(privKey),
+ Name: "0",
+ Addresses: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: netMap.Self.VPNIP,
+ Bits: 32,
+ }},
+ LocalPort: port,
+ Peers: make([]*tailcfg.Node, 0, len(netMap.Peers)),
+ }
+ for i, peer := range netMap.Peers {
+ c.Peers = append(c.Peers, &tailcfg.Node{
+ ID: tailcfg.NodeID(i + 1),
+ StableID: tailcfg.StableNodeID(fmt.Sprintf("%d", i+1)),
+ Name: fmt.Sprintf("%d", i+1),
+ Key: tailcfg.NodeKey(peer.PublicKey),
+ DiscoKey: tailcfg.DiscoKey(peer.DiscoKey),
+ Addresses: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: peer.VPNIP,
+ Bits: 32,
+ }},
+ AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefix{
+ IP: netMap.Self.VPNIP,
+ Bits: 32,
+ }},
+ Endpoints: []string{peer.IPPort.String()},
+ KeepAlive: true,
+ })
+ }
+ return c
+}
+
+func (e *WireguardEngine) Configure(netMap *types.NetworkMap) error {
+ err := e.wg.Reconfig(
+ genWireguardConf(e.privKey, e.port, netMap),
+ genRouterConf(netMap))
+ if err != nil {
+ return err
+ }
+ e.wg.SetNetworkMap(genTailNetMap(e.privKey, e.port, netMap))
+ e.wg.RequestStatus()
+ return err
+}
+
+func (e *WireguardEngine) DiscoKey() types.DiscoKey {
+ return types.DiscoKey(e.wg.DiscoPublicKey())
+}
+
+func (e *WireguardEngine) DiscoEndpoint() string {
+ k := e.DiscoKey()
+ discoHex := hex.EncodeToString(k[:])
+ return fmt.Sprintf("%s%s", discoHex, controlclient.EndpointDiscoSuffix)
+}
+
+func (e *WireguardEngine) Ping(ip netaddr.IP, cb func(*ipnstate.PingResult)) {
+ e.wg.Ping(ip, cb)
+}
diff --git a/core/vpn/engine/wireguard_test.go b/core/vpn/engine/wireguard_test.go
new file mode 100644
index 0000000..39488e3
--- /dev/null
+++ b/core/vpn/engine/wireguard_test.go
@@ -0,0 +1,119 @@
+package engine
+
+import (
+ "fmt"
+ "log"
+ "testing"
+
+ "github.com/giolekva/pcloud/core/vpn/types"
+
+ "inet.af/netaddr"
+ "tailscale.com/ipn/ipnstate"
+)
+
+type node struct {
+ ip netaddr.IP
+ privKey types.PrivateKey
+ node types.Node
+ peers []types.Node
+ e Engine
+}
+
+func newNode(ip string, localPort uint16) (n *node, err error) {
+ n = &node{
+ ip: netaddr.MustParseIP(ip),
+ privKey: types.NewPrivateKey(),
+ }
+ if n.e, err = NewFakeWireguardEngine(localPort, n.privKey); err != nil {
+ return
+ }
+ n.node = types.Node{
+ PublicKey: n.privKey.Public(),
+ DiscoKey: n.e.DiscoKey(),
+ DiscoEndpoint: n.e.DiscoEndpoint(),
+ IPPort: netaddr.IPPort{
+ IP: netaddr.IPv4(127, 0, 0, 1),
+ Port: localPort,
+ },
+ VPNIP: netaddr.MustParseIP(ip),
+ }
+ return
+}
+
+func (n *node) addPeer(x types.Node) {
+ n.peers = append(n.peers, x)
+}
+
+func (n *node) configure() error {
+ return n.e.Configure(&types.NetworkMap{n.node, n.peers})
+}
+
+func (n *node) ping(ip string, ch chan<- *ipnstate.PingResult) {
+ n.e.Ping(netaddr.MustParseIP(ip), func(p *ipnstate.PingResult) {
+ ch <- p
+ })
+}
+
+func TestTwoPeers(t *testing.T) {
+ var a, b *node
+ var err error
+ if a, err = newNode("10.0.0.1", 1234); err != nil {
+ t.Fatal(err)
+ }
+ if b, err = newNode("10.0.0.2", 1235); err != nil {
+ t.Fatal(err)
+ }
+ a.addPeer(b.node)
+ b.addPeer(a.node)
+ if err := a.configure(); err != nil {
+ t.Fatal(err)
+ }
+ if err := b.configure(); err != nil {
+ t.Fatal(err)
+ }
+ ping := make(chan *ipnstate.PingResult, 0)
+ a.ping("10.0.0.2", ping)
+ b.ping("10.0.0.1", ping)
+ for i := 0; i < 2; i++ {
+ p := <-ping
+ if p.Err != "" {
+ t.Error(p.Err)
+ }
+ log.Printf("Ping received: %+v\n", p)
+ }
+}
+
+func TestTenPeers(t *testing.T) {
+ n := 10
+ nodes := make([]*node, n)
+ ping := make(chan *ipnstate.PingResult, 0)
+ for i := 0; i < n; i++ {
+ ip := fmt.Sprintf("10.0.0.%d", i+1)
+ localPort := uint16(i + 4321)
+ var err error
+ if nodes[i], err = newNode(ip, localPort); err != nil {
+ t.Fatal(err)
+ }
+ for j := 0; j < i; j++ {
+ nodes[i].addPeer(nodes[j].node)
+ nodes[j].addPeer(nodes[i].node)
+ }
+ }
+ for i := 0; i < n; i++ {
+ if err := nodes[i].configure(); err != nil {
+ t.Fatal(err)
+ }
+ for j := 0; j < i; j++ {
+ nodes[i].ping(nodes[j].ip.String(), ping)
+ nodes[j].ping(nodes[i].ip.String(), ping)
+ }
+
+ }
+ for i := 0; i < n*(n-1); i++ {
+ p := <-ping
+ if p.Err != "" {
+ t.Error(p.Err)
+ }
+ log.Printf("Ping received: %+v\n", p)
+ }
+}