#!/usr/bin/env python

import logging
import MySQLdb
import os
import rfc822
import subprocess
import sys

DB_USER = 'libdspam7-drv-my'
DB_PASS = 'xxxxx'
DB_NAME = 'libdspam7drvmysql'
DB_TABLE = 'auto_folder'

DSPAM_USER = 'pierre'
DSPAM_UID = 1

LAST_STARTED_SIGNATURE = '__last_started__'
LAST_FINISHED_SIGNATURE = '__last_finished__'

SIGNATURE_HEADER = 'X-DSPAM-Signature'

LOG_FILE = '%s/Maildir/dspam_auto.log' % os.environ['HOME']

DRY_RUN = False


def classify(signature, shall_be_here):
    if shall_be_here:
        kind = 'spam'
    else:
        kind = 'innocent'
        
    cmd = ['/usr/bin/dspam', 
           '--signature=%s' % signature, 
           '--class=%s' % kind, 
           '--source=error', 
           '--client',
           '--user', DSPAM_USER]

    if DRY_RUN:
        logging.info('[dryrun] Classify command: %s', ' '.join(cmd))
    else:
        logging.info('Classify command: %s', ' '.join(cmd))
        ret = 0
        try:
            ret = subprocess.call(cmd)
        except:
            logging.exception('Exception while running classify', cmd)

        if ret:
            logging.error('Classify command, non zero return code: %s', ret)


def MailIterator(basepath):
    # Iterator over mail for the scan commands.
    # Python 2.4 mailbox module is not rich enough, so reimplement basic
    # directory traversing.
    for subdir in ['new', 'cur']:
        path = os.path.join(basepath, subdir)
        for filename in os.listdir(path):
            try:
                flags = filename.rsplit(',', 1)[1]
            except IndexError:
                flags = ''
            if 'T' in flags: continue

            fullname = os.path.join(path, filename)
            try:
                message = rfc822.Message(open(fullname))
            except IOError:
                logging.warning('IOError on file %s', fullname)
                continue
            yield message


class Database(object):
    def __init__(self):
        self.conn = MySQLdb.connect(
            host='localhost',
            user=DB_USER,
            passwd=DB_PASS,
            db=DB_NAME)
        self.table = DB_TABLE
        self.cursor = self.conn.cursor()

        self.uid = DSPAM_UID

    def __del__(self):
        self.close()

    def close(self):
        self.cursor.close()
        self.conn.commit()
        self.conn.close()

    def bootstrap(self):
        req = """
            CREATE TABLE `%s` (
                `uid` smallint(5) unsigned NOT NULL,
                `signature` char(32) NOT NULL,
                `last_seen` int(10) unsigned default NULL,
                PRIMARY KEY  (`uid`,`signature`)
            );
        """ % self.table
        self.cursor.execute(req)

    def _initWatermarks(self):
        self.cursor.execute(
            'INSERT INTO %s (uid, signature, last_seen) VALUES (%%s, %%s, %%s),(%%s, %%s, %%s)' % self.table,
            (self.uid, LAST_STARTED_SIGNATURE, 0,
             self.uid, LAST_FINISHED_SIGNATURE, 0))

    def _getWatermark(self, name):
        self.cursor.execute(
            'SELECT last_seen FROM %s WHERE uid=%%s AND signature=%%s' % self.table,
            (self.uid, name))

        if not self.cursor.rowcount:
            self._initWatermarks()
            return self._getWatermark(name)
        else:
            return self.cursor.fetchone()[0]

    def hwm(self):
        return self._getWatermark(LAST_STARTED_SIGNATURE)

    def lwm(self):
        return self._getWatermark(LAST_FINISHED_SIGNATURE)

    def start(self):
        rowcount = 0
        
        # Atomic update of high watermark. That's an optimistic approach.
        # Starving is possible, but we suppose than we don't have many instance
        # running anyway.
        while not rowcount:
            hwm = self.hwm()
            new_hwm = hwm + 1
            self.cursor.execute(
                'UPDATE %s SET last_seen=%%s WHERE uid=%%s AND signature=%%s AND last_seen=%%s' % self.table,
                (new_hwm, self.uid, LAST_STARTED_SIGNATURE, hwm))
            rowcount = self.cursor.rowcount

        return new_hwm

    def finish(self, wm):
        self.cursor.execute(
            'UPDATE %s SET last_seen=%%s WHERE uid=%%s AND signature=%%s AND last_seen<%%s' % self.table,
                (wm, self.uid, LAST_FINISHED_SIGNATURE, wm))

    def touch(self, signature, wm):
        """Create or update the given signature.

        Args:
          signature: string representing the signature.
        Returns
          True if the entry was created, False if already existing.
        """
        self.cursor.execute(
            'INSERT %s (uid, signature, last_seen) VALUES (%%s, %%s, %%s) ' 
            'ON DUPLICATE KEY UPDATE last_seen=%%s' % self.table,
            (self.uid, signature, wm, wm))

        r = self.conn.affected_rows()
        # For a INSERT / ON DUPLICATE, affected rows is 1 if this is a new
        # content, 2 if it's an update.
        assert r in (1, 2)
        return r == 1

    def remove(self, signature):
        """Remove the given signature.

        Args:
          signature: string of the signature.
        Returns:
          True if the entry was present.
        """
        self.cursor.execute(
            'DELETE FROM %s WHERE uid=%%s AND signature=%%s' % self.table,
            (self.uid, signature))
        r = self.conn.affected_rows()
        return r == 1

    def gc(self, delta=3):
        hwm = self.hwm()
        lwm = self.lwm()

        # Use a delta to avoid marking a FP when a one-time OSError happen.
        lwm -= delta

        # TODO: handle case where hwm < lwm (counter looping)
        self.cursor.execute('SELECT signature FROM %s WHERE uid=%%s AND last_seen < %%s' % self.table,
                            (self.uid, lwm))

        for row in self.cursor.fetchall():
            signature = row[0]
            yield signature


def scanMbox(db, wm, boxpath):
    for message in MailIterator(boxpath):
        try:
            signature = message[SIGNATURE_HEADER] 
        except KeyError:
            logging.info('Message %s do not have dspam signature.', 
                         message['Message-Id'])
            continue

        new_msg = db.touch(signature, wm) 
        if new_msg:
            logging.info('New signature found: %s', signature)
            classify(signature, True)


def scan(db, args):
    wm = db.start()
    dirs = [os.path.abspath(a) for a in args]

    for d in dirs:
        logging.info('Scanning (wm=%s) %s...' % (wm, d))
        #scan_dir(db, d)
        #mbox = mailbox.Maildir(d)
        scanMbox(db, wm, d)
        logging.info('Finished (wm=%s) %s.' % (wm, d))

    db.finish(wm)


def gc(db):
    logging.info('Starting garbage collecting...')
    for signature in db.gc():
        logging.info('Garbage collecting %s', signature)
        classify(signature, False)
        db.remove(signature)
    
    logging.info('Finished garbage collecting.')
 

def push(db):
    message = rfc822.Message(sys.stdin)
    signature = message[SIGNATURE_HEADER] 

    # Race condition: when touching a signature, it shall be part of the last
    # successfull scan. However, if there's a scan in progress, when can
    # potentially miss it.
    hwm = db.hwm()
    logging.info('Touching %s at hwm %s', signature, hwm)
    db.touch(signature, hwm)


def test(db):
    s = 'test'
    
    print db.hwm(), db.lwm()

    wm = db.start()
    print db.touch(s, wm)
    print db.touch(s, wm)
    print db.remove(s)
    print db.remove(s)
    db.finish(wm)


def usage(msg=None):
    if msg:
        logging.fatal(msg)
    sys.stderr.write('Usage: %s <command> [args...]\n' % sys.argv[0])
    sys.exit(1)


def dispatch(command, args):
    db = Database()

    if command == 'init':
        db.bootstrap()
    elif command == 'test':
        test(db)
    elif command == 'scan':
        scan(db, args)
    elif command == 'gc':
        gc(db)
    elif command == 'update':
        scan(db, args)
        gc(db)
    elif command == 'push':
        push(db)
    else:
        usage('Unknown command name')


if __name__ == '__main__':
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(levelname)s %(asctime)s %(message)s',
        filename=LOG_FILE)

    console = logging.StreamHandler()
    console.setLevel(logging.WARNING)
    console.setFormatter(logging.Formatter('%(levelname)s %(asctime)s %(message)s'))
    logging.getLogger('').addHandler(console)

    if len(sys.argv) < 2:
        usage('Missing command name')

    try:
        command = sys.argv[1]
        args = sys.argv[2:]
        dispatch(command, args)
    except:
        logging.exception('Unhandeld error.')
        sys.exit(1)
