git: 989453da0589 - main - sctp: cleanup the SCTP_MAXSEG socket option.

From: Michael Tuexen <tuexen_at_FreeBSD.org>
Date: Mon, 27 Dec 2021 22:43:40 UTC
The branch main has been updated by tuexen:

URL: https://cgit.FreeBSD.org/src/commit/?id=989453da0589b8dc5c1948fd81f986a37ea385eb

commit 989453da0589b8dc5c1948fd81f986a37ea385eb
Author:     Michael Tuexen <tuexen@FreeBSD.org>
AuthorDate: 2021-12-27 22:40:31 +0000
Commit:     Michael Tuexen <tuexen@FreeBSD.org>
CommitDate: 2021-12-27 22:40:31 +0000

    sctp: cleanup the SCTP_MAXSEG socket option.
    
    This patch makes the handling of the SCTP_MAXSEG socket option
    compliant with RFC 6458 (SCTP socket API) and fixes an issue
    found by syzkaller.
    
    Reported by:    syzbot+a2791b89ab99121e3333@syzkaller.appspotmail.com
    MFC after:      3 days
---
 sys/netinet/sctp_constants.h |  2 -
 sys/netinet/sctp_output.c    | 93 ++++++++++++++++++++++++--------------------
 sys/netinet/sctp_output.h    |  2 +-
 sys/netinet/sctp_pcb.c       |  2 +-
 sys/netinet/sctp_usrreq.c    | 37 +++---------------
 sys/netinet/sctputil.c       |  2 +-
 6 files changed, 59 insertions(+), 79 deletions(-)

diff --git a/sys/netinet/sctp_constants.h b/sys/netinet/sctp_constants.h
index 1ff3f3918ef6..66f2cca5ab6d 100644
--- a/sys/netinet/sctp_constants.h
+++ b/sys/netinet/sctp_constants.h
@@ -673,8 +673,6 @@ __FBSDID("$FreeBSD$");
 /* amount peer is obligated to have in rwnd or I will abort */
 #define SCTP_MIN_RWND	1500
 
-#define SCTP_DEFAULT_MAXSEGMENT 65535
-
 #define SCTP_CHUNK_BUFFER_SIZE	512
 #define SCTP_PARAM_BUFFER_SIZE	512
 
diff --git a/sys/netinet/sctp_output.c b/sys/netinet/sctp_output.c
index 65767f9f73a9..f6597bc6cbdc 100644
--- a/sys/netinet/sctp_output.c
+++ b/sys/netinet/sctp_output.c
@@ -6217,43 +6217,48 @@ sctp_prune_prsctp(struct sctp_tcb *stcb,
 	}			/* if enabled in asoc */
 }
 
-int
-sctp_get_frag_point(struct sctp_tcb *stcb,
-    struct sctp_association *asoc)
+uint32_t
+sctp_get_frag_point(struct sctp_tcb *stcb)
 {
-	int siz, ovh;
+	struct sctp_association *asoc;
+	uint32_t frag_point, overhead;
 
-	/*
-	 * For endpoints that have both v6 and v4 addresses we must reserve
-	 * room for the ipv6 header, for those that are only dealing with V4
-	 * we use a larger frag point.
-	 */
+	asoc = &stcb->asoc;
+	/* Consider IP header and SCTP common header. */
 	if (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_BOUND_V6) {
-		ovh = SCTP_MIN_OVERHEAD;
+		overhead = SCTP_MIN_OVERHEAD;
 	} else {
-		ovh = SCTP_MIN_V4_OVERHEAD;
+		overhead = SCTP_MIN_V4_OVERHEAD;
 	}
-	ovh += SCTP_DATA_CHUNK_OVERHEAD(stcb);
-	if (stcb->asoc.sctp_frag_point > asoc->smallest_mtu)
-		siz = asoc->smallest_mtu - ovh;
-	else
-		siz = (stcb->asoc.sctp_frag_point - ovh);
-	/*
-	 * if (siz > (MCLBYTES-sizeof(struct sctp_data_chunk))) {
-	 */
-	/* A data chunk MUST fit in a cluster */
-	/* siz = (MCLBYTES - sizeof(struct sctp_data_chunk)); */
-	/* } */
-
-	/* adjust for an AUTH chunk if DATA requires auth */
-	if (sctp_auth_is_required_chunk(SCTP_DATA, stcb->asoc.peer_auth_chunks))
-		siz -= sctp_get_auth_chunk_len(stcb->asoc.peer_hmac_id);
+	/* Consider DATA/IDATA chunk header and AUTH header, if needed. */
+	if (asoc->idata_supported) {
+		overhead += sizeof(struct sctp_idata_chunk);
+		if (sctp_auth_is_required_chunk(SCTP_IDATA, asoc->peer_auth_chunks)) {
+			overhead += sctp_get_auth_chunk_len(asoc->peer_hmac_id);
+		}
+	} else {
+		overhead += sizeof(struct sctp_idata_chunk);
+		if (sctp_auth_is_required_chunk(SCTP_DATA, asoc->peer_auth_chunks)) {
+			overhead += sctp_get_auth_chunk_len(asoc->peer_hmac_id);
+		}
+	}
+	/* Consider padding. */
+	if (asoc->smallest_mtu % 4) {
+		overhead += (asoc->smallest_mtu % 4);
+	}
+	KASSERT(overhead % 4 == 0,
+	    ("overhead (%u) not a multiple of 4", overhead));
+	KASSERT(asoc->smallest_mtu > overhead,
+	    ("Association MTU (%u) too small for overhead (%u)",
+	    asoc->smallest_mtu, overhead));
 
-	if (siz % 4) {
-		/* make it an even word boundary please */
-		siz -= (siz % 4);
+	frag_point = asoc->smallest_mtu - overhead;
+	/* Honor MAXSEG socket option. */
+	if ((asoc->sctp_frag_point > 0) &&
+	    (asoc->sctp_frag_point < frag_point)) {
+		frag_point = asoc->sctp_frag_point;
 	}
-	return (siz);
+	return (frag_point);
 }
 
 static void
@@ -6571,7 +6576,8 @@ sctp_med_chunk_output(struct sctp_inpcb *inp,
     int *num_out,
     int *reason_code,
     int control_only, int from_where,
-    struct timeval *now, int *now_filled, int frag_point, int so_locked);
+    struct timeval *now, int *now_filled,
+    uint32_t frag_point, int so_locked);
 
 static void
 sctp_sendall_iterator(struct sctp_inpcb *inp, struct sctp_tcb *stcb, void *ptr,
@@ -6740,13 +6746,13 @@ sctp_sendall_iterator(struct sctp_inpcb *inp, struct sctp_tcb *stcb, void *ptr,
 	if (do_chunk_output)
 		sctp_chunk_output(inp, stcb, SCTP_OUTPUT_FROM_USR_SEND, SCTP_SO_NOT_LOCKED);
 	else if (added_control) {
-		int num_out, reason, now_filled = 0;
 		struct timeval now;
-		int frag_point;
+		int num_out, reason, now_filled = 0;
 
-		frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
 		(void)sctp_med_chunk_output(inp, stcb, &stcb->asoc, &num_out,
-		    &reason, 1, 1, &now, &now_filled, frag_point, SCTP_SO_NOT_LOCKED);
+		    &reason, 1, 1, &now, &now_filled,
+		    sctp_get_frag_point(stcb),
+		    SCTP_SO_NOT_LOCKED);
 	}
 no_chunk_output:
 	if (ret) {
@@ -7674,8 +7680,9 @@ out_of:
 }
 
 static void
-sctp_fill_outqueue(struct sctp_tcb *stcb, struct sctp_nets *net, int frag_point,
-    int eeor_mode, int *quit_now, int so_locked)
+sctp_fill_outqueue(struct sctp_tcb *stcb, struct sctp_nets *net,
+    uint32_t frag_point, int eeor_mode, int *quit_now,
+    int so_locked)
 {
 	struct sctp_association *asoc;
 	struct sctp_stream_out *strq;
@@ -7794,7 +7801,8 @@ sctp_med_chunk_output(struct sctp_inpcb *inp,
     int *num_out,
     int *reason_code,
     int control_only, int from_where,
-    struct timeval *now, int *now_filled, int frag_point, int so_locked)
+    struct timeval *now, int *now_filled,
+    uint32_t frag_point, int so_locked)
 {
 	/**
 	 * Ok this is the generic chunk service queue. we must do the
@@ -9975,7 +9983,7 @@ sctp_chunk_output(struct sctp_inpcb *inp,
 	struct timeval now;
 	int now_filled = 0;
 	int nagle_on;
-	int frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
+	uint32_t frag_point = sctp_get_frag_point(stcb);
 	int un_sent = 0;
 	int fr_done;
 	unsigned int tot_frs = 0;
@@ -13663,16 +13671,17 @@ skip_out_eof:
 		}
 		sctp_chunk_output(inp, stcb, SCTP_OUTPUT_FROM_USR_SEND, SCTP_SO_LOCKED);
 	} else if (some_on_control) {
-		int num_out, reason, frag_point;
+		int num_out, reason;
 
 		/* Here we do control only */
 		if (hold_tcblock == 0) {
 			hold_tcblock = 1;
 			SCTP_TCB_LOCK(stcb);
 		}
-		frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
 		(void)sctp_med_chunk_output(inp, stcb, &stcb->asoc, &num_out,
-		    &reason, 1, 1, &now, &now_filled, frag_point, SCTP_SO_LOCKED);
+		    &reason, 1, 1, &now, &now_filled,
+		    sctp_get_frag_point(stcb),
+		    SCTP_SO_LOCKED);
 	}
 	NET_EPOCH_EXIT(et);
 	SCTPDBG(SCTP_DEBUG_OUTPUT1, "USR Send complete qo:%d prw:%d unsent:%d tf:%d cooq:%d toqs:%d err:%d\n",
diff --git a/sys/netinet/sctp_output.h b/sys/netinet/sctp_output.h
index 7d2cdc4071d8..e6ee80c41f1a 100644
--- a/sys/netinet/sctp_output.h
+++ b/sys/netinet/sctp_output.h
@@ -117,7 +117,7 @@ void sctp_send_asconf(struct sctp_tcb *, struct sctp_nets *, int addr_locked);
 
 void sctp_send_asconf_ack(struct sctp_tcb *);
 
-int sctp_get_frag_point(struct sctp_tcb *, struct sctp_association *);
+uint32_t sctp_get_frag_point(struct sctp_tcb *);
 
 void sctp_toss_old_cookies(struct sctp_tcb *, struct sctp_association *);
 
diff --git a/sys/netinet/sctp_pcb.c b/sys/netinet/sctp_pcb.c
index b4a742c11629..7ad651ec377f 100644
--- a/sys/netinet/sctp_pcb.c
+++ b/sys/netinet/sctp_pcb.c
@@ -2422,7 +2422,7 @@ sctp_inpcb_alloc(struct socket *so, uint32_t vrf_id)
 #endif
 	inp->sctp_associd_counter = 1;
 	inp->partial_delivery_point = SCTP_SB_LIMIT_RCV(so) >> SCTP_PARTIAL_DELIVERY_SHIFT;
-	inp->sctp_frag_point = SCTP_DEFAULT_MAXSEGMENT;
+	inp->sctp_frag_point = 0;
 	inp->max_cwnd = 0;
 	inp->sctp_cmt_on_off = SCTP_BASE_SYSCTL(sctp_cmt_on_off);
 	inp->ecn_supported = (uint8_t)SCTP_BASE_SYSCTL(sctp_ecn_enable);
diff --git a/sys/netinet/sctp_usrreq.c b/sys/netinet/sctp_usrreq.c
index f218950feef9..bb84d3b7083f 100644
--- a/sys/netinet/sctp_usrreq.c
+++ b/sys/netinet/sctp_usrreq.c
@@ -2032,13 +2032,12 @@ flags_out:
 	case SCTP_MAXSEG:
 		{
 			struct sctp_assoc_value *av;
-			int ovh;
 
 			SCTP_CHECK_AND_CAST(av, optval, struct sctp_assoc_value, *optsize);
 			SCTP_FIND_STCB(inp, stcb, av->assoc_id);
 
 			if (stcb) {
-				av->assoc_value = sctp_get_frag_point(stcb, &stcb->asoc);
+				av->assoc_value = stcb->asoc.sctp_frag_point;
 				SCTP_TCB_UNLOCK(stcb);
 			} else {
 				if ((inp->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) ||
@@ -2046,15 +2045,7 @@ flags_out:
 				    ((inp->sctp_flags & SCTP_PCB_FLAGS_UDPTYPE) &&
 				    (av->assoc_id == SCTP_FUTURE_ASSOC))) {
 					SCTP_INP_RLOCK(inp);
-					if (inp->sctp_flags & SCTP_PCB_FLAGS_BOUND_V6) {
-						ovh = SCTP_MED_OVERHEAD;
-					} else {
-						ovh = SCTP_MED_V4_OVERHEAD;
-					}
-					if (inp->sctp_frag_point >= SCTP_DEFAULT_MAXSEGMENT)
-						av->assoc_value = 0;
-					else
-						av->assoc_value = inp->sctp_frag_point - ovh;
+					av->assoc_value = inp->sctp_frag_point;
 					SCTP_INP_RUNLOCK(inp);
 				} else {
 					SCTP_LTRACE_ERR_RET(inp, NULL, NULL, SCTP_FROM_SCTP_USRREQ, EINVAL);
@@ -2623,7 +2614,7 @@ flags_out:
 			    stcb->asoc.cnt_on_all_streams);
 			sstat->sstat_instrms = stcb->asoc.streamincnt;
 			sstat->sstat_outstrms = stcb->asoc.streamoutcnt;
-			sstat->sstat_fragmentation_point = sctp_get_frag_point(stcb, &stcb->asoc);
+			sstat->sstat_fragmentation_point = sctp_get_frag_point(stcb);
 			net = stcb->asoc.primary_destination;
 			if (net != NULL) {
 				memcpy(&sstat->sstat_primary.spinfo_address,
@@ -4977,22 +4968,12 @@ sctp_setopt(struct socket *so, int optname, void *optval, size_t optsize,
 	case SCTP_MAXSEG:
 		{
 			struct sctp_assoc_value *av;
-			int ovh;
 
 			SCTP_CHECK_AND_CAST(av, optval, struct sctp_assoc_value, optsize);
 			SCTP_FIND_STCB(inp, stcb, av->assoc_id);
 
-			if (inp->sctp_flags & SCTP_PCB_FLAGS_BOUND_V6) {
-				ovh = SCTP_MED_OVERHEAD;
-			} else {
-				ovh = SCTP_MED_V4_OVERHEAD;
-			}
 			if (stcb) {
-				if (av->assoc_value) {
-					stcb->asoc.sctp_frag_point = (av->assoc_value + ovh);
-				} else {
-					stcb->asoc.sctp_frag_point = SCTP_DEFAULT_MAXSEGMENT;
-				}
+				stcb->asoc.sctp_frag_point = av->assoc_value;
 				SCTP_TCB_UNLOCK(stcb);
 			} else {
 				if ((inp->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) ||
@@ -5000,15 +4981,7 @@ sctp_setopt(struct socket *so, int optname, void *optval, size_t optsize,
 				    ((inp->sctp_flags & SCTP_PCB_FLAGS_UDPTYPE) &&
 				    (av->assoc_id == SCTP_FUTURE_ASSOC))) {
 					SCTP_INP_WLOCK(inp);
-					/*
-					 * FIXME MT: I think this is not in
-					 * tune with the API ID
-					 */
-					if (av->assoc_value) {
-						inp->sctp_frag_point = (av->assoc_value + ovh);
-					} else {
-						inp->sctp_frag_point = SCTP_DEFAULT_MAXSEGMENT;
-					}
+					inp->sctp_frag_point = av->assoc_value;
 					SCTP_INP_WUNLOCK(inp);
 				} else {
 					SCTP_LTRACE_ERR_RET(inp, NULL, NULL, SCTP_FROM_SCTP_USRREQ, EINVAL);
diff --git a/sys/netinet/sctputil.c b/sys/netinet/sctputil.c
index 6c58ad47f274..df3768ca2a35 100644
--- a/sys/netinet/sctputil.c
+++ b/sys/netinet/sctputil.c
@@ -1248,7 +1248,7 @@ sctp_init_asoc(struct sctp_inpcb *inp, struct sctp_tcb *stcb,
 	asoc->my_rwnd = max(SCTP_SB_LIMIT_RCV(inp->sctp_socket), SCTP_MINIMAL_RWND);
 	asoc->peers_rwnd = SCTP_SB_LIMIT_RCV(inp->sctp_socket);
 
-	asoc->smallest_mtu = inp->sctp_frag_point;
+	asoc->smallest_mtu = 0;
 	asoc->minrto = inp->sctp_ep.sctp_minrto;
 	asoc->maxrto = inp->sctp_ep.sctp_maxrto;