#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import os
import sys

import lib.args
import lib.base
import lib.cache
import lib.disk
import lib.human
import lib.lftest
import lib.shell
from lib.globals import STATE_CRIT, STATE_OK, STATE_UNKNOWN, STATE_WARN

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026042301'

DESCRIPTION = """Queries SNMP OIDs defined in a CSV file and checks the returned values against optional
warning and critical thresholds. Supports SNMP v1, v2c, and v3 with authentication and
privacy protocols."""


CSV_COL_OID = 0
CSV_COL_NAME = 1
CSV_COL_RECALC = 2
CSV_COL_UNIT = 3
CSV_COL_WARN = 4
CSV_COL_CRIT = 5
CSV_COL_SIFL = 6
CSV_COL_RCA = 7
CSV_COL_IGNPERF = 8  # added 2024052901
CSV_COL_PERFTHRSHLD = 9  # added 2024052901
CSV_COL_SKIPOUTPUT = 10  # added 2025052101
# the last (non-existent) column contains the snmp result

DEFAULT_HIDE_TABLE = False

MAX_OIDS_PER_REQUEST = 25


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--community',
        help='SNMP v1/v2c community string. Default: %(default)s',
        default='public',
        dest='COMMUNITY',
    )

    parser.add_argument(
        '--device',
        help='Name of a CSV file containing the SNMP OIDs, located under `./device-oids`. '
        'The recommended naming convention is `class-vendor-model.csv`. '
        '`any-any-any.csv` is a good starting point showing some features. '
        'Example: `--device switch-fs-s3900.csv`. '
        'Default: %(default)s',
        dest='DEVICE',
        default='any-any-any.csv',
    )

    parser.add_argument(
        '--hide-ok',
        help='Suppress OIDs with OK state from output. Default: %(default)s',
        dest='HIDEOK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--hide-table',
        help='Suppress the table from output. Default: %(default)s',
        dest='HIDE_TABLE',
        default=DEFAULT_HIDE_TABLE,
        action='store_true',
    )

    parser.add_argument(
        '-H',
        '--hostname',
        help='SNMP appliance hostname or IP address.',
        dest='HOSTNAME',
        required=True,
    )

    parser.add_argument(
        '--mib',
        help='MIB(s) to load, behaves like the `-m` option of `snmpget`. '
        'Example: `--mib "+FS-MIB"` or `--mib "FS-MIB:BROTHER-MIB"`.',
        dest='MIB',
    )

    parser.add_argument(
        '--mib-dir',
        help='Colon-separated list of directories to search for MIBs, behaves like the `-M` option of `snmpget`. '
        'Default: %(default)s',
        dest='MIB_DIR',
        default='$HOME/.snmp/mibs:/usr/share/snmp/mibs',
    )

    parser.add_argument(
        '--snmp-version',
        help='SNMP version to use. Default: %(default)s',
        dest='SNMP_VERSION',
        choices=['1', '2c', '3'],
        default='2c',
    )

    parser.add_argument(
        '--test',
        help=lib.args.help('--test'),
        dest='TEST',
        type=lib.args.csv,
    )

    parser.add_argument(
        '-t',
        '--timeout',
        help='Network timeout in seconds. Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=7,
    )

    parser.add_argument(
        '--v3-auth-prot',
        help='SNMPv3 authentication protocol.',
        dest='V3_AUTH_PROT',
        choices=['MD5', 'SHA', 'SHA-224', 'SHA-256', 'SHA-384', 'SHA-512'],
    )

    parser.add_argument(
        '--v3-auth-prot-password',
        help='SNMPv3 authentication protocol passphrase.',
        dest='V3_AUTH_PROT_PASSWORD',
    )

    parser.add_argument(
        '--v3-boots-time',
        help='SNMPv3 destination engine boots/time.',
        dest='V3_BOOTS_TIME',
    )

    parser.add_argument(
        '--v3-context',
        help='SNMPv3 context name. Example: `--v3-context bridge1`.',
        dest='V3_CONTEXT',
    )

    parser.add_argument(
        '--v3-context-engine-id',
        help='SNMPv3 context engine ID. '
        'Example: `--v3-context-engine-id 800000020109840301`.',
        dest='V3_CONTEXT_ENGINE_ID',
    )

    parser.add_argument(
        '--v3-level',
        help='SNMPv3 security level.',
        dest='V3_LEVEL',
        choices=['noAuthNoPriv', 'authNoPriv', 'authPriv'],
    )

    parser.add_argument(
        '--v3-priv-prot',
        help='SNMPv3 privacy protocol.',
        dest='V3_PRIV_PROT',
        choices=['DES', 'AES', 'AES-192', 'AES-256'],
    )

    parser.add_argument(
        '--v3-priv-prot-password',
        help='SNMPv3 privacy protocol passphrase.',
        dest='V3_PRIV_PROT_PASSWORD',
    )

    parser.add_argument(
        '--v3-security-engine-id',
        help='SNMPv3 security engine ID. '
        'Example: `--v3-security-engine-id 800000020109840301`.',
        dest='V3_SECURITY_ENGINE_ID',
    )

    parser.add_argument(
        '--v3-username',
        help='SNMPv3 security name (username). Example: `--v3-username bert`.',
        dest='V3_USERNAME',
    )

    args, _ = parser.parse_known_args()
    return args


def build_snmpget_call(args):
    """Build parameters for snmpget:
    -r 0: set the number of retries to zero
    -O: Toggle various defaults controlling output display:
        q:  quick print for easier parsing
        S:  print MIB module-id plus last element
        t:  print timeticks unparsed as numeric integers
        U:  don't print units
    """
    if args.SNMP_VERSION in ['1', '2c']:
        cmd = f'snmpget -v {args.SNMP_VERSION} '
        cmd += f"-c '{args.COMMUNITY}' " if args.COMMUNITY else ''
    else:
        cmd = 'snmpget -v 3 '
        cmd += f'-a {args.V3_AUTH_PROT} ' if args.V3_AUTH_PROT else ''
        cmd += (
            f"-A '{args.V3_AUTH_PROT_PASSWORD}' " if args.V3_AUTH_PROT_PASSWORD else ''
        )
        cmd += f'-e {args.V3_SECURITY_ENGINE_ID} ' if args.V3_SECURITY_ENGINE_ID else ''
        cmd += f'-E {args.V3_CONTEXT_ENGINE_ID} ' if args.V3_CONTEXT_ENGINE_ID else ''
        cmd += f'-l {args.V3_LEVEL} ' if args.V3_LEVEL else ''
        cmd += f'-n {args.V3_CONTEXT} ' if args.V3_CONTEXT else ''
        cmd += f'-u {args.V3_USERNAME} ' if args.V3_USERNAME else ''
        cmd += f'-x {args.V3_PRIV_PROT} ' if args.V3_PRIV_PROT else ''
        cmd += (
            f"-X '{args.V3_PRIV_PROT_PASSWORD}' " if args.V3_PRIV_PROT_PASSWORD else ''
        )
        cmd += f'-Z {args.V3_BOOTS_TIME} ' if args.V3_BOOTS_TIME else ''
    cmd += '-OSqtU -r 0 '
    cmd += f'-t {args.TIMEOUT} ' if args.TIMEOUT else ''
    cmd += f'-M {args.MIB_DIR} ' if args.MIB_DIR else ''
    cmd += f'-m {args.MIB} ' if args.MIB else ''
    cmd += f'{args.HOSTNAME} ' if args.HOSTNAME else ''
    cmd += '_OIDS_'

    return cmd


def index_of_oid(snmp_objects, oid, start=0):
    """Find the index of "b" in [ ["a", "x"], ["b", "y"], ["c", "z"] ] (would be 1)"""
    idx = start
    for snmp_object in snmp_objects[start:]:
        if snmp_object[0] == oid:
            return idx
        idx += 1
    return None


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # read oid list for our device from the CSV file
    plugin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
    snmp_objects = lib.base.coe(
        lib.disk.read_csv(
            os.path.join(plugin_path, f'device-oids/{args.DEVICE}'),
            as_dict=False,
            skip_empty_rows=True,
        )
    )
    if args.TEST:
        # in test mode, override with CSV from unit-test directory (if any), which
        # has to be in unit-test directory
        test_csvfile = os.path.join(
            plugin_path,
            'unit-test',
            os.path.basename(args.TEST[0]) + '.csv',
        )
        if lib.disk.file_exists(test_csvfile):
            snmp_objects = lib.base.coe(
                lib.disk.read_csv(
                    test_csvfile,
                    as_dict=False,
                    skip_empty_rows=True,
                )
            )

    # Max. 128 object identifiers are allowed in one snmp get request, so divide them into chunks.
    # To avoid "Error in packet. Reason: (tooBig) Response message would have been too large.", we
    # only ask for 25 OIDs per request.
    oids = []
    count = 0
    tmp = ''
    for snmp_object in snmp_objects[1:]:  # ignore the header row in csv
        tmp += f'{snmp_object[CSV_COL_OID]} '
        count += 1
        if count == MAX_OIDS_PER_REQUEST:
            oids.append(tmp)
            count = 0
            tmp = ''
    if tmp:
        oids.append(tmp)

    # fetch data
    stdout = ''
    if args.TEST is None:
        cmd = build_snmpget_call(args)
        # cmd is built from admin-controlled SNMP parameters (Icinga check config),
        # not from end-user input; shell=True is required so snmpget accepts the
        # escaped OID list as a single argument
        for oid in oids:
            tmp, stderr, retc = lib.base.coe(
                lib.shell.shell_exec(cmd.replace('_OIDS_', oid), shell=True)  # nosec B604
            )
            if stderr or retc != 0:
                lib.base.cu(stderr)
            stdout += tmp
    else:
        # do not call the command, put in test data
        stdout, stderr, retc = lib.lftest.test(args.TEST)
        args.DEVICE = args.TEST[0]

    # enrich snmp_objects with results from snmpget run
    value = ''
    oid_index = 0
    for row in stdout.splitlines():
        if not row:
            continue
        space_index = row.find(
            ' '
        )  # find first space char (the delimiter used by snmpget)
        oid = row[0:space_index]
        oid_index = index_of_oid(snmp_objects, oid, oid_index + 1)
        if oid_index is None:
            # nothing found, oid_index is None
            oid_index = 0
        else:
            value = (
                row[space_index + 1 :]
                .replace('Wrong Type (should be Timeticks): ', '')
                .strip()
            )
            snmp_objects[oid_index].append(value)

    # init some vars
    msg, msg_header = '', ''
    state = STATE_OK
    perfdata = ''
    values = {}  # all values in a single dict, for re-calc column
    table_values = []

    # the last (non-existent) column should contain the result
    CSV_COL_VALUE = len(snmp_objects[0])

    # evaluate results
    for snmp_object in snmp_objects[1:]:
        if len(snmp_object) <= 1:
            # definitely an invalid csv line, ignore
            continue
        try:
            value = snmp_object[CSV_COL_VALUE]
        except Exception:
            # we got no value from snmpget
            value = None

        name = (
            snmp_object[CSV_COL_NAME]
            if snmp_object[CSV_COL_NAME]
            else snmp_object[CSV_COL_OID]
        )
        oid = snmp_object[CSV_COL_OID]

        # snmpget error handling
        if value is not None and value.lower().startswith('no such '):
            # No Such Instance currently exists at this OID
            # No Such Object available on this agent at this OID
            table_values.append(
                {
                    'name': oid,
                    'value': value,
                    'state': lib.base.state2str(STATE_UNKNOWN),
                }
            )
            state = lib.base.get_worst(state, STATE_UNKNOWN)
            continue

        recalc, unit, warn, crit = '', '', '', ''
        show_in_first_line, report_change = False, False
        try:
            recalc = snmp_object[CSV_COL_RECALC]
            unit = snmp_object[CSV_COL_UNIT]
            warn = snmp_object[CSV_COL_WARN]
            crit = snmp_object[CSV_COL_CRIT]
            show_in_first_line = lib.base.str2bool(snmp_object[CSV_COL_SIFL])
            report_change = snmp_object[CSV_COL_RCA]
        except Exception:
            pass

        # manage different CSV formats
        # v1: last column is CSV_COL_RCA
        # v2: last column is CSV_COL_PERFTHRSHLD, added 2024052901
        # v3: last column is CSV_COL_SKIPOUTPUT, added 2025052101
        skip_output = False
        ignore_perfdata = False
        perf_thresholds = False
        if CSV_COL_VALUE > CSV_COL_SKIPOUTPUT:
            # v3
            try:
                skip_output = lib.base.str2bool(snmp_object[CSV_COL_SKIPOUTPUT])
            except IndexError:
                # invalid csv definition
                pass
        if CSV_COL_VALUE > CSV_COL_RCA:
            # v2
            try:
                ignore_perfdata = lib.base.str2bool(snmp_object[CSV_COL_IGNPERF])
                perf_thresholds = snmp_object[CSV_COL_PERFTHRSHLD]
            except IndexError:
                # invalid csv definition
                pass

        if recalc:
            # we got a formula
            # eval() is the documented snmp plugin feature: admins provide
            # arithmetic recalculation formulas (e.g. 'value * 8') in the check
            # config; ast.literal_eval cannot evaluate arithmetic
            try:
                value = eval(recalc, {}, {'value': value, 'values': values})  # nosec B307
            except Exception as e:
                table_values.append(
                    {
                        'name': name,
                        'value': f"The recalc in {args.DEVICE} failed with {type(e).__name__}: '{e}'.",
                        'state': lib.base.state2str(STATE_UNKNOWN),
                    }
                )
                state = lib.base.get_worst(state, STATE_UNKNOWN)
                continue
        values[name] = value

        # check the state
        # eval() is the documented snmp plugin feature: admins provide
        # boolean threshold expressions (e.g. 'value > 80') in the check config
        check_state = True
        value_state = STATE_OK
        if check_state and crit:
            if eval(crit, {}, {'value': value, 'values': values}):  # nosec B307
                value_state = STATE_CRIT
                msg_header += (
                    f'{name}: {value}{unit} {lib.base.state2str(STATE_CRIT)}, '
                )
                check_state = False  # no more checks for this value necessary
        if check_state and warn:
            if eval(warn, {}, {'value': value, 'values': values}):  # nosec B307
                value_state = STATE_WARN
                msg_header += (
                    f'{name}: {value}{unit} {lib.base.state2str(STATE_WARN)}, '
                )
                check_state = False
        if check_state and report_change:
            cache_key = f'{args.DEVICE}::{oid}'
            cache_value = lib.cache.get(
                cache_key,
                filename='linuxfabrik-monitoring-plugins-snmp.db',
            )
            if cache_value:  # there is a previous value
                if cache_value != value:  # that is different from the current one
                    if report_change.lower().startswith('crit'):
                        value_state = STATE_CRIT
                        msg_header += (
                            f'{name}: {value}{unit} '
                            f'({lib.base.state2str(STATE_CRIT)}'
                            f', changed from "{cache_value}"), '
                        )
                    else:
                        value_state = STATE_WARN
                        msg_header += (
                            f'{name}: {value}{unit} '
                            f'({lib.base.state2str(STATE_WARN)}'
                            f', changed from "{cache_value}"), '
                        )
            else:
                lib.cache.set(
                    cache_key, value, filename='linuxfabrik-monitoring-plugins-snmp.db'
                )
        state = lib.base.get_worst(state, value_state)

        # create message body (the table)
        if ',' in unit:
            # example: "b,c" - convert the first part to human readable bytes,
            # but suffix the perfdata as a continous counter
            unit, perfdata_unit = unit.split(',')
        else:
            perfdata_unit = unit
        if unit == 's':
            if show_in_first_line:
                msg_header += (
                    f'{name}: '
                    f'{lib.human.seconds2human(value)}'
                    f'{lib.base.state2str(value_state, prefix=" ")}'
                    f', '
                )
            if not skip_output:
                if not args.HIDEOK or value_state:
                    table_values.append(
                        {
                            'name': name,
                            'value': f'{lib.human.seconds2human(value)}',
                            'state': lib.base.state2str(value_state, empty_ok=False),
                        }
                    )
        elif unit.lower() == 'b':
            if show_in_first_line:
                msg_header += (
                    f'{name}: '
                    f'{lib.human.bytes2human(value)}'
                    f'{lib.base.state2str(value_state, prefix=" ")}'
                    f', '
                )
            if not skip_output:
                if not args.HIDEOK or value_state:
                    table_values.append(
                        {
                            'name': name,
                            'value': f'{lib.human.bytes2human(value)}',
                            'state': lib.base.state2str(value_state, empty_ok=False),
                        }
                    )
        elif unit.lower() == 'bps':
            if show_in_first_line:
                msg_header += (
                    f'{name}: '
                    f'{lib.human.bps2human(value)}'
                    f'{lib.base.state2str(value_state, prefix=" ")}'
                    f', '
                )
            if not skip_output:
                if not args.HIDEOK or value_state:
                    table_values.append(
                        {
                            'name': name,
                            'value': f'{lib.human.bps2human(value)}',
                            'state': lib.base.state2str(value_state, empty_ok=False),
                        }
                    )
        else:
            if show_in_first_line:
                msg_header += (
                    f'{name}: {value}{unit}'
                    f'{lib.base.state2str(value_state, prefix=" ")}'
                    f', '
                )
            if not skip_output:
                if not args.HIDEOK or value_state:
                    table_values.append(
                        {
                            'name': name,
                            'value': f'{value}{unit}',
                            'state': lib.base.state2str(value_state, empty_ok=False),
                        }
                    )

        # create perfdata for numeric values
        value_type = lib.base.guess_type(value)
        w, c = None, None
        try:
            # performance thresholds should return a tuple
            # eval() is the documented feature: admins provide perf thresholds
            # as Python tuple expressions in the check config
            w, c = eval(perf_thresholds, {})  # nosec B307
        except Exception:
            pass
        if not ignore_perfdata and isinstance(value_type, (int, float)):
            if perfdata_unit == '%':
                perfdata += lib.base.get_perfdata(
                    name,
                    value,
                    uom='%',
                    warn=w,
                    crit=c,
                    _min=0,
                    _max=100,
                )
            elif perfdata_unit.upper() in [
                'B',
                'KB',
                'MB',
                'GB',
                'TB',
                'PB',
                'EB',
                'ZB',
                'YB',
            ]:
                perfdata += lib.base.get_perfdata(
                    name,
                    value,
                    uom=perfdata_unit.upper(),
                    warn=w,
                    crit=c,
                    _min=0,
                    _max=None,
                )
            elif perfdata_unit.lower() in [
                'c',
                's',
                'ms',
                'us',
            ]:
                perfdata += lib.base.get_perfdata(
                    name,
                    value,
                    uom=perfdata_unit.lower(),
                    warn=w,
                    crit=c,
                    _min=0,
                    _max=None,
                )
            else:
                # unknown perfdata suffixes, so do not use them
                perfdata += lib.base.get_perfdata(
                    name,
                    value,
                    uom=None,
                    warn=w,
                    crit=c,
                    _min=0,
                    _max=None,
                )

    # create output
    if msg_header:

        # build the message
        msg += msg_header[:-2]
    if not args.HIDE_TABLE and len(table_values) > 0:
        msg += '\n\n' + lib.base.get_table(
            table_values,
            ['name', 'value', 'state'],
            header=['Key', 'Value', 'State'],
        )

    # over and out
    lib.base.oao(msg, state, perfdata)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
