diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 70ea609..9eec384 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
 
 // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
 type LinuxSocketBind struct {
-	sock4    int
-	sock6    int
-	lastMark uint32
-	closing  sync.RWMutex
+	// mu guards sock4 and sock6 and the associated fds.
+	// As long as someone holds mu (read or write), the associated fds are valid.
+	mu    sync.RWMutex
+	sock4 int
+	sock6 int
 }
 
 func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
 	return nil, errors.New("invalid IP address")
 }
 
-func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
+func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
+
 	var err error
 	var newPort uint16
 	var tries int
 
 	if bind.sock4 != -1 || bind.sock6 != -1 {
-		return 0, ErrBindAlreadyOpen
+		return nil, 0, ErrBindAlreadyOpen
 	}
 
 	originalPort := port
 
 again:
 	port = originalPort
+	var sock4, sock6 int
 	// Attempt ipv6 bind, update port if successful.
-	bind.sock6, newPort, err = create6(port)
+	sock6, newPort, err = create6(port)
 	if err != nil {
-		if err != syscall.EAFNOSUPPORT {
-			return 0, err
+		if !errors.Is(err, syscall.EAFNOSUPPORT) {
+			return nil, 0, err
 		}
 	} else {
 		port = newPort
 	}
 
 	// Attempt ipv4 bind, update port if successful.
-	bind.sock4, newPort, err = create4(port)
+	sock4, newPort, err = create4(port)
 	if err != nil {
-		if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
-			unix.Close(bind.sock6)
+		if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+			unix.Close(sock6)
 			tries++
 			goto again
 		}
-		if err != syscall.EAFNOSUPPORT {
-			unix.Close(bind.sock6)
-			return 0, err
+		if !errors.Is(err, syscall.EAFNOSUPPORT) {
+			unix.Close(sock6)
+			return nil, 0, err
 		}
 	} else {
 		port = newPort
 	}
 
-	if bind.sock4 == -1 && bind.sock6 == -1 {
-		return 0, syscall.EAFNOSUPPORT
+	var fns []ReceiveFunc
+	if sock4 != -1 {
+		fns = append(fns, makeReceiveIPv4(sock4))
+		bind.sock4 = sock4
 	}
-	return port, nil
+	if sock6 != -1 {
+		fns = append(fns, makeReceiveIPv6(sock6))
+		bind.sock6 = sock6
+	}
+	if len(fns) == 0 {
+		return nil, 0, syscall.EAFNOSUPPORT
+	}
+	return fns, port, nil
 }
 
 func (bind *LinuxSocketBind) SetMark(value uint32) error {
-	bind.closing.RLock()
-	defer bind.closing.RUnlock()
+	bind.mu.RLock()
+	defer bind.mu.RUnlock()
 
 	if bind.sock6 != -1 {
 		err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
 		}
 	}
 
-	bind.lastMark = value
 	return nil
 }
 
 func (bind *LinuxSocketBind) Close() error {
-	var err1, err2 error
-	bind.closing.RLock()
+	// Take a readlock to shut down the sockets...
+	bind.mu.RLock()
 	if bind.sock6 != -1 {
 		unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
 	}
 	if bind.sock4 != -1 {
 		unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
 	}
-	bind.closing.RUnlock()
-	bind.closing.Lock()
+	bind.mu.RUnlock()
+	// ...and a write lock to close the fd.
+	// This ensures that no one else is using the fd.
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
+	var err1, err2 error
 	if bind.sock6 != -1 {
 		err1 = unix.Close(bind.sock6)
 		bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
 		err2 = unix.Close(bind.sock4)
 		bind.sock4 = -1
 	}
-	bind.closing.Unlock()
 
 	if err1 != nil {
 		return err1
@@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
 	return err2
 }
 
-func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
-	bind.closing.RLock()
-	defer bind.closing.RUnlock()
-
-	var end LinuxSocketEndpoint
-	if bind.sock6 == -1 {
-		return 0, nil, net.ErrClosed
+func makeReceiveIPv6(sock int) ReceiveFunc {
+	return func(buff []byte) (int, Endpoint, error) {
+		var end LinuxSocketEndpoint
+		n, err := receive6(sock, buff, &end)
+		return n, &end, err
 	}
-	n, err := receive6(
-		bind.sock6,
-		buff,
-		&end,
-	)
-	return n, &end, err
 }
 
-func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
-	bind.closing.RLock()
-	defer bind.closing.RUnlock()
-
-	var end LinuxSocketEndpoint
-	if bind.sock4 == -1 {
-		return 0, nil, net.ErrClosed
+func makeReceiveIPv4(sock int) ReceiveFunc {
+	return func(buff []byte) (int, Endpoint, error) {
+		var end LinuxSocketEndpoint
+		n, err := receive4(sock, buff, &end)
+		return n, &end, err
 	}
-	n, err := receive4(
-		bind.sock4,
-		buff,
-		&end,
-	)
-	return n, &end, err
 }
 
 func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
-	bind.closing.RLock()
-	defer bind.closing.RUnlock()
-
 	nend, ok := end.(*LinuxSocketEndpoint)
 	if !ok {
 		return ErrWrongEndpointType
 	}
+	bind.mu.RLock()
+	defer bind.mu.RUnlock()
 	if !nend.isV6 {
 		if bind.sock4 == -1 {
 			return net.ErrClosed
diff --git a/conn/bind_std.go b/conn/bind_std.go
index f8b8a1b..5261779 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -8,6 +8,7 @@ package conn
 import (
 	"errors"
 	"net"
+	"sync"
 	"syscall"
 )
 
@@ -16,6 +17,7 @@ import (
 // It uses the Go's net package to implement networking.
 // See LinuxSocketBind for a proper implementation on the Linux platform.
 type StdNetBind struct {
+	mu         sync.Mutex // protects following fields
 	ipv4       *net.UDPConn
 	ipv6       *net.UDPConn
 	blackhole4 bool
@@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
 	return conn, uaddr.Port, nil
 }
 
-func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
+func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
+
 	var err error
 	var tries int
 
 	if bind.ipv4 != nil || bind.ipv6 != nil {
-		return 0, ErrBindAlreadyOpen
+		return nil, 0, ErrBindAlreadyOpen
 	}
 
 	// Attempt to open ipv4 and ipv6 listeners on the same port.
@@ -97,7 +102,7 @@ again:
 
 	ipv4, port, err = listenNet("udp4", port)
 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
-		return 0, err
+		return nil, 0, err
 	}
 
 	// Listen on the same port as we're using for ipv4.
@@ -109,17 +114,27 @@ again:
 	}
 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
 		ipv4.Close()
-		return 0, err
+		return nil, 0, err
 	}
-	if ipv4 == nil && ipv6 == nil {
-		return 0, syscall.EAFNOSUPPORT
+	var fns []ReceiveFunc
+	if ipv4 != nil {
+		fns = append(fns, makeReceiveFunc(ipv4, true))
+		bind.ipv4 = ipv4
 	}
-	bind.ipv4 = ipv4
-	bind.ipv6 = ipv6
-	return uint16(port), nil
+	if ipv6 != nil {
+		fns = append(fns, makeReceiveFunc(ipv6, false))
+		bind.ipv6 = ipv6
+	}
+	if len(fns) == 0 {
+		return nil, 0, syscall.EAFNOSUPPORT
+	}
+	return fns, uint16(port), nil
 }
 
 func (bind *StdNetBind) Close() error {
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
+
 	var err1, err2 error
 	if bind.ipv4 != nil {
 		err1 = bind.ipv4.Close()
@@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
 	return err2
 }
 
-func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
-	if bind.ipv4 == nil {
-		return 0, nil, syscall.EAFNOSUPPORT
+func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
+	return func(buff []byte) (int, Endpoint, error) {
+		n, endpoint, err := conn.ReadFromUDP(buff)
+		if isIPv4 && endpoint != nil {
+			endpoint.IP = endpoint.IP.To4()
+		}
+		return n, (*StdNetEndpoint)(endpoint), err
 	}
-	n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
-	if endpoint != nil {
-		endpoint.IP = endpoint.IP.To4()
-	}
-	return n, (*StdNetEndpoint)(endpoint), err
-}
-
-func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
-	if bind.ipv6 == nil {
-		return 0, nil, syscall.EAFNOSUPPORT
-	}
-	n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
-	return n, (*StdNetEndpoint)(endpoint), err
 }
 
 func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
 	if !ok {
 		return ErrWrongEndpointType
 	}
-	var conn *net.UDPConn
-	var blackhole bool
-	if nend.IP.To4() != nil {
-		blackhole = bind.blackhole4
-		conn = bind.ipv4
-	} else {
+
+	bind.mu.Lock()
+	blackhole := bind.blackhole4
+	conn := bind.ipv4
+	if nend.IP.To4() == nil {
 		blackhole = bind.blackhole6
 		conn = bind.ipv6
 	}
+	bind.mu.Unlock()
+
 	if blackhole {
 		return nil
 	}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 1e2712e..6cabee1 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
 	return sa, nil
 }
 
-func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
+func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
 	bind.mu.Lock()
 	defer bind.mu.Unlock()
 	defer func() {
@@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
 		}
 	}()
 	if atomic.LoadUint32(&bind.isOpen) != 0 {
-		return 0, ErrBindAlreadyOpen
+		return nil, 0, ErrBindAlreadyOpen
 	}
 	var sa windows.Sockaddr
 	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
 	if err != nil {
-		return 0, err
+		return nil, 0, err
 	}
 	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
 	if err != nil {
-		return 0, err
+		return nil, 0, err
 	}
 	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
 	for i := 0; i < packetsPerRing; i++ {
 		err = bind.v4.InsertReceiveRequest()
 		if err != nil {
-			return 0, err
+			return nil, 0, err
 		}
 		err = bind.v6.InsertReceiveRequest()
 		if err != nil {
-			return 0, err
+			return nil, 0, err
 		}
 	}
 	atomic.StoreUint32(&bind.isOpen, 1)
-	return
+	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
 }
 
 func (bind *WinRingBind) Close() error {
@@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
 	return n, &ep, nil
 }
 
-func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
 	bind.mu.RLock()
 	defer bind.mu.RUnlock()
 	return bind.v4.Receive(buf, &bind.isOpen)
 }
 
-func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
 	bind.mu.RLock()
 	defer bind.mu.RUnlock()
 	return bind.v6.Receive(buf, &bind.isOpen)
@@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
 }
 
 func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
 	sysconn, err := bind.ipv4.SyscallConn()
 	if err != nil {
 		return err
@@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
 }
 
 func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+	bind.mu.Lock()
+	defer bind.mu.Unlock()
 	sysconn, err := bind.ipv6.SyscallConn()
 	if err != nil {
 		return err
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index ad8fa05..7d43fb3 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
 
 func (c ChannelEndpoint) SrcIP() net.IP { return nil }
 
-func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
+func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
 	c.closeSignal = make(chan bool)
+	fns = append(fns, c.makeReceiveFunc(*c.rx4))
+	fns = append(fns, c.makeReceiveFunc(*c.rx6))
 	if rand.Uint32()&1 == 0 {
-		return uint16(c.source4), nil
+		return fns, uint16(c.source4), nil
 	} else {
-		return uint16(c.source6), nil
+		return fns, uint16(c.source6), nil
 	}
 }
 
@@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error {
 
 func (c *ChannelBind) SetMark(mark uint32) error { return nil }
 
-func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
-	select {
-	case <-c.closeSignal:
-		return 0, nil, net.ErrClosed
-	case rx := <-*c.rx6:
-		return copy(b, rx), c.target6, nil
-	}
-}
-
-func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
-	select {
-	case <-c.closeSignal:
-		return 0, nil, net.ErrClosed
-	case rx := <-*c.rx4:
-		return copy(b, rx), c.target4, nil
+func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
+	return func(b []byte) (n int, ep conn.Endpoint, err error) {
+		select {
+		case <-c.closeSignal:
+			return 0, nil, net.ErrClosed
+		case rx := <-ch:
+			return copy(b, rx), c.target6, nil
+		}
 	}
 }
 
diff --git a/conn/conn.go b/conn/conn.go
index 6fd232f..3c7fcd0 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -12,6 +12,11 @@ import (
 	"strings"
 )
 
+// A ReceiveFunc receives a single inbound packet from the network.
+// It writes the data into b. n is the length of the packet.
+// ep is the remote endpoint.
+type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+
 // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
 //
 // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
@@ -19,23 +24,17 @@ import (
 type Bind interface {
 	// Open puts the Bind into a listening state on a given port and reports the actual
 	// port that it bound to. Passing zero results in a random selection.
-	Open(port uint16) (actualPort uint16, err error)
+	// fns is the set of functions that will be called to receive packets.
+	Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
 
 	// Close closes the Bind listener.
+	// All fns returned by Open must return net.ErrClosed after a call to Close.
 	Close() error
 
 	// SetMark sets the mark for each packet sent through this Bind.
 	// This mark is passed to the kernel as the socket option SO_MARK.
 	SetMark(mark uint32) error
 
-	// ReceiveIPv6 reads an IPv6 UDP packet into b.  It reports the number of bytes read,
-	// n, the packet source address ep, and any error.
-	ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
-
-	// ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
-	// n, the packet source address ep, and any error.
-	ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
-
 	// Send writes a packet b to address ep.
 	Send(b []byte, ep Endpoint) error
 
diff --git a/device/device.go b/device/device.go
index 1e32db6..a635e68 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,9 +11,6 @@ import (
 	"sync/atomic"
 	"time"
 
-	"golang.org/x/net/ipv4"
-	"golang.org/x/net/ipv6"
-
 	"golang.zx2c4.com/wireguard/conn"
 	"golang.zx2c4.com/wireguard/ratelimiter"
 	"golang.zx2c4.com/wireguard/rwcancel"
@@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error {
 
 	// bind to new port
 	var err error
+	var recvFns []conn.ReceiveFunc
 	netc := &device.net
-	netc.port, err = netc.bind.Open(netc.port)
+	recvFns, netc.port, err = netc.bind.Open(netc.port)
 	if err != nil {
 		netc.port = 0
 		return err
@@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error {
 	device.peers.RUnlock()
 
 	// start receiving routines
-	device.net.stopping.Add(2)
-	device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
-	device.queue.handshake.wg.Add(2)  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
-	go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
-	go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+	device.net.stopping.Add(len(recvFns))
+	device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+	device.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+	for _, fn := range recvFns {
+		go device.RoutineReceiveIncoming(fn)
+	}
 
 	device.log.Verbosef("UDP bind has been updated")
 	return nil
diff --git a/device/receive.go b/device/receive.go
index 5ddb66c..fa5c0a6 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
  * Every time the bind is updated a new routine is started for
  * IPv4 and IPv6 (separately)
  */
-func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
+func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
 	defer func() {
-		device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
+		device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
 		device.queue.decryption.wg.Done()
 		device.queue.handshake.wg.Done()
 		device.net.stopping.Done()
 	}()
 
-	device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
+	device.log.Verbosef("Routine: receive incoming %p - started", recv)
 
 	// receive datagrams until conn is closed
 
@@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
 	)
 
 	for {
-		switch IP {
-		case ipv4.Version:
-			size, endpoint, err = bind.ReceiveIPv4(buffer[:])
-		case ipv6.Version:
-			size, endpoint, err = bind.ReceiveIPv6(buffer[:])
-		default:
-			panic("invalid IP version")
-		}
+		size, endpoint, err = recv(buffer[:])
 
 		if err != nil {
 			device.PutMessageBuffer(buffer)