mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-16 22:16:55 +02:00
Port multihop tun
This commit is contained in:
parent
2e3f7d122c
commit
eece51b547
5 changed files with 943 additions and 0 deletions
1
go.mod
1
go.mod
|
@ -14,4 +14,5 @@ require (
|
|||
require (
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
)
|
||||
|
|
2
go.sum
2
go.sum
|
@ -12,5 +12,7 @@ golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0k
|
|||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
|
|
126
tun/multihoptun/bind.go
Normal file
126
tun/multihoptun/bind.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package multihoptun
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
type multihopBind struct {
|
||||
*MultihopTun
|
||||
socketShutdown chan struct{}
|
||||
}
|
||||
|
||||
// Close implements tun.Device
|
||||
func (st *multihopBind) Close() error {
|
||||
select {
|
||||
case <-st.socketShutdown:
|
||||
return nil
|
||||
default:
|
||||
close(st.socketShutdown)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Open implements conn.Bind.
|
||||
func (st *multihopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||
if port != 0 {
|
||||
st.localPort = port
|
||||
} else {
|
||||
st.localPort = uint16(rand.Uint32()>>16) | 1
|
||||
}
|
||||
// WireGuard will close existing sockets before bringing up a new device on Bind updates.
|
||||
// This guarantees that the socket shutdown channel is always available.
|
||||
st.socketShutdown = make(chan struct{})
|
||||
|
||||
actualPort = st.localPort
|
||||
fns = []conn.ReceiveFunc{
|
||||
func(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||
var batch packetBatch
|
||||
var ok bool
|
||||
|
||||
select {
|
||||
case <-st.shutdownChan:
|
||||
return 0, net.ErrClosed
|
||||
case <-st.socketShutdown:
|
||||
return 0, net.ErrClosed
|
||||
case batch, ok = <-st.writeRecv:
|
||||
break
|
||||
}
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
ipVersion := header.IPVersion(batch.packet[batch.offset:])
|
||||
if ipVersion == 4 {
|
||||
v4 := header.IPv4(batch.packet[batch.offset:])
|
||||
udp := header.UDP(v4.Payload())
|
||||
copy(packets[0], udp.Payload())
|
||||
sizes[0] = len(udp.Payload())
|
||||
|
||||
} else if ipVersion == 6 {
|
||||
v6 := header.IPv6(batch.packet[batch.offset:])
|
||||
udp := header.UDP(v6.Payload())
|
||||
copy(packets[0], udp.Payload())
|
||||
sizes[0] = len(udp.Payload())
|
||||
}
|
||||
batch.size = sizes[0]
|
||||
eps[0] = st.endpoint
|
||||
|
||||
batch.completion <- batch
|
||||
return 1, nil
|
||||
},
|
||||
}
|
||||
|
||||
return fns, actualPort, nil
|
||||
}
|
||||
|
||||
// ParseEndpoint implements conn.Bind.
|
||||
func (*multihopBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
return conn.NewStdNetBind().ParseEndpoint(s)
|
||||
}
|
||||
|
||||
// Send implements conn.Bind.
|
||||
func (st *multihopBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
var packetBatch packetBatch
|
||||
var ok bool
|
||||
|
||||
for _, buf := range bufs {
|
||||
select {
|
||||
case <-st.shutdownChan:
|
||||
return net.ErrClosed
|
||||
case <-st.socketShutdown:
|
||||
// it is important to return a net.ErrClosed, since it implements the
|
||||
// net.Error interface and indicates that it is not a recoverable error.
|
||||
// wg-go uses the net.Error interface to deduce if it should try to send
|
||||
// packets again after some time or if it should give up.
|
||||
return net.ErrClosed
|
||||
case packetBatch, ok = <-st.readRecv:
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
targetPacket := packetBatch.packet[packetBatch.offset:]
|
||||
size, err := st.writePayload(targetPacket, buf)
|
||||
|
||||
packetBatch.size = size
|
||||
|
||||
packetBatch.completion <- packetBatch
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMark implements conn.Bind.
|
||||
func (*multihopBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
}
|
307
tun/multihoptun/tun.go
Normal file
307
tun/multihoptun/tun.go
Normal file
|
@ -0,0 +1,307 @@
|
|||
package multihoptun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
// This is a special implementation of `tun.Device` that allows to connect a
|
||||
// `conn.Bind` from one WireGuard device to another's `tun.Device`. This way,
|
||||
// we can create a multi-hop WireGuard device that can use the same private key
|
||||
// and elude any MTU issues, since, for any single user packet, there is only
|
||||
// ever a single read from the real tunnel device needed to send it to the
|
||||
// entry hop.
|
||||
//
|
||||
// tun.Device.Write will push a buffer via writeRecv to be read by the recvfunc
|
||||
// of conn.Bind, stripping IPv4/IPv6 + UDP headers in the process. When the
|
||||
// packets have been transferred to the UDP receiver, writeDone will be used to
|
||||
// return from tun.Device.Write. Conversely, conn.Bind.Send will push a buffer
|
||||
// via readRecv to be read by tun.Device.Read, adding valid IPv4/IPv6 + UDP
|
||||
// headers in the process.
|
||||
//
|
||||
// Implements tun.Device and can create instances of conn.Bind.
|
||||
type MultihopTun struct {
|
||||
readRecv chan packetBatch
|
||||
writeRecv chan packetBatch
|
||||
isIpv4 bool
|
||||
localIp []byte
|
||||
localPort uint16
|
||||
remoteIp []byte
|
||||
remotePort uint16
|
||||
ipConnectionId uint16
|
||||
tunEvent chan tun.Event
|
||||
mtu int
|
||||
endpoint conn.Endpoint
|
||||
closed atomic.Bool
|
||||
shutdownChan chan struct{}
|
||||
}
|
||||
|
||||
type packetBatch struct {
|
||||
packet []byte
|
||||
size int
|
||||
offset int
|
||||
// to be used to return the packet batch back to tun.Read and tun.Write
|
||||
completion chan packetBatch
|
||||
}
|
||||
|
||||
func (pb *packetBatch) Size() int {
|
||||
return len(pb.packet)
|
||||
}
|
||||
|
||||
func NewMultihopTun(local, remote netip.Addr, remotePort uint16, mtu int) MultihopTun {
|
||||
readRecv := make(chan packetBatch)
|
||||
writeRecv := make(chan packetBatch)
|
||||
endpoint, err := conn.NewStdNetBind().ParseEndpoint(netip.AddrPortFrom(remote, remotePort).String())
|
||||
if err != nil {
|
||||
panic("Failed to parse endpoint")
|
||||
}
|
||||
|
||||
connectionId := uint16(rand.Uint32()>>16) | 1
|
||||
shutdownChan := make(chan struct{})
|
||||
|
||||
return MultihopTun{
|
||||
readRecv,
|
||||
writeRecv,
|
||||
local.Is4(),
|
||||
local.AsSlice(),
|
||||
0,
|
||||
remote.AsSlice(),
|
||||
remotePort,
|
||||
connectionId,
|
||||
make(chan tun.Event),
|
||||
mtu,
|
||||
endpoint,
|
||||
atomic.Bool{},
|
||||
shutdownChan,
|
||||
}
|
||||
}
|
||||
|
||||
func (st *MultihopTun) Binder() conn.Bind {
|
||||
socketShutdown := make(chan struct{})
|
||||
return &multihopBind{
|
||||
st,
|
||||
socketShutdown,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Events implements tun.Device.
|
||||
func (st *MultihopTun) Events() <-chan tun.Event {
|
||||
return st.tunEvent
|
||||
}
|
||||
|
||||
// File implements tun.Device.
|
||||
func (*MultihopTun) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MTU implements tun.Device.
|
||||
func (st *MultihopTun) MTU() (int, error) {
|
||||
return st.mtu, nil
|
||||
}
|
||||
|
||||
// Name implements tun.Device.
|
||||
func (*MultihopTun) Name() (string, error) {
|
||||
return "stun", nil
|
||||
}
|
||||
|
||||
// Write implements tun.Device.
|
||||
func (st *MultihopTun) Write(packets [][]byte, offset int) (int, error) {
|
||||
for i, packet := range packets {
|
||||
completion := make(chan packetBatch)
|
||||
packetBatch := packetBatch{
|
||||
packet: packet,
|
||||
offset: offset,
|
||||
size: len(packet),
|
||||
completion: completion,
|
||||
}
|
||||
|
||||
select {
|
||||
case st.writeRecv <- packetBatch:
|
||||
break
|
||||
case <-st.shutdownChan:
|
||||
return i, io.EOF
|
||||
}
|
||||
|
||||
packetBatch, ok := <-completion
|
||||
|
||||
if !ok {
|
||||
return i, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
return len(packets), nil
|
||||
}
|
||||
|
||||
// Read implements tun.Device.
|
||||
func (st *MultihopTun) Read(packet [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
completion := make(chan packetBatch)
|
||||
packetBatch := packetBatch{
|
||||
packet: packet[0],
|
||||
size: 0,
|
||||
offset: offset,
|
||||
completion: completion,
|
||||
}
|
||||
|
||||
select {
|
||||
case st.readRecv <- packetBatch:
|
||||
break
|
||||
case <-st.shutdownChan:
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
var ok bool
|
||||
packetBatch, ok = <-completion
|
||||
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
sizes[0] = packetBatch.size
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (st *MultihopTun) writePayload(target, payload []byte) (size int, err error) {
|
||||
headerSize := st.headerSize()
|
||||
if headerSize+len(payload) > len(target) {
|
||||
err = errors.New(fmt.Sprintf("target buffer is too small, need %d, got %d", headerSize+len(payload), len(target)))
|
||||
return
|
||||
}
|
||||
|
||||
if st.isIpv4 {
|
||||
return st.writeV4Payload(target, payload)
|
||||
} else {
|
||||
return st.writeV6Payload(target, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *MultihopTun) writeV4Payload(target, payload []byte) (size int, err error) {
|
||||
var ipv4 header.IPv4
|
||||
ipv4 = target
|
||||
|
||||
size = st.headerSize() + len(payload)
|
||||
src := tcpip.AddrFrom4Slice(st.localIp)
|
||||
dst := tcpip.AddrFrom4Slice(st.remoteIp)
|
||||
fields := header.IPv4Fields{
|
||||
// TODO: Figure out the best DSCP value, ideally would be 0x88 for handshakes and 0x00 for rest.
|
||||
TOS: 0,
|
||||
TotalLength: uint16(size),
|
||||
ID: st.ipConnectionId,
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: src,
|
||||
DstAddr: dst,
|
||||
Checksum: 0,
|
||||
}
|
||||
ipv4.Encode(&fields)
|
||||
ipv4.SetChecksum(^ipv4.CalculateChecksum())
|
||||
st.writeUdpPayload(ipv4.Payload(), payload, src, dst)
|
||||
return
|
||||
}
|
||||
|
||||
func (st *MultihopTun) writeV6Payload(target, payload []byte) (size int, err error) {
|
||||
|
||||
var ipv6 header.IPv6
|
||||
ipv6 = target
|
||||
|
||||
size = st.headerSize() + len(payload)
|
||||
src := tcpip.AddrFrom4Slice(st.localIp)
|
||||
dst := tcpip.AddrFrom4Slice(st.remoteIp)
|
||||
fields := header.IPv6Fields{
|
||||
TrafficClass: 0,
|
||||
PayloadLength: uint16(len(payload)),
|
||||
FlowLabel: uint32(st.ipConnectionId),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
SrcAddr: src,
|
||||
DstAddr: dst,
|
||||
HopLimit: 64,
|
||||
}
|
||||
ipv6.Encode(&fields)
|
||||
|
||||
st.writeUdpPayload(ipv6.Payload(), payload, src, dst)
|
||||
return
|
||||
}
|
||||
|
||||
func (st *MultihopTun) writeUdpPayload(target header.UDP, payload []byte, src, dst tcpip.Address) {
|
||||
target.Encode(&header.UDPFields{
|
||||
SrcPort: st.localPort,
|
||||
DstPort: st.remotePort,
|
||||
Length: uint16(len(payload) + header.UDPMinimumSize),
|
||||
Checksum: 0,
|
||||
})
|
||||
copy(target.Payload()[:], payload[:])
|
||||
|
||||
// Set the checksum field unless TX checksum offload is enabled.
|
||||
// On IPv4, UDP checksum is optional, and a zero value indicates the
|
||||
// transmitter skipped the checksum generation (RFC768).
|
||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
||||
xsum := target.CalculateChecksum(checksum.Combine(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, src, dst, uint16(len(payload)+header.UDPMinimumSize)),
|
||||
checksum.Checksum(target, 0),
|
||||
))
|
||||
// As per RFC 768 page 2,
|
||||
//
|
||||
// Checksum is the 16-bit one's complement of the one's complement sum of
|
||||
// a pseudo header of information from the IP header, the UDP header, and
|
||||
// the data, padded with zero octets at the end (if necessary) to make a
|
||||
// multiple of two octets.
|
||||
//
|
||||
// The pseudo header conceptually prefixed to the UDP header contains the
|
||||
// source address, the destination address, the protocol, and the UDP
|
||||
// length. This information gives protection against misrouted datagrams.
|
||||
// This checksum procedure is the same as is used in TCP.
|
||||
//
|
||||
// If the computed checksum is zero, it is transmitted as all ones (the
|
||||
// equivalent in one's complement arithmetic). An all zero transmitted
|
||||
// checksum value means that the transmitter generated no checksum (for
|
||||
// debugging or for higher level protocols that don't care).
|
||||
//
|
||||
// To avoid the zero value, we only calculate the one's complement of the
|
||||
// one's complement sum if the sum is not all ones.
|
||||
if xsum != math.MaxUint16 {
|
||||
xsum = ^xsum
|
||||
}
|
||||
target.SetChecksum(0)
|
||||
|
||||
}
|
||||
|
||||
func (st *MultihopTun) headerSize() int {
|
||||
udpPacketSize := header.UDPMinimumSize
|
||||
if st.isIpv4 {
|
||||
return header.IPv4MinimumSize + udpPacketSize
|
||||
} else {
|
||||
return header.IPv6MinimumSize + udpPacketSize
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSize implements conn.Bind.
|
||||
func (*MultihopTun) BatchSize() int {
|
||||
return 128
|
||||
}
|
||||
|
||||
// BatchSize implements conn.Bind.
|
||||
func (*MultihopTun) Flush() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements tun.Device
|
||||
func (st *MultihopTun) Close() error {
|
||||
if !st.closed.Load() {
|
||||
st.closed.Store(true)
|
||||
}
|
||||
close(st.shutdownChan)
|
||||
return nil
|
||||
}
|
507
tun/multihoptun/tun_test.go
Normal file
507
tun/multihoptun/tun_test.go
Normal file
|
@ -0,0 +1,507 @@
|
|||
package multihoptun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
func TestMultihopTunBind(t *testing.T) {
|
||||
stIp := netip.AddrFrom4([4]byte{192, 168, 1, 1})
|
||||
virtualIp := netip.AddrFrom4([4]byte{192, 168, 1, 11})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
|
||||
_ = device.NewDevice(&st, st.Binder(), device.NewLogger(device.LogLevelSilent, ""))
|
||||
}
|
||||
|
||||
func TestMultihopTunTrafficV4(t *testing.T) {
|
||||
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
stBind := st.Binder()
|
||||
|
||||
virtualTun, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280)
|
||||
|
||||
// Pipe reads from virtualTun into multihop tun
|
||||
go func() {
|
||||
buf := make([][]byte, 1600)
|
||||
sizes := make([]int, 1)
|
||||
var err error
|
||||
for err == nil {
|
||||
_, err = virtualTun.Read(buf, sizes, 0)
|
||||
_, err = st.Write(buf[:sizes[0]], 0)
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
// Pipe reads from multihop tun into virtualTun
|
||||
go func() {
|
||||
buf := make([][]byte, 1600)
|
||||
sizes := make([]int, 1)
|
||||
var err error
|
||||
for err == nil {
|
||||
_, err = st.Read(buf, sizes, 0)
|
||||
_, err = virtualTun.Write(buf[:sizes[0]], 0)
|
||||
}
|
||||
}()
|
||||
|
||||
recvFunc, _, err := stBind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open port for multihop tun: %s", err)
|
||||
}
|
||||
|
||||
payload := [][]byte{{1, 2, 3, 4}}
|
||||
readyChan := make(chan struct{})
|
||||
// Listen on the virtual tunnel
|
||||
go func() {
|
||||
conn, err := virtualNet.ListenUDPAddrPort(netip.AddrPortFrom(virtualIp, remotePort))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
readyChan <- struct{}{}
|
||||
buff := make([]byte, 4)
|
||||
n, addr, _ := conn.ReadFrom(buff)
|
||||
if n == 0 {
|
||||
fmt.Println("Did not receive anything")
|
||||
}
|
||||
|
||||
conn.WriteTo(buff, addr)
|
||||
}()
|
||||
_, _ = <-readyChan
|
||||
|
||||
err = stBind.Send(payload, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed ot send traffic to multihop tun: %s", err)
|
||||
}
|
||||
|
||||
recvBuf := make([][]byte, 1600)
|
||||
sizes := make([]int, 1)
|
||||
eps := make([]conn.Endpoint, 1)
|
||||
_, err = recvFunc[0](recvBuf, sizes, eps)
|
||||
packetSize := sizes[0]
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive traffic from recvFunc - %s", err)
|
||||
}
|
||||
if packetSize != len(payload) {
|
||||
t.Fatalf("Expected to recieve %d bytes, instead received %d", len(payload), packetSize)
|
||||
}
|
||||
|
||||
for idx := range payload {
|
||||
if payload[0][idx] != recvBuf[0][idx] {
|
||||
t.Fatalf("Expected to receive %v, instead received %v", payload, recvBuf[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEnd(t *testing.T) {
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
stBind := st.Binder()
|
||||
otherSt := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
|
||||
readerDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
|
||||
otherDev := device.NewDevice(&otherSt, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
|
||||
|
||||
configureDevices(t, readerDev, otherDev)
|
||||
|
||||
readerDev.Up()
|
||||
receivers, port, err := stBind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open UDP socket: %s", err)
|
||||
}
|
||||
if len(receivers) != 1 {
|
||||
t.Fatalf("Expected 1 receiver func, got %v", len(receivers))
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
t.Fatalf("Expected a random port to be assigned, instead got 0")
|
||||
}
|
||||
|
||||
buf := [][]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}
|
||||
|
||||
err = stBind.Send(buf, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when sending UDP traffic: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultihopTunWrite(t *testing.T) {
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
stBind := st.Binder()
|
||||
|
||||
receivers, port, err := stBind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open UDP socket: %s", err)
|
||||
}
|
||||
if len(receivers) != 1 {
|
||||
t.Fatalf("Expected 1 receiver func, got %v", len(receivers))
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
t.Fatalf("Expected a random port to be assigned, instead got 0")
|
||||
}
|
||||
|
||||
udpPacket := [][]byte{{69, 0, 0, 32, 164, 27, 0, 0, 64, 17, 206, 165, 1, 2, 3, 5, 1, 2, 3, 4, 209, 129, 19, 141, 0, 12, 0, 0, 1, 2, 3, 4}}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error when sending UDP traffic: %v", err)
|
||||
}
|
||||
go func() {
|
||||
st.Write(udpPacket, 0)
|
||||
}()
|
||||
|
||||
buf := make([][]byte, 1600)
|
||||
sizes := make([]int, 1)
|
||||
eps := make([]conn.Endpoint, 1)
|
||||
|
||||
_, err = receivers[0](buf, sizes, eps)
|
||||
packetSize := sizes[0]
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive packets: %s", err)
|
||||
}
|
||||
|
||||
expected := []byte{1, 2, 3, 4}
|
||||
if len(buf[:packetSize]) != len(expected) {
|
||||
t.Fatalf("Expected %v, got %v", expected, buf[0])
|
||||
}
|
||||
|
||||
for b := range buf[:packetSize] {
|
||||
if buf[0][b] != expected[b] {
|
||||
t.Fatalf("Expected %v, got %v", expected, buf[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultihopTunRead(t *testing.T) {
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
stBind := st.Binder()
|
||||
|
||||
_, _, err := stBind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open UDP socket: %s", err)
|
||||
}
|
||||
|
||||
payload := [][]byte{{1, 2, 3, 4}}
|
||||
go stBind.Send(payload, nil)
|
||||
|
||||
bytes := make([][]byte, 1500, 1500)
|
||||
sizes := make([]int, 1)
|
||||
_, err = st.Read(bytes, sizes, 0)
|
||||
bytesRead := sizes[0]
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read from tunnel device: %v", err)
|
||||
}
|
||||
|
||||
packet := header.IPv4(bytes[0][:bytesRead])
|
||||
virtualIpBytes, _ := virtualIp.MarshalBinary()
|
||||
stIpBytes, _ := stIp.MarshalBinary()
|
||||
|
||||
if packet.SourceAddress() != tcpip.AddrFromSlice(stIpBytes) {
|
||||
t.Fatalf("expected %v, got %v", stIp, packet.SourceAddress())
|
||||
}
|
||||
|
||||
if packet.DestinationAddress() != tcpip.AddrFromSlice(virtualIpBytes) {
|
||||
t.Fatalf("expected %v, got %v", virtualIp, packet.DestinationAddress())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func configureDevices(t testing.TB, aDev *device.Device, bDev *device.Device) {
|
||||
configs, endpointConfigs, _ := genConfigs(t)
|
||||
aConfig := configs[0] + endpointConfigs[0]
|
||||
bConfig := configs[1] + endpointConfigs[1]
|
||||
aDev.IpcSet(aConfig)
|
||||
bDev.IpcSet(bConfig)
|
||||
}
|
||||
|
||||
func genConfigsForMultihop(t testing.TB) ([4]string, [4]uint16) {
|
||||
entryConfigs, entryEndpoints, entryPorts := genConfigs(t)
|
||||
exitConfigs, exitEndpoints, exitPorts := genConfigs(t)
|
||||
|
||||
aExitConfig := exitConfigs[0] + exitEndpoints[0]
|
||||
bExitConfig := exitConfigs[1] + exitEndpoints[1]
|
||||
aEntryConfig := entryConfigs[0] + entryEndpoints[0]
|
||||
bEntryConfig := entryConfigs[1] + entryEndpoints[1]
|
||||
|
||||
ports := [4]uint16{entryPorts[0], exitPorts[0], exitPorts[1], entryPorts[1]}
|
||||
|
||||
return [4]string{aEntryConfig, aExitConfig, bExitConfig, bEntryConfig}, ports
|
||||
|
||||
}
|
||||
|
||||
// genConfigs generates a pair of configs that connect to each other.
|
||||
// The configs use distinct, probably-usable ports.
|
||||
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string, ports [2]uint16) {
|
||||
var key1, key2 device.NoisePrivateKey
|
||||
|
||||
_, err := rand.Read(key1[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
_, err = rand.Read(key2[:])
|
||||
if err != nil {
|
||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||
}
|
||||
|
||||
ports[0] = getFreeLocalUdpPort(tb)
|
||||
ports[1] = getFreeLocalUdpPort(tb)
|
||||
|
||||
pub1, pub2 := publicKey(&key1), publicKey(&key2)
|
||||
|
||||
cfgs[0] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key1[:]),
|
||||
"listen_port", fmt.Sprintf("%d", ports[0]),
|
||||
"replace_peers", "true",
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "0.0.0.0/0",
|
||||
)
|
||||
endpointCfgs[0] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub2[:]),
|
||||
"endpoint", fmt.Sprintf("127.0.0.1:%d", ports[1]),
|
||||
)
|
||||
cfgs[1] = uapiCfg(
|
||||
"private_key", hex.EncodeToString(key2[:]),
|
||||
"listen_port", fmt.Sprintf("%d", ports[1]),
|
||||
"replace_peers", "true",
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"protocol_version", "1",
|
||||
"replace_allowed_ips", "true",
|
||||
"allowed_ip", "0.0.0.0/0",
|
||||
)
|
||||
endpointCfgs[1] = uapiCfg(
|
||||
"public_key", hex.EncodeToString(pub1[:]),
|
||||
"endpoint", fmt.Sprintf("127.0.0.1:%d", ports[0]),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func publicKey(sk *device.NoisePrivateKey) (pk device.NoisePublicKey) {
|
||||
apk := (*[device.NoisePublicKeySize]byte)(&pk)
|
||||
ask := (*[device.NoisePrivateKeySize]byte)(sk)
|
||||
curve25519.ScalarBaseMult(apk, ask)
|
||||
return
|
||||
}
|
||||
|
||||
func getFreeLocalUdpPort(t testing.TB) uint16 {
|
||||
localAddr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
udpSockAddr := net.UDPAddrFromAddrPort(localAddr)
|
||||
udpConn, err := net.ListenUDP("udp4", udpSockAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open a UDP socket to assign an empty port")
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
port := netip.MustParseAddrPort(udpConn.LocalAddr().String()).Port()
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
func uapiCfg(cfg ...string) string {
|
||||
if len(cfg)%2 != 0 {
|
||||
panic("odd number of args to uapiReader")
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
for i, s := range cfg {
|
||||
buf.WriteString(s)
|
||||
sep := byte('\n')
|
||||
if i%2 == 0 {
|
||||
sep = '='
|
||||
}
|
||||
buf.WriteByte(sep)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
a, b := generateTestPair(t)
|
||||
b.Close()
|
||||
a.Close()
|
||||
}
|
||||
|
||||
func TestReversedShutdown(t *testing.T) {
|
||||
a, b := generateTestPair(t)
|
||||
a.Close()
|
||||
b.Close()
|
||||
}
|
||||
|
||||
func generateTestPair(t *testing.T) (*device.Device, *device.Device) {
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
stBind := st.Binder()
|
||||
|
||||
virtualDev, virtualNet, _ := netstack.CreateNetTUN([]netip.Addr{virtualIp}, []netip.Addr{}, 1280)
|
||||
|
||||
readerDev := device.NewDevice(virtualDev, stBind, device.NewLogger(device.LogLevelSilent, ""))
|
||||
otherDev := device.NewDevice(&st, conn.NewStdNetBind(), device.NewLogger(device.LogLevelSilent, ""))
|
||||
|
||||
configureDevices(t, readerDev, otherDev)
|
||||
|
||||
readerDev.Up()
|
||||
otherDev.Up()
|
||||
|
||||
conn, err := virtualNet.Dial("ping4", "10.64.0.1")
|
||||
requestPing := icmp.Echo{
|
||||
Seq: 345,
|
||||
Data: []byte("gopher burrow"),
|
||||
}
|
||||
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
||||
conn.SetReadDeadline(time.Now().Add(time.Second * 9))
|
||||
_, err = conn.Write(icmpBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return readerDev, otherDev
|
||||
}
|
||||
|
||||
func TestShutdownBind(t *testing.T) {
|
||||
stIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
virtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
remotePort := uint16(5005)
|
||||
|
||||
st := NewMultihopTun(stIp, virtualIp, remotePort, 1280)
|
||||
binder := st.Binder()
|
||||
recvFunc, _, err := binder.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open a UDP socket, %v", err)
|
||||
}
|
||||
|
||||
st.Close()
|
||||
|
||||
buf := make([][]byte, 1600)
|
||||
sizes := make([]int, 1)
|
||||
eps := make([]conn.Endpoint, 1)
|
||||
_, err = recvFunc[0](buf, sizes, eps)
|
||||
neterr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("Expected a net.Error, instead got %v", err)
|
||||
}
|
||||
if neterr.Temporary() {
|
||||
t.Fatalf("Expected the net error to not be temporary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultihopLocally(t *testing.T) {
|
||||
aVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 5})
|
||||
bVirtualIp := netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
|
||||
configsForMultihop, ports := genConfigsForMultihop(t)
|
||||
|
||||
multihopA := NewMultihopTun(aVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[3], 1280)
|
||||
multihopB := NewMultihopTun(bVirtualIp, netip.MustParseAddr(fmt.Sprintf("127.0.0.1")), ports[0], 1280)
|
||||
aBinder := multihopA.Binder()
|
||||
bBinder := multihopB.Binder()
|
||||
|
||||
virtualDevA, virtualNetA, _ := netstack.CreateNetTUN([]netip.Addr{aVirtualIp}, []netip.Addr{}, 1280)
|
||||
virtualDevB, virtualNetB, _ := netstack.CreateNetTUN([]netip.Addr{bVirtualIp}, []netip.Addr{}, 1280)
|
||||
|
||||
aExitDevice := device.NewDevice(virtualDevA, aBinder, device.NewLogger(device.LogLevelVerbose, ""))
|
||||
aExitDevice.IpcSet(configsForMultihop[0])
|
||||
|
||||
aEntryDevice := device.NewDevice(&multihopA, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
||||
aEntryDevice.IpcSet(configsForMultihop[1])
|
||||
|
||||
bEntryDevice := device.NewDevice(&multihopB, conn.NewStdNetBind(), device.NewLogger(device.LogLevelVerbose, ""))
|
||||
bEntryDevice.IpcSet(configsForMultihop[2])
|
||||
|
||||
bExitDevice := device.NewDevice(virtualDevB, bBinder, device.NewLogger(device.LogLevelVerbose, ""))
|
||||
bExitDevice.IpcSet(configsForMultihop[3])
|
||||
|
||||
err := aExitDevice.Up()
|
||||
if err != nil {
|
||||
t.Fatalf("exit device a failed to up itself: %v", err)
|
||||
}
|
||||
|
||||
err = aEntryDevice.Up()
|
||||
if err != nil {
|
||||
t.Fatalf("entry device a failed to up itself: %v", err)
|
||||
}
|
||||
|
||||
err = bExitDevice.Up()
|
||||
if err != nil {
|
||||
t.Fatalf("exit device b failed to up itself: %v", err)
|
||||
}
|
||||
|
||||
err = bEntryDevice.Up()
|
||||
if err != nil {
|
||||
t.Fatalf("entry device b failed to up itself: %v", err)
|
||||
}
|
||||
|
||||
listenerAddr := netip.AddrPortFrom(bVirtualIp, 7070)
|
||||
senderAddr := netip.AddrPortFrom(aVirtualIp, 4040)
|
||||
listenerSocket, err := virtualNetB.ListenUDPAddrPort(netip.AddrPortFrom(bVirtualIp, 7070))
|
||||
if err != nil {
|
||||
t.Fatalf("Fail to open listener socket: %v", err)
|
||||
}
|
||||
|
||||
senderSocket, err := virtualNetA.DialUDPAddrPort(senderAddr, listenerAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open sender socket: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte{1, 2, 3, 4, 5}
|
||||
|
||||
n, err := senderSocket.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send payload: %v", err)
|
||||
}
|
||||
|
||||
if n != len(payload) {
|
||||
t.Fatalf("Expected to send %v bytes, instead sent %v", len(payload), n)
|
||||
}
|
||||
|
||||
rxBuffer := []byte{1, 2, 3, 4, 5}
|
||||
n, err = listenerSocket.Read(rxBuffer)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive payload: %v", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("Expected to read %v bytes, instead read %v bytes", len(payload), n)
|
||||
}
|
||||
|
||||
for idx := range rxBuffer {
|
||||
if rxBuffer[idx] != payload[idx] {
|
||||
t.Fatalf("At index %d, expected value %d, instead got %v", idx, rxBuffer[idx], payload[idx])
|
||||
}
|
||||
}
|
||||
|
||||
aEntryDevice.Close()
|
||||
aExitDevice.Close()
|
||||
bEntryDevice.Close()
|
||||
bExitDevice.Close()
|
||||
}
|
Loading…
Add table
Reference in a new issue