git: a8280123e4c6 - main - KTLS: Add a new recrypt operation to the software backend.

From: John Baldwin <jhb_at_FreeBSD.org>
Date: Fri, 22 Apr 2022 22:55:23 UTC
The branch main has been updated by jhb:

URL: https://cgit.FreeBSD.org/src/commit/?id=a8280123e4c681f505917fdf126ea9091addab62

commit a8280123e4c681f505917fdf126ea9091addab62
Author:     John Baldwin <jhb@FreeBSD.org>
AuthorDate: 2022-04-22 22:52:50 +0000
Commit:     John Baldwin <jhb@FreeBSD.org>
CommitDate: 2022-04-22 22:52:50 +0000

    KTLS: Add a new recrypt operation to the software backend.
    
    When using NIC TLS RX, packets that are dropped and retransmitted are
    not decrypted by the NIC but are passed along as-is.  As a result, a
    received TLS record might contain a mix of encrypted and decrypted
    data.  If this occurs, the already-decrypted data needs to be
    re-encrypted so that the resulting record can then be decrypted
    normally.
    
    Add support for this for sessions using AES-GCM with TLS 1.2 or TLS
    1.3.  For the recrypt operation, allocate a temporary buffer and
    encrypt the the payload portion of the TLS record with AES-CTR with an
    initial IV constructed from the AES-GCM nonce.  Then fixup the
    original mbuf chain by copying the results from the temporary buffer
    back into the original mbufs for any mbufs containing decrypted data.
    
    Once it has been recrypted, the mbuf chain can then be decrypted via
    the normal software decryption path.
    
    Co-authored by: Hans Petter Selasky <hselasky@FreeBSD.org>
    Reviewed by:    hselasky
    Sponsored by:   Netflix
    Differential Revision:  https://reviews.freebsd.org/D35012
---
 sys/opencrypto/ktls.h     |   3 +
 sys/opencrypto/ktls_ocf.c | 176 +++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 178 insertions(+), 1 deletion(-)

diff --git a/sys/opencrypto/ktls.h b/sys/opencrypto/ktls.h
index b97f589fecb4..503864f87ccc 100644
--- a/sys/opencrypto/ktls.h
+++ b/sys/opencrypto/ktls.h
@@ -55,5 +55,8 @@ int ktls_ocf_encrypt(struct ktls_ocf_encrypt_state *state,
 int ktls_ocf_decrypt(struct ktls_session *tls,
     const struct tls_record_layer *hdr, struct mbuf *m, uint64_t seqno,
     int *trailer_len);
+int ktls_ocf_recrypt(struct ktls_session *tls,
+    const struct tls_record_layer *hdr, struct mbuf *m, uint64_t seqno);
+bool ktls_ocf_recrypt_supported(struct ktls_session *tls);
 
 #endif	/* !__OPENCRYPTO_KTLS_H__ */
diff --git a/sys/opencrypto/ktls_ocf.c b/sys/opencrypto/ktls_ocf.c
index 3b330bf7061c..6347ca459646 100644
--- a/sys/opencrypto/ktls_ocf.c
+++ b/sys/opencrypto/ktls_ocf.c
@@ -44,6 +44,7 @@ __FBSDID("$FreeBSD$");
 #include <vm/vm.h>
 #include <vm/pmap.h>
 #include <vm/vm_param.h>
+#include <netinet/in.h>
 #include <opencrypto/cryptodev.h>
 #include <opencrypto/ktls.h>
 
@@ -53,6 +54,11 @@ struct ktls_ocf_sw {
 	    struct ktls_session *tls, struct mbuf *m,
 	    struct iovec *outiov, int outiovcnt);
 
+	/* Re-encrypt a received TLS record that is partially decrypted. */
+	int	(*recrypt)(struct ktls_session *tls,
+	    const struct tls_record_layer *hdr, struct mbuf *m,
+	    uint64_t seqno);
+
 	/* Decrypt a received TLS record. */
 	int	(*decrypt)(struct ktls_session *tls,
 	    const struct tls_record_layer *hdr, struct mbuf *m,
@@ -63,6 +69,7 @@ struct ktls_ocf_session {
 	const struct ktls_ocf_sw *sw;
 	crypto_session_t sid;
 	crypto_session_t mac_sid;
+	crypto_session_t recrypt_sid;
 	struct mtx lock;
 	int mac_len;
 	bool implicit_iv;
@@ -109,6 +116,11 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_gcm_encrypts,
     CTLFLAG_RD, &ocf_tls12_gcm_encrypts,
     "Total number of OCF TLS 1.2 GCM encryption operations");
 
+static COUNTER_U64_DEFINE_EARLY(ocf_tls12_gcm_recrypts);
+SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_gcm_recrypts,
+    CTLFLAG_RD, &ocf_tls12_gcm_recrypts,
+    "Total number of OCF TLS 1.2 GCM re-encryption operations");
+
 static COUNTER_U64_DEFINE_EARLY(ocf_tls12_chacha20_decrypts);
 SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls12_chacha20_decrypts,
     CTLFLAG_RD, &ocf_tls12_chacha20_decrypts,
@@ -129,6 +141,11 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_gcm_encrypts,
     CTLFLAG_RD, &ocf_tls13_gcm_encrypts,
     "Total number of OCF TLS 1.3 GCM encryption operations");
 
+static COUNTER_U64_DEFINE_EARLY(ocf_tls13_gcm_recrypts);
+SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_gcm_recrypts,
+    CTLFLAG_RD, &ocf_tls13_gcm_recrypts,
+    "Total number of OCF TLS 1.3 GCM re-encryption operations");
+
 static COUNTER_U64_DEFINE_EARLY(ocf_tls13_chacha20_decrypts);
 SYSCTL_COUNTER_U64(_kern_ipc_tls_stats_ocf, OID_AUTO, tls13_chacha20_decrypts,
     CTLFLAG_RD, &ocf_tls13_chacha20_decrypts,
@@ -549,8 +566,84 @@ ktls_ocf_tls12_aead_decrypt(struct ktls_session *tls,
 	return (error);
 }
 
+/*
+ * Reconstruct encrypted mbuf data in input buffer.
+ */
+static void
+ktls_ocf_recrypt_fixup(struct mbuf *m, u_int skip, u_int len, char *buf)
+{
+	const char *src = buf;
+	u_int todo;
+
+	while (skip >= m->m_len) {
+		skip -= m->m_len;
+		m = m->m_next;
+	}
+
+	while (len > 0) {
+		todo = m->m_len - skip;
+		if (todo > len)
+			todo = len;
+
+		if (m->m_flags & M_DECRYPTED)
+			memcpy(mtod(m, char *) + skip, src, todo);
+		src += todo;
+		len -= todo;
+		skip = 0;
+		m = m->m_next;
+	}
+}
+
+static int
+ktls_ocf_tls12_aead_recrypt(struct ktls_session *tls,
+    const struct tls_record_layer *hdr, struct mbuf *m,
+    uint64_t seqno)
+{
+	struct cryptop crp;
+	struct ktls_ocf_session *os;
+	char *buf;
+	u_int payload_len;
+	int error;
+
+	os = tls->ocf_session;
+
+	crypto_initreq(&crp, os->recrypt_sid);
+
+	KASSERT(tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16,
+	    ("%s: only AES-GCM is supported", __func__));
+
+	/* Setup the IV. */
+	memcpy(crp.crp_iv, tls->params.iv, TLS_AEAD_GCM_LEN);
+	memcpy(crp.crp_iv + TLS_AEAD_GCM_LEN, hdr + 1, sizeof(uint64_t));
+	be32enc(crp.crp_iv + AES_GCM_IV_LEN, 2);
+
+	payload_len = ntohs(hdr->tls_length) -
+	    (AES_GMAC_HASH_LEN + sizeof(uint64_t));
+	crp.crp_op = CRYPTO_OP_ENCRYPT;
+	crp.crp_flags = CRYPTO_F_CBIMM | CRYPTO_F_IV_SEPARATE;
+	crypto_use_mbuf(&crp, m);
+	crp.crp_payload_start = tls->params.tls_hlen;
+	crp.crp_payload_length = payload_len;
+
+	buf = malloc(payload_len, M_KTLS_OCF, M_WAITOK);
+	crypto_use_output_buf(&crp, buf, payload_len);
+
+	counter_u64_add(ocf_tls12_gcm_recrypts, 1);
+	error = ktls_ocf_dispatch(os, &crp);
+
+	crypto_destroyreq(&crp);
+
+	if (error == 0)
+		ktls_ocf_recrypt_fixup(m, tls->params.tls_hlen, payload_len,
+		    buf);
+
+	free(buf, M_KTLS_OCF);
+	return (error);
+}
+
 static const struct ktls_ocf_sw ktls_ocf_tls12_aead_sw = {
 	.encrypt = ktls_ocf_tls12_aead_encrypt,
+	.recrypt = ktls_ocf_tls12_aead_recrypt,
 	.decrypt = ktls_ocf_tls12_aead_decrypt,
 };
 
@@ -681,8 +774,55 @@ ktls_ocf_tls13_aead_decrypt(struct ktls_session *tls,
 	return (error);
 }
 
+static int
+ktls_ocf_tls13_aead_recrypt(struct ktls_session *tls,
+    const struct tls_record_layer *hdr, struct mbuf *m,
+    uint64_t seqno)
+{
+	struct cryptop crp;
+	struct ktls_ocf_session *os;
+	char *buf;
+	u_int payload_len;
+	int error;
+
+	os = tls->ocf_session;
+
+	crypto_initreq(&crp, os->recrypt_sid);
+
+	KASSERT(tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16,
+	    ("%s: only AES-GCM is supported", __func__));
+
+	/* Setup the IV. */
+	memcpy(crp.crp_iv, tls->params.iv, tls->params.iv_len);
+	*(uint64_t *)(crp.crp_iv + 4) ^= htobe64(seqno);
+	be32enc(crp.crp_iv + 12, 2);
+
+	payload_len = ntohs(hdr->tls_length) - AES_GMAC_HASH_LEN;
+	crp.crp_op = CRYPTO_OP_ENCRYPT;
+	crp.crp_flags = CRYPTO_F_CBIMM | CRYPTO_F_IV_SEPARATE;
+	crypto_use_mbuf(&crp, m);
+	crp.crp_payload_start = tls->params.tls_hlen;
+	crp.crp_payload_length = payload_len;
+
+	buf = malloc(payload_len, M_KTLS_OCF, M_WAITOK);
+	crypto_use_output_buf(&crp, buf, payload_len);
+
+	counter_u64_add(ocf_tls13_gcm_recrypts, 1);
+	error = ktls_ocf_dispatch(os, &crp);
+
+	crypto_destroyreq(&crp);
+
+	if (error == 0)
+		ktls_ocf_recrypt_fixup(m, tls->params.tls_hlen, payload_len,
+		    buf);
+
+	free(buf, M_KTLS_OCF);
+	return (error);
+}
+
 static const struct ktls_ocf_sw ktls_ocf_tls13_aead_sw = {
 	.encrypt = ktls_ocf_tls13_aead_encrypt,
+	.recrypt = ktls_ocf_tls13_aead_recrypt,
 	.decrypt = ktls_ocf_tls13_aead_decrypt,
 };
 
@@ -694,6 +834,7 @@ ktls_ocf_free(struct ktls_session *tls)
 	os = tls->ocf_session;
 	crypto_freesession(os->sid);
 	crypto_freesession(os->mac_sid);
+	crypto_freesession(os->recrypt_sid);
 	mtx_destroy(&os->lock);
 	zfree(os, M_KTLS_OCF);
 }
@@ -701,7 +842,7 @@ ktls_ocf_free(struct ktls_session *tls)
 int
 ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction)
 {
-	struct crypto_session_params csp, mac_csp;
+	struct crypto_session_params csp, mac_csp, recrypt_csp;
 	struct ktls_ocf_session *os;
 	int error, mac_len;
 
@@ -709,6 +850,8 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction)
 	memset(&mac_csp, 0, sizeof(mac_csp));
 	mac_csp.csp_mode = CSP_MODE_NONE;
 	mac_len = 0;
+	memset(&recrypt_csp, 0, sizeof(mac_csp));
+	recrypt_csp.csp_mode = CSP_MODE_NONE;
 
 	switch (tls->params.cipher_algorithm) {
 	case CRYPTO_AES_NIST_GCM_16:
@@ -732,6 +875,13 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction)
 		csp.csp_cipher_key = tls->params.cipher_key;
 		csp.csp_cipher_klen = tls->params.cipher_key_len;
 		csp.csp_ivlen = AES_GCM_IV_LEN;
+
+		recrypt_csp.csp_flags |= CSP_F_SEPARATE_OUTPUT;
+		recrypt_csp.csp_mode = CSP_MODE_CIPHER;
+		recrypt_csp.csp_cipher_alg = CRYPTO_AES_ICM;
+		recrypt_csp.csp_cipher_key = tls->params.cipher_key;
+		recrypt_csp.csp_cipher_klen = tls->params.cipher_key_len;
+		recrypt_csp.csp_ivlen = AES_BLOCK_LEN;
 		break;
 	case CRYPTO_AES_CBC:
 		switch (tls->params.cipher_key_len) {
@@ -826,6 +976,16 @@ ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction)
 		os->mac_len = mac_len;
 	}
 
+	if (recrypt_csp.csp_mode != CSP_MODE_NONE) {
+		error = crypto_newsession(&os->recrypt_sid, &recrypt_csp,
+		    CRYPTO_FLAG_HARDWARE | CRYPTO_FLAG_SOFTWARE);
+		if (error) {
+			crypto_freesession(os->sid);
+			free(os, M_KTLS_OCF);
+			return (error);
+		}
+	}
+
 	mtx_init(&os->lock, "ktls_ocf", NULL, MTX_DEF);
 	tls->ocf_session = os;
 	if (tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16 ||
@@ -870,3 +1030,17 @@ ktls_ocf_decrypt(struct ktls_session *tls, const struct tls_record_layer *hdr,
 {
 	return (tls->ocf_session->sw->decrypt(tls, hdr, m, seqno, trailer_len));
 }
+
+int
+ktls_ocf_recrypt(struct ktls_session *tls, const struct tls_record_layer *hdr,
+    struct mbuf *m, uint64_t seqno)
+{
+	return (tls->ocf_session->sw->recrypt(tls, hdr, m, seqno));
+}
+
+bool
+ktls_ocf_recrypt_supported(struct ktls_session *tls)
+{
+	return (tls->ocf_session->sw->recrypt != NULL &&
+	    tls->ocf_session->recrypt_sid != NULL);
+}