git: 7b92493ab1d4 - main - inpcb: Avoid inp_cred dereferences in SMR-protected lookup

From: Mark Johnston <markj_at_FreeBSD.org>
Date: Thu, 20 Apr 2023 16:13:52 UTC
The branch main has been updated by markj:

URL: https://cgit.FreeBSD.org/src/commit/?id=7b92493ab1d464263ccdf4494b187edbe19864dc

commit 7b92493ab1d464263ccdf4494b187edbe19864dc
Author:     Mark Johnston <markj@FreeBSD.org>
AuthorDate: 2023-04-20 15:48:19 +0000
Commit:     Mark Johnston <markj@FreeBSD.org>
CommitDate: 2023-04-20 16:13:06 +0000

    inpcb: Avoid inp_cred dereferences in SMR-protected lookup
    
    The SMR-protected inpcb lookup algorithm currently has to check whether
    a matching inpcb belongs to a jail, in order to prioritize jailed
    bound sockets.  To do this it has to maintain a ucred reference, and for
    this to be safe, the reference can't be released until the UMA
    destructor is called, and this will not happen within any bounded time
    period.
    
    Changing SMR to periodically recycle garbage is not trivial.  Instead,
    let's implement SMR-synchronized lookup without needing to dereference
    inp_cred.  This will allow the inpcb code to free the inp_cred reference
    immediately when a PCB is freed, ensuring that ucred (and thus jail)
    references are released promptly.
    
    Commit 220d89212943 ("inpcb: immediately return matching pcb on lookup")
    gets us part of the way there.  This patch goes further to handle
    lookups of unconnected sockets.  Here, the strategy is to maintain a
    well-defined order of items within a hash chain so that a wild lookup
    can simply return the first match and preserve existing semantics.  This
    makes insertion of listening sockets more complicated in order to make
    lookup simpler, which seems like the right tradeoff anyway given that
    bind() is already a fairly expensive operation and lookups are more
    common.
    
    In particular, when inserting an unconnected socket, in_pcbinhash() now
    keeps the following ordering:
    - jailed sockets before non-jailed sockets,
    - specified local addresses before unspecified local addresses.
    
    Most of the change adds a separate SMR-based lookup path for inpcb hash
    lookups.  When a match is found, we try to lock the inpcb and
    re-validate its connection info.  In the common case, this works well
    and we can simply return the inpcb.  If this fails, typically because
    something is concurrently modifying the inpcb, we go to the slow path,
    which performs a serialized lookup.
    
    Note, I did not touch lbgroup lookup, since there the credential
    reference is formally synchronized by net_epoch, not SMR.  In
    particular, lbgroups are rarely allocated or freed.
    
    I think it is possible to simplify in_pcblookup_hash_wild_locked() now,
    but I didn't do it in this patch.
    
    Discussed with: glebius
    Tested by:      glebius
    Sponsored by:   Klara, Inc.
    Sponsored by:   Modirum MDPay
    Differential Revision:  https://reviews.freebsd.org/D38572
---
 sys/netinet/in_pcb.c     | 256 ++++++++++++++++++++++++++++++++++++++++++-----
 sys/netinet/in_pcb.h     |   1 +
 sys/netinet/in_pcb_var.h |   3 +
 sys/netinet6/in6_pcb.c   | 125 ++++++++++++++++++++---
 sys/netinet6/in6_pcb.h   |   5 +-
 5 files changed, 351 insertions(+), 39 deletions(-)

diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index e301c307a413..ea36d684a82b 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -222,7 +222,6 @@ SYSCTL_COUNTER_U64(_net_inet_ip_rl, OID_AUTO, newrl, CTLFLAG_RD,
    &rate_limit_new, "Total Rate limit new attempts");
 SYSCTL_COUNTER_U64(_net_inet_ip_rl, OID_AUTO, chgrl, CTLFLAG_RD,
    &rate_limit_chg, "Total Rate limited change attempts");
-
 #endif /* RATELIMIT */
 
 #endif /* INET */
@@ -1450,7 +1449,7 @@ in_pcbdetach(struct inpcb *inp)
  * here because SMR is a critical section.
  * In 99%+ cases inp_smr_lock() would obtain the lock immediately.
  */
-static inline void
+void
 inp_lock(struct inpcb *inp, const inp_lookup_t lock)
 {
 
@@ -1458,7 +1457,7 @@ inp_lock(struct inpcb *inp, const inp_lookup_t lock)
 	    rw_rlock(&inp->inp_lock) : rw_wlock(&inp->inp_lock);
 }
 
-static inline void
+void
 inp_unlock(struct inpcb *inp, const inp_lookup_t lock)
 {
 
@@ -1466,7 +1465,7 @@ inp_unlock(struct inpcb *inp, const inp_lookup_t lock)
 	    rw_runlock(&inp->inp_lock) : rw_wunlock(&inp->inp_lock);
 }
 
-static inline int
+int
 inp_trylock(struct inpcb *inp, const inp_lookup_t lock)
 {
 
@@ -1474,14 +1473,6 @@ inp_trylock(struct inpcb *inp, const inp_lookup_t lock)
 	    rw_try_rlock(&inp->inp_lock) : rw_try_wlock(&inp->inp_lock));
 }
 
-static inline bool
-in_pcbrele(struct inpcb *inp, const inp_lookup_t lock)
-{
-
-	return (lock == INPLOOKUP_RLOCKPCB ?
-	    in_pcbrele_rlocked(inp) : in_pcbrele_wlocked(inp));
-}
-
 static inline bool
 _inp_smr_lock(struct inpcb *inp, const inp_lookup_t lock, const int ignflags)
 {
@@ -1725,6 +1716,14 @@ in_pcbrele_wlocked(struct inpcb *inp)
 	return (true);
 }
 
+bool
+in_pcbrele(struct inpcb *inp, const inp_lookup_t lock)
+{
+
+	return (lock == INPLOOKUP_RLOCKPCB ?
+	    in_pcbrele_rlocked(inp) : in_pcbrele_wlocked(inp));
+}
+
 /*
  * Unconditionally schedule an inpcb to be freed by decrementing its
  * reference count, which should occur only after the inpcb has been detached
@@ -2239,6 +2238,44 @@ in_pcblookup_wild_match(const struct inpcb *inp, struct in_addr laddr,
 	return (INPLOOKUP_MATCH_NONE);
 }
 
+#define	INP_LOOKUP_AGAIN	((struct inpcb *)(uintptr_t)-1)
+
+static struct inpcb *
+in_pcblookup_hash_wild_smr(struct inpcbinfo *pcbinfo, struct in_addr faddr,
+    u_short fport, struct in_addr laddr, u_short lport,
+    const inp_lookup_t lockflags)
+{
+	struct inpcbhead *head;
+	struct inpcb *inp;
+
+	KASSERT(SMR_ENTERED(pcbinfo->ipi_smr),
+	    ("%s: not in SMR read section", __func__));
+
+	head = &pcbinfo->ipi_hash_wild[INP_PCBHASH_WILD(lport,
+	    pcbinfo->ipi_hashmask)];
+	CK_LIST_FOREACH(inp, head, inp_hash_wild) {
+		inp_lookup_match_t match;
+
+		match = in_pcblookup_wild_match(inp, laddr, lport);
+		if (match == INPLOOKUP_MATCH_NONE)
+			continue;
+
+		if (__predict_true(inp_smr_lock(inp, lockflags))) {
+			if (__predict_true(in_pcblookup_wild_match(inp, laddr,
+			    lport) != INPLOOKUP_MATCH_NONE))
+				return (inp);
+			inp_unlock(inp, lockflags);
+		}
+
+		/*
+		 * The matching socket disappeared out from under us.  Fall back
+		 * to a serialized lookup.
+		 */
+		return (INP_LOOKUP_AGAIN);
+	}
+	return (NULL);
+}
+
 static struct inpcb *
 in_pcblookup_hash_wild_locked(struct inpcbinfo *pcbinfo, struct in_addr faddr,
     u_short fport, struct in_addr laddr, u_short lport)
@@ -2332,15 +2369,15 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 	    ("%s: invalid foreign address", __func__));
 	KASSERT(laddr.s_addr != INADDR_ANY,
 	    ("%s: invalid local address", __func__));
-	INP_HASH_LOCK_ASSERT(pcbinfo);
+	INP_HASH_WLOCK_ASSERT(pcbinfo);
 
 	inp = in_pcblookup_hash_exact(pcbinfo, faddr, fport, laddr, lport);
 	if (inp != NULL)
 		return (inp);
 
 	if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
-		inp = in_pcblookup_lbgroup(pcbinfo, &faddr, fport, &laddr,
-		    lport, numa_domain);
+		inp = in_pcblookup_lbgroup(pcbinfo, &faddr, fport,
+		    &laddr, lport, numa_domain);
 		if (inp == NULL) {
 			inp = in_pcblookup_hash_wild_locked(pcbinfo, faddr,
 			    fport, laddr, lport);
@@ -2351,11 +2388,40 @@ in_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 }
 
 static struct inpcb *
-in_pcblookup_hash_smr(struct inpcbinfo *pcbinfo, struct in_addr faddr,
+in_pcblookup_hash(struct inpcbinfo *pcbinfo, struct in_addr faddr,
     u_int fport, struct in_addr laddr, u_int lport, int lookupflags,
     uint8_t numa_domain)
 {
 	struct inpcb *inp;
+	const inp_lookup_t lockflags = lookupflags & INPLOOKUP_LOCKMASK;
+
+	KASSERT((lookupflags & (INPLOOKUP_RLOCKPCB | INPLOOKUP_WLOCKPCB)) != 0,
+	    ("%s: LOCKPCB not set", __func__));
+
+	INP_HASH_WLOCK(pcbinfo);
+	inp = in_pcblookup_hash_locked(pcbinfo, faddr, fport, laddr, lport,
+	    lookupflags & ~INPLOOKUP_LOCKMASK, numa_domain);
+	if (inp != NULL && !inp_trylock(inp, lockflags)) {
+		in_pcbref(inp);
+		INP_HASH_WUNLOCK(pcbinfo);
+		inp_lock(inp, lockflags);
+		if (in_pcbrele(inp, lockflags))
+			/* XXX-MJ or retry until we get a negative match? */
+			inp = NULL;
+	} else {
+		INP_HASH_WUNLOCK(pcbinfo);
+	}
+	return (inp);
+}
+
+static struct inpcb *
+in_pcblookup_hash_smr(struct inpcbinfo *pcbinfo, struct in_addr faddr,
+    u_int fport_arg, struct in_addr laddr, u_int lport_arg, int lookupflags,
+    uint8_t numa_domain)
+{
+	struct inpcb *inp;
+	const inp_lookup_t lockflags = lookupflags & INPLOOKUP_LOCKMASK;
+	const u_short fport = fport_arg, lport = lport_arg;
 
 	KASSERT((lookupflags & ~INPLOOKUP_MASK) == 0,
 	    ("%s: invalid lookup flags %d", __func__, lookupflags));
@@ -2363,13 +2429,49 @@ in_pcblookup_hash_smr(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 	    ("%s: LOCKPCB not set", __func__));
 
 	smr_enter(pcbinfo->ipi_smr);
-	inp = in_pcblookup_hash_locked(pcbinfo, faddr, fport, laddr, lport,
-	    lookupflags & INPLOOKUP_WILDCARD, numa_domain);
+	inp = in_pcblookup_hash_exact(pcbinfo, faddr, fport, laddr, lport);
 	if (inp != NULL) {
-		if (__predict_false(inp_smr_lock(inp,
-		    (lookupflags & INPLOOKUP_LOCKMASK)) == false))
-			inp = NULL;
-	} else
+		if (__predict_true(inp_smr_lock(inp, lockflags))) {
+			/*
+			 * Revalidate the 4-tuple, the socket could have been
+			 * disconnected.
+			 */
+			if (__predict_true(in_pcblookup_exact_match(inp,
+			    faddr, fport, laddr, lport)))
+				return (inp);
+			inp_unlock(inp, lockflags);
+		}
+
+		/*
+		 * We failed to lock the inpcb, or its connection state changed
+		 * out from under us.  Fall back to a precise search.
+		 */
+		return (in_pcblookup_hash(pcbinfo, faddr, fport, laddr, lport,
+		    lookupflags, numa_domain));
+	}
+
+	if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
+		inp = in_pcblookup_lbgroup(pcbinfo, &faddr, fport,
+		    &laddr, lport, numa_domain);
+		if (inp != NULL) {
+			if (__predict_true(inp_smr_lock(inp, lockflags))) {
+				if (__predict_true(in_pcblookup_wild_match(inp,
+				    laddr, lport) != INPLOOKUP_MATCH_NONE))
+					return (inp);
+				inp_unlock(inp, lockflags);
+			}
+			inp = INP_LOOKUP_AGAIN;
+		} else {
+			inp = in_pcblookup_hash_wild_smr(pcbinfo, faddr, fport,
+			    laddr, lport, lockflags);
+		}
+		if (inp == INP_LOOKUP_AGAIN) {
+			return (in_pcblookup_hash(pcbinfo, faddr, fport, laddr,
+			    lport, lookupflags, numa_domain));
+		}
+	}
+
+	if (inp == NULL)
 		smr_exit(pcbinfo->ipi_smr);
 
 	return (inp);
@@ -2398,6 +2500,106 @@ in_pcblookup_mbuf(struct inpcbinfo *pcbinfo, struct in_addr faddr,
 }
 #endif /* INET */
 
+static bool
+in_pcbjailed(const struct inpcb *inp, unsigned int flag)
+{
+	return (prison_flag(inp->inp_cred, flag) != 0);
+}
+
+/*
+ * Insert the PCB into a hash chain using ordering rules which ensure that
+ * in_pcblookup_hash_wild_*() always encounter the highest-ranking PCB first.
+ *
+ * Specifically, keep jailed PCBs in front of non-jailed PCBs, and keep PCBs
+ * with exact local addresses ahead of wildcard PCBs.
+ */
+static void
+_in_pcbinshash_wild(struct inpcbhead *pcbhash, struct inpcb *inp)
+{
+	struct inpcb *last;
+	bool bound, injail;
+
+	INP_HASH_WLOCK_ASSERT(inp->inp_pcbinfo);
+
+	last = NULL;
+	bound = inp->inp_laddr.s_addr != INADDR_ANY;
+	injail = in_pcbjailed(inp, PR_IP4);
+	if (!injail) {
+		CK_LIST_FOREACH(last, pcbhash, inp_hash_wild) {
+			if (in_pcbjailed(inp, PR_IP4))
+				break;
+			if (CK_LIST_NEXT(last, inp_hash_wild) == NULL) {
+				CK_LIST_INSERT_AFTER(last, inp, inp_hash_wild);
+				return;
+			}
+		}
+	} else if (!CK_LIST_EMPTY(pcbhash) &&
+	    !in_pcbjailed(CK_LIST_FIRST(pcbhash), PR_IP4)) {
+		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_wild);
+		return;
+	}
+	if (!bound) {
+		CK_LIST_FOREACH_FROM(last, pcbhash, inp_hash_wild) {
+			if (last->inp_laddr.s_addr == INADDR_ANY)
+				break;
+			if (CK_LIST_NEXT(last, inp_hash_wild) == NULL) {
+				CK_LIST_INSERT_AFTER(last, inp, inp_hash_wild);
+				return;
+			}
+		}
+	}
+	if (last == NULL)
+		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_wild);
+	else
+		CK_LIST_INSERT_BEFORE(last, inp, inp_hash_wild);
+}
+
+#ifdef INET6
+/*
+ * See the comment above _in_pcbinshash_wild().
+ */
+static void
+_in6_pcbinshash_wild(struct inpcbhead *pcbhash, struct inpcb *inp)
+{
+	struct inpcb *last;
+	bool bound, injail;
+
+	INP_HASH_WLOCK_ASSERT(inp->inp_pcbinfo);
+
+	last = NULL;
+	bound = !IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_laddr);
+	injail = in_pcbjailed(inp, PR_IP6);
+	if (!injail) {
+		CK_LIST_FOREACH(last, pcbhash, inp_hash_wild) {
+			if (in_pcbjailed(last, PR_IP6))
+				break;
+			if (CK_LIST_NEXT(last, inp_hash_wild) == NULL) {
+				CK_LIST_INSERT_AFTER(last, inp, inp_hash_wild);
+				return;
+			}
+		}
+	} else if (!CK_LIST_EMPTY(pcbhash) &&
+	    !in_pcbjailed(CK_LIST_FIRST(pcbhash), PR_IP6)) {
+		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_wild);
+		return;
+	}
+	if (!bound) {
+		CK_LIST_FOREACH_FROM(last, pcbhash, inp_hash_wild) {
+			if (IN6_IS_ADDR_UNSPECIFIED(&last->in6p_laddr))
+				break;
+			if (CK_LIST_NEXT(last, inp_hash_wild) == NULL) {
+				CK_LIST_INSERT_AFTER(last, inp, inp_hash_wild);
+				return;
+			}
+		}
+	}
+	if (last == NULL)
+		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_wild);
+	else
+		CK_LIST_INSERT_BEFORE(last, inp, inp_hash_wild);
+}
+#endif
+
 /*
  * Insert PCB onto various hash lists.
  */
@@ -2484,8 +2686,14 @@ in_pcbinshash(struct inpcb *inp)
 
 	if (connected)
 		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_exact);
-	else
-		CK_LIST_INSERT_HEAD(pcbhash, inp, inp_hash_wild);
+	else {
+#ifdef INET6
+		if ((inp->inp_vflag & INP_IPV6) != 0)
+			_in6_pcbinshash_wild(pcbhash, inp);
+		else
+#endif
+			_in_pcbinshash_wild(pcbhash, inp);
+	}
 	inp->inp_flags |= INP_INHASHLIST;
 
 	return (0);
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
index 179d706381a7..984cb9e26561 100644
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -764,6 +764,7 @@ void	in_pcbnotifyall(struct inpcbinfo *pcbinfo, struct in_addr,
 void	in_pcbref(struct inpcb *);
 void	in_pcbrehash(struct inpcb *);
 void	in_pcbremhash_locked(struct inpcb *);
+bool	in_pcbrele(struct inpcb *, inp_lookup_t);
 bool	in_pcbrele_rlocked(struct inpcb *);
 bool	in_pcbrele_wlocked(struct inpcb *);
 
diff --git a/sys/netinet/in_pcb_var.h b/sys/netinet/in_pcb_var.h
index 31214b6092f3..51fae58ea6a7 100644
--- a/sys/netinet/in_pcb_var.h
+++ b/sys/netinet/in_pcb_var.h
@@ -47,6 +47,9 @@
 VNET_DECLARE(uint32_t, in_pcbhashseed);
 #define	V_in_pcbhashseed	VNET(in_pcbhashseed)
 
+void	inp_lock(struct inpcb *inp, const inp_lookup_t lock);
+void	inp_unlock(struct inpcb *inp, const inp_lookup_t lock);
+int	inp_trylock(struct inpcb *inp, const inp_lookup_t lock);
 bool	inp_smr_lock(struct inpcb *, const inp_lookup_t);
 int	in_pcb_lport(struct inpcb *, struct in_addr *, u_short *,
 	    struct ucred *, int);
diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c
index 508d264809e3..da7ed5ca79e0 100644
--- a/sys/netinet6/in6_pcb.c
+++ b/sys/netinet6/in6_pcb.c
@@ -954,8 +954,9 @@ in6_pcblookup_exact_match(const struct inpcb *inp, const struct in6_addr *faddr,
 }
 
 static struct inpcb *
-in6_pcblookup_hash_exact(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
-    u_short fport, struct in6_addr *laddr, u_short lport)
+in6_pcblookup_hash_exact(struct inpcbinfo *pcbinfo,
+    const struct in6_addr *faddr, u_short fport,
+    const struct in6_addr *laddr, u_short lport)
 {
 	struct inpcbhead *head;
 	struct inpcb *inp;
@@ -997,9 +998,47 @@ in6_pcblookup_wild_match(const struct inpcb *inp, const struct in6_addr *laddr,
 	return (INPLOOKUP_MATCH_NONE);
 }
 
+#define	INP_LOOKUP_AGAIN	((struct inpcb *)(uintptr_t)-1)
+
+static struct inpcb *
+in6_pcblookup_hash_wild_smr(struct inpcbinfo *pcbinfo,
+    const struct in6_addr *faddr, u_short fport, const struct in6_addr *laddr,
+    u_short lport, const inp_lookup_t lockflags)
+{
+	struct inpcbhead *head;
+	struct inpcb *inp;
+
+	KASSERT(SMR_ENTERED(pcbinfo->ipi_smr),
+	    ("%s: not in SMR read section", __func__));
+
+	head = &pcbinfo->ipi_hash_wild[INP_PCBHASH_WILD(lport,
+	    pcbinfo->ipi_hashmask)];
+	CK_LIST_FOREACH(inp, head, inp_hash_wild) {
+		inp_lookup_match_t match;
+
+		match = in6_pcblookup_wild_match(inp, laddr, lport);
+		if (match == INPLOOKUP_MATCH_NONE)
+			continue;
+
+		if (__predict_true(inp_smr_lock(inp, lockflags))) {
+			if (__predict_true(in6_pcblookup_wild_match(inp, laddr,
+			    lport) != INPLOOKUP_MATCH_NONE))
+				return (inp);
+			inp_unlock(inp, lockflags);
+		}
+
+		/*
+		 * The matching socket disappeared out from under us.  Fall back
+		 * to a serialized lookup.
+		 */
+		return (INP_LOOKUP_AGAIN);
+	}
+	return (NULL);
+}
+
 static struct inpcb *
 in6_pcblookup_hash_wild_locked(struct inpcbinfo *pcbinfo,
-    struct in6_addr *faddr, u_short fport, struct in6_addr *laddr,
+    const struct in6_addr *faddr, u_short fport, const struct in6_addr *laddr,
     u_short lport)
 {
 	struct inpcbhead *head;
@@ -1058,8 +1097,9 @@ in6_pcblookup_hash_wild_locked(struct inpcbinfo *pcbinfo,
 }
 
 struct inpcb *
-in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
-    u_int fport_arg, struct in6_addr *laddr, u_int lport_arg,
+in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo,
+    const struct in6_addr *faddr, u_int fport_arg,
+    const struct in6_addr *laddr, u_int lport_arg,
     int lookupflags, uint8_t numa_domain)
 {
 	struct inpcb *inp;
@@ -1071,7 +1111,6 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
 	    ("%s: invalid foreign address", __func__));
 	KASSERT(!IN6_IS_ADDR_UNSPECIFIED(laddr),
 	    ("%s: invalid local address", __func__));
-
 	INP_HASH_LOCK_ASSERT(pcbinfo);
 
 	inp = in6_pcblookup_hash_exact(pcbinfo, faddr, fport, laddr, lport);
@@ -1089,12 +1128,41 @@ in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
 	return (inp);
 }
 
+static struct inpcb *
+in6_pcblookup_hash(struct inpcbinfo *pcbinfo, const struct in6_addr *faddr,
+    u_int fport, const struct in6_addr *laddr, u_int lport, int lookupflags,
+    uint8_t numa_domain)
+{
+	struct inpcb *inp;
+	const inp_lookup_t lockflags = lookupflags & INPLOOKUP_LOCKMASK;
+
+	KASSERT((lookupflags & (INPLOOKUP_RLOCKPCB | INPLOOKUP_WLOCKPCB)) != 0,
+	    ("%s: LOCKPCB not set", __func__));
+
+	INP_HASH_WLOCK(pcbinfo);
+	inp = in6_pcblookup_hash_locked(pcbinfo, faddr, fport, laddr, lport,
+	    lookupflags & ~INPLOOKUP_LOCKMASK, numa_domain);
+	if (inp != NULL && !inp_trylock(inp, lockflags)) {
+		in_pcbref(inp);
+		INP_HASH_WUNLOCK(pcbinfo);
+		inp_lock(inp, lockflags);
+		if (in_pcbrele(inp, lockflags))
+			/* XXX-MJ or retry until we get a negative match? */
+			inp = NULL;
+	} else {
+		INP_HASH_WUNLOCK(pcbinfo);
+	}
+	return (inp);
+}
+
 static struct inpcb *
 in6_pcblookup_hash_smr(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
-    u_int fport, struct in6_addr *laddr, u_int lport, int lookupflags,
+    u_int fport_arg, struct in6_addr *laddr, u_int lport_arg, int lookupflags,
     uint8_t numa_domain)
 {
 	struct inpcb *inp;
+	const inp_lookup_t lockflags = lookupflags & INPLOOKUP_LOCKMASK;
+	const u_short fport = fport_arg, lport = lport_arg;
 
 	KASSERT((lookupflags & ~INPLOOKUP_MASK) == 0,
 	    ("%s: invalid lookup flags %d", __func__, lookupflags));
@@ -1102,13 +1170,44 @@ in6_pcblookup_hash_smr(struct inpcbinfo *pcbinfo, struct in6_addr *faddr,
 	    ("%s: LOCKPCB not set", __func__));
 
 	smr_enter(pcbinfo->ipi_smr);
-	inp = in6_pcblookup_hash_locked(pcbinfo, faddr, fport, laddr, lport,
-	    lookupflags & INPLOOKUP_WILDCARD, numa_domain);
+	inp = in6_pcblookup_hash_exact(pcbinfo, faddr, fport, laddr, lport);
 	if (inp != NULL) {
-		if (__predict_false(inp_smr_lock(inp,
-		    (lookupflags & INPLOOKUP_LOCKMASK)) == false))
-			inp = NULL;
-	} else
+		if (__predict_true(inp_smr_lock(inp, lockflags))) {
+			if (__predict_true(in6_pcblookup_exact_match(inp,
+			    faddr, fport, laddr, lport)))
+				return (inp);
+			inp_unlock(inp, lockflags);
+		}
+		/*
+		 * We failed to lock the inpcb, or its connection state changed
+		 * out from under us.  Fall back to a precise search.
+		 */
+		return (in6_pcblookup_hash(pcbinfo, faddr, fport, laddr, lport,
+		    lookupflags, numa_domain));
+	}
+
+	if ((lookupflags & INPLOOKUP_WILDCARD) != 0) {
+		inp = in6_pcblookup_lbgroup(pcbinfo, faddr, fport,
+		    laddr, lport, numa_domain);
+		if (inp != NULL) {
+			if (__predict_true(inp_smr_lock(inp, lockflags))) {
+				if (__predict_true(in6_pcblookup_wild_match(inp,
+				    laddr, lport) != INPLOOKUP_MATCH_NONE))
+					return (inp);
+				inp_unlock(inp, lockflags);
+			}
+			inp = INP_LOOKUP_AGAIN;
+		} else {
+			inp = in6_pcblookup_hash_wild_smr(pcbinfo, faddr, fport,
+			    laddr, lport, lockflags);
+		}
+		if (inp == INP_LOOKUP_AGAIN) {
+			return (in6_pcblookup_hash(pcbinfo, faddr, fport, laddr,
+			    lport, lookupflags, numa_domain));
+		}
+	}
+
+	if (inp == NULL)
 		smr_exit(pcbinfo->ipi_smr);
 
 	return (inp);
diff --git a/sys/netinet6/in6_pcb.h b/sys/netinet6/in6_pcb.h
index 91131d1968bc..92cec00bce3b 100644
--- a/sys/netinet6/in6_pcb.h
+++ b/sys/netinet6/in6_pcb.h
@@ -83,8 +83,9 @@ struct	inpcb *
 				 struct ucred *);
 struct inpcb *
 	in6_pcblookup_hash_locked(struct inpcbinfo *pcbinfo,
-	    struct in6_addr *faddr, u_int fport_arg, struct in6_addr *laddr,
-	    u_int lport_arg, int lookupflags, uint8_t);
+	    const struct in6_addr *faddr, u_int fport_arg,
+	    const struct in6_addr *laddr, u_int lport_arg,
+	    int lookupflags, uint8_t);
 struct	inpcb *
 	in6_pcblookup(struct inpcbinfo *, struct in6_addr *,
 			   u_int, struct in6_addr *, u_int, int,