git: 8eb2bee6c0f4 - main - testing: Add basic atf support to pytest.
- Go to: [ bottom of page ] [ top of archives ] [ this month ]
Date: Sat, 25 Jun 2022 19:30:46 UTC
The branch main has been updated by melifaro: URL: https://cgit.FreeBSD.org/src/commit/?id=8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64 commit 8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64 Author: Alexander V. Chernikov <melifaro@FreeBSD.org> AuthorDate: 2022-06-25 19:01:45 +0000 Commit: Alexander V. Chernikov <melifaro@FreeBSD.org> CommitDate: 2022-06-25 19:25:15 +0000 testing: Add basic atf support to pytest. Implementation consists of the pytest plugin implementing ATF format and a simple C++ wrapper, which reorders the provided arguments from ATF format to the format understandable by pytest. Each test has this wrapper specified after the shebang. When kyua executes the test, wrapper calls pytest, which loads atf plugin, does the work and returns the result. Additionally, a separate python "package", `/usr/tests/atf_python` has been added to collect code that may be useful across different tests. Current limitations: * Opaque metadata passing via X-Name properties. Require some fixtures to write * `-s srcdir` parameter passed by the runner is ignored. * No `atf-c-api(3)` or similar - relying on pytest framework & existing python libraries * No support for `atf_tc_<get|has>_config_var()` & `atf_tc_set_md_var()`. Can be probably implemented with env variables & autoload fixtures Differential Revision: https://reviews.freebsd.org/D31084 Reviewed by: kp, ngie --- share/mk/atf.test.mk | 47 ++ tests/Makefile | 4 +- tests/__init__.py | 0 tests/atf_python/Makefile | 12 + tests/atf_python/__init__.py | 4 + tests/atf_python/atf_pytest.py | 218 +++++++++ tests/atf_python/sys/Makefile | 11 + tests/atf_python/sys/__init__.py | 0 tests/atf_python/sys/net/Makefile | 10 + tests/atf_python/sys/net/__init__.py | 0 tests/atf_python/sys/net/rtsock.py | 604 ++++++++++++++++++++++++ tests/atf_python/sys/net/tools.py | 33 ++ tests/atf_python/sys/net/vnet.py | 203 ++++++++ tests/conftest.py | 121 +++++ tests/freebsd_test_suite/Makefile | 13 + tests/freebsd_test_suite/atf_pytest_wrapper.cpp | 192 ++++++++ 16 files changed, 1471 insertions(+), 1 deletion(-) diff --git a/share/mk/atf.test.mk b/share/mk/atf.test.mk index e7a8c82b7a8f..c00c00bc9b3d 100644 --- a/share/mk/atf.test.mk +++ b/share/mk/atf.test.mk @@ -22,6 +22,7 @@ ATF_TESTS_C?= ATF_TESTS_CXX?= ATF_TESTS_SH?= ATF_TESTS_KSH93?= +ATF_TESTS_PYTEST?= .if !empty(ATF_TESTS_C) PROGS+= ${ATF_TESTS_C} @@ -109,3 +110,49 @@ ${_T}: ${ATF_TESTS_KSH93_SRC_${_T}} mv ${.TARGET}.tmp ${.TARGET} .endfor .endif + +.if !empty(ATF_TESTS_PYTEST) +# bsd.prog.mk SCRIPTS interface removes file extension unless +# SCRIPTSNAME is set, which is not possible to do here. +# Workaround this by appending another extension (.xtmp) to the +# file name. Use separate loop to avoid dealing with explicitly +# stating expansion for each and every variable. +# +# ATF_TESTS_PYTEST -> contains list of files as is (test_something.py ..) +# _ATF_TESTS_PYTEST -> (test_something.py.xtmp ..) +# +# Former array is iterated to construct Kyuafile, where original file +# names need to be written. +# Latter array is iterated to enable bsd.prog.mk scripts framework - +# namely, installing scripts without .xtmp prefix. Note: this allows to +# not bother about the fact that make target needs to be different from +# the source file. +_TESTS+= ${ATF_TESTS_PYTEST} +_ATF_TESTS_PYTEST= +.for _T in ${ATF_TESTS_PYTEST} +_ATF_TESTS_PYTEST += ${_T}.xtmp +TEST_INTERFACE.${_T}= atf +TEST_METADATA.${_T}+= required_programs="pytest" +.endfor + +SCRIPTS+= ${_ATF_TESTS_PYTEST} +.for _T in ${_ATF_TESTS_PYTEST} +SCRIPTSDIR_${_T}= ${TESTSDIR} +CLEANFILES+= ${_T} ${_T}.tmp +# TODO(jmmv): It seems to me that this SED and SRC functionality should +# exist in bsd.prog.mk along the support for SCRIPTS. Move it there if +# this proves to be useful within the tests. +ATF_TESTS_PYTEST_SED_${_T}?= # empty +ATF_TESTS_PYTEST_SRC_${_T}?= ${.CURDIR}/${_T:S,.xtmp$,,} +${_T}: + echo "#!${TESTSBASE}/atf_pytest_wrapper -P ${TESTSBASE}" > ${.TARGET}.tmp +.if empty(ATF_TESTS_PYTEST_SED_${_T}) + cat ${ATF_TESTS_PYTEST_SRC_${_T}} >>${.TARGET}.tmp +.else + cat ${ATF_TESTS_PYTEST_SRC_${_T}} \ + | sed ${ATF_TESTS_PYTEST_SED_${_T}} >>${.TARGET}.tmp +.endif + chmod +x ${.TARGET}.tmp + mv ${.TARGET}.tmp ${.TARGET} +.endfor +.endif diff --git a/tests/Makefile b/tests/Makefile index 561a0ec5fcab..cfd065d61539 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -4,12 +4,14 @@ PACKAGE= tests TESTSDIR= ${TESTSBASE} -${PACKAGE}FILES+= README +${PACKAGE}FILES+= README __init__.py conftest.py KYUAFILE= yes SUBDIR+= etc SUBDIR+= sys +SUBDIR+= atf_python +SUBDIR+= freebsd_test_suite SUBDIR_PARALLEL= diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile new file mode 100644 index 000000000000..26d419743257 --- /dev/null +++ b/tests/atf_python/Makefile @@ -0,0 +1,12 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py atf_pytest.py +SUBDIR= sys + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python + + +.include <bsd.prog.mk> diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py new file mode 100644 index 000000000000..6d5ec22ef054 --- /dev/null +++ b/tests/atf_python/__init__.py @@ -0,0 +1,4 @@ +import pytest + +pytest.register_assert_rewrite("atf_python.sys.net.rtsock") +pytest.register_assert_rewrite("atf_python.sys.net.vnet") diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py new file mode 100644 index 000000000000..89c0e3a515b9 --- /dev/null +++ b/tests/atf_python/atf_pytest.py @@ -0,0 +1,218 @@ +import types +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple + +import pytest + + +class ATFCleanupItem(pytest.Item): + def runtest(self): + """Runs cleanup procedure for the test instead of the test""" + instance = self.parent.cls() + instance.cleanup(self.nodeid) + + def setup_method_noop(self, method): + """Overrides runtest setup method""" + pass + + def teardown_method_noop(self, method): + """Overrides runtest teardown method""" + pass + + +class ATFTestObj(object): + def __init__(self, obj, has_cleanup): + # Use nodeid without name to properly name class-derived tests + self.ident = obj.nodeid.split("::", 1)[1] + self.description = self._get_test_description(obj) + self.has_cleanup = has_cleanup + self.obj = obj + + def _get_test_description(self, obj): + """Returns first non-empty line from func docstring or func name""" + docstr = obj.function.__doc__ + if docstr: + for line in docstr.split("\n"): + if line: + return line + return obj.name + + def _convert_marks(self, obj) -> Dict[str, Any]: + wj_func = lambda x: " ".join(x) # noqa: E731 + _map: Dict[str, Dict] = { + "require_user": {"name": "require.user"}, + "require_arch": {"name": "require.arch", "fmt": wj_func}, + "require_diskspace": {"name": "require.diskspace"}, + "require_files": {"name": "require.files", "fmt": wj_func}, + "require_machine": {"name": "require.machine", "fmt": wj_func}, + "require_memory": {"name": "require.memory"}, + "require_progs": {"name": "require.progs", "fmt": wj_func}, + "timeout": {}, + } + ret = {} + for mark in obj.iter_markers(): + if mark.name in _map: + name = _map[mark.name].get("name", mark.name) + if "fmt" in _map[mark.name]: + val = _map[mark.name]["fmt"](mark.args[0]) + else: + val = mark.args[0] + ret[name] = val + return ret + + def as_lines(self) -> List[str]: + """Output test definition in ATF-specific format""" + ret = [] + ret.append("ident: {}".format(self.ident)) + ret.append("descr: {}".format(self._get_test_description(self.obj))) + if self.has_cleanup: + ret.append("has.cleanup: true") + for key, value in self._convert_marks(self.obj).items(): + ret.append("{}: {}".format(key, value)) + return ret + + +class ATFHandler(object): + class ReportState(NamedTuple): + state: str + reason: str + + def __init__(self): + self._tests_state_map: Dict[str, ReportStatus] = {} + + def override_runtest(self, obj): + # Override basic runtest command + obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj) + # Override class setup/teardown + obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop + obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop + + def get_object_cleanup_class(self, obj): + if hasattr(obj, "parent") and obj.parent is not None: + if hasattr(obj.parent, "cls") and obj.parent.cls is not None: + if hasattr(obj.parent.cls, "cleanup"): + return obj.parent.cls + return None + + def has_object_cleanup(self, obj): + return self.get_object_cleanup_class(obj) is not None + + def list_tests(self, tests: List[str]): + print('Content-Type: application/X-atf-tp; version="1"') + print() + for test_obj in tests: + has_cleanup = self.has_object_cleanup(test_obj) + atf_test = ATFTestObj(test_obj, has_cleanup) + for line in atf_test.as_lines(): + print(line) + print() + + def set_report_state(self, test_name: str, state: str, reason: str): + self._tests_state_map[test_name] = self.ReportState(state, reason) + + def _extract_report_reason(self, report): + data = report.longrepr + if data is None: + return None + if isinstance(data, Tuple): + # ('/path/to/test.py', 23, 'Skipped: unable to test') + reason = data[2] + for prefix in "Skipped: ": + if reason.startswith(prefix): + reason = reason[len(prefix):] + return reason + else: + # string/ traceback / exception report. Capture the last line + return str(data).split("\n")[-1] + return None + + def add_report(self, report): + # MAP pytest report state to the atf-desired state + # + # ATF test states: + # (1) expected_death, (2) expected_exit, (3) expected_failure + # (4) expected_signal, (5) expected_timeout, (6) passed + # (7) skipped, (8) failed + # + # Note that ATF don't have the concept of "soft xfail" - xpass + # is a failure. It also calls teardown routine in a separate + # process, thus teardown states (pytest-only) are handled as + # body continuation. + + # (stage, state, wasxfail) + + # Just a passing test: WANT: passed + # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F) + # + # Failing body test: WHAT: failed + # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F) + # + # pytest.skip test decorator: WANT: skipped + # GOT: (setup,skipped, False), (teardown, passed, False) + # + # pytest.skip call inside test function: WANT: skipped + # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F) + # + # mark.xfail decorator+pytest.xfail: WANT: expected_failure + # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F) + # + # mark.xfail decorator+pass: WANT: failed + # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F) + + test_name = report.location[2] + stage = report.when + state = report.outcome + reason = self._extract_report_reason(report) + + # We don't care about strict xfail - it gets translated to False + + if stage == "setup": + if state in ("skipped", "failed"): + # failed init -> failed test, skipped setup -> xskip + # for the whole test + self.set_report_state(test_name, state, reason) + elif stage == "call": + # "call" stage shouldn't matter if setup failed + if test_name in self._tests_state_map: + if self._tests_state_map[test_name].state == "failed": + return + if state == "failed": + # Record failure & override "skipped" state + self.set_report_state(test_name, state, reason) + elif state == "skipped": + if hasattr(reason, "wasxfail"): + # xfail() called in the test body + state = "expected_failure" + else: + # skip inside the body + pass + self.set_report_state(test_name, state, reason) + elif state == "passed": + if hasattr(reason, "wasxfail"): + # the test was expected to fail but didn't + # mark as hard failure + state = "failed" + self.set_report_state(test_name, state, reason) + elif stage == "teardown": + if state == "failed": + # teardown should be empty, as the cleanup + # procedures should be implemented as a separate + # function/method, so mark teardown failure as + # global failure + self.set_report_state(test_name, state, reason) + + def write_report(self, path): + if self._tests_state_map: + # If we're executing in ATF mode, there has to be just one test + # Anyway, deterministically pick the first one + first_test_name = next(iter(self._tests_state_map)) + test = self._tests_state_map[first_test_name] + if test.state == "passed": + line = test.state + else: + line = "{}: {}".format(test.state, test.reason) + with open(path, mode="w") as f: + print(line, file=f) diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile new file mode 100644 index 000000000000..ff4cf17b85d2 --- /dev/null +++ b/tests/atf_python/sys/Makefile @@ -0,0 +1,11 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py +SUBDIR= net + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile new file mode 100644 index 000000000000..05b1d8afe863 --- /dev/null +++ b/tests/atf_python/sys/net/Makefile @@ -0,0 +1,10 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py rtsock.py tools.py vnet.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/net + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py new file mode 100755 index 000000000000..788e863f8b28 --- /dev/null +++ b/tests/atf_python/sys/net/rtsock.py @@ -0,0 +1,604 @@ +#!/usr/local/bin/python3 +import os +import socket +import struct +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtSockException(OSError): + pass + + +class RtConst: + RTM_VERSION = 5 + ALIGN = sizeof(c_long) + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + RTV_MTU = 0x1 + RTV_HOPCOUNT = 0x2 + RTV_EXPIRE = 0x4 + RTV_RPIPE = 0x8 + RTV_SPIPE = 0x10 + RTV_SSTHRESH = 0x20 + RTV_RTT = 0x40 + RTV_RTTVAR = 0x80 + RTV_WEIGHT = 0x100 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:{}:{}".format(prefix, value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def is_ipv6(ip: str) -> bool: + return ":" in ip + + @staticmethod + def ip_sa(ip: str, scopeid: int = 0) -> bytes: + if SaHelper.is_ipv6(ip): + return SaHelper.ip6_sa(ip, scopeid) + else: + return SaHelper.ip4_sa(ip) + + @staticmethod + def ip4_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int) -> bytes: + addr_bytes = (c_byte * 16)() + for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): + addr_bytes[i] = b + sin6 = SockaddrIn6( + sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid + ) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + @staticmethod + def pxlen4_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) + + @staticmethod + def pxlen_to_ip4(pxlen: int) -> str: + if pxlen == 32: + return "255.255.255.255" + else: + addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) + addr_bytes = struct.pack("!I", addr) + return socket.inet_ntop(socket.AF_INET, addr_bytes) + + @staticmethod + def pxlen6_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) + + @staticmethod + def pxlen_to_ip6(pxlen: int) -> str: + ip6_b = [0] * 16 + start = 0 + while pxlen > 8: + ip6_b[start] = 0xFF + pxlen -= 8 + start += 1 + ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) + return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) + + @staticmethod + def print_sa_inet(sa: bytes): + if len(sa) < 8: + raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + @staticmethod + def print_sa_inet6(sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28])[0] + return "{} scopeid {}".format(addr, scopeid) + + @staticmethod + def print_sa_link(sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise RtSockException("LINK sa size too small: {}".format(len(sa))) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise RtSockException( + "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) + ) + ifname = "ifname:{} ".format( + bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) + ) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + @staticmethod + def print_sa_unknown(sa: bytes): + return "unknown_type:{}".format(sa[1]) + + @classmethod + def print_sa(cls, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception( + "sa type {} too short: {}".format( + RtConst.get_name("AF_", sa[1]), len(sa) + ) + ) + + if sa[1] == socket.AF_INET: + text = cls.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = cls.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = cls.print_sa_link(sa) + else: + text = cls.print_sa_unknown(sa) + if hd: + dump = " [{!r}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.sa = SaHelper() + + @staticmethod + def print_rtm_type(rtm_type): + return RtConst.get_name("RTM_", rtm_type) + + @property + def rtm_type_str(self): + return self.print_rtm_type(self.rtm_type) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [ + RtConst.RTM_ADD, + RtConst.RTM_DELETE, + RtConst.RTM_CHANGE, + RtConst.RTM_GET, + ] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self.rtm_inits = 0 + self.rtm_rmx = RtMetrics() + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes: bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): + if ":" in ip_addr: + self.add_ip6_attr(attr_type, ip_addr, scopeid) + else: + self.add_ip4_attr(attr_type, ip_addr) + + def add_ip4_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def get_sa(self, attr_type) -> bytes: + return self._attrs.get(attr_type) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC> + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print( + "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + self.rtm_type_str, + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags), + ) + ) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = SaHelper.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + def print_in_message(self): + print("vvvvvvvv IN vvvvvvvv") + self.print_message() + print() + + def verify_sa_inet(self, sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception( + "IPv4 sin_len too big: {} vs sa size {}: {}".format( + sa_data[0], len(sa_data), sa_data + ) + ) + sin = SockaddrIn.from_buffer_copy(sa_data) + assert sin.sin_port == 0 + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception( + "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) + ) + our_sa = self._attrs[sa_type] + assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) + assert len(sa_data) == len(our_sa) + assert sa_data == our_sa + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type_str == self.print_rtm_type(rtm_type) + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception("SA type {} not present".format(sa_type_name)) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception( + "messages size {} is less than expected {}".format( + len(data), sizeof(RtMsgHdr) + ) + ) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_inits = hdr.rtm_inits + self.rtm_rmx = hdr.rtm_rmx + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception( + "SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data) + ) + ) + self._attrs[v] = data[off : off + data[off]] + off += roundup2(data[off], RtConst.ALIGN) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), RtConst.ALIGN) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + rtm_inits=self.rtm_inits, + rtm_rmx=self.rtm_rmx, + ) + buf = bytearray(sz) + buf[0 : sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off : off + sa_len] = v + off += roundup2(len(v), RtConst.ALIGN) + return bytes(buf) + + +class Rtsock: + def __init__(self): + self.socket = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def get_weight(self, weight) -> int: + if weight: + return weight + else: + return 1 # RT_DEFAULT_WEIGHT + + def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): + px = prefix.split("/") + addr_sa = SaHelper.ip_sa(px[0]) + if len(px) > 1: + pxlen = int(px[1]) + if SaHelper.is_ipv6(px[0]): + mask_sa = SaHelper.pxlen6_sa(pxlen) + else: + mask_sa = SaHelper.pxlen4_sa(pxlen) + else: + mask_sa = None + msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) + if isinstance(gw, bytes): + msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) + else: + # String + msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) + return msg + + def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) + + def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) + + def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) + *** 639 LINES SKIPPED ***