#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import sys

import lib.args
import lib.base
import lib.db_mysql
import lib.txt
from lib.globals import STATE_OK, STATE_UNKNOWN

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026040801'

DESCRIPTION = """Connects to a MySQL/MariaDB database and runs configurable SQL queries for warning
and critical conditions. The query result - either a row count or a specific value -
is checked against Nagios range expressions. Useful for custom application-level
monitoring."""

DEFAULT_DEFAULTS_FILE = '/var/spool/icinga2/.my.cnf'
DEFAULT_DEFAULTS_GROUP = 'client'
DEFAULT_TIMEOUT = 3


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '--always-ok',
        help=lib.args.help('--always-ok'),
        dest='ALWAYS_OK',
        action='store_true',
        default=False,
    )

    parser.add_argument(
        '-c',
        '--critical',
        help='CRIT threshold as a Nagios range expression.',
        dest='CRIT',
    )

    parser.add_argument(
        '--critical-query',
        help='`SELECT` statement to evaluate for CRIT. '
        'If the result contains more than one column, the row count is checked against `--critical`. '
        'Otherwise the single returned value is used.',
        dest='CRITICAL_QUERY',
    )

    parser.add_argument(
        '--defaults-file',
        help='MySQL/MariaDB cnf file to read user, host and password from. '
        'Example: `--defaults-file=/var/spool/icinga2/.my.cnf`. '
        'Default: %(default)s',
        dest='DEFAULTS_FILE',
        default=DEFAULT_DEFAULTS_FILE,
    )

    parser.add_argument(
        '--defaults-group',
        help=lib.args.help('--defaults-group') + ' Default: %(default)s',
        dest='DEFAULTS_GROUP',
        default=DEFAULT_DEFAULTS_GROUP,
    )

    parser.add_argument(
        '--timeout',
        help=lib.args.help('--timeout') + ' Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=DEFAULT_TIMEOUT,
    )

    parser.add_argument(
        '-w',
        '--warning',
        help='WARN threshold as a Nagios range expression.',
        dest='WARN',
    )

    parser.add_argument(
        '--warning-query',
        help='`SELECT` statement to evaluate for WARN. '
        'If the result contains more than one column, the row count is checked against `--warning`. '
        'Otherwise the single returned value is used.',
        dest='WARNING_QUERY',
    )

    args, _ = parser.parse_known_args()
    return args


def get_state_and_value(conn, query, threshold, _type):
    """Execute SQL query, get the value and check against the threshold.
    * One row, one column: Check this single value.
    * x rows: Check the number of rows against the threshold.
    """
    state = STATE_OK
    value = 0
    result = []
    shortened = False
    if query:
        result = lib.base.coe(lib.db_mysql.select(conn, query))
        if result:
            if len(result) == 1 and len(result[0]) == 1:
                # one row, one column: could be a "select count(*) from ..." result
                value = next(iter(result[0].values()))
            else:
                # a bunch of rows (at least one) with multiple columns, so count them
                value = len(result)
                # shorten the result if there are too many rows
                if len(result) > 10:
                    # shorten the result
                    result = result[0:5] + result[-5:]
                    shortened = True

            if _type == 'warn':
                state = lib.base.get_state(value, threshold, None, _operator='range')
            else:
                state = lib.base.get_state(value, None, threshold, _operator='range')

    return state, value, result, shortened


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    if args.WARNING_QUERY is None and args.CRITICAL_QUERY is None:
        lib.base.cu('Nothing to check, no queries provided.')

    # init some vars
    msg = ''
    state = STATE_OK
    perfdata = ''

    mysql_connection = {
        'defaults_file': args.DEFAULTS_FILE,
        'defaults_group': args.DEFAULTS_GROUP,
        'timeout': args.TIMEOUT,
    }
    conn = lib.base.coe(lib.db_mysql.connect(mysql_connection))
    lib.base.coe(lib.db_mysql.check_select_privileges(conn))

    # analyze data
    state_warn, cnt_warn, result_warn, shortened_warn = get_state_and_value(
        conn,
        args.WARNING_QUERY,
        args.WARN,
        'warn',
    )
    state = lib.base.get_worst(state, state_warn)
    state_crit, cnt_crit, result_crit, shortened_crit = get_state_and_value(
        conn,
        args.CRITICAL_QUERY,
        args.CRIT,
        'crit',
    )
    state = lib.base.get_worst(state, state_crit)

    lib.db_mysql.close(conn)

    # build the message
    if args.WARNING_QUERY:
        msg = (
            f'{cnt_warn}'
            f' {lib.txt.pluralize("result", cnt_warn)}'
            f' from warning query `{args.WARNING_QUERY}`'
            f'{lib.base.state2str(state_warn, prefix=" ")}'
        )
    if args.WARNING_QUERY and args.CRITICAL_QUERY:
        msg += ' and '
    if args.CRITICAL_QUERY:
        msg += (
            f'{cnt_crit}'
            f' {lib.txt.pluralize("result", cnt_crit)}'
            f' from critical query `{args.CRITICAL_QUERY}`'
            f'{lib.base.state2str(state_crit, prefix=" ")}'
        )
    msg += '\n'

    if shortened_warn:
        msg += '\nAttention: Table below is truncated, showing the 5 first and the 5 last items.\n'
    try:
        keys = result_warn[0].keys()
        headers = keys
        msg += '\n' + lib.base.get_table(result_warn, keys, header=headers)
    except Exception:
        # no results
        pass

    if shortened_crit:
        msg += '\nAttention: Table below is truncated, showing the 5 first and the 5 last items.\n'
    try:
        keys = result_crit[0].keys()
        headers = keys
        msg += '\n' + lib.base.get_table(result_crit, keys, header=headers)
    except Exception:
        # no results
        pass

    perfdata += lib.base.get_perfdata(
        'cnt_warn',
        cnt_warn,
        warn=args.WARN,
    )
    perfdata += lib.base.get_perfdata(
        'cnt_crit',
        cnt_crit,
        crit=args.CRIT,
    )

    # over and out
    lib.base.oao(msg, state, perfdata, always_ok=args.ALWAYS_OK)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
