tun: use add-with-carry in checksumNoFold()

Use parallel summation with native byte order per RFC 1071.
add-with-carry operation is used to add 4 words per operation.  Byteswap
is performed before and after checksumming for compatibility with old
`checksumNoFold()`.  With this we get a 30-80% speedup in `checksum()`
depending on packet sizes.

Add unit tests with comparison to a per-word implementation.

**Intel(R) Xeon(R) Silver 4210R CPU @ 2.40GHz**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 12.64   | 9.183   | 1.376456 |
| 128  | 18.52   | 12.72   | 1.455975 |
| 256  | 31.01   | 18.13   | 1.710425 |
| 512  | 54.46   | 29.03   | 1.87599  |
| 1024 | 102     | 52.2    | 1.954023 |
| 1500 | 146.8   | 81.36   | 1.804326 |
| 2048 | 196.9   | 102.5   | 1.920976 |
| 4096 | 389.8   | 200.8   | 1.941235 |
| 8192 | 767.3   | 413.3   | 1.856521 |
| 9000 | 851.7   | 448.8   | 1.897727 |
| 9001 | 854.8   | 451.9   | 1.891569 |

**AMD EPYC 7352 24-Core Processor**

| Size | OldTime | NewTime | Speedup  |
|------|---------|---------|----------|
| 64   | 9.159   | 6.949   | 1.318031 |
| 128  | 13.59   | 10.59   | 1.283286 |
| 256  | 22.37   | 14.91   | 1.500335 |
| 512  | 41.42   | 24.22   | 1.710157 |
| 1024 | 81.59   | 45.05   | 1.811099 |
| 1500 | 120.4   | 68.35   | 1.761522 |
| 2048 | 162.8   | 90.14   | 1.806079 |
| 4096 | 321.4   | 180.3   | 1.782585 |
| 8192 | 650.4   | 360.8   | 1.802661 |
| 9000 | 706.3   | 398.1   | 1.774177 |
| 9001 | 712.4   | 398.2   | 1.789051 |

Signed-off-by: Tu Dinh Ngoc <dinhngoc.tu@irit.fr>
[Jason: simplified and cleaned up unit tests]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Tu Dinh Ngoc 2024-06-20 13:28:38 +00:00 committed by jmwample
parent ac8a885a03
commit 75d6c67a67
No known key found for this signature in database
2 changed files with 116 additions and 69 deletions

View file

@ -1,102 +1,86 @@
package tun
import "encoding/binary"
import (
"encoding/binary"
"math/bits"
)
// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
tmp := make([]byte, 8)
binary.NativeEndian.PutUint64(tmp, initial)
ac := binary.BigEndian.Uint64(tmp)
var carry uint64
for len(b) >= 128 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac += uint64(binary.BigEndian.Uint32(b[64:68]))
ac += uint64(binary.BigEndian.Uint32(b[68:72]))
ac += uint64(binary.BigEndian.Uint32(b[72:76]))
ac += uint64(binary.BigEndian.Uint32(b[76:80]))
ac += uint64(binary.BigEndian.Uint32(b[80:84]))
ac += uint64(binary.BigEndian.Uint32(b[84:88]))
ac += uint64(binary.BigEndian.Uint32(b[88:92]))
ac += uint64(binary.BigEndian.Uint32(b[92:96]))
ac += uint64(binary.BigEndian.Uint32(b[96:100]))
ac += uint64(binary.BigEndian.Uint32(b[100:104]))
ac += uint64(binary.BigEndian.Uint32(b[104:108]))
ac += uint64(binary.BigEndian.Uint32(b[108:112]))
ac += uint64(binary.BigEndian.Uint32(b[112:116]))
ac += uint64(binary.BigEndian.Uint32(b[116:120]))
ac += uint64(binary.BigEndian.Uint32(b[120:124]))
ac += uint64(binary.BigEndian.Uint32(b[124:128]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[120:128]), carry)
ac += carry
b = b[128:]
}
if len(b) >= 64 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[56:64]), carry)
ac += carry
b = b[64:]
}
if len(b) >= 32 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[24:32]), carry)
ac += carry
b = b[32:]
}
if len(b) >= 16 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[8:16]), carry)
ac += carry
b = b[16:]
}
if len(b) >= 8 {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac, carry = bits.Add64(ac, binary.NativeEndian.Uint64(b[:8]), 0)
ac += carry
b = b[8:]
}
if len(b) >= 4 {
ac += uint64(binary.BigEndian.Uint32(b))
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint32(b[:4])), 0)
ac += carry
b = b[4:]
}
if len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
ac, carry = bits.Add64(ac, uint64(binary.NativeEndian.Uint16(b[:2])), 0)
ac += carry
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
tmp := binary.NativeEndian.Uint16([]byte{b[0], 0})
ac, carry = bits.Add64(ac, uint64(tmp), 0)
ac += carry
}
return ac
binary.NativeEndian.PutUint64(tmp, ac)
return binary.BigEndian.Uint64(tmp)
}
func checksum(b []byte, initial uint64) uint16 {

View file

@ -1,11 +1,74 @@
package tun
import (
"encoding/binary"
"fmt"
"math/rand"
"testing"
"golang.org/x/sys/unix"
)
func checksumRef(b []byte, initial uint16) uint16 {
ac := uint64(initial)
for len(b) >= 2 {
ac += uint64(binary.BigEndian.Uint16(b))
b = b[2:]
}
if len(b) == 1 {
ac += uint64(b[0]) << 8
}
for (ac >> 16) > 0 {
ac = (ac >> 16) + (ac & 0xffff)
}
return uint16(ac)
}
func pseudoHeaderChecksumRefNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
sum := checksumRef(srcAddr, 0)
sum = checksumRef(dstAddr, sum)
sum = checksumRef([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumRef(tmp, sum)
}
func TestChecksum(t *testing.T) {
for length := 0; length <= 9001; length++ {
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(buf)
csum := checksum(buf, 0x1234)
csumRef := checksumRef(buf, 0x1234)
if csum != csumRef {
t.Error("Expected checksum", csumRef, "got", csum)
}
}
}
func TestPseudoHeaderChecksum(t *testing.T) {
for _, addrLen := range []int{4, 16} {
for length := 0; length <= 9001; length++ {
srcAddr := make([]byte, addrLen)
dstAddr := make([]byte, addrLen)
buf := make([]byte, length)
rng := rand.New(rand.NewSource(1))
rng.Read(srcAddr)
rng.Read(dstAddr)
rng.Read(buf)
phSum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csum := checksum(buf, phSum)
phSumRef := pseudoHeaderChecksumRefNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(length))
csumRef := checksumRef(buf, phSumRef)
if csum != csumRef {
t.Error("Expected checksumRef", csumRef, "got", csum)
}
}
}
}
func BenchmarkChecksum(b *testing.B) {
lengths := []int{
64,