git: a600aabe9b04 - main - inpcb: Close some SO_REUSEPORT_LB races

From: Mark Johnston <markj_at_FreeBSD.org>
Date: Thu, 12 Dec 2024 14:27:33 UTC
The branch main has been updated by markj:

URL: https://cgit.FreeBSD.org/src/commit/?id=a600aabe9b04f0906069a8fb1f8d696ad186080f

commit a600aabe9b04f0906069a8fb1f8d696ad186080f
Author:     Mark Johnston <markj@FreeBSD.org>
AuthorDate: 2024-12-12 14:02:12 +0000
Commit:     Mark Johnston <markj@FreeBSD.org>
CommitDate: 2024-12-12 14:02:12 +0000

    inpcb: Close some SO_REUSEPORT_LB races
    
    For a long time, the inpcb lookup path has been lockless in the common
    case: we use net_epoch to synchronize lookups.  However, the routines
    which update lbgroups were not careful to synchronize with unlocked
    lookups.  I believe that in the worst case this can result in spurious
    connection aborts (I have a regression test case to exercise this), but
    it's hard to be certain.
    
    Modify in_pcblbgroup* routines to synchronize with unlocked lookup:
    - When removing inpcbs from an lbgroup, do not shrink the array.
      The maximum number of lbgroup entries is INPCBLBGROUP_SIZMAX (256),
      and it doesn't seem worth the complexity to shrink the array when a
      socket is removed.
    - When resizing an lbgroup, do not insert it into the hash table until
      it is fully initialized; otherwise lookups may observe a partially
      constructed lbgroup.
    - When adding an inpcb to the group, increment the counter after adding
      the array entry, using a release store.  Otherwise it's possible for
      lookups to observe a null array slot.
    - When looking up an entry, use a corresponding acquire load.
    
    Reviewed by:    ae, glebius
    MFC after:      1 month
    Sponsored by:   Klara, Inc.
    Sponsored by:   Stormshield
    Differential Revision:  https://reviews.freebsd.org/D48020
---
 sys/netinet/in_pcb.c   | 94 ++++++++++++++++++++++++++++----------------------
 sys/netinet6/in6_pcb.c | 13 +++++--
 2 files changed, 63 insertions(+), 44 deletions(-)

diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index 87489e8f457c..cfe3fd65e032 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -253,9 +253,8 @@ static void in_pcbremhash(struct inpcb *);
  */
 
 static struct inpcblbgroup *
-in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred,
-    u_char vflag, uint16_t port, const union in_dependaddr *addr, int size,
-    uint8_t numa_domain)
+in_pcblbgroup_alloc(struct ucred *cred, u_char vflag, uint16_t port,
+    const union in_dependaddr *addr, int size, uint8_t numa_domain)
 {
 	struct inpcblbgroup *grp;
 	size_t bytes;
@@ -270,7 +269,6 @@ in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred,
 	grp->il_numa_domain = numa_domain;
 	grp->il_dependladdr = *addr;
 	grp->il_inpsiz = size;
-	CK_LIST_INSERT_HEAD(hdr, grp, il_list);
 	return (grp);
 }
 
@@ -292,6 +290,24 @@ in_pcblbgroup_free(struct inpcblbgroup *grp)
 	NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx);
 }
 
+static void
+in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp)
+{
+	KASSERT(grp->il_inpcnt < grp->il_inpsiz,
+	    ("invalid local group size %d and count %d", grp->il_inpsiz,
+	    grp->il_inpcnt));
+	INP_WLOCK_ASSERT(inp);
+
+	inp->inp_flags |= INP_INLBGROUP;
+	grp->il_inp[grp->il_inpcnt] = inp;
+
+	/*
+	 * Synchronize with in_pcblookup_lbgroup(): make sure that we don't
+	 * expose a null slot to the lookup path.
+	 */
+	atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1);
+}
+
 static struct inpcblbgroup *
 in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
     struct inpcblbgroup *old_grp, int size)
@@ -299,7 +315,7 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
 	struct inpcblbgroup *grp;
 	int i;
 
-	grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag,
+	grp = in_pcblbgroup_alloc(old_grp->il_cred, old_grp->il_vflag,
 	    old_grp->il_lport, &old_grp->il_dependladdr, size,
 	    old_grp->il_numa_domain);
 	if (grp == NULL)
@@ -312,34 +328,11 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
 	for (i = 0; i < old_grp->il_inpcnt; ++i)
 		grp->il_inp[i] = old_grp->il_inp[i];
 	grp->il_inpcnt = old_grp->il_inpcnt;
+	CK_LIST_INSERT_HEAD(hdr, grp, il_list);
 	in_pcblbgroup_free(old_grp);
 	return (grp);
 }
 
-/*
- * PCB at index 'i' is removed from the group. Pull up the ones below il_inp[i]
- * and shrink group if possible.
- */
-static void
-in_pcblbgroup_reorder(struct inpcblbgrouphead *hdr, struct inpcblbgroup **grpp,
-    int i)
-{
-	struct inpcblbgroup *grp, *new_grp;
-
-	grp = *grpp;
-	for (; i + 1 < grp->il_inpcnt; ++i)
-		grp->il_inp[i] = grp->il_inp[i + 1];
-	grp->il_inpcnt--;
-
-	if (grp->il_inpsiz > INPCBLBGROUP_SIZMIN &&
-	    grp->il_inpcnt <= grp->il_inpsiz / 4) {
-		/* Shrink this group. */
-		new_grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz / 2);
-		if (new_grp != NULL)
-			*grpp = new_grp;
-	}
-}
-
 /*
  * Add PCB to load balance group for SO_REUSEPORT_LB option.
  */
@@ -384,11 +377,13 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain)
 	}
 	if (grp == NULL) {
 		/* Create new load balance group. */
-		grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag,
+		grp = in_pcblbgroup_alloc(inp->inp_cred, inp->inp_vflag,
 		    inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr,
 		    INPCBLBGROUP_SIZMIN, numa_domain);
 		if (grp == NULL)
 			return (ENOBUFS);
+		in_pcblbgroup_insert(grp, inp);
+		CK_LIST_INSERT_HEAD(hdr, grp, il_list);
 	} else if (grp->il_inpcnt == grp->il_inpsiz) {
 		if (grp->il_inpsiz >= INPCBLBGROUP_SIZMAX) {
 			if (ratecheck(&lastprint, &interval))
@@ -401,15 +396,10 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain)
 		grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz * 2);
 		if (grp == NULL)
 			return (ENOBUFS);
+		in_pcblbgroup_insert(grp, inp);
+	} else {
+		in_pcblbgroup_insert(grp, inp);
 	}
-
-	KASSERT(grp->il_inpcnt < grp->il_inpsiz,
-	    ("invalid local group size %d and count %d", grp->il_inpsiz,
-	    grp->il_inpcnt));
-
-	grp->il_inp[grp->il_inpcnt] = inp;
-	grp->il_inpcnt++;
-	inp->inp_flags |= INP_INLBGROUP;
 	return (0);
 }
 
@@ -441,8 +431,17 @@ in_pcbremlbgrouphash(struct inpcb *inp)
 				/* We are the last, free this local group. */
 				in_pcblbgroup_free(grp);
 			} else {
-				/* Pull up inpcbs, shrink group if possible. */
-				in_pcblbgroup_reorder(hdr, &grp, i);
+				KASSERT(grp->il_inpcnt >= 2,
+				    ("invalid local group count %d",
+				    grp->il_inpcnt));
+				grp->il_inp[i] =
+				    grp->il_inp[grp->il_inpcnt - 1];
+
+				/*
+				 * Synchronize with in_pcblookup_lbgroup().
+				 */
+				atomic_store_rel_int(&grp->il_inpcnt,
+				    grp->il_inpcnt - 1);
 			}
 			inp->inp_flags &= ~INP_INLBGROUP;
 			return;
@@ -2068,8 +2067,11 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 	const struct inpcblbgrouphead *hdr;
 	struct inpcblbgroup *grp;
 	struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
+	struct inpcb *inp;
+	u_int count;
 
 	INP_HASH_LOCK_ASSERT(pcbinfo);
+	NET_EPOCH_ASSERT();
 
 	hdr = &pcbinfo->ipi_lbgrouphashbase[
 	    INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)];
@@ -2128,9 +2130,17 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 		grp = local_wild;
 	if (grp == NULL)
 		return (NULL);
+
 out:
-	return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
-	    grp->il_inpcnt]);
+	/*
+	 * Synchronize with in_pcblbgroup_insert().
+	 */
+	count = atomic_load_acq_int(&grp->il_inpcnt);
+	if (count == 0)
+		return (NULL);
+	inp = grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count];
+	KASSERT(inp != NULL, ("%s: inp == NULL", __func__));
+	return (inp);
 }
 
 static bool
diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c
index 4f7c0912e675..ada5058e56b3 100644
--- a/sys/netinet6/in6_pcb.c
+++ b/sys/netinet6/in6_pcb.c
@@ -893,6 +893,8 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 	const struct inpcblbgrouphead *hdr;
 	struct inpcblbgroup *grp;
 	struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
+	struct inpcb *inp;
+	u_int count;
 
 	INP_HASH_LOCK_ASSERT(pcbinfo);
 
@@ -954,8 +956,15 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 	if (grp == NULL)
 		return (NULL);
 out:
-	return (grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
-	    grp->il_inpcnt]);
+	/*
+	 * Synchronize with in_pcblbgroup_insert().
+	 */
+	count = atomic_load_acq_int(&grp->il_inpcnt);
+	if (count == 0)
+		return (NULL);
+	inp = grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count];
+	KASSERT(inp != NULL, ("%s: inp == NULL", __func__));
+	return (inp);
 }
 
 static bool