From 6a84778f2ca810f5fb5cb078e001494f08d9085f Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 13:53:07 -0700 Subject: [PATCH 01/13] conn, device: use UDP GSO and GRO on Linux StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is dependent on checksum offload support on the egress netdev. UDP GSO will be disabled in the event sendmmsg() returns EIO, which is a strong signal that the egress netdev does not support checksum offload. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is from commit 052af4a without UDP GSO or GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 3.01 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 sender [ 5] 0.00-10.04 sec 9.85 GBytes 8.42 Gbits/sec receiver The second result is with UDP GSO and GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 397 ++++++++++++------ conn/bind_std_test.go | 230 +++++++++- .../{sticky_default.go => control_default.go} | 20 +- conn/{sticky_linux.go => control_linux.go} | 51 ++- ...ky_linux_test.go => control_linux_test.go} | 10 +- conn/controlfns_linux.go | 8 + conn/errors_default.go | 12 + conn/errors_linux.go | 26 ++ conn/features_default.go | 15 + conn/features_linux.go | 35 ++ device/send.go | 8 + go.mod | 2 +- go.sum | 4 +- 13 files changed, 672 insertions(+), 146 deletions(-) rename conn/{sticky_default.go => control_default.go} (56%) rename conn/{sticky_linux.go => control_linux.go} (65%) rename conn/{sticky_linux_test.go => control_linux_test.go} (96%) create mode 100644 conn/errors_default.go create mode 100644 conn/errors_linux.go create mode 100644 conn/features_default.go create mode 100644 conn/features_linux.go diff --git a/conn/bind_std.go b/conn/bind_std.go index c701ef8..9886c91 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -29,16 +30,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +58,14 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, controlSize) } return &msgs }, @@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx. +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -179,19 +174,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) if runtime.GOOS == "linux" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) if runtime.GOOS == "linux" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -201,69 +198,93 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err } - return numMsgs, nil + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } @@ -293,28 +314,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -324,25 +359,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] } else { - return s.send4(conn, pc4, endpoint, bufs) + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error @@ -350,59 +416,128 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, ) if runtime.GOOS == "linux" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e46776..34a3c9a 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/sticky_default.go b/conn/control_default.go similarity index 56% rename from conn/sticky_default.go rename to conn/control_default.go index 1fa8a0c..9459da5 100644 --- a/conn/sticky_default.go +++ b/conn/control_default.go @@ -21,8 +21,9 @@ func (e *StdNetEndpoint) SrcToString() string { return "" } -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. @@ -34,8 +35,17 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const controlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/control_linux.go similarity index 65% rename from conn/sticky_linux.go rename to conn/control_linux.go index a30ccc7..44a94e6 100644 --- a/conn/sticky_linux.go +++ b/conn/control_linux.go @@ -8,6 +8,7 @@ package conn import ( + "fmt" "net/netip" "unsafe" @@ -105,6 +106,54 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) { *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/control_linux_test.go similarity index 96% rename from conn/sticky_linux_test.go rename to conn/control_linux_test.go index 679213a..96f9da2 100644 --- a/conn/sticky_linux_test.go +++ b/conn/control_linux_test.go @@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) @@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, controlSize) setSrcControl(&control, ep) @@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, unix.CmsgLen(0)) + control := make([]byte, controlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + control := make([]byte, controlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) { } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + control := make([]byte, controlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe..f6ab1d2 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -57,5 +57,13 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, ) } diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 0000000..f1e5b90 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 0000000..8e61000 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 0000000..d53ff5f --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 0000000..e1fb57f --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,35 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + if errSyscall != nil { + return + } + txOffload = true + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + if errSyscall != nil { + return + } + rxOffload = opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/device/send.go b/device/send.go index d22bf26..cd8a2a0 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) @@ -525,6 +526,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.PutOutboundElement(elem) } device.PutOutboundElementsSlice(elems) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue diff --git a/go.mod b/go.mod index c04e1bb..758dcde 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( golang.org/x/crypto v0.6.0 golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 ) diff --git a/go.sum b/go.sum index cfeaee6..fe4ca7e 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= From 4201e08f1dbb521e5555d96a3b6464a860466f5f Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 14:41:04 -0700 Subject: [PATCH 02/13] device: distribute crypto work as slice of elements After reducing UDP stack traversal overhead via GSO and GRO, runtime.chanrecv() began to account for a high percentage (20% in one environment) of perf samples during a throughput benchmark. The individual packet channel ops with the crypto goroutines was the primary contributor to this overhead. Updating these channels to pass vectors, which the device package already handles at its ends, reduced this overhead substantially, and improved throughput. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is with UDP GSO and GRO, and with single element channels. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver The second result is with channels updated to pass a slice of elements. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 13.2 GBytes 11.3 Gbits/sec 182 sender [ 5] 0.00-10.04 sec 13.2 GBytes 11.3 Gbits/sec receiver Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- device/channels.go | 8 ++++---- device/receive.go | 42 ++++++++++++++++++++-------------------- device/send.go | 48 +++++++++++++++++++++++----------------------- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/device/channels.go b/device/channels.go index 039d8df..40ee5c9 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *QueueOutboundElement + c chan *[]*QueueOutboundElement wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *QueueInboundElement + c chan *[]*QueueInboundElement wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *[]*QueueInboundElement, QueueInboundSize), } q.wg.Add(1) go func() { diff --git a/device/receive.go b/device/receive.go index e24d29f..f0f37a1 100644 --- a/device/receive.go +++ b/device/receive.go @@ -220,9 +220,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive for peer, elems := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elems - for _, elem := range *elems { - device.queue.decryption.c <- elem - } + device.queue.decryption.c <- elems } else { for _, elem := range *elems { device.PutMessageBuffer(elem.buffer) @@ -241,26 +239,28 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elem := range device.queue.decryption.c { - // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] + for elems := range device.queue.decryption.c { + for _, elem := range *elems { + // split message into fields + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] - // decrypt and release to consumer - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.packet = nil + // decrypt and release to consumer + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.packet = nil + } + elem.Unlock() } - elem.Unlock() } } diff --git a/device/send.go b/device/send.go index cd8a2a0..e838c4e 100644 --- a/device/send.go +++ b/device/send.go @@ -385,9 +385,7 @@ top: // add to parallel and sequential queue if peer.isRunning.Load() { peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.device.queue.encryption.c <- elems } else { for _, elem := range *elems { peer.device.PutMessageBuffer(elem.buffer) @@ -447,32 +445,34 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elems := range device.queue.encryption.c { + for _, elem := range *elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.Unlock() + } } } From 895d6c23cd60bd0522c5b6598a69ad6c5f1ab3a7 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 14:43:56 -0700 Subject: [PATCH 03/13] tun: unwind summing loop in checksumNoFold() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit $ benchstat old.txt new.txt goos: linux goarch: amd64 pkg: golang.zx2c4.com/wireguard/tun cpu: 12th Gen Intel(R) Core(TM) i5-12400 │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Checksum/64-12 10.670n ± 2% 4.769n ± 0% -55.30% (p=0.000 n=10) Checksum/128-12 19.665n ± 2% 8.032n ± 0% -59.16% (p=0.000 n=10) Checksum/256-12 37.68n ± 1% 16.06n ± 0% -57.37% (p=0.000 n=10) Checksum/512-12 76.61n ± 3% 32.13n ± 0% -58.06% (p=0.000 n=10) Checksum/1024-12 160.55n ± 4% 64.25n ± 0% -59.98% (p=0.000 n=10) Checksum/1500-12 231.05n ± 7% 94.12n ± 0% -59.26% (p=0.000 n=10) Checksum/2048-12 309.5n ± 3% 128.5n ± 0% -58.48% (p=0.000 n=10) Checksum/4096-12 603.8n ± 4% 257.2n ± 0% -57.41% (p=0.000 n=10) Checksum/8192-12 1185.0n ± 3% 515.5n ± 0% -56.50% (p=0.000 n=10) Checksum/9000-12 1328.5n ± 5% 564.8n ± 0% -57.49% (p=0.000 n=10) Checksum/9001-12 1340.5n ± 3% 564.8n ± 0% -57.87% (p=0.000 n=10) geomean 185.3n 77.99n -57.92% Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- tun/checksum.go | 100 +++++++++++++++++++++++++++++++++++++------ tun/checksum_test.go | 35 +++++++++++++++ 2 files changed, 123 insertions(+), 12 deletions(-) create mode 100644 tun/checksum_test.go diff --git a/tun/checksum.go b/tun/checksum.go index f4f8471..29a8fc8 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -3,23 +3,99 @@ package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] } - if n == 1 { - ac += uint64(b[i]) << 8 + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + return ac } diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..c1ccff5 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} From 8a015f7c766564c21f6bef6fdddedce7e2ede830 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 14:46:13 -0700 Subject: [PATCH 04/13] tun: reduce redundant checksumming in tcpGRO() IPv4 header and pseudo header checksums were being computed on every merge operation. Additionally, virtioNetHdr was being written at the same time. This delays those operations until after all coalescing has occurred. Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- tun/tcp_offload_linux.go | 162 ++++++++++++++++++++++++--------------- 1 file changed, 99 insertions(+), 63 deletions(-) diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 39a7180..1afd27e 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess ) // coalesceTCPPackets attempts to coalesce pkt with the packet described by @@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) item.numMerged++ return coalesceSuccess @@ -390,43 +354,52 @@ const ( maxUint16 = 1<<16 - 1 ) +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return false + return tcpGROResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return false + return tcpGROResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return false + return tcpGROResultNoop } } if len(pkt) < iphLen { - return false + return tcpGROResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return false + return tcpGROResultNoop } if len(pkt) < iphLen+tcphLen { - return false + return tcpGROResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return false + return tcpGROResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false + return tcpGROResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return false + return tcpGROResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return false + return tcpGROResultNoop } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return true + return tcpGROResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return false + return tcpGROResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false + return tcpGROResultTableInsert } func isTCP4NoIPOptions(b []byte) bool { @@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool { return true } +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset @@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var coalesced bool + var result tcpGROResult switch { case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) + result = tcpGRO(bufs, offset, i, tcp4Table, false) case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) + result = tcpGRO(bufs, offset, i, tcp6Table, true) } - if !coalesced { + switch result { + case tcpGROResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } + fallthrough + case tcpGROResultTableInsert: *toWrite = append(*toWrite, i) } } - return nil + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return errors.Join(err4, err6) } // tcpTSO splits packets from in into outBuffs, writing the size of each From 1ec454f253c068f74ba7a7aea34546c9819493c0 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 14:48:28 -0700 Subject: [PATCH 05/13] device: move Queue{In,Out}boundElement Mutex to container type Queue{In,Out}boundElement locking can contribute to significant overhead via sync.Mutex.lockSlow() in some environments. These types are passed throughout the device package as elements in a slice, so move the per-element Mutex to a container around the slice. Reviewed-by: Maisem Ali Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- device/channels.go | 32 ++++++++-------- device/device.go | 10 ++--- device/peer.go | 8 ++-- device/pools.go | 44 +++++++++++---------- device/receive.go | 43 +++++++++++---------- device/send.go | 95 ++++++++++++++++++++++++---------------------- 6 files changed, 121 insertions(+), 111 deletions(-) diff --git a/device/channels.go b/device/channels.go index 40ee5c9..e526f6b 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) default: return } @@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) default: return } diff --git a/device/device.go b/device/device.go index 1af9fe0..f9557a0 100644 --- a/device/device.go +++ b/device/device.go @@ -68,11 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - outboundElementsSlice *WaitPool - inboundElementsSlice *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { diff --git a/device/peer.go b/device/peer.go index 0ac4896..2fb5da6 100644 --- a/device/peer.go +++ b/device/peer.go @@ -45,9 +45,9 @@ type Peer struct { } queue struct { - staged chan *[]*QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] diff --git a/device/pools.go b/device/pools.go index 02a5d6a..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) { } func (device *Device) PopulatePools() { - device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { - s := make([]*QueueOutboundElement, 0, device.BatchSize()) - return &s - }) - device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) - return &s + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) @@ -65,28 +65,32 @@ func (device *Device) PopulatePools() { }) } -func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { - return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.outboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) } -func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { - return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.inboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { diff --git a/device/receive.go b/device/receive.go index f0f37a1..4b32dc5 100644 --- a/device/receive.go +++ b/device/receive.go @@ -27,7 +27,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -35,6 +34,11 @@ type QueueInboundElement struct { endpoint conn.Endpoint } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement +} + // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. @@ -87,7 +91,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int - elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { @@ -170,15 +174,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetInboundElementsSlice() + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue @@ -217,16 +220,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for peer, elems := range elemsByPeer { + for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { - peer.queue.inbound.c <- elems - device.queue.decryption.c <- elems + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } @@ -239,8 +242,8 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elems := range device.queue.decryption.c { - for _, elem := range *elems { + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] @@ -259,8 +262,8 @@ func (device *Device) RoutineDecryption(id int) { if err != nil { elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } @@ -437,12 +440,12 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.inbound.c { - if elems == nil { + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return } - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -515,11 +518,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index e838c4e..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -46,7 +46,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -54,10 +53,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -79,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -219,7 +222,7 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize @@ -276,10 +279,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -289,11 +292,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -317,7 +320,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -326,11 +329,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -349,52 +352,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - peer.device.queue.encryption.c <- elems + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -406,12 +409,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -445,8 +448,8 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elems := range device.queue.encryption.c { - for _, elem := range *elems { + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { // populate header fields header := elem.buffer[:MessageTransportHeaderSize] @@ -471,8 +474,8 @@ func (device *Device) RoutineEncryption(id int) { elem.packet, nil, ) - elem.Unlock() } + elemsContainer.Unlock() } } @@ -486,9 +489,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { @@ -498,16 +501,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } continue } dataSent := false - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } @@ -521,11 +524,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) if err != nil { var errGSO conn.ErrUDPGSODisabled if errors.As(err, &errGSO) { From ec8f6f82c20c617a3ea94478f2b5e4d49c6d3c2c Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 27 Sep 2023 14:52:21 -0700 Subject: [PATCH 06/13] tun: fix crash when ForceMTU is called after close Close closes the events channel, resulting in a panic from send on closed channel. Reported-By: Brad Fitzpatrick Signed-off-by: James Tucker Signed-off-by: Jason A. Donenfeld --- tun/tun_windows.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 0cb4ce1..34f2980 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { From 42ec952eadc297efadc70b9911d5a59bcd2db4a6 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 27 Sep 2023 16:15:09 -0700 Subject: [PATCH 07/13] go.mod,tun/netstack: bump gvisor Signed-off-by: James Tucker Signed-off-by: Jason A. Donenfeld --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- tun/netstack/tun.go | 14 +++++++------- tun/tcp_offload_linux_test.go | 8 ++++---- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index 758dcde..919dc49 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module golang.zx2c4.com/wireguard go 1.20 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 + golang.org/x/crypto v0.13.0 + golang.org/x/net v0.15.0 golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index fe4ca7e..6bcecea 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,14 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 596cfcd..2b73054 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -25,7 +25,7 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,7 +43,7 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event - incomingPacket chan *bufferv2.View + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool @@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan *bufferv2.View), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } @@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) @@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index 9160e18..ddddc48 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -35,8 +35,8 @@ func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.Address(srcAs4[:]), - DstAddr: tcpip.Address(dstAs4[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), Protocol: unix.IPPROTO_TCP, TTL: 64, TotalLength: uint16(totalLen), @@ -72,8 +72,8 @@ func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.Address(srcAs16[:]), - DstAddr: tcpip.Address(dstAs16[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), TransportProtocol: unix.IPPROTO_TCP, HopLimit: 64, PayloadLength: uint16(segmentSize + 20), From 177caa7e4419d1b95bbf0423f6be6230c7101504 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 18 Oct 2023 21:02:52 +0200 Subject: [PATCH 08/13] conn: simplify supportsUDPOffload This allows a kernel to support UDP_GRO while not supporting UDP_SEGMENT. Signed-off-by: Jason A. Donenfeld --- conn/features_linux.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/conn/features_linux.go b/conn/features_linux.go index e1fb57f..8959d93 100644 --- a/conn/features_linux.go +++ b/conn/features_linux.go @@ -18,15 +18,9 @@ func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { } err = rc.Control(func(fd uintptr) { _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) - if errSyscall != nil { - return - } - txOffload = true + txOffload = errSyscall == nil opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) - if errSyscall != nil { - return - } - rxOffload = opt == 1 + rxOffload = errSyscall == nil && opt == 1 }) if err != nil { return false, false From 24ea13351eb7a06c3760f2eae18a484ce009fcf9 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 18 Oct 2023 21:14:13 +0200 Subject: [PATCH 09/13] conn: harmonize GOOS checks between "linux" and "android" Otherwise GRO gets enabled on Android, but the conn doesn't use it, resulting in bundled packets being discarded. Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/conn/bind_std.go b/conn/bind_std.go index 9886c91..5a00f34 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -175,7 +175,7 @@ again: var fns []ReceiveFunc if v4conn != nil { s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } @@ -184,7 +184,7 @@ again: } if v6conn != nil { s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } @@ -237,7 +237,7 @@ func (s *StdNetBind) receiveIP( } defer s.putMessages(msgs) var numMsgs int - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { if rxOffload { readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) @@ -291,7 +291,7 @@ func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxO // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { return IdealBatchSize } return 1 @@ -414,7 +414,7 @@ func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message err error start int ) - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { for { n, err = pc.WriteBatch(msgs[start:], 0) if err != nil || n == len(msgs[start:]) { From 5d37bd24e14e3fff6c1ce61e299480beb3d68c00 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 21 Oct 2023 18:41:27 +0200 Subject: [PATCH 10/13] conn: separate gso and sticky control Android wants GSO but not sticky. Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 2 +- conn/gso_default.go | 21 ++++++ conn/gso_linux.go | 65 +++++++++++++++++++ .../{control_default.go => sticky_default.go} | 13 +--- conn/{control_linux.go => sticky_linux.go} | 51 +-------------- ...rol_linux_test.go => sticky_linux_test.go} | 10 +-- 6 files changed, 96 insertions(+), 66 deletions(-) create mode 100644 conn/gso_default.go create mode 100644 conn/gso_linux.go rename conn/{control_default.go => sticky_default.go} (72%) rename conn/{control_linux.go => sticky_linux.go} (66%) rename conn/{control_linux_test.go => sticky_linux_test.go} (96%) diff --git a/conn/bind_std.go b/conn/bind_std.go index 5a00f34..e1bcbd1 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -65,7 +65,7 @@ func NewStdNetBind() Bind { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, controlSize) + msgs[i].OOB = make([]byte, stickyControlSize+gsoControlSize) } return &msgs }, diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 0000000..57780db --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 0000000..b8599ce --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/control_default.go b/conn/sticky_default.go similarity index 72% rename from conn/control_default.go rename to conn/sticky_default.go index 9459da5..0b21386 100644 --- a/conn/control_default.go +++ b/conn/sticky_default.go @@ -35,17 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. -func getGSOSize(control []byte) (int, error) { - return 0, nil -} - -// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. -func setGSOSize(control *[]byte, gsoSize uint16) { -} - -// controlSize returns the recommended buffer size for pooling sticky and UDP +// stickyControlSize returns the recommended buffer size for pooling sticky // offloading control data. -const controlSize = 0 +const stickyControlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/control_linux.go b/conn/sticky_linux.go similarity index 66% rename from conn/control_linux.go rename to conn/sticky_linux.go index 44a94e6..8e206e9 100644 --- a/conn/control_linux.go +++ b/conn/sticky_linux.go @@ -8,7 +8,6 @@ package conn import ( - "fmt" "net/netip" "unsafe" @@ -106,54 +105,8 @@ func setSrcControl(control *[]byte, ep *StdNetEndpoint) { *control = append(*control, ep.src...) } -const ( - sizeOfGSOData = 2 -) - -// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. -func getGSOSize(control []byte) (int, error) { - var ( - hdr unix.Cmsghdr - data []byte - rem = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return 0, fmt.Errorf("error parsing socket control message: %w", err) - } - if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { - var gso uint16 - copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) - return int(gso), nil - } - } - return 0, nil -} - -// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing -// data in control untouched. -func setGSOSize(control *[]byte, gsoSize uint16) { - existingLen := len(*control) - avail := cap(*control) - existingLen - space := unix.CmsgSpace(sizeOfGSOData) - if avail < space { - return - } - *control = (*control)[:cap(*control)] - gsoControl := (*control)[existingLen:] - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) - hdr.Level = unix.SOL_UDP - hdr.Type = unix.UDP_SEGMENT - hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) - copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) - *control = (*control)[:existingLen+space] -} - -// controlSize returns the recommended buffer size for pooling sticky and UDP +// stickyControlSize returns the recommended buffer size for pooling sticky // offloading control data. -var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) const StdNetSupportsStickySockets = true diff --git a/conn/control_linux_test.go b/conn/sticky_linux_test.go similarity index 96% rename from conn/control_linux_test.go rename to conn/sticky_linux_test.go index 96f9da2..d2bd584 100644 --- a/conn/control_linux_test.go +++ b/conn/sticky_linux_test.go @@ -60,7 +60,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, controlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -89,7 +89,7 @@ func Test_setSrcControl(t *testing.T) { } setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, controlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -113,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, controlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -129,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, controlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -149,7 +149,7 @@ func Test_getSrcFromControl(t *testing.T) { } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, controlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO From f502ec3fad116d11109529bcf283e464f4822c18 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 21 Oct 2023 19:06:38 +0200 Subject: [PATCH 11/13] conn: fix cmsg data padding calculation for gso Signed-off-by: Jason A. Donenfeld --- conn/gso_linux.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn/gso_linux.go b/conn/gso_linux.go index b8599ce..8596b29 100644 --- a/conn/gso_linux.go +++ b/conn/gso_linux.go @@ -56,7 +56,7 @@ func setGSOSize(control *[]byte, gsoSize uint16) { hdr.Level = unix.SOL_UDP hdr.Type = unix.UDP_SEGMENT hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) - copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) *control = (*control)[:existingLen+space] } From b3df23dcd40ba4568572f338f9fd16b87053fc29 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 21 Oct 2023 19:32:07 +0200 Subject: [PATCH 12/13] conn: set unused OOB to zero length Otherwise in the event that we're using GSO without sticky sockets, we pass garbage OOB buffers to sendmmsg, making a EINVAL, when GSO doesn't set its header. Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conn/bind_std.go b/conn/bind_std.go index e1bcbd1..46df7fd 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -65,7 +65,7 @@ func NewStdNetBind() Bind { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, stickyControlSize+gsoControlSize) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) } return &msgs }, @@ -200,6 +200,7 @@ again: func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} } s.msgsPool.Put(msgs) From 2e0774f246fb4fc1bd5cb44584d033038c89174e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 22 Oct 2023 02:12:13 +0200 Subject: [PATCH 13/13] device: ratchet up max segment size on android GRO requires big allocations to be efficient. This isn't great, as there might be Android memory usage issues. So we should revisit this commit. But at least it gets things working again. Signed-off-by: Jason A. Donenfeld --- device/queueconstants_android.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 3d80ead..25f700a 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -14,6 +14,6 @@ const ( QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 )