From eece51b54753acd80e66c9241d126abf1e6e2096 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan-=C8=98tefan=20Neac=C5=9Fu?= Date: Thu, 30 Jan 2025 11:53:52 +0100 Subject: [PATCH] Port multihop tun --- go.mod | 1 + go.sum | 2 + tun/multihoptun/bind.go | 126 +++++++++ tun/multihoptun/tun.go | 307 ++++++++++++++++++++++ tun/multihoptun/tun_test.go | 507 ++++++++++++++++++++++++++++++++++++ 5 files changed, 943 insertions(+) create mode 100644 tun/multihoptun/bind.go create mode 100644 tun/multihoptun/tun.go create mode 100644 tun/multihoptun/tun_test.go diff --git a/go.mod b/go.mod index 115ae88..7737df1 100644 --- a/go.mod +++ b/go.mod @@ -14,4 +14,5 @@ require ( require ( github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 ) diff --git a/go.sum b/go.sum index 7b53725..9749226 100644 --- a/go.sum +++ b/go.sum @@ -12,5 +12,7 @@ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0k golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/tun/multihoptun/bind.go b/tun/multihoptun/bind.go new file mode 100644 index 0000000..3f345ac --- /dev/null +++ b/tun/multihoptun/bind.go @@ -0,0 +1,126 @@ +package multihoptun + +import ( + "math/rand" + "net" + + "golang.zx2c4.com/wireguard/conn" + + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +type multihopBind struct { + *MultihopTun + socketShutdown chan struct{} +} + +// Close implements tun.Device +func (st *multihopBind) Close() error { + select { + case <-st.socketShutdown: + return nil + default: + close(st.socketShutdown) + } + return nil +} + +// Open implements conn.Bind. +func (st *multihopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + if port != 0 { + st.localPort = port + } else { + st.localPort = uint16(rand.Uint32()>>16) | 1 + } + // WireGuard will close existing sockets before bringing up a new device on Bind updates. + // This guarantees that the socket shutdown channel is always available. + st.socketShutdown = make(chan struct{}) + + actualPort = st.localPort + fns = []conn.ReceiveFunc{ + func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + var batch packetBatch + var ok bool + + select { + case <-st.shutdownChan: + return 0, net.ErrClosed + case <-st.socketShutdown: + return 0, net.ErrClosed + case batch, ok = <-st.writeRecv: + break + } + if !ok { + return 0, net.ErrClosed + } + + ipVersion := header.IPVersion(batch.packet[batch.offset:]) + if ipVersion == 4 { + v4 := header.IPv4(batch.packet[batch.offset:]) + udp := header.UDP(v4.Payload()) + copy(packets[0], udp.Payload()) + sizes[0] = len(udp.Payload()) + + } else if ipVersion == 6 { + v6 := header.IPv6(batch.packet[batch.offset:]) + udp := header.UDP(v6.Payload()) + copy(packets[0], udp.Payload()) + sizes[0] = len(udp.Payload()) + } + batch.size = sizes[0] + eps[0] = st.endpoint + + batch.completion <- batch + return 1, nil + }, + } + + return fns, actualPort, nil +} + +// ParseEndpoint implements conn.Bind. +func (*multihopBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return conn.NewStdNetBind().ParseEndpoint(s) +} + +// Send implements conn.Bind. +func (st *multihopBind) Send(bufs [][]byte, ep conn.Endpoint) error { + var packetBatch packetBatch + var ok bool + + for _, buf := range bufs { + select { + case <-st.shutdownChan: + return net.ErrClosed + case <-st.socketShutdown: + // it is important to return a net.ErrClosed, since it implements the + // net.Error interface and indicates that it is not a recoverable error. + // wg-go uses the net.Error interface to deduce if it should try to send + // packets again after some time or if it should give up. + return net.ErrClosed + case packetBatch, ok = <-st.readRecv: + } + + if !ok { + return net.ErrClosed + } + + targetPacket := packetBatch.packet[packetBatch.offset:] + size, err := st.writePayload(targetPacket, buf) + + packetBatch.size = size + + packetBatch.completion <- packetBatch + + if err != nil { + return err + } + } + + return nil +} + +// SetMark implements conn.Bind. +func (*multihopBind) SetMark(mark uint32) error { + return nil +} diff --git a/tun/multihoptun/tun.go b/tun/multihoptun/tun.go new file mode 100644 index 0000000..d098a6e --- /dev/null +++ b/tun/multihoptun/tun.go @@ -0,0 +1,307 @@ +package multihoptun + +import ( + "errors" + "fmt" + "io" + "math" + "math/rand" + "net/netip" + "os" + "sync/atomic" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// This is a special implementation of `tun.Device` that allows to connect a +// `conn.Bind` from one WireGuard device to another's `tun.Device`. This way, +// we can create a multi-hop WireGuard device that can use the same private key +// and elude any MTU issues, since, for any single user packet, there is only +// ever a single read from the real tunnel device needed to send it to the +// entry hop. +// +// tun.Device.Write will push a buffer via writeRecv to be read by the recvfunc +// of conn.Bind, stripping IPv4/IPv6 + UDP headers in the process. When the +// packets have been transferred to the UDP receiver, writeDone will be used to +// return from tun.Device.Write. Conversely, conn.Bind.Send will push a buffer +// via readRecv to be read by tun.Device.Read, adding valid IPv4/IPv6 + UDP +// headers in the process. +// +// Implements tun.Device and can create instances of conn.Bind. +type MultihopTun struct { + readRecv chan packetBatch + writeRecv chan packetBatch + isIpv4 bool + localIp []byte + localPort uint16 + remoteIp []byte + remotePort uint16 + ipConnectionId uint16 + tunEvent chan tun.Event + mtu int + endpoint conn.Endpoint + closed atomic.Bool + shutdownChan chan struct{} +} + +type packetBatch struct { + packet []byte + size int + offset int + // to be used to return the packet batch back to tun.Read and tun.Write + completion chan packetBatch +} + +func (pb *packetBatch) Size() int { + return len(pb.packet) +} + +func NewMultihopTun(local, remote netip.Addr, remotePort uint16, mtu int) MultihopTun { + readRecv := make(chan packetBatch) + writeRecv := make(chan packetBatch) + endpoint, err := conn.NewStdNetBind().ParseEndpoint(netip.AddrPortFrom(remote, remotePort).String()) + if err != nil { + panic("Failed to parse endpoint") + } + + connectionId := uint16(rand.Uint32()>>16) | 1 + shutdownChan := make(chan struct{}) + + return MultihopTun{ + readRecv, + writeRecv, + local.Is4(), + local.AsSlice(), + 0, + remote.AsSlice(), + remotePort, + connectionId, + make(chan tun.Event), + mtu, + endpoint, + atomic.Bool{}, + shutdownChan, + } +} + +func (st *MultihopTun) Binder() conn.Bind { + socketShutdown := make(chan struct{}) + return &multihopBind{ + st, + socketShutdown, + } + +} + +// Events implements tun.Device. +func (st *MultihopTun) Events() <-chan tun.Event { + return st.tunEvent +} + +// File implements tun.Device. +func (*MultihopTun) File() *os.File { + return nil +} + +// MTU implements tun.Device. +func (st *MultihopTun) MTU() (int, error) { + return st.mtu, nil +} + +// Name implements tun.Device. +func (*MultihopTun) Name() (string, error) { + return "stun", nil +} + +// Write implements tun.Device. +func (st *MultihopTun) Write(packets [][]byte, offset int) (int, error) { + for i, packet := range packets { + completion := make(chan packetBatch) + packetBatch := packetBatch{ + packet: packet, + offset: offset, + size: len(packet), + completion: completion, + } + + select { + case st.writeRecv <- packetBatch: + break + case <-st.shutdownChan: + return i, io.EOF + } + + packetBatch, ok := <-completion + + if !ok { + return i, io.EOF + } + } + + return len(packets), nil +} + +// Read implements tun.Device. +func (st *MultihopTun) Read(packet [][]byte, sizes []int, offset int) (n int, err error) { + completion := make(chan packetBatch) + packetBatch := packetBatch{ + packet: packet[0], + size: 0, + offset: offset, + completion: completion, + } + + select { + case st.readRecv <- packetBatch: + break + case <-st.shutdownChan: + return 0, io.EOF + } + + var ok bool + packetBatch, ok = <-completion + + if !ok { + return 0, io.EOF + } + + sizes[0] = packetBatch.size + return 1, nil +} + +func (st *MultihopTun) writePayload(target, payload []byte) (size int, err error) { + headerSize := st.headerSize() + if headerSize+len(payload) > len(target) { + err = errors.New(fmt.Sprintf("target buffer is too small, need %d, got %d", headerSize+len(payload), len(target))) + return + } + + if st.isIpv4 { + return st.writeV4Payload(target, payload) + } else { + return st.writeV6Payload(target, payload) + } +} + +func (st *MultihopTun) writeV4Payload(target, payload []byte) (size int, err error) { + var ipv4 header.IPv4 + ipv4 = target + + size = st.headerSize() + len(payload) + src := tcpip.AddrFrom4Slice(st.localIp) + dst := tcpip.AddrFrom4Slice(st.remoteIp) + fields := header.IPv4Fields{ + // TODO: Figure out the best DSCP value, ideally would be 0x88 for handshakes and 0x00 for rest. + TOS: 0, + TotalLength: uint16(size), + ID: st.ipConnectionId, + TTL: 64, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: src, + DstAddr: dst, + Checksum: 0, + } + ipv4.Encode(&fields) + ipv4.SetChecksum(^ipv4.CalculateChecksum()) + st.writeUdpPayload(ipv4.Payload(), payload, src, dst) + return +} + +func (st *MultihopTun) writeV6Payload(target, payload []byte) (size int, err error) { + + var ipv6 header.IPv6 + ipv6 = target + + size = st.headerSize() + len(payload) + src := tcpip.AddrFrom4Slice(st.localIp) + dst := tcpip.AddrFrom4Slice(st.remoteIp) + fields := header.IPv6Fields{ + TrafficClass: 0, + PayloadLength: uint16(len(payload)), + FlowLabel: uint32(st.ipConnectionId), + TransportProtocol: header.UDPProtocolNumber, + SrcAddr: src, + DstAddr: dst, + HopLimit: 64, + } + ipv6.Encode(&fields) + + st.writeUdpPayload(ipv6.Payload(), payload, src, dst) + return +} + +func (st *MultihopTun) writeUdpPayload(target header.UDP, payload []byte, src, dst tcpip.Address) { + target.Encode(&header.UDPFields{ + SrcPort: st.localPort, + DstPort: st.remotePort, + Length: uint16(len(payload) + header.UDPMinimumSize), + Checksum: 0, + }) + copy(target.Payload()[:], payload[:]) + + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + xsum := target.CalculateChecksum(checksum.Combine( + header.PseudoHeaderChecksum(header.UDPProtocolNumber, src, dst, uint16(len(payload)+header.UDPMinimumSize)), + checksum.Checksum(target, 0), + )) + // As per RFC 768 page 2, + // + // Checksum is the 16-bit one's complement of the one's complement sum of + // a pseudo header of information from the IP header, the UDP header, and + // the data, padded with zero octets at the end (if necessary) to make a + // multiple of two octets. + // + // The pseudo header conceptually prefixed to the UDP header contains the + // source address, the destination address, the protocol, and the UDP + // length. This information gives protection against misrouted datagrams. + // This checksum procedure is the same as is used in TCP. + // + // If the computed checksum is zero, it is transmitted as all ones (the + // equivalent in one's complement arithmetic). An all zero transmitted + // checksum value means that the transmitter generated no checksum (for + // debugging or for higher level protocols that don't care). + // + // To avoid the zero value, we only calculate the one's complement of the + // one's complement sum if the sum is not all ones. + if xsum != math.MaxUint16 { + xsum = ^xsum + } + target.SetChecksum(0) + +} + +func (st *MultihopTun) headerSize() int { + udpPacketSize := header.UDPMinimumSize + if st.isIpv4 { + return header.IPv4MinimumSize + udpPacketSize + } else { + return header.IPv6MinimumSize + udpPacketSize + } +} + +// BatchSize implements conn.Bind. +func (*MultihopTun) BatchSize() int { + return 128 +} + +// BatchSize implements conn.Bind. +func (*MultihopTun) Flush() error { + return nil +} + +// Close implements tun.Device +func (st *MultihopTun) Close() error { + if !st.closed.Load() { + st.closed.Store(true) + } + close(st.shutdownChan) + return nil +} diff --git a/tun/multihoptun/tun_test.go b/tun/multihoptun/tun_test.go new file mode 100644 index 0000000..c7a8205 --- /dev/null +++ b/tun/multihoptun/tun_test.go @@ -0,0 +1,507 @@ +package multihoptun + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "fmt" + "net" + "net/netip" + "testing" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestMultihopTunBind(t *testing.T) { + stIp := netip.AddrFrom4([4]byte{192, 168, 1, 1}) + virtualIp := netip.AddrFrom4([4]byte{192, 168, 1, 11}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + + _ = device.NewDevice(&st, st.Binder(), device.NewLogger(device.LogLevelSilent, "")) +} + +func TestMultihopTunTrafficV4(t *testing.T) { + + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + stBind := st.Binder() + + virtualTun, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280) + + // Pipe reads from virtualTun into multihop tun + go func() { + buf := make([][]byte, 1600) + sizes := make([]int, 1) + var err error + for err == nil { + _, err = virtualTun.Read(buf, sizes, 0) + _, err = st.Write(buf[:sizes[0]], 0) + } + + }() + + // Pipe reads from multihop tun into virtualTun + go func() { + buf := make([][]byte, 1600) + sizes := make([]int, 1) + var err error + for err == nil { + _, err = st.Read(buf, sizes, 0) + _, err = virtualTun.Write(buf[:sizes[0]], 0) + } + }() + + recvFunc, _, err := stBind.Open(0) + if err != nil { + t.Fatalf("Failed to open port for multihop tun: %s", err) + } + + payload := [][]byte{{1, 2, 3, 4}} + readyChan := make(chan struct{}) + // Listen on the virtual tunnel + go func() { + conn, err := virtualNet.ListenUDPAddrPort(netip.AddrPortFrom(virtualIp, remotePort)) + if err != nil { + panic(err) + } + readyChan <- struct{}{} + buff := make([]byte, 4) + n, addr, _ := conn.ReadFrom(buff) + if n == 0 { + fmt.Println("Did not receive anything") + } + + conn.WriteTo(buff, addr) + }() + _, _ = <-readyChan + + err = stBind.Send(payload, nil) + if err != nil { + t.Fatalf("Failed ot send traffic to multihop tun: %s", err) + } + + recvBuf := make([][]byte, 1600) + sizes := make([]int, 1) + eps := make([]conn.Endpoint, 1) + _, err = recvFunc[0](recvBuf, sizes, eps) + packetSize := sizes[0] + if err != nil { + t.Fatalf("Failed to receive traffic from recvFunc - %s", err) + } + if packetSize != len(payload) { + t.Fatalf("Expected to recieve %d bytes, instead received %d", len(payload), packetSize) + } + + for idx := range payload { + if payload[0][idx] != recvBuf[0][idx] { + t.Fatalf("Expected to receive %v, instead received %v", payload, recvBuf[0]) + } + } +} + +func TestReadEnd(t *testing.T) { + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + stBind := st.Binder() + otherSt := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + + readerDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "")) + otherDev := device.NewDevice(&otherSt, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "")) + + configureDevices(t, readerDev, otherDev) + + readerDev.Up() + receivers, port, err := stBind.Open(0) + if err != nil { + t.Fatalf("Failed to open UDP socket: %s", err) + } + if len(receivers) != 1 { + t.Fatalf("Expected 1 receiver func, got %v", len(receivers)) + } + + if port == 0 { + t.Fatalf("Expected a random port to be assigned, instead got 0") + } + + buf := [][]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}} + + err = stBind.Send(buf, nil) + if err != nil { + t.Fatalf("Error when sending UDP traffic: %v", err) + } +} + +func TestMultihopTunWrite(t *testing.T) { + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + stBind := st.Binder() + + receivers, port, err := stBind.Open(0) + if err != nil { + t.Fatalf("Failed to open UDP socket: %s", err) + } + if len(receivers) != 1 { + t.Fatalf("Expected 1 receiver func, got %v", len(receivers)) + } + + if port == 0 { + t.Fatalf("Expected a random port to be assigned, instead got 0") + } + + udpPacket := [][]byte{{69, 0, 0, 32, 164, 27, 0, 0, 64, 17, 206, 165, 1, 2, 3, 5, 1, 2, 3, 4, 209, 129, 19, 141, 0, 12, 0, 0, 1, 2, 3, 4}} + + if err != nil { + t.Fatalf("Error when sending UDP traffic: %v", err) + } + go func() { + st.Write(udpPacket, 0) + }() + + buf := make([][]byte, 1600) + sizes := make([]int, 1) + eps := make([]conn.Endpoint, 1) + + _, err = receivers[0](buf, sizes, eps) + packetSize := sizes[0] + if err != nil { + t.Fatalf("Failed to receive packets: %s", err) + } + + expected := []byte{1, 2, 3, 4} + if len(buf[:packetSize]) != len(expected) { + t.Fatalf("Expected %v, got %v", expected, buf[0]) + } + + for b := range buf[:packetSize] { + if buf[0][b] != expected[b] { + t.Fatalf("Expected %v, got %v", expected, buf[0]) + } + } +} + +func TestMultihopTunRead(t *testing.T) { + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + stBind := st.Binder() + + _, _, err := stBind.Open(0) + if err != nil { + t.Fatalf("Failed to open UDP socket: %s", err) + } + + payload := [][]byte{{1, 2, 3, 4}} + go stBind.Send(payload, nil) + + bytes := make([][]byte, 1500, 1500) + sizes := make([]int, 1) + _, err = st.Read(bytes, sizes, 0) + bytesRead := sizes[0] + if err != nil { + t.Fatalf("Failed to read from tunnel device: %v", err) + } + + packet := header.IPv4(bytes[0][:bytesRead]) + virtualIpBytes, _ := virtualIp.MarshalBinary() + stIpBytes, _ := stIp.MarshalBinary() + + if packet.SourceAddress() != tcpip.AddrFromSlice(stIpBytes) { + t.Fatalf("expected %v, got %v", stIp, packet.SourceAddress()) + } + + if packet.DestinationAddress() != tcpip.AddrFromSlice(virtualIpBytes) { + t.Fatalf("expected %v, got %v", virtualIp, packet.DestinationAddress()) + } + +} + +func configureDevices(t testing.TB, aDev *device.Device, bDev *device.Device) { + configs, endpointConfigs, _ := genConfigs(t) + aConfig := configs[0] + endpointConfigs[0] + bConfig := configs[1] + endpointConfigs[1] + aDev.IpcSet(aConfig) + bDev.IpcSet(bConfig) +} + +func genConfigsForMultihop(t testing.TB) ([4]string, [4]uint16) { + entryConfigs, entryEndpoints, entryPorts := genConfigs(t) + exitConfigs, exitEndpoints, exitPorts := genConfigs(t) + + aExitConfig := exitConfigs[0] + exitEndpoints[0] + bExitConfig := exitConfigs[1] + exitEndpoints[1] + aEntryConfig := entryConfigs[0] + entryEndpoints[0] + bEntryConfig := entryConfigs[1] + entryEndpoints[1] + + ports := [4]uint16{entryPorts[0], exitPorts[0], exitPorts[1], entryPorts[1]} + + return [4]string{aEntryConfig, aExitConfig, bExitConfig, bEntryConfig}, ports + +} + +// genConfigs generates a pair of configs that connect to each other. +// The configs use distinct, probably-usable ports. +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string, ports [2]uint16) { + var key1, key2 device.NoisePrivateKey + + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + + ports[0] = getFreeLocalUdpPort(tb) + ports[1] = getFreeLocalUdpPort(tb) + + pub1, pub2 := publicKey(&key1), publicKey(&key2) + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", fmt.Sprintf("%d", ports[0]), + "replace_peers", "true", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "0.0.0.0/0", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", fmt.Sprintf("127.0.0.1:%d", ports[1]), + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", fmt.Sprintf("%d", ports[1]), + "replace_peers", "true", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "0.0.0.0/0", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", fmt.Sprintf("127.0.0.1:%d", ports[0]), + ) + return +} + +func publicKey(sk *device.NoisePrivateKey) (pk device.NoisePublicKey) { + apk := (*[device.NoisePublicKeySize]byte)(&pk) + ask := (*[device.NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func getFreeLocalUdpPort(t testing.TB) uint16 { + localAddr := netip.MustParseAddrPort("127.0.0.1:0") + udpSockAddr := net.UDPAddrFromAddrPort(localAddr) + udpConn, err := net.ListenUDP("udp4", udpSockAddr) + if err != nil { + t.Fatalf("Failed to open a UDP socket to assign an empty port") + } + defer udpConn.Close() + + port := netip.MustParseAddrPort(udpConn.LocalAddr().String()).Port() + + return port +} + +func uapiCfg(cfg ...string) string { + if len(cfg)%2 != 0 { + panic("odd number of args to uapiReader") + } + buf := new(bytes.Buffer) + for i, s := range cfg { + buf.WriteString(s) + sep := byte('\n') + if i%2 == 0 { + sep = '=' + } + buf.WriteByte(sep) + } + return buf.String() +} + +func TestShutdown(t *testing.T) { + a, b := generateTestPair(t) + b.Close() + a.Close() +} + +func TestReversedShutdown(t *testing.T) { + a, b := generateTestPair(t) + a.Close() + b.Close() +} + +func generateTestPair(t *testing.T) (*device.Device, *device.Device) { + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + stBind := st.Binder() + + virtualDev, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280) + + readerDev := device.NewDevice(virtualDev, stBind, device.NewLogger(device.LogLevelSilent, "")) + otherDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, "")) + + configureDevices(t, readerDev, otherDev) + + readerDev.Up() + otherDev.Up() + + conn, err := virtualNet.Dial("ping4", "10.64.0.1") + requestPing := icmp.Echo{ + Seq: 345, + Data: []byte("gopher burrow"), + } + icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + conn.SetReadDeadline(time.Now().Add(time.Second * 9)) + _, err = conn.Write(icmpBytes) + if err != nil { + t.Fatal(err) + } + + return readerDev, otherDev +} + +func TestShutdownBind(t *testing.T) { + stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + remotePort := uint16(5005) + + st := NewMultihopTun(stIp, virtualIp, remotePort, 1280) + binder := st.Binder() + recvFunc, _, err := binder.Open(0) + if err != nil { + t.Fatalf("Failed to open a UDP socket, %v", err) + } + + st.Close() + + buf := make([][]byte, 1600) + sizes := make([]int, 1) + eps := make([]conn.Endpoint, 1) + _, err = recvFunc[0](buf, sizes, eps) + neterr, ok := err.(net.Error) + if !ok { + t.Fatalf("Expected a net.Error, instead got %v", err) + } + if neterr.Temporary() { + t.Fatalf("Expected the net error to not be temporary") + } +} + +func TestMultihopLocally(t *testing.T) { + aVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 5}) + bVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + + configsForMultihop, ports := genConfigsForMultihop(t) + + multihopA := NewMultihopTun(aVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[3], 1280) + multihopB := NewMultihopTun(bVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[0], 1280) + aBinder := multihopA.Binder() + bBinder := multihopB.Binder() + + virtualDevA, virtualNetA, _ := netstack.CreateNetTUN([]netip.Addr{aVirtualIp}, []netip.Addr{}, 1280) + virtualDevB, virtualNetB, _ := netstack.CreateNetTUN([]netip.Addr{bVirtualIp}, []netip.Addr{}, 1280) + + aExitDevice := device.NewDevice(virtualDevA, aBinder, device.NewLogger(device.LogLevelVerbose, "")) + aExitDevice.IpcSet(configsForMultihop[0]) + + aEntryDevice := device.NewDevice(&multihopA, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, "")) + aEntryDevice.IpcSet(configsForMultihop[1]) + + bEntryDevice := device.NewDevice(&multihopB, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, "")) + bEntryDevice.IpcSet(configsForMultihop[2]) + + bExitDevice := device.NewDevice(virtualDevB, bBinder, device.NewLogger(device.LogLevelVerbose, "")) + bExitDevice.IpcSet(configsForMultihop[3]) + + err := aExitDevice.Up() + if err != nil { + t.Fatalf("exit device a failed to up itself: %v", err) + } + + err = aEntryDevice.Up() + if err != nil { + t.Fatalf("entry device a failed to up itself: %v", err) + } + + err = bExitDevice.Up() + if err != nil { + t.Fatalf("exit device b failed to up itself: %v", err) + } + + err = bEntryDevice.Up() + if err != nil { + t.Fatalf("entry device b failed to up itself: %v", err) + } + + listenerAddr := netip.AddrPortFrom(bVirtualIp, 7070) + senderAddr := netip.AddrPortFrom(aVirtualIp, 4040) + listenerSocket, err := virtualNetB.ListenUDPAddrPort(netip.AddrPortFrom(bVirtualIp, 7070)) + if err != nil { + t.Fatalf("Fail to open listener socket: %v", err) + } + + senderSocket, err := virtualNetA.DialUDPAddrPort(senderAddr, listenerAddr) + if err != nil { + t.Fatalf("Failed to open sender socket: %v", err) + } + + payload := []byte{1, 2, 3, 4, 5} + + n, err := senderSocket.Write(payload) + if err != nil { + t.Fatalf("Failed to send payload: %v", err) + } + + if n != len(payload) { + t.Fatalf("Expected to send %v bytes, instead sent %v", len(payload), n) + } + + rxBuffer := []byte{1, 2, 3, 4, 5} + n, err = listenerSocket.Read(rxBuffer) + if err != nil { + t.Fatalf("Failed to receive payload: %v", err) + } + if n != len(payload) { + t.Fatalf("Expected to read %v bytes, instead read %v bytes", len(payload), n) + } + + for idx := range rxBuffer { + if rxBuffer[idx] != payload[idx] { + t.Fatalf("At index %d, expected value %d, instead got %v", idx, rxBuffer[idx], payload[idx]) + } + } + + aEntryDevice.Close() + aExitDevice.Close() + bEntryDevice.Close() + bExitDevice.Close() +}