From 2fcdaf979915be4702bf8aba4a90ac3c3ae0796b Mon Sep 17 00:00:00 2001
From: Jordan Whited <jordan@tailscale.com>
Date: Mon, 6 Mar 2023 15:58:32 -0800
Subject: [PATCH] conn: fix StdNetBind fallback on Windows

If RIO is unavailable, NewWinRingBind() falls back to StdNetBind.
StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving
datagrams, specifically via the {Read,Write}Batch methods.
These methods are unimplemented on Windows and will return runtime
errors as a result. Additionally, only Linux benefits from these
x/net types for reading and writing, so we update StdNetBind to fall
back to the standard library net package for all platforms other than
Linux.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
---
 conn/bind_std.go      | 192 ++++++++++++++++++++++++++++--------------
 conn/bind_std_test.go |  22 +++++
 2 files changed, 150 insertions(+), 64 deletions(-)
 create mode 100644 conn/bind_std_test.go

diff --git a/conn/bind_std.go b/conn/bind_std.go
index b9da4c3..a842b12 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -10,6 +10,7 @@ import (
 	"errors"
 	"net"
 	"net/netip"
+	"runtime"
 	"strconv"
 	"sync"
 	"syscall"
@@ -22,16 +23,21 @@ var (
 	_ Bind = (*StdNetBind)(nil)
 )
 
-// StdNetBind implements Bind for all platforms except Windows.
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
 type StdNetBind struct {
-	mu           sync.Mutex // protects following fields
-	ipv4         *net.UDPConn
-	ipv6         *net.UDPConn
-	blackhole4   bool
-	blackhole6   bool
-	ipv4PC       *ipv4.PacketConn
-	ipv6PC       *ipv6.PacketConn
-	udpAddrPool  sync.Pool
+	mu         sync.Mutex // protects following fields
+	ipv4       *net.UDPConn
+	ipv6       *net.UDPConn
+	blackhole4 bool
+	blackhole6 bool
+	ipv4PC     *ipv4.PacketConn // will be nil on non-Linux
+	ipv6PC     *ipv6.PacketConn // will be nil on non-Linux
+
+	udpAddrPool  sync.Pool // following fields are not guarded by mu
 	ipv4MsgsPool sync.Pool
 	ipv6MsgsPool sync.Pool
 }
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
 again:
 	port := int(uport)
 	var v4conn, v6conn *net.UDPConn
+	var v4pc *ipv4.PacketConn
+	var v6pc *ipv6.PacketConn
 
 	v4conn, port, err = listenNet("udp4", port)
 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
@@ -173,63 +181,92 @@ again:
 	}
 	var fns []ReceiveFunc
 	if v4conn != nil {
-		fns = append(fns, s.receiveIPv4)
+		if runtime.GOOS == "linux" {
+			v4pc = ipv4.NewPacketConn(v4conn)
+			s.ipv4PC = v4pc
+		}
+		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
 		s.ipv4 = v4conn
 	}
 	if v6conn != nil {
-		fns = append(fns, s.receiveIPv6)
+		if runtime.GOOS == "linux" {
+			v6pc = ipv6.NewPacketConn(v6conn)
+			s.ipv6PC = v6pc
+		}
+		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
 		s.ipv6 = v6conn
 	}
 	if len(fns) == 0 {
 		return nil, 0, syscall.EAFNOSUPPORT
 	}
 
-	s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
-	s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-
 	return fns, uint16(port), nil
 }
 
-func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-	msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-	defer s.ipv4MsgsPool.Put(msgs)
-	for i := range buffs {
-		(*msgs)[i].Buffers[0] = buffs[i]
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
+	return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+		msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+		defer s.ipv4MsgsPool.Put(msgs)
+		for i := range buffs {
+			(*msgs)[i].Buffers[0] = buffs[i]
+		}
+		var numMsgs int
+		if runtime.GOOS == "linux" {
+			numMsgs, err = pc.ReadBatch(*msgs, 0)
+			if err != nil {
+				return 0, err
+			}
+		} else {
+			msg := &(*msgs)[0]
+			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+			if err != nil {
+				return 0, err
+			}
+			numMsgs = 1
+		}
+		for i := 0; i < numMsgs; i++ {
+			msg := &(*msgs)[i]
+			sizes[i] = msg.N
+			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+			ep := asEndpoint(addrPort)
+			getSrcFromControl(msg.OOB, ep)
+			eps[i] = ep
+		}
+		return numMsgs, nil
 	}
-	numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
-	if err != nil {
-		return 0, err
-	}
-	for i := 0; i < numMsgs; i++ {
-		msg := &(*msgs)[i]
-		sizes[i] = msg.N
-		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-		ep := asEndpoint(addrPort)
-		getSrcFromControl(msg.OOB, ep)
-		eps[i] = ep
-	}
-	return numMsgs, nil
 }
 
-func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-	msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-	defer s.ipv6MsgsPool.Put(msgs)
-	for i := range buffs {
-		(*msgs)[i].Buffers[0] = buffs[i]
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+	return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+		msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
+		defer s.ipv4MsgsPool.Put(msgs)
+		for i := range buffs {
+			(*msgs)[i].Buffers[0] = buffs[i]
+		}
+		var numMsgs int
+		if runtime.GOOS == "linux" {
+			numMsgs, err = pc.ReadBatch(*msgs, 0)
+			if err != nil {
+				return 0, err
+			}
+		} else {
+			msg := &(*msgs)[0]
+			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+			if err != nil {
+				return 0, err
+			}
+			numMsgs = 1
+		}
+		for i := 0; i < numMsgs; i++ {
+			msg := &(*msgs)[i]
+			sizes[i] = msg.N
+			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+			ep := asEndpoint(addrPort)
+			getSrcFromControl(msg.OOB, ep)
+			eps[i] = ep
+		}
+		return numMsgs, nil
 	}
-	numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
-	if err != nil {
-		return 0, err
-	}
-	for i := 0; i < numMsgs; i++ {
-		msg := &(*msgs)[i]
-		sizes[i] = msg.N
-		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-		ep := asEndpoint(addrPort)
-		getSrcFromControl(msg.OOB, ep)
-		eps[i] = ep
-	}
-	return numMsgs, nil
 }
 
 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
 	if s.ipv4 != nil {
 		err1 = s.ipv4.Close()
 		s.ipv4 = nil
+		s.ipv4PC = nil
 	}
 	if s.ipv6 != nil {
 		err2 = s.ipv6.Close()
 		s.ipv6 = nil
+		s.ipv6PC = nil
 	}
 	s.blackhole4 = false
 	s.blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
 	s.mu.Lock()
 	blackhole := s.blackhole4
 	conn := s.ipv4
+	var (
+		pc4 *ipv4.PacketConn
+		pc6 *ipv6.PacketConn
+	)
 	is6 := false
 	if endpoint.DstIP().Is6() {
 		blackhole = s.blackhole6
 		conn = s.ipv6
+		pc6 = s.ipv6PC
 		is6 = true
+	} else {
+		pc4 = s.ipv4PC
 	}
 	s.mu.Unlock()
 
@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
 		return syscall.EAFNOSUPPORT
 	}
 	if is6 {
-		return s.send6(s.ipv6PC, endpoint, buffs)
+		return s.send6(conn, pc6, endpoint, buffs)
 	} else {
-		return s.send4(s.ipv4PC, endpoint, buffs)
+		return s.send4(conn, pc4, endpoint, buffs)
 	}
 }
 
-func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
 	ua := s.udpAddrPool.Get().(*net.UDPAddr)
 	as4 := ep.DstIP().As4()
 	copy(ua.IP, as4[:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
 		err   error
 		start int
 	)
-	for {
-		n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
-		if err != nil || n == len((*msgs)[start:len(buffs)]) {
-			break
+	if runtime.GOOS == "linux" {
+		for {
+			n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+			if err != nil || n == len((*msgs)[start:len(buffs)]) {
+				break
+			}
+			start += n
+		}
+	} else {
+		for i, buff := range buffs {
+			_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+			if err != nil {
+				break
+			}
 		}
-		start += n
 	}
 	s.udpAddrPool.Put(ua)
 	s.ipv4MsgsPool.Put(msgs)
 	return err
 }
 
-func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
 	ua := s.udpAddrPool.Get().(*net.UDPAddr)
 	as16 := ep.DstIP().As16()
 	copy(ua.IP, as16[:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
 		err   error
 		start int
 	)
-	for {
-		n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
-		if err != nil || n == len((*msgs)[start:len(buffs)]) {
-			break
+	if runtime.GOOS == "linux" {
+		for {
+			n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+			if err != nil || n == len((*msgs)[start:len(buffs)]) {
+				break
+			}
+			start += n
+		}
+	} else {
+		for i, buff := range buffs {
+			_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+			if err != nil {
+				break
+			}
 		}
-		start += n
 	}
 	s.udpAddrPool.Put(ua)
 	s.ipv6MsgsPool.Put(msgs)
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644
index 0000000..76afa30
--- /dev/null
+++ b/conn/bind_std_test.go
@@ -0,0 +1,22 @@
+package conn
+
+import "testing"
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+	bind := NewStdNetBind().(*StdNetBind)
+	fns, _, err := bind.Open(0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	bind.Close()
+	buffs := make([][]byte, 1)
+	buffs[0] = make([]byte, 1)
+	sizes := make([]int, 1)
+	eps := make([]Endpoint, 1)
+	for _, fn := range fns {
+		// The ReceiveFuncs must not access conn-related fields on StdNetBind
+		// unguarded. Close() nils the conn-related fields resulting in a panic
+		// if they violate the mutex.
+		fn(buffs, sizes, eps)
+	}
+}