From 07a1e55270bd34ee526ad328d597fa01a8e17619 Mon Sep 17 00:00:00 2001
From: Jordan Whited <jordan@tailscale.com>
Date: Tue, 14 Mar 2023 20:28:07 -0700
Subject: [PATCH] conn: fix getSrcFromControl() iteration

We only expect a single control message in the normal case, but this
would loop infinitely if there were more.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
---
 conn/sticky_linux.go      |  2 +-
 conn/sticky_linux_test.go | 28 ++++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
index 342e739..278eb19 100644
--- a/conn/sticky_linux.go
+++ b/conn/sticky_linux.go
@@ -25,7 +25,7 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
 	)
 
 	for len(rem) > unix.SizeofCmsghdr {
-		hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
+		hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
 		if err != nil {
 			return
 		}
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
index 672b67e..503c342 100644
--- a/conn/sticky_linux_test.go
+++ b/conn/sticky_linux_test.go
@@ -150,6 +150,34 @@ func Test_getSrcFromControl(t *testing.T) {
 			t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
 		}
 	})
+	t.Run("Multiple", func(t *testing.T) {
+		zeroControl := make([]byte, unix.CmsgSpace(0))
+		zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+		zeroHdr.SetLen(unix.CmsgLen(0))
+
+		control := make([]byte, srcControlSize)
+		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+		hdr.Level = unix.IPPROTO_IP
+		hdr.Type = unix.IP_PKTINFO
+		hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+		info.Spec_dst = [4]byte{127, 0, 0, 1}
+		info.Ifindex = 5
+
+		combined := make([]byte, 0)
+		combined = append(combined, zeroControl...)
+		combined = append(combined, control...)
+
+		ep := &StdNetEndpoint{}
+		getSrcFromControl(combined, ep)
+
+		if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
+			t.Errorf("unexpected address: %v", ep.src.Addr)
+		}
+		if ep.src.ifidx != 5 {
+			t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+		}
+	})
 }
 
 func Test_listenConfig(t *testing.T) {