import re  # for input validation
import psycopg2  # for first run tasks
import requests
import logging
import urllib

from helpers import run


def make_set_db_statements(set_db):
    '''Using a list of dictionaries, yield a sequence of update statements
    to send to the database.

    String escaping happens according to the SQL standard. No characters
    are escaped, except single quotes.

    >>> make_set_db_statements([{
    ...     'table': 'mytable',
    ...     'key_col': 'mykey_id',
    ...     'val_col': 'myval',
    ...     'key': "ke'y",
    ...     'old_val': 'oldval\\n\\\\',
    ...     'new_val': 'newval',
    ... }]) == [
    ...     "update mytable\\n"
    ...     "   set mytable_modtime = now(),\\n"
    ...     "       mytable_author = '__first_run__',\\n"
    ...     "       mytable_myval = 'newval'\\n"
    ...     " where mytable_myval = 'oldval\\n\\\\'\\n"
    ...     "   and mytable_mykey_id = 'ke''y'\\n"
    ... ]
    True

    Literals are only accepted if they follow DB-Utils naming conventions,
    otherwise an AssertionError is raised.

    >>> make_set_db_statements([{'table': 'illegal;'}])
    Traceback (most recent call last):
      ...
    AssertionError: illegal literal
    >>> make_set_db_statements([{
    ...     'table': 'mytable',
    ...     'key_col': 'mykey_id',
    ...     'val_col': 'myval',
    ...     'key': "ke'y",
    ...     'ignore_old': True,
    ...     'new_val': 'newval',
    ... }]) == [
    ...     "update mytable\\n"
    ...     "   set mytable_modtime = now(),\\n"
    ...     "       mytable_author = '__first_run__',\\n"
    ...     "       mytable_myval = 'newval'\\n"
    ...     " where mytable_mykey_id = 'ke''y'\\n"
    ... ]
    True

    Empty strings are still correctly escaped
    >>> make_set_db_statements([{
    ...     'table': 'mytable',
    ...     'key_col': 'mykey_id',
    ...     'val_col': 'myval',
    ...     'key': "ke'y",
    ...     'old_val': 'oldval\\n\\\\',
    ...     'new_val': '',
    ... }]) == [
    ...     "update mytable\\n"
    ...     "   set mytable_modtime = now(),\\n"
    ...     "       mytable_author = '__first_run__',\\n"
    ...     "       mytable_myval = ''\\n"
    ...     " where mytable_myval = 'oldval\\n\\\\'\\n"
    ...     "   and mytable_mykey_id = 'ke''y'\\n"
    ... ]
    True
    '''
    statements = []
    for item in set_db:
        params = dict(item)
        for key in 'table', 'key_col', 'val_col':
            assert re.match(
                r'^[a-z][a-z0-9_]*$', params[key]
            ), "illegal literal"

        keys = ['key', 'old_val', 'new_val']
        for key in keys:
            if key in params:
                params[key] = "'{}'".format(
                    str(params[key] or '').replace("'", "''")
                )

        if params.get('ignore_old'):
            statements.append('''\
update {table}
   set {table}_modtime = now(),
       {table}_author = '__first_run__',
       {table}_{val_col} = {new_val}
 where {table}_{key_col} = {key}
'''.format(**params))

        else:
            statements.append('''\
update {table}
   set {table}_modtime = now(),
       {table}_author = '__first_run__',
       {table}_{val_col} = {new_val}
 where {table}_{val_col} = {old_val}
   and {table}_{key_col} = {key}
'''.format(**params))

    return statements


def send_db_statements(dbconn_string, statements):
    '''Send a sequence of statements to the database in a single transaction.

    During execution, all output tuple lists are collected, along with the
    number of rows affected, in a list of dictionaries:
    {
        'rowcount': <number of rows affected>,
        'tuples': <list of tuples>,
    }
    If a rollback was performed, the returned list is empty.
    '''
    outputs = []
    conn = psycopg2.connect(dbconn_string)
    cur = conn.cursor()
    try:
        for statement in statements:
            cur.execute(statement)
            # we have to retrieve the rowcount here because the cursor
            # might get destroyed in the 'try' process
            rows = cur.rowcount
            try:
                results = cur.fetchall()
            except psycopg2.ProgrammingError:
                # this exception is thrown if we have no results at all
                results = ()
            outputs.append({
                'statement': statement,
                'rowcount': rows,
                'tuples': results,
            })
    except psycopg2.OperationalError:
        conn.rollback()
        outputs = []
    else:
        conn.commit()
    return outputs


def perfact_result_parser(res):
    '''Takes a result object generated by the 'requests' module and parses
    it to determine if a PerFact specific error or login page was returned
    '''
    result = {
        'error': not res.ok,
    }

    # check for perfact specific error page
    if res.is_redirect:
        err_msg_url = 'standard_error_message'
        if err_msg_url in res.headers.get('Location'):
            result['error'] = True

    return result


def url_calls(commands, scheme='https', tls_verify=False, redirects=False,
              redirects_max=10, host='localhost', port='443',
              netrc='.netrc-cron', username=None, password=None,
              resultparsefunc=None):
    '''Invoke URLs on local ZOPE service instance with GET

    Authentication information is added automatically

    :commands is a list of dictonaries, the following keys are respected:
        :url    - Mandatory path on the server to invoke
        :params - Optional parameters for the request
        :method - Optional: Set to POST or GET (default)

    :scheme is a string to indicate either plain HTTP or HTTPS

    :tls_verify is a boolean which controls if HTTPS certificate validation
        should take place

    :redirects controls if redirects should be followed or not

    :host is the hostname/IP of the server

    :port is an optional port number e.g. '443'

    :netrc is a relative (!) path to a .netrc file in the ~/ home directory of
        the calling user, containing username/password for BASIC authentication
        against the given host. The 'host' value must have a match in the
        .netrc file for this to work!

    :username for manual BASIC auth if no netrc file is given
    :password for manual BASIC auth if no netrc file is given
    '''
    # base without port because :443 or :80 may be ommited on redirects
    base2 = scheme + '://' + host
    base = base2 + ':' + port

    if netrc:
        # this *must* be a tuple!
        requests.utils.NETRC_FILES = (netrc,)

    auth = None
    if username and password:
        auth = requests.auth.HTTPBasicAuth(username, password)

    output = []
    for item in commands:
        url = urllib.parse.urljoin(base, item['url'])
        params = item.get('params')
        count = 0
        method = item.get('method', 'GET')

        # we need a session to store cookies for redirects
        session = requests.Session()

        while True:
            if count > redirects_max:
                logging.error(
                    'Maximum redirect limit of %s reached! Skipping URL %s'
                    % (redirects_max, url)
                )
                break
            count += 1

            logging.info('Calling URL: %s' % url)
            logging.info('Parameters: %s' % params)
            res = session.request(
                method=method,
                url=url,
                params=params,
                allow_redirects=False,  # redirecting will be made manually
                verify=False,  # we do not care about invalid SSL certs
                auth=auth,
                files=item.get('files')
            )
            parsed = {
                'error': not res.ok,
            }
            status_code = res.status_code

            # custom parser for results
            if callable(resultparsefunc):
                logging.debug(
                    'Running custom result parser: %s'
                    % resultparsefunc.__name__
                )
                parsed = resultparsefunc(res)

            assert not parsed['error'], (
                'Request failed with (%s): %s [%s]'
                % (status_code, url, params)
            )

            # We only allow redirects to the same host with same protocol
            # and port. A relative location will be rewritten to an absolute
            # path.
            # A location could be:
            #  some_page
            #  /some_page
            #  ../../some_page
            #  https://host:port/some_page
            location = res.headers.get('Location')
            if location:
                location_absolute = urllib.parse.urljoin(base, location)

                if not (location_absolute.startswith(base) or
                        location_absolute.startswith(base2)):
                    # scheme or host changed - invalid!
                    location = None

            follow = redirects \
                and res.is_redirect \
                and location

            if not follow:
                # proceed with next url in commands
                break

            if follow:
                # redirect loop
                url = location
                # do not pass the parameters to the new location!
                params = None
                logging.info('Following redirect to: %s' % url)

        output.append({
            'url': url,
            'status': status_code,
        })
    return output


def run_commands(cmds):
    for line in cmds:
        cmd = line['cmd']
        options = line.get('options', {})
        logging.info('Executing command: %s' % cmd)
        run(cmd, **options)
