git: 7cbb6b6e28db - main - inpcb: Close some SO_REUSEPORT_LB races, part 2

From: Mark Johnston <markj_at_FreeBSD.org>
Date: Thu, 23 Jan 2025 17:12:27 UTC
The branch main has been updated by markj:

URL: https://cgit.FreeBSD.org/src/commit/?id=7cbb6b6e28db33095a1cf7a8887921a5ec969824

commit 7cbb6b6e28db33095a1cf7a8887921a5ec969824
Author:     Mark Johnston <markj@FreeBSD.org>
AuthorDate: 2025-01-23 17:00:11 +0000
Commit:     Mark Johnston <markj@FreeBSD.org>
CommitDate: 2025-01-23 17:12:10 +0000

    inpcb: Close some SO_REUSEPORT_LB races, part 2
    
    Suppose a thread is adds a socket to an existing TCP lbgroup that is
    actively accepting connections.  It has to do the following operations:
    1. set SO_REUSEPORT_LB on the socket
    2. bind() the socket to the shared address/port
    3. call listen()
    
    Step 2 makes the inpcb visible to incoming connection requests.
    However, at this point the inpcb cannot accept new connections.  If
    in_pcblookup() matches it, the remote end will see ECONNREFUSED even
    when other listening sockets are present in the lbgroup.  This means
    that dynamically adding inpcbs to an lbgroup (e.g., by starting up new
    workers) can trigger spurious connection failures for no good reason.
    (A similar problem exists when removing inpcbs from an lbgroup, but that
    is harder to fix and is not addressed by this patch; see the review for
    a bit more commentary.)
    
    Fix this by augmenting each lbgroup with a linked list of inpcbs that
    are pending a listen() call.  When adding an inpcb to an lbgroup, keep
    the inpcb on this list if listen() hasn't been called, so it is not yet
    visible to the lookup path.  Then, add a new in_pcblisten() routine which
    makes the inpcb visible within the lbgroup now that it's safe to let it
    handle new connections.
    
    Add a regression test which verifies that we don't get spurious
    connection errors while adding sockets to an LB group.
    
    Reviewed by:    glebius
    MFC after:      1 month
    Sponsored by:   Klara, Inc.
    Sponsored by:   Stormshield
    Differential Revision:  https://reviews.freebsd.org/D48544
---
 sys/kern/uipc_domain.c                   |   2 +-
 sys/netinet/in_pcb.c                     | 131 ++++++++++++++++++++--------
 sys/netinet/in_pcb.h                     |   7 +-
 sys/netinet/tcp_usrreq.c                 |   4 +
 sys/sys/socketvar.h                      |   2 +
 tests/sys/netinet/Makefile               |   2 +
 tests/sys/netinet/so_reuseport_lb_test.c | 143 +++++++++++++++++++++++++++++++
 7 files changed, 252 insertions(+), 39 deletions(-)

diff --git a/sys/kern/uipc_domain.c b/sys/kern/uipc_domain.c
index 43bdd44a09bf..c5296f12ba94 100644
--- a/sys/kern/uipc_domain.c
+++ b/sys/kern/uipc_domain.c
@@ -109,7 +109,7 @@ pr_disconnect_notsupp(struct socket *so)
 	return (EOPNOTSUPP);
 }
 
-static int
+int
 pr_listen_notsupp(struct socket *so, int backlog, struct thread *td)
 {
 	return (EOPNOTSUPP);
diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index 11bc68a3915a..c50e39f20c16 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -263,6 +263,7 @@ in_pcblbgroup_alloc(struct ucred *cred, u_char vflag, uint16_t port,
 	grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT);
 	if (grp == NULL)
 		return (NULL);
+	LIST_INIT(&grp->il_pending);
 	grp->il_cred = crhold(cred);
 	grp->il_vflag = vflag;
 	grp->il_lport = port;
@@ -285,11 +286,45 @@ in_pcblbgroup_free_deferred(epoch_context_t ctx)
 static void
 in_pcblbgroup_free(struct inpcblbgroup *grp)
 {
+	KASSERT(LIST_EMPTY(&grp->il_pending),
+	    ("local group %p still has pending inps", grp));
 
 	CK_LIST_REMOVE(grp, il_list);
 	NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx);
 }
 
+static struct inpcblbgroup *
+in_pcblbgroup_find(struct inpcb *inp)
+{
+	struct inpcbinfo *pcbinfo;
+	struct inpcblbgroup *grp;
+	struct inpcblbgrouphead *hdr;
+
+	INP_LOCK_ASSERT(inp);
+
+	pcbinfo = inp->inp_pcbinfo;
+	INP_HASH_LOCK_ASSERT(pcbinfo);
+	KASSERT((inp->inp_flags & INP_INLBGROUP) != 0,
+	    ("inpcb %p is not in a load balance group", inp));
+
+	hdr = &pcbinfo->ipi_lbgrouphashbase[
+	    INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)];
+	CK_LIST_FOREACH(grp, hdr, il_list) {
+		struct inpcb *inp1;
+
+		for (unsigned int i = 0; i < grp->il_inpcnt; i++) {
+			if (inp == grp->il_inp[i])
+				goto found;
+		}
+		LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) {
+			if (inp == inp1)
+				goto found;
+		}
+	}
+found:
+	return (grp);
+}
+
 static void
 in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp)
 {
@@ -298,14 +333,24 @@ in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp)
 	    grp->il_inpcnt));
 	INP_WLOCK_ASSERT(inp);
 
-	inp->inp_flags |= INP_INLBGROUP;
-	grp->il_inp[grp->il_inpcnt] = inp;
+	if (inp->inp_socket->so_proto->pr_listen != pr_listen_notsupp &&
+	    !SOLISTENING(inp->inp_socket)) {
+		/*
+		 * If this is a TCP socket, it should not be visible to lbgroup
+		 * lookups until listen() has been called.
+		 */
+		LIST_INSERT_HEAD(&grp->il_pending, inp, inp_lbgroup_list);
+	} else {
+		grp->il_inp[grp->il_inpcnt] = inp;
 
-	/*
-	 * Synchronize with in_pcblookup_lbgroup(): make sure that we don't
-	 * expose a null slot to the lookup path.
-	 */
-	atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1);
+		/*
+		 * Synchronize with in_pcblookup_lbgroup(): make sure that we
+		 * don't expose a null slot to the lookup path.
+		 */
+		atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1);
+	}
+
+	inp->inp_flags |= INP_INLBGROUP;
 }
 
 static struct inpcblbgroup *
@@ -329,6 +374,8 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
 		grp->il_inp[i] = old_grp->il_inp[i];
 	grp->il_inpcnt = old_grp->il_inpcnt;
 	CK_LIST_INSERT_HEAD(hdr, grp, il_list);
+	LIST_SWAP(&old_grp->il_pending, &grp->il_pending, inpcb,
+	    inp_lbgroup_list);
 	in_pcblbgroup_free(old_grp);
 	return (grp);
 }
@@ -412,6 +459,7 @@ in_pcbremlbgrouphash(struct inpcb *inp)
 	struct inpcbinfo *pcbinfo;
 	struct inpcblbgrouphead *hdr;
 	struct inpcblbgroup *grp;
+	struct inpcb *inp1;
 	int i;
 
 	pcbinfo = inp->inp_pcbinfo;
@@ -427,13 +475,11 @@ in_pcbremlbgrouphash(struct inpcb *inp)
 			if (grp->il_inp[i] != inp)
 				continue;
 
-			if (grp->il_inpcnt == 1) {
+			if (grp->il_inpcnt == 1 &&
+			    LIST_EMPTY(&grp->il_pending)) {
 				/* We are the last, free this local group. */
 				in_pcblbgroup_free(grp);
 			} else {
-				KASSERT(grp->il_inpcnt >= 2,
-				    ("invalid local group count %d",
-				    grp->il_inpcnt));
 				grp->il_inp[i] =
 				    grp->il_inp[grp->il_inpcnt - 1];
 
@@ -446,17 +492,22 @@ in_pcbremlbgrouphash(struct inpcb *inp)
 			inp->inp_flags &= ~INP_INLBGROUP;
 			return;
 		}
+		LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) {
+			if (inp == inp1) {
+				LIST_REMOVE(inp, inp_lbgroup_list);
+				inp->inp_flags &= ~INP_INLBGROUP;
+				return;
+			}
+		}
 	}
-	KASSERT(0, ("%s: did not find %p", __func__, inp));
+	__assert_unreachable();
 }
 
 int
 in_pcblbgroup_numa(struct inpcb *inp, int arg)
 {
 	struct inpcbinfo *pcbinfo;
-	struct inpcblbgrouphead *hdr;
-	struct inpcblbgroup *grp;
-	int err, i;
+	int error;
 	uint8_t numa_domain;
 
 	switch (arg) {
@@ -472,33 +523,20 @@ in_pcblbgroup_numa(struct inpcb *inp, int arg)
 		numa_domain = arg;
 	}
 
-	err = 0;
 	pcbinfo = inp->inp_pcbinfo;
 	INP_WLOCK_ASSERT(inp);
 	INP_HASH_WLOCK(pcbinfo);
-	hdr = &pcbinfo->ipi_lbgrouphashbase[
-	    INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)];
-	CK_LIST_FOREACH(grp, hdr, il_list) {
-		for (i = 0; i < grp->il_inpcnt; ++i) {
-			if (grp->il_inp[i] != inp)
-				continue;
-
-			if (grp->il_numa_domain == numa_domain) {
-				goto abort_with_hash_wlock;
-			}
-
-			/* Remove it from the old group. */
-			in_pcbremlbgrouphash(inp);
-
-			/* Add it to the new group based on numa domain. */
-			in_pcbinslbgrouphash(inp, numa_domain);
-			goto abort_with_hash_wlock;
-		}
+	if (in_pcblbgroup_find(inp) != NULL) {
+		/* Remove it from the old group. */
+		in_pcbremlbgrouphash(inp);
+		/* Add it to the new group based on numa domain. */
+		in_pcbinslbgrouphash(inp, numa_domain);
+		error = 0;
+	} else {
+		error = ENOENT;
 	}
-	err = ENOENT;
-abort_with_hash_wlock:
 	INP_HASH_WUNLOCK(pcbinfo);
-	return (err);
+	return (error);
 }
 
 /* Make sure it is safe to use hashinit(9) on CK_LIST. */
@@ -1437,6 +1475,25 @@ in_pcbdisconnect(struct inpcb *inp)
 }
 #endif /* INET */
 
+void
+in_pcblisten(struct inpcb *inp)
+{
+	struct inpcblbgroup *grp;
+
+	INP_WLOCK_ASSERT(inp);
+
+	if ((inp->inp_flags & INP_INLBGROUP) != 0) {
+		struct inpcbinfo *pcbinfo;
+
+		pcbinfo = inp->inp_pcbinfo;
+		INP_HASH_WLOCK(pcbinfo);
+		grp = in_pcblbgroup_find(inp);
+		LIST_REMOVE(inp, inp_lbgroup_list);
+		in_pcblbgroup_insert(grp, inp);
+		INP_HASH_WUNLOCK(pcbinfo);
+	}
+}
+
 /*
  * inpcb hash lookups are protected by SMR section.
  *
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
index 0cf5ca017963..f00629266c58 100644
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -167,7 +167,10 @@ struct inpcbpolicy;
 struct m_snd_tag;
 struct inpcb {
 	/* Cache line #1 (amd64) */
-	CK_LIST_ENTRY(inpcb) inp_hash_exact;	/* hash table linkage */
+	union {
+		CK_LIST_ENTRY(inpcb) inp_hash_exact;	/* hash table linkage */
+		LIST_ENTRY(inpcb) inp_lbgroup_list;	/* lb group list */
+	};
 	CK_LIST_ENTRY(inpcb) inp_hash_wild;	/* hash table linkage */
 	struct rwlock	inp_lock;
 	/* Cache line #2 (amd64) */
@@ -428,6 +431,7 @@ SYSUNINIT(prot##_inpcbstorage_uninit, SI_SUB_PROTO_DOMAIN,		\
  */
 struct inpcblbgroup {
 	CK_LIST_ENTRY(inpcblbgroup) il_list;
+	LIST_HEAD(, inpcb) il_pending;	/* PCBs waiting for listen() */
 	struct epoch_context il_epoch_ctx;
 	struct ucred	*il_cred;
 	uint16_t	il_lport;			/* (c) */
@@ -671,6 +675,7 @@ int	in_pcbinshash(struct inpcb *);
 int	in_pcbladdr(struct inpcb *, struct in_addr *, struct in_addr *,
 	    struct ucred *);
 int	in_pcblbgroup_numa(struct inpcb *, int arg);
+void	in_pcblisten(struct inpcb *);
 struct inpcb *
 	in_pcblookup(struct inpcbinfo *, struct in_addr, u_int,
 	    struct in_addr, u_int, int, struct ifnet *);
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index acc3e2ea2942..3e73e448a9f7 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -391,6 +391,8 @@ tcp_usr_listen(struct socket *so, int backlog, struct thread *td)
 	}
 	SOCK_UNLOCK(so);
 
+	if (error == 0)
+		in_pcblisten(inp);
 	if (tp->t_flags & TF_FASTOPEN)
 		tp->t_tfo_pending = tcp_fastopen_alloc_counter();
 
@@ -448,6 +450,8 @@ tcp6_usr_listen(struct socket *so, int backlog, struct thread *td)
 	}
 	SOCK_UNLOCK(so);
 
+	if (error == 0)
+		in_pcblisten(inp);
 	if (tp->t_flags & TF_FASTOPEN)
 		tp->t_tfo_pending = tcp_fastopen_alloc_counter();
 
diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h
index fda8d23f5649..e818fd3fc757 100644
--- a/sys/sys/socketvar.h
+++ b/sys/sys/socketvar.h
@@ -596,6 +596,8 @@ SYSCTL_DECL(_net_inet_accf);
 int	accept_filt_generic_mod_event(module_t mod, int event, void *data);
 #endif
 
+int	pr_listen_notsupp(struct socket *so, int backlog, struct thread *td);
+
 #endif /* _KERNEL */
 
 /*
diff --git a/tests/sys/netinet/Makefile b/tests/sys/netinet/Makefile
index 9fac7152e137..6faaf8ac1df1 100644
--- a/tests/sys/netinet/Makefile
+++ b/tests/sys/netinet/Makefile
@@ -27,6 +27,8 @@ ATF_TESTS_SH=	arp \
 ATF_TESTS_PYTEST+=	carp.py
 ATF_TESTS_PYTEST+=	igmp.py
 
+LIBADD.so_reuseport_lb_test=	pthread
+
 # Some of the arp tests look for log messages in the dmesg buffer, so run them
 # serially to avoid problems with interleaved output.
 TEST_METADATA.arp+=	is_exclusive="true"
diff --git a/tests/sys/netinet/so_reuseport_lb_test.c b/tests/sys/netinet/so_reuseport_lb_test.c
index 64fe0b53618d..3ce09fcf5794 100644
--- a/tests/sys/netinet/so_reuseport_lb_test.c
+++ b/tests/sys/netinet/so_reuseport_lb_test.c
@@ -28,12 +28,16 @@
  */
 
 #include <sys/param.h>
+#include <sys/event.h>
 #include <sys/socket.h>
 
 #include <netinet/in.h>
+#include <netinet/tcp.h>
 
 #include <err.h>
 #include <errno.h>
+#include <pthread.h>
+#include <stdatomic.h>
 #include <stdlib.h>
 #include <unistd.h>
 
@@ -235,10 +239,149 @@ ATF_TC_BODY(basic_ipv6, tc)
 	}
 }
 
+struct concurrent_add_softc {
+	struct sockaddr_storage ss;
+	int socks[128];
+	int kq;
+};
+
+static void *
+listener(void *arg)
+{
+	for (struct concurrent_add_softc *sc = arg;;) {
+		struct kevent kev;
+		ssize_t n;
+		int error, count, cs, s;
+		uint8_t b;
+
+		count = kevent(sc->kq, NULL, 0, &kev, 1, NULL);
+		ATF_REQUIRE_MSG(count == 1,
+		    "kevent() failed: %s", strerror(errno));
+
+		s = (int)kev.ident;
+		cs = accept(s, NULL, NULL);
+		ATF_REQUIRE_MSG(cs >= 0,
+		    "accept() failed: %s", strerror(errno));
+
+		b = 'M';
+		n = write(cs, &b, sizeof(b));
+		ATF_REQUIRE_MSG(n >= 0, "write() failed: %s", strerror(errno));
+		ATF_REQUIRE(n == 1);
+
+		error = close(cs);
+		ATF_REQUIRE_MSG(error == 0 || errno == ECONNRESET,
+		    "close() failed: %s", strerror(errno));
+	}
+}
+
+static void *
+connector(void *arg)
+{
+	for (struct concurrent_add_softc *sc = arg;;) {
+		ssize_t n;
+		int error, s;
+		uint8_t b;
+
+		s = socket(sc->ss.ss_family, SOCK_STREAM, 0);
+		ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno));
+
+		error = setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (int[]){1},
+		    sizeof(int));
+
+		error = connect(s, (struct sockaddr *)&sc->ss, sc->ss.ss_len);
+		ATF_REQUIRE_MSG(error == 0, "connect() failed: %s",
+		    strerror(errno));
+
+		n = read(s, &b, sizeof(b));
+		ATF_REQUIRE_MSG(n >= 0, "read() failed: %s",
+		    strerror(errno));
+		ATF_REQUIRE(n == 1);
+		ATF_REQUIRE(b == 'M');
+		error = close(s);
+		ATF_REQUIRE_MSG(error == 0,
+		    "close() failed: %s", strerror(errno));
+	}
+}
+
+/*
+ * Run three threads.  One accepts connections from listening sockets on a
+ * kqueue, while the other makes connections.  The third thread slowly adds
+ * sockets to the LB group.  This is meant to help flush out race conditions.
+ */
+ATF_TC_WITHOUT_HEAD(concurrent_add);
+ATF_TC_BODY(concurrent_add, tc)
+{
+	struct concurrent_add_softc sc;
+	struct sockaddr_in *sin;
+	pthread_t threads[4];
+	int error;
+
+	sc.kq = kqueue();
+	ATF_REQUIRE_MSG(sc.kq >= 0, "kqueue() failed: %s", strerror(errno));
+
+	error = pthread_create(&threads[0], NULL, listener, &sc);
+	ATF_REQUIRE_MSG(error == 0, "pthread_create() failed: %s",
+	    strerror(error));
+
+	sin = (struct sockaddr_in *)&sc.ss;
+	memset(sin, 0, sizeof(*sin));
+	sin->sin_len = sizeof(*sin);
+	sin->sin_family = AF_INET;
+	sin->sin_port = htons(0);
+	sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+	for (size_t i = 0; i < nitems(sc.socks); i++) {
+		struct kevent kev;
+		int s;
+
+		sc.socks[i] = s = socket(AF_INET, SOCK_STREAM, 0);
+		ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno));
+
+		error = setsockopt(s, SOL_SOCKET, SO_REUSEPORT_LB, (int[]){1},
+		    sizeof(int));
+		ATF_REQUIRE_MSG(error == 0,
+		    "setsockopt(SO_REUSEPORT_LB) failed: %s", strerror(errno));
+
+		error = bind(s, (struct sockaddr *)sin, sizeof(*sin));
+		ATF_REQUIRE_MSG(error == 0, "bind() failed: %s",
+		    strerror(errno));
+
+		error = listen(s, 5);
+		ATF_REQUIRE_MSG(error == 0, "listen() failed: %s",
+		    strerror(errno));
+
+		EV_SET(&kev, s, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, 0);
+		error = kevent(sc.kq, &kev, 1, NULL, 0, NULL);
+		ATF_REQUIRE_MSG(error == 0, "kevent() failed: %s",
+		    strerror(errno));
+
+		if (i == 0) {
+			socklen_t slen = sizeof(sc.ss);
+
+			error = getsockname(sc.socks[i],
+			    (struct sockaddr *)&sc.ss, &slen);
+			ATF_REQUIRE_MSG(error == 0, "getsockname() failed: %s",
+			    strerror(errno));
+			ATF_REQUIRE(sc.ss.ss_family == AF_INET);
+
+			for (size_t j = 1; j < nitems(threads); j++) {
+				error = pthread_create(&threads[j], NULL,
+				    connector, &sc);
+				ATF_REQUIRE_MSG(error == 0,
+				    "pthread_create() failed: %s",
+				    strerror(error));
+			}
+		}
+
+		usleep(20000);
+	}
+}
+
 ATF_TP_ADD_TCS(tp)
 {
 	ATF_TP_ADD_TC(tp, basic_ipv4);
 	ATF_TP_ADD_TC(tp, basic_ipv6);
+	ATF_TP_ADD_TC(tp, concurrent_add);
 
 	return (atf_no_error());
 }