git: a600aabe9b04 - main - inpcb: Close some SO_REUSEPORT_LB races
- Go to: [ bottom of page ] [ top of archives ] [ this month ]
Date: Thu, 12 Dec 2024 14:27:33 UTC
The branch main has been updated by markj: URL: https://cgit.FreeBSD.org/src/commit/?id=a600aabe9b04f0906069a8fb1f8d696ad186080f commit a600aabe9b04f0906069a8fb1f8d696ad186080f Author: Mark Johnston <markj@FreeBSD.org> AuthorDate: 2024-12-12 14:02:12 +0000 Commit: Mark Johnston <markj@FreeBSD.org> CommitDate: 2024-12-12 14:02:12 +0000 inpcb: Close some SO_REUSEPORT_LB races For a long time, the inpcb lookup path has been lockless in the common case: we use net_epoch to synchronize lookups. However, the routines which update lbgroups were not careful to synchronize with unlocked lookups. I believe that in the worst case this can result in spurious connection aborts (I have a regression test case to exercise this), but it's hard to be certain. Modify in_pcblbgroup* routines to synchronize with unlocked lookup: - When removing inpcbs from an lbgroup, do not shrink the array. The maximum number of lbgroup entries is INPCBLBGROUP_SIZMAX (256), and it doesn't seem worth the complexity to shrink the array when a socket is removed. - When resizing an lbgroup, do not insert it into the hash table until it is fully initialized; otherwise lookups may observe a partially constructed lbgroup. - When adding an inpcb to the group, increment the counter after adding the array entry, using a release store. Otherwise it's possible for lookups to observe a null array slot. - When looking up an entry, use a corresponding acquire load. Reviewed by: ae, glebius MFC after: 1 month Sponsored by: Klara, Inc. Sponsored by: Stormshield Differential Revision: https://reviews.freebsd.org/D48020 --- sys/netinet/in_pcb.c | 94 ++++++++++++++++++++++++++++---------------------- sys/netinet6/in6_pcb.c | 13 +++++-- 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c index 87489e8f457c..cfe3fd65e032 100644 --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -253,9 +253,8 @@ static void in_pcbremhash(struct inpcb *); */ static struct inpcblbgroup * -in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred, - u_char vflag, uint16_t port, const union in_dependaddr *addr, int size, - uint8_t numa_domain) +in_pcblbgroup_alloc(struct ucred *cred, u_char vflag, uint16_t port, + const union in_dependaddr *addr, int size, uint8_t numa_domain) { struct inpcblbgroup *grp; size_t bytes; @@ -270,7 +269,6 @@ in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred, grp->il_numa_domain = numa_domain; grp->il_dependladdr = *addr; grp->il_inpsiz = size; - CK_LIST_INSERT_HEAD(hdr, grp, il_list); return (grp); } @@ -292,6 +290,24 @@ in_pcblbgroup_free(struct inpcblbgroup *grp) NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx); } +static void +in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp) +{ + KASSERT(grp->il_inpcnt < grp->il_inpsiz, + ("invalid local group size %d and count %d", grp->il_inpsiz, + grp->il_inpcnt)); + INP_WLOCK_ASSERT(inp); + + inp->inp_flags |= INP_INLBGROUP; + grp->il_inp[grp->il_inpcnt] = inp; + + /* + * Synchronize with in_pcblookup_lbgroup(): make sure that we don't + * expose a null slot to the lookup path. + */ + atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1); +} + static struct inpcblbgroup * in_pcblbgroup_resize(struct inpcblbgrouphead *hdr, struct inpcblbgroup *old_grp, int size) @@ -299,7 +315,7 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr, struct inpcblbgroup *grp; int i; - grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag, + grp = in_pcblbgroup_alloc(old_grp->il_cred, old_grp->il_vflag, old_grp->il_lport, &old_grp->il_dependladdr, size, old_grp->il_numa_domain); if (grp == NULL) @@ -312,34 +328,11 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr, for (i = 0; i < old_grp->il_inpcnt; ++i) grp->il_inp[i] = old_grp->il_inp[i]; grp->il_inpcnt = old_grp->il_inpcnt; + CK_LIST_INSERT_HEAD(hdr, grp, il_list); in_pcblbgroup_free(old_grp); return (grp); } -/* - * PCB at index 'i' is removed from the group. Pull up the ones below il_inp[i] - * and shrink group if possible. - */ -static void -in_pcblbgroup_reorder(struct inpcblbgrouphead *hdr, struct inpcblbgroup **grpp, - int i) -{ - struct inpcblbgroup *grp, *new_grp; - - grp = *grpp; - for (; i + 1 < grp->il_inpcnt; ++i) - grp->il_inp[i] = grp->il_inp[i + 1]; - grp->il_inpcnt--; - - if (grp->il_inpsiz > INPCBLBGROUP_SIZMIN && - grp->il_inpcnt <= grp->il_inpsiz / 4) { - /* Shrink this group. */ - new_grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz / 2); - if (new_grp != NULL) - *grpp = new_grp; - } -} - /* * Add PCB to load balance group for SO_REUSEPORT_LB option. */ @@ -384,11 +377,13 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain) } if (grp == NULL) { /* Create new load balance group. */ - grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag, + grp = in_pcblbgroup_alloc(inp->inp_cred, inp->inp_vflag, inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr, INPCBLBGROUP_SIZMIN, numa_domain); if (grp == NULL) return (ENOBUFS); + in_pcblbgroup_insert(grp, inp); + CK_LIST_INSERT_HEAD(hdr, grp, il_list); } else if (grp->il_inpcnt == grp->il_inpsiz) { if (grp->il_inpsiz >= INPCBLBGROUP_SIZMAX) { if (ratecheck(&lastprint, &interval)) @@ -401,15 +396,10 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain) grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz * 2); if (grp == NULL) return (ENOBUFS); + in_pcblbgroup_insert(grp, inp); + } else { + in_pcblbgroup_insert(grp, inp); } - - KASSERT(grp->il_inpcnt < grp->il_inpsiz, - ("invalid local group size %d and count %d", grp->il_inpsiz, - grp->il_inpcnt)); - - grp->il_inp[grp->il_inpcnt] = inp; - grp->il_inpcnt++; - inp->inp_flags |= INP_INLBGROUP; return (0); } @@ -441,8 +431,17 @@ in_pcbremlbgrouphash(struct inpcb *inp) /* We are the last, free this local group. */ in_pcblbgroup_free(grp); } else { - /* Pull up inpcbs, shrink group if possible. */ - in_pcblbgroup_reorder(hdr, &grp, i); + KASSERT(grp->il_inpcnt >= 2, + ("invalid local group count %d", + grp->il_inpcnt)); + grp->il_inp[i] = + grp->il_inp[grp->il_inpcnt - 1]; + + /* + * Synchronize with in_pcblookup_lbgroup(). + */ + atomic_store_rel_int(&grp->il_inpcnt, + grp->il_inpcnt - 1); } inp->inp_flags &= ~INP_INLBGROUP; return; @@ -2068,8 +2067,11 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; + struct inpcb *inp; + u_int count; INP_HASH_LOCK_ASSERT(pcbinfo); + NET_EPOCH_ASSERT(); hdr = &pcbinfo->ipi_lbgrouphashbase[ INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)]; @@ -2128,9 +2130,17 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, grp = local_wild; if (grp == NULL) return (NULL); + out: - return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt]); + /* + * Synchronize with in_pcblbgroup_insert(). + */ + count = atomic_load_acq_int(&grp->il_inpcnt); + if (count == 0) + return (NULL); + inp = grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count]; + KASSERT(inp != NULL, ("%s: inp == NULL", __func__)); + return (inp); } static bool diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c index 4f7c0912e675..ada5058e56b3 100644 --- a/sys/netinet6/in6_pcb.c +++ b/sys/netinet6/in6_pcb.c @@ -893,6 +893,8 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; + struct inpcb *inp; + u_int count; INP_HASH_LOCK_ASSERT(pcbinfo); @@ -954,8 +956,15 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, if (grp == NULL) return (NULL); out: - return (grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt]); + /* + * Synchronize with in_pcblbgroup_insert(). + */ + count = atomic_load_acq_int(&grp->il_inpcnt); + if (count == 0) + return (NULL); + inp = grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count]; + KASSERT(inp != NULL, ("%s: inp == NULL", __func__)); + return (inp); } static bool