#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import sys

import lib.args
import lib.base
import lib.lftest
import lib.shell
from lib.globals import STATE_OK, STATE_UNKNOWN

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026041001'

DESCRIPTION = """Counts the number of currently logged-in users by session type: tty (console) and pts
(SSH on Linux, RDP on Windows). On Windows, also counts disconnected sessions (closed
connections without logging out). Alerts when the total user count exceeds the
configured thresholds."""

DEFAULT_WARN_PTS = 20
DEFAULT_WARN_DISC = 1
DEFAULT_WARN_TTY = 1
DEFAULT_CRIT_PTS = None
DEFAULT_CRIT_DISC = None
DEFAULT_CRIT_TTY = None


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '-c',
        '--critical',
        default=[DEFAULT_CRIT_TTY, DEFAULT_CRIT_PTS, DEFAULT_CRIT_DISC],
        dest='CRIT',
        help='Threshold for logged-in tty/pts users, in the format "tty,pts". '
        'On Windows, you can additionally specify a threshold for disconnected users '
        'in the format "tty,pts,disc". '
        'Example: `--critical 3,10`. '
        'Default: %(default)s',
        type=lib.args.csv,
    )

    parser.add_argument(
        '--test',
        help=lib.args.help('--test'),
        dest='TEST',
        type=lib.args.csv,
    )

    parser.add_argument(
        '-w',
        '--warning',
        default=[DEFAULT_WARN_TTY, DEFAULT_WARN_PTS, DEFAULT_WARN_DISC],
        dest='WARN',
        help='Threshold for logged-in tty/pts users, in the format "tty,pts". '
        'On Windows, you can additionally specify a threshold for disconnected users '
        'in the format "tty,pts,disc". '
        'Example: `--warning 1,5`. '
        'Default: %(default)s',
        type=lib.args.csv,
    )

    args, _ = parser.parse_known_args()
    return args


def parse_linux_output(s):
    """Parse the output of `w` on Linux.

    Uses the header line to determine the TTY column position, so it works
    regardless of whether FROM is present, and regardless of column widths
    (which vary across distros and versions).
    """
    # replace pipes in output, otherwise we will get problems with perfdata,
    # and ignore the first line of w's output
    s = s.strip().replace('|', '!').splitlines()[1:]
    if not s:
        return s, 0, 0

    header = s[0]
    tty_start = header.find('TTY')
    if tty_start < 0:
        return s, 0, 0

    # find end of TTY column: start of the next column header after TTY
    after_tty = header[tty_start + 3 :]
    tty_end = tty_start + 3 + (len(after_tty) - len(after_tty.lstrip()))

    count_tty, count_pts = 0, 0
    for line in s[1:]:
        tty_value = line[tty_start:tty_end].strip() if len(line) > tty_start else ''
        if tty_value.startswith('tty') or tty_value.startswith(':'):
            # tty = local terminal, ":0" = local X display
            count_tty += 1
        else:
            # pts, empty (SSH without TTY) or anything else
            count_pts += 1

    return s, count_tty, count_pts


def parse_windows_output(s):
    """Parse the output of `query user` on Windows."""
    s = s.strip().splitlines()
    count_tty, count_pts, count_disc = 0, 0, 0
    for line in s:
        value = line.split()[1]
        if value == 'console':
            count_tty += 1
        if 'rdp-' in value:
            count_pts += 1
        if value == '':  # independent of display language
            count_disc += 1

    return s, count_tty, count_pts, count_disc


def main():
    """The main function. This is where the magic happens."""
    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # init some vars
    msg = ''
    perfdata = ''
    state = STATE_OK
    count_tty, count_pts, count_disc = 0, 0, 0

    try:
        WARN_TTY = lib.args.int_or_none(args.WARN[0])
        WARN_PTS = lib.args.int_or_none(args.WARN[1])
        CRIT_TTY = lib.args.int_or_none(args.CRIT[0])
        CRIT_PTS = lib.args.int_or_none(args.CRIT[1])
    except Exception:
        lib.base.cu('Unexpected parameter values for --warning and/or --critical.')

    WARN_DISC = None
    CRIT_DISC = None

    # fetch data
    if args.TEST is None:
        if lib.base.WINDOWS:
            WARN_DISC = lib.args.int_or_none(args.WARN[2])
            CRIT_DISC = lib.args.int_or_none(args.CRIT[2])
            cmd = 'query user'
            stdout, _, _ = lib.base.coe(lib.shell.shell_exec(cmd))
            # Could not find any documentation in the return codes of 'query user'
            # (%ERRORLEVEL% is always 1).
            lines, count_tty, count_pts, count_disc = parse_windows_output(stdout)
        else:
            cmd = '/usr/bin/w'
            stdout, stderr, retc = lib.base.coe(lib.shell.shell_exec(cmd))
            if stderr or retc != 0:
                lib.base.cu(stderr)
            lines, count_tty, count_pts = parse_linux_output(stdout)
    else:
        # do not call the command, put in test data
        stdout, _, _ = lib.lftest.test(args.TEST)
        if any('windows' in s for s in args.TEST):
            WARN_DISC = lib.args.int_or_none(args.WARN[2])
            CRIT_DISC = lib.args.int_or_none(args.CRIT[2])
            lines, count_tty, count_pts, count_disc = parse_windows_output(stdout)
        else:
            lines, count_tty, count_pts = parse_linux_output(stdout)

    # analyze data and build the message
    if count_tty == 0 and count_pts == 0 and len(lines) == 1:
        msg = 'No one is logged in.'
    else:
        tty_state = lib.base.get_state(count_tty, WARN_TTY, CRIT_TTY)
        state = lib.base.get_worst(state, tty_state)

        # build the message
        msg += f'TTY: {count_tty}' + lib.base.state2str(tty_state, prefix=' ')

        pts_state = lib.base.get_state(count_pts, WARN_PTS, CRIT_PTS)
        state = lib.base.get_worst(state, pts_state)
        msg += f', PTS: {count_pts}' + lib.base.state2str(pts_state, prefix=' ')

        if lib.base.WINDOWS:
            disc_state = lib.base.get_state(count_disc, WARN_DISC, CRIT_DISC)
            state = lib.base.get_worst(state, disc_state)
            msg += f', Disconnected: {count_disc}' + lib.base.state2str(
                disc_state, prefix=' '
            )
            perfdata += lib.base.get_perfdata(
                'disc',
                count_disc,
                uom=None,
                warn=WARN_DISC,
                crit=CRIT_DISC,
                _min=0,
                _max=None,
            )

        msg += '\n\n' + '\n'.join(lines)

    perfdata += lib.base.get_perfdata(
        'tty',
        count_tty,
        uom=None,
        warn=WARN_TTY,
        crit=CRIT_TTY,
        _min=0,
        _max=None,
    )
    perfdata += lib.base.get_perfdata(
        'pts',
        count_pts,
        uom=None,
        warn=WARN_PTS,
        crit=CRIT_PTS,
        _min=0,
        _max=None,
    )

    # over and out
    lib.base.oao(msg, state, perfdata)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
