#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import json
import sys

import lib.args
import lib.base
import lib.human
import lib.lftest
import lib.version
from lib.globals import STATE_OK, STATE_UNKNOWN

try:
    import psutil
except ImportError:
    print('Python module "psutil" is not installed.')
    sys.exit(STATE_UNKNOWN)


__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026041301'

DESCRIPTION = """Monitors system memory usage and alerts when the overall usage percentage exceeds the
configured thresholds. Reports total, used, available, and free memory plus shared,
buffers, and cached values. Optionally lists the top memory-consuming processes via
--top to help identify the source of high usage."""

DEFAULT_CRIT = 95  # %
DEFAULT_TOP = 5
DEFAULT_WARN = 90  # %


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '-c',
        '--critical',
        help='CRIT threshold for memory usage in percent. Default: %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--top',
        help='Number of top memory-consuming processes to list. '
        'Use `--top=0` to disable. '
        'Default: %(default)s',
        dest='TOP',
        type=int,
        default=DEFAULT_TOP,
    )

    parser.add_argument(
        '--test',
        help=lib.args.help('--test'),
        dest='TEST',
        type=lib.args.csv,
    )

    parser.add_argument(
        '-w',
        '--warning',
        help='WARN threshold for memory usage in percent. Default: %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    args, _ = parser.parse_known_args()
    return args


def _load_memory_fixture(raw_json):
    """Convert a test fixture into the shape the plugin expects from
    `psutil.virtual_memory()`. The plugin uses `getattr(virt, field, 0)`
    so an object with plain attributes is enough. Any fields not
    present in the fixture default to 0.

    Fixture shape:

        {"total": <bytes>, "used": <bytes>, "available": <bytes>,
         "free": <bytes>, "percent": <0..100>, ...}
    """

    class _Virt:
        pass

    data = json.loads(raw_json)
    virt = _Virt()
    defaults = {
        'total': 0,
        'available': 0,
        'percent': 0,
        'used': 0,
        'free': 0,
        'active': 0,
        'inactive': 0,
        'buffers': 0,
        'cached': 0,
        'shared': 0,
        'slab': 0,
    }
    for key, default in defaults.items():
        setattr(virt, key, data.get(key, default))
    return virt


def top(count):
    """Get top X most memory consuming processes."""
    # Fast path: nothing to print, so nothing to scan
    if count <= 0:
        return ''

    procs = {}  # name > {'%': float, 'rss': int}
    msg = ''

    # Prefer attrs path (psutil >= 5.3.0): fewer syscalls, fewer exceptions
    if lib.version.version(psutil.__version__) >= lib.version.version('5.3.0'):
        try:
            for p in psutil.process_iter(
                attrs=['name', 'memory_percent', 'memory_info'], ad_value=None
            ):
                try:
                    info = p.info
                    name = info.get('name') or ''
                    # On some platforms memory_percent can be None briefly; guard it.
                    mem_pct = float(info.get('memory_percent') or 0.0)
                    mi = info.get('memory_info')
                    rss = getattr(mi, 'rss', 0) if mi is not None else 0
                    entry = procs.setdefault(name, {'%': 0.0, 'rss': 0})
                    entry['%'] += mem_pct
                    entry['rss'] += rss
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    continue
        except Exception:
            # Defensive: if psutil attrs path misbehaves on some platform/version, fall back below.
            pass

    # Legacy / fallback path
    if not procs:
        try:
            for proc in psutil.process_iter():
                try:
                    info = proc.as_dict(attrs=['name', 'memory_percent', 'memory_info'])
                except (
                    psutil.NoSuchProcess,
                    psutil.AccessDenied,
                    psutil.ZombieProcess,
                ):
                    continue
                name = info.get('name') or ''
                mem_pct = float(info.get('memory_percent') or 0.0)
                mi = info.get('memory_info')
                rss = getattr(mi, 'rss', 0) if mi is not None else 0
                entry = procs.setdefault(name, {'%': 0.0, 'rss': 0})
                entry['%'] += mem_pct
                entry['rss'] += rss
        except psutil.NoSuchProcess:
            pass

    if not procs:
        return msg

    # Sort by percentage, desc; produce up to 'count'
    ranked = sorted(procs.items(), key=lambda kv: kv[1]['%'], reverse=True)[:count]
    lines = [f'\nTop {count} most memory consuming processes:']
    for i, (name, agg) in enumerate(ranked, start=1):
        lines.append(
            f'{i}. {name}: {lib.human.bytes2human(agg["rss"])} ({agg["%"]:.1f}%)'
        )
    return '\n'.join(lines) + '\n'


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # fetch data
    if args.TEST is None:
        virt = psutil.virtual_memory()
    else:
        stdout, _, _ = lib.lftest.test(args.TEST)
        virt = _load_memory_fixture(stdout)
        # force --top=0 under --test so we don't walk real processes
        args.TOP = 0

    # init some vars
    state = STATE_OK
    perfdata = ''

    # analyze data and build the message
    usage_percent = float(getattr(virt, 'percent', 0))
    state = lib.base.get_state(usage_percent, args.WARN, args.CRIT, _operator='ge')

    # Header: percentage + main sizes
    header_parts = [
        f'{usage_percent:.0f}%{lib.base.state2str(state, prefix=" ")}',
        f'total: {lib.human.bytes2human(getattr(virt, "total", 0))}',
        f'used: {lib.human.bytes2human(getattr(virt, "used", 0))}',
        f'available: {lib.human.bytes2human(getattr(virt, "available", 0))}',
        f'free: {lib.human.bytes2human(getattr(virt, "free", 0))}',
    ]
    msg_header = ' - '.join([header_parts[0], ', '.join(header_parts[1:])])

    # Body: extended fields (if present)
    body_parts = [
        f'shared: {lib.human.bytes2human(getattr(virt, "shared", 0))}',
        f'buffers: {lib.human.bytes2human(getattr(virt, "buffers", 0))}',
        f'cached: {lib.human.bytes2human(getattr(virt, "cached", 0))}',
    ]
    msg_body = '\n' + ', '.join(body_parts)

    # Top X most memory consuming processes
    msg_body += '\n' + top(args.TOP)

    # perfdata (identical keys/units/min/max as before)
    total = getattr(virt, 'total', 0)
    perfdata += lib.base.get_perfdata(
        'usage_percent',
        usage_percent,
        uom='%',
        warn=args.WARN,
        crit=args.CRIT,
        _min=0,
        _max=100,
    )
    perfdata += lib.base.get_perfdata(
        'total',
        total,
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'used',
        getattr(virt, 'used', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'available',
        getattr(virt, 'available', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'free',
        getattr(virt, 'free', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'shared',
        getattr(virt, 'shared', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'buffers',
        getattr(virt, 'buffers', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )
    perfdata += lib.base.get_perfdata(
        'cached',
        getattr(virt, 'cached', 0),
        uom='B',
        warn=None,
        crit=None,
        _min=0,
        _max=total,
    )

    # over and out
    lib.base.oao(msg_header + msg_body, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
