Merge pull request #4 from amnezia-vpn/upstream-merge

Upstream merge
This commit is contained in:
pokamest 2023-11-30 10:49:31 -08:00 committed by GitHub
commit b43118018e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 1110 additions and 397 deletions

View file

@ -8,6 +8,7 @@ package conn
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
@ -29,16 +30,19 @@ var (
// methods for sending and receiving multiple datagrams per-syscall. See the
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
type StdNetBind struct {
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
mu sync.Mutex // protects all fields except as specified
ipv4 *net.UDPConn
ipv6 *net.UDPConn
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
ipv4TxOffload bool
ipv4RxOffload bool
ipv6TxOffload bool
ipv6RxOffload bool
// these three fields are not guarded by mu
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// these two fields are not guarded by mu
udpAddrPool sync.Pool
msgsPool sync.Pool
blackhole4 bool
blackhole6 bool
@ -54,23 +58,14 @@ func NewStdNetBind() Bind {
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
msgsPool: sync.Pool{
New: func() any {
// ipv6.Message and ipv4.Message are interchangeable as they are
// both aliases for x/net/internal/socket.Message.
msgs := make([]ipv6.Message, IdealBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
}
return &msgs
},
@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
@ -179,19 +174,21 @@ again:
}
var fns []ReceiveFunc
if v4conn != nil {
if runtime.GOOS == "linux" {
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v4pc = ipv4.NewPacketConn(v4conn)
s.ipv4PC = v4pc
}
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
s.ipv4 = v4conn
}
if v6conn != nil {
if runtime.GOOS == "linux" {
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
v6pc = ipv6.NewPacketConn(v6conn)
s.ipv6PC = v6pc
}
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
s.ipv6 = v6conn
}
if len(fns) == 0 {
@ -201,76 +198,101 @@ again:
return fns, uint16(port), nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
s.msgsPool.Put(msgs)
}
func (s *StdNetBind) getMessages() *[]ipv6.Message {
return s.msgsPool.Get().(*[]ipv6.Message)
}
var (
// If compilation fails here these are no longer the same underlying type.
_ ipv6.Message = ipv4.Message{}
)
type batchReader interface {
ReadBatch([]ipv6.Message, int) (int, error)
}
type batchWriter interface {
WriteBatch([]ipv6.Message, int) (int, error)
}
func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
numMsgs, err = br.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
return numMsgs, nil
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
getSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
}
}
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
// rename the IdealBatchSize constant to BatchSize.
func (s *StdNetBind) BatchSize() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return IdealBatchSize
}
return 1
@ -293,28 +315,42 @@ func (s *StdNetBind) Close() error {
}
s.blackhole4 = false
s.blackhole6 = false
s.ipv4TxOffload = false
s.ipv4RxOffload = false
s.ipv6TxOffload = false
s.ipv6RxOffload = false
if err1 != nil {
return err1
}
return err2
}
type ErrUDPGSODisabled struct {
onLaddr string
RetryErr error
}
func (e ErrUDPGSODisabled) Error() string {
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
}
func (e ErrUDPGSODisabled) Unwrap() error {
return e.RetryErr
}
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
var (
pc4 *ipv4.PacketConn
pc6 *ipv6.PacketConn
)
offload := s.ipv4TxOffload
br := batchWriter(s.ipv4PC)
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
pc6 = s.ipv6PC
br = s.ipv6PC
is6 = true
} else {
pc4 = s.ipv4PC
offload = s.ipv6TxOffload
}
s.mu.Unlock()
@ -324,85 +360,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
msgs := s.getMessages()
defer s.putMessages(msgs)
ua := s.udpAddrPool.Get().(*net.UDPAddr)
defer s.udpAddrPool.Put(ua)
if is6 {
return s.send6(conn, pc6, endpoint, bufs)
as16 := endpoint.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
} else {
return s.send4(conn, pc4, endpoint, bufs)
as4 := endpoint.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
}
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
var (
retried bool
err error
)
retry:
if offload {
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
err = s.send(conn, br, (*msgs)[:n])
if err != nil && offload && errShouldDisableUDPGSO(err) {
offload = false
s.mu.Lock()
if is6 {
s.ipv6TxOffload = false
} else {
s.ipv4TxOffload = false
}
s.mu.Unlock()
retried = true
goto retry
}
} else {
for i := range bufs {
(*msgs)[i].Addr = ua
(*msgs)[i].Buffers[0] = bufs[i]
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
}
err = s.send(conn, br, (*msgs)[:len(bufs)])
}
if retried {
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
}
return err
}
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
var (
n int
err error
start int
)
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
n, err = pc.WriteBatch(msgs[start:], 0)
if err != nil || n == len(msgs[start:]) {
break
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
for _, msg := range msgs {
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
if err != nil {
break
}
}
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buf := range bufs {
(*msgs)[i].Buffers[0] = buf
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
const (
// Exceeding these values results in EMSGSIZE. They account for layer3 and
// layer4 headers. IPv6 does not need to account for itself as the payload
// length field is self excluding.
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
maxIPv6PayloadLen = 1<<16 - 1 - 8
// This is a hard limit imposed by the kernel.
udpSegmentMaxDatagrams = 64
)
type setGSOFunc func(control *[]byte, gsoSize uint16)
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
var (
n int
err error
start int
base = -1 // index of msg we are currently coalescing into
gsoSize int // segmentation size of msgs[base]
dgramCnt int // number of dgrams coalesced into msgs[base]
endBatch bool // tracking flag to start a new batch on next iteration of bufs
)
if runtime.GOOS == "linux" {
for {
n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
if err != nil || n == len((*msgs)[start:len(bufs)]) {
break
maxPayloadLen := maxIPv4PayloadLen
if ep.DstIP().Is6() {
maxPayloadLen = maxIPv6PayloadLen
}
for i, buf := range bufs {
if i > 0 {
msgLen := len(buf)
baseLenBefore := len(msgs[base].Buffers[0])
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
if msgLen+baseLenBefore <= maxPayloadLen &&
msgLen <= gsoSize &&
msgLen <= freeBaseCap &&
dgramCnt < udpSegmentMaxDatagrams &&
!endBatch {
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
if i == len(bufs)-1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
dgramCnt++
if msgLen < gsoSize {
// A smaller than gsoSize packet on the tail is legal, but
// it must end the batch.
endBatch = true
}
continue
}
start += n
}
} else {
for i, buf := range bufs {
_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
if err != nil {
break
if dgramCnt > 1 {
setGSO(&msgs[base].OOB, uint16(gsoSize))
}
// Reset prior to incrementing base since we are preparing to start a
// new potential batch.
endBatch = false
base++
gsoSize = len(buf)
setSrcControl(&msgs[base].OOB, ep)
msgs[base].Buffers[0] = buf
msgs[base].Addr = addr
dgramCnt = 1
}
return base + 1
}
type getGSOFunc func(control []byte) (int, error)
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
for i := firstMsgAt; i < len(msgs); i++ {
msg := &msgs[i]
if msg.N == 0 {
return n, err
}
var (
gsoSize int
start int
end = msg.N
numToSplit = 1
)
gsoSize, err = getGSO(msg.OOB[:msg.NN])
if err != nil {
return n, err
}
if gsoSize > 0 {
numToSplit = (msg.N + gsoSize - 1) / gsoSize
end = gsoSize
}
for j := 0; j < numToSplit; j++ {
if n > i {
return n, errors.New("splitting coalesced packet resulted in overflow")
}
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
msgs[n].N = copied
msgs[n].Addr = msg.Addr
start = end
end += gsoSize
if end > msg.N {
end = msg.N
}
n++
}
if i != n-1 {
// It is legal for bytes to move within msg.Buffers[0] as a result
// of splitting, so we only zero the source msg len when it is not
// the destination of the last split operation above.
msg.N = 0
}
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
return n, nil
}

View file

@ -1,6 +1,12 @@
package conn
import "testing"
import (
"encoding/binary"
"net"
"testing"
"golang.org/x/net/ipv6"
)
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
bind := NewStdNetBind().(*StdNetBind)
@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
fn(bufs, sizes, eps)
}
}
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
*control = (*control)[:cap(*control)]
binary.LittleEndian.PutUint16(*control, gsoSize)
}
func Test_coalesceMessages(t *testing.T) {
cases := []struct {
name string
buffs [][]byte
wantLens []int
wantGSO []int
}{
{
name: "one message no coalesce",
buffs: [][]byte{
make([]byte, 1, 1),
},
wantLens: []int{1},
wantGSO: []int{0},
},
{
name: "two messages equal len coalesce",
buffs: [][]byte{
make([]byte, 1, 2),
make([]byte, 1, 1),
},
wantLens: []int{2},
wantGSO: []int{1},
},
{
name: "two messages unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
},
wantLens: []int{3},
wantGSO: []int{2},
},
{
name: "three messages second unequal len coalesce",
buffs: [][]byte{
make([]byte, 2, 3),
make([]byte, 1, 1),
make([]byte, 2, 2),
},
wantLens: []int{3, 2},
wantGSO: []int{2, 0},
},
{
name: "three messages limited cap coalesce",
buffs: [][]byte{
make([]byte, 2, 4),
make([]byte, 2, 2),
make([]byte, 2, 2),
},
wantLens: []int{4, 2},
wantGSO: []int{2, 0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1").To4(),
Port: 1,
}
msgs := make([]ipv6.Message, len(tt.buffs))
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].OOB = make([]byte, 0, 2)
}
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
if got != len(tt.wantLens) {
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
}
for i := 0; i < got; i++ {
if msgs[i].Addr != addr {
t.Errorf("msgs[%d].Addr != passed addr", i)
}
gotLen := len(msgs[i].Buffers[0])
if gotLen != tt.wantLens[i] {
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
}
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
if err != nil {
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
}
if gotGSO != tt.wantGSO[i] {
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
}
}
})
}
}
func mockGetGSOSize(control []byte) (int, error) {
if len(control) < 2 {
return 0, nil
}
return int(binary.LittleEndian.Uint16(control)), nil
}
func Test_splitCoalescedMessages(t *testing.T) {
newMsg := func(n, gso int) ipv6.Message {
msg := ipv6.Message{
Buffers: [][]byte{make([]byte, 1<<16-1)},
N: n,
OOB: make([]byte, 2),
}
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
if gso > 0 {
msg.NN = 2
}
return msg
}
cases := []struct {
name string
msgs []ipv6.Message
firstMsgAt int
wantNumEval int
wantMsgLens []int
wantErr bool
}{
{
name: "second last split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(3, 1),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 3,
wantMsgLens: []int{1, 1, 1, 0},
wantErr: false,
},
{
name: "second last no split last empty",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(0, 0),
},
firstMsgAt: 2,
wantNumEval: 1,
wantMsgLens: []int{1, 0, 0, 0},
wantErr: false,
},
{
name: "second last no split last no split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(1, 0),
},
firstMsgAt: 2,
wantNumEval: 2,
wantMsgLens: []int{1, 1, 0, 0},
wantErr: false,
},
{
name: "second last no split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(3, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last split last split",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(2, 1),
newMsg(2, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: false,
},
{
name: "second last no split last split overflow",
msgs: []ipv6.Message{
newMsg(0, 0),
newMsg(0, 0),
newMsg(1, 0),
newMsg(4, 1),
},
firstMsgAt: 2,
wantNumEval: 4,
wantMsgLens: []int{1, 1, 1, 1},
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
if err != nil && !tt.wantErr {
t.Fatalf("err: %v", err)
}
if got != tt.wantNumEval {
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
}
for i, msg := range tt.msgs {
if msg.N != tt.wantMsgLens[i] {
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
}
}
})
}
}

View file

@ -57,5 +57,13 @@ func init() {
}
return err
},
// Attempt to enable UDP_GRO
func(network, address string, c syscall.RawConn) error {
c.Control(func(fd uintptr) {
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
})
return nil
},
)
}

12
conn/errors_default.go Normal file
View file

@ -0,0 +1,12 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func errShouldDisableUDPGSO(err error) bool {
return false
}

26
conn/errors_linux.go Normal file
View file

@ -0,0 +1,26 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"os"
"golang.org/x/sys/unix"
)
func errShouldDisableUDPGSO(err error) bool {
var serr *os.SyscallError
if errors.As(err, &serr) {
// EIO is returned by udp_send_skb() if the device driver does not have
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
// See:
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
return serr.Err == unix.EIO
}
return false
}

15
conn/features_default.go Normal file
View file

@ -0,0 +1,15 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import "net"
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
return
}

29
conn/features_linux.go Normal file
View file

@ -0,0 +1,29 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"golang.org/x/sys/unix"
)
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
rc, err := conn.SyscallConn()
if err != nil {
return
}
err = rc.Control(func(fd uintptr) {
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
txOffload = errSyscall == nil
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
rxOffload = errSyscall == nil && opt == 1
})
if err != nil {
return false, false
}
return txOffload, rxOffload
}

21
conn/gso_default.go Normal file
View file

@ -0,0 +1,21 @@
//go:build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
func setGSOSize(control *[]byte, gsoSize uint16) {
}
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
// offloading control data.
const gsoControlSize = 0

65
conn/gso_linux.go Normal file
View file

@ -0,0 +1,65 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"unsafe"
"golang.org/x/sys/unix"
)
const (
sizeOfGSOData = 2
)
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
func getGSOSize(control []byte) (int, error) {
var (
hdr unix.Cmsghdr
data []byte
rem = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
if err != nil {
return 0, fmt.Errorf("error parsing socket control message: %w", err)
}
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
var gso uint16
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
return int(gso), nil
}
}
return 0, nil
}
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
// data in control untouched.
func setGSOSize(control *[]byte, gsoSize uint16) {
existingLen := len(*control)
avail := cap(*control) - existingLen
space := unix.CmsgSpace(sizeOfGSOData)
if avail < space {
return
}
*control = (*control)[:cap(*control)]
gsoControl := (*control)[existingLen:]
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
hdr.Level = unix.SOL_UDP
hdr.Type = unix.UDP_SEGMENT
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
*control = (*control)[:existingLen+space]
}
// gsoControlSize returns the recommended buffer size for pooling UDP
// offloading control data.
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)

View file

@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string {
return ""
}
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
// {get,set}srcControl feature set, but use alternatively named flags and need
// ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
@ -34,8 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// srcControlSize returns the recommended buffer size for pooling sticky control
// data.
const srcControlSize = 0
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
const stickyControlSize = 0
const StdNetSupportsStickySockets = false

View file

@ -105,6 +105,8 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = append(*control, ep.src...)
}
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
// stickyControlSize returns the recommended buffer size for pooling sticky
// offloading control data.
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
const StdNetSupportsStickySockets = true

View file

@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, srcControlSize)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) {
}
setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, srcControlSize)
control := make([]byte, stickyControlSize)
setSrcControl(&control, ep)
@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, unix.CmsgLen(0))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) {
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
control := make([]byte, stickyControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO

View file

@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
c chan *QueueOutboundElement
c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize),
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
c chan *QueueInboundElement
c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
c: make(chan *QueueInboundElement, QueueInboundSize),
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
c chan *[]*QueueInboundElement
c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
c: make(chan *[]*QueueInboundElement, QueueInboundSize),
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
case elems := <-q.c:
for _, elem := range *elems {
elem.Lock()
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
c chan *[]*QueueOutboundElement
c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
case elems := <-q.c:
for _, elem := range *elems {
elem.Lock()
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
device.PutOutboundElementsContainer(elemsContainer)
default:
return
}

View file

@ -70,11 +70,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
outboundElementsSlice *WaitPool
inboundElementsSlice *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
}
queue struct {

View file

@ -45,9 +45,9 @@ type Peer struct {
}
queue struct {
staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]

View file

@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) {
}
func (device *Device) PopulatePools() {
device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &s
})
device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &s
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
@ -65,28 +65,32 @@ func (device *Device) PopulatePools() {
})
}
func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
for i := range *s {
(*s)[i] = nil
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
*s = (*s)[:0]
device.pool.outboundElementsSlice.Put(s)
c.elems = c.elems[:0]
device.pool.inboundElementsContainer.Put(c)
}
func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
c.Mutex = sync.Mutex{}
return c
}
func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
for i := range *s {
(*s)[i] = nil
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
for i := range c.elems {
c.elems[i] = nil
}
*s = (*s)[:0]
device.pool.inboundElementsSlice.Put(s)
c.elems = c.elems[:0]
device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {

View file

@ -14,6 +14,6 @@ const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2200
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)

View file

@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@ -35,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
type QueueInboundElementsContainer struct {
sync.Mutex
elems []*QueueInboundElement
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@ -90,7 +94,7 @@ func (device *Device) RoutineReceiveIncoming(
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
@ -139,7 +143,7 @@ func (device *Device) RoutineReceiveIncoming(
if device.isAdvancedSecurityOn() {
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
junkSize := msgTypeToJunkSize[assumedMsgType]
// transport size can align with other header types;
// transport size can align with other header types;
// making sure we have the right msgType
msgType = binary.LittleEndian.Uint32(packet[junkSize:junkSize+4])
if msgType == assumedMsgType {
@ -153,7 +157,7 @@ func (device *Device) RoutineReceiveIncoming(
if msgType != MessageTransportType {
device.log.Verbosef("ASec: Received message with unknown type")
continue
}
}
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
@ -195,15 +199,14 @@ func (device *Device) RoutineReceiveIncoming(
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsSlice()
elemsForPeer = device.GetInboundElementsContainer()
elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
@ -243,18 +246,16 @@ func (device *Device) RoutineReceiveIncoming(
}
}
device.aSecMux.RUnlock()
for peer, elems := range elemsByPeer {
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
for _, elem := range *elems {
device.queue.decryption.c <- elem
}
peer.queue.inbound.c <- elemsContainer
device.queue.decryption.c <- elemsContainer
} else {
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
@ -267,26 +268,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
for elemsContainer := range device.queue.decryption.c {
for _, elem := range elemsContainer.elems {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
}
elem.Unlock()
elemsContainer.Unlock()
}
}
@ -468,12 +471,12 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.inbound.c {
if elems == nil {
for elemsContainer := range peer.queue.inbound.c {
if elemsContainer == nil {
return
}
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
@ -552,11 +555,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
device.PutInboundElementsSlice(elems)
device.PutInboundElementsContainer(elemsContainer)
}
}

View file

@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amnezia-wg/conn"
"github.com/amnezia-vpn/amnezia-wg/tun"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
@ -46,7 +47,6 @@ import (
*/
type QueueOutboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@ -54,10 +54,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
type QueueOutboundElementsContainer struct {
sync.Mutex
elems []*QueueOutboundElement
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@ -79,15 +83,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
elems := peer.device.GetOutboundElementsSlice()
*elems = append(*elems, elem)
elemsContainer := peer.device.GetOutboundElementsContainer()
elemsContainer.elems = append(elemsContainer.elems, elem)
select {
case peer.queue.staged <- elems:
case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@ -278,7 +282,7 @@ func (device *Device) RoutineReadFromTUN() {
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
@ -335,10 +339,10 @@ func (device *Device) RoutineReadFromTUN() {
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsSlice()
elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
@ -348,11 +352,11 @@ func (device *Device) RoutineReadFromTUN() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
for _, elem := range *elemsForPeer {
for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elemsForPeer)
device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
@ -376,7 +380,7 @@ func (device *Device) RoutineReadFromTUN() {
}
}
func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
case peer.queue.staged <- elems:
@ -385,11 +389,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
}
select {
case tooOld := <-peer.queue.staged:
for _, elem := range *tooOld {
for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(tooOld)
peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@ -408,54 +412,52 @@ top:
}
for {
var elemsOOO *[]*QueueOutboundElement
var elemsContainerOOO *QueueOutboundElementsContainer
select {
case elems := <-peer.queue.staged:
case elemsContainer := <-peer.queue.staged:
i := 0
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
if elemsOOO == nil {
elemsOOO = peer.device.GetOutboundElementsSlice()
if elemsContainerOOO == nil {
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
*elemsOOO = append(*elemsOOO, elem)
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
(*elems)[i] = elem
elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
elem.Lock()
}
*elems = (*elems)[:i]
elemsContainer.Lock()
elemsContainer.elems = elemsContainer.elems[:i]
if elemsOOO != nil {
peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
if elemsContainerOOO != nil {
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(*elems) == 0 {
peer.device.PutOutboundElementsSlice(elems)
if len(elemsContainer.elems) == 0 {
peer.device.PutOutboundElementsContainer(elemsContainer)
goto top
}
// add to parallel and sequential queue
if peer.isRunning.Load() {
peer.queue.outbound.c <- elems
for _, elem := range *elems {
peer.device.queue.encryption.c <- elem
}
peer.queue.outbound.c <- elemsContainer
peer.device.queue.encryption.c <- elemsContainer
} else {
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
}
if elemsOOO != nil {
if elemsContainerOOO != nil {
goto top
}
default:
@ -492,12 +494,12 @@ func (peer *Peer) createJunkPackets() ([][]byte, error) {
func (peer *Peer) FlushStagedPackets() {
for {
select {
case elems := <-peer.queue.staged:
for _, elem := range *elems {
case elemsContainer := <-peer.queue.staged:
for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(elems)
peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@ -531,30 +533,34 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
for elem := range device.queue.encryption.c {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
for elemsContainer := range device.queue.encryption.c {
for _, elem := range elemsContainer.elems {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(
len(elem.packet),
int(device.tun.mtu.Load()),
)
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(header, nonce[:], elem.packet, nil)
elem.Unlock()
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
}
elemsContainer.Unlock()
}
}
@ -568,9 +574,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.outbound.c {
for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
if elems == nil {
if elemsContainer == nil {
return
}
if !peer.isRunning.Load() {
@ -580,16 +586,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
// The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
continue
}
dataSent := false
for _, elem := range *elems {
elem.Lock()
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
@ -603,11 +609,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
if dataSent {
peer.timersDataSent()
}
for _, elem := range *elems {
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
device.PutOutboundElementsContainer(elemsContainer)
if err != nil {
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {
device.log.Verbosef(err.Error())
err = errGSO.RetryErr
}
}
if err != nil {
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue

10
go.mod
View file

@ -4,14 +4,14 @@ go 1.20
require (
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.6.0
golang.org/x/net v0.7.0
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
golang.org/x/crypto v0.13.0
golang.org/x/net v0.15.0
golang.org/x/sys v0.12.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259
)
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)

20
go.sum
View file

@ -2,15 +2,15 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=

View file

@ -3,23 +3,99 @@ package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
i := 0
n := len(b)
for n >= 4 {
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
n -= 4
i += 4
for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
b = b[128:]
}
for n >= 2 {
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
b = b[64:]
}
if n == 1 {
ac += uint64(b[i]) << 8
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
return ac
}

35
tun/checksum_test.go Normal file
View file

@ -0,0 +1,35 @@
package tun
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,
128,
256,
512,
1024,
1500,
2048,
4096,
8192,
9000,
9001,
}
for _, length := range lengths {
b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum(buf, 0)
}
})
}
}

View file

@ -25,7 +25,7 @@ import (
"github.com/amnezia-vpn/amnezia-wg/tun"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
@ -43,7 +43,7 @@ type netTun struct {
ep *channel.Endpoint
stack *stack.Stack
events chan tun.Event
incomingPacket chan *bufferv2.View
incomingPacket chan *buffer.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
incomingPacket: make(chan *bufferv2.View),
incomingPacket: make(chan *buffer.View),
dnsServers: dnsServers,
mtu: mtu,
}
@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
continue
}
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
return res.Count, &PingAddr{remoteAddr}, nil
}

View file

@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = 0
coalescePSHEnding coalesceResult = 1
coalesceItemInvalidCSum coalesceResult = 2
coalescePktInvalidCSum coalesceResult = 3
coalesceSuccess coalesceResult = 4
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(headersLen),
gsoSize: uint16(item.gsoSize),
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
// (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
}
hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:])
// Calculate the pseudo header checksum and place it at the TCP checksum
// offset. Downstream checksum offloading will combine this with computation
// of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := bufsOffset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
item.numMerged++
return coalesceSuccess
@ -390,43 +354,52 @@ const (
maxUint16 = 1<<16 - 1
)
type tcpGROResult int
const (
tcpGROResultNoop tcpGROResult = iota
tcpGROResultTableInsert
tcpGROResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It will return false when pktI is not
// coalesced, otherwise true. This indicates to the caller if bufs[pktI]
// should be written to the Device.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
// existing packets tracked in table. It returns a tcpGROResultNoop when no
// action was taken, tcpGROResultTableInsert when the evaluated packet was
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return false
return tcpGROResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return false
return tcpGROResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return false
return tcpGROResultNoop
}
}
if len(pkt) < iphLen {
return false
return tcpGROResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return false
return tcpGROResultNoop
}
if len(pkt) < iphLen+tcphLen {
return false
return tcpGROResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return false
return tcpGROResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return false
return tcpGROResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return false
return tcpGROResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return false
return tcpGROResultNoop
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return true
return tcpGROResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return false
return tcpGROResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return false
return tcpGROResultTableInsert
}
func isTCP4NoIPOptions(b []byte) bool {
@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool {
return true
}
// applyCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var coalesced bool
var result tcpGROResult
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp4Table, false)
result = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
coalesced = tcpGRO(bufs, offset, i, tcp6Table, true)
result = tcpGRO(bufs, offset, i, tcp6Table, true)
}
if !coalesced {
switch result {
case tcpGROResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case tcpGROResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
return nil
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
return errors.Join(err4, err6)
}
// tcpTSO splits packets from in into outBuffs, writing the size of each

View file

@ -35,8 +35,8 @@ func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipFields := &header.IPv4Fields{
SrcAddr: tcpip.Address(srcAs4[:]),
DstAddr: tcpip.Address(dstAs4[:]),
SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
Protocol: unix.IPPROTO_TCP,
TTL: 64,
TotalLength: uint16(totalLen),
@ -72,8 +72,8 @@ func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipFields := &header.IPv6Fields{
SrcAddr: tcpip.Address(srcAs16[:]),
DstAddr: tcpip.Address(dstAs16[:]),
SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64,
PayloadLength: uint16(segmentSize + 20),

View file

@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
if tun.close.Load() {
return
}
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
if update {