git: 65074f6f3115 - main - pf: fix double ruleset evaluation for fragments sent to dummynet

From: Kristof Provost <kp_at_FreeBSD.org>
Date: Thu, 03 Oct 2024 11:58:08 UTC
The branch main has been updated by kp:

URL: https://cgit.FreeBSD.org/src/commit/?id=65074f6f3115a066e5beda31787df626aed03696

commit 65074f6f3115a066e5beda31787df626aed03696
Author:     Kajetan Staszkiewicz <vegeta@tuxpowered.net>
AuthorDate: 2024-10-03 08:28:57 +0000
Commit:     Kristof Provost <kp@FreeBSD.org>
CommitDate: 2024-10-03 11:49:57 +0000

    pf: fix double ruleset evaluation for fragments sent to dummynet
    
    The function `pf_setup_pdesc()` handles ruleset evaluation for non-reassembled
    packets. Having it called before `pf_mtag` is checked for flags
    `PF_MTAG_FLAG_ROUTE_TO` and `PF_MTAG_FLAG_DUMMYNET` will cause loops for
    fragmented packets if reassembly is disabled.
    
    Move `pd` zeroing and `pf_mtag` extraction from `pf_setup_pdesc()` to a separate
    function `pf_init_pdesc()` and change the order of function calls: first
    call `pf_init_pdesc()`, then check if the currently processed packet has been
    reinjected from dummynet, finally call `pf_setup_pdesc()`.
    
    Add functionality of sending UDP packets to `pft_ping.py` with fragmentation
    support and fix broken IPv6 reassembly.
    
    Reviewed by:    kp
    Differential Revision:  https://reviews.freebsd.org/D46880
---
 sys/netpfil/pf/pf.c                        |  62 +++++++++--------
 tests/sys/netpfil/common/pft_ping.py       | 107 +++++++++++++++++++++++++++--
 tests/sys/netpfil/common/sniffer.py        |  13 ++--
 tests/sys/netpfil/pf/fragmentation_pass.sh |  43 ++++++++++++
 4 files changed, 188 insertions(+), 37 deletions(-)

diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c
index bbed285ab9f8..7edf65ae4a09 100644
--- a/sys/netpfil/pf/pf.c
+++ b/sys/netpfil/pf/pf.c
@@ -8460,6 +8460,13 @@ pf_dummynet_route(struct pf_pdesc *pd, struct pf_kstate *s,
 	return (0);
 }
 
+static void
+pf_init_pdesc(struct pf_pdesc *pd, struct mbuf *m)
+{
+	memset(pd, 0, sizeof(*pd));
+	pd->pf_mtag = pf_find_mtag(m);
+}
+
 static int
 pf_setup_pdesc(sa_family_t af, int dir, struct pf_pdesc *pd, struct mbuf **m0,
     u_short *action, u_short *reason, struct pfi_kkif *kif, struct pf_krule **a,
@@ -8469,21 +8476,18 @@ pf_setup_pdesc(sa_family_t af, int dir, struct pf_pdesc *pd, struct mbuf **m0,
 {
 	struct mbuf *m = *m0;
 
-	memset(pd, 0, sizeof(*pd));
+	pd->af = af;
 	pd->dir = dir;
 
 	TAILQ_INIT(&pd->sctp_multihome_jobs);
 	if (default_actions != NULL)
 		memcpy(&pd->act, default_actions, sizeof(pd->act));
-	pd->pf_mtag = pf_find_mtag(m);
 
 	if (pd->pf_mtag && pd->pf_mtag->dnpipe) {
 		pd->act.dnpipe = pd->pf_mtag->dnpipe;
 		pd->act.flags = pd->pf_mtag->dnflags;
 	}
 
-	pd->af = af;
-
 	switch (af) {
 #ifdef INET
 	case AF_INET: {
@@ -8918,30 +8922,7 @@ pf_test(sa_family_t af, int dir, int pflags, struct ifnet *ifp, struct mbuf **m0
 			return (PF_DROP);
 	}
 
-	if (pf_setup_pdesc(af, dir, &pd, m0, &action, &reason, kif, &a, &r,
-		&s, &ruleset, &off, &hdrlen, inp, default_actions) == -1) {
-		if (action != PF_PASS)
-			pd.act.log |= PF_LOG_FORCE;
-		goto done;
-	}
-	m = *m0;
-
-	switch (af) {
-#ifdef INET
-	case AF_INET:
-		h = mtod(m, struct ip *);
-		ttl = h->ip_ttl;
-		break;
-#endif
-#ifdef INET6
-	case AF_INET6:
-		h6 = mtod(m, struct ip6_hdr *);
-		ttl = h6->ip6_hlim;
-		break;
-#endif
-	default:
-		panic("Unknown af %d", af);
-	}
+	pf_init_pdesc(&pd, m);
 
 	if (pd.pf_mtag != NULL && (pd.pf_mtag->flags & PF_MTAG_FLAG_ROUTE_TO)) {
 		pd.pf_mtag->flags &= ~PF_MTAG_FLAG_ROUTE_TO;
@@ -8974,6 +8955,31 @@ pf_test(sa_family_t af, int dir, int pflags, struct ifnet *ifp, struct mbuf **m0
 		return (PF_PASS);
 	}
 
+	if (pf_setup_pdesc(af, dir, &pd, m0, &action, &reason, kif, &a, &r,
+		&s, &ruleset, &off, &hdrlen, inp, default_actions) == -1) {
+		if (action != PF_PASS)
+			pd.act.log |= PF_LOG_FORCE;
+		goto done;
+	}
+	m = *m0;
+
+	switch (af) {
+#ifdef INET
+	case AF_INET:
+		h = mtod(m, struct ip *);
+		ttl = h->ip_ttl;
+		break;
+#endif
+#ifdef INET6
+	case AF_INET6:
+		h6 = mtod(m, struct ip6_hdr *);
+		ttl = h6->ip6_hlim;
+		break;
+#endif
+	default:
+		panic("Unknown af %d", af);
+	}
+
 	if (__predict_false(ip_divert_ptr != NULL) &&
 	    ((mtag = m_tag_locate(m, MTAG_PF_DIVERT, 0, NULL)) != NULL)) {
 		struct pf_divert_mtag *dt = (struct pf_divert_mtag *)(mtag+1);
diff --git a/tests/sys/netpfil/common/pft_ping.py b/tests/sys/netpfil/common/pft_ping.py
index d8aafc884265..1caa26abe5f6 100644
--- a/tests/sys/netpfil/common/pft_ping.py
+++ b/tests/sys/netpfil/common/pft_ping.py
@@ -103,7 +103,7 @@ def send_icmp_ping(send_params):
         ip6 = prepare_ipv6(send_params)
         icmp = sp.ICMPv6EchoRequest(data=sp.raw(build_payload(send_length)))
         if send_frag_length:
-            for packet in sp.fragment(ip6 / icmp, fragsize=send_frag_length):
+            for packet in sp.fragment6(ip6 / icmp, fragSize=send_frag_length):
                 packets.append(ether / packet)
         else:
             packets.append(ether / ip6 / icmp)
@@ -141,6 +141,39 @@ def send_tcp_syn(send_params):
     sp.sendp(req, iface=send_params['sendif'], verbose=False)
 
 
+def send_udp(send_params):
+    LOGGER.debug(f'Sending UDP ping')
+    packets = []
+    send_length = send_params['length']
+    send_frag_length = send_params['frag_length']
+    ether = sp.Ether()
+    if ':' in send_params['dst_address']:
+        ip6 = prepare_ipv6(send_params)
+        udp = sp.UDP(
+            sport=send_params.get('sport'), dport=send_params.get('dport'),
+        )
+        raw = sp.Raw(load=build_payload(send_length))
+        if send_frag_length:
+            for packet in sp.fragment6(ip6 / udp / raw, fragSize=send_frag_length):
+                packets.append(ether / packet)
+        else:
+            packets.append(ether / ip6 / udp / raw)
+    else:
+        ip = prepare_ipv4(send_params)
+        udp = sp.UDP(
+            sport=send_params.get('sport'), dport=send_params.get('dport'),
+        )
+        raw = sp.Raw(load=build_payload(send_length))
+        if send_frag_length:
+            for packet in sp.fragment(ip / udp / raw, fragsize=send_frag_length):
+                packets.append(ether / packet)
+        else:
+            packets.append(ether / ip / udp / raw)
+
+    for packet in packets:
+        sp.sendp(packet, iface=send_params['sendif'], verbose=False)
+
+
 def send_ping(ping_type, send_params):
     if ping_type == 'icmp':
         send_icmp_ping(send_params)
@@ -149,8 +182,10 @@ def send_ping(ping_type, send_params):
         ping_type == 'tcp3way'
     ):
         send_tcp_syn(send_params)
+    elif ping_type == 'udp':
+        send_udp(send_params)
     else:
-        raise Exception('Unspported ping type')
+        raise Exception('Unsupported ping type')
 
 
 def check_ipv4(expect_params, packet):
@@ -345,6 +380,30 @@ def check_tcp(expect_params, packet):
     return True
 
 
+def check_udp(expect_params, packet):
+    expect_length = expect_params['length']
+    udp = packet.getlayer(sp.UDP)
+    if not udp:
+        LOGGER.debug('Packet is not UDP!')
+        return False
+    raw = packet.getlayer(sp.Raw)
+    if not raw:
+        LOGGER.debug('Packet contains no payload!')
+        return False
+    if raw.load != build_payload(expect_length):
+        LOGGER.debug(f'Payload magic does not match len {len(raw.load)} vs {expect_length}!')
+        return False
+    orig_chksum = udp.chksum
+    udp.chksum = None
+    newpacket = sp.Ether(sp.raw(packet[sp.Ether]))
+    new_chksum = newpacket[sp.UDP].chksum
+    if new_chksum and orig_chksum != new_chksum:
+        LOGGER.debug(f'Wrong UDP checksum {orig_chksum}, expected {new_chksum}!')
+        return False
+
+    return True
+
+
 def check_tcp_syn_request_4(expect_params, packet):
     if not check_ipv4(expect_params, packet):
         return False
@@ -391,6 +450,14 @@ def check_tcp_3way_4(args, packet):
     return False
 
 
+def check_udp_request_4(expect_params, packet):
+    if not check_ipv4(expect_params, packet):
+        return False
+    if not check_udp(expect_params, packet):
+        return False
+    return True
+
+
 def check_tcp_syn_request_6(expect_params, packet):
     if not check_ipv6(expect_params, packet):
         return False
@@ -437,6 +504,13 @@ def check_tcp_3way_6(args, packet):
     return False
 
 
+def check_udp_request_6(expect_params, packet):
+    if not check_ipv6(expect_params, packet):
+        return False
+    if not check_udp(expect_params, packet):
+        return False
+    return True
+
 def check_tcp_syn_request(args, packet):
     expect_params = args['expect_params']
     src_address = expect_params.get('src_address')
@@ -481,6 +555,21 @@ def check_tcp_3way(args, packet):
         return check_tcp_3way_4(args, packet)
 
 
+def check_udp_request(args, packet):
+    expect_params = args['expect_params']
+    src_address = expect_params.get('src_address')
+    dst_address = expect_params.get('dst_address')
+    if not (src_address or dst_address):
+        raise Exception('Source or destination address must be given to match the tcp syn request!')
+    if (
+            (src_address and ':' in src_address) or
+            (dst_address and ':' in dst_address)
+    ):
+        return check_udp_request_6(expect_params, packet)
+    else:
+        return check_udp_request_4(expect_params, packet)
+
+
 def setup_sniffer(
         recvif, ping_type, sniff_type, expect_params, defrag, send_params,
 ):
@@ -494,8 +583,10 @@ def setup_sniffer(
         checkfn = check_tcp_syn_reply
     elif ping_type == 'tcp3way' and sniff_type == 'reply':
         checkfn = check_tcp_3way
+    elif ping_type == 'udp' and sniff_type == 'request':
+        checkfn = check_udp_request
     else:
-        raise Exception('Unspported ping or sniff type')
+        raise Exception('Unspported ping and sniff type combination')
 
     return Sniffer(
         {'send_params': send_params, 'expect_params': expect_params},
@@ -513,7 +604,7 @@ def parse_args():
     parser.add_argument('--to', required=True,
         help='The destination IP address for the ping request')
     parser.add_argument('--ping-type',
-        choices=('icmp', 'tcpsyn', 'tcp3way'),
+        choices=('icmp', 'tcpsyn', 'tcp3way', 'udp'),
         help='Type of ping: ICMP (default) or TCP SYN or 3-way TCP handshake',
         default='icmp')
     parser.add_argument('--fromaddr',
@@ -612,7 +703,13 @@ def main():
     sniffers = []
 
     if send_params['frag_length']:
-        defrag = True
+        if (
+            (send_params['src_address'] and ':' in send_params['src_address']) or
+            (send_params['dst_address'] and ':' in send_params['dst_address'])
+        ):
+            defrag = 'IPv6'
+        else:
+            defrag = 'IPv4'
     else:
         defrag = False
 
diff --git a/tests/sys/netpfil/common/sniffer.py b/tests/sys/netpfil/common/sniffer.py
index 14305a37278c..583b27d34ca6 100644
--- a/tests/sys/netpfil/common/sniffer.py
+++ b/tests/sys/netpfil/common/sniffer.py
@@ -56,14 +56,19 @@ class Sniffer(threading.Thread):
 
 	def run(self):
 		self.packets = []
-		if self._defrag:
-			# With fragment reassembly we can't stop the sniffer after catching
-			# the good packets, as those have not been reassembled. We must
-			#  wait for sniffer to finish and check returned packets instead.
+		# With fragment reassembly we can't stop the sniffer after catching
+		# the good packets, as those have not been reassembled. We must
+		#  wait for sniffer to finish and check returned packets instead.
+		if self._defrag == 'IPv4':
 			self.packets = sp.sniff(session=sp.IPSession, iface=self._recvif,
 				timeout=self._timeout, started_callback=self._startedCb)
 			for p in self.packets:
 				self._checkPacket(p)
+		elif self._defrag == 'IPv6':
+			self.packets = sp.sniff(session=sp.DefaultSession, iface=self._recvif,
+				timeout=self._timeout, started_callback=self._startedCb)
+			for p in sp.defragment6(self.packets):
+				self._checkPacket(p)
 		else:
 			self.packets = sp.sniff(iface=self._recvif,
 				stop_filter=self._checkPacket, timeout=self._timeout,
diff --git a/tests/sys/netpfil/pf/fragmentation_pass.sh b/tests/sys/netpfil/pf/fragmentation_pass.sh
index d505accba5f2..99d2c827b239 100644
--- a/tests/sys/netpfil/pf/fragmentation_pass.sh
+++ b/tests/sys/netpfil/pf/fragmentation_pass.sh
@@ -553,6 +553,48 @@ dummynet_nat_cleanup()
 	pft_cleanup
 }
 
+atf_test_case "dummynet_fragmented" "cleanup"
+dummynet_fragmented_head()
+{
+	atf_set descr 'Test dummynet on NATed fragmented traffic'
+	atf_set require.user root
+}
+
+dummynet_fragmented_body()
+{
+	pft_init
+	dummynet_init
+
+	# No test for IPv6. IPv6 fragment reassembly can't be disabled.
+	setup_router_dummy_ipv4
+
+	jexec router dnctl pipe 1 config delay 1
+
+	pft_set_rules router \
+		"set reassemble no" \
+		"block" \
+		"pass inet6 proto icmp6 icmp6-type { neighbrsol, neighbradv }" \
+		"pass in  on ${epair_tester}b inet  proto udp dnpipe (1, 1)" \
+		"pass out on ${epair_server}a inet  proto udp" \
+
+	ping_dummy_check_request exit:0 --ping-type=udp --send-length=10000 --send-frag-length=1280
+
+	rules=$(mktemp) || exit 1
+	jexec router pfctl -qvsr > $rules
+
+	# Count that fragmented packets have hit the rule only once and that
+	# they have not created states. There is no stateful firewall support
+	# for fragmented packets.
+	grep -A2 'pass in on epair0b inet proto udp all keep state dnpipe(1, 1)' $rules |
+		grep -qE 'Packets: 8\s+Bytes: 10168\s+States: 0\s+' ||
+		atf_fail "Fragmented packets not counted correctly"
+}
+
+dummynet_fragmented_cleanup()
+{
+	pft_cleanup
+}
+
 atf_init_test_cases()
 {
 	atf_add_test_case "too_many_fragments"
@@ -566,4 +608,5 @@ atf_init_test_cases()
 	atf_add_test_case "reassemble_slowpath"
 	atf_add_test_case "dummynet"
 	atf_add_test_case "dummynet_nat"
+	atf_add_test_case "dummynet_fragmented"
 }