#!/usr/bin/env python3
# -*- coding: utf-8; py-indent-offset: 4 -*-
#
# Author:  Linuxfabrik GmbH, Zurich, Switzerland
# Contact: info (at) linuxfabrik (dot) ch
#          https://www.linuxfabrik.ch/
# License: The Unlicense, see LICENSE file.

# https://github.com/Linuxfabrik/monitoring-plugins/blob/main/CONTRIBUTING.md

"""See the check's README for more details.
"""

import argparse  # pylint: disable=C0413
import sys  # pylint: disable=C0413

import lib.base  # pylint: disable=C0413
import lib.db_mysql  # pylint: disable=C0413
import lib.shell  # pylint: disable=C0413
import lib.time  # pylint: disable=C0413
import lib.txt  # pylint: disable=C0413
from lib.globals import (STATE_CRIT, STATE_OK,  # pylint: disable=C0413
                          STATE_UNKNOWN, STATE_WARN)

__author__ = 'Linuxfabrik GmbH, Zurich/Switzerland'
__version__ = '2023112901'

DESCRIPTION = 'Checks expiration date of certificates in a XCA based MySQL/MariaDB database.'

DEFAULT_CRIT = 5  # days
DEFAULT_DEFAULTS_FILE = '/var/spool/icinga2/.my.cnf'
DEFAULT_DEFAULTS_GROUP = 'client'
DEFAULT_TIMEOUT = 3
DEFAULT_WARN = 14 # days


def parse_args():
    """Parse command line arguments using argparse.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)

    parser.add_argument(
        '-V', '--version',
        action='version',
        version='{0}: v{1} by {2}'.format('%(prog)s', __version__, __author__)
    )

    parser.add_argument(
        '-c', '--critical',
        help='Set the critical for the expiration date in days. Default: %(default)s',
        dest='CRIT',
        type=int,
        default=DEFAULT_CRIT,
    )

    parser.add_argument(
        '--defaults-file',
        help='Specifies a cnf file to read parameters like user, host and password from '
             '(instead of specifying them on the command line), '
             'for example `/var/spool/icinga2/.my.cnf`. Default: %(default)s',
        dest='DEFAULTS_FILE',
        default=DEFAULT_DEFAULTS_FILE,
    )

    parser.add_argument(
        '--defaults-group',
        help='Group/section to read from in the cnf file. Default: %(default)s',
        dest='DEFAULTS_GROUP',
        default=DEFAULT_DEFAULTS_GROUP,
    )

    parser.add_argument(
        '--prefix',
        help='Set the table prefix of the XCA database.',
        dest='PREFIX',
        type=str,
    )

    parser.add_argument(
        '--timeout',
        help='Network timeout in seconds. Default: %(default)s (seconds)',
        dest='TIMEOUT',
        type=int,
        default=DEFAULT_TIMEOUT,
    )

    parser.add_argument(
        '-w', '--warning',
        help='Set the warning for the expiration date in days. Default: %(default)s',
        dest='WARN',
        type=int,
        default=DEFAULT_WARN,
    )

    return parser.parse_args()


def check_crts(conn, args, state):
    """Check certificates for expiration.
    """
    result = lib.base.coe(
        lib.db_mysql.select(
            conn,
            'select name, cert, ca, {}certs_serial '
            'from xca.{}view_certs '
            'where invaldate is null'.format(args.PREFIX, args.PREFIX)
        )
    )

    crts = 0
    expiring_crts = 0
    table_crts = []
    for row in result:
        # get the dict keys:
        name, cert, ca, serial = row
        crts += 1

        # execute the shell command and return its result and exit code
        stdout, stderr, retc = lib.base.coe(
            lib.shell.shell_exec(
                'openssl x509 -noout -dates -in /dev/stdin',
                # add line break after 10th character, for the openssl command to always work
                stdin='-----BEGIN CERTIFICATE-----\n{}\n{}\n-----END CERTIFICATE-----'.format(
                    row[cert][:10],
                    row[cert][10:],
                ),
        ))
        for line in stdout.splitlines():
            if not line.startswith('notAfter='):
                continue
            timestr = line.split('=')[1]       # Oct 29 08:41:00 2028 GMT
            timestr = lib.time.timestr2datetime(timestr, pattern='%b %d %H:%M:%S %Y %Z')
            delta = timestr - lib.time.now(as_type='datetime')
            local_state = lib.base.get_state(delta.days, args.WARN, args.CRIT, _operator='le')
            table_crts.append({
                'name': row[name],
                'ca': 'y' if row[ca] == 1 else 'n',
                'serial': row[serial],
                'state': lib.base.state2str(local_state, empty_ok=False),
                'timestr': '{}'.format(timestr),
            })
            if local_state != STATE_OK:
                expiring_crts += 1
                state = lib.base.get_worst(state, local_state)
            break
    return state, crts, expiring_crts, table_crts


def check_crls(conn, args, state):
    """Check CRLs for their next update.
    """
    result = lib.base.coe(
        lib.db_mysql.select(
            conn,
            'select name,crl '
            'from xca.{}view_crls '
            'order by date desc '
            'limit 1'.format(args.PREFIX),
        )
    )

    crls = 0
    expiring_crls = 0
    table_crls = []
    for row in result:
        # get the dict keys:
        name, cert = row
        crls += 1

        # execute the shell command and return its result and exit code
        stdout, stderr, retc = lib.base.coe(
            lib.shell.shell_exec(
                'openssl crl -noout -nextupdate -in /dev/stdin',
                stdin='-----BEGIN X509 CRL-----\n{}\n-----END X509 CRL-----'.format(row[cert]),
        ))
        for line in stdout.splitlines():
            if not line.startswith('nextUpdate='):
                continue
            timestr = line.split('=')[1]       # Oct 29 08:41:00 2028 GMT
            timestr = lib.time.timestr2datetime(timestr, pattern='%b %d %H:%M:%S %Y %Z')
            delta = timestr - lib.time.now(as_type='datetime')
            local_state = lib.base.get_state(delta.days, args.WARN, args.CRIT, _operator='le')
            table_crls.append({
                'name': row[name],
                'state': lib.base.state2str(local_state, empty_ok=False),
                'timestr': '{}'.format(timestr),
            })
            if local_state != STATE_OK:
                expiring_crls += 1
                state = lib.base.get_worst(state, local_state)
            break
    return state, crls, expiring_crls, table_crls


def main():
    """The main function. Hier spielt die Musik.
    """

    # parse the command line, exit with UNKNOWN if it fails
    try:
        args = parse_args()
    except SystemExit:
        sys.exit(STATE_UNKNOWN)

    # fetch data
    mysql_connection = {
        'defaults_file':  args.DEFAULTS_FILE,
        'defaults_group': args.DEFAULTS_GROUP,
        'timeout':        args.TIMEOUT,
    }
    conn = lib.base.coe(lib.db_mysql.connect(mysql_connection))

    # analyze data
    state = STATE_OK
    state, crts, expiring_crts, table_crts = check_crts(conn, args, state)
    state, crls, expiring_crls, table_crls = check_crls(conn, args, state)

    lib.db_mysql.close(conn)

    # create the message
    msg = '{} {} and {} {} checked. '.format(
        crts,
        lib.txt.pluralize('Certificate', crts),
        crls,
        lib.txt.pluralize('CRL', crls),
    )

    if expiring_crts > 0:
        msg += '{} {} {} expiring. '.format(
            expiring_crts,
            lib.txt.pluralize('Certificate', expiring_crts),
            lib.txt.pluralize('', expiring_crts, 'is,are'),
        )
    if expiring_crls > 0:
        msg += '{} {} {} expiring.'.format(
            expiring_crls,
            lib.txt.pluralize('CRL', expiring_crls),
            lib.txt.pluralize('', expiring_crls, 'is,are'),
        )
    if len(table_crts) > 0:
        msg += '\n\nCertificates:\n'
        msg += lib.base.get_table(
            table_crts,
            ['name', 'ca', 'serial', 'timestr', 'state'],
            header=['commonName', 'CA', 'Serial', 'Expiry date', 'State'],
        )
    if len(table_crls) > 0:
        msg += '\nCRLs:\n'
        msg += lib.base.get_table(
            table_crls,
            ['name', 'timestr', 'state'],
            header=['commonName', 'Expiry date', 'State'],
        )

    # over and out
    lib.base.oao(msg, state)


if __name__ == '__main__':
    try:
        main()
    except Exception:   # pylint: disable=W0703
        lib.base.cu()
