git: 286b49d4ea3a - stable/13 - netlink: allow more than 64 groups per netlink socket.

From: Alexander V. Chernikov <melifaro_at_FreeBSD.org>
Date: Mon, 23 Jan 2023 22:11:58 UTC
The branch stable/13 has been updated by melifaro:

URL: https://cgit.FreeBSD.org/src/commit/?id=286b49d4ea3a4dda6627d7f6cea90d41d880f076

commit 286b49d4ea3a4dda6627d7f6cea90d41d880f076
Author:     Alexander V. Chernikov <melifaro@FreeBSD.org>
AuthorDate: 2022-11-03 16:44:07 +0000
Commit:     Alexander V. Chernikov <melifaro@FreeBSD.org>
CommitDate: 2023-01-23 22:04:03 +0000

    netlink: allow more than 64 groups per netlink socket.
    
    (cherry picked from commit 4dfd380e06c515349c5cc55a1f05effbf3a44ba1)
---
 sys/netlink/netlink_domain.c  | 71 ++++++++++++++++++++++++++++++++-----------
 sys/netlink/netlink_generic.c |  2 +-
 sys/netlink/netlink_var.h     |  4 ++-
 3 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/sys/netlink/netlink_domain.c b/sys/netlink/netlink_domain.c
index 44b5fb732896..b0ff84401c84 100644
--- a/sys/netlink/netlink_domain.c
+++ b/sys/netlink/netlink_domain.c
@@ -55,6 +55,10 @@
 #include <netlink/netlink_debug.h>
 _DECLARE_DEBUG(LOG_DEBUG);
 
+_Static_assert((NLP_MAX_GROUPS % 64) == 0,
+    "NLP_MAX_GROUPS has to be multiple of 64");
+_Static_assert(NLP_MAX_GROUPS >= 64,
+    "NLP_MAX_GROUPS has to be at least 64");
 
 #define	NLCTL_TRACKER		struct rm_priotracker nl_tracker
 #define	NLCTL_RLOCK(_ctl)	rm_rlock(&((_ctl)->ctl_lock), &nl_tracker)
@@ -97,12 +101,43 @@ nl_port_lookup(uint32_t port_id)
 }
 
 static void
-nl_update_groups_locked(struct nlpcb *nlp, uint64_t nl_groups)
+nl_add_group_locked(struct nlpcb *nlp, unsigned int group_id)
 {
-	/* Update group mask */
-	NL_LOG(LOG_DEBUG2, "socket %p, groups 0x%X -> 0x%X",
-	    nlp->nl_socket, (uint32_t)nlp->nl_groups, (uint32_t)nl_groups);
-	nlp->nl_groups = nl_groups;
+	MPASS(group_id <= NLP_MAX_GROUPS);
+	--group_id;
+
+	nlp->nl_groups[group_id / 64] |= (uint64_t)1 << (group_id % 64);
+}
+
+static void
+nl_del_group_locked(struct nlpcb *nlp, unsigned int group_id)
+{
+	MPASS(group_id <= NLP_MAX_GROUPS);
+	--group_id;
+
+	nlp->nl_groups[group_id / 64] &= ~((uint64_t)1 << (group_id % 64));
+}
+
+static bool
+nl_isset_group_locked(struct nlpcb *nlp, unsigned int group_id)
+{
+	MPASS(group_id <= NLP_MAX_GROUPS);
+	--group_id;
+
+	return (nlp->nl_groups[group_id / 64] & ((uint64_t)1 << (group_id % 64)));
+}
+
+static uint32_t
+nl_get_groups_compat(struct nlpcb *nlp)
+{
+	uint32_t groups_mask = 0;
+
+	for (int i = 0; i < 32; i++) {
+		if (nl_isset_group_locked(nlp, i + 1))
+			groups_mask |= (1 << i);
+	}
+
+	return (groups_mask);
 }
 
 /*
@@ -134,10 +169,9 @@ nl_send_group(struct mbuf *m, int num_messages, int proto, int group_id)
 	NLCTL_RLOCK(ctl);
 
 	int io_flags = NL_IOF_UNTRANSLATED;
-	uint64_t groups_mask = 1 << ((uint64_t)group_id - 1);
 
 	CK_LIST_FOREACH(nlp, &ctl->ctl_pcb_head, nl_next) {
-		if (nlp->nl_groups & groups_mask && nlp->nl_proto == proto) {
+		if (nl_isset_group_locked(nlp, group_id) && nlp->nl_proto == proto) {
 			if (nlp_last != NULL) {
 				struct mbuf *m_copy;
 				m_copy = m_copym(m, 0, M_COPYALL, M_NOWAIT);
@@ -213,7 +247,12 @@ nl_bind_locked(struct nlpcb *nlp, struct sockaddr_nl *snl)
 		nlp->nl_bound = true;
 		CK_LIST_INSERT_HEAD(&V_nl_ctl->ctl_port_head, nlp, nl_port_next);
 	}
-	nl_update_groups_locked(nlp, snl->nl_groups);
+	for (int i = 0; i < 32; i++) {
+		if (snl->nl_groups & ((uint32_t)1 << i))
+			nl_add_group_locked(nlp, i + 1);
+		else
+			nl_del_group_locked(nlp, i + 1);
+	}
 
 	return (0);
 }
@@ -324,7 +363,7 @@ nl_assign_port(struct nlpcb *nlp, uint32_t port_id)
 
 	NLCTL_WLOCK(ctl);
 	NLP_LOCK(nlp);
-	snl.nl_groups = nlp->nl_groups;
+	snl.nl_groups = nl_get_groups_compat(nlp);
 	error = nl_bind_locked(nlp, &snl);
 	NLP_UNLOCK(nlp);
 	NLCTL_WUNLOCK(ctl);
@@ -562,7 +601,6 @@ nl_ctloutput(struct socket *so, struct sockopt *sopt)
 	struct nl_control *ctl = atomic_load_ptr(&V_nl_ctl);
 	struct nlpcb *nlp = sotonlpcb(so);
 	uint32_t flag;
-	uint64_t groups, group_mask;
 	int optval, error = 0;
 	NLCTL_TRACKER;
 
@@ -575,20 +613,17 @@ nl_ctloutput(struct socket *so, struct sockopt *sopt)
 		case NETLINK_ADD_MEMBERSHIP:
 		case NETLINK_DROP_MEMBERSHIP:
 			sooptcopyin(sopt, &optval, sizeof(optval), sizeof(optval));
-			if (optval <= 0 || optval >= 64) {
+			if (optval <= 0 || optval >= NLP_MAX_GROUPS) {
 				error = ERANGE;
 				break;
 			}
-			group_mask = (uint64_t)1 << (optval - 1);
-			NL_LOG(LOG_DEBUG2, "ADD/DEL group %d mask (%X)",
-			    (uint32_t)optval, (uint32_t)group_mask);
+			NL_LOG(LOG_DEBUG2, "ADD/DEL group %d", (uint32_t)optval);
 
 			NLCTL_WLOCK(ctl);
 			if (sopt->sopt_name == NETLINK_ADD_MEMBERSHIP)
-				groups = nlp->nl_groups | group_mask;
+				nl_add_group_locked(nlp, optval);
 			else
-				groups = nlp->nl_groups & ~group_mask;
-			nl_update_groups_locked(nlp, groups);
+				nl_del_group_locked(nlp, optval);
 			NLCTL_WUNLOCK(ctl);
 			break;
 		case NETLINK_CAP_ACK:
@@ -613,7 +648,7 @@ nl_ctloutput(struct socket *so, struct sockopt *sopt)
 		switch (sopt->sopt_name) {
 		case NETLINK_LIST_MEMBERSHIPS:
 			NLCTL_RLOCK(ctl);
-			optval = nlp->nl_groups;
+			optval = nl_get_groups_compat(nlp);
 			NLCTL_RUNLOCK(ctl);
 			error = sooptcopyout(sopt, &optval, sizeof(optval));
 			break;
diff --git a/sys/netlink/netlink_generic.c b/sys/netlink/netlink_generic.c
index 5d074640ad60..de45048ff519 100644
--- a/sys/netlink/netlink_generic.c
+++ b/sys/netlink/netlink_generic.c
@@ -47,7 +47,7 @@ __FBSDID("$FreeBSD$");
 _DECLARE_DEBUG(LOG_DEBUG3);
 
 #define	MAX_FAMILIES	20
-#define	MAX_GROUPS	20
+#define	MAX_GROUPS	64
 
 #define	MIN_GROUP_NUM	48
 
diff --git a/sys/netlink/netlink_var.h b/sys/netlink/netlink_var.h
index 130f3d40a1a3..ed19008248e9 100644
--- a/sys/netlink/netlink_var.h
+++ b/sys/netlink/netlink_var.h
@@ -47,9 +47,11 @@ struct nl_io_queue {
 	int			hiwat;
 };
 
+#define	NLP_MAX_GROUPS		128
+
 struct nlpcb {
         struct socket           *nl_socket;
-	uint64_t	        nl_groups;
+	uint64_t	        nl_groups[NLP_MAX_GROUPS / 64];
 	uint32_t                nl_port;
 	uint32_t	        nl_flags;
 	uint32_t	        nl_process_id;