#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details.
"""

import argparse
import os
import platform
import sys

import lib.args
import lib.base
import lib.cache
import lib.db_sqlite
import lib.disk
import lib.human
import lib.time
import lib.txt
import lib.version
from lib.globals import (STATE_OK, STATE_UNKNOWN)

try:
    import psutil
except ImportError:
    print('Python module "psutil" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2025090703'

DESCRIPTION = """Checks disk I/O bandwidth over time and alerts on sustained saturation, not
short spikes. The check records per-disk read/write counters and then derives current (R1/W1)
and period averages (R{COUNT}/W{COUNT}). It compares the period’s total bandwidth against the
maximum ever observed for that disk (RWmax). WARN/CRIT trigger if the period average exceeds
the configured percentage of RWmax for COUNT consecutive runs.

Perfdata is emitted for each disk (busy_time, read_bytes, read_time, write_bytes, write_time)
so you can graph trends. On Linux the check automatically focuses on "real" block devices with
mountpoints; on Windows it uses psutil’s disk counters. Optionally, `--top` lists the processes
that generated the most I/O traffic (read/write totals) to help identify offenders.

This check is cross-platform and works on Linux, Windows, and all psutil-supported systems.
The check stores its short trend state locally in an SQLite DB to evaluate sustained load across
runs."""


DEFAULT_CACHE_EXPIRE = 90
DEFAULT_COUNT = 5    # measurements; if check runs once per minute, this is a 5 minute interval
DEFAULT_CRIT = 90    # %
DEFAULT_MATCH = ''
DEFAULT_TOP = 5
DEFAULT_WARN = 80    # %


def parse_args():
    """Parse command line arguments using argparse.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V', '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}'
    )

    parser.add_argument(
        '--always-ok',
        help='Always returns OK.',
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--count',
        help='Number of times the value must exceed specified thresholds before alerting. '
             'Default: %(default)s',
        dest='COUNT',
        type=int,
        default=DEFAULT_COUNT,
    )

    parser.add_argument(
        '--critical',
        help='Threshold for disk bandwidth saturation (over the last `--count` measurements) as a '
             'percentage of the maximum bandwidth the disk can support. '
             'Default: >= %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--match',
        help='Match on disk names. '
             + lib.args.help('--match') + ' '
             + 'Default: %(default)s',
        dest='MATCH',
        default=DEFAULT_MATCH,
    )

    parser.add_argument(
        '--top',
        help='List x "Top processes that generated the most I/O traffic". '
             'Use `--top=0` to disable this feature. '
             'Default: %(default)s',
        dest='TOP',
        type=int,
        default=DEFAULT_TOP,
    )

    parser.add_argument(
        '--warning',
        help='Threshold for disk bandwidth saturation (over the last `--count` measurements) as a '
             'percentage of the maximum bandwidth the disk can support. '
             'Default: >= %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    return parser.parse_args()


def get_max_bandwidth(disk, current_bandwidth):
    """Store the maximum measured bandwidth for the secific disk in cache table.
    """
    historic_bandwidth = lib.cache.get(
        f'disk-io-{disk}-bandwidth-max',
        filename='linuxfabrik-monitoring-plugins-disk-io.db',
    )
    # Disk should be capable of at least 10 MB/sec (if no info is provided)
    max_bandwidth = max(
        int(historic_bandwidth),
        int(current_bandwidth),
        10 * 1024 * 1024,
    )
    lib.cache.set(
        f'disk-io-{disk}-bandwidth-max',
        max_bandwidth,
        filename='linuxfabrik-monitoring-plugins-disk-io.db',
    )
    return max_bandwidth


def get_rate(ts1, ts2, r1, r2, w1, w2):
    """Given two read-, write- and timestamp-values, return the read- and write-rate plus bandwidth.
    """
    timediff = abs(ts1 - ts2)  # in seconds
    if timediff == 0:
        return 0, 0, 0, 0
    r = abs(int(float(r1 - r2) / timediff))
    w = abs(int(float(w1 - w2) / timediff))
    return timediff, r, w, r + w


def top(count):
    """Get top X processes that generated the most I/O traffic.
    """
    # Fast path: nothing to print, so nothing to scan
    if count <= 0:
        return ''

    totals = {}  # name -> {'r': bytes, 'w': bytes}
    msg = ''

    # Prefer attrs path (psutil >= 5.3.0): fewer syscalls, fewer exceptions
    if lib.version.version(psutil.__version__) >= lib.version.version('5.3.0'):
        try:
            for p in psutil.process_iter(attrs=['name', 'io_counters'], ad_value=None):
                try:
                    info = p.info
                    name = info.get('name') or ''
                    ioc = info.get('io_counters')
                    if not ioc:
                        continue
                    entry = totals.setdefault(name, {'r': 0, 'w': 0})
                    # accumulate read/write bytes; guard against None
                    entry['r'] += getattr(ioc, 'read_bytes', 0) or 0
                    entry['w'] += getattr(ioc, 'write_bytes', 0) or 0
                except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
                    # process vanished or denied: skip and continue
                    continue
        except Exception:
            # Defensive: if attrs/ad_value path misbehaves anywhere, fall back below.
            pass

    # Legacy / fallback path
    if not totals:
        try:
            for proc in psutil.process_iter():
                try:
                    info = proc.as_dict(attrs=['name', 'io_counters'])
                except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
                    continue
                name = info.get('name') or ''
                ioc = info.get('io_counters')
                if not ioc:
                    continue
                entry = totals.setdefault(name, {'r': 0, 'w': 0})
                entry['r'] += getattr(ioc, 'read_bytes', 0) or 0
                entry['w'] += getattr(ioc, 'write_bytes', 0) or 0
        except psutil.NoSuchProcess:
            pass

    if not totals:
        return msg

    # Sort by total bytes (read+write) desc and show the top N
    ranked = sorted(totals.items(), key=lambda kv: kv[1]['r'] + kv[1]['w'], reverse=True)[:count]

    # If everything is truly zero, keep output empty
    if ranked and (ranked[0][1]['r'] + ranked[0][1]['w'] > 0):
        lines = [f"\nTop {count} processes that generate the most I/O traffic (r/w):"]
        for i, (name, io) in enumerate(ranked, start=1):
            lines.append(
                f"{i}. {name}: "
                f"{lib.human.bytes2human(io['r'])}/"
                f"{lib.human.bytes2human(io['w'])}"
            )
        msg = "\n".join(lines) + "\n"
    return msg


def main():
    """The main function. Hier spielt die Musik.
    """

    # parse the command line, exit with UNKNOWN if it fails
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # On Windows we can work with what psutil returns, but on Linux psutil returns too much noise
    # from devices of all kinds. There we use a different approach, but therefore we have
    # to handle both platforms separately. :-(
    # Kernel 5.5 added 2 more fields to /proc/diskstats, requiring another
    # change after the one for 4.18, which recently added 4 fields.
    # To prevent "ValueError: not sure how to interpret line",
    # we check the version of psutil first.
    if lib.base.LINUX and all([
        lib.version.version(platform.release()) >= lib.version.version('4.18.0'),
        lib.version.version(psutil.__version__) < lib.version.version('5.7.0')
    ]):
        lib.base.oao('Nothing checked. '
                     'Running Kernel >= 4.18, this check needs the Python module '
                     f'psutil v5.7.0+ (installed {psutil.__version__}; have a look at '
                     'https://github.com/giampaolo/psutil/pull/1665 '
                     'for details).',
                     STATE_OK,
                     always_ok=args.ALWAYS_OK,
        )

    # bd: block device; dmd: device mapper device, mp: mountpoint

    # create the perfdata table
    conn = lib.base.coe(
        lib.db_sqlite.connect(filename='linuxfabrik-monitoring-plugins-disk-io.db')
    )

    # Best-effort: reduce IO stalls and file locking on Windows without changing outputs
    try:
        conn.execute('PRAGMA journal_mode=WAL')
        conn.execute('PRAGMA synchronous=NORMAL')
    except Exception:
        pass

    # same structure for Linux and Windows, makes life easier
    definition = '''
        bd TEXT NOT NULL,
        dmd TEXT,
        mp TEXT,
        busy_time INT DEFAULT 0,
        read_bytes INT DEFAULT 0,
        read_merged_count INT DEFAULT 0,
        read_time INT DEFAULT 0,
        write_bytes INT DEFAULT 0,
        write_merged_count INT DEFAULT 0,
        write_time INT DEFAULT 0,
        timestamp INT DEFAULT 0
    '''
    lib.base.coe(lib.db_sqlite.create_table(conn, definition, drop_table_first=False))
    lib.base.coe(lib.db_sqlite.create_index(conn, 'bd'))

    # init some vars
    msg = f'No I/O on `{args.MATCH}`.' if args.MATCH else 'No I/O.'
    perfdata = ''
    state = STATE_OK
    table_values = []
    compiled_regex = lib.base.coe(lib.txt.compile_regex(args.MATCH))
    now = lib.time.now()
    busiest_disk = 0  # disk with the highest sum of r/w: show this on top later on
    disks = []

    # fetch data
    try:
        disk_io_counters = psutil.disk_io_counters(perdisk=True)
    except ValueError:
        lib.base.cu('psutil raised an error')

    # analyze and enrich data, store it to database
    if lib.base.WINDOWS:
        for disk, values in disk_io_counters.items():
            # filter devices that do not match
            if args.MATCH and not lib.base.coe(lib.txt.match_regex(compiled_regex, disk)):
                continue

            # read_count and write_count are the same value for all disks, so simply ignore them
            data = {}
            data['bd'] = disk
            data['dmd'] = ''
            data['mp'] = ''
            data['busy_time'] = getattr(values, 'busy_time', 0)
            data['read_bytes'] = getattr(values, 'read_bytes', 0)
            # data['read_count'] = getattr(values, 'read_count', 0)
            data['read_merged_count'] = getattr(values, 'read_merged_count', 0)
            data['read_time'] = getattr(values, 'read_time', 0)
            data['write_bytes'] = getattr(values, 'write_bytes', 0)
            # data['write_count'] = getattr(values, 'write_count', 0)
            data['write_merged_count'] = getattr(values, 'write_merged_count', 0)
            data['write_time'] = getattr(values, 'write_time', 0)
            data['timestamp'] = now
            disks.append({'bd': disk, 'dmd': '', 'mp': ''})

            # store it to database
            lib.base.coe(lib.db_sqlite.insert(conn, data))
    else:
        real_disks = lib.disk.get_real_disks()
        for disk in real_disks:
            # filter devices that do not match
            if args.MATCH \
               and not any((
                   lib.base.coe(lib.txt.match_regex(compiled_regex, disk['bd'])),
                   lib.base.coe(lib.txt.match_regex(compiled_regex, disk['dmd'])),
                   lib.base.coe(lib.txt.match_regex(compiled_regex, disk['mp'])),
               )):
                continue

            psutil_name = os.path.basename(disk['bd'])
            if psutil_name not in disk_io_counters:
                continue

            data = {}
            data['bd'] = disk['bd']
            data['dmd'] = disk['dmd']
            data['mp'] = disk['mp']
            # read_count and write_count are the same value over all disks, so simply ignore them
            data['busy_time'] = getattr(disk_io_counters[psutil_name], 'busy_time', 0)
            data['read_bytes'] = getattr(disk_io_counters[psutil_name], 'read_bytes', 0)
            data['read_merged_count'] = getattr(disk_io_counters[psutil_name], 'read_merged_count', 0)  # pylint: disable=C0301
            data['read_time'] = getattr(disk_io_counters[psutil_name], 'read_time', 0)
            data['write_bytes'] = getattr(disk_io_counters[psutil_name], 'write_bytes', 0)
            data['write_merged_count'] = getattr(disk_io_counters[psutil_name], 'write_merged_count', 0)  # pylint: disable=C0301
            data['write_time'] = getattr(disk_io_counters[psutil_name], 'write_time', 0)
            data['timestamp'] = now
            disks.append(disk)

            # store it to database
            lib.base.coe(lib.db_sqlite.insert(conn, data))

    if not disks:
        lib.db_sqlite.close(conn)
        lib.base.oao(
            'No disks matched.' if args.MATCH else 'No disks found.',
            state,
            '',
            always_ok=args.ALWAYS_OK,
        )

    # truncate old data (just keep args.COUNT for each disk) and commit
    lib.base.coe(lib.db_sqlite.cut(conn, _max=args.COUNT * len(disks)))
    lib.base.coe(lib.db_sqlite.commit(conn))

    # from here on just working on the database
    # warn about a "count" period/amount of time, not about the current situation above
    # (what might be a peak only)
    for disk in disks:
        # get all historical data rows for a specific disk, newest item first
        data = lib.base.coe(lib.db_sqlite.select(
            conn,
            '''
            SELECT *
            FROM perfdata
            WHERE bd = :name
            ORDER BY timestamp DESC
            ''',
            {'name': disk['bd']},
        ))

        if len(data) < 2:
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # calculate current rates (like "load1")
        timediff, read_bytes_per_second1, write_bytes_per_second1, bandwidth1 = get_rate(
            data[0]['timestamp'], data[1]['timestamp'],
            data[0]['read_bytes'], data[1]['read_bytes'],
            data[0]['write_bytes'], data[1]['write_bytes'],
        )
        if timediff <= 0:  # often happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # get the maximum disk bandwidth in disks' history
        bandwidth_max = get_max_bandwidth(disk['bd'], bandwidth1)

        if bandwidth1 > busiest_disk:
            # get the current busiest disk for the first line of the message
            msg = f'{disk["bd"]}: ' \
                  f'{lib.human.bytes2human(read_bytes_per_second1)}/s read1, ' \
                  f'{lib.human.bytes2human(write_bytes_per_second1)}/s write1, ' \
                  f'{lib.human.bytes2human(bandwidth1)}/s total, ' \
                  f'{lib.human.bytes2human(bandwidth_max)}/s max'
            if args.MATCH:
                msg += f' (disks matching `{args.MATCH}`).'
            busiest_disk = bandwidth1

        # calculate read/write rate over the entire period (like "load15")
        if len(data) != args.COUNT:  # not enough data yet
            continue

        timediff, read_bytes_per_second15, write_bytes_per_second15, bandwidth15 = get_rate(
            data[0]['timestamp'], data[args.COUNT - 1]['timestamp'],
            data[0]['read_bytes'], data[args.COUNT - 1]['read_bytes'],
            data[0]['write_bytes'], data[args.COUNT - 1]['write_bytes'],
        )
        if timediff <= 0:  # often happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # get state based on max measured I/O values
        local_state = lib.base.get_state(
            bandwidth15,
            bandwidth_max * args.WARN / 100,
            bandwidth_max * args.CRIT / 100,
        )
        state = lib.base.get_worst(local_state, state)

        bd = disk['bd'].replace('/dev/', '')
        table_values.append({
            'bd': bd,
            'dmd': disk['dmd'].replace('/dev/mapper/', ''),
            'mp': disk['mp'],
            'max': lib.human.bytes2human(bandwidth_max),
            'r1': lib.human.bytes2human(read_bytes_per_second1),
            'w1': lib.human.bytes2human(write_bytes_per_second1),
            'r15': lib.human.bytes2human(read_bytes_per_second15),
            'w15': lib.human.bytes2human(write_bytes_per_second15),
            't15': lib.human.bytes2human(bandwidth15) + lib.base.state2str(local_state, prefix=' '),
        })

        # perfdata
        try:
            perfdata += lib.base.get_perfdata(
                f'{bd}_busy_time',
                data[0]['busy_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_read_bytes',
                data[0]['read_bytes'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_read_time',
                data[0]['read_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_write_bytes',
                data[0]['write_bytes'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_write_time',
                data[0]['write_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
        except (KeyError, TypeError):
            pass

    lib.db_sqlite.close(conn)

    # build the message
    msg = msg + '\n\n'
    if table_values:
        msg += lib.base.get_table(
            table_values,
            [
                'bd',
                'mp',
                'dmd',
                'max',
                'r1',
                'w1',
                'r15',
                'w15',
                't15',
            ],
            header=[
                'Name',
                'MntPnts',
                'DvMppr',
                'RWmax/s',
                'R1/s',
                'W1/s',
                f'R{args.COUNT}/s',
                f'W{args.COUNT}/s',
                f'RW{args.COUNT}/s',
            ],
        )

    # Top X processes that generated the most I/O traffic
    msg += top(args.TOP)

    # over and out
    lib.base.oao(msg.replace('\n\n\n', '\n\n'), state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        lib.base.cu()
