#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details.
"""

import argparse  # pylint: disable=C0413
import sys  # pylint: disable=C0413

import lib.base  # pylint: disable=C0413
import lib.cache  # pylint: disable=C0413
import lib.db_sqlite  # pylint: disable=C0413
import lib.human  # pylint: disable=C0413
import lib.time  # pylint: disable=C0413
from lib.globals import (STATE_CRIT, STATE_OK,  # pylint: disable=C0413
                          STATE_UNKNOWN, STATE_WARN)

try:
    import psutil
except ImportError as e:
    print('Python module "psutil" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2025020401'

DESCRIPTION = """Checks network IO."""

DEFAULT_CACHE_EXPIRE = 90
DEFAULT_COUNT = 5       # measurements; if check runs once per minute, this is a 5 minute interval
DEFAULT_WARN = 80       # %
DEFAULT_CRIT = 90       # %
DEFAULT_IGNORE = [
    'lo',
]


def parse_args():
    """Parse command line arguments using argparse.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V', '--version',
        action='version',
        version='{0}: v{1} by {2}'.format('%(prog)s', __version__, __author__)
    )

    parser.add_argument(
        '--always-ok',
        help='Always returns OK.',
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--count',
        help='Number of times the value must exceed specified thresholds before alerting. '
             'Default: %(default)s',
        dest='COUNT',
        type=int,
        default=DEFAULT_COUNT,
    )

    parser.add_argument(
        '--critical',
        help='Set the CRIT threshold for network I/O rx/tx rate over the entire period as a '
             'percentage of the maximum network I/O rate. Default: >= %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--ignore',
        help='Ignore network interfaces starting with a string like "tun" (repeating). '
             'Default: %(default)s',
        dest='IGNORE',
        default=DEFAULT_IGNORE,
        action='append',
    )

    parser.add_argument(
        '--warning',
        help='Set the CRIT threshold for network I/O rx/tx rate over the entire period as a '
             'percentage of the maximum network I/O rate. Default: >= %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    return parser.parse_args()


def main():
    """The main function. Hier spielt die Musik.
    """

    # parse the command line, exit with UNKNOWN if it fails
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    conn = lib.base.coe(
        lib.db_sqlite.connect(filename='linuxfabrik-monitoring-plugins-network-io.db')
    )

    # create the perfdata table
    definition = '''
            name                TEXT NOT NULL,
            bytes_sent          INT NOT NULL,
            bytes_recv          INT NOT NULL,
            packets_sent        INT NOT NULL,
            packets_recv        INT NOT NULL,
            errin               INT NOT NULL,
            errout              INT NOT NULL,
            dropin              INT NOT NULL,
            dropout             INT NOT NULL,
            timestamp           INT NOT NULL
        '''
    lib.base.coe(lib.db_sqlite.create_table(conn, definition, drop_table_first=False))
    lib.base.coe(lib.db_sqlite.create_index(conn, 'name'))

    # get interface data and store it to database
    try:
        net_io_counters = psutil.net_io_counters(pernic=True, nowrap=True)
    except ValueError:
        lib.base.cu('psutil raised an error')

    now = lib.time.now()

    interfaces = []
    for interface, values in net_io_counters.items():
        if not interface or any(interface.startswith(i) for i in args.IGNORE):
            continue
        data = {}
        data['name'] = interface
        data['bytes_sent'] = getattr(values, 'bytes_sent', 0)
        data['bytes_recv'] = getattr(values, 'bytes_recv', 0)
        data['packets_sent'] = getattr(values, 'packets_sent', 0)
        data['packets_recv'] = getattr(values, 'packets_recv', 0)
        data['errin'] = getattr(values, 'errin', 0)
        data['errout'] = getattr(values, 'errout', 0)
        data['dropin'] = getattr(values, 'dropin', 0)
        data['dropout'] = getattr(values, 'dropout', 0)
        data['timestamp'] = now
        interfaces.append(interface)
        lib.base.coe(lib.db_sqlite.insert(conn, data))

    lib.base.coe(lib.db_sqlite.cut(conn, _max=args.COUNT*len(interfaces)))
    lib.base.coe(lib.db_sqlite.commit(conn))

    # init some vars
    msg = 'No I/O.'
    perfdata = ''
    state = STATE_OK
    table_values = []

    max_rw = 0       # interface with the highest sum of rx/tx: show this on top later on
    # we warn about a "count" period/amount of time, not about the current situation above
    # (what might be a peak only)
    for interface in sorted(interfaces):
        # get all historical data rows for a specific interface, newest item first
        interfacedata = lib.base.coe(lib.db_sqlite.select(
            conn,
            '''
            SELECT *
            FROM perfdata
            WHERE name = :name
            ORDER BY timestamp DESC
            ''',
            {'name': interface},
        ))
        if len(interfacedata) < 2:
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # calculate current rates
        timestamp_diff = interfacedata[0]['timestamp'] - interfacedata[1]['timestamp'] # in seconds
        if timestamp_diff == 0:
            timestamp_diff = 1
        bytes_recv_per_second1 = int(float(interfacedata[0]['bytes_recv'] - interfacedata[1]['bytes_recv']) / timestamp_diff) # pylint: disable=C0301
        bytes_sent_per_second1 = int(float(interfacedata[0]['bytes_sent'] - interfacedata[1]['bytes_sent']) / timestamp_diff) # pylint: disable=C0301
        throughput1 = bytes_recv_per_second1 + bytes_sent_per_second1
        if any([
            timestamp_diff < 0,
            bytes_recv_per_second1 < 0,
            bytes_sent_per_second1 < 0,
            throughput1 < 0,
        ]):
            # happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # store the max. measured throughput in cache
        throughput_db = lib.cache.get(
            'network-io-{}-throughput-max'.format(interface),
            filename='linuxfabrik-monitoring-plugins-network-io.db',
        )
        if throughput_db is False:
            # unknown interface, no value
            throughput_max = 10 * 1024 * 1024      # interface should be capable of 10 MB/sec
            lib.cache.set(
                'network-io-{}-throughput-max'.format(interface),
                throughput_max,
                filename='linuxfabrik-monitoring-plugins-network-io.db',
            )
        elif throughput1 > int(throughput_db):
            throughput_max = throughput1
            lib.cache.set(
                'network-io-{}-throughput-max'.format(interface),
                throughput_max,
                filename='linuxfabrik-monitoring-plugins-network-io.db',
            )
        else:
            throughput_max = int(throughput_db)

        # perfdata
        perfdata += lib.base.get_perfdata('{}_bytes_recv'.format(interface), interfacedata[0]['bytes_recv'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_bytes_recv_per_second1'.format(interface), bytes_recv_per_second1, 'B', None, None, 0, throughput_max) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_bytes_sent'.format(interface), interfacedata[0]['bytes_sent'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_bytes_sent_per_second1'.format(interface), bytes_sent_per_second1, 'B', None, None, 0, throughput_max) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_packets_sent'.format(interface), interfacedata[0]['packets_sent'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_packets_recv'.format(interface), interfacedata[0]['packets_recv'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_errin'.format(interface), interfacedata[0]['errin'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_errout'.format(interface), interfacedata[0]['errout'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_dropin'.format(interface), interfacedata[0]['dropin'], 'c', None, None, 0, None) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_dropout'.format(interface), interfacedata[0]['dropout'], 'c', None, None, 0, None) # pylint: disable=C0301

        perfdata += lib.base.get_perfdata('{}_throughput1'.format(interface), throughput1, None, None, None, 0, throughput_max) # pylint: disable=C0301

        if throughput1 > max_rw:
            msg = '{}: {}/s rx, {}/s tx (current)'.format(
                interface,
                lib.human.bytes2human(bytes_recv_per_second1),
                lib.human.bytes2human(bytes_sent_per_second1),
            )
            max_rw = bytes_recv_per_second1 + bytes_sent_per_second1

        # calculate rx/tx rate over the entire period
        if len(interfacedata) != args.COUNT:
            # not enough data yet
            continue
        timestamp_diff = interfacedata[0]['timestamp'] - interfacedata[args.COUNT - 1]['timestamp'] # in seconds # pylint: disable=C0301
        if timestamp_diff == 0:
            timestamp_diff = 1
        bytes_recv_per_second15 = float((interfacedata[0]['bytes_recv'] - interfacedata[args.COUNT - 1]['bytes_recv']) / timestamp_diff) # pylint: disable=C0301
        bytes_recv_per_second15 = max(bytes_recv_per_second15, 0)
        bytes_sent_per_second15 = float((interfacedata[0]['bytes_sent'] - interfacedata[args.COUNT - 1]['bytes_sent']) / timestamp_diff) # pylint: disable=C0301
        bytes_sent_per_second15 = max(bytes_sent_per_second15, 0)
        throughput15 = bytes_recv_per_second15 + bytes_sent_per_second15   # let's just call it like in cpu-usage # pylint: disable=C0301

        perfdata += lib.base.get_perfdata('{}_bytes_recv_per_second15'.format(interface), bytes_recv_per_second15, 'B', None, None, 0, throughput_max) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_bytes_sent_per_second15'.format(interface), bytes_sent_per_second15, 'B', None, None, 0, throughput_max) # pylint: disable=C0301
        perfdata += lib.base.get_perfdata('{}_throughput15'.format(interface), throughput15, None, throughput_max * args.WARN / 100, throughput_max * args.CRIT / 100, 0, throughput_max) # pylint: disable=C0301

        # get state based on max measured I/O values
        interface_state = lib.base.get_state(
            throughput15,
            throughput_max * args.WARN / 100,
            throughput_max * args.CRIT / 100,
        )
        state = lib.base.get_worst(interface_state, state)

        table_values.append({
            'name': interface,
            'max': lib.human.bytes2human(throughput_max),
            'rx1': lib.human.bytes2human(bytes_recv_per_second1),
            'tx1': lib.human.bytes2human(bytes_sent_per_second1),
            'rx15': lib.human.bytes2human(bytes_recv_per_second15),
            'tx15': lib.human.bytes2human(bytes_sent_per_second15),
            't15': lib.human.bytes2human(throughput15) + lib.base.state2str(interface_state, prefix=' '), # pylint: disable=C0301
        })

    lib.db_sqlite.close(conn)

    msg = msg + '\n\n'
    if len(table_values) > 0:
        msg += lib.base.get_table(
            table_values,
            [
                'name',
                'max',
                'rx1',
                'tx1',
                'rx15',
                'tx15',
                't15',
            ],
            header=[
                'Interface',
                'rxtx-max/s',
                'rx1/s',
                'tx1/s',
                'rx{}/s'.format(args.COUNT),
                'tx{}/s'.format(args.COUNT),
                'rxtx{}/s'.format(args.COUNT)
            ],
        )

    lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        lib.base.cu()
