git: b9c0321d54e9 - main - pf: Fix source node locking

From: Kajetan Staszkiewicz <ks_at_FreeBSD.org>
Date: Thu, 28 Nov 2024 17:34:16 UTC
The branch main has been updated by ks:

URL: https://cgit.FreeBSD.org/src/commit/?id=b9c0321d54e96d0a6591e9c609c7581916d3ddd3

commit b9c0321d54e96d0a6591e9c609c7581916d3ddd3
Author:     Kajetan Staszkiewicz <ks@FreeBSD.org>
AuthorDate: 2024-11-23 21:21:22 +0000
Commit:     Kajetan Staszkiewicz <ks@FreeBSD.org>
CommitDate: 2024-11-28 17:31:55 +0000

    pf: Fix source node locking
    
    Source nodes are created quite early in pf_create_state(), even before
    the state is allocated, locked and inserted into its hash row. They are
    prone to being freed by source node killing or clearing ioctl while
    pf_create_state() is still running.
    
    The function pf_map_addr_sn() can be called in two very different paths.
    
    One is for filter rules where it is called from
    pf_create_state() after pf_insert_src_node(). In this case it is called
    with a given source node and does not perform its own search and must
    return the source node.
    
    The other one is for NAT rules where it is called from
    pf_get_translation() or its descendants. In this case it is called with
    no known source node and performs its own search for source nodes. This
    source node is then passed back to pf_create_state() without locking.
    
    The states property of source node is increased in pf_find_src_node()
    which allows for the counter to increase when a packet matches the NAT
    rule but not a pass keep state rule.
    
    The function pf_map_addr() operates on unlocked source node.
    
    Modify pf_find_src_node() to return locked on source node found, so
    that any subsequent operations can access the source node safely.
    
    Move sn->states++ counter increase to pf_insert_src_node() to ensure
    that it's called only from pf_create_state() and not from NAT ruleset
    path, and have it increased only if the source node has really been
    inserted or found, simplifying the cleanup.
    
    Add locking in pf_src_connlimit() and pf_map_addr(). Sprinkle mutex
    assertions in pf_map_addr().
    
    Add a function pf_src_node_exists() to check a known source node is
    still valid. Use it in pf_create_state() where it's impossible to hold
    locks from pf_insert_src_node() because that would cause LoR (nodes
    first, then state) against pf_src_connlimit() (state first, then node).
    
    Don't propagate the source node found while parsing the NAT ruleset to
    pf_create_state() because it must be found again and locked or created.
    
    Reviewed by:            kp
    Approved by:            kp (mentor)
    Sponsored by:           InnoGames GmbH
    Differential Revision:  https://reviews.freebsd.org/D47770
---
 sys/net/pfvar.h                   |   8 ++-
 sys/netpfil/pf/pf.c               | 148 ++++++++++++++++++++++----------------
 sys/netpfil/pf/pf_lb.c            |  66 ++++++++++-------
 tests/sys/netpfil/pf/src_track.sh |  15 ++--
 4 files changed, 143 insertions(+), 94 deletions(-)

diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
index e00101ba2b78..51f525c7383b 100644
--- a/sys/net/pfvar.h
+++ b/sys/net/pfvar.h
@@ -2334,6 +2334,9 @@ extern int			 pf_udp_mapping_insert(struct pf_udp_mapping
 				    *mapping);
 extern void			 pf_udp_mapping_release(struct pf_udp_mapping
 				    *mapping);
+uint32_t			 pf_hashsrc(struct pf_addr *, sa_family_t);
+extern bool			 pf_src_node_exists(struct pf_ksrc_node **,
+				    struct pf_srchash *);
 extern struct pf_ksrc_node	*pf_find_src_node(struct pf_addr *,
 				    struct pf_krule *, sa_family_t,
 				    struct pf_srchash **, bool);
@@ -2622,10 +2625,9 @@ u_short			 pf_map_addr(u_int8_t, struct pf_krule *,
 u_short			 pf_map_addr_sn(u_int8_t, struct pf_krule *,
 			    struct pf_addr *, struct pf_addr *,
 			    struct pfi_kkif **nkif, struct pf_addr *,
-			    struct pf_ksrc_node **);
+			    struct pf_ksrc_node **, struct pf_srchash **);
 u_short			 pf_get_translation(struct pf_pdesc *,
-			    int, struct pf_ksrc_node **,
-			    struct pf_state_key **, struct pf_state_key **,
+			    int, struct pf_state_key **, struct pf_state_key **,
 			    struct pf_addr *, struct pf_addr *,
 			    uint16_t, uint16_t, struct pf_kanchor_stackframe *,
 			    struct pf_krule **,
diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c
index 9436a4247411..9f8fec51e420 100644
--- a/sys/netpfil/pf/pf.c
+++ b/sys/netpfil/pf/pf.c
@@ -332,8 +332,7 @@ static int		 pf_test_rule(struct pf_krule **, struct pf_kstate **,
 			    struct pf_kruleset **, struct inpcb *);
 static int		 pf_create_state(struct pf_krule *, struct pf_krule *,
 			    struct pf_krule *, struct pf_pdesc *,
-			    struct pf_ksrc_node *, struct pf_state_key *,
-			    struct pf_state_key *,
+			    struct pf_state_key *, struct pf_state_key *,
 			    u_int16_t, u_int16_t, int *,
 			    struct pf_kstate **, int, u_int16_t, u_int16_t,
 			    struct pf_krule_slist *, struct pf_udp_mapping *);
@@ -372,14 +371,15 @@ static void		 pf_patch_8(struct mbuf *, u_int16_t *, u_int8_t *, u_int8_t,
 			    bool, u_int8_t);
 static struct pf_kstate	*pf_find_state(struct pfi_kkif *,
 			    const struct pf_state_key_cmp *, u_int);
-static int		 pf_src_connlimit(struct pf_kstate *);
+static bool		 pf_src_connlimit(struct pf_kstate *);
 static int		 pf_match_rcvif(struct mbuf *, struct pf_krule *);
 static void		 pf_counters_inc(int, struct pf_pdesc *,
 			    struct pf_kstate *, struct pf_krule *,
 			    struct pf_krule *);
 static void		 pf_overload_task(void *v, int pending);
 static u_short		 pf_insert_src_node(struct pf_ksrc_node **,
-			    struct pf_krule *, struct pf_addr *, sa_family_t);
+			    struct pf_srchash **, struct pf_krule *,
+			    struct pf_addr *, sa_family_t);
 static u_int		 pf_purge_expired_states(u_int, int);
 static void		 pf_purge_unlinked_rules(void);
 static int		 pf_mtag_uminit(void *, int, int);
@@ -701,7 +701,7 @@ pf_hashkey(const struct pf_state_key *sk)
 	return (h & V_pf_hashmask);
 }
 
-static __inline uint32_t
+__inline uint32_t
 pf_hashsrc(struct pf_addr *addr, sa_family_t af)
 {
 	uint32_t h;
@@ -812,17 +812,14 @@ pf_check_threshold(struct pf_threshold *threshold)
 	return (threshold->count > threshold->limit);
 }
 
-static int
+static bool
 pf_src_connlimit(struct pf_kstate *state)
 {
 	struct pf_overload_entry *pfoe;
-	int bad = 0;
+	bool limited = false;
 
 	PF_STATE_LOCK_ASSERT(state);
-	/*
-	 * XXXKS: The src node is accessed unlocked!
-	 * PF_SRC_NODE_LOCK_ASSERT(state->src_node);
-	 */
+	PF_SRC_NODE_LOCK(state->src_node);
 
 	state->src_node->conn++;
 	state->src.tcp_est = 1;
@@ -832,29 +829,29 @@ pf_src_connlimit(struct pf_kstate *state)
 	    state->rule->max_src_conn <
 	    state->src_node->conn) {
 		counter_u64_add(V_pf_status.lcounters[LCNT_SRCCONN], 1);
-		bad++;
+		limited = true;
 	}
 
 	if (state->rule->max_src_conn_rate.limit &&
 	    pf_check_threshold(&state->src_node->conn_rate)) {
 		counter_u64_add(V_pf_status.lcounters[LCNT_SRCCONNRATE], 1);
-		bad++;
+		limited = true;
 	}
 
-	if (!bad)
-		return (0);
+	if (!limited)
+		goto done;
 
 	/* Kill this state. */
 	state->timeout = PFTM_PURGE;
 	pf_set_protostate(state, PF_PEER_BOTH, TCPS_CLOSED);
 
 	if (state->rule->overload_tbl == NULL)
-		return (1);
+		goto done;
 
 	/* Schedule overloading and flushing task. */
 	pfoe = malloc(sizeof(*pfoe), M_PFTEMP, M_NOWAIT);
 	if (pfoe == NULL)
-		return (1);	/* too bad :( */
+		goto done;  /* too bad :( */
 
 	bcopy(&state->src_node->addr, &pfoe->addr, sizeof(pfoe->addr));
 	pfoe->af = state->key[PF_SK_WIRE]->af;
@@ -865,7 +862,9 @@ pf_src_connlimit(struct pf_kstate *state)
 	PF_OVERLOADQ_UNLOCK();
 	taskqueue_enqueue(taskqueue_swi, &V_pf_overloadtask);
 
-	return (1);
+done:
+	PF_SRC_NODE_UNLOCK(state->src_node);
+	return (limited);
 }
 
 static void
@@ -962,8 +961,7 @@ pf_overload_task(void *v, int pending)
 }
 
 /*
- * Can return locked on failure, so that we can consistently
- * allocate and insert a new one.
+ * On node found always returns locked. On not found its configurable.
  */
 struct pf_ksrc_node *
 pf_find_src_node(struct pf_addr *src, struct pf_krule *rule, sa_family_t af,
@@ -981,15 +979,34 @@ pf_find_src_node(struct pf_addr *src, struct pf_krule *rule, sa_family_t af,
 		    (af == AF_INET6 && bcmp(&n->addr, src, sizeof(*src)) == 0)))
 			break;
 
-	if (n != NULL) {
-		n->states++;
-		PF_HASHROW_UNLOCK(*sh);
-	} else if (returnlocked == false)
+	if (n == NULL && !returnlocked)
 		PF_HASHROW_UNLOCK(*sh);
 
 	return (n);
 }
 
+bool
+pf_src_node_exists(struct pf_ksrc_node **sn, struct pf_srchash *sh)
+{
+	struct pf_ksrc_node	*cur;
+
+	if ((*sn) == NULL)
+		return (false);
+
+	KASSERT(sh != NULL, ("%s: sh is NULL", __func__));
+
+	counter_u64_add(V_pf_status.scounters[SCNT_SRC_NODE_SEARCH], 1);
+	PF_HASHROW_LOCK(sh);
+	LIST_FOREACH(cur, &(sh->nodes), entry) {
+		if (cur == (*sn) &&
+		    cur->expire != 1) /* Ignore nodes being killed */
+			return (true);
+	}
+	PF_HASHROW_UNLOCK(sh);
+	(*sn) = NULL;
+	return (false);
+}
+
 static void
 pf_free_src_node(struct pf_ksrc_node *sn)
 {
@@ -1002,33 +1019,33 @@ pf_free_src_node(struct pf_ksrc_node *sn)
 }
 
 static u_short
-pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule,
-    struct pf_addr *src, sa_family_t af)
+pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_srchash **sh,
+    struct pf_krule *rule, struct pf_addr *src, sa_family_t af)
 {
 	u_short			 reason = 0;
-	struct pf_srchash	*sh = NULL;
 
 	KASSERT((rule->rule_flag & PFRULE_SRCTRACK ||
 	    rule->rpool.opts & PF_POOL_STICKYADDR),
 	    ("%s for non-tracking rule %p", __func__, rule));
 
+	/*
+	 * Request the sh to always be locked, as we might insert a new sn.
+	 */
 	if (*sn == NULL)
-		*sn = pf_find_src_node(src, rule, af, &sh, true);
+		*sn = pf_find_src_node(src, rule, af, sh, true);
 
 	if (*sn == NULL) {
-		PF_HASHROW_ASSERT(sh);
+		PF_HASHROW_ASSERT(*sh);
 
 		if (rule->max_src_nodes &&
 		    counter_u64_fetch(rule->src_nodes) >= rule->max_src_nodes) {
 			counter_u64_add(V_pf_status.lcounters[LCNT_SRCNODES], 1);
-			PF_HASHROW_UNLOCK(sh);
 			reason = PFRES_SRCLIMIT;
 			goto done;
 		}
 
 		(*sn) = uma_zalloc(V_pf_sources_z, M_NOWAIT | M_ZERO);
 		if ((*sn) == NULL) {
-			PF_HASHROW_UNLOCK(sh);
 			reason = PFRES_MEMORY;
 			goto done;
 		}
@@ -1039,7 +1056,6 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule,
 
 			if ((*sn)->bytes[i] == NULL || (*sn)->packets[i] == NULL) {
 				pf_free_src_node(*sn);
-				PF_HASHROW_UNLOCK(sh);
 				reason = PFRES_MEMORY;
 				goto done;
 			}
@@ -1050,18 +1066,16 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule,
 		    rule->max_src_conn_rate.seconds);
 
 		MPASS((*sn)->lock == NULL);
-		(*sn)->lock = &sh->lock;
+		(*sn)->lock = &(*sh)->lock;
 
 		(*sn)->af = af;
 		(*sn)->rule = rule;
 		PF_ACPY(&(*sn)->addr, src, af);
-		LIST_INSERT_HEAD(&sh->nodes, *sn, entry);
+		LIST_INSERT_HEAD(&(*sh)->nodes, *sn, entry);
 		(*sn)->creation = time_uptime;
 		(*sn)->ruletype = rule->action;
-		(*sn)->states = 1;
 		if ((*sn)->rule != NULL)
 			counter_u64_add((*sn)->rule->src_nodes, 1);
-		PF_HASHROW_UNLOCK(sh);
 		counter_u64_add(V_pf_status.scounters[SCNT_SRC_NODE_INSERT], 1);
 	} else {
 		if (rule->max_src_states &&
@@ -1073,6 +1087,12 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule,
 		}
 	}
 done:
+	if (reason == 0)
+		(*sn)->states++;
+	else
+		(*sn) = NULL;
+
+	PF_HASHROW_UNLOCK(*sh);
 	return (reason);
 }
 
@@ -4880,7 +4900,6 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm,
 	struct pf_kruleset	*ruleset = NULL;
 	struct pf_krule_slist	 match_rules;
 	struct pf_krule_item	*ri;
-	struct pf_ksrc_node	*nsn = NULL;
 	struct tcphdr		*th = &pd->hdr.tcp;
 	struct pf_state_key	*sk = NULL, *nk = NULL;
 	u_short			 reason, transerror;
@@ -4960,8 +4979,8 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm,
 	r = TAILQ_FIRST(pf_main_ruleset.rules[PF_RULESET_FILTER].active.ptr);
 
 	/* check packet for BINAT/NAT/RDR */
-	transerror = pf_get_translation(pd, pd->off, &nsn, &sk,
-	    &nk, saddr, daddr, sport, dport, anchor_stack, &nr, &udp_mapping);
+	transerror = pf_get_translation(pd, pd->off, &sk, &nk, saddr, daddr,
+	    sport, dport, anchor_stack, &nr, &udp_mapping);
 	switch (transerror) {
 	default:
 		/* A translation error occurred. */
@@ -5290,7 +5309,7 @@ nextrule:
 	   (!state_icmp && (r->keep_state || nr != NULL ||
 	    (pd->flags & PFDESC_TCP_NORM)))) {
 		int action;
-		action = pf_create_state(r, nr, a, pd, nsn, nk, sk,
+		action = pf_create_state(r, nr, a, pd, nk, sk,
 		    sport, dport, &rewrite, sm, tag, bproto_sum, bip_sum,
 		    &match_rules, udp_mapping);
 		if (action != PF_PASS) {
@@ -5345,14 +5364,16 @@ cleanup:
 
 static int
 pf_create_state(struct pf_krule *r, struct pf_krule *nr, struct pf_krule *a,
-    struct pf_pdesc *pd, struct pf_ksrc_node *nsn, struct pf_state_key *nk,
-    struct pf_state_key *sk, u_int16_t sport,
-    u_int16_t dport, int *rewrite, struct pf_kstate **sm,
+    struct pf_pdesc *pd, struct pf_state_key *nk, struct pf_state_key *sk,
+    u_int16_t sport, u_int16_t dport, int *rewrite, struct pf_kstate **sm,
     int tag, u_int16_t bproto_sum, u_int16_t bip_sum,
     struct pf_krule_slist *match_rules, struct pf_udp_mapping *udp_mapping)
 {
 	struct pf_kstate	*s = NULL;
 	struct pf_ksrc_node	*sn = NULL;
+	struct pf_srchash	*snh = NULL;
+	struct pf_ksrc_node	*nsn = NULL;
+	struct pf_srchash	*nsnh = NULL;
 	struct tcphdr		*th = &pd->hdr.tcp;
 	u_int16_t		 mss = V_tcp_mssdflt;
 	u_short			 reason, sn_reason;
@@ -5368,13 +5389,13 @@ pf_create_state(struct pf_krule *r, struct pf_krule *nr, struct pf_krule *a,
 	/* src node for filter rule */
 	if ((r->rule_flag & PFRULE_SRCTRACK ||
 	    r->rpool.opts & PF_POOL_STICKYADDR) &&
-	    (sn_reason = pf_insert_src_node(&sn, r, pd->src, pd->af)) != 0) {
+	    (sn_reason = pf_insert_src_node(&sn, &snh, r, pd->src, pd->af)) != 0) {
 		REASON_SET(&reason, sn_reason);
 		goto csfailed;
 	}
 	/* src node for translation rule */
 	if (nr != NULL && (nr->rpool.opts & PF_POOL_STICKYADDR) &&
-	    (sn_reason = pf_insert_src_node(&nsn, nr, &sk->addr[pd->sidx],
+	    (sn_reason = pf_insert_src_node(&nsn, &nsnh, nr, &sk->addr[pd->sidx],
 	    pd->af)) != 0 ) {
 		REASON_SET(&reason, sn_reason);
 		goto csfailed;
@@ -5468,20 +5489,13 @@ pf_create_state(struct pf_krule *r, struct pf_krule *nr, struct pf_krule *a,
 	if (r->rt) {
 		/* pf_map_addr increases the reason counters */
 		if ((reason = pf_map_addr_sn(pd->af, r, pd->src, &s->rt_addr,
-		    &s->rt_kif, NULL, &sn)) != 0)
+		    &s->rt_kif, NULL, &sn, &snh)) != 0)
 			goto csfailed;
 		s->rt = r->rt;
 	}
 
 	s->creation = s->expire = pf_get_uptime();
 
-	if (sn != NULL)
-		s->src_node = sn;
-	if (nsn != NULL) {
-		/* XXX We only modify one side for now. */
-		PF_ACPY(&nsn->raddr, &nk->addr[1], pd->af);
-		s->nat_src_node = nsn;
-	}
 	if (pd->proto == IPPROTO_TCP) {
 		if (s->state_flags & PFSTATE_SCRUB_TCP &&
 		    pf_normalize_tcp_init(pd, th, &s->src, &s->dst)) {
@@ -5528,6 +5542,20 @@ pf_create_state(struct pf_krule *r, struct pf_krule *nr, struct pf_krule *a,
 	} else
 		*sm = s;
 
+	/*
+	 * Lock order is important: first state, then source node.
+	 */
+	if (pf_src_node_exists(&sn, snh)) {
+		s->src_node = sn;
+		PF_HASHROW_UNLOCK(snh);
+	}
+	if (pf_src_node_exists(&nsn, nsnh)) {
+		/* XXX We only modify one side for now. */
+		PF_ACPY(&nsn->raddr, &nk->addr[1], pd->af);
+		s->nat_src_node = nsn;
+		PF_HASHROW_UNLOCK(nsnh);
+	}
+
 	if (tag > 0)
 		s->tag = tag;
 	if (pd->proto == IPPROTO_TCP && (th->th_flags & (TH_SYN|TH_ACK)) ==
@@ -5578,26 +5606,24 @@ csfailed:
 	uma_zfree(V_pf_state_key_z, sk);
 	uma_zfree(V_pf_state_key_z, nk);
 
-	if (sn != NULL) {
-		PF_SRC_NODE_LOCK(sn);
+	if (pf_src_node_exists(&sn, snh)) {
 		if (--sn->states == 0 && sn->expire == 0) {
 			pf_unlink_src_node(sn);
-			uma_zfree(V_pf_sources_z, sn);
+			pf_free_src_node(sn);
 			counter_u64_add(
 			    V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1);
 		}
-		PF_SRC_NODE_UNLOCK(sn);
+		PF_HASHROW_UNLOCK(snh);
 	}
 
-	if (nsn != sn && nsn != NULL) {
-		PF_SRC_NODE_LOCK(nsn);
+	if (sn != nsn && pf_src_node_exists(&nsn, nsnh)) {
 		if (--nsn->states == 0 && nsn->expire == 0) {
 			pf_unlink_src_node(nsn);
-			uma_zfree(V_pf_sources_z, nsn);
+			pf_free_src_node(nsn);
 			counter_u64_add(
 			    V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1);
 		}
-		PF_SRC_NODE_UNLOCK(nsn);
+		PF_HASHROW_UNLOCK(nsnh);
 	}
 
 drop:
diff --git a/sys/netpfil/pf/pf_lb.c b/sys/netpfil/pf/pf_lb.c
index 5777cf19b067..e180f87d2998 100644
--- a/sys/netpfil/pf/pf_lb.c
+++ b/sys/netpfil/pf/pf_lb.c
@@ -69,7 +69,7 @@ static struct pf_krule	*pf_match_translation(struct pf_pdesc *,
 			    struct pf_kanchor_stackframe *);
 static int pf_get_sport(sa_family_t, uint8_t, struct pf_krule *,
     struct pf_addr *, uint16_t, struct pf_addr *, uint16_t, struct pf_addr *,
-    uint16_t *, uint16_t, uint16_t, struct pf_ksrc_node **,
+    uint16_t *, uint16_t, uint16_t, struct pf_ksrc_node **, struct pf_srchash**,
     struct pf_udp_mapping **);
 static bool		 pf_islinklocal(const sa_family_t, const struct pf_addr *);
 
@@ -225,12 +225,11 @@ static int
 pf_get_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
     struct pf_addr *saddr, uint16_t sport, struct pf_addr *daddr,
     uint16_t dport, struct pf_addr *naddr, uint16_t *nport, uint16_t low,
-    uint16_t high, struct pf_ksrc_node **sn,
+    uint16_t high, struct pf_ksrc_node **sn, struct pf_srchash **sh,
     struct pf_udp_mapping **udp_mapping)
 {
 	struct pf_state_key_cmp	key;
 	struct pf_addr		init_addr;
-	struct pf_srchash	*sh = NULL;
 
 	bzero(&init_addr, sizeof(init_addr));
 
@@ -255,7 +254,9 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
 			/* Try to find a src_node as per pf_map_addr(). */
 			if (*sn == NULL && r->rpool.opts & PF_POOL_STICKYADDR &&
 			    (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
-				*sn = pf_find_src_node(saddr, r, af, &sh, 0);
+				*sn = pf_find_src_node(saddr, r, af, sh, false);
+			if (*sn != NULL)
+				PF_SRC_NODE_UNLOCK(*sn);
 			return (0);
 		} else {
 			*udp_mapping = pf_udp_mapping_create(af, saddr, sport, &init_addr, 0);
@@ -264,7 +265,7 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
 		}
 	}
 
-	if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn))
+	if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn, sh))
 		goto failed;
 
 	if (proto == IPPROTO_ICMP) {
@@ -385,7 +386,8 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
 			 * pick a different source address since we're out
 			 * of free port choices for the current one.
 			 */
-			if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn))
+			(*sn) = NULL;
+			if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn, sh))
 				return (1);
 			break;
 		case PF_POOL_NONE:
@@ -414,7 +416,8 @@ static int
 pf_get_mape_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
     struct pf_addr *saddr, uint16_t sport, struct pf_addr *daddr,
     uint16_t dport, struct pf_addr *naddr, uint16_t *nport,
-    struct pf_ksrc_node **sn, struct pf_udp_mapping **udp_mapping)
+    struct pf_ksrc_node **sn, struct pf_srchash **sh,
+    struct pf_udp_mapping **udp_mapping)
 {
 	uint16_t psmask, low, highmask;
 	uint16_t i, ahigh, cut;
@@ -434,13 +437,13 @@ pf_get_mape_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
 	for (i = cut; i <= ahigh; i++) {
 		low = (i << ashift) | psmask;
 		if (!pf_get_sport(af, proto, r, saddr, sport, daddr, dport,
-		    naddr, nport, low, low | highmask, sn, udp_mapping))
+		    naddr, nport, low, low | highmask, sn, sh, udp_mapping))
 			return (0);
 	}
 	for (i = cut - 1; i > 0; i--) {
 		low = (i << ashift) | psmask;
 		if (!pf_get_sport(af, proto, r, saddr, sport, daddr, dport,
-		    naddr, nport, low, low | highmask, sn, udp_mapping))
+		    naddr, nport, low, low | highmask, sn, sh, udp_mapping))
 			return (0);
 	}
 	return (1);
@@ -623,23 +626,31 @@ done_pool_mtx:
 u_short
 pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct pf_addr *saddr,
     struct pf_addr *naddr, struct pfi_kkif **nkif, struct pf_addr *init_addr,
-    struct pf_ksrc_node **sn)
+    struct pf_ksrc_node **sn, struct pf_srchash **sh)
 {
 	u_short			 reason = 0;
 	struct pf_kpool		*rpool = &r->rpool;
-	struct pf_srchash	*sh = NULL;
 
-	/* Try to find a src_node if none was given and this
-	   is a sticky-address rule. */
-	if (*sn == NULL && r->rpool.opts & PF_POOL_STICKYADDR &&
-	    (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
-		*sn = pf_find_src_node(saddr, r, af, &sh, false);
+	/*
+	 * Try to find a src_node if none was given and this is
+	 * a sticky-address rule. Request the sh to be unlocked if
+	 * sn was not found, as here we never insert a new sn.
+	 */
+	if (*sn == NULL) {
+		if (r->rpool.opts & PF_POOL_STICKYADDR &&
+		    (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
+			*sn = pf_find_src_node(saddr, r, af, sh, false);
+	} else {
+		pf_src_node_exists(sn, *sh);
+	}
 
 	/* If a src_node was found or explicitly given and it has a non-zero
 	   route address, use this address. A zeroed address is found if the
 	   src node was created just a moment ago in pf_create_state and it
 	   needs to be filled in with routing decision calculated here. */
 	if (*sn != NULL && !PF_AZERO(&(*sn)->raddr, af)) {
+		PF_SRC_NODE_LOCK_ASSERT(*sn);
+
 		/* If the supplied address is the same as the current one we've
 		 * been asked before, so tell the caller that there's no other
 		 * address to be had. */
@@ -673,6 +684,8 @@ pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct pf_addr *saddr,
 	}
 
 	if (*sn != NULL) {
+		PF_SRC_NODE_LOCK_ASSERT(*sn);
+
 		PF_ACPY(&(*sn)->raddr, naddr, af);
 		if (nkif)
 			(*sn)->rkif = *nkif;
@@ -688,6 +701,9 @@ pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct pf_addr *saddr,
 	}
 
 done:
+	if ((*sn) != NULL)
+		PF_SRC_NODE_UNLOCK(*sn);
+
 	if (reason) {
 		counter_u64_add(V_pf_status.counters[reason], 1);
 	}
@@ -697,14 +713,15 @@ done:
 
 u_short
 pf_get_translation(struct pf_pdesc *pd, int off,
-    struct pf_ksrc_node **sn, struct pf_state_key **skp,
-    struct pf_state_key **nkp, struct pf_addr *saddr, struct pf_addr *daddr,
-    uint16_t sport, uint16_t dport, struct pf_kanchor_stackframe *anchor_stack,
-    struct pf_krule **rp,
+    struct pf_state_key **skp, struct pf_state_key **nkp, struct pf_addr *saddr,
+    struct pf_addr *daddr, uint16_t sport, uint16_t dport,
+    struct pf_kanchor_stackframe *anchor_stack, struct pf_krule **rp,
     struct pf_udp_mapping **udp_mapping)
 {
 	struct pf_krule	*r = NULL;
 	struct pf_addr	*naddr;
+	struct pf_ksrc_node	*sn = NULL;
+	struct pf_srchash	*sh = NULL;
 	uint16_t	*nportp;
 	uint16_t	 low, high;
 	u_short		 reason;
@@ -765,7 +782,8 @@ pf_get_translation(struct pf_pdesc *pd, int off,
 		}
 		if (r->rpool.mape.offset > 0) {
 			if (pf_get_mape_sport(pd->af, pd->proto, r, saddr,
-			    sport, daddr, dport, naddr, nportp, sn, udp_mapping)) {
+			    sport, daddr, dport, naddr, nportp, &sn, &sh,
+			    udp_mapping)) {
 				DPFPRINTF(PF_DEBUG_MISC,
 				    ("pf: MAP-E port allocation (%u/%u/%u)"
 				    " failed\n",
@@ -776,7 +794,8 @@ pf_get_translation(struct pf_pdesc *pd, int off,
 				goto notrans;
 			}
 		} else if (pf_get_sport(pd->af, pd->proto, r, saddr, sport,
-		    daddr, dport, naddr, nportp, low, high, sn, udp_mapping)) {
+		    daddr, dport, naddr, nportp, low, high, &sn, &sh,
+		    udp_mapping)) {
 			DPFPRINTF(PF_DEBUG_MISC,
 			    ("pf: NAT proxy port allocation (%u-%u) failed\n",
 			    r->rpool.proxy_port[0], r->rpool.proxy_port[1]));
@@ -863,7 +882,7 @@ pf_get_translation(struct pf_pdesc *pd, int off,
 		int tries;
 		uint16_t cut, low, high, nport;
 
-		reason = pf_map_addr_sn(pd->af, r, saddr, naddr, NULL, NULL, sn);
+		reason = pf_map_addr_sn(pd->af, r, saddr, naddr, NULL, NULL, &sn, &sh);
 		if (reason != 0)
 			goto notrans;
 		if ((r->rpool.opts & PF_POOL_TYPEMASK) == PF_POOL_BITMASK)
@@ -970,7 +989,6 @@ notrans:
 	uma_zfree(V_pf_state_key_z, *nkp);
 	uma_zfree(V_pf_state_key_z, *skp);
 	*skp = *nkp = NULL;
-	*sn = NULL;
 
 	return (reason);
 }
diff --git a/tests/sys/netpfil/pf/src_track.sh b/tests/sys/netpfil/pf/src_track.sh
index 5349e61ec76b..620f1353f9fe 100755
--- a/tests/sys/netpfil/pf/src_track.sh
+++ b/tests/sys/netpfil/pf/src_track.sh
@@ -217,28 +217,31 @@ max_src_states_rule_body()
 	# 2 connections from host ::1 matching rule_A will be allowed, 1 will fail to create a state.
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4211 --fromaddr 2001:db8:44::1
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4212 --fromaddr 2001:db8:44::1
-	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4213 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4213 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4214 --fromaddr 2001:db8:44::1
 
 	# 2 connections from host ::1 matching rule_B will be allowed, 1 will fail to create a state.
 	# Limits from rule_A don't interfere with rule_B.
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4221 --fromaddr 2001:db8:44::1
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4222 --fromaddr 2001:db8:44::1
-	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4223 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4223 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4224 --fromaddr 2001:db8:44::1
 
 	# 2 connections from host ::2 matching rule_B will be allowed, 1 will fail to create a state.
 	# Limits for host ::1 will not interfere with host ::2.
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4224 --fromaddr 2001:db8:44::2
 	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4225 --fromaddr 2001:db8:44::2
-	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4226 --fromaddr 2001:db8:44::2
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4226 --fromaddr 2001:db8:44::2
+	ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4227 --fromaddr 2001:db8:44::2
 
 	# We will check the resulting source nodes, though.
 	# Order of source nodes in output is not guaranteed, find each one separately.
 	nodes=$(mktemp) || exit 1
 	jexec router pfctl -qvsS | normalize_pfctl_s > $nodes
 	for node_regexp in \
-		'2001:db8:44::1 -> :: \( states 2, connections 2, rate [0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 3$' \
-		'2001:db8:44::1 -> :: \( states 2, connections 2, rate [0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 4$' \
-		'2001:db8:44::2 -> :: \( states 2, connections 2, rate [0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 4$' \
+		'2001:db8:44::1 -> :: \( states 3, connections 3, rate [0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 3$' \
+		'2001:db8:44::1 -> :: \( states 3, connections 3, rate [0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 4$' \
+		'2001:db8:44::2 -> :: \( states 3, connections 3, rate [0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 4$' \
 	; do
 		grep -qE "$node_regexp" $nodes || atf_fail "Source nodes not matching expected output"
 	done