Port multihop tun

This commit is contained in:
Bogdan-Ștefan Neacşu 2025-01-30 11:53:52 +01:00
parent 2e3f7d122c
commit eece51b547
5 changed files with 943 additions and 0 deletions

1
go.mod
View file

@ -14,4 +14,5 @@ require (
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
)

2
go.sum
View file

@ -12,5 +12,7 @@ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0k
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

126
tun/multihoptun/bind.go Normal file
View file

@ -0,0 +1,126 @@
package multihoptun
import (
"math/rand"
"net"
"golang.zx2c4.com/wireguard/conn"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
type multihopBind struct {
*MultihopTun
socketShutdown chan struct{}
}
// Close implements tun.Device
func (st *multihopBind) Close() error {
select {
case <-st.socketShutdown:
return nil
default:
close(st.socketShutdown)
}
return nil
}
// Open implements conn.Bind.
func (st *multihopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
if port != 0 {
st.localPort = port
} else {
st.localPort = uint16(rand.Uint32()>>16) | 1
}
// WireGuard will close existing sockets before bringing up a new device on Bind updates.
// This guarantees that the socket shutdown channel is always available.
st.socketShutdown = make(chan struct{})
actualPort = st.localPort
fns = []conn.ReceiveFunc{
func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
var batch packetBatch
var ok bool
select {
case <-st.shutdownChan:
return 0, net.ErrClosed
case <-st.socketShutdown:
return 0, net.ErrClosed
case batch, ok = <-st.writeRecv:
break
}
if !ok {
return 0, net.ErrClosed
}
ipVersion := header.IPVersion(batch.packet[batch.offset:])
if ipVersion == 4 {
v4 := header.IPv4(batch.packet[batch.offset:])
udp := header.UDP(v4.Payload())
copy(packets[0], udp.Payload())
sizes[0] = len(udp.Payload())
} else if ipVersion == 6 {
v6 := header.IPv6(batch.packet[batch.offset:])
udp := header.UDP(v6.Payload())
copy(packets[0], udp.Payload())
sizes[0] = len(udp.Payload())
}
batch.size = sizes[0]
eps[0] = st.endpoint
batch.completion <- batch
return 1, nil
},
}
return fns, actualPort, nil
}
// ParseEndpoint implements conn.Bind.
func (*multihopBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return conn.NewStdNetBind().ParseEndpoint(s)
}
// Send implements conn.Bind.
func (st *multihopBind) Send(bufs [][]byte, ep conn.Endpoint) error {
var packetBatch packetBatch
var ok bool
for _, buf := range bufs {
select {
case <-st.shutdownChan:
return net.ErrClosed
case <-st.socketShutdown:
// it is important to return a net.ErrClosed, since it implements the
// net.Error interface and indicates that it is not a recoverable error.
// wg-go uses the net.Error interface to deduce if it should try to send
// packets again after some time or if it should give up.
return net.ErrClosed
case packetBatch, ok = <-st.readRecv:
}
if !ok {
return net.ErrClosed
}
targetPacket := packetBatch.packet[packetBatch.offset:]
size, err := st.writePayload(targetPacket, buf)
packetBatch.size = size
packetBatch.completion <- packetBatch
if err != nil {
return err
}
}
return nil
}
// SetMark implements conn.Bind.
func (*multihopBind) SetMark(mark uint32) error {
return nil
}

307
tun/multihoptun/tun.go Normal file
View file

@ -0,0 +1,307 @@
package multihoptun
import (
"errors"
"fmt"
"io"
"math"
"math/rand"
"net/netip"
"os"
"sync/atomic"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// This is a special implementation of `tun.Device` that allows to connect a
// `conn.Bind` from one WireGuard device to another's `tun.Device`. This way,
// we can create a multi-hop WireGuard device that can use the same private key
// and elude any MTU issues, since, for any single user packet, there is only
// ever a single read from the real tunnel device needed to send it to the
// entry hop.
//
// tun.Device.Write will push a buffer via writeRecv to be read by the recvfunc
// of conn.Bind, stripping IPv4/IPv6 + UDP headers in the process. When the
// packets have been transferred to the UDP receiver, writeDone will be used to
// return from tun.Device.Write. Conversely, conn.Bind.Send will push a buffer
// via readRecv to be read by tun.Device.Read, adding valid IPv4/IPv6 + UDP
// headers in the process.
//
// Implements tun.Device and can create instances of conn.Bind.
type MultihopTun struct {
readRecv chan packetBatch
writeRecv chan packetBatch
isIpv4 bool
localIp []byte
localPort uint16
remoteIp []byte
remotePort uint16
ipConnectionId uint16
tunEvent chan tun.Event
mtu int
endpoint conn.Endpoint
closed atomic.Bool
shutdownChan chan struct{}
}
type packetBatch struct {
packet []byte
size int
offset int
// to be used to return the packet batch back to tun.Read and tun.Write
completion chan packetBatch
}
func (pb *packetBatch) Size() int {
return len(pb.packet)
}
func NewMultihopTun(local, remote netip.Addr, remotePort uint16, mtu int) MultihopTun {
readRecv := make(chan packetBatch)
writeRecv := make(chan packetBatch)
endpoint, err := conn.NewStdNetBind().ParseEndpoint(netip.AddrPortFrom(remote, remotePort).String())
if err != nil {
panic("Failed to parse endpoint")
}
connectionId := uint16(rand.Uint32()>>16) | 1
shutdownChan := make(chan struct{})
return MultihopTun{
readRecv,
writeRecv,
local.Is4(),
local.AsSlice(),
0,
remote.AsSlice(),
remotePort,
connectionId,
make(chan tun.Event),
mtu,
endpoint,
atomic.Bool{},
shutdownChan,
}
}
func (st *MultihopTun) Binder() conn.Bind {
socketShutdown := make(chan struct{})
return &multihopBind{
st,
socketShutdown,
}
}
// Events implements tun.Device.
func (st *MultihopTun) Events() <-chan tun.Event {
return st.tunEvent
}
// File implements tun.Device.
func (*MultihopTun) File() *os.File {
return nil
}
// MTU implements tun.Device.
func (st *MultihopTun) MTU() (int, error) {
return st.mtu, nil
}
// Name implements tun.Device.
func (*MultihopTun) Name() (string, error) {
return "stun", nil
}
// Write implements tun.Device.
func (st *MultihopTun) Write(packets [][]byte, offset int) (int, error) {
for i, packet := range packets {
completion := make(chan packetBatch)
packetBatch := packetBatch{
packet: packet,
offset: offset,
size: len(packet),
completion: completion,
}
select {
case st.writeRecv <- packetBatch:
break
case <-st.shutdownChan:
return i, io.EOF
}
packetBatch, ok := <-completion
if !ok {
return i, io.EOF
}
}
return len(packets), nil
}
// Read implements tun.Device.
func (st *MultihopTun) Read(packet [][]byte, sizes []int, offset int) (n int, err error) {
completion := make(chan packetBatch)
packetBatch := packetBatch{
packet: packet[0],
size: 0,
offset: offset,
completion: completion,
}
select {
case st.readRecv <- packetBatch:
break
case <-st.shutdownChan:
return 0, io.EOF
}
var ok bool
packetBatch, ok = <-completion
if !ok {
return 0, io.EOF
}
sizes[0] = packetBatch.size
return 1, nil
}
func (st *MultihopTun) writePayload(target, payload []byte) (size int, err error) {
headerSize := st.headerSize()
if headerSize+len(payload) > len(target) {
err = errors.New(fmt.Sprintf("target buffer is too small, need %d, got %d", headerSize+len(payload), len(target)))
return
}
if st.isIpv4 {
return st.writeV4Payload(target, payload)
} else {
return st.writeV6Payload(target, payload)
}
}
func (st *MultihopTun) writeV4Payload(target, payload []byte) (size int, err error) {
var ipv4 header.IPv4
ipv4 = target
size = st.headerSize() + len(payload)
src := tcpip.AddrFrom4Slice(st.localIp)
dst := tcpip.AddrFrom4Slice(st.remoteIp)
fields := header.IPv4Fields{
// TODO: Figure out the best DSCP value, ideally would be 0x88 for handshakes and 0x00 for rest.
TOS: 0,
TotalLength: uint16(size),
ID: st.ipConnectionId,
TTL: 64,
Protocol: uint8(header.UDPProtocolNumber),
SrcAddr: src,
DstAddr: dst,
Checksum: 0,
}
ipv4.Encode(&fields)
ipv4.SetChecksum(^ipv4.CalculateChecksum())
st.writeUdpPayload(ipv4.Payload(), payload, src, dst)
return
}
func (st *MultihopTun) writeV6Payload(target, payload []byte) (size int, err error) {
var ipv6 header.IPv6
ipv6 = target
size = st.headerSize() + len(payload)
src := tcpip.AddrFrom4Slice(st.localIp)
dst := tcpip.AddrFrom4Slice(st.remoteIp)
fields := header.IPv6Fields{
TrafficClass: 0,
PayloadLength: uint16(len(payload)),
FlowLabel: uint32(st.ipConnectionId),
TransportProtocol: header.UDPProtocolNumber,
SrcAddr: src,
DstAddr: dst,
HopLimit: 64,
}
ipv6.Encode(&fields)
st.writeUdpPayload(ipv6.Payload(), payload, src, dst)
return
}
func (st *MultihopTun) writeUdpPayload(target header.UDP, payload []byte, src, dst tcpip.Address) {
target.Encode(&header.UDPFields{
SrcPort: st.localPort,
DstPort: st.remotePort,
Length: uint16(len(payload) + header.UDPMinimumSize),
Checksum: 0,
})
copy(target.Payload()[:], payload[:])
// Set the checksum field unless TX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
xsum := target.CalculateChecksum(checksum.Combine(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, src, dst, uint16(len(payload)+header.UDPMinimumSize)),
checksum.Checksum(target, 0),
))
// As per RFC 768 page 2,
//
// Checksum is the 16-bit one's complement of the one's complement sum of
// a pseudo header of information from the IP header, the UDP header, and
// the data, padded with zero octets at the end (if necessary) to make a
// multiple of two octets.
//
// The pseudo header conceptually prefixed to the UDP header contains the
// source address, the destination address, the protocol, and the UDP
// length. This information gives protection against misrouted datagrams.
// This checksum procedure is the same as is used in TCP.
//
// If the computed checksum is zero, it is transmitted as all ones (the
// equivalent in one's complement arithmetic). An all zero transmitted
// checksum value means that the transmitter generated no checksum (for
// debugging or for higher level protocols that don't care).
//
// To avoid the zero value, we only calculate the one's complement of the
// one's complement sum if the sum is not all ones.
if xsum != math.MaxUint16 {
xsum = ^xsum
}
target.SetChecksum(0)
}
func (st *MultihopTun) headerSize() int {
udpPacketSize := header.UDPMinimumSize
if st.isIpv4 {
return header.IPv4MinimumSize + udpPacketSize
} else {
return header.IPv6MinimumSize + udpPacketSize
}
}
// BatchSize implements conn.Bind.
func (*MultihopTun) BatchSize() int {
return 128
}
// BatchSize implements conn.Bind.
func (*MultihopTun) Flush() error {
return nil
}
// Close implements tun.Device
func (st *MultihopTun) Close() error {
if !st.closed.Load() {
st.closed.Store(true)
}
close(st.shutdownChan)
return nil
}

507
tun/multihoptun/tun_test.go Normal file
View file

@ -0,0 +1,507 @@
package multihoptun
import (
"bytes"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/netip"
"testing"
"time"
"golang.org/x/crypto/curve25519"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
func TestMultihopTunBind(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{192, 168, 1, 1})
virtualIp := netip.AddrFrom4([4]byte{192, 168, 1, 11})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
_ = device.NewDevice(&st, st.Binder(), device.NewLogger(device.LogLevelSilent, ""))
}
func TestMultihopTunTrafficV4(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
stBind := st.Binder()
virtualTun, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280)
// Pipe reads from virtualTun into multihop tun
go func() {
buf := make([][]byte, 1600)
sizes := make([]int, 1)
var err error
for err == nil {
_, err = virtualTun.Read(buf, sizes, 0)
_, err = st.Write(buf[:sizes[0]], 0)
}
}()
// Pipe reads from multihop tun into virtualTun
go func() {
buf := make([][]byte, 1600)
sizes := make([]int, 1)
var err error
for err == nil {
_, err = st.Read(buf, sizes, 0)
_, err = virtualTun.Write(buf[:sizes[0]], 0)
}
}()
recvFunc, _, err := stBind.Open(0)
if err != nil {
t.Fatalf("Failed to open port for multihop tun: %s", err)
}
payload := [][]byte{{1, 2, 3, 4}}
readyChan := make(chan struct{})
// Listen on the virtual tunnel
go func() {
conn, err := virtualNet.ListenUDPAddrPort(netip.AddrPortFrom(virtualIp, remotePort))
if err != nil {
panic(err)
}
readyChan <- struct{}{}
buff := make([]byte, 4)
n, addr, _ := conn.ReadFrom(buff)
if n == 0 {
fmt.Println("Did not receive anything")
}
conn.WriteTo(buff, addr)
}()
_, _ = <-readyChan
err = stBind.Send(payload, nil)
if err != nil {
t.Fatalf("Failed ot send traffic to multihop tun: %s", err)
}
recvBuf := make([][]byte, 1600)
sizes := make([]int, 1)
eps := make([]conn.Endpoint, 1)
_, err = recvFunc[0](recvBuf, sizes, eps)
packetSize := sizes[0]
if err != nil {
t.Fatalf("Failed to receive traffic from recvFunc - %s", err)
}
if packetSize != len(payload) {
t.Fatalf("Expected to recieve %d bytes, instead received %d", len(payload), packetSize)
}
for idx := range payload {
if payload[0][idx] != recvBuf[0][idx] {
t.Fatalf("Expected to receive %v, instead received %v", payload, recvBuf[0])
}
}
}
func TestReadEnd(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
stBind := st.Binder()
otherSt := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
readerDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
otherDev := device.NewDevice(&otherSt, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
configureDevices(t, readerDev, otherDev)
readerDev.Up()
receivers, port, err := stBind.Open(0)
if err != nil {
t.Fatalf("Failed to open UDP socket: %s", err)
}
if len(receivers) != 1 {
t.Fatalf("Expected 1 receiver func, got %v", len(receivers))
}
if port == 0 {
t.Fatalf("Expected a random port to be assigned, instead got 0")
}
buf := [][]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}
err = stBind.Send(buf, nil)
if err != nil {
t.Fatalf("Error when sending UDP traffic: %v", err)
}
}
func TestMultihopTunWrite(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
stBind := st.Binder()
receivers, port, err := stBind.Open(0)
if err != nil {
t.Fatalf("Failed to open UDP socket: %s", err)
}
if len(receivers) != 1 {
t.Fatalf("Expected 1 receiver func, got %v", len(receivers))
}
if port == 0 {
t.Fatalf("Expected a random port to be assigned, instead got 0")
}
udpPacket := [][]byte{{69, 0, 0, 32, 164, 27, 0, 0, 64, 17, 206, 165, 1, 2, 3, 5, 1, 2, 3, 4, 209, 129, 19, 141, 0, 12, 0, 0, 1, 2, 3, 4}}
if err != nil {
t.Fatalf("Error when sending UDP traffic: %v", err)
}
go func() {
st.Write(udpPacket, 0)
}()
buf := make([][]byte, 1600)
sizes := make([]int, 1)
eps := make([]conn.Endpoint, 1)
_, err = receivers[0](buf, sizes, eps)
packetSize := sizes[0]
if err != nil {
t.Fatalf("Failed to receive packets: %s", err)
}
expected := []byte{1, 2, 3, 4}
if len(buf[:packetSize]) != len(expected) {
t.Fatalf("Expected %v, got %v", expected, buf[0])
}
for b := range buf[:packetSize] {
if buf[0][b] != expected[b] {
t.Fatalf("Expected %v, got %v", expected, buf[0])
}
}
}
func TestMultihopTunRead(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
stBind := st.Binder()
_, _, err := stBind.Open(0)
if err != nil {
t.Fatalf("Failed to open UDP socket: %s", err)
}
payload := [][]byte{{1, 2, 3, 4}}
go stBind.Send(payload, nil)
bytes := make([][]byte, 1500, 1500)
sizes := make([]int, 1)
_, err = st.Read(bytes, sizes, 0)
bytesRead := sizes[0]
if err != nil {
t.Fatalf("Failed to read from tunnel device: %v", err)
}
packet := header.IPv4(bytes[0][:bytesRead])
virtualIpBytes, _ := virtualIp.MarshalBinary()
stIpBytes, _ := stIp.MarshalBinary()
if packet.SourceAddress() != tcpip.AddrFromSlice(stIpBytes) {
t.Fatalf("expected %v, got %v", stIp, packet.SourceAddress())
}
if packet.DestinationAddress() != tcpip.AddrFromSlice(virtualIpBytes) {
t.Fatalf("expected %v, got %v", virtualIp, packet.DestinationAddress())
}
}
func configureDevices(t testing.TB, aDev *device.Device, bDev *device.Device) {
configs, endpointConfigs, _ := genConfigs(t)
aConfig := configs[0] + endpointConfigs[0]
bConfig := configs[1] + endpointConfigs[1]
aDev.IpcSet(aConfig)
bDev.IpcSet(bConfig)
}
func genConfigsForMultihop(t testing.TB) ([4]string, [4]uint16) {
entryConfigs, entryEndpoints, entryPorts := genConfigs(t)
exitConfigs, exitEndpoints, exitPorts := genConfigs(t)
aExitConfig := exitConfigs[0] + exitEndpoints[0]
bExitConfig := exitConfigs[1] + exitEndpoints[1]
aEntryConfig := entryConfigs[0] + entryEndpoints[0]
bEntryConfig := entryConfigs[1] + entryEndpoints[1]
ports := [4]uint16{entryPorts[0], exitPorts[0], exitPorts[1], entryPorts[1]}
return [4]string{aEntryConfig, aExitConfig, bExitConfig, bEntryConfig}, ports
}
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string, ports [2]uint16) {
var key1, key2 device.NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
_, err = rand.Read(key2[:])
if err != nil {
tb.Errorf("unable to generate private key random bytes: %v", err)
}
ports[0] = getFreeLocalUdpPort(tb)
ports[1] = getFreeLocalUdpPort(tb)
pub1, pub2 := publicKey(&key1), publicKey(&key2)
cfgs[0] = uapiCfg(
"private_key", hex.EncodeToString(key1[:]),
"listen_port", fmt.Sprintf("%d", ports[0]),
"replace_peers", "true",
"public_key", hex.EncodeToString(pub2[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "0.0.0.0/0",
)
endpointCfgs[0] = uapiCfg(
"public_key", hex.EncodeToString(pub2[:]),
"endpoint", fmt.Sprintf("127.0.0.1:%d", ports[1]),
)
cfgs[1] = uapiCfg(
"private_key", hex.EncodeToString(key2[:]),
"listen_port", fmt.Sprintf("%d", ports[1]),
"replace_peers", "true",
"public_key", hex.EncodeToString(pub1[:]),
"protocol_version", "1",
"replace_allowed_ips", "true",
"allowed_ip", "0.0.0.0/0",
)
endpointCfgs[1] = uapiCfg(
"public_key", hex.EncodeToString(pub1[:]),
"endpoint", fmt.Sprintf("127.0.0.1:%d", ports[0]),
)
return
}
func publicKey(sk *device.NoisePrivateKey) (pk device.NoisePublicKey) {
apk := (*[device.NoisePublicKeySize]byte)(&pk)
ask := (*[device.NoisePrivateKeySize]byte)(sk)
curve25519.ScalarBaseMult(apk, ask)
return
}
func getFreeLocalUdpPort(t testing.TB) uint16 {
localAddr := netip.MustParseAddrPort("127.0.0.1:0")
udpSockAddr := net.UDPAddrFromAddrPort(localAddr)
udpConn, err := net.ListenUDP("udp4", udpSockAddr)
if err != nil {
t.Fatalf("Failed to open a UDP socket to assign an empty port")
}
defer udpConn.Close()
port := netip.MustParseAddrPort(udpConn.LocalAddr().String()).Port()
return port
}
func uapiCfg(cfg ...string) string {
if len(cfg)%2 != 0 {
panic("odd number of args to uapiReader")
}
buf := new(bytes.Buffer)
for i, s := range cfg {
buf.WriteString(s)
sep := byte('\n')
if i%2 == 0 {
sep = '='
}
buf.WriteByte(sep)
}
return buf.String()
}
func TestShutdown(t *testing.T) {
a, b := generateTestPair(t)
b.Close()
a.Close()
}
func TestReversedShutdown(t *testing.T) {
a, b := generateTestPair(t)
a.Close()
b.Close()
}
func generateTestPair(t *testing.T) (*device.Device, *device.Device) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
stBind := st.Binder()
virtualDev, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280)
readerDev := device.NewDevice(virtualDev, stBind, device.NewLogger(device.LogLevelSilent, ""))
otherDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
configureDevices(t, readerDev, otherDev)
readerDev.Up()
otherDev.Up()
conn, err := virtualNet.Dial("ping4", "10.64.0.1")
requestPing := icmp.Echo{
Seq: 345,
Data: []byte("gopher burrow"),
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
conn.SetReadDeadline(time.Now().Add(time.Second * 9))
_, err = conn.Write(icmpBytes)
if err != nil {
t.Fatal(err)
}
return readerDev, otherDev
}
func TestShutdownBind(t *testing.T) {
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
remotePort := uint16(5005)
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
binder := st.Binder()
recvFunc, _, err := binder.Open(0)
if err != nil {
t.Fatalf("Failed to open a UDP socket, %v", err)
}
st.Close()
buf := make([][]byte, 1600)
sizes := make([]int, 1)
eps := make([]conn.Endpoint, 1)
_, err = recvFunc[0](buf, sizes, eps)
neterr, ok := err.(net.Error)
if !ok {
t.Fatalf("Expected a net.Error, instead got %v", err)
}
if neterr.Temporary() {
t.Fatalf("Expected the net error to not be temporary")
}
}
func TestMultihopLocally(t *testing.T) {
aVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
bVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
configsForMultihop, ports := genConfigsForMultihop(t)
multihopA := NewMultihopTun(aVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[3], 1280)
multihopB := NewMultihopTun(bVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[0], 1280)
aBinder := multihopA.Binder()
bBinder := multihopB.Binder()
virtualDevA, virtualNetA, _ := netstack.CreateNetTUN([]netip.Addr{aVirtualIp}, []netip.Addr{}, 1280)
virtualDevB, virtualNetB, _ := netstack.CreateNetTUN([]netip.Addr{bVirtualIp}, []netip.Addr{}, 1280)
aExitDevice := device.NewDevice(virtualDevA, aBinder, device.NewLogger(device.LogLevelVerbose, ""))
aExitDevice.IpcSet(configsForMultihop[0])
aEntryDevice := device.NewDevice(&multihopA, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, ""))
aEntryDevice.IpcSet(configsForMultihop[1])
bEntryDevice := device.NewDevice(&multihopB, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, ""))
bEntryDevice.IpcSet(configsForMultihop[2])
bExitDevice := device.NewDevice(virtualDevB, bBinder, device.NewLogger(device.LogLevelVerbose, ""))
bExitDevice.IpcSet(configsForMultihop[3])
err := aExitDevice.Up()
if err != nil {
t.Fatalf("exit device a failed to up itself: %v", err)
}
err = aEntryDevice.Up()
if err != nil {
t.Fatalf("entry device a failed to up itself: %v", err)
}
err = bExitDevice.Up()
if err != nil {
t.Fatalf("exit device b failed to up itself: %v", err)
}
err = bEntryDevice.Up()
if err != nil {
t.Fatalf("entry device b failed to up itself: %v", err)
}
listenerAddr := netip.AddrPortFrom(bVirtualIp, 7070)
senderAddr := netip.AddrPortFrom(aVirtualIp, 4040)
listenerSocket, err := virtualNetB.ListenUDPAddrPort(netip.AddrPortFrom(bVirtualIp, 7070))
if err != nil {
t.Fatalf("Fail to open listener socket: %v", err)
}
senderSocket, err := virtualNetA.DialUDPAddrPort(senderAddr, listenerAddr)
if err != nil {
t.Fatalf("Failed to open sender socket: %v", err)
}
payload := []byte{1, 2, 3, 4, 5}
n, err := senderSocket.Write(payload)
if err != nil {
t.Fatalf("Failed to send payload: %v", err)
}
if n != len(payload) {
t.Fatalf("Expected to send %v bytes, instead sent %v", len(payload), n)
}
rxBuffer := []byte{1, 2, 3, 4, 5}
n, err = listenerSocket.Read(rxBuffer)
if err != nil {
t.Fatalf("Failed to receive payload: %v", err)
}
if n != len(payload) {
t.Fatalf("Expected to read %v bytes, instead read %v bytes", len(payload), n)
}
for idx := range rxBuffer {
if rxBuffer[idx] != payload[idx] {
t.Fatalf("At index %d, expected value %d, instead got %v", idx, rxBuffer[idx], payload[idx])
}
}
aEntryDevice.Close()
aExitDevice.Close()
bEntryDevice.Close()
bExitDevice.Close()
}