From 1cf89f5339b549236f38ce5fbc40f7bf993d9626 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Wed, 8 Nov 2023 14:06:20 -0800 Subject: [PATCH 1/7] tun: fix Device.Read() buf length assumption on Windows The length of a packet read from the underlying TUN device may exceed the length of a supplied buffer when MTU exceeds device.MaxMessageSize. Reviewed-by: Brad Fitzpatrick Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- tun/tun_windows.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 34f2980..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -160,11 +160,10 @@ retry: packet, err := tun.session.ReceivePacket() switch err { case nil: - packetSize := len(packet) - copy(bufs[0][offset:], packet) - sizes[0] = packetSize + n := copy(bufs[0][offset:], packet) + sizes[0] = n tun.session.ReleaseReceivePacket(packet) - tun.rate.update(uint64(packetSize)) + tun.rate.update(uint64(n)) return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { From d0bc03c707974a84a672716c718f99fab49e7eb8 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 31 Oct 2023 19:53:35 -0700 Subject: [PATCH 2/7] tun: implement UDP GSO/GRO for Linux Implement UDP GSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver starting in v6.2. secnetperf, a QUIC benchmark utility from microsoft/msquic@8e1eb1a, is used to demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. secnetperf was invoked with the following command line options: -stats:1 -exec:maxtput -test:tput -download:10000 -timed:1 -encrypt:0 The first result is from commit 2e0774f without UDP GSO/GRO on the TUN. [conn][0x55739a144980] STATS: EcnCapable=0 RTT=3973 us SendTotalPackets=55859 SendSuspectedLostPackets=61 SendSpuriousLostPackets=59 SendCongestionCount=27 SendEcnCongestionCount=0 RecvTotalPackets=2779122 RecvReorderedPackets=0 RecvDroppedPackets=0 RecvDuplicatePackets=0 RecvDecryptionFailures=0 Result: 3654977571 bytes @ 2922821 kbps (10003.972 ms). The second result is with UDP GSO/GRO on the TUN. [conn][0x56493dfd09a0] STATS: EcnCapable=0 RTT=1216 us SendTotalPackets=165033 SendSuspectedLostPackets=64 SendSpuriousLostPackets=61 SendCongestionCount=53 SendEcnCongestionCount=0 RecvTotalPackets=11845268 RecvReorderedPackets=25267 RecvDroppedPackets=0 RecvDuplicatePackets=0 RecvDecryptionFailures=0 Result: 15574671184 bytes @ 12458214 kbps (10001.222 ms). Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- ...{tcp_offload_linux.go => offload_linux.go} | 598 ++++++++++---- tun/offload_linux_test.go | 752 ++++++++++++++++++ tun/tcp_offload_linux_test.go | 411 ---------- ...65e4830d6dc087cab24cd1e154c2e790589a309b77 | 8 - ...6784411a8ce2e8e03aa3384105e581f2c67494700d | 8 - tun/tun_linux.go | 71 +- 6 files changed, 1258 insertions(+), 590 deletions(-) rename tun/{tcp_offload_linux.go => offload_linux.go} (50%) create mode 100644 tun/offload_linux_test.go delete mode 100644 tun/tcp_offload_linux_test.go delete mode 100644 tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 delete mode 100644 tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d diff --git a/tun/tcp_offload_linux.go b/tun/offload_linux.go similarity index 50% rename from tun/tcp_offload_linux.go rename to tun/offload_linux.go index 1afd27e..9ff7fea 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/offload_linux.go @@ -57,22 +57,23 @@ const ( virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) ) -// flowKey represents the key for a flow. -type flowKey struct { +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { srcAddr, dstAddr [16]byte srcPort, dstPort uint16 rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool } -// tcpGROTable holds flow and coalescing information for the purposes of GRO. +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. type tcpGROTable struct { - itemsByFlow map[flowKey][]tcpGROItem + itemsByFlow map[tcpFlowKey][]tcpGROItem itemsPool [][]tcpGROItem } func newTCPGROTable() *tcpGROTable { t := &tcpGROTable{ - itemsByFlow: make(map[flowKey][]tcpGROItem, conn.IdealBatchSize), + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), } for i := range t.itemsPool { @@ -81,14 +82,15 @@ func newTCPGROTable() *tcpGROTable { return t } -func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { - key := flowKey{} - addrSize := dstAddr - srcAddr - copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) - copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 return key } @@ -96,7 +98,7 @@ func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { // returning the packets found for the flow, or inserting a new one if none // is found. func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) items, ok := t.itemsByFlow[key] if ok { return items, ok @@ -108,7 +110,7 @@ func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, t // insert an item in the table for the provided packet and packet metadata. func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { - key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) item := tcpGROItem{ key: key, bufsIndex: uint16(bufsIndex), @@ -131,7 +133,7 @@ func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { items[i] = item } -func (t *tcpGROTable) deleteAt(key flowKey, i int) { +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { items, _ := t.itemsByFlow[key] items = append(items[:i], items[i+1:]...) t.itemsByFlow[key] = items @@ -140,7 +142,7 @@ func (t *tcpGROTable) deleteAt(key flowKey, i int) { // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime // of a GRO evaluation across a vector of packets. type tcpGROItem struct { - key flowKey + key tcpFlowKey sentSeq uint32 // the sequence number bufsIndex uint16 // the index into the original bufs slice numMerged uint16 // the number of packets merged into this item @@ -164,6 +166,103 @@ func (t *tcpGROTable) reset() { } } +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + // canCoalesce represents the outcome of checking if two TCP packets are // candidates for coalescing. type canCoalesce int @@ -174,6 +273,61 @@ const ( coalesceAppend canCoalesce = 1 ) +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. This function makes considerations that match the kernel's // GRO self tests, which can be found in tools/testing/selftests/net/gro.c. @@ -189,29 +343,8 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } } - if pkt[0]>>4 == 6 { - if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { - // cannot coalesce with unequal Traffic class values - return coalesceUnavailable - } - if pkt[7] != pktTarget[7] { - // cannot coalesce with unequal Hop limit values - return coalesceUnavailable - } - } else { - if pkt[1] != pktTarget[1] { - // cannot coalesce with unequal ToS values - return coalesceUnavailable - } - if pkt[6]>>5 != pktTarget[6]>>5 { - // cannot coalesce with unequal DF or reserved bits. MF is checked - // further up the stack. - return coalesceUnavailable - } - if pkt[8] != pktTarget[8] { - // cannot coalesce with unequal TTL values - return coalesceUnavailable - } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable } // seq adjacency lhsLen := item.gsoSize @@ -252,16 +385,16 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet return coalesceUnavailable } -func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { srcAddrAt := ipv4SrcAddrOffset addrSize := 4 if isV6 { srcAddrAt = ipv6SrcAddrOffset addrSize = 16 } - tcpTotalLen := uint16(len(pkt) - int(iphLen)) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) - return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 } // coalesceResult represents the result of attempting to coalesce two TCP @@ -276,8 +409,36 @@ const ( coalesceSuccess ) +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + // coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, returning the outcome. This function may swap bufs elements in the +// item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { @@ -297,11 +458,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalescePSHEnding } if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } item.sentSeq = seq @@ -319,11 +480,11 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalesceInsufficientCap } if item.numMerged == 0 { - if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } - if !tcpChecksumValid(pkt, item.iphLen, isV6) { + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { return coalescePktInvalidCSum } if pshSet { @@ -354,52 +515,52 @@ const ( maxUint16 = 1<<16 - 1 ) -type tcpGROResult int +type groResult int const ( - tcpGROResultNoop tcpGROResult = iota - tcpGROResultTableInsert - tcpGROResultCoalesced + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced ) // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It returns a tcpGROResultNoop when no -// action was taken, tcpGROResultTableInsert when the evaluated packet was -// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return tcpGROResultNoop + return groResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return tcpGROResultNoop + return groResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return tcpGROResultNoop + return groResultNoop } } if len(pkt) < iphLen { - return tcpGROResultNoop + return groResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return tcpGROResultNoop + return groResultNoop } if len(pkt) < iphLen+tcphLen { - return tcpGROResultNoop + return groResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return tcpGROResultNoop + return groResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -407,14 +568,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return tcpGROResultNoop + return groResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return tcpGROResultNoop + return groResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -425,7 +586,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return tcpGROResultNoop + return groResultTableInsert } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -443,54 +604,25 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return tcpGROResultCoalesced + return groResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return tcpGROResultNoop + return groResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return tcpGROResultTableInsert + return groResultTableInsert } -func isTCP4NoIPOptions(b []byte) bool { - if len(b) < 40 { - return false - } - if b[0]>>4 != 4 { - return false - } - if b[0]&0x0F != 5 { - return false - } - if b[9] != unix.IPPROTO_TCP { - return false - } - return true -} - -func isTCP6NoEH(b []byte) bool { - if len(b) < 60 { - return false - } - if b[0]>>4 != 6 { - return false - } - if b[6] != unix.IPPROTO_TCP { - return false - } - return true -} - -// applyCoalesceAccounting updates bufs to account for coalescing based on the +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the // metadata found in table. -func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { if item.numMerged > 0 { @@ -505,7 +637,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. - if isV6 { + if item.key.isV6 { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len } else { @@ -525,7 +657,7 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 // this with computation of the tcp header and payload checksum. addrLen := 4 addrOffset := ipv4SrcAddrOffset - if isV6 { + if item.key.isV6 { addrLen = 16 addrOffset = ipv6SrcAddrOffset } @@ -546,54 +678,245 @@ func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 return nil } +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && canUDPGRO { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset -// and recycle them across vectors of packets. -func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var result tcpGROResult - switch { - case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - result = tcpGRO(bufs, offset, i, tcp4Table, false) - case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - result = tcpGRO(bufs, offset, i, tcp6Table, true) + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) } switch result { - case tcpGROResultNoop: + case groResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } fallthrough - case tcpGROResultTableInsert: + case groResultTableInsert: *toWrite = append(*toWrite, i) } } - err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) - err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) - return errors.Join(err4, err6) + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) } -// tcpTSO splits packets from in into outBuffs, writing the size of each +// gsoSplit splits packets from in into outBuffs, writing the size of each // element into sizes. It returns the number of buffers populated, and/or an // error. -func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { iphLen := int(hdr.csumStart) srcAddrOffset := ipv6SrcAddrOffset addrLen := 16 - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if !isV6 { in[10], in[11] = 0, 0 // clear ipv4 header checksum srcAddrOffset = ipv4SrcAddrOffset addrLen = 4 } - tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) - in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum - firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } nextSegmentDataAt := int(hdr.hdrLen) i := 0 for ; nextSegmentDataAt < len(in); i++ { @@ -610,7 +933,7 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs out := outBuffs[i][outOffset:] copy(out, in[:iphLen]) - if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if !isV6 { // For IPv4 we are responsible for incrementing the ID field, // updating the total len field, and recalculating the header // checksum. @@ -627,25 +950,32 @@ func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffs binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) } - // TCP header + // copy transport header copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) - tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) - binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) - if nextSegmentEnd != len(in) { - // FIN and PSH should only be set on last segment - clearFlags := tcpFlagFIN | tcpFlagPSH - out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) } // payload copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) - // TCP checksum - tcpHLen := int(hdr.hdrLen - hdr.csumStart) - tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) - tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) - tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) - binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) nextSegmentDataAt += int(hdr.gsoSize) } diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go new file mode 100644 index 0000000..ae55c8c --- /dev/null +++ b/tun/offload_linux_test.go @@ -0,0 +1,752 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.IdealBatchSize) + sizes := make([]int, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + canUDPGRO bool + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + true, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, + }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + true, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + true, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + true, + []int{0}, + []int{340}, + false, + }, + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} + +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] + + tests := []struct { + name string + b []byte + canUDPGRO bool + want groCandidateType + }{ + { + "tcp4", + tcp4, + true, + tcp4GROCandidate, + }, + { + "tcp6", + tcp6, + true, + tcp6GROCandidate, + }, + { + "udp4", + udp4, + true, + udp4GROCandidate, + }, + { + "udp4 no support", + udp4, + false, + notGROCandidate, + }, + { + "udp6", + udp6, + true, + udp6GROCandidate, + }, + { + "udp6 no support", + udp6, + false, + notGROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + true, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + true, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + true, + notGROCandidate, + }, + { + "tcp6 too short", + tcp6TooShort, + true, + notGROCandidate, + }, + { + "invalid IP version", + []byte{0x00}, + true, + notGROCandidate, + }, + { + "invalid IP header len", + ip4InvalidHeaderLen, + true, + notGROCandidate, + }, + { + "ip4 invalid protocol", + ip4InvalidProtocol, + true, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + true, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go deleted file mode 100644 index ddddc48..0000000 --- a/tun/tcp_offload_linux_test.go +++ /dev/null @@ -1,411 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "net/netip" - "testing" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/conn" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - offset = virtioNetHdrLen -) - -var ( - ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") - ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") - ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") - ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") - ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") - ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") -) - -func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { - totalLen := 40 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv4H := header.IPv4(b[offset:]) - srcAs4 := srcIPPort.Addr().As4() - dstAs4 := dstIPPort.Addr().As4() - ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), - DstAddr: tcpip.AddrFromSlice(dstAs4[:]), - Protocol: unix.IPPROTO_TCP, - TTL: 64, - TotalLength: uint16(totalLen), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv4H.Encode(ipFields) - tcpH := header.TCP(b[offset+20:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { - totalLen := 60 + segmentSize - b := make([]byte, offset+int(totalLen), 65535) - ipv6H := header.IPv6(b[offset:]) - srcAs16 := srcIPPort.Addr().As16() - dstAs16 := dstIPPort.Addr().As16() - ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), - DstAddr: tcpip.AddrFromSlice(dstAs16[:]), - TransportProtocol: unix.IPPROTO_TCP, - HopLimit: 64, - PayloadLength: uint16(segmentSize + 20), - } - if ipFn != nil { - ipFn(ipFields) - } - ipv6H.Encode(ipFields) - tcpH := header.TCP(b[offset+40:]) - tcpH.Encode(&header.TCPFields{ - SrcPort: srcIPPort.Port(), - DstPort: dstIPPort.Port(), - SeqNum: seq, - AckNum: 1, - DataOffset: 20, - Flags: flags, - WindowSize: 3000, - }) - pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) - tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) - return b -} - -func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { - return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) -} - -func Test_handleVirtioRead(t *testing.T) { - tests := []struct { - name string - hdr virtioNetHdr - pktIn []byte - wantLens []int - wantErr bool - }{ - { - "tcp4", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, - gsoSize: 100, - hdrLen: 40, - csumStart: 20, - csumOffset: 16, - }, - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{140, 140}, - false, - }, - { - "tcp6", - virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, - gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, - gsoSize: 100, - hdrLen: 60, - csumStart: 40, - csumOffset: 16, - }, - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if n != len(tt.wantLens) { - t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) - } - for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) - } - } - }) - } -} - -func flipTCP4Checksum(b []byte) []byte { - at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 - b[at] ^= 0xFF - b[at+1] ^= 0xFF - return b -} - -func Fuzz_handleGRO(f *testing.F) { - pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) - pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) - pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) - pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) - pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) - pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) - f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) - f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { - pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true - } - }) -} - -func Test_handleGRO(t *testing.T) { - tests := []struct { - name string - pktsIn [][]byte - wantToWrite []int - wantLens []int - wantErr bool - }{ - { - "multiple flows", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 - }, - []int{0, 2, 3, 5}, - []int{240, 140, 260, 160}, - false, - }, - { - "PSH interleaved", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 - }, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, - false, - }, - { - "coalesceItemInvalidCSum", - [][]byte{ - flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0, 1}, - []int{140, 240}, - false, - }, - { - "out of order", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 - }, - []int{0}, - []int{340}, - false, - }, - { - "tcp4 unequal TTL", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TTL++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal ToS", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.TOS++ - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags more fragments set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 1 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp4 unequal flags DF set", - [][]byte{ - tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), - tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { - fields.Flags = 2 - }), - }, - []int{0, 1}, - []int{140, 140}, - false, - }, - { - "tcp6 unequal hop limit", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.HopLimit++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - { - "tcp6 unequal traffic class", - [][]byte{ - tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), - tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { - fields.TrafficClass++ - }), - }, - []int{0, 1}, - []int{160, 160}, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) - if err != nil { - if tt.wantErr { - return - } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) - } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) - } - } - }) - } -} - -func Test_isTCP4NoIPOptions(t *testing.T) { - valid := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] - invalidLen := valid[:39] - invalidHeaderLen := make([]byte, len(valid)) - copy(invalidHeaderLen, valid) - invalidHeaderLen[0] = 0x46 - invalidProtocol := make([]byte, len(valid)) - copy(invalidProtocol, valid) - invalidProtocol[9] = unix.IPPROTO_TCP + 1 - - tests := []struct { - name string - b []byte - want bool - }{ - { - "valid", - valid, - true, - }, - { - "invalid length", - invalidLen, - false, - }, - { - "invalid version", - []byte{0x00}, - false, - }, - { - "invalid header len", - invalidHeaderLen, - false, - }, - { - "invalid protocol", - invalidProtocol, - false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := isTCP4NoIPOptions(tt.b); got != tt.want { - t.Errorf("isTCP4NoIPOptions() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 deleted file mode 100644 index 5461e79..0000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d deleted file mode 100644 index b441819..0000000 --- a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d +++ /dev/null @@ -1,8 +0,0 @@ -go test fuzz v1 -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -[]byte("0") -int(-48) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 12cd49f..bd69cb5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -38,6 +38,7 @@ type NativeTun struct { statusListenersShutdown chan struct{} batchSize int vnetHdr bool + udpGSO bool closeOnce sync.Once @@ -48,9 +49,10 @@ type NativeTun struct { readOpMu sync.Mutex // readOpMu guards readBuff readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr - writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable - toWrite []int - tcp4GROTable, tcp6GROTable *tcpGROTable + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -333,8 +335,8 @@ func (tun *NativeTun) nameSlow() (string, error) { func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.writeOpMu.Lock() defer func() { - tun.tcp4GROTable.reset() - tun.tcp6GROTable.reset() + tun.tcpGROTable.reset() + tun.udpGROTable.reset() tun.writeOpMu.Unlock() }() var ( @@ -343,7 +345,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) if err != nil { return 0, err } @@ -394,37 +396,42 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e sizes[0] = n return 1, nil } - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } ipVersion := in[0] >> 4 switch ipVersion { case 4: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } case 6: - if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) } default: return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) } - if len(in) <= int(hdr.csumStart+12) { - return 0, errors.New("packet is too short") - } // Don't trust hdr.hdrLen from the kernel as it can be equal to the length // of the entire first packet when the kernel is handling it as part of a - // FORWARD path. Instead, parse the TCP header length and add it onto + // FORWARD path. Instead, parse the transport header length and add it onto // csumStart, which is synonymous for IP header length. - tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) - if tcpHLen < 20 || tcpHLen > 60 { - // A TCP header must be between 20 and 60 bytes in length. - return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen } - hdr.hdrLen = hdr.csumStart + tcpHLen if len(in) < int(hdr.hdrLen) { return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) @@ -438,7 +445,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) } - return tcpTSO(in, hdr, bufs, sizes, offset) + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { @@ -497,7 +504,8 @@ func (tun *NativeTun) BatchSize() int { const ( // TODO: support TSO with ECN bits - tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 ) func (tun *NativeTun) initFromFlags(name string) error { @@ -519,12 +527,17 @@ func (tun *NativeTun) initFromFlags(name string) error { } got := ifr.Uint16() if got&unix.IFF_VNET_HDR != 0 { - err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) if err != nil { return } tun.vnetHdr = true tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil } else { tun.batchSize = 1 } @@ -575,8 +588,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), } @@ -628,12 +641,12 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcp4GROTable: newTCPGROTable(), - tcp6GROTable: newTCPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { From 4ffa9c20327b9471c3eeb142347f679b69f84648 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 20 Nov 2023 16:49:06 -0800 Subject: [PATCH 3/7] device: change Peer.endpoint locking to reduce contention Access to Peer.endpoint was previously synchronized by Peer.RWMutex. This has now moved to Peer.endpoint.Mutex. Peer.SendBuffers() is now the sole caller of Endpoint.ClearSrc(), which is signaled via a new bool, Peer.endpoint.clearSrcOnTx. Previous Callers of Endpoint.ClearSrc() now set this bool, primarily via peer.markEndpointSrcForClearing(). Peer.SetEndpointFromPacket() clears Peer.endpoint.clearSrcOnTx when an updated conn.Endpoint is stored. This maintains the same event order as before, i.e. a conn.Endpoint received after peer.endpoint.clearSrcOnTx is set, but before the next Peer.SendBuffers() call results in the latest conn.Endpoint source being used for the next packet transmission. These changes result in throughput improvements for single flow, parallel (-P n) flow, and bidirectional (--bidir) flow iperf3 TCP/UDP tests as measured on both Linux and Windows. Latency under load improves especially for high throughput Linux scenarios. These improvements are likely realized on all platforms to some degree, as the changes are not platform-specific. Co-authored-by: James Tucker Signed-off-by: James Tucker Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- device/device.go | 12 ++-------- device/mobilequirks.go | 6 ++--- device/peer.go | 50 ++++++++++++++++++++++++++------------ device/sticky_linux.go | 30 +++++++++++------------ device/timers.go | 12 ++-------- device/uapi.go | 54 ++++++++++++++++++++---------------------- 6 files changed, 83 insertions(+), 81 deletions(-) diff --git a/device/device.go b/device/device.go index f9557a0..ca26d00 100644 --- a/device/device.go +++ b/device/device.go @@ -461,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -515,11 +511,7 @@ func (device *Device) BindUpdate() error { // clear cached source addresses device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() diff --git a/device/mobilequirks.go b/device/mobilequirks.go index 4e5051d..0a0080e 100644 --- a/device/mobilequirks.go +++ b/device/mobilequirks.go @@ -11,9 +11,9 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { device.net.brokenRoaming = true device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - peer.disableRoaming = peer.endpoint != nil - peer.Unlock() + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() } device.peers.RUnlock() } diff --git a/device/peer.go b/device/peer.go index 2fb5da6..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -17,17 +17,20 @@ import ( type Peer struct { isRunning atomic.Bool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer keypairs Keypairs handshake Handshake device *Device - endpoint conn.Endpoint stopping sync.WaitGroup // routines pending stop txBytes atomic.Uint64 // bytes send to peer (endpoint) rxBytes atomic.Uint64 // bytes received from peer lastHandshakeNano atomic.Int64 // nano seconds since epoch - disableRoaming bool + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool + } timers struct { retransmitHandshake *Timer @@ -74,8 +77,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // create peer peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device @@ -97,7 +98,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() // init timers peer.timersInit() @@ -116,14 +121,19 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { @@ -267,10 +277,20 @@ func (peer *Peer) Stop() { } func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if peer.disableRoaming { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { + return + } + peer.endpoint.clearSrcOnTx = true } diff --git a/device/sticky_linux.go b/device/sticky_linux.go index f9230f8..6057ff1 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -110,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -134,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { - peer.RUnlock() + peer.endpoint.Unlock() break } nlmsg := struct { @@ -188,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { diff --git a/device/timers.go b/device/timers.go index e28732c..d4a4ed4 100644 --- a/device/timers.go +++ b/device/timers.go @@ -100,11 +100,7 @@ func expiredRetransmitHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -123,11 +119,7 @@ func expiredSendKeepalive(peer *Peer) { func expiredNewHandshake(peer *Peer) { peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) } diff --git a/device/uapi.go b/device/uapi.go index 617dcd3..d81dae3 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -99,33 +99,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error { for _, peer := range device.peers.keyMap { // Serialize peer state. - // Do the work in an anonymous function so that we can use defer. - func() { - peer.RLock() - defer peer.RUnlock() + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) + } + peer.endpoint.Unlock() - keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) - keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) - sendf("protocol_version=1") - if peer.endpoint != nil { - sendf("endpoint=%s", peer.endpoint.DstToString()) - } + nano := peer.lastHandshakeNano.Load() + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() - nano := peer.lastHandshakeNano.Load() - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - sendf("last_handshake_time_sec=%d", secs) - sendf("last_handshake_time_nsec=%d", nano) - sendf("tx_bytes=%d", peer.txBytes.Load()) - sendf("rx_bytes=%d", peer.rxBytes.Load()) - sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) - - device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { - sendf("allowed_ip=%s", prefix.String()) - return true - }) - }() + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() @@ -262,7 +260,7 @@ func (peer *ipcSetPeer) handlePostConfig() { return } if peer.created { - peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil } if peer.device.isUp() { peer.Start() @@ -345,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } - peer.Lock() - defer peer.Unlock() - peer.endpoint = endpoint + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint case "persistent_keepalive_interval": device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) From 7c20311b3d30b96576a95fec31f58e4d5e0d3234 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Tue, 7 Nov 2023 15:24:21 -0800 Subject: [PATCH 4/7] device: reduce redundant per-packet overhead in RX path Peer.RoutineSequentialReceiver() deals with packet vectors and does not need to perform timer and endpoint operations for every packet in a given vector. Changing these per-packet operations to per-vector improves throughput by as much as 10% in some environments. Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- device/receive.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/device/receive.go b/device/receive.go index 4b32dc5..98e2024 100644 --- a/device/receive.go +++ b/device/receive.go @@ -445,7 +445,9 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { return } elemsContainer.Lock() - for _, elem := range elemsContainer.elems { + validTailPacket := -1 + dataPacketReceived := false + for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -455,21 +457,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - peer.SetEndpointFromPacket(elem.endpoint) + validTailPacket = i if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } - peer.timersDataReceived() + dataPacketReceived = true switch elem.packet[0] >> 4 { case 4: @@ -512,6 +512,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } if len(bufs) > 0 { _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { From 542e565baa776ed4c5c55b73ef9aa38d33d55197 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 11 Dec 2023 16:35:57 +0100 Subject: [PATCH 5/7] device: do atomic 64-bit add outside of vector loop Only bother updating the rxBytes counter once we've processed a whole vector, since additions are atomic. Signed-off-by: Jason A. Donenfeld --- device/receive.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/device/receive.go b/device/receive.go index 98e2024..1ab3e29 100644 --- a/device/receive.go +++ b/device/receive.go @@ -447,6 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { elemsContainer.Lock() validTailPacket := -1 dataPacketReceived := false + rxBytesLen := uint64(0) for i, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed @@ -463,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersHandshakeComplete() peer.SendStagedPackets() } - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) @@ -512,6 +513,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } + + peer.rxBytes.Add(rxBytesLen) if validTailPacket >= 0 { peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) peer.keepKeyFreshReceiving() From 12269c2761734b15625017d8565745096325392f Mon Sep 17 00:00:00 2001 From: Martin Basovnik Date: Fri, 10 Nov 2023 11:10:12 +0100 Subject: [PATCH 6/7] device: fix possible deadlock in close method There is a possible deadlock in `device.Close()` when you try to close the device very soon after its start. The problem is that two different methods acquire the same locks in different order: 1. device.Close() - device.ipcMutex.Lock() - device.state.Lock() 2. device.changeState(deviceState) - device.state.Lock() - device.ipcMutex.Lock() Reproducer: func TestDevice_deadlock(t *testing.T) { d := randDevice(t) d.Close() } Problem: $ go clean -testcache && go test -race -timeout 3s -run TestDevice_deadlock ./device | grep -A 10 sync.runtime_SemacquireMutex sync.runtime_SemacquireMutex(0xc000117d20?, 0x94?, 0x0?) /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25 sync.(*Mutex).lockSlow(0xc000130518) /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213 sync.(*Mutex).Lock(0xc000130518) /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55 golang.zx2c4.com/wireguard/device.(*Device).Close(0xc000130500) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:373 +0xb6 golang.zx2c4.com/wireguard/device.TestDevice_deadlock(0x0?) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device_test.go:480 +0x2c testing.tRunner(0xc00014c000, 0x131d7b0) -- sync.runtime_SemacquireMutex(0xc000130564?, 0x60?, 0xc000130548?) /usr/local/opt/go/libexec/src/runtime/sema.go:77 +0x25 sync.(*Mutex).lockSlow(0xc000130750) /usr/local/opt/go/libexec/src/sync/mutex.go:171 +0x213 sync.(*Mutex).Lock(0xc000130750) /usr/local/opt/go/libexec/src/sync/mutex.go:90 +0x55 sync.(*RWMutex).Lock(0xc000130750) /usr/local/opt/go/libexec/src/sync/rwmutex.go:147 +0x45 golang.zx2c4.com/wireguard/device.(*Device).upLocked(0xc000130500) /Users/martin.basovnik/git/basovnik/wireguard-go/device/device.go:179 +0x72 golang.zx2c4.com/wireguard/device.(*Device).changeState(0xc000130500, 0x1) Signed-off-by: Martin Basovnik Signed-off-by: Jason A. Donenfeld --- device/device.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/device/device.go b/device/device.go index ca26d00..83c33ee 100644 --- a/device/device.go +++ b/device/device.go @@ -368,10 +368,10 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - device.ipcMutex.Lock() - defer device.ipcMutex.Unlock() device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() if device.isClosed() { return } From e5f355e843a71a0492b9201884f028a01197473b Mon Sep 17 00:00:00 2001 From: Iurii Egorov Date: Sun, 14 Jan 2024 18:22:02 +0300 Subject: [PATCH 7/7] Fix incorrect configuration handling for zero-valued Jc --- device/send.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/device/send.go b/device/send.go index 8191f07..db4e1a1 100644 --- a/device/send.go +++ b/device/send.go @@ -137,10 +137,13 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - err = peer.SendBuffers(junks) - if err != nil { - peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err) - return err + if len(junks) > 0 { + err = peer.SendBuffers(junks) + + if err != nil { + peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err) + return err + } } peer.device.aSecMux.RLock()