#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details."""

import argparse
import sys

import lib.args
import lib.base
import lib.db_mysql
import lib.shell
import lib.time
import lib.txt
from lib.globals import STATE_OK, STATE_UNKNOWN

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2026040801'

DESCRIPTION = """Checks the expiration dates of certificates stored in a XCA-managed MySQL/MariaDB
database. Alerts when certificates are about to expire or have already expired."""

DEFAULT_CRIT = 5  # days
DEFAULT_DEFAULTS_FILE = '/var/spool/icinga2/.my.cnf'
DEFAULT_DEFAULTS_GROUP = 'client'
DEFAULT_TIMEOUT = 3
DEFAULT_WARN = 14  # days


def parse_args():
    """Parse command line arguments using argparse."""
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V',
        '--version',
        action='version',
        version=f'%(prog)s: v{__version__} by {__author__}',
    )

    parser.add_argument(
        '-c',
        '--critical',
        help='CRIT threshold for certificate expiration in days. Default: %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--defaults-file',
        help=lib.args.help('--defaults-file') + ' '
        '(for MySQL/MariaDB cnf-style files). '
        'Example: `/var/spool/icinga2/.my.cnf`. '
        'Default: %(default)s',
        dest='DEFAULTS_FILE',
        default=DEFAULT_DEFAULTS_FILE,
    )

    parser.add_argument(
        '--defaults-group',
        help=lib.args.help('--defaults-group') + ' Default: %(default)s',
        dest='DEFAULTS_GROUP',
        default=DEFAULT_DEFAULTS_GROUP,
    )

    parser.add_argument(
        '--prefix',
        help='Table name prefix used in the XCA database.',
        dest='PREFIX',
        type=str,
    )

    parser.add_argument(
        '--timeout',
        help=lib.args.help('--timeout') + ' Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=DEFAULT_TIMEOUT,
    )

    parser.add_argument(
        '-w',
        '--warning',
        help='WARN threshold for certificate expiration in days. Default: %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    args, _ = parser.parse_known_args()
    return args


def check_crts(conn, args, state):
    """Check certificates for expiration."""
    # PREFIX is a trusted admin-controlled SQL identifier (Icinga check config), not user input
    result = lib.base.coe(
        lib.db_mysql.select(
            conn,
            f'select name, cert, ca, {args.PREFIX}certs_serial '  # nosec B608
            f'from xca.{args.PREFIX}view_certs '
            f'where invaldate is null',
        )
    )

    crts = 0
    expiring_crts = 0
    table_crts = []
    for row in result:
        # get the dict keys:
        name, cert, ca, serial = row
        crts += 1

        # execute the shell command and return its result and exit code
        stdout, _stderr, _retc = lib.base.coe(
            lib.shell.shell_exec(
                'openssl x509 -noout -dates -in /dev/stdin',
                # add line break after 10th character, for the openssl command to always work
                stdin=(
                    f'-----BEGIN CERTIFICATE-----\n'
                    f'{row[cert][:10]}\n'
                    f'{row[cert][10:]}\n'
                    f'-----END CERTIFICATE-----'
                ),
            )
        )
        for line in stdout.splitlines():
            if not line.startswith('notAfter='):
                continue
            timestr = line.split('=')[1]  # Oct 29 08:41:00 2028 GMT
            timestr = lib.time.timestr2datetime(timestr, pattern='%b %d %H:%M:%S %Y %Z')
            delta = timestr - lib.time.now(as_type='datetime')
            local_state = lib.base.get_state(
                delta.days, args.WARN, args.CRIT, _operator='le'
            )
            table_crts.append(
                {
                    'name': row[name],
                    'ca': 'y' if row[ca] == 1 else 'n',
                    'serial': row[serial],
                    'state': lib.base.state2str(local_state, empty_ok=False),
                    'timestr': f'{timestr}',
                }
            )
            if local_state != STATE_OK:
                expiring_crts += 1
                state = lib.base.get_worst(state, local_state)
            break
    return state, crts, expiring_crts, table_crts


def check_crls(conn, args, state):
    """Check CRLs for their next update."""
    # PREFIX is a trusted admin-controlled SQL identifier (Icinga check config), not user input
    result = lib.base.coe(
        lib.db_mysql.select(
            conn,
            f'select name,crl from xca.{args.PREFIX}view_crls order by date desc limit 1',  # nosec B608
        )
    )

    crls = 0
    expiring_crls = 0
    table_crls = []
    for row in result:
        # get the dict keys:
        name, cert = row
        crls += 1

        # execute the shell command and return its result and exit code
        stdout, _stderr, _retc = lib.base.coe(
            lib.shell.shell_exec(
                'openssl crl -noout -nextupdate -in /dev/stdin',
                stdin=f'-----BEGIN X509 CRL-----\n{row[cert]}\n-----END X509 CRL-----',
            )
        )
        for line in stdout.splitlines():
            if not line.startswith('nextUpdate='):
                continue
            timestr = line.split('=')[1]  # Oct 29 08:41:00 2028 GMT
            timestr = lib.time.timestr2datetime(timestr, pattern='%b %d %H:%M:%S %Y %Z')
            delta = timestr - lib.time.now(as_type='datetime')
            local_state = lib.base.get_state(
                delta.days, args.WARN, args.CRIT, _operator='le'
            )
            table_crls.append(
                {
                    'name': row[name],
                    'state': lib.base.state2str(local_state, empty_ok=False),
                    'timestr': f'{timestr}',
                }
            )
            if local_state != STATE_OK:
                expiring_crls += 1
                state = lib.base.get_worst(state, local_state)
            break
    return state, crls, expiring_crls, table_crls


def main():
    """The main function. This is where the magic happens."""

    # parse the command line
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # fetch data
    mysql_connection = {
        'defaults_file': args.DEFAULTS_FILE,
        'defaults_group': args.DEFAULTS_GROUP,
        'timeout': args.TIMEOUT,
    }
    conn = lib.base.coe(lib.db_mysql.connect(mysql_connection))

    # analyze data
    state = STATE_OK
    state, crts, expiring_crts, table_crts = check_crts(conn, args, state)
    state, crls, expiring_crls, table_crls = check_crls(conn, args, state)

    lib.db_mysql.close(conn)

    # create the message
    msg = (
        f'{crts} {lib.txt.pluralize("Certificate", crts)}'
        f' and {crls} {lib.txt.pluralize("CRL", crls)}'
        f' checked. '
    )

    if expiring_crts > 0:

        # build the message
        msg += (
            f'{expiring_crts}'
            f' {lib.txt.pluralize("Certificate", expiring_crts)}'
            f' {lib.txt.pluralize("", expiring_crts, "is,are")}'
            f' expiring. '
        )
    if expiring_crls > 0:
        msg += (
            f'{expiring_crls}'
            f' {lib.txt.pluralize("CRL", expiring_crls)}'
            f' {lib.txt.pluralize("", expiring_crls, "is,are")}'
            f' expiring.'
        )
    if len(table_crts) > 0:
        msg += '\n\nCertificates:\n'
        msg += lib.base.get_table(
            table_crts,
            ['name', 'ca', 'serial', 'timestr', 'state'],
            header=['commonName', 'CA', 'Serial', 'Expiry date', 'State'],
        )
    if len(table_crls) > 0:
        msg += '\nCRLs:\n'
        msg += lib.base.get_table(
            table_crls,
            ['name', 'timestr', 'state'],
            header=['commonName', 'Expiry date', 'State'],
        )

    # over and out
    lib.base.oao(msg, state)


if __name__ == '__main__':
    try:
        main()
    except Exception:
        lib.base.cu()
