git: b18abe2e1f22 - stable/13 - tests: add more netlink tests for neighbors/routes

From: Alexander V. Chernikov <melifaro_at_FreeBSD.org>
Date: Sat, 08 Apr 2023 19:45:05 UTC
The branch stable/13 has been updated by melifaro:

URL: https://cgit.FreeBSD.org/src/commit/?id=b18abe2e1f22394717d4847f2c7b491929cee92b

commit b18abe2e1f22394717d4847f2c7b491929cee92b
Author:     Alexander V. Chernikov <melifaro@FreeBSD.org>
AuthorDate: 2023-03-07 17:30:35 +0000
Commit:     Alexander V. Chernikov <melifaro@FreeBSD.org>
CommitDate: 2023-04-08 19:15:16 +0000

    tests: add more netlink tests for neighbors/routes
    
    Differential Revision: https://reviews.freebsd.org/D38912
    MFC after:      2 weeks
    
    (cherry picked from commit c57dfd92c876fabc04e94945dd9534468520bbbf)
---
 tests/atf_python/sys/net/netlink.py  | 121 ++++++++++++++++++++++++++++++-----
 tests/sys/netlink/Makefile           |   1 +
 tests/sys/netlink/test_rtnl_neigh.py |  53 +++++++++++++++
 tests/sys/netlink/test_rtnl_route.py |  23 +++++++
 4 files changed, 183 insertions(+), 15 deletions(-)

diff --git a/tests/atf_python/sys/net/netlink.py b/tests/atf_python/sys/net/netlink.py
index ec5a7feef317..bfbf3217d52a 100644
--- a/tests/atf_python/sys/net/netlink.py
+++ b/tests/atf_python/sys/net/netlink.py
@@ -29,6 +29,12 @@ def align4(val: int) -> int:
     return roundup2(val, 4)
 
 
+def enum_or_int(val) -> int:
+    if isinstance(val, Enum):
+        return val.value
+    return val
+
+
 class SockaddrNl(Structure):
     _fields_ = [
         ("nl_len", c_ubyte),
@@ -125,8 +131,8 @@ class NlRtMsgType(Enum):
     RTM_DELROUTE = 25
     RTM_GETROUTE = 26
     RTM_NEWNEIGH = 28
-    RTM_DELNEIGH = 27
-    RTM_GETNEIGH = 28
+    RTM_DELNEIGH = 29
+    RTM_GETNEIGH = 30
     RTM_NEWRULE = 32
     RTM_DELRULE = 33
     RTM_GETRULE = 34
@@ -491,6 +497,39 @@ class IfattrType(Enum):
     IFA_TARGET_NETNSID = auto()
 
 
+class NdMsg(Structure):
+    _fields_ = [
+        ("ndm_family", c_ubyte),
+        ("ndm_pad1", c_ubyte),
+        ("ndm_pad2", c_ubyte),
+        ("ndm_ifindex", c_uint),
+        ("ndm_state", c_ushort),
+        ("ndm_flags", c_ubyte),
+        ("ndm_type", c_ubyte),
+    ]
+
+
+class NdAttrType(Enum):
+    NDA_UNSPEC = 0
+    NDA_DST = 1
+    NDA_LLADDR = 2
+    NDA_CACHEINFO = 3
+    NDA_PROBES = 4
+    NDA_VLAN = 5
+    NDA_PORT = 6
+    NDA_VNI = 7
+    NDA_IFINDEX = 8
+    NDA_MASTER = 9
+    NDA_LINK_NETNSID = 10
+    NDA_SRC_VNI = 11
+    NDA_PROTOCOL = 12
+    NDA_NH_ID = 13
+    NDA_FDB_EXT_ATTRS = 14
+    NDA_FLAGS_EXT = 15
+    NDA_NDM_STATE_MASK = 16
+    NDA_NDM_FLAGS_MASK = 17
+
+
 class GenlMsgHdr(Structure):
     _fields_ = [
         ("cmd", c_ubyte),
@@ -702,7 +741,7 @@ class NlAttrNested(NlAttr):
 
 class NlAttrU32(NlAttr):
     def __init__(self, nla_type, val):
-        self.u32 = val
+        self.u32 = enum_or_int(val)
         super().__init__(nla_type, b"")
 
     @property
@@ -729,7 +768,7 @@ class NlAttrU32(NlAttr):
 
 class NlAttrU16(NlAttr):
     def __init__(self, nla_type, val):
-        self.u16 = val
+        self.u16 = enum_or_int(val)
         super().__init__(nla_type, b"")
 
     @property
@@ -756,7 +795,7 @@ class NlAttrU16(NlAttr):
 
 class NlAttrU8(NlAttr):
     def __init__(self, nla_type, val):
-        self.u8 = val
+        self.u8 = enum_or_int(val)
         super().__init__(nla_type, b"")
 
     @property
@@ -842,6 +881,11 @@ class NlAttrIfindex(NlAttrU32):
         return " iface=if#{}".format(self.u32)
 
 
+class NlAttrMac(NlAttr):
+    def _print_attr_value(self):
+        return ["{:02}".format(int(d)) for d in data[4:]].join(":")
+
+
 class NlAttrTable(NlAttrU32):
     def _print_attr_value(self):
         return " rtable={}".format(self.u32)
@@ -1067,26 +1111,44 @@ rtnl_ifa_attrs = prepare_attrs_map(
 )
 
 
+rtnl_nd_attrs = prepare_attrs_map(
+    [
+        AttrDescr(NdAttrType.NDA_DST, NlAttrIp),
+        AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex),
+        AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32),
+        AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac),
+    ]
+)
+
+
 class BaseNetlinkMessage(object):
     def __init__(self, helper, nlmsg_type):
-        self.nlmsg_type = nlmsg_type
+        self.nlmsg_type = enum_or_int(nlmsg_type)
         self.ut = unittest.TestCase()
         self.nla_list = []
         self._orig_data = None
         self.helper = helper
         self.nl_hdr = Nlmsghdr(
-            nlmsg_type=nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
+            nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid
         )
         self.base_hdr = None
 
+    def set_request(self, need_ack=True):
+        self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST])
+        if need_ack:
+            self.add_nlflags([NlmBaseFlags.NLM_F_ACK])
+
+    def add_nlflags(self, flags: List):
+        int_flags = 0
+        for flag in flags:
+            int_flags |= enum_or_int(flag)
+        self.nl_hdr.nlmsg_flags |= int_flags
+
     def add_nla(self, nla):
         self.nla_list.append(nla)
 
     def _get_nla(self, nla_list, nla_type):
-        if isinstance(nla_type, Enum):
-            nla_type_raw = nla_type.value
-        else:
-            nla_type_raw = nla_type
+        nla_type_raw = enum_or_int(nla_type)
         for nla in nla_list:
             if nla.nla_type == nla_type_raw:
                 return nla
@@ -1102,10 +1164,7 @@ class BaseNetlinkMessage(object):
         return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr)
 
     def is_type(self, nlmsg_type):
-        if isinstance(nlmsg_type, Enum):
-            nlmsg_type_raw = nlmsg_type.value
-        else:
-            nlmsg_type_raw = nlmsg_type
+        nlmsg_type_raw = enum_or_int(nlmsg_type)
         return nlmsg_type_raw == self.nl_hdr.nlmsg_type
 
     def is_reply(self, hdr):
@@ -1422,6 +1481,37 @@ class NetlinkIfaMessage(BaseNetlinkRtMessage):
         )
 
 
+class NetlinkNdMessage(BaseNetlinkRtMessage):
+    messages = [
+        NlRtMsgType.RTM_NEWNEIGH.value,
+        NlRtMsgType.RTM_DELNEIGH.value,
+        NlRtMsgType.RTM_GETNEIGH.value,
+    ]
+    nl_attrs_map = rtnl_nd_attrs
+
+    def __init__(self, helper, nlm_type):
+        super().__init__(helper, nlm_type)
+        self.base_hdr = NdMsg()
+
+    def parse_base_header(self, data):
+        if len(data) < sizeof(NdMsg):
+            raise ValueError("length less than NdMsg header")
+        nd_hdr = NdMsg.from_buffer_copy(data)
+        return (nd_hdr, sizeof(NdMsg))
+
+    def print_base_header(self, hdr, prepend=""):
+        family = self.helper.get_af_name(hdr.ndm_family)
+        print(
+            "{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format(  # noqa: E501
+                prepend,
+                family,
+                hdr.ndm_ifindex,
+                hdr.ndm_state,
+                hdr.ndm_flags,
+            )
+        )
+
+
 class Nlsock:
     def __init__(self, family, helper):
         self.helper = helper
@@ -1435,6 +1525,7 @@ class Nlsock:
             NetlinkRtMessage,
             NetlinkIflaMessage,
             NetlinkIfaMessage,
+            NetlinkNdMessage,
             NetlinkDoneMessage,
             NetlinkErrorMessage,
         ]
diff --git a/tests/sys/netlink/Makefile b/tests/sys/netlink/Makefile
index cbec7b2d8b5d..16559f0e9d3d 100644
--- a/tests/sys/netlink/Makefile
+++ b/tests/sys/netlink/Makefile
@@ -9,6 +9,7 @@ ATF_TESTS_C +=	test_snl test_snl_generic
 ATF_TESTS_PYTEST +=	test_nl_core.py
 ATF_TESTS_PYTEST +=	test_rtnl_iface.py
 ATF_TESTS_PYTEST +=	test_rtnl_ifaddr.py
+ATF_TESTS_PYTEST +=	test_rtnl_neigh.py
 ATF_TESTS_PYTEST +=	test_rtnl_route.py
 
 CFLAGS+=	-I${.CURDIR:H:H:H}
diff --git a/tests/sys/netlink/test_rtnl_neigh.py b/tests/sys/netlink/test_rtnl_neigh.py
new file mode 100644
index 000000000000..6d6f95098d14
--- /dev/null
+++ b/tests/sys/netlink/test_rtnl_neigh.py
@@ -0,0 +1,53 @@
+import socket
+import pytest
+
+from atf_python.sys.net.netlink import NdAttrType
+from atf_python.sys.net.netlink import NetlinkNdMessage
+from atf_python.sys.net.netlink import NetlinkTestTemplate
+from atf_python.sys.net.netlink import NlConst
+from atf_python.sys.net.netlink import NlRtMsgType
+from atf_python.sys.net.vnet import SingleVnetTestTemplate
+
+
+class TestRtNlNeigh(NetlinkTestTemplate, SingleVnetTestTemplate):
+    def setup_method(self, method):
+        method_name = method.__name__
+        if "4" in method_name:
+            self.IPV4_PREFIXES = ["192.0.2.1/24"]
+        if "6" in method_name:
+            self.IPV6_PREFIXES = ["2001:db8::1/64"]
+        super().setup_method(method)
+        self.setup_netlink(NlConst.NETLINK_ROUTE)
+
+    def filter_iface(self, family, num_items):
+        epair_ifname = self.vnet.iface_alias_map["if1"].name
+        epair_ifindex = socket.if_nametoindex(epair_ifname)
+
+        msg = NetlinkNdMessage(self.helper, NlRtMsgType.RTM_GETNEIGH)
+        msg.set_request()
+        msg.base_hdr.ndm_family = family
+        msg.base_hdr.ndm_ifindex = epair_ifindex
+        self.write_message(msg)
+
+        ret = []
+        for rx_msg in self.read_msg_list(
+            msg.nl_hdr.nlmsg_seq, NlRtMsgType.RTM_NEWNEIGH
+        ):
+            ifname = socket.if_indextoname(rx_msg.base_hdr.ndm_ifindex)
+            family = rx_msg.base_hdr.ndm_family
+            assert ifname == epair_ifname
+            assert family == family
+            assert rx_msg.get_nla(NdAttrType.NDA_DST) is not None
+            assert rx_msg.get_nla(NdAttrType.NDA_LLADDR) is not None
+            ret.append(rx_msg)
+        assert len(ret) == num_items
+
+    @pytest.mark.timeout(5)
+    def test_6_filter_iface(self):
+        """Tests that listing outputs all nd6 records"""
+        return self.filter_iface(socket.AF_INET6, 2)
+
+    @pytest.mark.timeout(5)
+    def test_4_filter_iface(self):
+        """Tests that listing outputs all arp records"""
+        return self.filter_iface(socket.AF_INET, 1)
diff --git a/tests/sys/netlink/test_rtnl_route.py b/tests/sys/netlink/test_rtnl_route.py
index 71125343166a..b64fce57a518 100644
--- a/tests/sys/netlink/test_rtnl_route.py
+++ b/tests/sys/netlink/test_rtnl_route.py
@@ -2,9 +2,11 @@ import ipaddress
 import socket
 
 import pytest
+from atf_python.sys.net.tools import ToolsHelper
 from atf_python.sys.net.netlink import NetlinkRtMessage
 from atf_python.sys.net.netlink import NetlinkTestTemplate
 from atf_python.sys.net.netlink import NlAttrIp
+from atf_python.sys.net.netlink import NlAttrU32
 from atf_python.sys.net.netlink import NlConst
 from atf_python.sys.net.netlink import NlmBaseFlags
 from atf_python.sys.net.netlink import NlmGetFlags
@@ -22,6 +24,27 @@ class TestRtNlRoute(NetlinkTestTemplate, SingleVnetTestTemplate):
         super().setup_method(method)
         self.setup_netlink(NlConst.NETLINK_ROUTE)
 
+    @pytest.mark.timeout(5)
+    def test_add_route6_ll_gw(self):
+        epair_ifname = self.vnet.iface_alias_map["if1"].name
+        epair_ifindex = socket.if_nametoindex(epair_ifname)
+
+        msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_NEWROUTE)
+        msg.set_request()
+        msg.add_nlflags([NlmNewFlags.NLM_F_CREATE])
+        msg.base_hdr.rtm_family = socket.AF_INET6
+        msg.base_hdr.rtm_dst_len = 64
+        msg.add_nla(NlAttrIp(RtattrType.RTA_DST, "2001:db8:2::"))
+        msg.add_nla(NlAttrIp(RtattrType.RTA_GATEWAY, "fe80::1"))
+        msg.add_nla(NlAttrU32(RtattrType.RTA_OIF, epair_ifindex))
+
+        rx_msg = self.get_reply(msg)
+        assert rx_msg.is_type(NlMsgType.NLMSG_ERROR)
+        assert rx_msg.error_code == 0
+
+        ToolsHelper.print_net_debug()
+        ToolsHelper.print_output("netstat -6onW")
+
     @pytest.mark.timeout(20)
     def test_buffer_override(self):
         msg_flags = (