git: a3a6dc24f34a - main - rpcsec_tls/client: use netlink RPC client to talk to rpc.tlsclntd(8)

From: Gleb Smirnoff <glebius_at_FreeBSD.org>
Date: Sat, 01 Feb 2025 09:02:15 UTC
The branch main has been updated by glebius:

URL: https://cgit.FreeBSD.org/src/commit/?id=a3a6dc24f34a1b0522bd0dd8fcb6b84c69686251

commit a3a6dc24f34a1b0522bd0dd8fcb6b84c69686251
Author:     Gleb Smirnoff <glebius@FreeBSD.org>
AuthorDate: 2025-02-01 01:02:32 +0000
Commit:     Gleb Smirnoff <glebius@FreeBSD.org>
CommitDate: 2025-02-01 09:00:26 +0000

    rpcsec_tls/client: use netlink RPC client to talk to rpc.tlsclntd(8)
    
    In addition to using netlink(4) socket instead of unix(4) to pass
    rpctlscd_* RPC commands to rpc.tlsclntd(8), the logic of passing file
    descriptor is also changed.  Since clnt_nl provides us all needed
    parallelism and waits on individual RPC xids, we don't need to store
    socket in a global variable and serialize all communication to the daemon.
    Instead, we will augment rpctlscd_connect arguments with a cookie that is
    basically a pointer to socket, that we keep for the daemon.  While
    sleeping on the request, we will store a database of all sockets
    associated with rpctlscd_connect RPCs that we have sent to userland.  The
    daemon then will send us back the cookie in the
    rpctls_syscall(RPCTLS_SYSC_CLSOCKET) argument and we will find and return
    the socket for this upcall.
    
    This will be optimized further in a separate commit, that will also touch
    clnt_vc.c and other krpc files.  This commit is intentionally made minimal,
    so that it is easier to understand what changes with netlink(4) transport.
    
    Reviewed by:            rmacklem
    Differential Revision:  https://reviews.freebsd.org/D48559
---
 sys/rpc/rpcsec_tls.h             |   2 -
 sys/rpc/rpcsec_tls/rpctls_impl.c | 257 +++++++++++++++------------------------
 sys/rpc/rpcsec_tls/rpctlscd.x    |   1 +
 3 files changed, 96 insertions(+), 164 deletions(-)

diff --git a/sys/rpc/rpcsec_tls.h b/sys/rpc/rpcsec_tls.h
index e3eed64863a1..4c4fa20dde31 100644
--- a/sys/rpc/rpcsec_tls.h
+++ b/sys/rpc/rpcsec_tls.h
@@ -29,9 +29,7 @@
 #define	_RPC_RPCSEC_TLS_H_
 
 /* Operation values for rpctls syscall. */
-#define	RPCTLS_SYSC_CLSETPATH	1
 #define	RPCTLS_SYSC_CLSOCKET	2
-#define	RPCTLS_SYSC_CLSHUTDOWN	3
 #define	RPCTLS_SYSC_SRVSETPATH	4
 #define	RPCTLS_SYSC_SRVSOCKET	5
 #define	RPCTLS_SYSC_SRVSHUTDOWN	6
diff --git a/sys/rpc/rpcsec_tls/rpctls_impl.c b/sys/rpc/rpcsec_tls/rpctls_impl.c
index 64111eed62c0..327233f63f1d 100644
--- a/sys/rpc/rpcsec_tls/rpctls_impl.c
+++ b/sys/rpc/rpcsec_tls/rpctls_impl.c
@@ -49,6 +49,7 @@
 #include <sys/syscallsubr.h>
 #include <sys/sysent.h>
 #include <sys/sysproto.h>
+#include <sys/tree.h>
 
 #include <net/vnet.h>
 
@@ -71,16 +72,14 @@ static struct syscall_helper_data rpctls_syscalls[] = {
 	SYSCALL_INIT_LAST
 };
 
-static CLIENT		*rpctls_connect_handle;
 static struct mtx	rpctls_connect_lock;
-static struct socket	*rpctls_connect_so = NULL;
-static CLIENT		*rpctls_connect_cl = NULL;
 static struct mtx	rpctls_server_lock;
 static struct opaque_auth rpctls_null_verf;
 
 KRPC_VNET_DECLARE(uint64_t, svc_vc_tls_handshake_success);
 KRPC_VNET_DECLARE(uint64_t, svc_vc_tls_handshake_failed);
 
+KRPC_VNET_DEFINE_STATIC(CLIENT *, rpctls_connect_handle);
 KRPC_VNET_DEFINE_STATIC(CLIENT **, rpctls_server_handle);
 KRPC_VNET_DEFINE_STATIC(struct socket *, rpctls_server_so) = NULL;
 KRPC_VNET_DEFINE_STATIC(SVCXPRT *, rpctls_server_xprt) = NULL;
@@ -88,7 +87,20 @@ KRPC_VNET_DEFINE_STATIC(bool, rpctls_srv_newdaemon) = false;
 KRPC_VNET_DEFINE_STATIC(int, rpctls_srv_prevproc) = 0;
 KRPC_VNET_DEFINE_STATIC(bool *, rpctls_server_busy);
 
-static CLIENT		*rpctls_connect_client(void);
+struct upsock {
+	RB_ENTRY(upsock) tree;
+	struct socket *so;
+	CLIENT *cl;
+};
+
+static RB_HEAD(upsock_t, upsock) upcall_sockets;
+static intptr_t
+upsock_compare(const struct upsock *a, const struct upsock *b)
+{
+	return ((intptr_t)((uintptr_t)a->so/2 - (uintptr_t)b->so/2));
+}
+RB_GENERATE_STATIC(upsock_t, upsock, tree, upsock_compare);
+
 static CLIENT		*rpctls_server_client(int procpos);
 static enum clnt_stat	rpctls_server(SVCXPRT *xprt, struct socket *so,
 			    uint32_t *flags, uint64_t *sslp,
@@ -98,6 +110,7 @@ static enum clnt_stat	rpctls_server(SVCXPRT *xprt, struct socket *so,
 static void
 rpctls_vnetinit(const void *unused __unused)
 {
+	CLIENT *cl;
 	int i;
 
 	KRPC_VNET(rpctls_server_handle) = malloc(sizeof(CLIENT *) *
@@ -106,6 +119,22 @@ rpctls_vnetinit(const void *unused __unused)
 	    RPCTLS_SRV_MAXNPROCS, M_RPC, M_WAITOK | M_ZERO);
 	for (i = 0; i < RPCTLS_SRV_MAXNPROCS; i++)
 		KRPC_VNET(rpctls_server_busy)[i] = false;
+
+	cl = client_nl_create("tlsclnt", RPCTLSCD, RPCTLSCDVERS);
+	KASSERT(cl, ("%s: netlink client already exist", __func__));
+	/*
+	 * Set the try_count to 1 so that no retries of the RPC occur.  Since
+	 * it is an upcall to a local daemon, requests should not be lost and
+	 * doing one of these RPCs multiple times is not correct.  If the
+	 * server is not working correctly, the daemon can get stuck in
+	 * SSL_connect() trying to read data from the socket during the upcall.
+	 * Set a timeout (currently 15sec) and assume the daemon is hung when
+	 *  the timeout occurs.
+	 */
+	clnt_control(cl, CLSET_RETRIES, &(int){1});
+	clnt_control(cl, CLSET_TIMEOUT, &(struct timeval){.tv_sec = 15});
+	clnt_control(cl, CLSET_WAITCHAN, "tlsclntd");
+	KRPC_VNET(rpctls_connect_handle) = cl;
 }
 VNET_SYSINIT(rpctls_vnetinit, SI_SUB_VNET_DONE, SI_ORDER_ANY,
     rpctls_vnetinit, NULL);
@@ -147,10 +176,11 @@ sys_rpctls_syscall(struct thread *td, struct rpctls_syscall_args *uap)
         struct netconfig *nconf;
 	struct file *fp;
 	struct socket *so;
+	struct upsock *ups;
 	SVCXPRT *xprt;
 	char path[MAXPATHLEN];
 	int fd = -1, error, i, try_count;
-	CLIENT *cl, *oldcl[RPCTLS_SRV_MAXNPROCS], *concl;
+	CLIENT *cl, *oldcl[RPCTLS_SRV_MAXNPROCS];
 	uint64_t ssl[3];
 	struct timeval timeo;
 #ifdef KERN_TLS
@@ -186,65 +216,6 @@ sys_rpctls_syscall(struct thread *td, struct rpctls_syscall_args *uap)
 			}
 		}
 		break;
-	case RPCTLS_SYSC_CLSETPATH:
-		if (jailed(curthread->td_ucred))
-			error = EPERM;
-		if (error == 0)
-			error = copyinstr(uap->path, path, sizeof(path), NULL);
-		if (error == 0) {
-			error = ENXIO;
-#ifdef KERN_TLS
-			if (rpctls_getinfo(&maxlen, false, false))
-				error = 0;
-#endif
-		}
-		if (error == 0 && (strlen(path) + 1 > sizeof(sun.sun_path) ||
-		    strlen(path) == 0))
-			error = EINVAL;
-	
-		cl = NULL;
-		if (error == 0) {
-			sun.sun_family = AF_LOCAL;
-			strlcpy(sun.sun_path, path, sizeof(sun.sun_path));
-			sun.sun_len = SUN_LEN(&sun);
-			
-			nconf = getnetconfigent("local");
-			cl = clnt_reconnect_create(nconf,
-			    (struct sockaddr *)&sun, RPCTLSCD, RPCTLSCDVERS,
-			    RPC_MAXDATASIZE, RPC_MAXDATASIZE);
-			/*
-			 * The number of retries defaults to INT_MAX, which
-			 * effectively means an infinite, uninterruptable loop. 
-			 * Set the try_count to 1 so that no retries of the
-			 * RPC occur.  Since it is an upcall to a local daemon,
-			 * requests should not be lost and doing one of these
-			 * RPCs multiple times is not correct.
-			 * If the server is not working correctly, the
-			 * daemon can get stuck in SSL_connect() trying
-			 * to read data from the socket during the upcall.
-			 * Set a timeout (currently 15sec) and assume the
-			 * daemon is hung when the timeout occurs.
-			 */
-			if (cl != NULL) {
-				try_count = 1;
-				CLNT_CONTROL(cl, CLSET_RETRIES, &try_count);
-				timeo.tv_sec = 15;
-				timeo.tv_usec = 0;
-				CLNT_CONTROL(cl, CLSET_TIMEOUT, &timeo);
-			} else
-				error = EINVAL;
-		}
-	
-		mtx_lock(&rpctls_connect_lock);
-		oldcl[0] = rpctls_connect_handle;
-		rpctls_connect_handle = cl;
-		mtx_unlock(&rpctls_connect_lock);
-	
-		if (oldcl[0] != NULL) {
-			CLNT_CLOSE(oldcl[0]);
-			CLNT_RELEASE(oldcl[0]);
-		}
-		break;
 	case RPCTLS_SYSC_SRVSETPATH:
 		if (jailed(curthread->td_ucred) &&
 		    !prison_check_nfsd(curthread->td_ucred))
@@ -327,17 +298,6 @@ sys_rpctls_syscall(struct thread *td, struct rpctls_syscall_args *uap)
 			}
 		}
 		break;
-	case RPCTLS_SYSC_CLSHUTDOWN:
-		mtx_lock(&rpctls_connect_lock);
-		oldcl[0] = rpctls_connect_handle;
-		rpctls_connect_handle = NULL;
-		mtx_unlock(&rpctls_connect_lock);
-	
-		if (oldcl[0] != NULL) {
-			CLNT_CLOSE(oldcl[0]);
-			CLNT_RELEASE(oldcl[0]);
-		}
-		break;
 	case RPCTLS_SYSC_SRVSHUTDOWN:
 		mtx_lock(&rpctls_server_lock);
 		for (i = 0; i < RPCTLS_SRV_MAXNPROCS; i++) {
@@ -356,30 +316,33 @@ sys_rpctls_syscall(struct thread *td, struct rpctls_syscall_args *uap)
 		break;
 	case RPCTLS_SYSC_CLSOCKET:
 		mtx_lock(&rpctls_connect_lock);
-		so = rpctls_connect_so;
-		rpctls_connect_so = NULL;
-		concl = rpctls_connect_cl;
-		rpctls_connect_cl = NULL;
+		ups = RB_FIND(upsock_t, &upcall_sockets,
+		    &(struct upsock){
+		    .so = __DECONST(struct socket *, uap->path) });
+		if (__predict_true(ups != NULL))
+			RB_REMOVE(upsock_t, &upcall_sockets, ups);
 		mtx_unlock(&rpctls_connect_lock);
-		if (so != NULL) {
-			error = falloc(td, &fp, &fd, 0);
-			if (error == 0) {
-				/*
-				 * Set ssl refno so that clnt_vc_destroy() will
-				 * not close the socket and will leave that for
-				 * the daemon to do.
-				 */
-				soref(so);
-				ssl[0] = ssl[1] = 0;
-				ssl[2] = RPCTLS_REFNO_HANDSHAKE;
-				CLNT_CONTROL(concl, CLSET_TLS, ssl);
-				finit(fp, FREAD | FWRITE, DTYPE_SOCKET, so,
-				    &socketops);
-				fdrop(fp, td);	/* Drop fp reference. */
-				td->td_retval[0] = fd;
-			}
-		} else
+		if (ups == NULL) {
+			printf("%s: socket lookup failed\n", __func__);
 			error = EPERM;
+			break;
+		}
+		error = falloc(td, &fp, &fd, 0);
+		if (error == 0) {
+			/*
+			 * Set ssl refno so that clnt_vc_destroy() will
+			 * not close the socket and will leave that for
+			 * the daemon to do.
+			 */
+			soref(ups->so);
+			ssl[0] = ssl[1] = 0;
+			ssl[2] = RPCTLS_REFNO_HANDSHAKE;
+			CLNT_CONTROL(ups->cl, CLSET_TLS, ssl);
+			finit(fp, FREAD | FWRITE, DTYPE_SOCKET, ups->so,
+			    &socketops);
+			fdrop(fp, td);	/* Drop fp reference. */
+			td->td_retval[0] = fd;
+		}
 		break;
 	case RPCTLS_SYSC_SRVSOCKET:
 		mtx_lock(&rpctls_server_lock);
@@ -416,23 +379,6 @@ sys_rpctls_syscall(struct thread *td, struct rpctls_syscall_args *uap)
 	return (error);
 }
 
-/*
- * Acquire the rpctls_connect_handle and return it with a reference count,
- * if it is available.
- */
-static CLIENT *
-rpctls_connect_client(void)
-{
-	CLIENT *cl;
-
-	mtx_lock(&rpctls_connect_lock);
-	cl = rpctls_connect_handle;
-	if (cl != NULL)
-		CLNT_ACQUIRE(cl);
-	mtx_unlock(&rpctls_connect_lock);
-	return (cl);
-}
-
 /*
  * Acquire the rpctls_server_handle and return it with a reference count,
  * if it is available.
@@ -462,13 +408,12 @@ rpctls_connect(CLIENT *newclient, char *certname, struct socket *so,
 	struct rpc_callextra ext;
 	struct timeval utimeout;
 	enum clnt_stat stat;
-	CLIENT *cl;
+	struct upsock ups = {
+		.so = so,
+		.cl = newclient,
+	};
+	CLIENT *cl = KRPC_VNET(rpctls_connect_handle);
 	int val;
-	static bool rpctls_connect_busy = false;
-
-	cl = rpctls_connect_client();
-	if (cl == NULL)
-		return (RPC_AUTHERROR);
 
 	/* First, do the AUTH_TLS NULL RPC. */
 	memset(&ext, 0, sizeof(ext));
@@ -483,14 +428,8 @@ rpctls_connect(CLIENT *newclient, char *certname, struct socket *so,
 	if (stat != RPC_SUCCESS)
 		return (RPC_SYSTEMERROR);
 
-	/* Serialize the connect upcalls. */
 	mtx_lock(&rpctls_connect_lock);
-	while (rpctls_connect_busy)
-		msleep(&rpctls_connect_busy, &rpctls_connect_lock, PVFS,
-		    "rtlscn", 0);
-	rpctls_connect_busy = true;
-	rpctls_connect_so = so;
-	rpctls_connect_cl = newclient;
+	RB_INSERT(upsock_t, &upcall_sockets, &ups);
 	mtx_unlock(&rpctls_connect_lock);
 
 	/* Temporarily block reception during the handshake upcall. */
@@ -503,37 +442,47 @@ rpctls_connect(CLIENT *newclient, char *certname, struct socket *so,
 		arg.certname.certname_val = certname;
 	} else
 		arg.certname.certname_len = 0;
+	arg.socookie = (uintptr_t)so;
 	stat = rpctlscd_connect_1(&arg, &res, cl);
 	if (stat == RPC_SUCCESS) {
+#ifdef INVARIANTS
+		MPASS((RB_FIND(upsock_t, &upcall_sockets, &ups) == NULL));
+#endif
 		*reterr = res.reterr;
 		if (res.reterr == 0) {
 			*sslp++ = res.sec;
 			*sslp++ = res.usec;
 			*sslp = res.ssl;
 		}
-	} else if (stat == RPC_TIMEDOUT) {
-		/*
-		 * Do a shutdown on the socket, since the daemon is probably
-		 * stuck in SSL_connect() trying to read the socket.
-		 * Do not soclose() the socket, since the daemon will close()
-		 * the socket after SSL_connect() returns an error.
-		 */
-		soshutdown(so, SHUT_RD);
+	} else {
+		mtx_lock(&rpctls_connect_lock);
+		if (RB_FIND(upsock_t, &upcall_sockets, &ups)) {
+			struct upsock *removed __diagused;
+
+			removed = RB_REMOVE(upsock_t, &upcall_sockets, &ups);
+			mtx_unlock(&rpctls_connect_lock);
+			MPASS(removed == &ups);
+			/*
+			 * Do a shutdown on the socket, since the daemon is
+			 * probably stuck in SSL_accept() trying to read the
+			 * socket.  Do not soclose() the socket, since the
+			 * daemon will close() the socket after SSL_accept()
+			 * returns an error.
+			 */
+			soshutdown(so, SHUT_RD);
+		} else {
+			/*
+			 * The daemon has taken the socket from the tree, but
+			 * failed to do the handshake.
+			 */
+			mtx_unlock(&rpctls_connect_lock);
+		}
 	}
-	CLNT_RELEASE(cl);
 
 	/* Unblock reception. */
 	val = 0;
 	CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &val);
 
-	/* Once the upcall is done, the daemon is done with the fp and so. */
-	mtx_lock(&rpctls_connect_lock);
-	rpctls_connect_so = NULL;
-	rpctls_connect_cl = NULL;
-	rpctls_connect_busy = false;
-	wakeup(&rpctls_connect_busy);
-	mtx_unlock(&rpctls_connect_lock);
-
 	return (stat);
 }
 
@@ -545,20 +494,13 @@ rpctls_cl_handlerecord(uint64_t sec, uint64_t usec, uint64_t ssl,
 	struct rpctlscd_handlerecord_arg arg;
 	struct rpctlscd_handlerecord_res res;
 	enum clnt_stat stat;
-	CLIENT *cl;
-
-	cl = rpctls_connect_client();
-	if (cl == NULL) {
-		*reterr = RPCTLSERR_NOSSL;
-		return (RPC_SUCCESS);
-	}
+	CLIENT *cl = KRPC_VNET(rpctls_connect_handle);
 
 	/* Do the handlerecord upcall. */
 	arg.sec = sec;
 	arg.usec = usec;
 	arg.ssl = ssl;
 	stat = rpctlscd_handlerecord_1(&arg, &res, cl);
-	CLNT_RELEASE(cl);
 	if (stat == RPC_SUCCESS)
 		*reterr = res.reterr;
 	return (stat);
@@ -598,20 +540,13 @@ rpctls_cl_disconnect(uint64_t sec, uint64_t usec, uint64_t ssl,
 	struct rpctlscd_disconnect_arg arg;
 	struct rpctlscd_disconnect_res res;
 	enum clnt_stat stat;
-	CLIENT *cl;
-
-	cl = rpctls_connect_client();
-	if (cl == NULL) {
-		*reterr = RPCTLSERR_NOSSL;
-		return (RPC_SUCCESS);
-	}
+	CLIENT *cl = KRPC_VNET(rpctls_connect_handle);
 
 	/* Do the disconnect upcall. */
 	arg.sec = sec;
 	arg.usec = usec;
 	arg.ssl = ssl;
 	stat = rpctlscd_disconnect_1(&arg, &res, cl);
-	CLNT_RELEASE(cl);
 	if (stat == RPC_SUCCESS)
 		*reterr = res.reterr;
 	return (stat);
@@ -854,8 +789,6 @@ rpctls_getinfo(u_int *maxlenp, bool rpctlscd_run, bool rpctlssd_run)
 	    &maxlen, &siz, NULL, 0, NULL, 0);
 	if (error != 0)
 		return (false);
-	if (rpctlscd_run && rpctls_connect_handle == NULL)
-		return (false);
 	KRPC_CURVNET_SET_QUIET(KRPC_TD_TO_VNET(curthread));
 	if (rpctlssd_run && KRPC_VNET(rpctls_server_handle)[0] == NULL) {
 		KRPC_CURVNET_RESTORE();
diff --git a/sys/rpc/rpcsec_tls/rpctlscd.x b/sys/rpc/rpcsec_tls/rpctlscd.x
index 9a7fb46181c5..5c323445079b 100644
--- a/sys/rpc/rpcsec_tls/rpctlscd.x
+++ b/sys/rpc/rpcsec_tls/rpctlscd.x
@@ -29,6 +29,7 @@
 
 
 struct rpctlscd_connect_arg {
+	uint64_t socookie;
 	char certname<>;
 };