#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import sys

import lib.args
import lib.base
import lib.cache
import lib.db_sqlite
import lib.human
import lib.time
from lib.globals import STATE_OK, STATE_UNKNOWN

try:
    import psutil
except ImportError:
    print('Python module "psutil" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026040801'

DESCRIPTION = """Monitors network I/O throughput per interface over time. Calculates bytes per second
from cumulative counters using SQLite state persistence between runs. Alerts only if
bandwidth thresholds have been exceeded for a configurable number of consecutive check
runs (default: 5), suppressing short spikes. Also reports packet rates, errors, and
drops per interface."""

DEFAULT_CACHE_EXPIRE = 90
DEFAULT_COUNT = (
    5  # measurements; if check runs once per minute, this is a 5 minute interval
)
DEFAULT_WARN = 80  # %
DEFAULT_CRIT = 90  # %
DEFAULT_IGNORE = [
    'lo',
]


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--count',
        help=lib.args.help('--count') + ' Default: %(default)s',
        dest='COUNT',
        type=int,
        default=DEFAULT_COUNT,
    )

    parser.add_argument(
        '--critical',
        help='CRIT threshold for network I/O rx/tx rate over the entire period as a percentage of the maximum network I/O rate. '
        'Default: >= %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--ignore',
        help='Ignore network interfaces starting with this string. Can be specified multiple times. '
        'Example: `--ignore tun`. '
        'Default: %(default)s',
        dest='IGNORE',
        default=DEFAULT_IGNORE,
        action='append',
    )

    parser.add_argument(
        '--warning',
        help='WARN threshold for network I/O rx/tx rate over the entire period as a percentage of the maximum network I/O rate. '
        'Default: >= %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    args, _ = parser.parse_known_args()
    return args


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    conn = lib.base.coe(
        lib.db_sqlite.connect(filename='linuxfabrik-monitoring-plugins-network-io.db')
    )

    # create the perfdata table
    definition = """
            name                TEXT NOT NULL,
            bytes_sent          INT NOT NULL,
            bytes_recv          INT NOT NULL,
            packets_sent        INT NOT NULL,
            packets_recv        INT NOT NULL,
            errin               INT NOT NULL,
            errout              INT NOT NULL,
            dropin              INT NOT NULL,
            dropout             INT NOT NULL,
            timestamp           INT NOT NULL
        """
    lib.base.coe(lib.db_sqlite.create_table(conn, definition, drop_table_first=False))
    lib.base.coe(lib.db_sqlite.create_index(conn, 'name'))

    # get interface data and store it to database
    try:
        net_io_counters = psutil.net_io_counters(pernic=True, nowrap=True)
    except ValueError:
        lib.base.cu('psutil raised an error')

    now = lib.time.now()

    interfaces = []
    for interface, values in net_io_counters.items():
        if not interface or any(interface.startswith(i) for i in args.IGNORE):
            continue
        data = {}
        data['name'] = interface
        data['bytes_sent'] = getattr(values, 'bytes_sent', 0)
        data['bytes_recv'] = getattr(values, 'bytes_recv', 0)
        data['packets_sent'] = getattr(values, 'packets_sent', 0)
        data['packets_recv'] = getattr(values, 'packets_recv', 0)
        data['errin'] = getattr(values, 'errin', 0)
        data['errout'] = getattr(values, 'errout', 0)
        data['dropin'] = getattr(values, 'dropin', 0)
        data['dropout'] = getattr(values, 'dropout', 0)
        data['timestamp'] = now
        interfaces.append(interface)
        lib.base.coe(lib.db_sqlite.insert(conn, data))

    lib.base.coe(lib.db_sqlite.cut(conn, _max=args.COUNT * len(interfaces)))
    lib.base.coe(lib.db_sqlite.commit(conn))

    # init some vars
    msg = 'No I/O.'
    perfdata = ''
    state = STATE_OK
    table_values = []

    max_rw = 0  # interface with the highest sum of rx/tx: show this on top later on
    # we warn about a "count" period/amount of time, not about the current situation above
    # (what might be a peak only)
    for interface in sorted(interfaces):
        # get all historical data rows for a specific interface, newest item first
        interfacedata = lib.base.coe(
            lib.db_sqlite.select(
                conn,
                """
            SELECT *
            FROM perfdata
            WHERE name = :name
            ORDER BY timestamp DESC
            """,
                {'name': interface},
            )
        )
        if len(interfacedata) < 2:
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # calculate current rates
        timestamp_diff = (
            interfacedata[0]['timestamp'] - interfacedata[1]['timestamp']
        )  # in seconds
        if timestamp_diff == 0:
            timestamp_diff = 1
        bytes_recv_per_second1 = int(
            float(interfacedata[0]['bytes_recv'] - interfacedata[1]['bytes_recv'])
            / timestamp_diff
        )
        bytes_sent_per_second1 = int(
            float(interfacedata[0]['bytes_sent'] - interfacedata[1]['bytes_sent'])
            / timestamp_diff
        )
        throughput1 = bytes_recv_per_second1 + bytes_sent_per_second1
        if any(
            [
                timestamp_diff < 0,
                bytes_recv_per_second1 < 0,
                bytes_sent_per_second1 < 0,
                throughput1 < 0,
            ]
        ):
            # happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # store the max. measured throughput in cache
        throughput_db = lib.cache.get(
            f'network-io-{interface}-throughput-max',
            filename='linuxfabrik-monitoring-plugins-network-io.db',
        )
        if throughput_db is False:
            # unknown interface, no value
            throughput_max = (
                10 * 1024 * 1024
            )  # interface should be capable of 10 MB/sec
            lib.cache.set(
                f'network-io-{interface}-throughput-max',
                throughput_max,
                filename='linuxfabrik-monitoring-plugins-network-io.db',
            )
        elif throughput1 > int(throughput_db):
            throughput_max = throughput1
            lib.cache.set(
                f'network-io-{interface}-throughput-max',
                throughput_max,
                filename='linuxfabrik-monitoring-plugins-network-io.db',
            )
        else:
            throughput_max = int(throughput_db)

        # perfdata
        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_recv',
            interfacedata[0]['bytes_recv'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_recv_per_second1',
            bytes_recv_per_second1,
            uom='B',
            _min=0,
            _max=throughput_max,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_sent',
            interfacedata[0]['bytes_sent'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_sent_per_second1',
            bytes_sent_per_second1,
            uom='B',
            _min=0,
            _max=throughput_max,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_packets_sent',
            interfacedata[0]['packets_sent'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_packets_recv',
            interfacedata[0]['packets_recv'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_errin',
            interfacedata[0]['errin'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_errout',
            interfacedata[0]['errout'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_dropin',
            interfacedata[0]['dropin'],
            uom='c',
            _min=0,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_dropout',
            interfacedata[0]['dropout'],
            uom='c',
            _min=0,
        )

        perfdata += lib.base.get_perfdata(
            f'{interface}_throughput1',
            throughput1,
            _min=0,
            _max=throughput_max,
        )

        if throughput1 > max_rw:
            msg = (
                f'{interface}:'
                f' {lib.human.bytes2human(bytes_recv_per_second1)}/s'
                f' rx,'
                f' {lib.human.bytes2human(bytes_sent_per_second1)}/s'
                f' tx (current)'
            )
            max_rw = bytes_recv_per_second1 + bytes_sent_per_second1

        # calculate rx/tx rate over the entire period
        if len(interfacedata) != args.COUNT:
            # not enough data yet
            continue
        timestamp_diff = (
            interfacedata[0]['timestamp'] - interfacedata[args.COUNT - 1]['timestamp']
        )  # in seconds
        if timestamp_diff == 0:
            timestamp_diff = 1
        bytes_recv_per_second15 = float(
            (
                interfacedata[0]['bytes_recv']
                - interfacedata[args.COUNT - 1]['bytes_recv']
            )
            / timestamp_diff
        )
        bytes_recv_per_second15 = max(bytes_recv_per_second15, 0)
        bytes_sent_per_second15 = float(
            (
                interfacedata[0]['bytes_sent']
                - interfacedata[args.COUNT - 1]['bytes_sent']
            )
            / timestamp_diff
        )
        bytes_sent_per_second15 = max(bytes_sent_per_second15, 0)
        throughput15 = (
            bytes_recv_per_second15 + bytes_sent_per_second15
        )  # let's just call it like in cpu-usage

        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_recv_per_second15',
            bytes_recv_per_second15,
            uom='B',
            _min=0,
            _max=throughput_max,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_bytes_sent_per_second15',
            bytes_sent_per_second15,
            uom='B',
            _min=0,
            _max=throughput_max,
        )
        perfdata += lib.base.get_perfdata(
            f'{interface}_throughput15',
            throughput15,
            warn=throughput_max * args.WARN / 100,
            crit=throughput_max * args.CRIT / 100,
            _min=0,
            _max=throughput_max,
        )

        # get state based on max measured I/O values
        interface_state = lib.base.get_state(
            throughput15,
            throughput_max * args.WARN / 100,
            throughput_max * args.CRIT / 100,
        )
        state = lib.base.get_worst(interface_state, state)

        table_values.append(
            {
                'name': interface,
                'max': lib.human.bytes2human(throughput_max),
                'rx1': lib.human.bytes2human(bytes_recv_per_second1),
                'tx1': lib.human.bytes2human(bytes_sent_per_second1),
                'rx15': lib.human.bytes2human(bytes_recv_per_second15),
                'tx15': lib.human.bytes2human(bytes_sent_per_second15),
                't15': lib.human.bytes2human(throughput15)
                + lib.base.state2str(interface_state, prefix=' '),
            }
        )

    lib.db_sqlite.close(conn)

    msg = msg + '\n\n'
    if len(table_values) > 0:

        # build the message
        msg += lib.base.get_table(
            table_values,
            [
                'name',
                'max',
                'rx1',
                'tx1',
                'rx15',
                'tx15',
                't15',
            ],
            header=[
                'Interface',
                'rxtx-max/s',
                'rx1/s',
                'tx1/s',
                f'rx{args.COUNT}/s',
                f'tx{args.COUNT}/s',
                f'rxtx{args.COUNT}/s',
            ],
        )

    # over and out
    lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
