#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# ssh.py  -  Utilities for connecting via OpenSSH and setup of keys/certs
#            Support for legacy systems (subset of sna_utils.py)
#
# Copyright (C) 2018 Thorsten Südbrock <thorsten.suedbrock@perfact.de>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import six
import os
import tempfile
import binascii
import argparse
import sys
from .generic import safe_syscall, to_string
from time import strftime
from io import BytesIO, StringIO
if six.PY2:
    from ConfigParser import ConfigParser
else:
    from configparser import ConfigParser
if six.PY2:
    from base64 import decodestring as decodebytes
else:
    from base64 import decodebytes


# original in /var/lib/zope2.13/instance/ema/Extensions/GONE/sna_utils.py_GONE
def generate_key(usePutty=None, authopt=None, unixuser='secconnect',
                 bits=2048):
    '''
    Generate a public/private key pair for use in ssh.
    The private key is returned, while the public key is
    appended to the authorized_keys file of the given
    unixuser.
    '''
    # make key in unique temp directory
    dirname = tempfile.mkdtemp()
    timestamp = strftime('%Y%m%d%H%M%S_%s')
    os.chmod(dirname, 0o700)
    keyfile = os.path.join(dirname, 'keyfile')
    pubfile = os.path.join(dirname, 'keyfile.pub')
    # os.chdir(dirname)

    # create a private key
    cmd = ['/usr/bin/ssh-keygen', '-C', timestamp,
           '-f', keyfile, '-N', '', '-t', 'rsa', '-m', 'PEM', '-b', str(bits)]
    err, out = safe_syscall(cmd, raisemode=True)

    if (usePutty):
        keyfile_putty = keyfile + '_oss'
        # determine which puttygen cmd to issue
        # is the parameter --ppk-param available or not?
        # we must enforce version=2 to support old putty versions in the field
        cmd = ['/usr/bin/puttygen', '--help']
        err, out = safe_syscall(cmd, raisemode=True)
        cmd_add = []
        if '--ppk-param' in out:
            cmd_add = ['--ppk-param', 'version=2']
        cmd = ['/usr/bin/puttygen', keyfile, '-o', keyfile_putty] + cmd_add
        err, out = safe_syscall(cmd, raisemode=True)
        os.rename(keyfile_putty, keyfile)

    # read public/private keys

    fh = open(pubfile)
    pubkey = fh.readline()[:-1]
    fh.close()

    fh = open(keyfile)
    privkey = ''.join(fh.readlines())
    fh.close()

    # remove temporary files again
    os.remove(keyfile)
    os.remove(pubfile)
    os.rmdir(dirname)

    # unixuser home file
    homepath = os.path.expanduser('~'+unixuser)
    # append the generated public key to the ~/.ssh/authorized_keys using
    # the 'appendkey' binary
    appendkey = os.path.join(homepath, 'appendkey')
    cmd = [appendkey, homepath]
    if authopt:
        cmd.append(authopt + ' ' + pubkey)
    else:
        cmd.append(pubkey)
    err, out = safe_syscall(cmd, raisemode=True)

    return privkey


def getHostHash(path='/etc/ssh/ssh_host_rsa_key.pub', key=None):
    """ Print out hash value of a SSH public or private key

    >>> from .file import fileassets  # doctest: +SKIP
    >>> getHostHash(key=fileassets['tests.ssh_rsa_pub'])  # doctest: +SKIP
    'SHA256:1zm+8IkeokEbEjG2Y66uKm+aDV+Ta8IXDWOIGH0Ttp4'

    >>> from .file import fileassets  # doctest: +SKIP
    >>> getHostHash(key=fileassets['tests.ssh_rsa_key'])  # doctest: +SKIP
    'SHA256:1zm+8IkeokEbEjG2Y66uKm+aDV+Ta8IXDWOIGH0Ttp4'

    >>> from .file import fileassets  # doctest: +SKIP
    >>> getHostHash(path=fileassets.data['tests.ssh_rsa_pub']['filename']
    ... )  # doctest: +SKIP
    'SHA256:1zm+8IkeokEbEjG2Y66uKm+aDV+Ta8IXDWOIGH0Ttp4'
    """
    if not (key or path):
        raise AssertionError(
            "A key (string) or path (file path) must be given!")

    if path:
        if os.path.exists(path):
            keypath = path
        else:
            raise AssertionError("Given file path does not exist: %s" % path)
    tmpfile = None
    if key:
        # write out to tmpdir
        tmpdir = tempfile.mkdtemp()
        tmpfile = os.path.join(tmpdir, 'key')
        keypath = tmpfile
        f = open(tmpfile, 'w')
        f.write(key)
        f.close()
        os.chmod(tmpfile, 0o600)

    cmd = ['/usr/bin/ssh-keygen', '-l', '-f', keypath]
    err, out = safe_syscall(cmd, raisemode=False)

    if tmpfile and os.path.exists(tmpfile):
        # remove tmpfile again
        os.remove(tmpfile)
        os.rmdir(tmpdir)

    if err != 0:
        raise AssertionError("An error occured. Errorlevel: %s - Output: %s"
                             % (err, out))
    else:
        return out.split(' ')[1]


def get_pub_key(key, inputmode='openssh'):
    """ Output the corresponding SSH public key to a given SSH private key
    Supported inputmodes are 'openssh' and 'putty'
    Default is to use 'openssh'. The output format will always be openssh

    >>> from .file import fileassets
    >>> get_pub_key(key=fileassets['tests.ssh_rsa_key'])
    'ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDCs8us9NI70QNwHW1boSAONhem8PwCCC76\
m/WqtGzlvl9mxrHDHChhebQByjDO0ZKcO3Cm/HY/qA84RxGW93fzLCjxpQwtU8ymyiqeF1Hm23gUc\
oAVf+C1SHf6ivMkUk5LqxPdMM1GfsFVz/2u26sd/3KcABzi7ouYyNWrXTlwaw==\\n'
    """
    # write out to tmpdir
    tmpdir = tempfile.mkdtemp()
    tmpfile = os.path.join(tmpdir, 'privkey')
    f = open(tmpfile, 'w')
    f.write(key)
    f.close()
    os.chmod(tmpfile, 0o600)

    if inputmode == 'putty':
        cmd = ['/usr/bin/puttygen', tmpfile, '-O', 'public-openssh']
        err, out = safe_syscall(cmd, raisemode=False)
    else:
        cmd = ['/usr/bin/ssh-keygen', '-y', '-f', tmpfile]
        err, out = safe_syscall(cmd, raisemode=False)

    # remove tmpfile again
    os.remove(tmpfile)
    os.rmdir(tmpdir)

    if err != 0:
        raise AssertionError(
            "An error occured. Errorlevel: %s - Output: %s" % (err, out)
        )
    else:
        return out


def toint(value):
    '''Convert four big-endian bytes into an integer. This is needed to
    split data blocks in SSH public keys.

    >>> toint(b'\\x00\\x00\\x00\\x07ssh-rsa')
    7
    >>> toint(b'\\x02\\x03\\x04\\x05')
    33752069
    '''
    intval = 0
    for item in value[0:4]:
        if not isinstance(item, int):
            # In Python2, we need to convert to an integer with ord()
            item = ord(item)
        intval = (intval << 8) + item
    return intval


def getPuttyKnownHost(path='/etc/ssh/ssh_host_rsa_key.pub', pubkey=None):
    """Convert the given public key into the PuTTY compatible format, so it
    can be used in the known_hosts file on Windows systems.

    >>> from .file import fileassets
    >>> from .generic import to_string
    >>> to_string(getPuttyKnownHost(pubkey=fileassets['tests.ssh_rsa_pub']))
    '0x10001,0xc2b3cbacf4d23bd103701d6d5ba1200e3617a6f0fc02082efa9bf5aab46ce5\
be5f66c6b1c31c286179b401ca30ced1929c3b70a6fc763fa80f38471196f777f32c28f1a50c2\
d53cca6ca2a9e1751e6db78147280157fe0b54877fa8af324524e4bab13dd30cd467ec155cffd\
aedbab1dff729c001ce2ee8b98c8d5ab5d39706b'

    >>> from .file import fileassets
    >>> to_string(getPuttyKnownHost(
    ...     path=fileassets.data['tests.ssh_rsa_pub']['filename']))
    '0x10001,0xc2b3cbacf4d23bd103701d6d5ba1200e3617a6f0fc02082efa9bf5aab46ce5\
be5f66c6b1c31c286179b401ca30ced1929c3b70a6fc763fa80f38471196f777f32c28f1a50c2\
d53cca6ca2a9e1751e6db78147280157fe0b54877fa8af324524e4bab13dd30cd467ec155cffd\
aedbab1dff729c001ce2ee8b98c8d5ab5d39706b'
    """

    assert (path or pubkey), 'At least key or path must be given!'

    tmpfile = None
    if pubkey:
        # write out to tmpdir
        tmpdir = tempfile.mkdtemp()
        tmpfile = os.path.join(tmpdir, 'pubkey')
        path = tmpfile
        f = open(tmpfile, 'w')
        f.write(pubkey)
        f.close()
        os.chmod(tmpfile, 0o600)

    # read out the key
    f = open(path, 'rb')
    e = f.read()
    f.close()

    # remove tmpfile again
    if tmpfile and os.path.exists(tmpfile):
        os.remove(tmpfile)
        os.rmdir(tmpdir)

    parts = e.split(b' ')
    data = decodebytes(parts[1])

    start = 0
    sdata = []
    while start < len(data):
        load = toint(data[start:start+4])
        sdata.append(data[start+4:start+4+load])
        start = start + load + 4

    expval = to_string(binascii.b2a_hex(sdata[1]))
    # remove leading zeros since plink fails to verify the key if there some
    while len(expval) and expval[0] == '0':
        expval = expval[1:]
    exp = '0x' + expval
    modulo = '0x' + to_string(binascii.b2a_hex(sdata[2])[2:])

    return exp + ',' + modulo


def genConfig(inputdict):
    """ Genearate a remote support configuration file for SSH connections
    a.k.a 'connect.rcn' file.
    It has an .ini file like syntax consisting of one [RemoteService] section
    which in turn contains multiple key = value pairs on each line

    >>> genConfig({'foo': 'bar', 'blubb': 'swush'})
    '[RemoteService]\\nblubb = swush\\nfoo = bar\\n\\n'
    """
    c = ConfigParser()
    c.add_section('RemoteService')
    for k in sorted(list(inputdict.keys())):
        c.set('RemoteService', k, inputdict[k])
    if six.PY2:
        output = BytesIO()
    else:
        output = StringIO()
    c.write(output)
    return output.getvalue()


def get_ports_from_proc():
    """ Return a list of tcp/udp IPv4 ports currently in use
    This function uses the /proc filesystem for information.

    >>> type(get_ports_from_proc()) == list  # doctest: +SKIP
    True
    """
    with open('/proc/net/tcp', 'r') as myfile:
        lines = myfile.readlines()[1:]  # slice first line which is a header
    tcp_items = [item.split()[1] for item in lines]
    ports_hex = [item.split(':')[1] for item in tcp_items]
    ports_int = [int(port_hex, 16) for port_hex in ports_hex]

    with open('/proc/net/udp', 'r') as myfile:
        lines = myfile.readlines()[1:]  # slice first line which is a header
    udp_items = [item.split()[1] for item in lines]
    ports_hex = [item.split(':')[1] for item in udp_items]

    # add udp port to list
    ports_int += [int(port_hex, 16) for port_hex in ports_hex]

    # use set to eliminate duplicated, make sorted lsit afterwards
    return sorted(set(ports_int))


def get_ports_from_services(path='/etc/services'):
    """ Return a set of all ports listed in the given file.
    Only lines with # are removed, no further parsing is done!

    >>> type(get_ports_from_services()) == set  # doctest: +SKIP
    True
    """
    with open(path, 'r') as myfile:
        lines = myfile.readlines()

    ports = set()
    for line in lines:
        if line.strip().startswith('#'):
            continue
        items = line.split()
        if len(items) >= 2:
            port = items[1].split('/')[0]
            ports.add(int(port))

    return ports


def get_free_port(start, stop, safe=True, blocked_portlist=None):
    """ Return the lowest free port in the given range start-stop which is
    currently unused in the system and not part of the blocked_portlist

    >>> get_free_port(start=17, stop=23, safe=True)  # doctest: +SKIP
    Traceback (most recent call last):
    ...
    AssertionError: All ports are blocked by other services
    start: 17
    stop: 23

    >>> get_free_port(start=20, stop=23, blocked_portlist=[20,21,22,23]
    ... )  # doctest: +SKIP
    Traceback (most recent call last):
    ...
    AssertionError: No more free ports available!
    Maybe you need to increase start: 20 or stop: 23
    or give less blocked ports: [20, 21, 22, 23]

    >>> get_free_port(start=20, stop=23, blocked_portlist=[20,21,22],
    ...     safe=False)  # doctest: +SKIP
    23

    >>> get_free_port(start=42000, stop=42000, safe=True)  # doctest: +SKIP
    42000
    """
    # exclude some ports
    if not blocked_portlist:
        blocked_portlist = []

    blocked_s = set(blocked_portlist)

    free = range(start, stop+1)
    free_s = set(free)

    if len(free) <= len(blocked_s):
        raise AssertionError(
            'No more free ports available!\nMaybe you need'
            ' to increase start: %s or stop: %s\nor give less'
            ' blocked ports: %s'
            % (start, stop, blocked_s)
        )

    if safe:
        # exclude other 'reserved' ports
        reserved = get_ports_from_services()
        blocked_s.update(reserved)

        # exclude all currently open ports
        open_ports = get_ports_from_proc()
        blocked_s.update(open_ports)

    free = free_s.difference(blocked_s)
    free = sorted(free)
    if not len(free):
        raise AssertionError(
            'All ports are blocked by other services\n'
            'start: %s\n'
            'stop: %s'
            % (start, stop)
        )
    port = free[0]
    return port


def kill_sshproc(pid):
    '''kill ssh process of user secconnect using wrapper executable listed in
    /etc/sudoers, create this entry if neccessary
    entry should look like:
    zope ALL=NOPASSWD:/usr/bin/perfact-killssh-secconnect
    '''
    retcode, output = safe_syscall(
        ['sudo', '/usr/bin/perfact-killssh-secconnect', '-p', str(pid)],
        raisemode=False
    )
    return retcode, output


def createWinInstaller(files, run=None, title=None, prompt=None):
    '''Creates an self extracting archive (using 7zip) for a simple
    installation without registry keys e.g. extraction of some files
    into a folder.
    Optionally a command can be executed after extraction is finished,
    using the 'run' parameter for the filename.

    :files
    Dictionary - keys are the filenames and values are the contents

    :run
    String - Name of file to be run after extraction

    :title
    String - Displayed name in the window title while extracting

    :prompt
    String - Bring up a message box with "OK/Cancel" buttons before
    extracting anything
    '''
    tempdir = tempfile.mkdtemp()
    filepaths = []
    for f in files:
        filepath = os.path.join(tempdir, f)
        filepaths.append(filepath)
        fh = open(filepath, 'wb')
        fh.write(files[f])
        fh.flush()
        fh.close()
    conf = ';!@Install@!UTF-8!\n'
    if title:
        conf += 'Title="'+title+'"\n'
    if prompt:
        conf += 'BeginPrompt="'+prompt+'"\n'
    if run:
        conf += 'RunProgram="'+run+'"\n'
    conf += ';!@InstallEnd@!\n'

    archive = os.path.join(tempdir, 'archive.7z')
    cmds = ['/usr/bin/7z', 'a', ]
    cmds.append(archive)
    cmds.extend(filepaths)
    ret, out = safe_syscall(cmds, raisemode=True)

    conffile = os.path.join(tempdir, '__config.txt')
    fc = open(conffile, 'w')
    fc.write(conf)
    fc.close()
    ret, out = safe_syscall(
        ['/bin/cat',
         '/usr/lib/p7zip/7zSD.sfx',
         conffile,
         archive, ],
        raisemode=True,
        text=False,
    )
    inst = out
    for filepath in filepaths:
        os.remove(filepath)
    os.remove(conffile)
    os.remove(archive)
    os.rmdir(tempdir)
    return inst


def binwrapper_killssh_secconnect():
    parser = argparse.ArgumentParser()

    parser.add_argument('--pid', '-p', type=str,
                        help='Pid of a secconnect ssh process to be killed',
                        default=None)

    args = parser.parse_args()
    mypid = args.pid

    if not mypid:
        sys.exit(1)

    # If no processes are found at all, ps exits with error level 1
    # The following can happen (retcode, output):
    # - User exists, has no processes: (1, '')
    # - User exists, has processes: (0, '<pids with blanks and newlines>)
    # - User doesn't exist: (1, 'error: user name ...')
    # Because of this behaviour, we use raisemode=False and check these cases.
    retcode, raw_secconnect_pids = safe_syscall(
        ['/bin/ps', 'h', '-o', 'pid', '-u', 'secconnect'],
        raisemode=False
    )

    # Raise an error, but only if a real error ocurred.
    if retcode != 0 and not ((retcode, raw_secconnect_pids) == (1, '')):
        raise AssertionError(
            'PIDs could not be retrieved: ' + repr(raw_secconnect_pids))

    secconnect_pids = raw_secconnect_pids.split('\n')
    secconnect_pids = [pid.strip() for pid in secconnect_pids if pid]

    if mypid not in secconnect_pids:
        print('Pid %s not found in secconnect processes!' % mypid)
        sys.exit(2)

    retcode, output = safe_syscall(
        ['/bin/kill', '-15', mypid],
        raisemode=False
    )

    sys.exit(retcode)
