#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import json
import os
import re
import socket
import sys

import lib.args
import lib.base
import lib.human
import lib.lftest
import lib.time
import lib.txt
from lib.globals import STATE_OK, STATE_UNKNOWN, STATE_WARN

try:
    import vici
except ImportError:
    print('Python module "vici" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026041407'

DESCRIPTION = """Checks IPSec connection states on a strongSwan VPN gateway. Connects to the charon
daemon via the VICI interface to retrieve IKE SA and CHILD SA states. Alerts on
connections that are not in the expected established state. Connection names can be
filtered out with --ignore, which is useful for gateways that mix permanent site-to-site
peers with transient remote-access clients where only the site-to-site peers should
drive the alert.
Requires root or sudo."""

DEFAULT_LENGTHY = False
DEFAULT_SOCKET = '/run/strongswan/charon.vici'


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--ignore',
        help='Ignore connections whose VICI key matches this Python regular '
        'expression. Case-sensitive by default; use `(?i)` for case-insensitive '
        'matching. Can be specified multiple times. Example: `--ignore="^RA_"` '
        'to skip transient remote-access clients on a VPN gateway that also '
        'carries permanent site-to-site peers. Example: `--ignore="(?i)test"` '
        '(case-insensitive) to skip any connection with "test" in its name. '
        'Default: %(default)s',
        dest='IGNORE',
        action='append',
        default=None,
    )

    parser.add_argument(
        '--lengthy',
        help=lib.args.help('--lengthy'),
        dest='LENGTHY',
        action='store_true',
        default=DEFAULT_LENGTHY,
    )

    parser.add_argument(
        '--match',
        help='Only check connections whose VICI key matches this Python regular '
        'expression. Case-sensitive by default; use `(?i)` for case-insensitive '
        'matching. Can be specified multiple times. If both `--match` and '
        '`--ignore` are given, a connection must match `--match` AND not match '
        '`--ignore` to be checked (include first, exclude second). '
        'Example: `--match="^S2S_SITE-XY$"` to pin an Icinga service to one '
        'specific site-to-site peer. Example: `--match="(?i)^s2s_"` '
        '(case-insensitive) to check every site-to-site peer on a gateway. '
        'Default: %(default)s',
        dest='MATCH',
        action='append',
        default=None,
    )

    parser.add_argument(
        '--socket',
        help='Path to the Versatile IKE Control Interface (VICI) socket. '
        'Default: %(default)s',
        dest='SOCKET',
        default=DEFAULT_SOCKET,
    )

    parser.add_argument(
        '--test',
        help=lib.args.help('--test'),
        dest='TEST',
        type=lib.args.csv,
    )

    args, _ = parser.parse_known_args()
    return args


def _collect_keys(entries):
    """Return the sorted list of dict keys from a sequence of
    single-key dicts. VICI's `list_conns()` and `list_sas()` both
    yield one-key-per-dict entries, which makes them interchangeable
    for a simple key collection. The production path passes the
    VICI generator directly; the test path passes a pre-loaded
    list from the JSON fixture.
    """
    keys = []
    for entry in entries:
        keys.extend(entry.keys())
    keys.sort()
    return keys


def keep_connection(name, match_patterns, ignore_patterns):
    """Return True if `name` should be kept by the --match / --ignore
    filter pair, False if it should be dropped. Include first, then
    exclude: a name passes if it matches any `match_patterns` entry
    (or if `match_patterns` is empty) AND does not match any
    `ignore_patterns` entry. Same semantics as the disk-usage plugin
    and the lib.args canonical `--match` / `--ignore` convention.
    """
    if match_patterns and not any(p.search(name) for p in match_patterns):
        return False
    return not any(p.search(name) for p in ignore_patterns)


def format_sas_data(sas):
    """Re-format SAS connection details for a single connection.
    Also handle different VICI versions.

    VICI returns all scalars as `bytes` at runtime, while the JSON
    fixtures used by the unit tests hold already-decoded `str`
    values. Both shapes are accepted: `bytes` are decoded via
    `lib.txt.to_text()`, `bytearray` is joined and decoded, and
    anything else (including `str`) is passed through unchanged.
    """
    data = {}
    for key, value in sas.items():
        if isinstance(value, bytes):
            data[key] = lib.txt.to_text(value)
        elif isinstance(value, bytearray):
            data[key] = lib.txt.to_text(b', '.join(value))
        else:
            data[key] = value
        if key == 'integ-alg' and not value:
            continue
        # handle different versions:
        if key == 'reauth-time':  # v5.7
            data['rekey-time'] = data['reauth-time']  # v5.9

    if 'integ-alg' not in data:
        # If a connection uses AES GCM encryption (or probably an other AEAD algorithm) the vici
        # interface does not return a "integ-alg" key at all. There is no key without value but the
        # key is missing in the dictionary.
        data['integ-alg'] = 'None'
    data['encr'] = (
        f'{data["encr-alg"]}-{data["encr-keysize"]}'
        f'/{data["integ-alg"]}'
        f'/{data["prf-alg"]}'
        f'/{data["dh-group"]}'
    )
    data['established-hr'] = lib.time.epoch2iso(
        lib.time.now(as_type='epoch') - int(data['established'])
    )
    if data['local-id'] != data['local-host']:
        data['local'] = (
            f'{data["local-host"]}:{data["local-port"]} ("{data["local-id"]}")'
        )
    else:
        data['local'] = f'{data["local-host"]}:{data["local-port"]}'
    data['rekey-time-hr'] = lib.time.epoch2iso(
        lib.time.now(as_type='epoch') + int(data['rekey-time'])
    )
    if data['remote-id'] != data['remote-host']:
        data['remote'] = (
            f'{data["remote-host"]}:{data["remote-port"]} ("{data["remote-id"]}")'
        )
    else:
        data['remote'] = f'{data["remote-host"]}:{data["remote-port"]}'
    data['state'] = data['state'].replace('ESTABLISHED', 'EST')
    data['version'] = f'v{data["version"]}'

    return data


def _join_traffic_selectors(values):
    """Join a VICI `local-ts` / `remote-ts` list into a single
    comma-separated text string. VICI returns the list elements as
    `bytes` at runtime; the JSON fixtures used by the unit tests
    hold already-decoded `str` values. Both shapes work.
    """
    if not values:
        return ''
    return ', '.join(lib.txt.to_text(v) for v in values)


def format_child_data(child):
    """Re-format child connection details for a single sub connection.
    This is much more volatile (depending on the conn state), so list all expected keys manually
    and return empty defaults if necessary.
    """
    data = {}
    data['child-bytes-in'] = int(lib.txt.to_text(child.get('bytes-in', 0)))
    data['child-bytes-out'] = int(lib.txt.to_text(child.get('bytes-out', 0)))
    data['child-dh-group'] = lib.txt.to_text(child.get('dh-group', ''))
    data['child-encr-alg'] = lib.txt.to_text(child.get('encr-alg', ''))
    data['child-encr-keysize'] = lib.txt.to_text(child.get('encr-keysize', ''))
    data['child-install-time'] = int(lib.txt.to_text(child.get('install-time', 0)))
    data['child-integ-alg'] = lib.txt.to_text(child.get('integ-alg', 'None'))
    data['child-life-time'] = int(lib.txt.to_text(child.get('life-time', 0)))
    data['child-local-ts'] = _join_traffic_selectors(child.get('local-ts'))
    data['child-mode'] = lib.txt.to_text(child.get('mode', ''))
    data['child-name'] = lib.txt.to_text(child.get('name', ''))
    data['child-protocol'] = lib.txt.to_text(child.get('protocol', ''))
    data['child-rekey-time'] = int(lib.txt.to_text(child.get('rekey-time', 0)))
    data['child-remote-ts'] = _join_traffic_selectors(child.get('remote-ts'))
    data['child-state'] = lib.txt.to_text(child.get('state', ''))

    data['child-bytes-in-hr'] = lib.human.bytes2human(data['child-bytes-in'])
    data['child-bytes-out-hr'] = lib.human.bytes2human(data['child-bytes-out'])
    data['child-encr'] = (
        f'{data["child-protocol"]}:'
        f'{data["child-encr-alg"]}-'
        f'{data["child-encr-keysize"]}'
        f'/{data["child-integ-alg"]}'
        f'/{data["child-dh-group"]}'
    )
    data['child-install-time-hr'] = lib.time.epoch2iso(
        lib.time.now(as_type='epoch') - data['child-install-time']
    )
    data['child-life-time-hr'] = lib.time.epoch2iso(
        lib.time.now(as_type='epoch') + data['child-life-time']
    )
    data['child-mode-state'] = f'{data["child-mode"]}:{data["child-state"]}'
    data['child-rekey-time-hr'] = lib.time.epoch2iso(
        lib.time.now(as_type='epoch') + data['child-rekey-time']
    )

    return data


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    if args.IGNORE is None:
        args.IGNORE = []
    if args.MATCH is None:
        args.MATCH = []

    # compile --match and --ignore patterns (case-sensitive by default,
    # matching the lib.args convention; the user can opt into
    # case-insensitive matching with the inline `(?i)` flag)
    try:
        match_patterns = [re.compile(p) for p in args.MATCH]
        ignore_patterns = [re.compile(p) for p in args.IGNORE]
    except re.error as e:
        lib.base.cu(f'Invalid regular expression: {e}')

    # fetch data
    if args.TEST is None:
        s = socket.socket(socket.AF_UNIX)
        try:
            s.connect(args.SOCKET)
        except OSError as e:
            s.close()
            lib.base.cu(f'Failed to connect to the VICI socket at {args.SOCKET}: {e}')
        try:
            session = vici.Session(s)
            list_conns = list(session.list_conns())
            # list_sas() returns a generator backed by the VICI
            # socket; materialise it to a list now so we can close
            # the socket cleanly before processing the data.
            list_sas = list(session.list_sas())
        finally:
            s.close()
    else:
        # Single-file test fixture: a JSON object with `list_conns`
        # and `list_sas` keys, both holding the raw VICI "list of
        # single-key dicts" shape. `active_connection_keys` is
        # derived from `list_sas` below via the same `_collect_keys`
        # helper the production path uses, so the test and the
        # production paths cannot drift.
        fixture_path = args.TEST[0] if args.TEST else ''
        if not fixture_path or not os.path.isfile(fixture_path):
            hint = ''
            for legacy in (
                '-possible_connection_keys',
                '-active_connection_keys',
                '-list_sas',
            ):
                if fixture_path.endswith(legacy):
                    hint = (
                        f' (the three-file convention was removed; drop the '
                        f'`{legacy}` suffix and pass '
                        f'`{fixture_path[: -len(legacy)]}` instead)'
                    )
                    break
            lib.base.cu(f'Test fixture not found: {fixture_path}{hint}')
        stdout, _stderr, _retc = lib.lftest.test(args.TEST)
        try:
            fixture = json.loads(stdout)
        except json.JSONDecodeError as e:
            lib.base.cu(f'Malformed JSON in test fixture {fixture_path}: {e}')
        list_conns = fixture.get('list_conns', [])
        list_sas = fixture.get('list_sas', [])
    possible_connection_keys = _collect_keys(list_conns)
    active_connection_keys = _collect_keys(list_sas)

    # Apply the --match / --ignore filter to both the configured and
    # the active connection lists. Filtering before the "configured
    # but not active" comparison means an ignored connection that is
    # configured but currently down does not trigger the warning,
    # which is exactly the point of --ignore.
    possible_connection_keys = [
        k
        for k in possible_connection_keys
        if keep_connection(k, match_patterns, ignore_patterns)
    ]
    active_connection_keys = [
        k
        for k in active_connection_keys
        if keep_connection(k, match_patterns, ignore_patterns)
    ]

    if not possible_connection_keys:
        lib.base.oao(
            'No connections configured.', STATE_UNKNOWN, always_ok=args.ALWAYS_OK
        )
    if not active_connection_keys:
        lib.base.oao(
            'There are no active connections at all.',
            STATE_WARN,
            always_ok=args.ALWAYS_OK,
        )

    # init some vars
    msg = ''
    state = STATE_OK
    perfdata = ''
    table_data = []

    if possible_connection_keys != active_connection_keys:
        conf_state = STATE_WARN
        state = lib.base.get_worst(state, conf_state)
        msg += (
            f'One or more connections are configured '
            f'but not active'
            f'{lib.base.state2str(conf_state, prefix=" ")}. '
        )

    # analyze data. `list_sas` is a list of single-key dicts where
    # each key is the connection name and the value is the SA
    # details; iterate the dict entries directly instead of trying
    # `sas[name]` for every `active_connection_keys` item (which was
    # O(n*m) and relied on a bare `except` to probe for presence).
    active_keyset = set(active_connection_keys)
    for sas in list_sas:
        for key, details in sas.items():
            if key not in active_keyset:
                continue
            row = format_sas_data(details)
            row['conn'] = key
            perfdata += lib.base.get_perfdata(
                f'{key}_established',
                row['established'],
                uom='s',
                _min=0,
            )
            perfdata += lib.base.get_perfdata(
                f'{key}_rekey-time',
                row['rekey-time'],
                uom='s',
                _min=0,
            )

            children = details.get('child-sas') or None
            if children == {}:
                children = None

            if children is not None:
                for child_key in children:
                    child_row = format_child_data(children[child_key])
                    # combine two dictionaries using dictionary comprehension
                    table_data.append(
                        {k: v for d in (row, child_row) for k, v in d.items()}
                    )
                    perfdata += lib.base.get_perfdata(
                        f'{key}_{child_row["child-name"]}_bytes-in',
                        child_row['child-bytes-in'],
                        uom='B',
                        _min=0,
                    )
                    perfdata += lib.base.get_perfdata(
                        f'{key}_{child_row["child-name"]}_bytes-out',
                        child_row['child-bytes-out'],
                        uom='B',
                        _min=0,
                    )
                    perfdata += lib.base.get_perfdata(
                        f'{key}_{child_row["child-name"]}_install-time',
                        child_row['child-install-time'],
                        uom='s',
                        _min=0,
                    )
                    perfdata += lib.base.get_perfdata(
                        f'{key}_{child_row["child-name"]}_life-time',
                        child_row['child-life-time'],
                        uom='s',
                        _min=0,
                    )
                    perfdata += lib.base.get_perfdata(
                        f'{key}_{child_row["child-name"]}_rekey-time',
                        child_row['child-rekey-time'],
                        uom='s',
                        _min=0,
                    )
            else:
                child_state = STATE_WARN
                state = lib.base.get_worst(state, child_state)
                msg += (
                    f'{key} not connected at child level'
                    f'{lib.base.state2str(child_state, prefix=" ")}. '
                )

    # build the message
    if state == STATE_OK:
        msg = 'Everything is ok.'

    # over and out
    if table_data:
        if not args.LENGTHY:
            keys = [
                'conn',
                'state',
                'rekey-time-hr',
                'child-name',
                'child-mode-state',
                'child-rekey-time-hr',
                'child-life-time-hr',
                'child-bytes-in-hr',
                'child-bytes-out-hr',
            ]
            headers = [
                'Conn.',
                'State',
                'Re-Authentication',
                'Child',
                'Mode:State',
                'Re-Keying',
                'Expires',
                'Rx',
                'Tx',
            ]
        else:
            keys = [
                'conn',
                'state',
                'established-hr',
                'rekey-time-hr',
                'version',
                'local',
                'remote',
                'encr',
                'child-name',
                'child-mode-state',
                'child-local-ts',
                'child-remote-ts',
                'child-encr',
                'child-install-time-hr',
                'child-rekey-time-hr',
                'child-life-time-hr',
                'child-bytes-in-hr',
                'child-bytes-out-hr',
            ]
            headers = [
                'Conn.',
                'State',
                'Established',
                'Re-Authentication',
                'IKE',
                'Local',
                'Remote',
                'Encryption/Integrity/Pseudo Random/DH',
                'Child',
                'Mode:State',
                'Local',
                'Remote',
                'Prot:Encryption/Integrity/DH',
                'Installed',
                'Re-Keying',
                'Expires',
                'Rx',
                'Tx',
            ]
        lib.base.oao(
            f'{msg}\n\n{lib.base.get_table(table_data, keys, header=headers)}',
            state,
            perfdata,
            always_ok=args.ALWAYS_OK,
        )
    else:
        lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
