git: d93ec8cb1324 - main - inpcb: Allow SO_REUSEPORT_LB to be used in jails

From: Mark Johnston <markj_at_FreeBSD.org>
Date: Wed, 02 Nov 2022 17:47:14 UTC
The branch main has been updated by markj:

URL: https://cgit.FreeBSD.org/src/commit/?id=d93ec8cb1324d04d7cae19fb7fa98ade2ff33c80

commit d93ec8cb1324d04d7cae19fb7fa98ade2ff33c80
Author:     Mark Johnston <markj@FreeBSD.org>
AuthorDate: 2022-11-02 17:08:07 +0000
Commit:     Mark Johnston <markj@FreeBSD.org>
CommitDate: 2022-11-02 17:46:24 +0000

    inpcb: Allow SO_REUSEPORT_LB to be used in jails
    
    Currently SO_REUSEPORT_LB silently does nothing when set by a jailed
    process.  It is trivial to support this option in VNET jails, but it's
    also useful in traditional jails.
    
    This patch enables LB groups in jails with the following semantics:
    - all PCBs in a group must belong to the same jail,
    - PCB lookup prefers jailed groups to non-jailed groups
    
    This is a straightforward extension of the semantics used for individual
    listening sockets.  One pre-existing quirk of the lbgroup implementation
    is that non-jailed lbgroups are searched before jailed listening
    sockets; that is preserved with this change.
    
    Discussed with: glebius
    MFC after:      1 month
    Sponsored by:   Modirum MDPay
    Sponsored by:   Klara, Inc.
    Differential Revision:  https://reviews.freebsd.org/D37029
---
 sys/netinet/in_pcb.c   | 126 ++++++++++++++++++++++++++++---------------------
 sys/netinet/in_pcb.h   |   3 +-
 sys/netinet6/in6_pcb.c |  99 +++++++++++++++++++++++---------------
 3 files changed, 136 insertions(+), 92 deletions(-)

diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index af3f8d8d9d4d..ea8bbea1b5ff 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -250,8 +250,8 @@ static void in_pcbremhash(struct inpcb *);
  */
 
 static struct inpcblbgroup *
-in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, u_char vflag,
-    uint16_t port, const union in_dependaddr *addr, int size,
+in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred,
+    u_char vflag, uint16_t port, const union in_dependaddr *addr, int size,
     uint8_t numa_domain)
 {
 	struct inpcblbgroup *grp;
@@ -259,8 +259,9 @@ in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, u_char vflag,
 
 	bytes = __offsetof(struct inpcblbgroup, il_inp[size]);
 	grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT);
-	if (!grp)
+	if (grp == NULL)
 		return (NULL);
+	grp->il_cred = crhold(cred);
 	grp->il_vflag = vflag;
 	grp->il_lport = port;
 	grp->il_numa_domain = numa_domain;
@@ -276,6 +277,7 @@ in_pcblbgroup_free_deferred(epoch_context_t ctx)
 	struct inpcblbgroup *grp;
 
 	grp = __containerof(ctx, struct inpcblbgroup, il_epoch_ctx);
+	crfree(grp->il_cred);
 	free(grp, M_PCB);
 }
 
@@ -294,7 +296,7 @@ in_pcblbgroup_resize(struct inpcblbgrouphead *hdr,
 	struct inpcblbgroup *grp;
 	int i;
 
-	grp = in_pcblbgroup_alloc(hdr, old_grp->il_vflag,
+	grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag,
 	    old_grp->il_lport, &old_grp->il_dependladdr, size,
 	    old_grp->il_numa_domain);
 	if (grp == NULL)
@@ -353,12 +355,6 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain)
 	INP_WLOCK_ASSERT(inp);
 	INP_HASH_WLOCK_ASSERT(pcbinfo);
 
-	/*
-	 * Don't allow jailed socket to join local group.
-	 */
-	if (inp->inp_socket != NULL && jailed(inp->inp_socket->so_cred))
-		return (0);
-
 #ifdef INET6
 	/*
 	 * Don't allow IPv4 mapped INET6 wild socket.
@@ -373,17 +369,19 @@ in_pcbinslbgrouphash(struct inpcb *inp, uint8_t numa_domain)
 	idx = INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask);
 	hdr = &pcbinfo->ipi_lbgrouphashbase[idx];
 	CK_LIST_FOREACH(grp, hdr, il_list) {
-		if (grp->il_vflag == inp->inp_vflag &&
+		if (grp->il_cred->cr_prison == inp->inp_cred->cr_prison &&
+		    grp->il_vflag == inp->inp_vflag &&
 		    grp->il_lport == inp->inp_lport &&
 		    grp->il_numa_domain == numa_domain &&
 		    memcmp(&grp->il_dependladdr,
 		    &inp->inp_inc.inc_ie.ie_dependladdr,
-		    sizeof(grp->il_dependladdr)) == 0)
+		    sizeof(grp->il_dependladdr)) == 0) {
 			break;
+		}
 	}
 	if (grp == NULL) {
 		/* Create new load balance group. */
-		grp = in_pcblbgroup_alloc(hdr, inp->inp_vflag,
+		grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag,
 		    inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr,
 		    INPCBLBGROUP_SIZMIN, numa_domain);
 		if (grp == NULL)
@@ -2145,15 +2143,20 @@ in_pcblookup_local(struct inpcbinfo *pcbinfo, struct in_addr laddr,
 }
 #undef INP_LOOKUP_MAPPED_PCB_COST
 
+static bool
+in_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain)
+{
+	return (domain == M_NODOM || domain == grp->il_numa_domain);
+}
+
 static struct inpcb *
 in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
     const struct in_addr *laddr, uint16_t lport, const struct in_addr *faddr,
-    uint16_t fport, int lookupflags, int numa_domain)
+    uint16_t fport, int lookupflags, int domain)
 {
-	struct inpcb *local_wild, *numa_wild;
 	const struct inpcblbgrouphead *hdr;
 	struct inpcblbgroup *grp;
-	uint32_t idx;
+	struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
 
 	INP_HASH_LOCK_ASSERT(pcbinfo);
 
@@ -2161,17 +2164,15 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 	    INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)];
 
 	/*
-	 * Order of socket selection:
-	 * 1. non-wild.
-	 * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD).
-	 *
-	 * NOTE:
-	 * - Load balanced group does not contain jailed sockets
-	 * - Load balanced group does not contain IPv4 mapped INET6 wild sockets
+	 * Search for an LB group match based on the following criteria:
+	 * - prefer jailed groups to non-jailed groups
+	 * - prefer exact source address matches to wildcard matches
+	 * - prefer groups bound to the specified NUMA domain
 	 */
-	local_wild = NULL;
-	numa_wild = NULL;
+	jail_exact = jail_wild = local_exact = local_wild = NULL;
 	CK_LIST_FOREACH(grp, hdr, il_list) {
+		bool injail;
+
 #ifdef INET6
 		if (!(grp->il_vflag & INP_IPV4))
 			continue;
@@ -2179,27 +2180,47 @@ in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 		if (grp->il_lport != lport)
 			continue;
 
-		idx = INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
-		    grp->il_inpcnt;
+		injail = prison_flag(grp->il_cred, PR_IP4) != 0;
+		if (injail && prison_check_ip4_locked(grp->il_cred->cr_prison,
+		    laddr) != 0)
+			continue;
+
 		if (grp->il_laddr.s_addr == laddr->s_addr) {
-			if (numa_domain == M_NODOM ||
-			    grp->il_numa_domain == numa_domain) {
-				return (grp->il_inp[idx]);
-			} else {
-				numa_wild = grp->il_inp[idx];
+			if (injail) {
+				jail_exact = grp;
+				if (in_pcblookup_lb_numa_match(grp, domain))
+					/* This is a perfect match. */
+					goto out;
+			} else if (local_exact == NULL ||
+			    in_pcblookup_lb_numa_match(grp, domain)) {
+				local_exact = grp;
+			}
+		} else if (grp->il_laddr.s_addr == INADDR_ANY &&
+		    (lookupflags & INPLOOKUP_WILDCARD) != 0) {
+			if (injail) {
+				if (jail_wild == NULL ||
+				    in_pcblookup_lb_numa_match(grp, domain))
+					jail_wild = grp;
+			} else if (local_wild == NULL ||
+			    in_pcblookup_lb_numa_match(grp, domain)) {
+				local_wild = grp;
 			}
-		}
-		if (grp->il_laddr.s_addr == INADDR_ANY &&
-		    (lookupflags & INPLOOKUP_WILDCARD) != 0 &&
-		    (local_wild == NULL || numa_domain == M_NODOM ||
-			grp->il_numa_domain == numa_domain)) {
-			local_wild = grp->il_inp[idx];
 		}
 	}
-	if (numa_wild != NULL)
-		return (numa_wild);
 
-	return (local_wild);
+	if (jail_exact != NULL)
+		grp = jail_exact;
+	else if (jail_wild != NULL)
+		grp = jail_wild;
+	else if (local_exact != NULL)
+		grp = local_exact;
+	else
+		grp = local_wild;
+	if (grp == NULL)
+		return (NULL);
+out:
+	return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
+	    grp->il_inpcnt]);
 }
 
 /*
@@ -2251,16 +2272,6 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 	if (tmpinp != NULL)
 		return (tmpinp);
 
-	/*
-	 * Then look in lb group (for wildcard match).
-	 */
-	if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
-		inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr,
-		    fport, lookupflags, numa_domain);
-		if (inp != NULL)
-			return (inp);
-	}
-
 	/*
 	 * Then look for a wildcard match, if requested.
 	 */
@@ -2272,6 +2283,15 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 		struct inpcb *jail_wild = NULL;
 		int injail;
 
+		/*
+		 * First see if an LB group matches the request before scanning
+		 * all sockets on this port.
+		 */
+		inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr,
+		    fport, lookupflags, numa_domain);
+		if (inp != NULL)
+			return (inp);
+
 		/*
 		 * Order of socket selection - we always prefer jails.
 		 *      1. jailed, non-wild.
@@ -2472,8 +2492,8 @@ in_pcbremhash(struct inpcb *inp)
 	MPASS(inp->inp_flags & INP_INHASHLIST);
 
 	INP_HASH_WLOCK(inp->inp_pcbinfo);
-	/* XXX: Only do if SO_REUSEPORT_LB set? */
-	in_pcbremlbgrouphash(inp);
+	if ((inp->inp_flags2 & INP_REUSEPORT_LB) != 0)
+		in_pcbremlbgrouphash(inp);
 	CK_LIST_REMOVE(inp, inp_hash);
 	CK_LIST_REMOVE(inp, inp_portlist);
 	if (CK_LIST_FIRST(&phd->phd_pcblist) == NULL) {
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
index 29a042942ead..e8e6fbd86708 100644
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -500,9 +500,10 @@ SYSUNINIT(prot##_inpcbstorage_uninit, SI_SUB_PROTO_DOMAIN,		\
 struct inpcblbgroup {
 	CK_LIST_ENTRY(inpcblbgroup) il_list;
 	struct epoch_context il_epoch_ctx;
+	struct ucred	*il_cred;
 	uint16_t	il_lport;			/* (c) */
 	u_char		il_vflag;			/* (c) */
-	u_int8_t		il_numa_domain;
+	uint8_t		il_numa_domain;
 	uint32_t	il_pad2;
 	union in_dependaddr il_dependladdr;		/* (c) */
 #define	il_laddr	il_dependladdr.id46_addr.ia46_addr4
diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c
index 3a3c4043f749..3e829f00f51a 100644
--- a/sys/netinet6/in6_pcb.c
+++ b/sys/netinet6/in6_pcb.c
@@ -887,15 +887,20 @@ in6_rtchange(struct inpcb *inp, int errno __unused)
 	return inp;
 }
 
+static bool
+in6_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain)
+{
+	return (domain == M_NODOM || domain == grp->il_numa_domain);
+}
+
 static struct inpcb *
 in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
     const struct in6_addr *laddr, uint16_t lport, const struct in6_addr *faddr,
-    uint16_t fport, int lookupflags, uint8_t numa_domain)
+    uint16_t fport, int lookupflags, uint8_t domain)
 {
-	struct inpcb *local_wild, *numa_wild;
 	const struct inpcblbgrouphead *hdr;
 	struct inpcblbgroup *grp;
-	uint32_t idx;
+	struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild;
 
 	INP_HASH_LOCK_ASSERT(pcbinfo);
 
@@ -903,17 +908,15 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 	    INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)];
 
 	/*
-	 * Order of socket selection:
-	 * 1. non-wild.
-	 * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD).
-	 *
-	 * NOTE:
-	 * - Load balanced group does not contain jailed sockets.
-	 * - Load balanced does not contain IPv4 mapped INET6 wild sockets.
+	 * Search for an LB group match based on the following criteria:
+	 * - prefer jailed groups to non-jailed groups
+	 * - prefer exact source address matches to wildcard matches
+	 * - prefer groups bound to the specified NUMA domain 
 	 */
-	local_wild = NULL;
-	numa_wild = NULL;
+	jail_exact = jail_wild = local_exact = local_wild = NULL;
 	CK_LIST_FOREACH(grp, hdr, il_list) {
+		bool injail;
+
 #ifdef INET
 		if (!(grp->il_vflag & INP_IPV6))
 			continue;
@@ -921,26 +924,47 @@ in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo,
 		if (grp->il_lport != lport)
 			continue;
 
-		idx = INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
-		    grp->il_inpcnt;
+		injail = prison_flag(grp->il_cred, PR_IP6) != 0;
+		if (injail && prison_check_ip6_locked(grp->il_cred->cr_prison,
+		    laddr) != 0)
+			continue;
+
 		if (IN6_ARE_ADDR_EQUAL(&grp->il6_laddr, laddr)) {
-			if (numa_domain == M_NODOM ||
-			    grp->il_numa_domain == numa_domain) {
-				return (grp->il_inp[idx]);
+			if (injail) {
+				jail_exact = grp;
+				if (in6_pcblookup_lb_numa_match(grp, domain))
+					/* This is a perfect match. */
+					goto out;
+			} else if (local_exact == NULL ||
+			    in6_pcblookup_lb_numa_match(grp, domain)) {
+				local_exact = grp;
+			}
+		} else if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) &&
+		    (lookupflags & INPLOOKUP_WILDCARD) != 0) {
+			if (injail) {
+				if (jail_wild == NULL ||
+				    in6_pcblookup_lb_numa_match(grp, domain))
+					jail_wild = grp;
+			} else if (local_wild == NULL ||
+			    in6_pcblookup_lb_numa_match(grp, domain)) {
+				local_wild = grp;
 			}
-			else
-				numa_wild = grp->il_inp[idx];
-		}
-		if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) &&
-		    (lookupflags & INPLOOKUP_WILDCARD) != 0 &&
-		    (local_wild == NULL || numa_domain == M_NODOM ||
-			grp->il_numa_domain == numa_domain)) {
-			local_wild = grp->il_inp[idx];
 		}
 	}
-	if (numa_wild != NULL)
-		return (numa_wild);
-	return (local_wild);
+
+	if (jail_exact != NULL)
+		grp = jail_exact;
+	else if (jail_wild != NULL)
+		grp = jail_wild;
+	else if (local_exact != NULL)
+		grp = local_exact;
+	else
+		grp = local_wild;
+	if (grp == NULL)
+		return (NULL);
+out:
+	return (grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) %
+	    grp->il_inpcnt]);
 }
 
 /*
@@ -988,16 +1012,6 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
 	if (tmpinp != NULL)
 		return (tmpinp);
 
-	/*
-	 * Then look in lb group (for wildcard match).
-	 */
-	if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
-		inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr,
-		    fport, lookupflags, numa_domain);
-		if (inp != NULL)
-			return (inp);
-	}
-
 	/*
 	 * Then look for a wildcard match, if requested.
 	 */
@@ -1006,6 +1020,15 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
 		struct inpcb *jail_wild = NULL;
 		int injail;
 
+		/*
+		 * First see if an LB group matches the request before scanning
+		 * all sockets on this port.
+		 */
+		inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr,
+		    fport, lookupflags, numa_domain);
+		if (inp != NULL)
+			return (inp);
+
 		/*
 		 * Order of socket selection - we always prefer jails.
 		 *      1. jailed, non-wild.