git: 6ffd4aeba5b9 - main - pf tests: add a test for max-src-conn

From: Kristof Provost <kp_at_FreeBSD.org>
Date: Sat, 28 Sep 2024 16:55:21 UTC
The branch main has been updated by kp:

URL: https://cgit.FreeBSD.org/src/commit/?id=6ffd4aeba5b997c4c538711d259200da193d949d

commit 6ffd4aeba5b997c4c538711d259200da193d949d
Author:     Kajetan Staszkiewicz <vegeta@tuxpowered.net>
AuthorDate: 2024-09-28 13:37:45 +0000
Commit:     Kristof Provost <kp@FreeBSD.org>
CommitDate: 2024-09-28 16:54:50 +0000

    pf tests: add a test for max-src-conn
    
    Switch tests using pft_ping.py to inetd. Netcat can only accept a single
    connection, we need multiple parallel connections to test max-src-conn. Use the
    discard service and modify pft_ping.py to use proper port number.
    
    Implement functionality of 3-way handshake test in pft_ping.py. Make send_params
    accessible to sniffer, because answers to SYN+ACK packets should be send with
    the same parameters as the original SYN packet.
    
    Finally add a test for max-src-conn and overload.
    
    Reviewed by:    kp
    Differential Revision:  https://reviews.freebsd.org/D46798
---
 tests/sys/netpfil/common/pft_ping.py | 185 +++++++++++++++++++++++++++--------
 tests/sys/netpfil/pf/src_track.sh    |  62 ++++++++++++
 tests/sys/netpfil/pf/utils.subr      |  10 +-
 3 files changed, 210 insertions(+), 47 deletions(-)

diff --git a/tests/sys/netpfil/common/pft_ping.py b/tests/sys/netpfil/common/pft_ping.py
index befe757406be..d8aafc884265 100644
--- a/tests/sys/netpfil/common/pft_ping.py
+++ b/tests/sys/netpfil/common/pft_ping.py
@@ -49,8 +49,17 @@ def build_payload(l):
     return ret
 
 
-def prepare_ipv6(dst_address, send_params):
+def clean_params(params):
+    # Prepare a copy of safe copy of params
+    ret = copy(params)
+    ret.pop('src_address')
+    ret.pop('dst_address')
+    ret.pop('flags')
+    return ret
+
+def prepare_ipv6(send_params):
     src_address = send_params.get('src_address')
+    dst_address = send_params.get('dst_address')
     hlim = send_params.get('hlim')
     tc = send_params.get('tc')
     ip6 = sp.IPv6(dst=dst_address)
@@ -63,8 +72,9 @@ def prepare_ipv6(dst_address, send_params):
     return ip6
 
 
-def prepare_ipv4(dst_address, send_params):
+def prepare_ipv4(send_params):
     src_address = send_params.get('src_address')
+    dst_address = send_params.get('dst_address')
     flags = send_params.get('flags')
     tos = send_params.get('tc')
     ttl = send_params.get('hlim')
@@ -84,13 +94,13 @@ def prepare_ipv4(dst_address, send_params):
     return ip
 
 
-def send_icmp_ping(dst_address, sendif, send_params):
+def send_icmp_ping(send_params):
     send_length = send_params['length']
     send_frag_length = send_params['frag_length']
     packets = []
     ether = sp.Ether()
-    if ':' in dst_address:
-        ip6 = prepare_ipv6(dst_address, send_params)
+    if ':' in send_params['dst_address']:
+        ip6 = prepare_ipv6(send_params)
         icmp = sp.ICMPv6EchoRequest(data=sp.raw(build_payload(send_length)))
         if send_frag_length:
             for packet in sp.fragment(ip6 / icmp, fragsize=send_frag_length):
@@ -99,7 +109,7 @@ def send_icmp_ping(dst_address, sendif, send_params):
             packets.append(ether / ip6 / icmp)
 
     else:
-        ip = prepare_ipv4(dst_address, send_params)
+        ip = prepare_ipv4(send_params)
         icmp = sp.ICMP(type='echo-request')
         raw = sp.raw(build_payload(send_length))
         if send_frag_length:
@@ -108,10 +118,10 @@ def send_icmp_ping(dst_address, sendif, send_params):
         else:
             packets.append(ether / ip / icmp / raw)
     for packet in packets:
-        sp.sendp(packet, sendif, verbose=False)
+        sp.sendp(packet, iface=send_params['sendif'], verbose=False)
 
 
-def send_tcp_syn(dst_address, sendif, send_params):
+def send_tcp_syn(send_params):
     tcpopt_unaligned = send_params.get('tcpopt_unaligned')
     seq = send_params.get('seq')
     mss = send_params.get('mss')
@@ -119,23 +129,26 @@ def send_tcp_syn(dst_address, sendif, send_params):
     opts=[('Timestamp', (1, 1)), ('MSS', mss if mss else 1280)]
     if tcpopt_unaligned:
         opts = [('NOP', 0 )] + opts
-    if ':' in dst_address:
-        ip = prepare_ipv6(dst_address, send_params)
+    if ':' in send_params['dst_address']:
+        ip = prepare_ipv6(send_params)
     else:
-        ip = prepare_ipv4(dst_address, send_params)
+        ip = prepare_ipv4(send_params)
     tcp = sp.TCP(
         sport=send_params.get('sport'), dport=send_params.get('dport'),
         flags='S', options=opts, seq=seq,
     )
     req = ether / ip / tcp
-    sp.sendp(req, iface=sendif, verbose=False)
+    sp.sendp(req, iface=send_params['sendif'], verbose=False)
 
 
-def send_ping(dst_address, sendif, ping_type, send_params):
+def send_ping(ping_type, send_params):
     if ping_type == 'icmp':
-        send_icmp_ping(dst_address, sendif, send_params)
-    elif ping_type == 'tcpsyn':
-        send_tcp_syn(dst_address, sendif, send_params)
+        send_icmp_ping(send_params)
+    elif (
+        ping_type == 'tcpsyn' or
+        ping_type == 'tcp3way'
+    ):
+        send_tcp_syn(send_params)
     else:
         raise Exception('Unspported ping type')
 
@@ -147,20 +160,21 @@ def check_ipv4(expect_params, packet):
     tos = expect_params.get('tc')
     ttl = expect_params.get('hlim')
     ip = packet.getlayer(sp.IP)
+    LOGGER.debug(f'Packet: {ip}')
     if not ip:
         LOGGER.debug('Packet is not IPv4!')
         return False
     if src_address and ip.src != src_address:
-        LOGGER.debug('Source IPv4 address does not match!')
+        LOGGER.debug(f'Wrong IPv4 source {ip.src}, expected {src_address}')
         return False
     if dst_address and ip.dst != dst_address:
-        LOGGER.debug('Destination IPv4 address does not match!')
+        LOGGER.debug(f'Wrong IPv4 destination {ip.dst}, expected {dst_address}')
         return False
     chksum = ip.chksum
     ip.chksum = None
     new_chksum = sp.IP(sp.raw(ip)).chksum
     if chksum != new_chksum:
-        LOGGER.debug(f'Expected IP checksum {new_chksum} but found {chksum}')
+        LOGGER.debug(f'Wrong IPv4 checksum {chksum}, expected {new_chksum}')
         return False
     if flags and ip.flags != flags:
         LOGGER.debug(f'Wrong IP flags value {ip.flags}, expected {flags}')
@@ -185,10 +199,10 @@ def check_ipv6(expect_params, packet):
         LOGGER.debug('Packet is not IPv6!')
         return False
     if src_address and ip6.src != src_address:
-        LOGGER.debug('Source IPv6 address does not match!')
+        LOGGER.debug(f'Wrong IPv6 source {ip6.src}, expected {src_address}')
         return False
     if dst_address and ip6.dst != dst_address:
-        LOGGER.debug('Destination IPv6 address does not match!')
+        LOGGER.debug(f'Wrong IPv6 destination {ip6.dst}, expected {dst_address}')
         return False
     # IPv6 has no IP-level checksum.
     if flags:
@@ -268,32 +282,32 @@ def check_ping_reply_6(expect_params, packet):
     return True
 
 
-def check_ping_request(expect_params, packet):
-    src_address = expect_params.get('src_address')
-    dst_address = expect_params.get('dst_address')
+def check_ping_request(args, packet):
+    src_address = args['expect_params'].get('src_address')
+    dst_address = args['expect_params'].get('dst_address')
     if not (src_address or dst_address):
         raise Exception('Source or destination address must be given to match the ping request!')
     if (
         (src_address and ':' in src_address) or
         (dst_address and ':' in dst_address)
     ):
-        return check_ping_request_6(expect_params, packet)
+        return check_ping_request_6(args['expect_params'], packet)
     else:
-        return check_ping_request_4(expect_params, packet)
+        return check_ping_request_4(args['expect_params'], packet)
 
 
-def check_ping_reply(expect_params, packet):
-    src_address = expect_params.get('src_address')
-    dst_address = expect_params.get('dst_address')
+def check_ping_reply(args, packet):
+    src_address = args['expect_params'].get('src_address')
+    dst_address = args['expect_params'].get('dst_address')
     if not (src_address or dst_address):
         raise Exception('Source or destination address must be given to match the ping reply!')
     if (
         (src_address and ':' in src_address) or
         (dst_address and ':' in dst_address)
     ):
-        return check_ping_reply_6(expect_params, packet)
+        return check_ping_reply_6(args['expect_params'], packet)
     else:
-        return check_ping_reply_4(expect_params, packet)
+        return check_ping_reply_4(args['expect_params'], packet)
 
 
 def check_tcp(expect_params, packet):
@@ -308,7 +322,7 @@ def check_tcp(expect_params, packet):
     tcp.chksum = None
     newpacket = sp.Ether(sp.raw(packet[sp.Ether]))
     new_chksum = newpacket[sp.TCP].chksum
-    if chksum != new_chksum:
+    if new_chksum and chksum != new_chksum:
         LOGGER.debug(f'Wrong TCP checksum {chksum}, expected {new_chksum}!')
         return False
     if tcp_flags and tcp.flags != tcp_flags:
@@ -339,7 +353,7 @@ def check_tcp_syn_request_4(expect_params, packet):
     return True
 
 
-def check_tcp_syn_reply_4(expect_params, packet):
+def check_tcp_syn_reply_4(send_params, expect_params, packet):
     if not check_ipv4(expect_params, packet):
         return False
     if not check_tcp(expect_params | {'tcp_flags': 'SA'}, packet):
@@ -347,6 +361,36 @@ def check_tcp_syn_reply_4(expect_params, packet):
     return True
 
 
+def check_tcp_3way_4(args, packet):
+    send_params = args['send_params']
+
+    expect_params_sa = clean_params(args['expect_params'])
+    expect_params_sa['src_address'] = send_params['dst_address']
+    expect_params_sa['dst_address'] = send_params['src_address']
+
+    # Sniff incoming SYN+ACK packet
+    if (
+        check_ipv4(expect_params_sa, packet) and
+        check_tcp(expect_params_sa | {'tcp_flags': 'SA'}, packet)
+    ):
+        ether = sp.Ether()
+        ip_sa = packet.getlayer(sp.IP)
+        tcp_sa = packet.getlayer(sp.TCP)
+        reply_params = clean_params(send_params)
+        reply_params['src_address'] = ip_sa.dst
+        reply_params['dst_address'] = ip_sa.src
+        ip_a = prepare_ipv4(reply_params)
+        tcp_a = sp.TCP(
+            sport=tcp_sa.dport, dport=tcp_sa.sport, flags='A',
+            seq=tcp_sa.ack, ack=tcp_sa.seq + 1,
+        )
+        req = ether / ip_a / tcp_a
+        sp.sendp(req, iface=send_params['sendif'], verbose=False)
+        return True
+
+    return False
+
+
 def check_tcp_syn_request_6(expect_params, packet):
     if not check_ipv6(expect_params, packet):
         return False
@@ -363,7 +407,38 @@ def check_tcp_syn_reply_6(expect_params, packet):
     return True
 
 
-def check_tcp_syn_request(expect_params, packet):
+def check_tcp_3way_6(args, packet):
+    send_params = args['send_params']
+
+    expect_params_sa = clean_params(args['expect_params'])
+    expect_params_sa['src_address'] = send_params['dst_address']
+    expect_params_sa['dst_address'] = send_params['src_address']
+
+    # Sniff incoming SYN+ACK packet
+    if (
+        check_ipv6(expect_params_sa, packet) and
+        check_tcp(expect_params_sa | {'tcp_flags': 'SA'}, packet)
+    ):
+        ether = sp.Ether()
+        ip6_sa = packet.getlayer(sp.IPv6)
+        tcp_sa = packet.getlayer(sp.TCP)
+        reply_params = clean_params(send_params)
+        reply_params['src_address'] = ip6_sa.dst
+        reply_params['dst_address'] = ip6_sa.src
+        ip_a = prepare_ipv6(reply_params)
+        tcp_a = sp.TCP(
+            sport=tcp_sa.dport, dport=tcp_sa.sport, flags='A',
+            seq=tcp_sa.ack, ack=tcp_sa.seq + 1,
+        )
+        req = ether / ip_a / tcp_a
+        sp.sendp(req, iface=send_params['sendif'], verbose=False)
+        return True
+
+    return False
+
+
+def check_tcp_syn_request(args, packet):
+    expect_params = args['expect_params']
     src_address = expect_params.get('src_address')
     dst_address = expect_params.get('dst_address')
     if not (src_address or dst_address):
@@ -377,7 +452,8 @@ def check_tcp_syn_request(expect_params, packet):
         return check_tcp_syn_request_4(expect_params, packet)
 
 
-def check_tcp_syn_reply(expect_params, packet):
+def check_tcp_syn_reply(args, packet):
+    expect_params = args['expect_params']
     src_address = expect_params.get('src_address')
     dst_address = expect_params.get('dst_address')
     if not (src_address or dst_address):
@@ -390,8 +466,24 @@ def check_tcp_syn_reply(expect_params, packet):
     else:
         return check_tcp_syn_reply_4(expect_params, packet)
 
+def check_tcp_3way(args, packet):
+    expect_params = args['expect_params']
+    src_address = expect_params.get('src_address')
+    dst_address = expect_params.get('dst_address')
+    if not (src_address or dst_address):
+        raise Exception('Source or destination address must be given to match the tcp syn reply!')
+    if (
+            (src_address and ':' in src_address) or
+            (dst_address and ':' in dst_address)
+    ):
+        return check_tcp_3way_6(args, packet)
+    else:
+        return check_tcp_3way_4(args, packet)
+
 
-def setup_sniffer(recvif, ping_type, sniff_type, expect_params, defrag):
+def setup_sniffer(
+        recvif, ping_type, sniff_type, expect_params, defrag, send_params,
+):
     if ping_type == 'icmp' and sniff_type == 'request':
         checkfn = check_ping_request
     elif ping_type == 'icmp' and sniff_type == 'reply':
@@ -400,10 +492,15 @@ def setup_sniffer(recvif, ping_type, sniff_type, expect_params, defrag):
         checkfn = check_tcp_syn_request
     elif ping_type == 'tcpsyn' and sniff_type == 'reply':
         checkfn = check_tcp_syn_reply
+    elif ping_type == 'tcp3way' and sniff_type == 'reply':
+        checkfn = check_tcp_3way
     else:
         raise Exception('Unspported ping or sniff type')
 
-    return Sniffer(expect_params, checkfn, recvif, defrag=defrag)
+    return Sniffer(
+        {'send_params': send_params, 'expect_params': expect_params},
+        checkfn, recvif, defrag=defrag,
+    )
 
 
 def parse_args():
@@ -416,8 +513,8 @@ def parse_args():
     parser.add_argument('--to', required=True,
         help='The destination IP address for the ping request')
     parser.add_argument('--ping-type',
-        choices=('icmp', 'tcpsyn'),
-        help='Type of ping: ICMP (default) or TCP SYN',
+        choices=('icmp', 'tcpsyn', 'tcp3way'),
+        help='Type of ping: ICMP (default) or TCP SYN or 3-way TCP handshake',
         default='icmp')
     parser.add_argument('--fromaddr',
         help='The source IP address for the ping request')
@@ -444,7 +541,7 @@ def parse_args():
         help='TCP sequence number')
     parser_send.add_argument('--send-sport', type=int,
         help='TCP source port')
-    parser_send.add_argument('--send-dport', type=int, default=666,
+    parser_send.add_argument('--send-dport', type=int, default=9,
         help='TCP destination port')
     parser_send.add_argument('--send-length', type=int, default=len(PAYLOAD_MAGIC),
         help='ICMP Echo Request payload size')
@@ -500,6 +597,8 @@ def main():
     send_params['tcpopt_unaligned'] = args.send_tcpopt_unaligned
     send_params['nop'] = args.send_nop
     send_params['src_address'] = args.fromaddr if args.fromaddr else None
+    send_params['dst_address'] = args.to
+    send_params['sendif'] = args.sendif
 
     # We may not have a default route. Tell scapy where to start looking for routes
     sp.conf.iface6 = args.sendif
@@ -525,7 +624,7 @@ def main():
             LOGGER.debug(f'Installing receive sniffer on {iface}')
             sniffers.append(
                 setup_sniffer(iface, args.ping_type, 'request',
-                              sniffer_params, defrag,
+                              sniffer_params, defrag, send_params,
             ))
 
     if args.replyif:
@@ -536,12 +635,12 @@ def main():
             LOGGER.debug(f'Installing reply sniffer on {iface}')
             sniffers.append(
                 setup_sniffer(iface, args.ping_type, 'reply',
-                              sniffer_params, defrag,
+                              sniffer_params, defrag, send_params,
             ))
 
     LOGGER.debug(f'Installed {len(sniffers)} sniffers')
 
-    send_ping(args.to, args.sendif, args.ping_type, send_params)
+    send_ping(args.ping_type, send_params)
 
     err = 0
     sniffer_num = 0
diff --git a/tests/sys/netpfil/pf/src_track.sh b/tests/sys/netpfil/pf/src_track.sh
index 27eb62abcf41..eb053dd84a90 100755
--- a/tests/sys/netpfil/pf/src_track.sh
+++ b/tests/sys/netpfil/pf/src_track.sh
@@ -2,6 +2,7 @@
 # SPDX-License-Identifier: BSD-2-Clause
 #
 # Copyright (c) 2020 Kristof Provost <kp@FreeBSD.org>
+# Copyright (c) 2024 Kajetan Staszkiewicz <vegeta@tuxpowered.net>
 #
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions
@@ -59,7 +60,68 @@ source_track_cleanup()
 	pft_cleanup
 }
 
+
+max_src_conn_rule_head()
+{
+	atf_set descr 'Max connections per source per rule'
+	atf_set require.user root
+}
+
+max_src_conn_rule_body()
+{
+	setup_router_server_ipv6
+
+	# Clients will connect from another network behind the router.
+	# This allows for using multiple source addresses and for tester jail
+	# to not respond with RST packets for SYN+ACKs.
+	jexec router route add -6 2001:db8:44::0/64 2001:db8:42::2
+	jexec server route add -6 2001:db8:44::0/64 2001:db8:43::1
+
+	pft_set_rules router \
+		"block" \
+		"pass inet6 proto icmp6 icmp6-type { neighbrsol, neighbradv }" \
+		"pass in  on ${epair_tester}b inet6 proto tcp keep state (max-src-conn 3 source-track rule overload <bad_hosts>)" \
+		"pass out on ${epair_server}a inet6 proto tcp keep state"
+
+	# Limiting of connections is done for connections which have successfully
+	# finished the 3-way handshake. Once the handshake is done, the state
+	# is moved to CLOSED state. We use pft_ping.py to check that the handshake
+	# was really successful and after that we check what is in pf state table.
+
+	# 3 connections from host ::1 will be allowed.
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4201 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4202 --fromaddr 2001:db8:44::1
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4203 --fromaddr 2001:db8:44::1
+	# The 4th connection from host ::1 will have its state killed.
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4204 --fromaddr 2001:db8:44::1
+	# A connection from host :2 is will be allowed.
+	ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4205 --fromaddr 2001:db8:44::2
+
+	states=$(mktemp) || exit 1
+	jexec router pfctl -qss | grep 'tcp 2001:db8:43::2\[9\] <-' > $states
+
+	grep -qE '2001:db8:44::1\[4201\]\s+ESTABLISHED:ESTABLISHED' $states || atf_fail "State for port 4201 not found or not established"
+	grep -qE '2001:db8:44::1\[4202\]\s+ESTABLISHED:ESTABLISHED' $states || atf_fail "State for port 4202 not found or not established"
+	grep -qE '2001:db8:44::1\[4203\]\s+ESTABLISHED:ESTABLISHED' $states || atf_fail "State for port 4203 not found or not established"
+	grep -qE '2001:db8:44::2\[4205\]\s+ESTABLISHED:ESTABLISHED' $states || atf_fail "State for port 4205 not found or not established"
+
+	if (
+		grep -qE '2001:db8:44::1\[4204\]\s+' $states &&
+		! grep -qE '2001:db8:44::1\[4204\]\s+CLOSED:CLOSED' $states
+	); then
+		atf_fail "State for port 4204 found but not closed"
+	fi
+
+	jexec router pfctl -T test -t bad_hosts 2001:db8:44::1 || atf_fail "Host not found in overload table"
+}
+
+max_src_conn_rule_cleanup()
+{
+	pft_cleanup
+}
+
 atf_init_test_cases()
 {
 	atf_add_test_case "source_track"
+	atf_add_test_case "max_src_conn_rule"
 }
diff --git a/tests/sys/netpfil/pf/utils.subr b/tests/sys/netpfil/pf/utils.subr
index f02dfc22049f..0b7ee621e6fa 100644
--- a/tests/sys/netpfil/pf/utils.subr
+++ b/tests/sys/netpfil/pf/utils.subr
@@ -215,8 +215,9 @@ setup_router_server_ipv4()
 	vnet_mkjail server ${epair_server}b
 	jexec server ifconfig ${epair_server}b ${net_server_host_server}/${net_server_mask} up
 	jexec server route add -net ${net_tester} ${net_server_host_router}
-	jexec server nc -4l 666 &
-	sleep 1 # Give nc time to start and listen
+	inetd_conf=$(mktemp)
+	echo "discard stream tcp nowait root internal" > $inetd_conf
+	jexec server inetd $inetd_conf
 }
 
 # Create a bare router jail.
@@ -268,8 +269,9 @@ setup_router_server_ipv6()
 	vnet_mkjail server ${epair_server}b
 	jexec server ifconfig ${epair_server}b inet6 ${net_server_host_server}/${net_server_mask} up no_dad
 	jexec server route add -6 ${net_tester} ${net_server_host_router}
-	jexec server nc -6l 666 &
-	sleep 1 # Give nc time to start and listen
+	inetd_conf=$(mktemp)
+	echo "discard stream tcp6 nowait root internal" > $inetd_conf
+	jexec server inetd $inetd_conf
 }
 
 # Ping the dummy static NDP target.