#!/usr/bin/python
# -*- coding: utf-8 -*-

"""
Stressant is a simple yet complete stress-testing tool that forces
a computer to perform a series of test using well-known Linux software
in order to detect possible design or construction failures.
"""

# Copyright (C) 2017 Antoine Beaupré <anarcat@debian.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import getpass
import logging
from logging.handlers import SMTPHandler, MemoryHandler
import multiprocessing
import os
import os.path

import argparse
import smtplib
import socket
import subprocess
import tempfile
import time

try:
    import colorlog
    if 'StreamHandler' not in dir(colorlog):
        colorlog = False
except ImportError:
    colorlog = False
import humanize

try:
    from setuptools_scm import get_version
    __version__ = get_version()
except (ImportError, LookupError):
    # try the local generated version
    #
    # XXX: this may load an arbitrary version in another package!
    try:
        from __version import __version__
    except ImportError:
        __version__ = '???'


class NegateAction(argparse.Action):
    '''add a toggle flag to argparse

    this is similar to 'store_true' or 'store_false', but allows
    arguments prefixed with --no to disable the default. the default
    is set depending on the first argument - if it starts with the
    negative form (define by default as '--no'), the default is False,
    otherwise True.
    '''

    negative = '--no'

    def __init__(self, option_strings, *args, **kwargs):
        '''set default depending on the first argument'''
        default = not option_strings[0].startswith(self.negative)
        super(NegateAction, self).__init__(option_strings, *args,
                                           default=default, nargs=0, **kwargs)

    def __call__(self, parser, ns, values, option):
        '''set the truth value depending on whether
        it starts with the negative form'''
        setattr(ns, self.dest, not option.startswith(self.negative))


def parseArgs():
    '''parse commandline arguments and set defaults'''
    parser = argparse.ArgumentParser(version=__version__, epilog=__doc__)
    parser.add_argument('--logfile', default=None,
                        help='write reports to the given logfile (default: %(default)s)')
    parser.add_argument('--email', help='send report by email to given address')
    parser.add_argument('--smtpserver',
                        help=('SMTP server to use, use a colon to specify '
                              'the port number if non-default (%(port)d).'
                              ' willl attempt to use STARTTLS to secure '
                              'the connexion and fail if unsupported '
                              '(default: deliver using the --mta '
                              'command)') %
                        {'port': smtplib.SMTP_PORT})
    parser.add_argument('--smtpuser',
                        help=('username for the SMTP server '
                              '(default: no user)'))
    parser.add_argument('--smtppass',
                        help=('password for the SMTP server '
                              '(default: prompted, if --smtpuser is '
                              'specified)'))
    parser.add_argument('--information', '--no-information', action=NegateAction,
                        help='gather basic information (default: %(default)s)')
    parser.add_argument('--disk', '--no-disk', dest='disk', action=NegateAction,
                        help='run disk tests (default: %(default)s)')
    parser.add_argument('--no-smart', '--smart', dest='smart', action=NegateAction,
                        help='run SMART tests (default: %(default)s)')
    # XXX: disk detection could be done in a number of ways:
    #
    # * psutil.disk_partitions() - only lists mounted, but psutil also has
    #   features like checking amount of RAM, sensors and network..
    #
    # * parsing /proc/partitions
    #
    # * glob !/sys/block/%s/device/block/%s/removable
    parser.add_argument('--diskDevice', default='/dev/sda',
                        help='device to benchmark (default: %(default)s)')
    parser.add_argument('--overwrite', action='store_true',
                        help='actually destroy the given device (default: %(default)s)')
    parser.add_argument('--diskPercent', default='0%',
                        help='how much of the disk to trash (default: %(default)s)')
    parser.add_argument('--directory', default=None,
                        help='directory to perform file tests in, created if missing (default: %(default)s)')
    parser.add_argument('--fileSize', default='100M',
                        help='file size for I/O benchmarks (default: %(default)s)')
    parser.add_argument('--cpu', '--no-cpu', action=NegateAction,
                        help='run CPU tests (default: %(default)s)')
    parser.add_argument('--cpuBurnTime', default='1m',
                        help='timeout for CPU burn-in (default: %(default)s)')
    parser.add_argument('--network', '--no-network', action=NegateAction,
                        help='run network tests (default: %(default)s)')
    # see also https://iperf.fr/iperf-servers.php
    # XXX: we chose he.net because they are nice, but ideally we:
    # 1. would ask first
    # 2. have a DNS round-robin for this, like NTP
    parser.add_argument('--iperfServer', default='iperf.he.net',
                        help='iperf server to use (default: %(default)s)')
    parser.add_argument('--iperfTime', default=str(60),
                        help='timeout for iperf test, in seconds (default: %(default)s)')
    return parser.parse_args()


class BufferedSMTPHandler(SMTPHandler, MemoryHandler):
    """A handler class which sends records only when the buffer reaches
    capacity. The object is constructed with the arguments from
    SMTPHandler and MemoryHandler and basically behaves as a merge
    between the two classes.

    The SMTPHandler.emit() implementation was copy-pasted here because
    it is not flexible enough to be overridden. We could possibly
    override the format() function to instead look at the internal
    buffer, but that would have possibly undesirable side-effects.
    """

    def __init__(self, mailhost, fromaddr, toaddrs, subject,
                 credentials=None, secure=None,
                 capacity=5000, flushLevel=logging.ERROR, retries=1):
        SMTPHandler.__init__(self, mailhost, fromaddr, toaddrs, subject,
                             credentials=None, secure=None)
        self.retries = retries
        MemoryHandler.__init__(self, capacity=capacity, flushLevel=flushLevel)

    def emit(self, record):
        '''buffer the record in the MemoryHandler'''
        MemoryHandler.emit(self, record)

    def flush(self):
        """Flush all records.

        Format the records and send it to the specified addressees.

        The only change from SMTPHandler here is the way the email
        body is created.

        """
        if self.retries < 0:
            logging.error('Could not send email: %s', self.lastException)
        if len(self.buffer) <= 0:
            return
        body = ''
        for record in self.buffer:
            body += self.format(record) + "\n"
        try:
            import smtplib
            from email.utils import formatdate
            port = self.mailport
            if not port:
                port = smtplib.SMTP_PORT
            smtp = smtplib.SMTP(self.mailhost, port, timeout=self._timeout)
            msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\nDate: %s\r\n\r\n%s" % (
                            self.fromaddr,
                            ",".join(self.toaddrs),
                            self.getSubject(record),
                            formatdate(), body)
            if self.secure is not None:
                smtp.ehlo()
                smtp.starttls(*self.secure)
                smtp.ehlo()
            if self.username:
                smtp.login(self.username, self.password)
            smtp.sendmail(self.fromaddr, self.toaddrs, msg)
            smtp.quit()
            logging.info('sent email to %s using %s', self.toaddrs, self.mailhost)
            self.buffer = []
        except (KeyboardInterrupt, SystemExit):
            raise
        except smtplib.SMTPRecipientsRefused as e:
            for email, error in e.recipients.iteritems():
                if error[0] == 450:  # greylisting
                    logging.info('got temporary error, waiting 5 minutes for email')
                    self.retries -= 1
                    time.sleep(5*60)
                    self.lastException = e
                    self.flush()
        except:
            self.handleError(record)


def setupLogging(logfile=None, email=None,
                 smtpserver=None, smtpuser=None, smtppass=None,
                 **args):
    '''setup standard Python logging facilities

    we create a new facility called "output" to avoid coloring command
    output and distinguishing it from out own output

    we also setup various other logging handlers as specified on the
    commandline
    '''
    defaultFormat = '%(levelname)s: %(message)s'
    logging.OUTPUT = logging.INFO + 1
    logging.addLevelName(logging.OUTPUT, 'OUTPUT')
    if colorlog:
        handler = colorlog.StreamHandler()
        handler.setFormatter(colorlog.ColoredFormatter('%(log_color)s' + defaultFormat))
        logger = colorlog.getLogger('')
    else:
        logger = logging.getLogger('')
    logger.setLevel(logging.DEBUG)
    if colorlog:
        logger.addHandler(handler)
    if logfile:
        handler = logging.FileHandler(logfile)
        handler.setFormatter(logging.Formatter(defaultFormat))
        logger.addHandler(handler)
    if email:
        if not smtpserver:
            _, smtpserver = email.split('@', 1)
        # XXX: need to do MX discovery
        fromaddr = getpass.getuser() + '@' + socket.getfqdn()
        subject = 'Stressant report'
        credentials = None
        if smtpuser:
            if smtppass:
                smtppass = getpass.getpass('enter SMTP password for server %s: ' % smtpserver)
            credentials = (smtpuser, smtppass)
        handler = BufferedSMTPHandler(smtpserver,
                                      fromaddr,
                                      email,
                                      subject,
                                      secure=(),
                                      credentials=credentials,
                                      flushLevel=logging.CRITICAL)
        handler.setFormatter(logging.Formatter(defaultFormat))
        logger.addHandler(handler)


def collectCmd(args):
    '''collect output from the given command and feed it into the logging system'''
    logging.debug('Calling %s', ' '.join(args))
    proc = subprocess.Popen(args, stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT)
    for line in proc.stdout:
        logging.log(logging.OUTPUT, line.strip())
    returnCode = proc.wait()
    if returnCode != 0:
        logging.error("Command failed: Command '%s' returned non-zero exit status %d",
                      ' '.join(args), returnCode)


def collectCmdWithTmp(args):
    '''this will create a tempfile and append it to the last argument of the command

    the goal is to be able to run commands interactively. fio, for
    example, shows stuff on stderr that need to be unbuffered and
    shouldn't show up in logs'''
    _, tmpfile = tempfile.mkstemp()
    args[-1] += tmpfile
    logging.debug('Calling %s', ' '.join(args))
    try:
        subprocess.check_call(args)
    except subprocess.CalledProcessError as e:
        logging.error("Command failed: %s", e)
    for line in open(tmpfile):
        logging.log(logging.OUTPUT, line.rstrip())
    os.unlink(tmpfile)


def gatherInfo(diskDevice='/dev/null', **args):
    '''gather basic information from system'''
    cpuCount = multiprocessing.cpu_count()
    logging.info("CPU cores: %d", cpuCount)

    memory = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
    human = humanize.naturalsize(memory, binary=True, format="%.0f")
    logging.info("Memory: %s (%d bytes)", human, memory)

    logging.info("Hardware inventory")
    collectCmd(["lshw", "-short"])

    logging.info("SMART information for %s", diskDevice)
    collectCmd(["smartctl", "-qnoserial", "-i", diskDevice])


def testDrive(overwrite=False, diskDevice='/dev/null',
              diskPercent='0%', fileSize='0', reportDir='.',
              smart=True, directory=None, **args):
    '''disk tests'''
    # logging.info("How long a test takes")
    # XXX: i often need to use -d sat on external drives
    if smart:
        # XXX: need to parse this and wait and do magic
        # XXX: this is already available in -a, above
        # collectCmd(["smartctl", "-c", diskDevice])
        logging.info("Starting long SMART test")
        collectCmd(["smartctl", "-t", "long", diskDevice])
        # the above says:
        # Please wait 10 minutes for test to complete.
        # Test will complete after Wed Jan  4 21:28:11 2017
        # in 10 minutes:
        # smartctl -l selftest $disk
        # smartctl -a $disk says:
        # Self-test execution status:      ( 249)	Self-test routine in progress...
        #                                               90% of test remaining.

    if directory:
        _, testFile = tempfile.mkstemp(dir=directory)
        logging.info("Basic disk bandwidth tests")
        logging.info("Writing 1MB file %s", testFile)
        collectCmd(["dd", "bs=1M", "count=512", "conv=fdatasync",
                    "if=/dev/zero", "of=" + testFile])
        logging.info("Reading 1MB file %s", testFile)
        collectCmd(["dd", "bs=1M", "count=512", "of=/dev/null", "if=" + testFile])
        os.unlink(testFile)
    else:
        logging.warn('no dd test ran, provide --directory to run')
    logging.info("hdparm test on %s", diskDevice)
    collectCmd(["hdparm", "-Tt", diskDevice])

    logging.info("Disk stress test")
    # --readwrite=randrw, cargo-culted from #grml, random mix of read/write
    # --numjobs=4, more threads to load a bit more the heads on spinning disks
    # --sync=1, not sure, cargo-culted
    # --direct=1, to bypass disk caches so we don't need to worry about RAM
    # --size=100M, unsure - cargo-culted as well. can be bytes or %,
    # defaults to 100%, can also use --runtime
    # --group_reporting, give only one report, not one per job
    cmd = ["fio", "--name=stressant", "--readwrite=randrw",
           "--numjob=4", "--sync=1", "--direct=1", "--group_reporting"]
    if overwrite:
        # XXX: this is supposed to wipe the drive, but is that enough?
        # see also https://www.backblaze.com/blog/how-to-securely-recycle-or-dispose-of-your-ssd/
        cmd += ["--filename=" + diskDevice, "--size=" + diskPercent]
    elif directory:
        if not os.path.exists(directory):
            os.makedirs(directory)
        cmd += ["--size=" + fileSize, "--directory=" + directory]
    else:
        logging.error('--overwrite or --directory not specified, no fio test ran')
        return
    cmd += ["--output="]
    # more ideas:
    # https://wiki.mikejung.biz/Benchmarking#Fio_Test_Options_and_Examples
    # https://gist.github.com/tcooper/9417014
    # https://github.com/GoogleCloudPlatform/PerfKitBenchmarker
    # how to precondition for SSD benchmarks:
    # https://www.spinics.net/lists/fio/msg02496.html
    collectCmdWithTmp(cmd)


def testCpu(cpuBurnTime=None, reportDir='.', **args):
    '''stress-test the CPU'''
    logging.info("CPU stress test for %s", cpuBurnTime)
    cmd = ["stress-ng", "--timeout", cpuBurnTime,
           "--cpu", "0", "--ignite-cpu",
           "--metrics-brief", "--log-brief",
           "--tz", "--times", "--aggressive"]
    collectCmd(cmd)
    # --matrix 0 is apparently the best way to heat up the CPU
    #
    # --verify would be important, not sure it works with CPU.
    #
    # according to this it works with --vm:
    # https://wiki.ubuntu.com/Kernel/Reference/stress-ng
    # also, i7z is useful to show the status of the CPU, including temperatures

    # similar tools:
    # linpack: not in debian
    # mprime: not in debian, not free software
    # systester: not in debian


def testNetwork(iperfServer=None, iperfTime=None, **args):
    '''basic network tests'''
    logging.info('Running network benchmark')
    # we use iperf, but apparently netperf is more effective:
    # https://www.bufferbloat.net/projects/cerowrt/wiki/Netperf/
    # see also this article:
    # http://iwl.com/white-papers/iperf
    collectCmd(['iperf3', '-c', iperfServer, '-t', iperfTime])


def main():
    args = parseArgs()
    setupLogging(**vars(args))
    logging.info('Starting tests')
    if args.information:
        gatherInfo(**vars(args))
    if args.disk:
        testDrive(**vars(args))
    if args.cpu:
        testCpu(**vars(args))
    if args.network:
        testNetwork(**vars(args))
    logging.info("all done")
    # make sure emails get flushed
    logging.shutdown()


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        logging.error("Interrupted")
