#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import os
import platform
import sys

import lib.args
import lib.base
import lib.cache
import lib.db_sqlite
import lib.disk
import lib.human
import lib.time
import lib.txt
import lib.version
from lib.globals import STATE_CRIT, STATE_OK, STATE_UNKNOWN, STATE_WARN

try:
    import psutil
except ImportError:
    print('Python module "psutil" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026040701'

DESCRIPTION = """Checks disk I/O bandwidth over time and alerts on sustained saturation, not
short spikes. The check records per-disk read/write counters and then derives current (R1/W1)
and period averages (R{COUNT}/W{COUNT}). It compares the period's total bandwidth against the
maximum ever observed for that disk (RWmax). WARN/CRIT trigger if the period average exceeds
the configured percentage of RWmax for COUNT consecutive runs.

On Linux, the check also monitors the system-wide iowait percentage (CPU time spent waiting for
I/O). The raw iowait value is normalized by multiplying it with the number of logical CPUs, so
that 100% always means one CPU core is fully I/O-saturated, regardless of the total number of
CPUs. This makes the default thresholds (80/90%) work consistently across different hardware.
Like bandwidth alerts, iowait alerts require COUNT consecutive threshold violations.

Perfdata is emitted for each disk (busy_time, read_bytes, read_time, write_bytes, write_time)
and for iowait, so you can graph trends. On Linux the check automatically focuses on "real"
block devices with mountpoints; on Windows it uses psutil's disk counters. Optionally, `--top`
lists the processes that generated the most I/O traffic (read/write totals) to help identify
offenders.

This check is cross-platform and works on Linux, Windows, and all psutil-supported systems.
The check stores its short trend state locally in an SQLite DB to evaluate sustained load across
runs."""


DEFAULT_CACHE_EXPIRE = 90
DEFAULT_COUNT = (
    5  # measurements; if check runs once per minute, this is a 5 minute interval
)
DEFAULT_CRIT = 90  # %
DEFAULT_MATCH = ''
DEFAULT_TOP = 5
DEFAULT_WARN = 80  # %

DEFAULT_IOWAIT_CRIT = 90  # %; normalized so that 100% = one fully I/O-saturated core
DEFAULT_IOWAIT_WARN = 80  # %
cpu_count = os.cpu_count() or 1


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '--count',
        help=lib.args.help('--count') + ' Default: %(default)s',
        dest='COUNT',
        type=int,
        default=DEFAULT_COUNT,
    )

    parser.add_argument(
        '--critical',
        help='CRIT threshold for disk bandwidth saturation as a percentage of the observed maximum, '
        'measured over the last `--count` runs. '
        'Default: >= %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--iowait-critical',
        help='CRIT threshold for normalized iowait in percent (Linux only). '
        'The iowait value is normalized so that 100%% means one CPU core is fully I/O-saturated. '
        'Values above 100%% indicate that more than one core is waiting for I/O. '
        'Default: >= %(default)s',
        dest='IOWAIT_CRIT',
        type=int,
        default=DEFAULT_IOWAIT_CRIT,
    )

    parser.add_argument(
        '--iowait-warning',
        help='WARN threshold for normalized iowait in percent (Linux only). '
        'The iowait value is normalized so that 100%% means one CPU core is fully I/O-saturated. '
        'Values above 100%% indicate that more than one core is waiting for I/O. '
        'Default: >= %(default)s',
        dest='IOWAIT_WARN',
        type=int,
        default=DEFAULT_IOWAIT_WARN,
    )

    parser.add_argument(
        '--match',
        help='Filter by disk name. '
        + lib.args.help('--match')
        + ' '
        + 'Default: %(default)s',
        dest='MATCH',
        default=DEFAULT_MATCH,
    )

    parser.add_argument(
        '--top',
        help='Number of top processes to list by I/O traffic. '
        'Use `--top=0` to disable. '
        'Default: %(default)s',
        dest='TOP',
        type=int,
        default=DEFAULT_TOP,
    )

    parser.add_argument(
        '--warning',
        help='WARN threshold for disk bandwidth saturation as a percentage of the observed maximum, '
        'measured over the last `--count` runs. '
        'Default: >= %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    args, _ = parser.parse_known_args()
    return args


def get_max_bandwidth(disk, current_bandwidth):
    """Store the maximum measured bandwidth for the secific disk in cache table."""
    historic_bandwidth = lib.cache.get(
        f'disk-io-{disk}-bandwidth-max',
        filename='linuxfabrik-monitoring-plugins-disk-io.db',
    )
    # Disk should be capable of at least 10 MB/sec (if no info is provided)
    max_bandwidth = max(
        int(historic_bandwidth),
        int(current_bandwidth),
        10 * 1024 * 1024,
    )
    lib.cache.set(
        f'disk-io-{disk}-bandwidth-max',
        max_bandwidth,
        filename='linuxfabrik-monitoring-plugins-disk-io.db',
    )
    return max_bandwidth


def get_rate(ts1, ts2, r1, r2, w1, w2):
    """Given two read-, write- and timestamp-values, return the read- and write-rate plus bandwidth."""
    timediff = abs(ts1 - ts2)  # in seconds
    if timediff == 0:
        return 0, 0, 0, 0
    r = abs(int(float(r1 - r2) / timediff))
    w = abs(int(float(w1 - w2) / timediff))
    return timediff, r, w, r + w


def get_iowait(conn):
    """Compute the system-wide iowait percentage non-blockingly.

    Stores a raw psutil.cpu_times() snapshot in the ``iowait_raw`` table on
    each call and derives iowait% from the delta to the previous snapshot.
    Returns None on the first run (no previous snapshot) or on platforms that
    do not expose iowait (e.g. Windows).
    """
    if not lib.base.LINUX:
        return None

    definition = """
        ts REAL NOT NULL,
        guest REAL DEFAULT 0,
        guest_nice REAL DEFAULT 0,
        idle REAL DEFAULT 0,
        iowait REAL DEFAULT 0,
        irq REAL DEFAULT 0,
        nice REAL DEFAULT 0,
        softirq REAL DEFAULT 0,
        steal REAL DEFAULT 0,
        system REAL DEFAULT 0,
        user REAL DEFAULT 0
    """
    lib.base.coe(lib.db_sqlite.create_table(conn, definition, table='iowait_raw'))

    last = lib.base.coe(
        lib.db_sqlite.select(
            conn,
            'SELECT * FROM iowait_raw LIMIT 1',
            fetchone=True,
        )
    )

    now_ct = psutil.cpu_times()
    now = lib.time.now()
    fields = (
        'guest',
        'guest_nice',
        'idle',
        'iowait',
        'irq',
        'nice',
        'softirq',
        'steal',
        'system',
        'user',
    )
    now_d = {f: getattr(now_ct, f, 0.0) for f in fields}

    if last:
        total = 0.0
        iowait_delta = 0.0
        for k, v in now_d.items():
            dv = max(0.0, v - float(last.get(k, 0.0)))
            total += dv
            if k == 'iowait':
                iowait_delta = dv

        lib.base.coe(lib.db_sqlite.delete(conn, 'DELETE FROM iowait_raw WHERE 1=1'))
        lib.base.coe(
            lib.db_sqlite.insert(conn, {'ts': now, **now_d}, table='iowait_raw')
        )
        lib.base.coe(lib.db_sqlite.commit(conn))

        if total <= 0.0:
            return None

        return round((iowait_delta / total) * 100.0, 1)

    # first run: store snapshot, no result yet
    lib.base.coe(lib.db_sqlite.insert(conn, {'ts': now, **now_d}, table='iowait_raw'))
    lib.base.coe(lib.db_sqlite.commit(conn))
    return None


def get_iowait_from_db(conn, threshold):
    """Return the number of iowait_trend rows where iowait exceeds the given threshold."""
    result = lib.base.coe(
        lib.db_sqlite.select(
            conn,
            'SELECT count(*) as cnt FROM iowait_trend WHERE iowait > :threshold',
            {'threshold': threshold},
            fetchone=True,
        )
    )
    return int(result['cnt'])


def top(count):
    """Get top X processes that generated the most I/O traffic."""
    # Fast path: nothing to print, so nothing to scan
    if count <= 0:
        return ''

    totals = {}  # name -> {'r': bytes, 'w': bytes}
    msg = ''

    # Prefer attrs path (psutil >= 5.3.0): fewer syscalls, fewer exceptions
    if lib.version.version(psutil.__version__) >= lib.version.version('5.3.0'):
        try:
            for p in psutil.process_iter(attrs=['name', 'io_counters'], ad_value=None):
                try:
                    info = p.info
                    name = info.get('name') or ''
                    ioc = info.get('io_counters')
                    if not ioc:
                        continue
                    entry = totals.setdefault(name, {'r': 0, 'w': 0})
                    # accumulate read/write bytes; guard against None
                    entry['r'] += getattr(ioc, 'read_bytes', 0) or 0
                    entry['w'] += getattr(ioc, 'write_bytes', 0) or 0
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    # process vanished or denied: skip and continue
                    continue
        except Exception:
            # Defensive: if attrs/ad_value path misbehaves anywhere, fall back below.
            pass

    # Legacy / fallback path
    if not totals:
        try:
            for proc in psutil.process_iter():
                try:
                    info = proc.as_dict(attrs=['name', 'io_counters'])
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    continue
                name = info.get('name') or ''
                ioc = info.get('io_counters')
                if not ioc:
                    continue
                entry = totals.setdefault(name, {'r': 0, 'w': 0})
                entry['r'] += getattr(ioc, 'read_bytes', 0) or 0
                entry['w'] += getattr(ioc, 'write_bytes', 0) or 0
        except psutil.NoSuchProcess:
            pass

    if not totals:
        return msg

    # Sort by total bytes (read+write) desc and show the top N
    ranked = sorted(
        totals.items(), key=lambda kv: kv[1]['r'] + kv[1]['w'], reverse=True
    )[:count]

    # If everything is truly zero, keep output empty
    if ranked and (ranked[0][1]['r'] + ranked[0][1]['w'] > 0):
        lines = [f'\nTop {count} processes that generate the most I/O traffic (r/w):']
        for i, (name, io) in enumerate(ranked, start=1):
            lines.append(
                f'{i}. {name}: '
                f'{lib.human.bytes2human(io["r"])}/'
                f'{lib.human.bytes2human(io["w"])}'
            )
        msg = '\n'.join(lines) + '\n'
    return msg


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # On Windows we can work with what psutil returns, but on Linux psutil returns too much noise
    # from devices of all kinds. There we use a different approach, but therefore we have
    # to handle both platforms separately. :-(
    # Kernel 5.5 added 2 more fields to /proc/diskstats, requiring another
    # change after the one for 4.18, which recently added 4 fields.
    # To prevent "ValueError: not sure how to interpret line",
    # we check the version of psutil first.
    if lib.base.LINUX and all(
        [
            lib.version.version(platform.release()) >= lib.version.version('4.18.0'),
            lib.version.version(psutil.__version__) < lib.version.version('5.7.0'),
        ]
    ):
        lib.base.oao(
            'Nothing checked. '
            'Running Kernel >= 4.18, this check needs the Python module '
            f'psutil v5.7.0+ (installed {psutil.__version__}; have a look at '
            'https://github.com/giampaolo/psutil/pull/1665 '
            'for details).',
            STATE_OK,
            always_ok=args.ALWAYS_OK,
        )

    # bd: block device; dmd: device mapper device, mp: mountpoint

    # create the perfdata table
    conn = lib.base.coe(
        lib.db_sqlite.connect(filename='linuxfabrik-monitoring-plugins-disk-io.db')
    )

    # Best-effort: reduce IO stalls and file locking on Windows without changing outputs
    try:
        conn.execute('PRAGMA journal_mode=WAL')
        conn.execute('PRAGMA synchronous=NORMAL')
    except Exception:
        pass

    # same structure for Linux and Windows, makes life easier
    definition = """
        bd TEXT NOT NULL,
        dmd TEXT,
        mp TEXT,
        busy_time INT DEFAULT 0,
        read_bytes INT DEFAULT 0,
        read_merged_count INT DEFAULT 0,
        read_time INT DEFAULT 0,
        write_bytes INT DEFAULT 0,
        write_merged_count INT DEFAULT 0,
        write_time INT DEFAULT 0,
        timestamp INT DEFAULT 0
    """
    lib.base.coe(lib.db_sqlite.create_table(conn, definition, drop_table_first=False))
    lib.base.coe(lib.db_sqlite.create_index(conn, 'bd'))

    # iowait tracking (Linux only)
    lib.base.coe(
        lib.db_sqlite.create_table(
            conn,
            'iowait REAL NOT NULL',
            table='iowait_trend',
        )
    )
    iowait_pct = get_iowait(conn)
    if iowait_pct is not None:
        # normalize: multiply by cpu_count so that 100% = one fully I/O-saturated core
        iowait_pct = round(iowait_pct * cpu_count, 1)
        lib.base.coe(
            lib.db_sqlite.insert(
                conn,
                {'iowait': iowait_pct},
                table='iowait_trend',
            )
        )

    # init some vars
    msg = f'No I/O on `{args.MATCH}`.' if args.MATCH else 'No I/O.'
    perfdata = ''
    state = STATE_OK
    table_values = []
    compiled_regex = lib.base.coe(lib.txt.compile_regex(args.MATCH))
    now = lib.time.now()
    busiest_disk = 0  # disk with the highest sum of r/w: show this on top later on
    disks = []

    # fetch data
    try:
        disk_io_counters = psutil.disk_io_counters(perdisk=True)
    except ValueError:
        lib.base.cu('psutil raised an error')

    # analyze and enrich data, store it to database
    if lib.base.WINDOWS:
        for disk, values in disk_io_counters.items():
            # filter devices that do not match
            if args.MATCH and not lib.base.coe(
                lib.txt.match_regex(compiled_regex, disk)
            ):
                continue

            # read_count and write_count are the same value for all disks, so simply ignore them
            data = {}
            data['bd'] = disk
            data['dmd'] = ''
            data['mp'] = ''
            data['busy_time'] = getattr(values, 'busy_time', 0)
            data['read_bytes'] = getattr(values, 'read_bytes', 0)
            # data['read_count'] = getattr(values, 'read_count', 0)
            data['read_merged_count'] = getattr(values, 'read_merged_count', 0)
            data['read_time'] = getattr(values, 'read_time', 0)
            data['write_bytes'] = getattr(values, 'write_bytes', 0)
            # data['write_count'] = getattr(values, 'write_count', 0)
            data['write_merged_count'] = getattr(values, 'write_merged_count', 0)
            data['write_time'] = getattr(values, 'write_time', 0)
            data['timestamp'] = now
            disks.append({'bd': disk, 'dmd': '', 'mp': ''})

            # store it to database
            lib.base.coe(lib.db_sqlite.insert(conn, data))
    else:
        real_disks = lib.disk.get_real_disks()
        for disk in real_disks:
            # filter devices that do not match
            if args.MATCH and not any(
                (
                    lib.base.coe(lib.txt.match_regex(compiled_regex, disk['bd'])),
                    lib.base.coe(lib.txt.match_regex(compiled_regex, disk['dmd'])),
                    lib.base.coe(lib.txt.match_regex(compiled_regex, disk['mp'])),
                )
            ):
                continue

            psutil_name = os.path.basename(disk['bd'])
            if psutil_name not in disk_io_counters:
                continue

            data = {}
            data['bd'] = disk['bd']
            data['dmd'] = disk['dmd']
            data['mp'] = disk['mp']
            # read_count and write_count are the same value over all disks, so simply ignore them
            data['busy_time'] = getattr(disk_io_counters[psutil_name], 'busy_time', 0)
            data['read_bytes'] = getattr(disk_io_counters[psutil_name], 'read_bytes', 0)
            data['read_merged_count'] = getattr(
                disk_io_counters[psutil_name],
                'read_merged_count',
                0,
            )
            data['read_time'] = getattr(disk_io_counters[psutil_name], 'read_time', 0)
            data['write_bytes'] = getattr(
                disk_io_counters[psutil_name], 'write_bytes', 0
            )
            data['write_merged_count'] = getattr(
                disk_io_counters[psutil_name],
                'write_merged_count',
                0,
            )
            data['write_time'] = getattr(disk_io_counters[psutil_name], 'write_time', 0)
            data['timestamp'] = now
            disks.append(disk)

            # store it to database
            lib.base.coe(lib.db_sqlite.insert(conn, data))

    if not disks:
        lib.db_sqlite.close(conn)
        lib.base.oao(
            'No disks matched.' if args.MATCH else 'No disks found.',
            state,
            '',
            always_ok=args.ALWAYS_OK,
        )

    # truncate old data (just keep args.COUNT for each disk) and commit
    lib.base.coe(lib.db_sqlite.cut(conn, _max=args.COUNT * len(disks)))
    lib.base.coe(lib.db_sqlite.cut(conn, table='iowait_trend', _max=args.COUNT))
    lib.base.coe(lib.db_sqlite.commit(conn))

    # from here on just working on the database
    # warn about a "count" period/amount of time, not about the current situation above
    # (what might be a peak only)
    for disk in disks:
        # get all historical data rows for a specific disk, newest item first
        data = lib.base.coe(
            lib.db_sqlite.select(
                conn,
                """
            SELECT *
            FROM perfdata
            WHERE bd = :name
            ORDER BY timestamp DESC
            """,
                {'name': disk['bd']},
            )
        )

        if len(data) < 2:
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # calculate current rates (like "load1")
        timediff, read_bytes_per_second1, write_bytes_per_second1, bandwidth1 = (
            get_rate(
                data[0]['timestamp'],
                data[1]['timestamp'],
                data[0]['read_bytes'],
                data[1]['read_bytes'],
                data[0]['write_bytes'],
                data[1]['write_bytes'],
            )
        )
        if timediff <= 0:  # often happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # get the maximum disk bandwidth in disks' history
        bandwidth_max = get_max_bandwidth(disk['bd'], bandwidth1)

        if bandwidth1 > busiest_disk:
            # get the current busiest disk for the first line of the message
            msg = (
                f'{disk["bd"]}: '
                f'{lib.human.bytes2human(read_bytes_per_second1)}/s read1, '
                f'{lib.human.bytes2human(write_bytes_per_second1)}/s write1, '
                f'{lib.human.bytes2human(bandwidth1)}/s total, '
                f'{lib.human.bytes2human(bandwidth_max)}/s max'
            )
            if args.MATCH:
                msg += f' (disks matching `{args.MATCH}`).'
            busiest_disk = bandwidth1

        # calculate read/write rate over the entire period (like "load15")
        if len(data) != args.COUNT:  # not enough data yet
            continue

        timediff, read_bytes_per_second15, write_bytes_per_second15, bandwidth15 = (
            get_rate(
                data[0]['timestamp'],
                data[args.COUNT - 1]['timestamp'],
                data[0]['read_bytes'],
                data[args.COUNT - 1]['read_bytes'],
                data[0]['write_bytes'],
                data[args.COUNT - 1]['write_bytes'],
            )
        )
        if timediff <= 0:  # often happens after a reboot
            lib.db_sqlite.close(conn)
            lib.base.oao('Waiting for more data.', state)

        # get state based on max measured I/O values
        local_state = lib.base.get_state(
            bandwidth15,
            bandwidth_max * args.WARN / 100,
            bandwidth_max * args.CRIT / 100,
        )
        state = lib.base.get_worst(local_state, state)

        bd = disk['bd'].replace('/dev/', '')
        table_values.append(
            {
                'bd': bd,
                'dmd': disk['dmd'].replace('/dev/mapper/', ''),
                'mp': disk['mp'],
                'max': lib.human.bytes2human(bandwidth_max),
                'r1': lib.human.bytes2human(read_bytes_per_second1),
                'w1': lib.human.bytes2human(write_bytes_per_second1),
                'r15': lib.human.bytes2human(read_bytes_per_second15),
                'w15': lib.human.bytes2human(write_bytes_per_second15),
                't15': lib.human.bytes2human(bandwidth15)
                + lib.base.state2str(local_state, prefix=' '),
            }
        )

        # perfdata
        try:
            perfdata += lib.base.get_perfdata(
                f'{bd}_busy_time',
                data[0]['busy_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_read_bytes',
                data[0]['read_bytes'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_read_time',
                data[0]['read_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_write_bytes',
                data[0]['write_bytes'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
            perfdata += lib.base.get_perfdata(
                f'{bd}_write_time',
                data[0]['write_time'],
                uom='c',
                warn=None,
                crit=None,
                _min=0,
                _max=None,
            )
        except (KeyError, TypeError):
            pass

    # check iowait alerting (only if enough consecutive values collected)
    iowait_state = STATE_OK
    if iowait_pct is not None:
        if get_iowait_from_db(conn, args.IOWAIT_CRIT) == args.COUNT:
            iowait_state = STATE_CRIT
        elif get_iowait_from_db(conn, args.IOWAIT_WARN) == args.COUNT:
            iowait_state = STATE_WARN
        state = lib.base.get_worst(iowait_state, state)

    lib.db_sqlite.close(conn)

    # prepend iowait to the message
    if iowait_pct is not None:
        iowait_msg = (
            f'iowait: {iowait_pct}%{lib.base.state2str(iowait_state, prefix=" ")}'
        )
        msg = f'{iowait_msg}. {msg}'
        perfdata += lib.base.get_perfdata(
            'iowait',
            iowait_pct,
            uom='%',
            warn=args.IOWAIT_WARN,
            crit=args.IOWAIT_CRIT,
            _min=0,
            _max=None,
        )

    # build the message
    msg = msg + '\n\n'
    if table_values:
        msg += lib.base.get_table(
            table_values,
            [
                'bd',
                'mp',
                'dmd',
                'max',
                'r1',
                'w1',
                'r15',
                'w15',
                't15',
            ],
            header=[
                'Name',
                'MntPnts',
                'DvMppr',
                'RWmax/s',
                'R1/s',
                'W1/s',
                f'R{args.COUNT}/s',
                f'W{args.COUNT}/s',
                f'RW{args.COUNT}/s',
            ],
        )

    # Top X processes that generated the most I/O traffic
    msg += top(args.TOP)

    # over and out
    lib.base.oao(
        msg.replace('\n\n\n', '\n\n'), state, perfdata, always_ok=args.ALWAYS_OK
    )


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
